AI-Project/main.py

50 lines
1.4 KiB
Python

import math
import threading
import tkinter as tk
import tkinter.ttk as ttk
from AsyncProgress import AsyncProgress
from Plotter import Plotter
import cnn_train
def cancel_train(plot):
def cancel_func():
try:
if len(plot.data['train_loss']) > 7 and sum(plot.data['train_loss'][-7:]) < 50:
print("stopped")
return len(plot.data['train_loss']) > 7 and sum(plot.data['train_loss'][-7:]) < 50
except KeyError:
return False
return cancel_func
if __name__ == '__main__':
root = tk.Tk()
root.title("AI Project")
progress_frame = ttk.Frame(root, padding=20, width=500)
progress_frame.grid(row=0, column=0)
pbar_epoch = AsyncProgress(progress_frame, label=" Epoch: ")
pbar_train = AsyncProgress(progress_frame, label=" Train: ")
pbar_eval = AsyncProgress(progress_frame, label="Evaluation: ")
pbar_epoch.grid(row=0, column=0)
pbar_train.grid(row=1, column=0)
pbar_eval.grid(row=2, column=0)
plotter_acc = Plotter(root, percent_lim=True)
plotter_acc.grid(row=3, column=0)
plotter_loss = Plotter(root)
plotter_loss.grid(row=3, column=1)
root.after(0, lambda: threading.Thread(target=cnn_train.train_worker,
args=(pbar_epoch, pbar_train, pbar_eval, plotter_acc, plotter_loss),
daemon=True).start())
root.focus_set()
root.mainloop()