added better splitting, fixed accuracy bug, set save threshold to 50%
This commit is contained in:
parent
7466017933
commit
990d9a84b4
33
cnn_train.py
33
cnn_train.py
|
|
@ -11,8 +11,24 @@ from AImageDataset import AImagesDataset
|
|||
from AsyncDataLoader import AsyncDataLoader
|
||||
|
||||
|
||||
def split_data(data):
|
||||
pass
|
||||
def split_data(data: ImagesDataset):
|
||||
class_nums = (torch.bincount(
|
||||
torch.Tensor([data.classnames_to_ids[name] for _, name in data.filenames_classnames]).long())
|
||||
.tolist())
|
||||
|
||||
indices = ([], [])
|
||||
for class_id in range(len(class_nums)):
|
||||
class_num = class_nums[class_id]
|
||||
index_perm = torch.randperm(class_num).tolist()
|
||||
|
||||
class_indices = [i for i, e in enumerate(data.filenames_classnames) if data.classnames_to_ids[e[1]] == class_id]
|
||||
|
||||
indices[0].extend([class_indices[i] for i in index_perm[0::2]])
|
||||
indices[1].extend([class_indices[i] for i in index_perm[1::2]])
|
||||
|
||||
return (torch.utils.data.Subset(data, indices[0]),
|
||||
torch.utils.data.Subset(data, indices[1]))
|
||||
|
||||
|
||||
def train_model(accuracies,
|
||||
losses,
|
||||
|
|
@ -33,9 +49,10 @@ def train_model(accuracies,
|
|||
|
||||
dataset = ImagesDataset("training_data")
|
||||
|
||||
# dataset = torch.utils.data.Subset(dataset, range(0, 1024))
|
||||
# dataset = torch.utils.data.Subset(dataset, range(0, 32))
|
||||
|
||||
train_data, eval_data = split_data(dataset)
|
||||
# train_data, eval_data = torch.utils.data.random_split(dataset, [0.5, 0.5])
|
||||
|
||||
augmented_train_data = AImagesDataset(train_data, augment_data)
|
||||
train_loader = AsyncDataLoader(augmented_train_data,
|
||||
|
|
@ -77,8 +94,7 @@ def train_model(accuracies,
|
|||
|
||||
train_loss += loss
|
||||
|
||||
outputs.flatten()
|
||||
classes = outputs.argmax()
|
||||
classes = outputs.argmax(dim=1)
|
||||
train_positives += torch.sum(torch.eq(classes, class_ids))
|
||||
|
||||
accuracies.append('train_acc', train_positives.item() / len(augmented_train_data))
|
||||
|
|
@ -98,8 +114,7 @@ def train_model(accuracies,
|
|||
class_ids = class_ids.to(device)
|
||||
|
||||
outputs = model(image_t)
|
||||
outputs.flatten()
|
||||
classes = outputs.argmax()
|
||||
classes = outputs.argmax(dim=1)
|
||||
|
||||
eval_positives += torch.sum(torch.eq(classes, class_ids))
|
||||
eval_loss += loss_function(outputs, class_ids)
|
||||
|
|
@ -112,7 +127,7 @@ def train_model(accuracies,
|
|||
|
||||
# print(f"Epoch: {epoch} --- Train loss: {train_loss:7.4f} --- Eval loss: {eval_loss:7.4f}")
|
||||
|
||||
if eval_positives.item() / len(eval_data) > 0.3:
|
||||
if eval_positives.item() / len(eval_data) > 0.5:
|
||||
torch.save(model.state_dict(), f'models/model-{start_time.strftime("%Y%m%d-%H%M%S")}-epoch-{epoch}.pt')
|
||||
with open(f'models/model-{start_time.strftime("%Y%m%d-%H%M%S")}.csv', 'a') as file:
|
||||
file.write(f'{epoch};{len(augmented_train_data)};{len(eval_data)};{train_loss.item()};{eval_loss.item()};'
|
||||
|
|
@ -137,7 +152,7 @@ def train_worker(p_epoch, p_train, p_eval, plotter_accuracies, plotter_loss, sta
|
|||
fused=True)
|
||||
loss_function = torch.nn.CrossEntropyLoss()
|
||||
|
||||
augment_data = False
|
||||
augment_data = True
|
||||
|
||||
file_name = f'models/model-{start_time.strftime("%Y%m%d-%H%M%S")}.csv'
|
||||
with open(file_name.replace(".csv", ".txt"), 'a') as file:
|
||||
|
|
|
|||
Loading…
Reference in New Issue