AI-Project/cnn_train.py

166 lines
5.7 KiB
Python

from datetime import datetime
import numpy.random
import torch.utils.data
import torch.cuda
from architecture import model
from dataset import ImagesDataset
from AImageDataset import AImagesDataset
from AsyncDataLoader import AsyncDataLoader
def split_data(data: ImagesDataset):
class_nums = (torch.bincount(
torch.Tensor([data.classnames_to_ids[name] for _, name in data.filenames_classnames]).long())
.tolist())
indices = ([], [])
for class_id in range(len(class_nums)):
class_num = class_nums[class_id]
index_perm = torch.randperm(class_num).tolist()
class_indices = [i for i, e in enumerate(data.filenames_classnames) if data.classnames_to_ids[e[1]] == class_id]
indices[0].extend([class_indices[i] for i in index_perm[0::2]])
indices[1].extend([class_indices[i] for i in index_perm[1::2]])
return (torch.utils.data.Subset(data, indices[0]),
torch.utils.data.Subset(data, indices[1]))
def train_model(accuracies,
losses,
progress_epoch,
progress_train_data,
progress_eval_data,
model,
num_epochs,
batch_size,
optimizer,
loss_function,
augment_data,
device,
start_time):
torch.random.manual_seed(42)
numpy.random.seed(42)
torch.multiprocessing.set_start_method('spawn', force=True)
dataset = ImagesDataset("training_data")
train_data, eval_data = split_data(dataset)
augmented_train_data = AImagesDataset(train_data, augment_data)
train_loader = AsyncDataLoader(augmented_train_data,
batch_size=batch_size,
num_workers=3,
pin_memory=True,
shuffle=True)
eval_loader = AsyncDataLoader(eval_data,
batch_size=batch_size,
num_workers=3,
pin_memory=True)
for epoch in progress_epoch.range(num_epochs):
train_positives = torch.tensor(0, device=device)
eval_positives = torch.tensor(0, device=device)
train_loss = torch.tensor(0.0, device=device)
eval_loss = torch.tensor(0.0, device=device)
# Start training of model
progress_train_data.reset()
model.train()
for batch_nr, (image_t, transforms, img_index, class_ids, labels, paths) \
in enumerate(progress_train_data.iter(train_loader)):
image_t = image_t.to(device)
class_ids = class_ids.to(device)
outputs = model(image_t)
optimizer.zero_grad(set_to_none=True)
loss = loss_function(outputs, class_ids)
loss.backward()
optimizer.step()
train_loss += loss
classes = outputs.argmax(dim=1)
train_positives += torch.sum(torch.eq(classes, class_ids))
accuracies.append('train_acc', train_positives.item() / len(augmented_train_data))
losses.append('train_loss', train_loss.item() / len(augmented_train_data))
print("Train: ", train_positives.item(), "/ ", len(augmented_train_data),
" = ", train_positives.item() / len(augmented_train_data))
# evaluation of model
progress_eval_data.reset()
model.eval()
with torch.no_grad():
for (image_t, class_ids, labels, paths) in progress_eval_data.iter(eval_loader):
image_t = image_t.to(device)
class_ids = class_ids.to(device)
outputs = model(image_t)
classes = outputs.argmax(dim=1)
eval_positives += torch.sum(torch.eq(classes, class_ids))
eval_loss += loss_function(outputs, class_ids)
accuracies.append('eval_acc', eval_positives.item() / len(eval_data))
losses.append('eval_loss', eval_loss.item() / len(eval_data))
print("Eval: ", eval_positives.item(), "/ ", len(eval_data), " = ", eval_positives.item() / len(eval_data))
if eval_positives.item() / len(eval_data) > 0.5:
torch.save(model.state_dict(), f'models/model-{start_time.strftime("%Y%m%d-%H%M%S")}-epoch-{epoch}.pth')
with open(f'models/model-{start_time.strftime("%Y%m%d-%H%M%S")}.csv', 'a') as file:
file.write(f'{epoch};{len(augmented_train_data)};{len(eval_data)};{train_loss.item()};{eval_loss.item()};'
f'{train_positives};{eval_positives}\n')
def train_worker(p_epoch, p_train, p_eval, plotter_accuracies, plotter_loss, start_time):
if not torch.cuda.is_available():
raise RuntimeError("GPU not available")
device = 'cuda'
model.to(device)
num_epochs = 1000000
batch_size = 64
optimizer = torch.optim.Adam(model.parameters(),
lr=0.0001,
fused=True)
loss_function = torch.nn.CrossEntropyLoss()
augment_data = True
file_name = f'models/model-{start_time.strftime("%Y%m%d-%H%M%S")}.csv'
with open(file_name.replace(".csv", ".txt"), 'a') as file:
file.write(f"device: {device}\n")
file.write(f"batch_size: {batch_size}\n")
file.write(f"optimizer: {optimizer}\n")
file.write(f"loss_function: {loss_function}\n")
file.write(f"augment_data: {augment_data}\n")
file.write(f"model: {model}")
train_model(plotter_accuracies, plotter_loss, p_epoch, p_train, p_eval,
model,
num_epochs,
batch_size,
optimizer,
loss_function,
augment_data,
device,
start_time)