added better splitting, fixed accuracy bug, set save threshold to 50%

This commit is contained in:
Patrick 2024-07-15 23:15:30 +02:00
parent 7466017933
commit 990d9a84b4
1 changed files with 24 additions and 9 deletions

View File

@ -11,8 +11,24 @@ from AImageDataset import AImagesDataset
from AsyncDataLoader import AsyncDataLoader from AsyncDataLoader import AsyncDataLoader
def split_data(data): def split_data(data: ImagesDataset):
pass class_nums = (torch.bincount(
torch.Tensor([data.classnames_to_ids[name] for _, name in data.filenames_classnames]).long())
.tolist())
indices = ([], [])
for class_id in range(len(class_nums)):
class_num = class_nums[class_id]
index_perm = torch.randperm(class_num).tolist()
class_indices = [i for i, e in enumerate(data.filenames_classnames) if data.classnames_to_ids[e[1]] == class_id]
indices[0].extend([class_indices[i] for i in index_perm[0::2]])
indices[1].extend([class_indices[i] for i in index_perm[1::2]])
return (torch.utils.data.Subset(data, indices[0]),
torch.utils.data.Subset(data, indices[1]))
def train_model(accuracies, def train_model(accuracies,
losses, losses,
@ -33,9 +49,10 @@ def train_model(accuracies,
dataset = ImagesDataset("training_data") dataset = ImagesDataset("training_data")
# dataset = torch.utils.data.Subset(dataset, range(0, 1024)) # dataset = torch.utils.data.Subset(dataset, range(0, 32))
train_data, eval_data = split_data(dataset) train_data, eval_data = split_data(dataset)
# train_data, eval_data = torch.utils.data.random_split(dataset, [0.5, 0.5])
augmented_train_data = AImagesDataset(train_data, augment_data) augmented_train_data = AImagesDataset(train_data, augment_data)
train_loader = AsyncDataLoader(augmented_train_data, train_loader = AsyncDataLoader(augmented_train_data,
@ -77,8 +94,7 @@ def train_model(accuracies,
train_loss += loss train_loss += loss
outputs.flatten() classes = outputs.argmax(dim=1)
classes = outputs.argmax()
train_positives += torch.sum(torch.eq(classes, class_ids)) train_positives += torch.sum(torch.eq(classes, class_ids))
accuracies.append('train_acc', train_positives.item() / len(augmented_train_data)) accuracies.append('train_acc', train_positives.item() / len(augmented_train_data))
@ -98,8 +114,7 @@ def train_model(accuracies,
class_ids = class_ids.to(device) class_ids = class_ids.to(device)
outputs = model(image_t) outputs = model(image_t)
outputs.flatten() classes = outputs.argmax(dim=1)
classes = outputs.argmax()
eval_positives += torch.sum(torch.eq(classes, class_ids)) eval_positives += torch.sum(torch.eq(classes, class_ids))
eval_loss += loss_function(outputs, class_ids) eval_loss += loss_function(outputs, class_ids)
@ -112,7 +127,7 @@ def train_model(accuracies,
# print(f"Epoch: {epoch} --- Train loss: {train_loss:7.4f} --- Eval loss: {eval_loss:7.4f}") # print(f"Epoch: {epoch} --- Train loss: {train_loss:7.4f} --- Eval loss: {eval_loss:7.4f}")
if eval_positives.item() / len(eval_data) > 0.3: if eval_positives.item() / len(eval_data) > 0.5:
torch.save(model.state_dict(), f'models/model-{start_time.strftime("%Y%m%d-%H%M%S")}-epoch-{epoch}.pt') torch.save(model.state_dict(), f'models/model-{start_time.strftime("%Y%m%d-%H%M%S")}-epoch-{epoch}.pt')
with open(f'models/model-{start_time.strftime("%Y%m%d-%H%M%S")}.csv', 'a') as file: with open(f'models/model-{start_time.strftime("%Y%m%d-%H%M%S")}.csv', 'a') as file:
file.write(f'{epoch};{len(augmented_train_data)};{len(eval_data)};{train_loss.item()};{eval_loss.item()};' file.write(f'{epoch};{len(augmented_train_data)};{len(eval_data)};{train_loss.item()};{eval_loss.item()};'
@ -137,7 +152,7 @@ def train_worker(p_epoch, p_train, p_eval, plotter_accuracies, plotter_loss, sta
fused=True) fused=True)
loss_function = torch.nn.CrossEntropyLoss() loss_function = torch.nn.CrossEntropyLoss()
augment_data = False augment_data = True
file_name = f'models/model-{start_time.strftime("%Y%m%d-%H%M%S")}.csv' file_name = f'models/model-{start_time.strftime("%Y%m%d-%H%M%S")}.csv'
with open(file_name.replace(".csv", ".txt"), 'a') as file: with open(file_name.replace(".csv", ".txt"), 'a') as file: