diff --git a/main.py b/main.py index 4329931..83f97f1 100644 --- a/main.py +++ b/main.py @@ -1,7 +1,24 @@ +import math +import threading import tkinter as tk import tkinter.ttk as ttk from AsyncProgress import AsyncProgress +from Plotter import Plotter + +import cnn_train + + +def cancel_train(plot): + def cancel_func(): + try: + if len(plot.data['train_loss']) > 7 and sum(plot.data['train_loss'][-7:]) < 50: + print("stopped") + return len(plot.data['train_loss']) > 7 and sum(plot.data['train_loss'][-7:]) < 50 + except KeyError: + return False + + return cancel_func if __name__ == '__main__': @@ -19,5 +36,14 @@ if __name__ == '__main__': pbar_train.grid(row=1, column=0) pbar_eval.grid(row=2, column=0) + plotter_acc = Plotter(root, percent_lim=True) + plotter_acc.grid(row=3, column=0) + plotter_loss = Plotter(root) + plotter_loss.grid(row=3, column=1) + + root.after(0, lambda: threading.Thread(target=cnn_train.train_worker, + args=(pbar_epoch, pbar_train, pbar_eval, plotter_acc, plotter_loss), + daemon=True).start()) + root.focus_set() root.mainloop()