AI-Project/cnn_train.py

175 lines
6.1 KiB
Python

from datetime import datetime
import numpy.random
import torch.utils.data
import torch.cuda
from architecture import MyCNN
from dataset import ImagesDataset
from AImageDataset import AImagesDataset
from AsyncDataLoader import AsyncDataLoader
def split_data(data: ImagesDataset):
class_nums = (torch.bincount(
torch.Tensor([data.classnames_to_ids[name] for _, name in data.filenames_classnames]).long())
.tolist())
indices = ([], [])
for class_id in range(len(class_nums)):
class_num = class_nums[class_id]
index_perm = torch.randperm(class_num).tolist()
class_indices = [i for i, e in enumerate(data.filenames_classnames) if data.classnames_to_ids[e[1]] == class_id]
indices[0].extend([class_indices[i] for i in index_perm[0::2]])
indices[1].extend([class_indices[i] for i in index_perm[1::2]])
return (torch.utils.data.Subset(data, indices[0]),
torch.utils.data.Subset(data, indices[1]))
def train_model(accuracies,
losses,
progress_epoch,
progress_train_data,
progress_eval_data,
model,
num_epochs,
batch_size,
optimizer,
loss_function,
augment_data,
device,
start_time):
torch.random.manual_seed(42)
numpy.random.seed(42)
torch.multiprocessing.set_start_method('spawn', force=True)
dataset = ImagesDataset("training_data")
# dataset = torch.utils.data.Subset(dataset, range(0, 32))
train_data, eval_data = split_data(dataset)
# train_data, eval_data = torch.utils.data.random_split(dataset, [0.5, 0.5])
augmented_train_data = AImagesDataset(train_data, augment_data)
train_loader = AsyncDataLoader(augmented_train_data,
batch_size=batch_size,
num_workers=3,
pin_memory=True,
shuffle=True)
eval_loader = AsyncDataLoader(eval_data,
batch_size=batch_size,
num_workers=3,
pin_memory=True)
for epoch in progress_epoch.range(num_epochs):
train_positives = torch.tensor(0, device=device)
eval_positives = torch.tensor(0, device=device)
train_loss = torch.tensor(0.0, device=device)
eval_loss = torch.tensor(0.0, device=device)
# Start training of model
progress_train_data.reset()
model.train()
for batch_nr, (image_t, transforms, img_index, class_ids, labels, paths) \
in enumerate(progress_train_data.iter(train_loader)):
image_t = image_t.to(device)
class_ids = class_ids.to(device)
outputs = model(image_t)
optimizer.zero_grad(set_to_none=True)
loss = loss_function(outputs, class_ids)
loss.backward()
optimizer.step()
train_loss += loss
classes = outputs.argmax(dim=1)
train_positives += torch.sum(torch.eq(classes, class_ids))
accuracies.append('train_acc', train_positives.item() / len(augmented_train_data))
losses.append('train_loss', train_loss.item() / len(augmented_train_data))
print("Train: ", train_positives.item(), "/ ", len(augmented_train_data),
" = ", train_positives.item() / len(augmented_train_data))
# evaluation of model
progress_eval_data.reset()
model.eval()
with torch.no_grad():
for (image_t, class_ids, labels, paths) in progress_eval_data.iter(eval_loader):
image_t = image_t.to(device)
class_ids = class_ids.to(device)
outputs = model(image_t)
classes = outputs.argmax(dim=1)
eval_positives += torch.sum(torch.eq(classes, class_ids))
eval_loss += loss_function(outputs, class_ids)
accuracies.append('eval_acc', eval_positives.item() / len(eval_data))
losses.append('eval_loss', eval_loss.item() / len(eval_data))
print("Eval: ", eval_positives.item(), "/ ", len(eval_data), " = ", eval_positives.item() / len(eval_data))
# print epoch summary
# print(f"Epoch: {epoch} --- Train loss: {train_loss:7.4f} --- Eval loss: {eval_loss:7.4f}")
if eval_positives.item() / len(eval_data) > 0.5:
torch.save(model.state_dict(), f'models/model-{start_time.strftime("%Y%m%d-%H%M%S")}-epoch-{epoch}.pt')
with open(f'models/model-{start_time.strftime("%Y%m%d-%H%M%S")}.csv', 'a') as file:
file.write(f'{epoch};{len(augmented_train_data)};{len(eval_data)};{train_loss.item()};{eval_loss.item()};'
f'{train_positives};{eval_positives}\n')
def train_worker(p_epoch, p_train, p_eval, plotter_accuracies, plotter_loss, start_time):
if not torch.cuda.is_available():
raise RuntimeError("GPU not available")
device = 'cuda'
model = MyCNN(input_channels=1,
input_size=(100, 100)).to(device)
num_epochs = 1000000
batch_size = 64
optimizer = torch.optim.Adam(model.parameters(),
lr=0.0001,
# weight_decay=0.1,
fused=True)
loss_function = torch.nn.CrossEntropyLoss()
augment_data = True
file_name = f'models/model-{start_time.strftime("%Y%m%d-%H%M%S")}.csv'
with open(file_name.replace(".csv", ".txt"), 'a') as file:
file.write(f"device: {device}\n")
file.write(f"batch_size: {batch_size}\n")
file.write(f"optimizer: {optimizer}\n")
file.write(f"loss_function: {loss_function}\n")
file.write(f"augment_data: {augment_data}\n")
file.write(f"model: {model}")
train_model(plotter_accuracies, plotter_loss, p_epoch, p_train, p_eval,
model,
num_epochs,
batch_size,
optimizer,
loss_function,
augment_data,
device,
start_time)