from datetime import datetime import numpy.random import torch.utils.data import torch.cuda from architecture import model from dataset import ImagesDataset from AImageDataset import AImagesDataset from AsyncDataLoader import AsyncDataLoader def split_data(data: ImagesDataset): class_nums = (torch.bincount( torch.Tensor([data.classnames_to_ids[name] for _, name in data.filenames_classnames]).long()) .tolist()) indices = ([], []) for class_id in range(len(class_nums)): class_num = class_nums[class_id] index_perm = torch.randperm(class_num).tolist() class_indices = [i for i, e in enumerate(data.filenames_classnames) if data.classnames_to_ids[e[1]] == class_id] indices[0].extend([class_indices[i] for i in index_perm[0::2]]) indices[1].extend([class_indices[i] for i in index_perm[1::2]]) return (torch.utils.data.Subset(data, indices[0]), torch.utils.data.Subset(data, indices[1])) def train_model(accuracies, losses, progress_epoch, progress_train_data, progress_eval_data, model, num_epochs, batch_size, optimizer, loss_function, augment_data, device, start_time): torch.random.manual_seed(42) numpy.random.seed(42) torch.multiprocessing.set_start_method('spawn', force=True) dataset = ImagesDataset("training_data") train_data, eval_data = split_data(dataset) augmented_train_data = AImagesDataset(train_data, augment_data) train_loader = AsyncDataLoader(augmented_train_data, batch_size=batch_size, num_workers=3, pin_memory=True, shuffle=True) eval_loader = AsyncDataLoader(eval_data, batch_size=batch_size, num_workers=3, pin_memory=True) for epoch in progress_epoch.range(num_epochs): train_positives = torch.tensor(0, device=device) eval_positives = torch.tensor(0, device=device) train_loss = torch.tensor(0.0, device=device) eval_loss = torch.tensor(0.0, device=device) # Start training of model progress_train_data.reset() model.train() for batch_nr, (image_t, transforms, img_index, class_ids, labels, paths) \ in enumerate(progress_train_data.iter(train_loader)): image_t = image_t.to(device) class_ids = class_ids.to(device) outputs = model(image_t) optimizer.zero_grad(set_to_none=True) loss = loss_function(outputs, class_ids) loss.backward() optimizer.step() train_loss += loss classes = outputs.argmax(dim=1) train_positives += torch.sum(torch.eq(classes, class_ids)) accuracies.append('train_acc', train_positives.item() / len(augmented_train_data)) losses.append('train_loss', train_loss.item() / len(augmented_train_data)) print("Train: ", train_positives.item(), "/ ", len(augmented_train_data), " = ", train_positives.item() / len(augmented_train_data)) # evaluation of model progress_eval_data.reset() model.eval() with torch.no_grad(): for (image_t, class_ids, labels, paths) in progress_eval_data.iter(eval_loader): image_t = image_t.to(device) class_ids = class_ids.to(device) outputs = model(image_t) classes = outputs.argmax(dim=1) eval_positives += torch.sum(torch.eq(classes, class_ids)) eval_loss += loss_function(outputs, class_ids) accuracies.append('eval_acc', eval_positives.item() / len(eval_data)) losses.append('eval_loss', eval_loss.item() / len(eval_data)) print("Eval: ", eval_positives.item(), "/ ", len(eval_data), " = ", eval_positives.item() / len(eval_data)) if eval_positives.item() / len(eval_data) > 0.5: torch.save(model.state_dict(), f'models/model-{start_time.strftime("%Y%m%d-%H%M%S")}-epoch-{epoch}.pth') with open(f'models/model-{start_time.strftime("%Y%m%d-%H%M%S")}.csv', 'a') as file: file.write(f'{epoch};{len(augmented_train_data)};{len(eval_data)};{train_loss.item()};{eval_loss.item()};' f'{train_positives};{eval_positives}\n') def train_worker(p_epoch, p_train, p_eval, plotter_accuracies, plotter_loss, start_time): if not torch.cuda.is_available(): raise RuntimeError("GPU not available") device = 'cuda' model.to(device) num_epochs = 1000000 batch_size = 64 optimizer = torch.optim.Adam(model.parameters(), lr=0.0001, fused=True) loss_function = torch.nn.CrossEntropyLoss() augment_data = True file_name = f'models/model-{start_time.strftime("%Y%m%d-%H%M%S")}.csv' with open(file_name.replace(".csv", ".txt"), 'a') as file: file.write(f"device: {device}\n") file.write(f"batch_size: {batch_size}\n") file.write(f"optimizer: {optimizer}\n") file.write(f"loss_function: {loss_function}\n") file.write(f"augment_data: {augment_data}\n") file.write(f"model: {model}") train_model(plotter_accuracies, plotter_loss, p_epoch, p_train, p_eval, model, num_epochs, batch_size, optimizer, loss_function, augment_data, device, start_time)