diff --git a/cnn_train.py b/cnn_train.py index 3275c18..8b02705 100644 --- a/cnn_train.py +++ b/cnn_train.py @@ -11,8 +11,24 @@ from AImageDataset import AImagesDataset from AsyncDataLoader import AsyncDataLoader -def split_data(data): - pass +def split_data(data: ImagesDataset): + class_nums = (torch.bincount( + torch.Tensor([data.classnames_to_ids[name] for _, name in data.filenames_classnames]).long()) + .tolist()) + + indices = ([], []) + for class_id in range(len(class_nums)): + class_num = class_nums[class_id] + index_perm = torch.randperm(class_num).tolist() + + class_indices = [i for i, e in enumerate(data.filenames_classnames) if data.classnames_to_ids[e[1]] == class_id] + + indices[0].extend([class_indices[i] for i in index_perm[0::2]]) + indices[1].extend([class_indices[i] for i in index_perm[1::2]]) + + return (torch.utils.data.Subset(data, indices[0]), + torch.utils.data.Subset(data, indices[1])) + def train_model(accuracies, losses, @@ -33,9 +49,10 @@ def train_model(accuracies, dataset = ImagesDataset("training_data") - # dataset = torch.utils.data.Subset(dataset, range(0, 1024)) + # dataset = torch.utils.data.Subset(dataset, range(0, 32)) train_data, eval_data = split_data(dataset) + # train_data, eval_data = torch.utils.data.random_split(dataset, [0.5, 0.5]) augmented_train_data = AImagesDataset(train_data, augment_data) train_loader = AsyncDataLoader(augmented_train_data, @@ -77,8 +94,7 @@ def train_model(accuracies, train_loss += loss - outputs.flatten() - classes = outputs.argmax() + classes = outputs.argmax(dim=1) train_positives += torch.sum(torch.eq(classes, class_ids)) accuracies.append('train_acc', train_positives.item() / len(augmented_train_data)) @@ -98,8 +114,7 @@ def train_model(accuracies, class_ids = class_ids.to(device) outputs = model(image_t) - outputs.flatten() - classes = outputs.argmax() + classes = outputs.argmax(dim=1) eval_positives += torch.sum(torch.eq(classes, class_ids)) eval_loss += loss_function(outputs, class_ids) @@ -112,7 +127,7 @@ def train_model(accuracies, # print(f"Epoch: {epoch} --- Train loss: {train_loss:7.4f} --- Eval loss: {eval_loss:7.4f}") - if eval_positives.item() / len(eval_data) > 0.3: + if eval_positives.item() / len(eval_data) > 0.5: torch.save(model.state_dict(), f'models/model-{start_time.strftime("%Y%m%d-%H%M%S")}-epoch-{epoch}.pt') with open(f'models/model-{start_time.strftime("%Y%m%d-%H%M%S")}.csv', 'a') as file: file.write(f'{epoch};{len(augmented_train_data)};{len(eval_data)};{train_loss.item()};{eval_loss.item()};' @@ -137,7 +152,7 @@ def train_worker(p_epoch, p_train, p_eval, plotter_accuracies, plotter_loss, sta fused=True) loss_function = torch.nn.CrossEntropyLoss() - augment_data = False + augment_data = True file_name = f'models/model-{start_time.strftime("%Y%m%d-%H%M%S")}.csv' with open(file_name.replace(".csv", ".txt"), 'a') as file: