revamped training process

This commit is contained in:
Patrick 2024-07-06 18:20:21 +02:00
parent 79817340dd
commit e9d86f3309
1 changed files with 104 additions and 71 deletions

View File

@ -1,119 +1,152 @@
import tkinter as tk
import warnings
from datetime import datetime
import numpy.random
import torch.utils.data
import torch.cuda
from tqdm.tk import tqdm
from architecture import MyCNN
from dataset import ImagesDataset
from AImageDataset import AImagesDataset
model = MyCNN(input_channels=1,
input_size=(100, 100),
hidden_channels=[500, 250, 100, 50],
output_channels=20,
use_batchnorm=True,
kernel_size=[9, 5, 3, 3, 1],
stride=[1, 1, 1, 1, 1],
activation_function=torch.nn.ReLU())
num_epochs = 100
batch_size = 64
optimizer = torch.optim.ASGD(model.parameters(),
lr=0.001,
lambd=1e-4,
alpha=0.75,
t0=1000000.0,
weight_decay=0)
loss_function = torch.nn.CrossEntropyLoss()
if __name__ == '__main__':
def train_model(accuracies,
losses,
progress_epoch,
progress_train_data,
progress_eval_data,
model,
num_epochs,
batch_size,
optimizer,
loss_function,
device,
start_time):
torch.random.manual_seed(42)
numpy.random.seed(42)
start_time = datetime.now()
torch.multiprocessing.set_start_method('spawn', force=True)
dataset = ImagesDataset("training_data")
# dataset = torch.utils.data.Subset(dataset, range(0, 20))
# dataset = torch.utils.data.Subset(dataset, range(0, 1024))
train_data, eval_data = torch.utils.data.random_split(dataset, [0.5, 0.5])
train_loader = torch.utils.data.DataLoader(AImagesDataset(train_data), batch_size=batch_size)
eval_loader = torch.utils.data.DataLoader(eval_data, batch_size=1)
augmented_train_data = AImagesDataset(train_data, False)
train_loader = torch.utils.data.DataLoader(augmented_train_data,
batch_size=batch_size,
num_workers=3,
pin_memory=True,
shuffle=True)
eval_loader = torch.utils.data.DataLoader(eval_data,
batch_size=batch_size,
num_workers=3,
pin_memory=True)
if torch.cuda.is_available():
# print("GPU available")
model = model.cuda()
else:
warnings.warn("GPU not available")
for epoch in progress_epoch.range(num_epochs):
train_losses = []
eval_losses = []
train_positives = torch.tensor(0, device=device)
eval_positives = torch.tensor(0, device=device)
progress_epoch = tqdm(range(num_epochs), position=0, tk_parent=root_window)
progress_epoch.set_description("Epoch")
progress_train_data = tqdm(train_loader, position=1, tk_parent=root_window)
progress_eval_data = tqdm(eval_loader, position=2, tk_parent=root_window)
progress_train_data.set_description("Training progress")
progress_eval_data.set_description("Evaluation progress")
for epoch in progress_epoch:
train_loss = 0
eval_loss = 0
progress_train_data.reset()
progress_eval_data.reset()
train_loss = torch.tensor(0.0, device=device)
eval_loss = torch.tensor(0.0, device=device)
# Start training of model
progress_train_data.reset()
model.train()
for batch_nr, (imageT, transforms, img_index, classIDs, labels, paths) in enumerate(progress_train_data):
imageT = imageT.to('cuda')
classIDs = classIDs.to('cuda')
for batch_nr, (image_t, transforms, img_index, class_ids, labels, paths) \
in enumerate(progress_train_data.iter(train_loader)):
image_t = image_t.to(device)
class_ids = class_ids.to(device)
# progress_train_data.set_postfix_str("Running model...")
outputs = model(imageT)
outputs = model(image_t)
optimizer.zero_grad()
# progress_train_data.set_postfix_str("calculating loss...")
loss = loss_function(outputs, classIDs)
# progress_train_data.set_postfix_str("propagating loss...")
optimizer.zero_grad(set_to_none=True)
loss = loss_function(outputs, class_ids)
loss.backward()
# progress_train_data.set_postfix_str("optimizing...")
optimizer.step()
train_loss += loss.item()
train_loss += loss
mean_loss = train_loss / len(train_loader)
train_losses.append(mean_loss)
outputs.flatten()
classes = outputs.argmax()
train_positives += torch.sum(torch.eq(classes, class_ids))
accuracies.append('train_acc', train_positives.item() / len(augmented_train_data))
losses.append('train_loss', train_loss.item() / len(augmented_train_data))
print("Train: ", train_positives.item(), "/ ", len(augmented_train_data),
" = ", train_positives.item() / len(augmented_train_data))
# evaluation of model
progress_eval_data.reset()
model.eval()
with torch.no_grad():
for (imageT, classIDs, labels, paths) in progress_eval_data:
imageT = imageT.to('cuda')
classIDs = classIDs.to('cuda')
for (image_t, class_ids, labels, paths) in progress_eval_data.iter(eval_loader):
image_t = image_t.to(device)
class_ids = class_ids.to(device)
outputs = model(imageT)
loss = loss_function(outputs, classIDs)
eval_loss = loss.item()
outputs = model(image_t)
outputs.flatten()
classes = outputs.argmax()
eval_losses.append(eval_loss)
eval_positives += torch.sum(torch.eq(classes, class_ids))
eval_loss += loss_function(outputs, class_ids)
accuracies.append('eval_acc', eval_positives.item() / len(eval_data))
losses.append('eval_loss', eval_loss.item() / len(eval_data))
print("Eval: ", eval_positives.item(), "/ ", len(eval_data), " = ", eval_positives.item() / len(eval_data))
# print epoch summary
# print(f"Epoch: {epoch} --- Train loss: {train_loss:7.4f} --- Eval loss: {eval_loss:7.4f}")
if eval_positives.item() / len(eval_data) > 0.3:
torch.save(model.state_dict(), f'models/model-{start_time.strftime("%Y%m%d-%H%M%S")}-epoch-{epoch}.pt')
with open(f'models/model-{start_time.strftime("%Y%m%d-%H%M%S")}.csv', 'a') as file:
file.write(f'{epoch};{len(augmented_train_data)};{len(eval_data)};{train_loss.item()};{eval_loss.item()};'
f'{train_positives};{eval_positives}\n')
def train_worker(p_epoch, p_train, p_eval, plotter_accuracies, plotter_loss):
if not torch.cuda.is_available():
raise RuntimeError("GPU not available")
device = 'cuda'
model = MyCNN(input_channels=1,
input_size=(100, 100)).to(device)
num_epochs = 1000000
batch_size = 64
optimizer = torch.optim.Adam(model.parameters(),
lr=0.00005,
# weight_decay=0.1,
fused=True)
loss_function = torch.nn.CrossEntropyLoss()
start_time = datetime.now()
file_name = f'models/model-{start_time.strftime("%Y%m%d-%H%M%S")}.csv'
with open(file_name.replace(".csv", ".txt"), 'a') as file:
file.write(f"device: {device}\n")
file.write(f"batch_size: {batch_size}\n")
file.write(f"optimizer: {optimizer}\n")
file.write(f"loss_function: {loss_function}\n")
file.write(f"model: {model}")
train_model(plotter_accuracies, plotter_loss, p_epoch, p_train, p_eval,
model,
num_epochs,
batch_size,
optimizer,
loss_function,
device,
start_time)