120 lines
3.6 KiB
Python
120 lines
3.6 KiB
Python
import tkinter as tk
|
|
import warnings
|
|
from datetime import datetime
|
|
|
|
import numpy.random
|
|
import torch.utils.data
|
|
import torch.cuda
|
|
from tqdm.tk import tqdm
|
|
|
|
from architecture import MyCNN
|
|
from dataset import ImagesDataset
|
|
|
|
from AImageDataset import AImagesDataset
|
|
|
|
model = MyCNN(input_channels=1,
|
|
input_size=(100, 100),
|
|
hidden_channels=[500, 250, 100, 50],
|
|
output_channels=20,
|
|
use_batchnorm=True,
|
|
kernel_size=[9, 5, 3, 3, 1],
|
|
stride=[1, 1, 1, 1, 1],
|
|
activation_function=torch.nn.ReLU())
|
|
|
|
num_epochs = 100
|
|
|
|
batch_size = 64
|
|
optimizer = torch.optim.ASGD(model.parameters(),
|
|
lr=0.001,
|
|
lambd=1e-4,
|
|
alpha=0.75,
|
|
t0=1000000.0,
|
|
weight_decay=0)
|
|
loss_function = torch.nn.CrossEntropyLoss()
|
|
|
|
|
|
if __name__ == '__main__':
|
|
|
|
torch.random.manual_seed(42)
|
|
numpy.random.seed(42)
|
|
|
|
start_time = datetime.now()
|
|
|
|
dataset = ImagesDataset("training_data")
|
|
|
|
# dataset = torch.utils.data.Subset(dataset, range(0, 20))
|
|
|
|
train_data, eval_data = torch.utils.data.random_split(dataset, [0.5, 0.5])
|
|
|
|
train_loader = torch.utils.data.DataLoader(AImagesDataset(train_data), batch_size=batch_size)
|
|
eval_loader = torch.utils.data.DataLoader(eval_data, batch_size=1)
|
|
|
|
if torch.cuda.is_available():
|
|
# print("GPU available")
|
|
model = model.cuda()
|
|
else:
|
|
warnings.warn("GPU not available")
|
|
|
|
train_losses = []
|
|
eval_losses = []
|
|
|
|
progress_epoch = tqdm(range(num_epochs), position=0, tk_parent=root_window)
|
|
progress_epoch.set_description("Epoch")
|
|
progress_train_data = tqdm(train_loader, position=1, tk_parent=root_window)
|
|
progress_eval_data = tqdm(eval_loader, position=2, tk_parent=root_window)
|
|
progress_train_data.set_description("Training progress")
|
|
progress_eval_data.set_description("Evaluation progress")
|
|
|
|
for epoch in progress_epoch:
|
|
|
|
train_loss = 0
|
|
eval_loss = 0
|
|
|
|
progress_train_data.reset()
|
|
progress_eval_data.reset()
|
|
|
|
# Start training of model
|
|
|
|
model.train()
|
|
|
|
for batch_nr, (imageT, transforms, img_index, classIDs, labels, paths) in enumerate(progress_train_data):
|
|
imageT = imageT.to('cuda')
|
|
classIDs = classIDs.to('cuda')
|
|
|
|
# progress_train_data.set_postfix_str("Running model...")
|
|
outputs = model(imageT)
|
|
|
|
optimizer.zero_grad()
|
|
# progress_train_data.set_postfix_str("calculating loss...")
|
|
loss = loss_function(outputs, classIDs)
|
|
# progress_train_data.set_postfix_str("propagating loss...")
|
|
loss.backward()
|
|
# progress_train_data.set_postfix_str("optimizing...")
|
|
optimizer.step()
|
|
|
|
train_loss += loss.item()
|
|
|
|
mean_loss = train_loss / len(train_loader)
|
|
train_losses.append(mean_loss)
|
|
|
|
# evaluation of model
|
|
|
|
model.eval()
|
|
|
|
with torch.no_grad():
|
|
for (imageT, classIDs, labels, paths) in progress_eval_data:
|
|
imageT = imageT.to('cuda')
|
|
classIDs = classIDs.to('cuda')
|
|
|
|
outputs = model(imageT)
|
|
loss = loss_function(outputs, classIDs)
|
|
eval_loss = loss.item()
|
|
|
|
eval_losses.append(eval_loss)
|
|
|
|
# print epoch summary
|
|
|
|
# print(f"Epoch: {epoch} --- Train loss: {train_loss:7.4f} --- Eval loss: {eval_loss:7.4f}")
|
|
|
|
torch.save(model.state_dict(), f'models/model-{start_time.strftime("%Y%m%d-%H%M%S")}-epoch-{epoch}.pt')
|