added better splitting, fixed accuracy bug, set save threshold to 50%
This commit is contained in:
parent
7466017933
commit
990d9a84b4
33
cnn_train.py
33
cnn_train.py
|
|
@ -11,8 +11,24 @@ from AImageDataset import AImagesDataset
|
||||||
from AsyncDataLoader import AsyncDataLoader
|
from AsyncDataLoader import AsyncDataLoader
|
||||||
|
|
||||||
|
|
||||||
def split_data(data):
|
def split_data(data: ImagesDataset):
|
||||||
pass
|
class_nums = (torch.bincount(
|
||||||
|
torch.Tensor([data.classnames_to_ids[name] for _, name in data.filenames_classnames]).long())
|
||||||
|
.tolist())
|
||||||
|
|
||||||
|
indices = ([], [])
|
||||||
|
for class_id in range(len(class_nums)):
|
||||||
|
class_num = class_nums[class_id]
|
||||||
|
index_perm = torch.randperm(class_num).tolist()
|
||||||
|
|
||||||
|
class_indices = [i for i, e in enumerate(data.filenames_classnames) if data.classnames_to_ids[e[1]] == class_id]
|
||||||
|
|
||||||
|
indices[0].extend([class_indices[i] for i in index_perm[0::2]])
|
||||||
|
indices[1].extend([class_indices[i] for i in index_perm[1::2]])
|
||||||
|
|
||||||
|
return (torch.utils.data.Subset(data, indices[0]),
|
||||||
|
torch.utils.data.Subset(data, indices[1]))
|
||||||
|
|
||||||
|
|
||||||
def train_model(accuracies,
|
def train_model(accuracies,
|
||||||
losses,
|
losses,
|
||||||
|
|
@ -33,9 +49,10 @@ def train_model(accuracies,
|
||||||
|
|
||||||
dataset = ImagesDataset("training_data")
|
dataset = ImagesDataset("training_data")
|
||||||
|
|
||||||
# dataset = torch.utils.data.Subset(dataset, range(0, 1024))
|
# dataset = torch.utils.data.Subset(dataset, range(0, 32))
|
||||||
|
|
||||||
train_data, eval_data = split_data(dataset)
|
train_data, eval_data = split_data(dataset)
|
||||||
|
# train_data, eval_data = torch.utils.data.random_split(dataset, [0.5, 0.5])
|
||||||
|
|
||||||
augmented_train_data = AImagesDataset(train_data, augment_data)
|
augmented_train_data = AImagesDataset(train_data, augment_data)
|
||||||
train_loader = AsyncDataLoader(augmented_train_data,
|
train_loader = AsyncDataLoader(augmented_train_data,
|
||||||
|
|
@ -77,8 +94,7 @@ def train_model(accuracies,
|
||||||
|
|
||||||
train_loss += loss
|
train_loss += loss
|
||||||
|
|
||||||
outputs.flatten()
|
classes = outputs.argmax(dim=1)
|
||||||
classes = outputs.argmax()
|
|
||||||
train_positives += torch.sum(torch.eq(classes, class_ids))
|
train_positives += torch.sum(torch.eq(classes, class_ids))
|
||||||
|
|
||||||
accuracies.append('train_acc', train_positives.item() / len(augmented_train_data))
|
accuracies.append('train_acc', train_positives.item() / len(augmented_train_data))
|
||||||
|
|
@ -98,8 +114,7 @@ def train_model(accuracies,
|
||||||
class_ids = class_ids.to(device)
|
class_ids = class_ids.to(device)
|
||||||
|
|
||||||
outputs = model(image_t)
|
outputs = model(image_t)
|
||||||
outputs.flatten()
|
classes = outputs.argmax(dim=1)
|
||||||
classes = outputs.argmax()
|
|
||||||
|
|
||||||
eval_positives += torch.sum(torch.eq(classes, class_ids))
|
eval_positives += torch.sum(torch.eq(classes, class_ids))
|
||||||
eval_loss += loss_function(outputs, class_ids)
|
eval_loss += loss_function(outputs, class_ids)
|
||||||
|
|
@ -112,7 +127,7 @@ def train_model(accuracies,
|
||||||
|
|
||||||
# print(f"Epoch: {epoch} --- Train loss: {train_loss:7.4f} --- Eval loss: {eval_loss:7.4f}")
|
# print(f"Epoch: {epoch} --- Train loss: {train_loss:7.4f} --- Eval loss: {eval_loss:7.4f}")
|
||||||
|
|
||||||
if eval_positives.item() / len(eval_data) > 0.3:
|
if eval_positives.item() / len(eval_data) > 0.5:
|
||||||
torch.save(model.state_dict(), f'models/model-{start_time.strftime("%Y%m%d-%H%M%S")}-epoch-{epoch}.pt')
|
torch.save(model.state_dict(), f'models/model-{start_time.strftime("%Y%m%d-%H%M%S")}-epoch-{epoch}.pt')
|
||||||
with open(f'models/model-{start_time.strftime("%Y%m%d-%H%M%S")}.csv', 'a') as file:
|
with open(f'models/model-{start_time.strftime("%Y%m%d-%H%M%S")}.csv', 'a') as file:
|
||||||
file.write(f'{epoch};{len(augmented_train_data)};{len(eval_data)};{train_loss.item()};{eval_loss.item()};'
|
file.write(f'{epoch};{len(augmented_train_data)};{len(eval_data)};{train_loss.item()};{eval_loss.item()};'
|
||||||
|
|
@ -137,7 +152,7 @@ def train_worker(p_epoch, p_train, p_eval, plotter_accuracies, plotter_loss, sta
|
||||||
fused=True)
|
fused=True)
|
||||||
loss_function = torch.nn.CrossEntropyLoss()
|
loss_function = torch.nn.CrossEntropyLoss()
|
||||||
|
|
||||||
augment_data = False
|
augment_data = True
|
||||||
|
|
||||||
file_name = f'models/model-{start_time.strftime("%Y%m%d-%H%M%S")}.csv'
|
file_name = f'models/model-{start_time.strftime("%Y%m%d-%H%M%S")}.csv'
|
||||||
with open(file_name.replace(".csv", ".txt"), 'a') as file:
|
with open(file_name.replace(".csv", ".txt"), 'a') as file:
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue