From b80b24e07e4abb208b3a5d594727145b73e06c57 Mon Sep 17 00:00:00 2001 From: Patrick Date: Sun, 7 Jul 2024 11:55:50 +0200 Subject: [PATCH] added saving of plots --- Plotter.py | 3 +++ main.py | 17 ++++++++++++++++- 2 files changed, 19 insertions(+), 1 deletion(-) diff --git a/Plotter.py b/Plotter.py index f5acb43..3c934fa 100644 --- a/Plotter.py +++ b/Plotter.py @@ -25,6 +25,9 @@ class Plotter(tk.Frame): def append(self, name: str, value): self.__queue.put((name, value)) + def save(self, filename): + self.figure.savefig(filename) + def __update(self): to_update = False while not self.__queue.empty(): diff --git a/main.py b/main.py index 83f97f1..fc460d1 100644 --- a/main.py +++ b/main.py @@ -2,6 +2,7 @@ import math import threading import tkinter as tk import tkinter.ttk as ttk +import datetime from AsyncProgress import AsyncProgress from Plotter import Plotter @@ -21,6 +22,16 @@ def cancel_train(plot): return cancel_func +def on_exit(root, plotter_acc, plotter_loss, start_time): + def func(): + plotter_acc.save(f'models/model-{start_time.strftime("%Y%m%d-%H%M%S")}-acc.jpg') + plotter_loss.save(f'models/model-{start_time.strftime("%Y%m%d-%H%M%S")}-loss.jpg') + + root.destroy() + + return func + + if __name__ == '__main__': root = tk.Tk() @@ -41,9 +52,13 @@ if __name__ == '__main__': plotter_loss = Plotter(root) plotter_loss.grid(row=3, column=1) + start_time = datetime.datetime.now() + root.after(0, lambda: threading.Thread(target=cnn_train.train_worker, - args=(pbar_epoch, pbar_train, pbar_eval, plotter_acc, plotter_loss), + args=(pbar_epoch, pbar_train, pbar_eval, plotter_acc, plotter_loss, start_time), daemon=True).start()) + root.protocol("WM_DELETE_WINDOW", on_exit(root, plotter_acc, plotter_loss, start_time)) + root.focus_set() root.mainloop()