started data split

This commit is contained in:
Patrick 2024-07-07 11:56:19 +02:00
parent b80b24e07e
commit 57bd0c798b
1 changed files with 11 additions and 5 deletions

View File

@ -11,6 +11,9 @@ from AImageDataset import AImagesDataset
from AsyncDataLoader import AsyncDataLoader from AsyncDataLoader import AsyncDataLoader
def split_data(data):
pass
def train_model(accuracies, def train_model(accuracies,
losses, losses,
progress_epoch, progress_epoch,
@ -21,6 +24,7 @@ def train_model(accuracies,
batch_size, batch_size,
optimizer, optimizer,
loss_function, loss_function,
augment_data,
device, device,
start_time): start_time):
torch.random.manual_seed(42) torch.random.manual_seed(42)
@ -31,9 +35,9 @@ def train_model(accuracies,
# dataset = torch.utils.data.Subset(dataset, range(0, 1024)) # dataset = torch.utils.data.Subset(dataset, range(0, 1024))
train_data, eval_data = torch.utils.data.random_split(dataset, [0.5, 0.5]) train_data, eval_data = split_data(dataset)
augmented_train_data = AImagesDataset(train_data, True) augmented_train_data = AImagesDataset(train_data, augment_data)
train_loader = AsyncDataLoader(augmented_train_data, train_loader = AsyncDataLoader(augmented_train_data,
batch_size=batch_size, batch_size=batch_size,
num_workers=3, num_workers=3,
@ -115,7 +119,7 @@ def train_model(accuracies,
f'{train_positives};{eval_positives}\n') f'{train_positives};{eval_positives}\n')
def train_worker(p_epoch, p_train, p_eval, plotter_accuracies, plotter_loss): def train_worker(p_epoch, p_train, p_eval, plotter_accuracies, plotter_loss, start_time):
if not torch.cuda.is_available(): if not torch.cuda.is_available():
raise RuntimeError("GPU not available") raise RuntimeError("GPU not available")
@ -128,12 +132,12 @@ def train_worker(p_epoch, p_train, p_eval, plotter_accuracies, plotter_loss):
batch_size = 64 batch_size = 64
optimizer = torch.optim.Adam(model.parameters(), optimizer = torch.optim.Adam(model.parameters(),
lr=0.00005, lr=0.0001,
# weight_decay=0.1, # weight_decay=0.1,
fused=True) fused=True)
loss_function = torch.nn.CrossEntropyLoss() loss_function = torch.nn.CrossEntropyLoss()
start_time = datetime.now() augment_data = False
file_name = f'models/model-{start_time.strftime("%Y%m%d-%H%M%S")}.csv' file_name = f'models/model-{start_time.strftime("%Y%m%d-%H%M%S")}.csv'
with open(file_name.replace(".csv", ".txt"), 'a') as file: with open(file_name.replace(".csv", ".txt"), 'a') as file:
@ -141,6 +145,7 @@ def train_worker(p_epoch, p_train, p_eval, plotter_accuracies, plotter_loss):
file.write(f"batch_size: {batch_size}\n") file.write(f"batch_size: {batch_size}\n")
file.write(f"optimizer: {optimizer}\n") file.write(f"optimizer: {optimizer}\n")
file.write(f"loss_function: {loss_function}\n") file.write(f"loss_function: {loss_function}\n")
file.write(f"augment_data: {augment_data}\n")
file.write(f"model: {model}") file.write(f"model: {model}")
train_model(plotter_accuracies, plotter_loss, p_epoch, p_train, p_eval, train_model(plotter_accuracies, plotter_loss, p_epoch, p_train, p_eval,
@ -149,5 +154,6 @@ def train_worker(p_epoch, p_train, p_eval, plotter_accuracies, plotter_loss):
batch_size, batch_size,
optimizer, optimizer,
loss_function, loss_function,
augment_data,
device, device,
start_time) start_time)