AI-Project/cnn_train.py

120 lines
3.6 KiB
Python

import tkinter as tk
import warnings
from datetime import datetime
import numpy.random
import torch.utils.data
import torch.cuda
from tqdm.tk import tqdm
from architecture import MyCNN
from dataset import ImagesDataset
from AImageDataset import AImagesDataset
model = MyCNN(input_channels=1,
input_size=(100, 100),
hidden_channels=[500, 250, 100, 50],
output_channels=20,
use_batchnorm=True,
kernel_size=[9, 5, 3, 3, 1],
stride=[1, 1, 1, 1, 1],
activation_function=torch.nn.ReLU())
num_epochs = 100
batch_size = 64
optimizer = torch.optim.ASGD(model.parameters(),
lr=0.001,
lambd=1e-4,
alpha=0.75,
t0=1000000.0,
weight_decay=0)
loss_function = torch.nn.CrossEntropyLoss()
if __name__ == '__main__':
torch.random.manual_seed(42)
numpy.random.seed(42)
start_time = datetime.now()
dataset = ImagesDataset("training_data")
# dataset = torch.utils.data.Subset(dataset, range(0, 20))
train_data, eval_data = torch.utils.data.random_split(dataset, [0.5, 0.5])
train_loader = torch.utils.data.DataLoader(AImagesDataset(train_data), batch_size=batch_size)
eval_loader = torch.utils.data.DataLoader(eval_data, batch_size=1)
if torch.cuda.is_available():
# print("GPU available")
model = model.cuda()
else:
warnings.warn("GPU not available")
train_losses = []
eval_losses = []
progress_epoch = tqdm(range(num_epochs), position=0, tk_parent=root_window)
progress_epoch.set_description("Epoch")
progress_train_data = tqdm(train_loader, position=1, tk_parent=root_window)
progress_eval_data = tqdm(eval_loader, position=2, tk_parent=root_window)
progress_train_data.set_description("Training progress")
progress_eval_data.set_description("Evaluation progress")
for epoch in progress_epoch:
train_loss = 0
eval_loss = 0
progress_train_data.reset()
progress_eval_data.reset()
# Start training of model
model.train()
for batch_nr, (imageT, transforms, img_index, classIDs, labels, paths) in enumerate(progress_train_data):
imageT = imageT.to('cuda')
classIDs = classIDs.to('cuda')
# progress_train_data.set_postfix_str("Running model...")
outputs = model(imageT)
optimizer.zero_grad()
# progress_train_data.set_postfix_str("calculating loss...")
loss = loss_function(outputs, classIDs)
# progress_train_data.set_postfix_str("propagating loss...")
loss.backward()
# progress_train_data.set_postfix_str("optimizing...")
optimizer.step()
train_loss += loss.item()
mean_loss = train_loss / len(train_loader)
train_losses.append(mean_loss)
# evaluation of model
model.eval()
with torch.no_grad():
for (imageT, classIDs, labels, paths) in progress_eval_data:
imageT = imageT.to('cuda')
classIDs = classIDs.to('cuda')
outputs = model(imageT)
loss = loss_function(outputs, classIDs)
eval_loss = loss.item()
eval_losses.append(eval_loss)
# print epoch summary
# print(f"Epoch: {epoch} --- Train loss: {train_loss:7.4f} --- Eval loss: {eval_loss:7.4f}")
torch.save(model.state_dict(), f'models/model-{start_time.strftime("%Y%m%d-%H%M%S")}-epoch-{epoch}.pt')