import queue import threading from typing import Optional import torch.utils.data class AsyncDataLoader(torch.utils.data.DataLoader): def __init__(self, data, prefetch_size: Optional[int] = None, **kwargs): super().__init__(data, **kwargs) self.data_len = len(self) self.prefetch_size = prefetch_size if prefetch_size is not None else self.data_len self.__data_queue = queue.Queue() self.__dataset_access = threading.Lock() self.__dataset_access_request = threading.Event() self.__dataset_access_request_release = threading.Event() self.__data_extract_event = threading.Event() self._worker_thread = threading.Thread(target=AsyncDataLoader.__async_loader, args=(self,), daemon=True) self._worker_thread.start() def __iter__(self): for _ in range(self.data_len): _, data = self.__data_queue.get() self.__data_extract_event.set() yield data @staticmethod def __async_loader(self): self.__dataset_access.acquire(blocking=True) while True: for enumerated_data in enumerate(super().__iter__()): self.__data_queue.put(enumerated_data) while not self.__data_queue.qsize() < self.prefetch_size: self.__dataset_access.release() self.__data_extract_event.clear() self.__data_extract_event.wait(10) self.__dataset_access.acquire() if self.__dataset_access_request.is_set(): self.__dataset_access.release() self.__dataset_access_request_release.wait() self.__dataset_access.acquire() self.__dataset_access.release() if __name__ == '__main__': from torch.utils.data import TensorDataset, random_split, dataloader import torch.random import time torch.random.manual_seed(0) num_train_examples = 10 training_data = torch.rand(num_train_examples, 32) random_function = torch.rand(32) sums = torch.sum(training_data * random_function, dim=1) ** 2 targets = torch.where(sums < 100, sums, torch.zeros_like(sums)) all_data = TensorDataset(training_data, targets) train_data, eval_data = random_split(all_data, [0.5, 0.5]) train = AsyncDataLoader(train_data, batch_size=1, prefetch_size=4) eval = AsyncDataLoader(eval_data, batch_size=1, prefetch_size=4) for e in range(3): print('e', e) for i, _ in enumerate(train): print("train", i) time.sleep(1) for i, _ in enumerate(eval): print('eval', i) time.sleep(1)