65 lines
1.9 KiB
Python
65 lines
1.9 KiB
Python
import math
|
|
import threading
|
|
import tkinter as tk
|
|
import tkinter.ttk as ttk
|
|
import datetime
|
|
|
|
from AsyncProgress import AsyncProgress
|
|
from Plotter import Plotter
|
|
|
|
import cnn_train
|
|
|
|
|
|
def cancel_train(plot):
|
|
def cancel_func():
|
|
try:
|
|
if len(plot.data['train_loss']) > 7 and sum(plot.data['train_loss'][-7:]) < 50:
|
|
print("stopped")
|
|
return len(plot.data['train_loss']) > 7 and sum(plot.data['train_loss'][-7:]) < 50
|
|
except KeyError:
|
|
return False
|
|
|
|
return cancel_func
|
|
|
|
|
|
def on_exit(root, plotter_acc, plotter_loss, start_time):
|
|
def func():
|
|
plotter_acc.save(f'models/model-{start_time.strftime("%Y%m%d-%H%M%S")}-acc.jpg')
|
|
plotter_loss.save(f'models/model-{start_time.strftime("%Y%m%d-%H%M%S")}-loss.jpg')
|
|
|
|
root.destroy()
|
|
|
|
return func
|
|
|
|
|
|
if __name__ == '__main__':
|
|
|
|
root = tk.Tk()
|
|
root.title("AI Project")
|
|
|
|
progress_frame = ttk.Frame(root, padding=20, width=500)
|
|
progress_frame.grid(row=0, column=0)
|
|
pbar_epoch = AsyncProgress(progress_frame, label=" Epoch: ")
|
|
pbar_train = AsyncProgress(progress_frame, label=" Train: ")
|
|
pbar_eval = AsyncProgress(progress_frame, label="Evaluation: ")
|
|
|
|
pbar_epoch.grid(row=0, column=0)
|
|
pbar_train.grid(row=1, column=0)
|
|
pbar_eval.grid(row=2, column=0)
|
|
|
|
plotter_acc = Plotter(root, percent_lim=True)
|
|
plotter_acc.grid(row=3, column=0)
|
|
plotter_loss = Plotter(root)
|
|
plotter_loss.grid(row=3, column=1)
|
|
|
|
start_time = datetime.datetime.now()
|
|
|
|
root.after(0, lambda: threading.Thread(target=cnn_train.train_worker,
|
|
args=(pbar_epoch, pbar_train, pbar_eval, plotter_acc, plotter_loss, start_time),
|
|
daemon=True).start())
|
|
|
|
root.protocol("WM_DELETE_WINDOW", on_exit(root, plotter_acc, plotter_loss, start_time))
|
|
|
|
root.focus_set()
|
|
root.mainloop()
|