added saving of plots

This commit is contained in:
Patrick 2024-07-07 11:55:50 +02:00
parent 78f08419ba
commit b80b24e07e
2 changed files with 19 additions and 1 deletions

View File

@ -25,6 +25,9 @@ class Plotter(tk.Frame):
def append(self, name: str, value): def append(self, name: str, value):
self.__queue.put((name, value)) self.__queue.put((name, value))
def save(self, filename):
self.figure.savefig(filename)
def __update(self): def __update(self):
to_update = False to_update = False
while not self.__queue.empty(): while not self.__queue.empty():

17
main.py
View File

@ -2,6 +2,7 @@ import math
import threading import threading
import tkinter as tk import tkinter as tk
import tkinter.ttk as ttk import tkinter.ttk as ttk
import datetime
from AsyncProgress import AsyncProgress from AsyncProgress import AsyncProgress
from Plotter import Plotter from Plotter import Plotter
@ -21,6 +22,16 @@ def cancel_train(plot):
return cancel_func return cancel_func
def on_exit(root, plotter_acc, plotter_loss, start_time):
def func():
plotter_acc.save(f'models/model-{start_time.strftime("%Y%m%d-%H%M%S")}-acc.jpg')
plotter_loss.save(f'models/model-{start_time.strftime("%Y%m%d-%H%M%S")}-loss.jpg')
root.destroy()
return func
if __name__ == '__main__': if __name__ == '__main__':
root = tk.Tk() root = tk.Tk()
@ -41,9 +52,13 @@ if __name__ == '__main__':
plotter_loss = Plotter(root) plotter_loss = Plotter(root)
plotter_loss.grid(row=3, column=1) plotter_loss.grid(row=3, column=1)
start_time = datetime.datetime.now()
root.after(0, lambda: threading.Thread(target=cnn_train.train_worker, root.after(0, lambda: threading.Thread(target=cnn_train.train_worker,
args=(pbar_epoch, pbar_train, pbar_eval, plotter_acc, plotter_loss), args=(pbar_epoch, pbar_train, pbar_eval, plotter_acc, plotter_loss, start_time),
daemon=True).start()) daemon=True).start())
root.protocol("WM_DELETE_WINDOW", on_exit(root, plotter_acc, plotter_loss, start_time))
root.focus_set() root.focus_set()
root.mainloop() root.mainloop()