import math import threading import tkinter as tk import tkinter.ttk as ttk import datetime from AsyncProgress import AsyncProgress from Plotter import Plotter import cnn_train def cancel_train(plot): def cancel_func(): try: if len(plot.data['train_loss']) > 7 and sum(plot.data['train_loss'][-7:]) < 50: print("stopped") return len(plot.data['train_loss']) > 7 and sum(plot.data['train_loss'][-7:]) < 50 except KeyError: return False return cancel_func def on_exit(root, plotter_acc, plotter_loss, start_time): def func(): plotter_acc.save(f'models/model-{start_time.strftime("%Y%m%d-%H%M%S")}-acc.jpg') plotter_loss.save(f'models/model-{start_time.strftime("%Y%m%d-%H%M%S")}-loss.jpg') root.destroy() return func if __name__ == '__main__': root = tk.Tk() root.title("AI Project") progress_frame = ttk.Frame(root, padding=20, width=500) progress_frame.grid(row=0, column=0) pbar_epoch = AsyncProgress(progress_frame, label=" Epoch: ") pbar_train = AsyncProgress(progress_frame, label=" Train: ") pbar_eval = AsyncProgress(progress_frame, label="Evaluation: ") pbar_epoch.grid(row=0, column=0) pbar_train.grid(row=1, column=0) pbar_eval.grid(row=2, column=0) plotter_acc = Plotter(root, percent_lim=True) plotter_acc.grid(row=3, column=0) plotter_loss = Plotter(root) plotter_loss.grid(row=3, column=1) start_time = datetime.datetime.now() root.after(0, lambda: threading.Thread(target=cnn_train.train_worker, args=(pbar_epoch, pbar_train, pbar_eval, plotter_acc, plotter_loss, start_time), daemon=True).start()) root.protocol("WM_DELETE_WINDOW", on_exit(root, plotter_acc, plotter_loss, start_time)) root.focus_set() root.mainloop()