started data split
This commit is contained in:
parent
b80b24e07e
commit
57bd0c798b
16
cnn_train.py
16
cnn_train.py
|
|
@ -11,6 +11,9 @@ from AImageDataset import AImagesDataset
|
|||
from AsyncDataLoader import AsyncDataLoader
|
||||
|
||||
|
||||
def split_data(data):
|
||||
pass
|
||||
|
||||
def train_model(accuracies,
|
||||
losses,
|
||||
progress_epoch,
|
||||
|
|
@ -21,6 +24,7 @@ def train_model(accuracies,
|
|||
batch_size,
|
||||
optimizer,
|
||||
loss_function,
|
||||
augment_data,
|
||||
device,
|
||||
start_time):
|
||||
torch.random.manual_seed(42)
|
||||
|
|
@ -31,9 +35,9 @@ def train_model(accuracies,
|
|||
|
||||
# dataset = torch.utils.data.Subset(dataset, range(0, 1024))
|
||||
|
||||
train_data, eval_data = torch.utils.data.random_split(dataset, [0.5, 0.5])
|
||||
train_data, eval_data = split_data(dataset)
|
||||
|
||||
augmented_train_data = AImagesDataset(train_data, True)
|
||||
augmented_train_data = AImagesDataset(train_data, augment_data)
|
||||
train_loader = AsyncDataLoader(augmented_train_data,
|
||||
batch_size=batch_size,
|
||||
num_workers=3,
|
||||
|
|
@ -115,7 +119,7 @@ def train_model(accuracies,
|
|||
f'{train_positives};{eval_positives}\n')
|
||||
|
||||
|
||||
def train_worker(p_epoch, p_train, p_eval, plotter_accuracies, plotter_loss):
|
||||
def train_worker(p_epoch, p_train, p_eval, plotter_accuracies, plotter_loss, start_time):
|
||||
if not torch.cuda.is_available():
|
||||
raise RuntimeError("GPU not available")
|
||||
|
||||
|
|
@ -128,12 +132,12 @@ def train_worker(p_epoch, p_train, p_eval, plotter_accuracies, plotter_loss):
|
|||
|
||||
batch_size = 64
|
||||
optimizer = torch.optim.Adam(model.parameters(),
|
||||
lr=0.00005,
|
||||
lr=0.0001,
|
||||
# weight_decay=0.1,
|
||||
fused=True)
|
||||
loss_function = torch.nn.CrossEntropyLoss()
|
||||
|
||||
start_time = datetime.now()
|
||||
augment_data = False
|
||||
|
||||
file_name = f'models/model-{start_time.strftime("%Y%m%d-%H%M%S")}.csv'
|
||||
with open(file_name.replace(".csv", ".txt"), 'a') as file:
|
||||
|
|
@ -141,6 +145,7 @@ def train_worker(p_epoch, p_train, p_eval, plotter_accuracies, plotter_loss):
|
|||
file.write(f"batch_size: {batch_size}\n")
|
||||
file.write(f"optimizer: {optimizer}\n")
|
||||
file.write(f"loss_function: {loss_function}\n")
|
||||
file.write(f"augment_data: {augment_data}\n")
|
||||
file.write(f"model: {model}")
|
||||
|
||||
train_model(plotter_accuracies, plotter_loss, p_epoch, p_train, p_eval,
|
||||
|
|
@ -149,5 +154,6 @@ def train_worker(p_epoch, p_train, p_eval, plotter_accuracies, plotter_loss):
|
|||
batch_size,
|
||||
optimizer,
|
||||
loss_function,
|
||||
augment_data,
|
||||
device,
|
||||
start_time)
|
||||
|
|
|
|||
Loading…
Reference in New Issue