diff --git a/AsyncDataLoader.py b/AsyncDataLoader.py new file mode 100644 index 0000000..3bdb07f --- /dev/null +++ b/AsyncDataLoader.py @@ -0,0 +1,83 @@ +import queue +import threading +from typing import Optional + +import torch.utils.data + + +class AsyncDataLoader(torch.utils.data.DataLoader): + + def __init__(self, data, prefetch_size: Optional[int] = None, **kwargs): + super().__init__(data, **kwargs) + + self.data_len = len(self) + + self.prefetch_size = prefetch_size if prefetch_size is not None else self.data_len + self.__data_queue = queue.Queue() + self.__dataset_access = threading.Lock() + self.__dataset_access_request = threading.Event() + self.__dataset_access_request_release = threading.Event() + self.__data_extract_event = threading.Event() + + self._worker_thread = threading.Thread(target=AsyncDataLoader.__async_loader, args=(self,), daemon=True) + self._worker_thread.start() + + def __iter__(self): + for _ in range(self.data_len): + _, data = self.__data_queue.get() + self.__data_extract_event.set() + yield data + + @staticmethod + def __async_loader(self): + self.__dataset_access.acquire(blocking=True) + + while True: + + for enumerated_data in enumerate(super().__iter__()): + self.__data_queue.put(enumerated_data) + + while not self.__data_queue.qsize() < self.prefetch_size: + self.__dataset_access.release() + + self.__data_extract_event.clear() + self.__data_extract_event.wait(10) + + self.__dataset_access.acquire() + + if self.__dataset_access_request.is_set(): + self.__dataset_access.release() + self.__dataset_access_request_release.wait() + + self.__dataset_access.acquire() + + + self.__dataset_access.release() + + +if __name__ == '__main__': + from torch.utils.data import TensorDataset, random_split, dataloader + import torch.random + import time + + torch.random.manual_seed(0) + num_train_examples = 10 + training_data = torch.rand(num_train_examples, 32) + random_function = torch.rand(32) + sums = torch.sum(training_data * random_function, dim=1) ** 2 + targets = torch.where(sums < 100, sums, torch.zeros_like(sums)) + all_data = TensorDataset(training_data, targets) + train_data, eval_data = random_split(all_data, [0.5, 0.5]) + + train = AsyncDataLoader(train_data, batch_size=1, prefetch_size=4) + eval = AsyncDataLoader(eval_data, batch_size=1, prefetch_size=4) + + for e in range(3): + print('e', e) + for i, _ in enumerate(train): + print("train", i) + time.sleep(1) + + for i, _ in enumerate(eval): + print('eval', i) + time.sleep(1) diff --git a/cnn_train.py b/cnn_train.py index 49268cc..ed95677 100644 --- a/cnn_train.py +++ b/cnn_train.py @@ -8,6 +8,7 @@ from architecture import MyCNN from dataset import ImagesDataset from AImageDataset import AImagesDataset +from AsyncDataLoader import AsyncDataLoader def train_model(accuracies, @@ -22,7 +23,6 @@ def train_model(accuracies, loss_function, device, start_time): - torch.random.manual_seed(42) numpy.random.seed(42) torch.multiprocessing.set_start_method('spawn', force=True) @@ -33,16 +33,16 @@ def train_model(accuracies, train_data, eval_data = torch.utils.data.random_split(dataset, [0.5, 0.5]) - augmented_train_data = AImagesDataset(train_data, False) - train_loader = torch.utils.data.DataLoader(augmented_train_data, - batch_size=batch_size, - num_workers=3, - pin_memory=True, - shuffle=True) - eval_loader = torch.utils.data.DataLoader(eval_data, - batch_size=batch_size, - num_workers=3, - pin_memory=True) + augmented_train_data = AImagesDataset(train_data, True) + train_loader = AsyncDataLoader(augmented_train_data, + batch_size=batch_size, + num_workers=3, + pin_memory=True, + shuffle=True) + eval_loader = AsyncDataLoader(eval_data, + batch_size=batch_size, + num_workers=3, + pin_memory=True) for epoch in progress_epoch.range(num_epochs): @@ -60,6 +60,7 @@ def train_model(accuracies, for batch_nr, (image_t, transforms, img_index, class_ids, labels, paths) \ in enumerate(progress_train_data.iter(train_loader)): + image_t = image_t.to(device) class_ids = class_ids.to(device)