added plotter and start of training

This commit is contained in:
Patrick 2024-07-06 18:19:42 +02:00
parent b10f09c3e0
commit 79817340dd
1 changed files with 26 additions and 0 deletions

26
main.py
View File

@ -1,7 +1,24 @@
import math
import threading
import tkinter as tk import tkinter as tk
import tkinter.ttk as ttk import tkinter.ttk as ttk
from AsyncProgress import AsyncProgress from AsyncProgress import AsyncProgress
from Plotter import Plotter
import cnn_train
def cancel_train(plot):
def cancel_func():
try:
if len(plot.data['train_loss']) > 7 and sum(plot.data['train_loss'][-7:]) < 50:
print("stopped")
return len(plot.data['train_loss']) > 7 and sum(plot.data['train_loss'][-7:]) < 50
except KeyError:
return False
return cancel_func
if __name__ == '__main__': if __name__ == '__main__':
@ -19,5 +36,14 @@ if __name__ == '__main__':
pbar_train.grid(row=1, column=0) pbar_train.grid(row=1, column=0)
pbar_eval.grid(row=2, column=0) pbar_eval.grid(row=2, column=0)
plotter_acc = Plotter(root, percent_lim=True)
plotter_acc.grid(row=3, column=0)
plotter_loss = Plotter(root)
plotter_loss.grid(row=3, column=1)
root.after(0, lambda: threading.Thread(target=cnn_train.train_worker,
args=(pbar_epoch, pbar_train, pbar_eval, plotter_acc, plotter_loss),
daemon=True).start())
root.focus_set() root.focus_set()
root.mainloop() root.mainloop()