added plotter and start of training
This commit is contained in:
parent
b10f09c3e0
commit
79817340dd
26
main.py
26
main.py
|
|
@ -1,7 +1,24 @@
|
||||||
|
import math
|
||||||
|
import threading
|
||||||
import tkinter as tk
|
import tkinter as tk
|
||||||
import tkinter.ttk as ttk
|
import tkinter.ttk as ttk
|
||||||
|
|
||||||
from AsyncProgress import AsyncProgress
|
from AsyncProgress import AsyncProgress
|
||||||
|
from Plotter import Plotter
|
||||||
|
|
||||||
|
import cnn_train
|
||||||
|
|
||||||
|
|
||||||
|
def cancel_train(plot):
|
||||||
|
def cancel_func():
|
||||||
|
try:
|
||||||
|
if len(plot.data['train_loss']) > 7 and sum(plot.data['train_loss'][-7:]) < 50:
|
||||||
|
print("stopped")
|
||||||
|
return len(plot.data['train_loss']) > 7 and sum(plot.data['train_loss'][-7:]) < 50
|
||||||
|
except KeyError:
|
||||||
|
return False
|
||||||
|
|
||||||
|
return cancel_func
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
|
|
@ -19,5 +36,14 @@ if __name__ == '__main__':
|
||||||
pbar_train.grid(row=1, column=0)
|
pbar_train.grid(row=1, column=0)
|
||||||
pbar_eval.grid(row=2, column=0)
|
pbar_eval.grid(row=2, column=0)
|
||||||
|
|
||||||
|
plotter_acc = Plotter(root, percent_lim=True)
|
||||||
|
plotter_acc.grid(row=3, column=0)
|
||||||
|
plotter_loss = Plotter(root)
|
||||||
|
plotter_loss.grid(row=3, column=1)
|
||||||
|
|
||||||
|
root.after(0, lambda: threading.Thread(target=cnn_train.train_worker,
|
||||||
|
args=(pbar_epoch, pbar_train, pbar_eval, plotter_acc, plotter_loss),
|
||||||
|
daemon=True).start())
|
||||||
|
|
||||||
root.focus_set()
|
root.focus_set()
|
||||||
root.mainloop()
|
root.mainloop()
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue