revamped training process
This commit is contained in:
parent
79817340dd
commit
e9d86f3309
173
cnn_train.py
173
cnn_train.py
|
|
@ -1,119 +1,152 @@
|
|||
import tkinter as tk
|
||||
import warnings
|
||||
from datetime import datetime
|
||||
|
||||
import numpy.random
|
||||
import torch.utils.data
|
||||
import torch.cuda
|
||||
from tqdm.tk import tqdm
|
||||
|
||||
from architecture import MyCNN
|
||||
from dataset import ImagesDataset
|
||||
|
||||
from AImageDataset import AImagesDataset
|
||||
|
||||
model = MyCNN(input_channels=1,
|
||||
input_size=(100, 100),
|
||||
hidden_channels=[500, 250, 100, 50],
|
||||
output_channels=20,
|
||||
use_batchnorm=True,
|
||||
kernel_size=[9, 5, 3, 3, 1],
|
||||
stride=[1, 1, 1, 1, 1],
|
||||
activation_function=torch.nn.ReLU())
|
||||
|
||||
num_epochs = 100
|
||||
|
||||
batch_size = 64
|
||||
optimizer = torch.optim.ASGD(model.parameters(),
|
||||
lr=0.001,
|
||||
lambd=1e-4,
|
||||
alpha=0.75,
|
||||
t0=1000000.0,
|
||||
weight_decay=0)
|
||||
loss_function = torch.nn.CrossEntropyLoss()
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
def train_model(accuracies,
|
||||
losses,
|
||||
progress_epoch,
|
||||
progress_train_data,
|
||||
progress_eval_data,
|
||||
model,
|
||||
num_epochs,
|
||||
batch_size,
|
||||
optimizer,
|
||||
loss_function,
|
||||
device,
|
||||
start_time):
|
||||
|
||||
torch.random.manual_seed(42)
|
||||
numpy.random.seed(42)
|
||||
|
||||
start_time = datetime.now()
|
||||
torch.multiprocessing.set_start_method('spawn', force=True)
|
||||
|
||||
dataset = ImagesDataset("training_data")
|
||||
|
||||
# dataset = torch.utils.data.Subset(dataset, range(0, 20))
|
||||
# dataset = torch.utils.data.Subset(dataset, range(0, 1024))
|
||||
|
||||
train_data, eval_data = torch.utils.data.random_split(dataset, [0.5, 0.5])
|
||||
|
||||
train_loader = torch.utils.data.DataLoader(AImagesDataset(train_data), batch_size=batch_size)
|
||||
eval_loader = torch.utils.data.DataLoader(eval_data, batch_size=1)
|
||||
augmented_train_data = AImagesDataset(train_data, False)
|
||||
train_loader = torch.utils.data.DataLoader(augmented_train_data,
|
||||
batch_size=batch_size,
|
||||
num_workers=3,
|
||||
pin_memory=True,
|
||||
shuffle=True)
|
||||
eval_loader = torch.utils.data.DataLoader(eval_data,
|
||||
batch_size=batch_size,
|
||||
num_workers=3,
|
||||
pin_memory=True)
|
||||
|
||||
if torch.cuda.is_available():
|
||||
# print("GPU available")
|
||||
model = model.cuda()
|
||||
else:
|
||||
warnings.warn("GPU not available")
|
||||
for epoch in progress_epoch.range(num_epochs):
|
||||
|
||||
train_losses = []
|
||||
eval_losses = []
|
||||
train_positives = torch.tensor(0, device=device)
|
||||
eval_positives = torch.tensor(0, device=device)
|
||||
|
||||
progress_epoch = tqdm(range(num_epochs), position=0, tk_parent=root_window)
|
||||
progress_epoch.set_description("Epoch")
|
||||
progress_train_data = tqdm(train_loader, position=1, tk_parent=root_window)
|
||||
progress_eval_data = tqdm(eval_loader, position=2, tk_parent=root_window)
|
||||
progress_train_data.set_description("Training progress")
|
||||
progress_eval_data.set_description("Evaluation progress")
|
||||
|
||||
for epoch in progress_epoch:
|
||||
|
||||
train_loss = 0
|
||||
eval_loss = 0
|
||||
|
||||
progress_train_data.reset()
|
||||
progress_eval_data.reset()
|
||||
train_loss = torch.tensor(0.0, device=device)
|
||||
eval_loss = torch.tensor(0.0, device=device)
|
||||
|
||||
# Start training of model
|
||||
|
||||
progress_train_data.reset()
|
||||
|
||||
model.train()
|
||||
|
||||
for batch_nr, (imageT, transforms, img_index, classIDs, labels, paths) in enumerate(progress_train_data):
|
||||
imageT = imageT.to('cuda')
|
||||
classIDs = classIDs.to('cuda')
|
||||
for batch_nr, (image_t, transforms, img_index, class_ids, labels, paths) \
|
||||
in enumerate(progress_train_data.iter(train_loader)):
|
||||
image_t = image_t.to(device)
|
||||
class_ids = class_ids.to(device)
|
||||
|
||||
# progress_train_data.set_postfix_str("Running model...")
|
||||
outputs = model(imageT)
|
||||
outputs = model(image_t)
|
||||
|
||||
optimizer.zero_grad()
|
||||
# progress_train_data.set_postfix_str("calculating loss...")
|
||||
loss = loss_function(outputs, classIDs)
|
||||
# progress_train_data.set_postfix_str("propagating loss...")
|
||||
optimizer.zero_grad(set_to_none=True)
|
||||
loss = loss_function(outputs, class_ids)
|
||||
loss.backward()
|
||||
# progress_train_data.set_postfix_str("optimizing...")
|
||||
optimizer.step()
|
||||
|
||||
train_loss += loss.item()
|
||||
train_loss += loss
|
||||
|
||||
mean_loss = train_loss / len(train_loader)
|
||||
train_losses.append(mean_loss)
|
||||
outputs.flatten()
|
||||
classes = outputs.argmax()
|
||||
train_positives += torch.sum(torch.eq(classes, class_ids))
|
||||
|
||||
accuracies.append('train_acc', train_positives.item() / len(augmented_train_data))
|
||||
losses.append('train_loss', train_loss.item() / len(augmented_train_data))
|
||||
print("Train: ", train_positives.item(), "/ ", len(augmented_train_data),
|
||||
" = ", train_positives.item() / len(augmented_train_data))
|
||||
|
||||
# evaluation of model
|
||||
|
||||
progress_eval_data.reset()
|
||||
|
||||
model.eval()
|
||||
|
||||
with torch.no_grad():
|
||||
for (imageT, classIDs, labels, paths) in progress_eval_data:
|
||||
imageT = imageT.to('cuda')
|
||||
classIDs = classIDs.to('cuda')
|
||||
for (image_t, class_ids, labels, paths) in progress_eval_data.iter(eval_loader):
|
||||
image_t = image_t.to(device)
|
||||
class_ids = class_ids.to(device)
|
||||
|
||||
outputs = model(imageT)
|
||||
loss = loss_function(outputs, classIDs)
|
||||
eval_loss = loss.item()
|
||||
outputs = model(image_t)
|
||||
outputs.flatten()
|
||||
classes = outputs.argmax()
|
||||
|
||||
eval_losses.append(eval_loss)
|
||||
eval_positives += torch.sum(torch.eq(classes, class_ids))
|
||||
eval_loss += loss_function(outputs, class_ids)
|
||||
|
||||
accuracies.append('eval_acc', eval_positives.item() / len(eval_data))
|
||||
losses.append('eval_loss', eval_loss.item() / len(eval_data))
|
||||
print("Eval: ", eval_positives.item(), "/ ", len(eval_data), " = ", eval_positives.item() / len(eval_data))
|
||||
|
||||
# print epoch summary
|
||||
|
||||
# print(f"Epoch: {epoch} --- Train loss: {train_loss:7.4f} --- Eval loss: {eval_loss:7.4f}")
|
||||
|
||||
if eval_positives.item() / len(eval_data) > 0.3:
|
||||
torch.save(model.state_dict(), f'models/model-{start_time.strftime("%Y%m%d-%H%M%S")}-epoch-{epoch}.pt')
|
||||
with open(f'models/model-{start_time.strftime("%Y%m%d-%H%M%S")}.csv', 'a') as file:
|
||||
file.write(f'{epoch};{len(augmented_train_data)};{len(eval_data)};{train_loss.item()};{eval_loss.item()};'
|
||||
f'{train_positives};{eval_positives}\n')
|
||||
|
||||
|
||||
def train_worker(p_epoch, p_train, p_eval, plotter_accuracies, plotter_loss):
|
||||
if not torch.cuda.is_available():
|
||||
raise RuntimeError("GPU not available")
|
||||
|
||||
device = 'cuda'
|
||||
|
||||
model = MyCNN(input_channels=1,
|
||||
input_size=(100, 100)).to(device)
|
||||
|
||||
num_epochs = 1000000
|
||||
|
||||
batch_size = 64
|
||||
optimizer = torch.optim.Adam(model.parameters(),
|
||||
lr=0.00005,
|
||||
# weight_decay=0.1,
|
||||
fused=True)
|
||||
loss_function = torch.nn.CrossEntropyLoss()
|
||||
|
||||
start_time = datetime.now()
|
||||
|
||||
file_name = f'models/model-{start_time.strftime("%Y%m%d-%H%M%S")}.csv'
|
||||
with open(file_name.replace(".csv", ".txt"), 'a') as file:
|
||||
file.write(f"device: {device}\n")
|
||||
file.write(f"batch_size: {batch_size}\n")
|
||||
file.write(f"optimizer: {optimizer}\n")
|
||||
file.write(f"loss_function: {loss_function}\n")
|
||||
file.write(f"model: {model}")
|
||||
|
||||
train_model(plotter_accuracies, plotter_loss, p_epoch, p_train, p_eval,
|
||||
model,
|
||||
num_epochs,
|
||||
batch_size,
|
||||
optimizer,
|
||||
loss_function,
|
||||
device,
|
||||
start_time)
|
||||
|
|
|
|||
Loading…
Reference in New Issue