diff --git a/cnn_train.py b/cnn_train.py index ed95677..3275c18 100644 --- a/cnn_train.py +++ b/cnn_train.py @@ -11,6 +11,9 @@ from AImageDataset import AImagesDataset from AsyncDataLoader import AsyncDataLoader +def split_data(data): + pass + def train_model(accuracies, losses, progress_epoch, @@ -21,6 +24,7 @@ def train_model(accuracies, batch_size, optimizer, loss_function, + augment_data, device, start_time): torch.random.manual_seed(42) @@ -31,9 +35,9 @@ def train_model(accuracies, # dataset = torch.utils.data.Subset(dataset, range(0, 1024)) - train_data, eval_data = torch.utils.data.random_split(dataset, [0.5, 0.5]) + train_data, eval_data = split_data(dataset) - augmented_train_data = AImagesDataset(train_data, True) + augmented_train_data = AImagesDataset(train_data, augment_data) train_loader = AsyncDataLoader(augmented_train_data, batch_size=batch_size, num_workers=3, @@ -115,7 +119,7 @@ def train_model(accuracies, f'{train_positives};{eval_positives}\n') -def train_worker(p_epoch, p_train, p_eval, plotter_accuracies, plotter_loss): +def train_worker(p_epoch, p_train, p_eval, plotter_accuracies, plotter_loss, start_time): if not torch.cuda.is_available(): raise RuntimeError("GPU not available") @@ -128,12 +132,12 @@ def train_worker(p_epoch, p_train, p_eval, plotter_accuracies, plotter_loss): batch_size = 64 optimizer = torch.optim.Adam(model.parameters(), - lr=0.00005, + lr=0.0001, # weight_decay=0.1, fused=True) loss_function = torch.nn.CrossEntropyLoss() - start_time = datetime.now() + augment_data = False file_name = f'models/model-{start_time.strftime("%Y%m%d-%H%M%S")}.csv' with open(file_name.replace(".csv", ".txt"), 'a') as file: @@ -141,6 +145,7 @@ def train_worker(p_epoch, p_train, p_eval, plotter_accuracies, plotter_loss): file.write(f"batch_size: {batch_size}\n") file.write(f"optimizer: {optimizer}\n") file.write(f"loss_function: {loss_function}\n") + file.write(f"augment_data: {augment_data}\n") file.write(f"model: {model}") train_model(plotter_accuracies, plotter_loss, p_epoch, p_train, p_eval, @@ -149,5 +154,6 @@ def train_worker(p_epoch, p_train, p_eval, plotter_accuracies, plotter_loss): batch_size, optimizer, loss_function, + augment_data, device, start_time)