From 79817340dd49a81643b323f1ec79e68aa43a9eda Mon Sep 17 00:00:00 2001 From: Patrick Date: Sat, 6 Jul 2024 18:19:42 +0200 Subject: [PATCH] added plotter and start of training --- main.py | 26 ++++++++++++++++++++++++++ 1 file changed, 26 insertions(+) diff --git a/main.py b/main.py index 4329931..83f97f1 100644 --- a/main.py +++ b/main.py @@ -1,7 +1,24 @@ +import math +import threading import tkinter as tk import tkinter.ttk as ttk from AsyncProgress import AsyncProgress +from Plotter import Plotter + +import cnn_train + + +def cancel_train(plot): + def cancel_func(): + try: + if len(plot.data['train_loss']) > 7 and sum(plot.data['train_loss'][-7:]) < 50: + print("stopped") + return len(plot.data['train_loss']) > 7 and sum(plot.data['train_loss'][-7:]) < 50 + except KeyError: + return False + + return cancel_func if __name__ == '__main__': @@ -19,5 +36,14 @@ if __name__ == '__main__': pbar_train.grid(row=1, column=0) pbar_eval.grid(row=2, column=0) + plotter_acc = Plotter(root, percent_lim=True) + plotter_acc.grid(row=3, column=0) + plotter_loss = Plotter(root) + plotter_loss.grid(row=3, column=1) + + root.after(0, lambda: threading.Thread(target=cnn_train.train_worker, + args=(pbar_epoch, pbar_train, pbar_eval, plotter_acc, plotter_loss), + daemon=True).start()) + root.focus_set() root.mainloop()