Compare commits
10 Commits
992f84fa5a
...
ffe355e47c
| Author | SHA1 | Date |
|---|---|---|
|
|
ffe355e47c | |
|
|
c4a93e9fae | |
|
|
39aa685e28 | |
|
|
990d9a84b4 | |
|
|
7466017933 | |
|
|
48abee5ccd | |
|
|
57bd0c798b | |
|
|
b80b24e07e | |
|
|
78f08419ba | |
|
|
8454c73d77 |
|
|
@ -0,0 +1,2 @@
|
||||||
|
models/*
|
||||||
|
training_data/*
|
||||||
|
|
@ -0,0 +1,82 @@
|
||||||
|
import queue
|
||||||
|
import threading
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
import torch.utils.data
|
||||||
|
|
||||||
|
|
||||||
|
class AsyncDataLoader(torch.utils.data.DataLoader):
|
||||||
|
|
||||||
|
def __init__(self, data, prefetch_size: Optional[int] = None, **kwargs):
|
||||||
|
super().__init__(data, **kwargs)
|
||||||
|
|
||||||
|
self.data_len = len(self)
|
||||||
|
|
||||||
|
self.prefetch_size = prefetch_size if prefetch_size is not None else self.data_len
|
||||||
|
self.__data_queue = queue.Queue()
|
||||||
|
self.__dataset_access = threading.Lock()
|
||||||
|
self.__dataset_access_request = threading.Event()
|
||||||
|
self.__dataset_access_request_release = threading.Event()
|
||||||
|
self.__data_extract_event = threading.Event()
|
||||||
|
|
||||||
|
self._worker_thread = threading.Thread(target=AsyncDataLoader.__async_loader, args=(self,), daemon=True)
|
||||||
|
self._worker_thread.start()
|
||||||
|
|
||||||
|
def __iter__(self):
|
||||||
|
for _ in range(self.data_len):
|
||||||
|
_, data = self.__data_queue.get()
|
||||||
|
self.__data_extract_event.set()
|
||||||
|
yield data
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def __async_loader(self):
|
||||||
|
self.__dataset_access.acquire(blocking=True)
|
||||||
|
|
||||||
|
while True:
|
||||||
|
|
||||||
|
for enumerated_data in enumerate(super().__iter__()):
|
||||||
|
self.__data_queue.put(enumerated_data)
|
||||||
|
|
||||||
|
while not self.__data_queue.qsize() < self.prefetch_size:
|
||||||
|
self.__dataset_access.release()
|
||||||
|
|
||||||
|
self.__data_extract_event.clear()
|
||||||
|
self.__data_extract_event.wait(10)
|
||||||
|
|
||||||
|
self.__dataset_access.acquire()
|
||||||
|
|
||||||
|
if self.__dataset_access_request.is_set():
|
||||||
|
self.__dataset_access.release()
|
||||||
|
self.__dataset_access_request_release.wait()
|
||||||
|
|
||||||
|
self.__dataset_access.acquire()
|
||||||
|
|
||||||
|
self.__dataset_access.release()
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
from torch.utils.data import TensorDataset, random_split, dataloader
|
||||||
|
import torch.random
|
||||||
|
import time
|
||||||
|
|
||||||
|
torch.random.manual_seed(0)
|
||||||
|
num_train_examples = 10
|
||||||
|
training_data = torch.rand(num_train_examples, 32)
|
||||||
|
random_function = torch.rand(32)
|
||||||
|
sums = torch.sum(training_data * random_function, dim=1) ** 2
|
||||||
|
targets = torch.where(sums < 100, sums, torch.zeros_like(sums))
|
||||||
|
all_data = TensorDataset(training_data, targets)
|
||||||
|
train_data, eval_data = random_split(all_data, [0.5, 0.5])
|
||||||
|
|
||||||
|
train = AsyncDataLoader(train_data, batch_size=1, prefetch_size=4)
|
||||||
|
eval = AsyncDataLoader(eval_data, batch_size=1, prefetch_size=4)
|
||||||
|
|
||||||
|
for e in range(3):
|
||||||
|
print('e', e)
|
||||||
|
for i, _ in enumerate(train):
|
||||||
|
print("train", i)
|
||||||
|
time.sleep(1)
|
||||||
|
|
||||||
|
for i, _ in enumerate(eval):
|
||||||
|
print('eval', i)
|
||||||
|
time.sleep(1)
|
||||||
|
|
@ -47,7 +47,7 @@ class AsyncProgress(ttk.Frame):
|
||||||
def __repr__(self):
|
def __repr__(self):
|
||||||
ips = 1 / self.last_elapsed_time if self.last_elapsed_time != 0 else 0
|
ips = 1 / self.last_elapsed_time if self.last_elapsed_time != 0 else 0
|
||||||
return (f"{'{v:{p}d}'.format(v=self.current, p=len(str(self.total)))}/{self.total} "
|
return (f"{'{v:{p}d}'.format(v=self.current, p=len(str(self.total)))}/{self.total} "
|
||||||
f"| {self.last_elapsed_time}s/it" if ips < 1 else f"{ips:.2f}it/s "
|
f"| {f'{self.last_elapsed_time:.2f}s/it' if ips < 1 else f'{ips:.2f}it/s'} "
|
||||||
f"| elapsed: {format_time(time.time() - self.start_time)} "
|
f"| elapsed: {format_time(time.time() - self.start_time)} "
|
||||||
f"| remaining: {format_time((self.total - self.current) * self.last_elapsed_time):} ")
|
f"| remaining: {format_time((self.total - self.current) * self.last_elapsed_time):} ")
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -25,6 +25,9 @@ class Plotter(tk.Frame):
|
||||||
def append(self, name: str, value):
|
def append(self, name: str, value):
|
||||||
self.__queue.put((name, value))
|
self.__queue.put((name, value))
|
||||||
|
|
||||||
|
def save(self, filename):
|
||||||
|
self.figure.savefig(filename)
|
||||||
|
|
||||||
def __update(self):
|
def __update(self):
|
||||||
to_update = False
|
to_update = False
|
||||||
while not self.__queue.empty():
|
while not self.__queue.empty():
|
||||||
|
|
|
||||||
|
|
@ -4,93 +4,45 @@ import torch.nn
|
||||||
|
|
||||||
class MyCNN(torch.nn.Module):
|
class MyCNN(torch.nn.Module):
|
||||||
def __init__(self,
|
def __init__(self,
|
||||||
input_channels: int,
|
input_channels: int):
|
||||||
input_size: tuple[int, int]):
|
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
# input_layer = torch.nn.Conv2d(in_channels=input_channels,
|
|
||||||
# out_channels=hidden_channels[0],
|
|
||||||
# kernel_size=kernel_size[0],
|
|
||||||
# padding='same' if (stride[0] == 1 or stride[0] == 0) else 'valid',
|
|
||||||
# stride=stride[0],
|
|
||||||
# bias=not use_batchnorm)
|
|
||||||
# hidden_layers = [torch.nn.Conv2d(hidden_channels[i - 1],
|
|
||||||
# hidden_channels[i],
|
|
||||||
# kernel_size[i],
|
|
||||||
# padding='same' if (stride[i] == 1 or stride[i] == 0) else 'valid',
|
|
||||||
# stride=stride[i],
|
|
||||||
# bias=not use_batchnorm)
|
|
||||||
# for i in range(1, len(hidden_channels))]
|
|
||||||
# self.output_layer = torch.nn.Linear(hidden_channels[-1] * input_size[0] * input_size[1], output_channels)
|
|
||||||
|
|
||||||
self.layers = torch.nn.Sequential(
|
self.layers = torch.nn.Sequential(
|
||||||
torch.nn.Conv2d(in_channels=input_channels, out_channels=64, kernel_size=3, padding='same', bias=False),
|
torch.nn.Conv2d(in_channels=input_channels, out_channels=32, kernel_size=3, padding='same', bias=False),
|
||||||
torch.nn.BatchNorm2d(64),
|
torch.nn.BatchNorm2d(32),
|
||||||
torch.nn.MaxPool2d(kernel_size=3, stride=2),
|
|
||||||
torch.nn.ReLU(),
|
torch.nn.ReLU(),
|
||||||
|
torch.nn.Conv2d(in_channels=32, out_channels=32, kernel_size=3, padding='same', bias=False),
|
||||||
|
torch.nn.BatchNorm2d(32),
|
||||||
|
torch.nn.ReLU(),
|
||||||
|
torch.nn.MaxPool2d(kernel_size=3, padding=1),
|
||||||
|
|
||||||
|
torch.nn.Dropout2d(0.1),
|
||||||
|
torch.nn.Conv2d(in_channels=32, out_channels=64, kernel_size=3, padding='same', bias=False),
|
||||||
|
torch.nn.BatchNorm2d(64),
|
||||||
|
torch.nn.ReLU(),
|
||||||
|
torch.nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, padding='same', bias=False),
|
||||||
|
torch.nn.BatchNorm2d(64),
|
||||||
|
torch.nn.ReLU(),
|
||||||
|
torch.nn.MaxPool2d(kernel_size=3, padding=1),
|
||||||
|
|
||||||
|
torch.nn.Dropout2d(0.1),
|
||||||
torch.nn.Conv2d(in_channels=64, out_channels=128, kernel_size=3, padding='same', bias=False),
|
torch.nn.Conv2d(in_channels=64, out_channels=128, kernel_size=3, padding='same', bias=False),
|
||||||
torch.nn.BatchNorm2d(128),
|
torch.nn.BatchNorm2d(128),
|
||||||
torch.nn.MaxPool2d(kernel_size=3, stride=2),
|
|
||||||
torch.nn.ReLU(),
|
torch.nn.ReLU(),
|
||||||
|
torch.nn.Conv2d(in_channels=128, out_channels=128, kernel_size=3, padding='same', bias=False),
|
||||||
torch.nn.Conv2d(in_channels=128, out_channels=256, kernel_size=3, padding='same', bias=False),
|
torch.nn.BatchNorm2d(128),
|
||||||
torch.nn.BatchNorm2d(256),
|
|
||||||
torch.nn.MaxPool2d(kernel_size=3, stride=2),
|
|
||||||
torch.nn.ReLU(),
|
|
||||||
|
|
||||||
torch.nn.Conv2d(in_channels=256, out_channels=512, kernel_size=3, padding='same', bias=False),
|
|
||||||
torch.nn.BatchNorm2d(512),
|
|
||||||
torch.nn.MaxPool2d(kernel_size=3, stride=2),
|
|
||||||
torch.nn.ReLU(),
|
torch.nn.ReLU(),
|
||||||
|
torch.nn.MaxPool2d(kernel_size=3, padding=1),
|
||||||
|
|
||||||
torch.nn.Flatten(),
|
torch.nn.Flatten(),
|
||||||
torch.nn.Linear(in_features=12800, out_features=4096, bias=False),
|
torch.nn.Dropout(0.25),
|
||||||
|
torch.nn.Linear(in_features=2048, out_features=1024),
|
||||||
torch.nn.ReLU(),
|
torch.nn.ReLU(),
|
||||||
torch.nn.Linear(in_features=4096, out_features=4096, bias=False),
|
torch.nn.Linear(in_features=1024, out_features=512),
|
||||||
torch.nn.ReLU(),
|
torch.nn.ReLU(),
|
||||||
torch.nn.Linear(in_features=4096, out_features=20, bias=False),
|
torch.nn.Linear(in_features=512, out_features=20, bias=False)
|
||||||
torch.nn.Softmax(dim=1)
|
|
||||||
)
|
)
|
||||||
|
|
||||||
# self.layers = torch.nn.Sequential(
|
|
||||||
# torch.nn.Conv2d(in_channels=input_channels, out_channels=64, kernel_size=1, padding='same', bias=False),
|
|
||||||
# torch.nn.BatchNorm2d(64),
|
|
||||||
# torch.nn.ReLU(),
|
|
||||||
# torch.nn.Conv2d(in_channels=64, out_channels=64, kernel_size=7, padding='same', bias=False),
|
|
||||||
# torch.nn.BatchNorm2d(64),
|
|
||||||
# torch.nn.ReLU(),
|
|
||||||
# torch.nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, padding='same', bias=False),
|
|
||||||
# torch.nn.BatchNorm2d(64),
|
|
||||||
# torch.nn.ReLU(),
|
|
||||||
# torch.nn.MaxPool2d(kernel_size=3, padding=1),
|
|
||||||
# torch.nn.Dropout2d(0.1),
|
|
||||||
# torch.nn.Conv2d(in_channels=64, out_channels=32, kernel_size=5, padding='same', bias=False),
|
|
||||||
# torch.nn.BatchNorm2d(32),
|
|
||||||
# torch.nn.ReLU(),
|
|
||||||
# torch.nn.Conv2d(in_channels=32, out_channels=32, kernel_size=3, padding='same', bias=False),
|
|
||||||
# torch.nn.BatchNorm2d(32),
|
|
||||||
# torch.nn.ReLU(),
|
|
||||||
# torch.nn.Conv2d(in_channels=32, out_channels=32, kernel_size=1, padding='same', bias=False),
|
|
||||||
# torch.nn.BatchNorm2d(32),
|
|
||||||
# torch.nn.ReLU(),
|
|
||||||
# torch.nn.MaxPool2d(kernel_size=3, padding=1),
|
|
||||||
# torch.nn.Dropout2d(0.1),
|
|
||||||
# torch.nn.Conv2d(in_channels=32, out_channels=16, kernel_size=3, padding='same', bias=False),
|
|
||||||
# torch.nn.BatchNorm2d(16),
|
|
||||||
# torch.nn.ReLU(),
|
|
||||||
# torch.nn.Conv2d(in_channels=16, out_channels=16, kernel_size=1, padding='same', bias=False),
|
|
||||||
# torch.nn.BatchNorm2d(16),
|
|
||||||
# torch.nn.ReLU(),
|
|
||||||
# torch.nn.MaxPool2d(kernel_size=3, padding=1),
|
|
||||||
# torch.nn.Flatten(),
|
|
||||||
# torch.nn.Dropout(0.25),
|
|
||||||
# torch.nn.Linear(in_features=256, out_features=512),
|
|
||||||
# torch.nn.ReLU(),
|
|
||||||
# torch.nn.Linear(in_features=512, out_features=20, bias=False),
|
|
||||||
# # torch.nn.Softmax(dim=1),
|
|
||||||
# )
|
|
||||||
|
|
||||||
def forward(self, input_images: torch.Tensor) -> torch.Tensor:
|
def forward(self, input_images: torch.Tensor) -> torch.Tensor:
|
||||||
return self.layers(input_images)
|
return self.layers(input_images)
|
||||||
|
|
||||||
|
|
@ -98,4 +50,4 @@ class MyCNN(torch.nn.Module):
|
||||||
return str(self.layers)
|
return str(self.layers)
|
||||||
|
|
||||||
|
|
||||||
# model = MyCNN()
|
model = MyCNN(input_channels=1)
|
||||||
|
|
|
||||||
75
cnn_train.py
|
|
@ -4,10 +4,30 @@ import numpy.random
|
||||||
import torch.utils.data
|
import torch.utils.data
|
||||||
import torch.cuda
|
import torch.cuda
|
||||||
|
|
||||||
from architecture import MyCNN
|
from architecture import model
|
||||||
from dataset import ImagesDataset
|
from dataset import ImagesDataset
|
||||||
|
|
||||||
from AImageDataset import AImagesDataset
|
from AImageDataset import AImagesDataset
|
||||||
|
from AsyncDataLoader import AsyncDataLoader
|
||||||
|
|
||||||
|
|
||||||
|
def split_data(data: ImagesDataset):
|
||||||
|
class_nums = (torch.bincount(
|
||||||
|
torch.Tensor([data.classnames_to_ids[name] for _, name in data.filenames_classnames]).long())
|
||||||
|
.tolist())
|
||||||
|
|
||||||
|
indices = ([], [])
|
||||||
|
for class_id in range(len(class_nums)):
|
||||||
|
class_num = class_nums[class_id]
|
||||||
|
index_perm = torch.randperm(class_num).tolist()
|
||||||
|
|
||||||
|
class_indices = [i for i, e in enumerate(data.filenames_classnames) if data.classnames_to_ids[e[1]] == class_id]
|
||||||
|
|
||||||
|
indices[0].extend([class_indices[i] for i in index_perm[0::2]])
|
||||||
|
indices[1].extend([class_indices[i] for i in index_perm[1::2]])
|
||||||
|
|
||||||
|
return (torch.utils.data.Subset(data, indices[0]),
|
||||||
|
torch.utils.data.Subset(data, indices[1]))
|
||||||
|
|
||||||
|
|
||||||
def train_model(accuracies,
|
def train_model(accuracies,
|
||||||
|
|
@ -20,29 +40,27 @@ def train_model(accuracies,
|
||||||
batch_size,
|
batch_size,
|
||||||
optimizer,
|
optimizer,
|
||||||
loss_function,
|
loss_function,
|
||||||
|
augment_data,
|
||||||
device,
|
device,
|
||||||
start_time):
|
start_time):
|
||||||
|
|
||||||
torch.random.manual_seed(42)
|
torch.random.manual_seed(42)
|
||||||
numpy.random.seed(42)
|
numpy.random.seed(42)
|
||||||
torch.multiprocessing.set_start_method('spawn', force=True)
|
torch.multiprocessing.set_start_method('spawn', force=True)
|
||||||
|
|
||||||
dataset = ImagesDataset("training_data")
|
dataset = ImagesDataset("training_data")
|
||||||
|
|
||||||
# dataset = torch.utils.data.Subset(dataset, range(0, 1024))
|
train_data, eval_data = split_data(dataset)
|
||||||
|
|
||||||
train_data, eval_data = torch.utils.data.random_split(dataset, [0.5, 0.5])
|
augmented_train_data = AImagesDataset(train_data, augment_data)
|
||||||
|
train_loader = AsyncDataLoader(augmented_train_data,
|
||||||
augmented_train_data = AImagesDataset(train_data, False)
|
batch_size=batch_size,
|
||||||
train_loader = torch.utils.data.DataLoader(augmented_train_data,
|
num_workers=3,
|
||||||
batch_size=batch_size,
|
pin_memory=True,
|
||||||
num_workers=3,
|
shuffle=True)
|
||||||
pin_memory=True,
|
eval_loader = AsyncDataLoader(eval_data,
|
||||||
shuffle=True)
|
batch_size=batch_size,
|
||||||
eval_loader = torch.utils.data.DataLoader(eval_data,
|
num_workers=3,
|
||||||
batch_size=batch_size,
|
pin_memory=True)
|
||||||
num_workers=3,
|
|
||||||
pin_memory=True)
|
|
||||||
|
|
||||||
for epoch in progress_epoch.range(num_epochs):
|
for epoch in progress_epoch.range(num_epochs):
|
||||||
|
|
||||||
|
|
@ -60,6 +78,7 @@ def train_model(accuracies,
|
||||||
|
|
||||||
for batch_nr, (image_t, transforms, img_index, class_ids, labels, paths) \
|
for batch_nr, (image_t, transforms, img_index, class_ids, labels, paths) \
|
||||||
in enumerate(progress_train_data.iter(train_loader)):
|
in enumerate(progress_train_data.iter(train_loader)):
|
||||||
|
|
||||||
image_t = image_t.to(device)
|
image_t = image_t.to(device)
|
||||||
class_ids = class_ids.to(device)
|
class_ids = class_ids.to(device)
|
||||||
|
|
||||||
|
|
@ -72,8 +91,7 @@ def train_model(accuracies,
|
||||||
|
|
||||||
train_loss += loss
|
train_loss += loss
|
||||||
|
|
||||||
outputs.flatten()
|
classes = outputs.argmax(dim=1)
|
||||||
classes = outputs.argmax()
|
|
||||||
train_positives += torch.sum(torch.eq(classes, class_ids))
|
train_positives += torch.sum(torch.eq(classes, class_ids))
|
||||||
|
|
||||||
accuracies.append('train_acc', train_positives.item() / len(augmented_train_data))
|
accuracies.append('train_acc', train_positives.item() / len(augmented_train_data))
|
||||||
|
|
@ -93,8 +111,7 @@ def train_model(accuracies,
|
||||||
class_ids = class_ids.to(device)
|
class_ids = class_ids.to(device)
|
||||||
|
|
||||||
outputs = model(image_t)
|
outputs = model(image_t)
|
||||||
outputs.flatten()
|
classes = outputs.argmax(dim=1)
|
||||||
classes = outputs.argmax()
|
|
||||||
|
|
||||||
eval_positives += torch.sum(torch.eq(classes, class_ids))
|
eval_positives += torch.sum(torch.eq(classes, class_ids))
|
||||||
eval_loss += loss_function(outputs, class_ids)
|
eval_loss += loss_function(outputs, class_ids)
|
||||||
|
|
@ -103,36 +120,30 @@ def train_model(accuracies,
|
||||||
losses.append('eval_loss', eval_loss.item() / len(eval_data))
|
losses.append('eval_loss', eval_loss.item() / len(eval_data))
|
||||||
print("Eval: ", eval_positives.item(), "/ ", len(eval_data), " = ", eval_positives.item() / len(eval_data))
|
print("Eval: ", eval_positives.item(), "/ ", len(eval_data), " = ", eval_positives.item() / len(eval_data))
|
||||||
|
|
||||||
# print epoch summary
|
if eval_positives.item() / len(eval_data) > 0.5:
|
||||||
|
torch.save(model.state_dict(), f'models/model-{start_time.strftime("%Y%m%d-%H%M%S")}-epoch-{epoch}.pth')
|
||||||
# print(f"Epoch: {epoch} --- Train loss: {train_loss:7.4f} --- Eval loss: {eval_loss:7.4f}")
|
|
||||||
|
|
||||||
if eval_positives.item() / len(eval_data) > 0.3:
|
|
||||||
torch.save(model.state_dict(), f'models/model-{start_time.strftime("%Y%m%d-%H%M%S")}-epoch-{epoch}.pt')
|
|
||||||
with open(f'models/model-{start_time.strftime("%Y%m%d-%H%M%S")}.csv', 'a') as file:
|
with open(f'models/model-{start_time.strftime("%Y%m%d-%H%M%S")}.csv', 'a') as file:
|
||||||
file.write(f'{epoch};{len(augmented_train_data)};{len(eval_data)};{train_loss.item()};{eval_loss.item()};'
|
file.write(f'{epoch};{len(augmented_train_data)};{len(eval_data)};{train_loss.item()};{eval_loss.item()};'
|
||||||
f'{train_positives};{eval_positives}\n')
|
f'{train_positives};{eval_positives}\n')
|
||||||
|
|
||||||
|
|
||||||
def train_worker(p_epoch, p_train, p_eval, plotter_accuracies, plotter_loss):
|
def train_worker(p_epoch, p_train, p_eval, plotter_accuracies, plotter_loss, start_time):
|
||||||
if not torch.cuda.is_available():
|
if not torch.cuda.is_available():
|
||||||
raise RuntimeError("GPU not available")
|
raise RuntimeError("GPU not available")
|
||||||
|
|
||||||
device = 'cuda'
|
device = 'cuda'
|
||||||
|
|
||||||
model = MyCNN(input_channels=1,
|
model.to(device)
|
||||||
input_size=(100, 100)).to(device)
|
|
||||||
|
|
||||||
num_epochs = 1000000
|
num_epochs = 1000000
|
||||||
|
|
||||||
batch_size = 64
|
batch_size = 64
|
||||||
optimizer = torch.optim.Adam(model.parameters(),
|
optimizer = torch.optim.Adam(model.parameters(),
|
||||||
lr=0.00005,
|
lr=0.0001,
|
||||||
# weight_decay=0.1,
|
|
||||||
fused=True)
|
fused=True)
|
||||||
loss_function = torch.nn.CrossEntropyLoss()
|
loss_function = torch.nn.CrossEntropyLoss()
|
||||||
|
|
||||||
start_time = datetime.now()
|
augment_data = True
|
||||||
|
|
||||||
file_name = f'models/model-{start_time.strftime("%Y%m%d-%H%M%S")}.csv'
|
file_name = f'models/model-{start_time.strftime("%Y%m%d-%H%M%S")}.csv'
|
||||||
with open(file_name.replace(".csv", ".txt"), 'a') as file:
|
with open(file_name.replace(".csv", ".txt"), 'a') as file:
|
||||||
|
|
@ -140,6 +151,7 @@ def train_worker(p_epoch, p_train, p_eval, plotter_accuracies, plotter_loss):
|
||||||
file.write(f"batch_size: {batch_size}\n")
|
file.write(f"batch_size: {batch_size}\n")
|
||||||
file.write(f"optimizer: {optimizer}\n")
|
file.write(f"optimizer: {optimizer}\n")
|
||||||
file.write(f"loss_function: {loss_function}\n")
|
file.write(f"loss_function: {loss_function}\n")
|
||||||
|
file.write(f"augment_data: {augment_data}\n")
|
||||||
file.write(f"model: {model}")
|
file.write(f"model: {model}")
|
||||||
|
|
||||||
train_model(plotter_accuracies, plotter_loss, p_epoch, p_train, p_eval,
|
train_model(plotter_accuracies, plotter_loss, p_epoch, p_train, p_eval,
|
||||||
|
|
@ -148,5 +160,6 @@ def train_worker(p_epoch, p_train, p_eval, plotter_accuracies, plotter_loss):
|
||||||
batch_size,
|
batch_size,
|
||||||
optimizer,
|
optimizer,
|
||||||
loss_function,
|
loss_function,
|
||||||
|
augment_data,
|
||||||
device,
|
device,
|
||||||
start_time)
|
start_time)
|
||||||
|
|
|
||||||
17
main.py
|
|
@ -2,6 +2,7 @@ import math
|
||||||
import threading
|
import threading
|
||||||
import tkinter as tk
|
import tkinter as tk
|
||||||
import tkinter.ttk as ttk
|
import tkinter.ttk as ttk
|
||||||
|
import datetime
|
||||||
|
|
||||||
from AsyncProgress import AsyncProgress
|
from AsyncProgress import AsyncProgress
|
||||||
from Plotter import Plotter
|
from Plotter import Plotter
|
||||||
|
|
@ -21,6 +22,16 @@ def cancel_train(plot):
|
||||||
return cancel_func
|
return cancel_func
|
||||||
|
|
||||||
|
|
||||||
|
def on_exit(root, plotter_acc, plotter_loss, start_time):
|
||||||
|
def func():
|
||||||
|
plotter_acc.save(f'models/model-{start_time.strftime("%Y%m%d-%H%M%S")}-acc.jpg')
|
||||||
|
plotter_loss.save(f'models/model-{start_time.strftime("%Y%m%d-%H%M%S")}-loss.jpg')
|
||||||
|
|
||||||
|
root.destroy()
|
||||||
|
|
||||||
|
return func
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
|
|
||||||
root = tk.Tk()
|
root = tk.Tk()
|
||||||
|
|
@ -41,9 +52,13 @@ if __name__ == '__main__':
|
||||||
plotter_loss = Plotter(root)
|
plotter_loss = Plotter(root)
|
||||||
plotter_loss.grid(row=3, column=1)
|
plotter_loss.grid(row=3, column=1)
|
||||||
|
|
||||||
|
start_time = datetime.datetime.now()
|
||||||
|
|
||||||
root.after(0, lambda: threading.Thread(target=cnn_train.train_worker,
|
root.after(0, lambda: threading.Thread(target=cnn_train.train_worker,
|
||||||
args=(pbar_epoch, pbar_train, pbar_eval, plotter_acc, plotter_loss),
|
args=(pbar_epoch, pbar_train, pbar_eval, plotter_acc, plotter_loss, start_time),
|
||||||
daemon=True).start())
|
daemon=True).start())
|
||||||
|
|
||||||
|
root.protocol("WM_DELETE_WINDOW", on_exit(root, plotter_acc, plotter_loss, start_time))
|
||||||
|
|
||||||
root.focus_set()
|
root.focus_set()
|
||||||
root.mainloop()
|
root.mainloop()
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,53 @@
|
||||||
|
import torch
|
||||||
|
import torch.nn
|
||||||
|
|
||||||
|
|
||||||
|
class MyCNN(torch.nn.Module):
|
||||||
|
def __init__(self,
|
||||||
|
input_channels: int):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.layers = torch.nn.Sequential(
|
||||||
|
torch.nn.Conv2d(in_channels=input_channels, out_channels=32, kernel_size=3, padding='same', bias=False),
|
||||||
|
torch.nn.BatchNorm2d(32),
|
||||||
|
torch.nn.ReLU(),
|
||||||
|
torch.nn.Conv2d(in_channels=32, out_channels=32, kernel_size=3, padding='same', bias=False),
|
||||||
|
torch.nn.BatchNorm2d(32),
|
||||||
|
torch.nn.ReLU(),
|
||||||
|
torch.nn.MaxPool2d(kernel_size=3, padding=1),
|
||||||
|
|
||||||
|
torch.nn.Dropout2d(0.1),
|
||||||
|
torch.nn.Conv2d(in_channels=32, out_channels=64, kernel_size=3, padding='same', bias=False),
|
||||||
|
torch.nn.BatchNorm2d(64),
|
||||||
|
torch.nn.ReLU(),
|
||||||
|
torch.nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, padding='same', bias=False),
|
||||||
|
torch.nn.BatchNorm2d(64),
|
||||||
|
torch.nn.ReLU(),
|
||||||
|
torch.nn.MaxPool2d(kernel_size=3, padding=1),
|
||||||
|
|
||||||
|
torch.nn.Dropout2d(0.1),
|
||||||
|
torch.nn.Conv2d(in_channels=64, out_channels=128, kernel_size=3, padding='same', bias=False),
|
||||||
|
torch.nn.BatchNorm2d(128),
|
||||||
|
torch.nn.ReLU(),
|
||||||
|
torch.nn.Conv2d(in_channels=128, out_channels=128, kernel_size=3, padding='same', bias=False),
|
||||||
|
torch.nn.BatchNorm2d(128),
|
||||||
|
torch.nn.ReLU(),
|
||||||
|
torch.nn.MaxPool2d(kernel_size=3, padding=1),
|
||||||
|
|
||||||
|
torch.nn.Flatten(),
|
||||||
|
torch.nn.Dropout(0.25),
|
||||||
|
torch.nn.Linear(in_features=2048, out_features=1024),
|
||||||
|
torch.nn.ReLU(),
|
||||||
|
torch.nn.Linear(in_features=1024, out_features=512),
|
||||||
|
torch.nn.ReLU(),
|
||||||
|
torch.nn.Linear(in_features=512, out_features=20, bias=False)
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(self, input_images: torch.Tensor) -> torch.Tensor:
|
||||||
|
return self.layers(input_images)
|
||||||
|
|
||||||
|
def __repr__(self):
|
||||||
|
return str(self.layers)
|
||||||
|
|
||||||
|
|
||||||
|
model = MyCNN(input_channels=1)
|
||||||
|
After Width: | Height: | Size: 27 KiB |
|
After Width: | Height: | Size: 33 KiB |
|
|
@ -0,0 +1,277 @@
|
||||||
|
0;43722;6237;1514.7474365234375;175.6603546142578;12901;2606
|
||||||
|
1;43722;6237;1135.5784912109375;150.25299072265625;20202;3139
|
||||||
|
2;43722;6237;960.5227661132812;137.34268188476562;23839;3421
|
||||||
|
3;43722;6237;830.269287109375;125.98084259033203;26503;3719
|
||||||
|
4;43722;6237;731.7471313476562;121.24964141845703;28628;3849
|
||||||
|
5;43722;6237;645.378662109375;120.33052062988281;30565;3896
|
||||||
|
6;43722;6237;582.5308837890625;120.5954360961914;31863;3938
|
||||||
|
7;43722;6237;523.0032958984375;115.88058471679688;33160;4000
|
||||||
|
8;43722;6237;482.0741882324219;118.56775665283203;34019;4070
|
||||||
|
9;43722;6237;445.5543212890625;120.63733673095703;34848;4090
|
||||||
|
10;43722;6237;414.6208190917969;119.50141906738281;35441;4151
|
||||||
|
11;43722;6237;386.66143798828125;124.22693634033203;35971;4111
|
||||||
|
12;43722;6237;363.51251220703125;120.64466094970703;36512;4108
|
||||||
|
13;43722;6237;346.09075927734375;116.75762176513672;36813;4234
|
||||||
|
14;43722;6237;329.0418395996094;120.47087097167969;37123;4178
|
||||||
|
15;43722;6237;312.50323486328125;117.59949493408203;37509;4235
|
||||||
|
16;43722;6237;298.0853576660156;121.49691772460938;37694;4211
|
||||||
|
17;43722;6237;285.36566162109375;120.2123794555664;38003;4254
|
||||||
|
18;43722;6237;279.4658508300781;122.16336059570312;38094;4243
|
||||||
|
19;43722;6237;266.1520080566406;118.454345703125;38375;4298
|
||||||
|
20;43722;6237;253.64671325683594;122.3826675415039;38628;4263
|
||||||
|
21;43722;6237;240.9693145751953;128.91136169433594;38880;4230
|
||||||
|
22;43722;6237;241.7603302001953;126.29266357421875;38913;4278
|
||||||
|
23;43722;6237;232.66427612304688;117.6049575805664;38959;4328
|
||||||
|
24;43722;6237;221.97898864746094;119.4714584350586;39250;4275
|
||||||
|
25;43722;6237;210.0189971923828;120.75641632080078;39432;4328
|
||||||
|
26;43722;6237;207.76263427734375;134.53555297851562;39562;4182
|
||||||
|
27;43722;6237;202.59486389160156;131.0316925048828;39611;4238
|
||||||
|
28;43722;6237;200.13861083984375;135.6504669189453;39686;4254
|
||||||
|
29;43722;6237;197.95249938964844;126.32669067382812;39693;4305
|
||||||
|
30;43722;6237;191.49545288085938;125.14106750488281;39804;4309
|
||||||
|
31;43722;6237;186.21414184570312;127.87908935546875;39975;4293
|
||||||
|
32;43722;6237;181.6565704345703;131.55189514160156;40026;4333
|
||||||
|
33;43722;6237;172.87303161621094;131.88626098632812;40152;4353
|
||||||
|
34;43722;6237;172.97467041015625;124.51634216308594;40188;4344
|
||||||
|
35;43722;6237;169.8540802001953;126.38551330566406;40317;4372
|
||||||
|
36;43722;6237;164.03897094726562;133.2968292236328;40394;4313
|
||||||
|
37;43722;6237;157.93531799316406;131.04075622558594;40475;4372
|
||||||
|
38;43722;6237;158.68841552734375;133.65902709960938;40459;4322
|
||||||
|
39;43722;6237;154.88677978515625;133.8780975341797;40551;4339
|
||||||
|
40;43722;6237;151.96273803710938;129.21278381347656;40606;4403
|
||||||
|
41;43722;6237;147.31663513183594;140.08470153808594;40687;4350
|
||||||
|
42;43722;6237;142.7169647216797;132.55511474609375;40775;4370
|
||||||
|
43;43722;6237;140.73133850097656;135.10562133789062;40815;4354
|
||||||
|
44;43722;6237;140.28749084472656;132.77880859375;40849;4404
|
||||||
|
45;43722;6237;136.1907196044922;134.705322265625;40913;4361
|
||||||
|
46;43722;6237;138.6764678955078;134.1383056640625;40865;4377
|
||||||
|
47;43722;6237;131.3759307861328;143.128173828125;41000;4352
|
||||||
|
48;43722;6237;128.32232666015625;143.2043914794922;41048;4361
|
||||||
|
49;43722;6237;126.78913116455078;142.99790954589844;41080;4359
|
||||||
|
50;43722;6237;122.46489715576172;138.79641723632812;41203;4370
|
||||||
|
51;43722;6237;123.42253875732422;142.66725158691406;41162;4354
|
||||||
|
52;43722;6237;124.3193359375;138.46434020996094;41149;4393
|
||||||
|
53;43722;6237;115.27859497070312;138.3941650390625;41320;4438
|
||||||
|
54;43722;6237;118.68241882324219;138.24319458007812;41285;4444
|
||||||
|
55;43722;6237;116.2236557006836;137.2938995361328;41290;4418
|
||||||
|
56;43722;6237;114.83257293701172;140.25350952148438;41309;4411
|
||||||
|
57;43722;6237;111.2252426147461;138.39280700683594;41380;4425
|
||||||
|
58;43722;6237;111.19110870361328;136.1656951904297;41407;4422
|
||||||
|
59;43722;6237;111.5517578125;140.3943328857422;41370;4425
|
||||||
|
60;43722;6237;106.5470962524414;138.77578735351562;41507;4445
|
||||||
|
61;43722;6237;103.24518585205078;144.9277801513672;41564;4413
|
||||||
|
62;43722;6237;103.84107971191406;148.6267547607422;41519;4418
|
||||||
|
63;43722;6237;102.4736328125;140.71527099609375;41555;4423
|
||||||
|
64;43722;6237;98.89949035644531;143.99327087402344;41618;4447
|
||||||
|
65;43722;6237;98.72052001953125;149.64785766601562;41664;4363
|
||||||
|
66;43722;6237;98.64189910888672;145.81039428710938;41648;4407
|
||||||
|
67;43722;6237;95.4326400756836;150.38888549804688;41735;4371
|
||||||
|
68;43722;6237;96.24220275878906;147.1780548095703;41702;4424
|
||||||
|
69;43722;6237;91.39864349365234;151.0804443359375;41819;4422
|
||||||
|
70;43722;6237;93.90299224853516;147.62901306152344;41742;4424
|
||||||
|
71;43722;6237;90.43379974365234;152.78158569335938;41809;4418
|
||||||
|
72;43722;6237;90.1348648071289;146.18040466308594;41817;4413
|
||||||
|
73;43722;6237;87.73262023925781;150.17283630371094;41877;4426
|
||||||
|
74;43722;6237;87.60153198242188;151.4879913330078;41858;4414
|
||||||
|
75;43722;6237;85.25221252441406;148.32968139648438;41908;4485
|
||||||
|
76;43722;6237;80.42049407958984;150.0015106201172;42057;4423
|
||||||
|
77;43722;6237;82.6187973022461;149.63787841796875;41972;4433
|
||||||
|
78;43722;6237;80.78215789794922;151.9192657470703;41974;4393
|
||||||
|
79;43722;6237;78.09911346435547;155.3000030517578;42050;4448
|
||||||
|
80;43722;6237;79.99258422851562;152.6612548828125;42050;4442
|
||||||
|
81;43722;6237;77.55635833740234;154.47512817382812;42093;4428
|
||||||
|
82;43722;6237;77.18683624267578;151.2737579345703;42081;4411
|
||||||
|
83;43722;6237;77.42937469482422;150.166748046875;42088;4464
|
||||||
|
84;43722;6237;76.4849853515625;152.59466552734375;42064;4437
|
||||||
|
85;43722;6237;73.77530670166016;167.58355712890625;42135;4358
|
||||||
|
86;43722;6237;73.78494262695312;154.68739318847656;42147;4454
|
||||||
|
87;43722;6237;72.63799285888672;152.5028533935547;42225;4454
|
||||||
|
88;43722;6237;72.04568481445312;151.28076171875;42140;4445
|
||||||
|
89;43722;6237;69.72296142578125;157.5628662109375;42222;4429
|
||||||
|
90;43722;6237;69.09281158447266;159.32858276367188;42241;4453
|
||||||
|
91;43722;6237;69.55367279052734;164.70516967773438;42218;4431
|
||||||
|
92;43722;6237;68.90850067138672;158.1874542236328;42239;4463
|
||||||
|
93;43722;6237;67.16905975341797;155.82281494140625;42272;4455
|
||||||
|
94;43722;6237;62.867496490478516;160.17471313476562;42366;4442
|
||||||
|
95;43722;6237;66.03118896484375;164.81056213378906;42309;4418
|
||||||
|
96;43722;6237;66.00711059570312;151.60165405273438;42337;4431
|
||||||
|
97;43722;6237;61.56562805175781;164.44223022460938;42436;4412
|
||||||
|
98;43722;6237;64.58049011230469;160.0134735107422;42339;4427
|
||||||
|
99;43722;6237;61.30560302734375;158.8366241455078;42419;4443
|
||||||
|
100;43722;6237;63.23543167114258;156.88323974609375;42373;4421
|
||||||
|
101;43722;6237;62.51991271972656;159.0310516357422;42389;4438
|
||||||
|
102;43722;6237;60.01646423339844;158.1140899658203;42421;4417
|
||||||
|
103;43722;6237;60.08512496948242;161.7741241455078;42438;4431
|
||||||
|
104;43722;6237;57.94487762451172;160.0200653076172;42489;4446
|
||||||
|
105;43722;6237;58.03913879394531;165.98825073242188;42441;4427
|
||||||
|
106;43722;6237;56.5332145690918;173.01539611816406;42518;4398
|
||||||
|
107;43722;6237;57.06502914428711;167.72816467285156;42453;4399
|
||||||
|
108;43722;6237;54.27724838256836;162.75970458984375;42560;4494
|
||||||
|
109;43722;6237;55.83955764770508;159.5300750732422;42503;4462
|
||||||
|
110;43722;6237;52.22243881225586;172.03147888183594;42589;4406
|
||||||
|
111;43722;6237;55.676963806152344;161.65074157714844;42531;4451
|
||||||
|
112;43722;6237;54.30720520019531;161.99415588378906;42543;4426
|
||||||
|
113;43722;6237;52.3731575012207;167.92483520507812;42589;4406
|
||||||
|
114;43722;6237;53.243736267089844;162.59381103515625;42554;4427
|
||||||
|
115;43722;6237;53.11757278442383;159.04811096191406;42577;4455
|
||||||
|
116;43722;6237;49.79232406616211;164.5228271484375;42623;4434
|
||||||
|
117;43722;6237;48.79463195800781;178.80307006835938;42639;4441
|
||||||
|
118;43722;6237;48.74504470825195;169.4583282470703;42663;4458
|
||||||
|
119;43722;6237;48.75444412231445;169.40577697753906;42665;4392
|
||||||
|
120;43722;6237;49.253173828125;167.4823455810547;42667;4379
|
||||||
|
121;43722;6237;48.43288040161133;166.68701171875;42691;4427
|
||||||
|
122;43722;6237;47.13943862915039;168.96035766601562;42718;4449
|
||||||
|
123;43722;6237;47.78165817260742;170.65972900390625;42712;4432
|
||||||
|
124;43722;6237;46.490360260009766;171.7513427734375;42705;4441
|
||||||
|
125;43722;6237;45.12886428833008;176.67503356933594;42768;4435
|
||||||
|
126;43722;6237;45.8767204284668;175.19534301757812;42709;4444
|
||||||
|
127;43722;6237;46.022666931152344;172.56321716308594;42726;4463
|
||||||
|
128;43722;6237;43.341373443603516;174.56529235839844;42789;4383
|
||||||
|
129;43722;6237;43.37924575805664;175.2935791015625;42744;4422
|
||||||
|
130;43722;6237;43.61152648925781;176.03799438476562;42819;4455
|
||||||
|
131;43722;6237;45.63074493408203;171.19691467285156;42724;4444
|
||||||
|
132;43722;6237;45.26616668701172;172.88404846191406;42713;4480
|
||||||
|
133;43722;6237;44.142173767089844;177.1351318359375;42761;4435
|
||||||
|
134;43722;6237;42.29781723022461;173.97828674316406;42837;4416
|
||||||
|
135;43722;6237;43.62605667114258;171.77389526367188;42752;4395
|
||||||
|
136;43722;6237;42.02931213378906;180.46517944335938;42778;4447
|
||||||
|
137;43722;6237;41.774253845214844;182.2983856201172;42818;4343
|
||||||
|
138;43722;6237;41.40755081176758;176.11892700195312;42811;4470
|
||||||
|
139;43722;6237;41.578887939453125;177.47410583496094;42799;4424
|
||||||
|
140;43722;6237;41.19217300415039;180.14012145996094;42820;4430
|
||||||
|
141;43722;6237;40.049095153808594;179.26475524902344;42845;4401
|
||||||
|
142;43722;6237;40.21591567993164;186.4405975341797;42866;4399
|
||||||
|
143;43722;6237;39.01491165161133;184.1543731689453;42883;4455
|
||||||
|
144;43722;6237;38.47568893432617;181.78121948242188;42909;4394
|
||||||
|
145;43722;6237;37.514320373535156;181.0046844482422;42900;4432
|
||||||
|
146;43722;6237;38.65336608886719;176.87356567382812;42894;4441
|
||||||
|
147;43722;6237;38.86013412475586;180.0382843017578;42871;4430
|
||||||
|
148;43722;6237;37.66798400878906;181.6970672607422;42880;4415
|
||||||
|
149;43722;6237;37.83094787597656;190.0052947998047;42918;4385
|
||||||
|
150;43722;6237;34.7658805847168;180.7398223876953;42984;4404
|
||||||
|
151;43722;6237;36.68027877807617;179.19505310058594;42926;4414
|
||||||
|
152;43722;6237;38.391807556152344;178.177734375;42881;4434
|
||||||
|
153;43722;6237;35.51554489135742;178.7106170654297;42973;4430
|
||||||
|
154;43722;6237;35.32855987548828;187.07196044921875;42948;4448
|
||||||
|
155;43722;6237;36.418392181396484;189.13511657714844;42894;4384
|
||||||
|
156;43722;6237;33.90447235107422;182.01612854003906;43016;4450
|
||||||
|
157;43722;6237;35.53436279296875;181.25128173828125;42967;4448
|
||||||
|
158;43722;6237;33.19279479980469;180.28590393066406;43002;4476
|
||||||
|
159;43722;6237;33.298118591308594;187.7205047607422;42973;4440
|
||||||
|
160;43722;6237;34.43337631225586;184.99697875976562;42955;4437
|
||||||
|
161;43722;6237;34.259647369384766;192.1217803955078;42954;4408
|
||||||
|
162;43722;6237;34.26215362548828;185.66143798828125;42980;4444
|
||||||
|
163;43722;6237;33.019798278808594;190.13560485839844;42994;4438
|
||||||
|
164;43722;6237;32.04604721069336;181.56907653808594;43022;4431
|
||||||
|
165;43722;6237;30.804288864135742;195.40943908691406;43056;4410
|
||||||
|
166;43722;6237;32.44491195678711;182.85813903808594;43054;4449
|
||||||
|
167;43722;6237;30.949054718017578;186.40872192382812;43037;4424
|
||||||
|
168;43722;6237;34.90159225463867;190.98541259765625;42945;4398
|
||||||
|
169;43722;6237;33.85884475708008;190.09579467773438;42990;4376
|
||||||
|
170;43722;6237;31.5644474029541;188.6020965576172;42997;4433
|
||||||
|
171;43722;6237;31.076412200927734;192.26309204101562;43044;4458
|
||||||
|
172;43722;6237;30.08343505859375;183.1046905517578;43077;4460
|
||||||
|
173;43722;6237;30.289939880371094;185.90887451171875;43061;4470
|
||||||
|
174;43722;6237;31.39510154724121;186.18966674804688;43041;4456
|
||||||
|
175;43722;6237;31.1684513092041;189.1432647705078;43040;4449
|
||||||
|
176;43722;6237;28.690387725830078;192.99476623535156;43079;4428
|
||||||
|
177;43722;6237;30.27741813659668;185.2362823486328;43069;4435
|
||||||
|
178;43722;6237;28.554563522338867;187.9214630126953;43096;4412
|
||||||
|
179;43722;6237;30.413394927978516;192.8180389404297;43067;4418
|
||||||
|
180;43722;6237;30.49958038330078;188.61312866210938;43055;4424
|
||||||
|
181;43722;6237;28.986902236938477;198.4716339111328;43091;4414
|
||||||
|
182;43722;6237;28.59244155883789;192.9297637939453;43096;4413
|
||||||
|
183;43722;6237;30.493051528930664;196.74461364746094;43060;4408
|
||||||
|
184;43722;6237;30.890644073486328;189.09024047851562;43074;4424
|
||||||
|
185;43722;6237;29.030553817749023;193.50816345214844;43099;4384
|
||||||
|
186;43722;6237;26.469297409057617;195.7086181640625;43159;4391
|
||||||
|
187;43722;6237;29.260929107666016;195.16818237304688;43069;4409
|
||||||
|
188;43722;6237;29.081787109375;192.8885498046875;43119;4452
|
||||||
|
189;43722;6237;26.148517608642578;193.26358032226562;43183;4390
|
||||||
|
190;43722;6237;29.156747817993164;191.98316955566406;43119;4460
|
||||||
|
191;43722;6237;26.234939575195312;200.08349609375;43179;4433
|
||||||
|
192;43722;6237;28.03680419921875;195.2965545654297;43119;4387
|
||||||
|
193;43722;6237;28.13884162902832;189.22703552246094;43150;4443
|
||||||
|
194;43722;6237;27.855710983276367;186.77207946777344;43138;4461
|
||||||
|
195;43722;6237;25.09186363220215;199.45249938964844;43177;4448
|
||||||
|
196;43722;6237;26.17472267150879;196.16748046875;43196;4426
|
||||||
|
197;43722;6237;28.630319595336914;189.7792205810547;43087;4438
|
||||||
|
198;43722;6237;24.919069290161133;196.20501708984375;43167;4458
|
||||||
|
199;43722;6237;25.823429107666016;193.51722717285156;43184;4447
|
||||||
|
200;43722;6237;26.948225021362305;190.61703491210938;43156;4430
|
||||||
|
201;43722;6237;24.1967830657959;195.03314208984375;43207;4459
|
||||||
|
202;43722;6237;23.94419288635254;196.92050170898438;43209;4439
|
||||||
|
203;43722;6237;25.682336807250977;196.66305541992188;43183;4424
|
||||||
|
204;43722;6237;25.062114715576172;197.73385620117188;43195;4453
|
||||||
|
205;43722;6237;25.268651962280273;196.42738342285156;43173;4435
|
||||||
|
206;43722;6237;24.5706844329834;196.53646850585938;43193;4450
|
||||||
|
207;43722;6237;25.190696716308594;197.8087921142578;43199;4438
|
||||||
|
208;43722;6237;25.013042449951172;199.36846923828125;43175;4435
|
||||||
|
209;43722;6237;22.459022521972656;202.0741424560547;43234;4438
|
||||||
|
210;43722;6237;24.379507064819336;194.94219970703125;43209;4427
|
||||||
|
211;43722;6237;25.14781379699707;194.5927276611328;43198;4450
|
||||||
|
212;43722;6237;22.717470169067383;199.97332763671875;43238;4416
|
||||||
|
213;43722;6237;23.583417892456055;191.59429931640625;43217;4443
|
||||||
|
214;43722;6237;23.64990234375;200.4409942626953;43197;4426
|
||||||
|
215;43722;6237;24.00716781616211;207.8960418701172;43202;4381
|
||||||
|
216;43722;6237;23.25735855102539;194.6672821044922;43201;4433
|
||||||
|
217;43722;6237;23.340778350830078;201.62063598632812;43218;4414
|
||||||
|
218;43722;6237;23.24645233154297;203.1522674560547;43202;4410
|
||||||
|
219;43722;6237;22.242794036865234;196.96304321289062;43240;4456
|
||||||
|
220;43722;6237;22.977066040039062;203.4141387939453;43228;4403
|
||||||
|
221;43722;6237;24.17426872253418;203.6529998779297;43203;4363
|
||||||
|
222;43722;6237;21.43939971923828;198.41204833984375;43258;4440
|
||||||
|
223;43722;6237;23.04403305053711;208.3946075439453;43240;4414
|
||||||
|
224;43722;6237;23.194822311401367;202.41075134277344;43212;4383
|
||||||
|
225;43722;6237;22.339262008666992;205.32154846191406;43259;4430
|
||||||
|
226;43722;6237;22.946557998657227;198.68875122070312;43230;4447
|
||||||
|
227;43722;6237;22.933635711669922;198.5775604248047;43214;4431
|
||||||
|
228;43722;6237;21.506454467773438;198.9333038330078;43283;4447
|
||||||
|
229;43722;6237;21.730619430541992;208.0950164794922;43244;4412
|
||||||
|
230;43722;6237;22.364229202270508;207.3225555419922;43251;4393
|
||||||
|
231;43722;6237;22.753942489624023;196.91549682617188;43231;4447
|
||||||
|
232;43722;6237;21.35562515258789;201.66734313964844;43257;4426
|
||||||
|
233;43722;6237;20.41690444946289;206.2218475341797;43276;4453
|
||||||
|
234;43722;6237;20.87632179260254;211.685302734375;43265;4390
|
||||||
|
235;43722;6237;22.52677345275879;206.89291381835938;43230;4415
|
||||||
|
236;43722;6237;21.175365447998047;212.48947143554688;43280;4402
|
||||||
|
237;43722;6237;20.57710075378418;207.79122924804688;43294;4408
|
||||||
|
238;43722;6237;19.610031127929688;210.72604370117188;43298;4424
|
||||||
|
239;43722;6237;21.520944595336914;205.5826873779297;43265;4415
|
||||||
|
240;43722;6237;20.306123733520508;206.9119873046875;43294;4392
|
||||||
|
241;43722;6237;21.047883987426758;209.90646362304688;43277;4427
|
||||||
|
242;43722;6237;21.254474639892578;213.11070251464844;43262;4387
|
||||||
|
243;43722;6237;21.05403709411621;206.6425018310547;43269;4431
|
||||||
|
244;43722;6237;19.449607849121094;208.35243225097656;43305;4427
|
||||||
|
245;43722;6237;19.96088981628418;213.4718780517578;43312;4406
|
||||||
|
246;43722;6237;19.126388549804688;218.3990936279297;43325;4400
|
||||||
|
247;43722;6237;19.717483520507812;210.0780029296875;43305;4406
|
||||||
|
248;43722;6237;20.718196868896484;209.91795349121094;43271;4401
|
||||||
|
249;43722;6237;17.500123977661133;208.0767364501953;43359;4416
|
||||||
|
250;43722;6237;20.81154441833496;208.70692443847656;43284;4407
|
||||||
|
251;43722;6237;20.464780807495117;213.62667846679688;43270;4413
|
||||||
|
252;43722;6237;19.197351455688477;212.01211547851562;43289;4413
|
||||||
|
253;43722;6237;19.951648712158203;211.954833984375;43317;4390
|
||||||
|
254;43722;6237;18.864404678344727;219.70535278320312;43283;4395
|
||||||
|
255;43722;6237;20.94219398498535;202.2901153564453;43281;4431
|
||||||
|
256;43722;6237;19.4694766998291;212.18911743164062;43320;4397
|
||||||
|
257;43722;6237;19.134057998657227;212.84266662597656;43301;4371
|
||||||
|
258;43722;6237;17.84954261779785;221.81382751464844;43342;4394
|
||||||
|
259;43722;6237;18.62063217163086;215.65443420410156;43302;4397
|
||||||
|
260;43722;6237;19.49367332458496;215.1182403564453;43293;4376
|
||||||
|
261;43722;6237;19.67314910888672;215.09274291992188;43289;4354
|
||||||
|
262;43722;6237;18.555265426635742;220.30230712890625;43321;4381
|
||||||
|
263;43722;6237;18.139114379882812;213.9473419189453;43349;4389
|
||||||
|
264;43722;6237;17.78927993774414;213.45562744140625;43334;4346
|
||||||
|
265;43722;6237;17.6031551361084;215.17929077148438;43347;4394
|
||||||
|
266;43722;6237;17.88275718688965;220.81520080566406;43330;4393
|
||||||
|
267;43722;6237;17.69892692565918;215.7049102783203;43348;4431
|
||||||
|
268;43722;6237;18.19742202758789;220.33604431152344;43330;4386
|
||||||
|
269;43722;6237;18.801044464111328;208.64773559570312;43325;4416
|
||||||
|
270;43722;6237;18.08159828186035;217.66757202148438;43328;4417
|
||||||
|
271;43722;6237;17.587453842163086;214.5685577392578;43327;4363
|
||||||
|
272;43722;6237;17.779569625854492;219.0640411376953;43345;4394
|
||||||
|
273;43722;6237;17.884660720825195;216.69700622558594;43343;4425
|
||||||
|
274;43722;6237;16.699108123779297;217.4732666015625;43362;4425
|
||||||
|
275;43722;6237;17.471420288085938;218.8936004638672;43341;4400
|
||||||
|
276;43722;6237;17.408723831176758;215.26779174804688;43335;4420
|
||||||
|
|
|
@ -0,0 +1,49 @@
|
||||||
|
device: cuda
|
||||||
|
batch_size: 64
|
||||||
|
optimizer: Adam (
|
||||||
|
Parameter Group 0
|
||||||
|
amsgrad: False
|
||||||
|
betas: (0.9, 0.999)
|
||||||
|
capturable: False
|
||||||
|
differentiable: False
|
||||||
|
eps: 1e-08
|
||||||
|
foreach: None
|
||||||
|
fused: True
|
||||||
|
lr: 0.0001
|
||||||
|
maximize: False
|
||||||
|
weight_decay: 0
|
||||||
|
)
|
||||||
|
loss_function: CrossEntropyLoss()
|
||||||
|
augment_data: True
|
||||||
|
model: Sequential(
|
||||||
|
(0): Conv2d(1, 32, kernel_size=(3, 3), stride=(1, 1), padding=same, bias=False)
|
||||||
|
(1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
|
||||||
|
(2): ReLU()
|
||||||
|
(3): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=same, bias=False)
|
||||||
|
(4): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
|
||||||
|
(5): ReLU()
|
||||||
|
(6): MaxPool2d(kernel_size=3, stride=3, padding=1, dilation=1, ceil_mode=False)
|
||||||
|
(7): Dropout2d(p=0.1, inplace=False)
|
||||||
|
(8): Conv2d(32, 64, kernel_size=(3, 3), stride=(1, 1), padding=same, bias=False)
|
||||||
|
(9): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
|
||||||
|
(10): ReLU()
|
||||||
|
(11): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=same, bias=False)
|
||||||
|
(12): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
|
||||||
|
(13): ReLU()
|
||||||
|
(14): MaxPool2d(kernel_size=3, stride=3, padding=1, dilation=1, ceil_mode=False)
|
||||||
|
(15): Dropout2d(p=0.1, inplace=False)
|
||||||
|
(16): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=same, bias=False)
|
||||||
|
(17): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
|
||||||
|
(18): ReLU()
|
||||||
|
(19): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=same, bias=False)
|
||||||
|
(20): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
|
||||||
|
(21): ReLU()
|
||||||
|
(22): MaxPool2d(kernel_size=3, stride=3, padding=1, dilation=1, ceil_mode=False)
|
||||||
|
(23): Flatten(start_dim=1, end_dim=-1)
|
||||||
|
(24): Dropout(p=0.25, inplace=False)
|
||||||
|
(25): Linear(in_features=2048, out_features=1024, bias=True)
|
||||||
|
(26): ReLU()
|
||||||
|
(27): Linear(in_features=1024, out_features=512, bias=True)
|
||||||
|
(28): ReLU()
|
||||||
|
(29): Linear(in_features=512, out_features=20, bias=False)
|
||||||
|
)
|
||||||
|
Before Width: | Height: | Size: 8.7 KiB |
|
Before Width: | Height: | Size: 6.2 KiB |
|
Before Width: | Height: | Size: 1.2 KiB |
|
Before Width: | Height: | Size: 1.3 KiB |
|
Before Width: | Height: | Size: 3.9 KiB |
|
Before Width: | Height: | Size: 2.7 KiB |
|
Before Width: | Height: | Size: 4.1 KiB |
|
Before Width: | Height: | Size: 4.1 KiB |
|
Before Width: | Height: | Size: 4.2 KiB |
|
Before Width: | Height: | Size: 1.4 KiB |
|
Before Width: | Height: | Size: 1.9 KiB |
|
Before Width: | Height: | Size: 2.0 KiB |
|
Before Width: | Height: | Size: 1.4 KiB |
|
Before Width: | Height: | Size: 1.5 KiB |
|
Before Width: | Height: | Size: 1.5 KiB |
|
Before Width: | Height: | Size: 1.4 KiB |
|
Before Width: | Height: | Size: 6.1 KiB |
|
Before Width: | Height: | Size: 5.3 KiB |
|
Before Width: | Height: | Size: 5.6 KiB |
|
Before Width: | Height: | Size: 5.5 KiB |
|
Before Width: | Height: | Size: 5.3 KiB |
|
Before Width: | Height: | Size: 4.0 KiB |
|
Before Width: | Height: | Size: 2.2 KiB |
|
Before Width: | Height: | Size: 2.3 KiB |
|
Before Width: | Height: | Size: 2.5 KiB |
|
Before Width: | Height: | Size: 2.2 KiB |
|
Before Width: | Height: | Size: 2.6 KiB |
|
Before Width: | Height: | Size: 5.5 KiB |
|
Before Width: | Height: | Size: 5.2 KiB |
|
Before Width: | Height: | Size: 5.4 KiB |
|
Before Width: | Height: | Size: 5.0 KiB |
|
Before Width: | Height: | Size: 2.4 KiB |
|
Before Width: | Height: | Size: 2.6 KiB |
|
Before Width: | Height: | Size: 2.8 KiB |
|
Before Width: | Height: | Size: 2.7 KiB |
|
Before Width: | Height: | Size: 2.2 KiB |
|
Before Width: | Height: | Size: 5.1 KiB |
|
Before Width: | Height: | Size: 3.5 KiB |
|
Before Width: | Height: | Size: 3.4 KiB |
|
Before Width: | Height: | Size: 3.6 KiB |
|
Before Width: | Height: | Size: 4.0 KiB |
|
Before Width: | Height: | Size: 2.8 KiB |
|
Before Width: | Height: | Size: 3.3 KiB |
|
Before Width: | Height: | Size: 6.3 KiB |
|
Before Width: | Height: | Size: 4.4 KiB |
|
Before Width: | Height: | Size: 6.0 KiB |
|
Before Width: | Height: | Size: 4.6 KiB |
|
Before Width: | Height: | Size: 2.0 KiB |
|
Before Width: | Height: | Size: 5.1 KiB |
|
Before Width: | Height: | Size: 6.4 KiB |
|
Before Width: | Height: | Size: 5.4 KiB |
|
Before Width: | Height: | Size: 3.3 KiB |
|
Before Width: | Height: | Size: 2.5 KiB |
|
Before Width: | Height: | Size: 6.2 KiB |
|
Before Width: | Height: | Size: 1.3 KiB |
|
Before Width: | Height: | Size: 1.2 KiB |
|
Before Width: | Height: | Size: 1.2 KiB |
|
Before Width: | Height: | Size: 1.4 KiB |
|
Before Width: | Height: | Size: 1.2 KiB |
|
Before Width: | Height: | Size: 1.4 KiB |
|
Before Width: | Height: | Size: 2.5 KiB |
|
Before Width: | Height: | Size: 1.7 KiB |
|
Before Width: | Height: | Size: 2.4 KiB |
|
Before Width: | Height: | Size: 2.2 KiB |
|
Before Width: | Height: | Size: 2.0 KiB |
|
Before Width: | Height: | Size: 1.3 KiB |
|
Before Width: | Height: | Size: 1.5 KiB |
|
Before Width: | Height: | Size: 1.3 KiB |
|
Before Width: | Height: | Size: 1.3 KiB |
|
Before Width: | Height: | Size: 1.4 KiB |
|
Before Width: | Height: | Size: 2.0 KiB |
|
Before Width: | Height: | Size: 1.7 KiB |
|
Before Width: | Height: | Size: 1.2 KiB |
|
Before Width: | Height: | Size: 2.1 KiB |
|
Before Width: | Height: | Size: 1.5 KiB |
|
Before Width: | Height: | Size: 4.1 KiB |
|
Before Width: | Height: | Size: 5.2 KiB |
|
Before Width: | Height: | Size: 5.5 KiB |
|
Before Width: | Height: | Size: 4.6 KiB |
|
Before Width: | Height: | Size: 5.9 KiB |
|
Before Width: | Height: | Size: 3.3 KiB |
|
Before Width: | Height: | Size: 2.8 KiB |
|
Before Width: | Height: | Size: 3.2 KiB |
|
Before Width: | Height: | Size: 3.1 KiB |
|
Before Width: | Height: | Size: 3.1 KiB |
|
Before Width: | Height: | Size: 3.0 KiB |