AI-Project/AsyncDataLoader.py

84 lines
2.6 KiB
Python

import queue
import threading
from typing import Optional
import torch.utils.data
class AsyncDataLoader(torch.utils.data.DataLoader):
def __init__(self, data, prefetch_size: Optional[int] = None, **kwargs):
super().__init__(data, **kwargs)
self.data_len = len(self)
self.prefetch_size = prefetch_size if prefetch_size is not None else self.data_len
self.__data_queue = queue.Queue()
self.__dataset_access = threading.Lock()
self.__dataset_access_request = threading.Event()
self.__dataset_access_request_release = threading.Event()
self.__data_extract_event = threading.Event()
self._worker_thread = threading.Thread(target=AsyncDataLoader.__async_loader, args=(self,), daemon=True)
self._worker_thread.start()
def __iter__(self):
for _ in range(self.data_len):
_, data = self.__data_queue.get()
self.__data_extract_event.set()
yield data
@staticmethod
def __async_loader(self):
self.__dataset_access.acquire(blocking=True)
while True:
for enumerated_data in enumerate(super().__iter__()):
self.__data_queue.put(enumerated_data)
while not self.__data_queue.qsize() < self.prefetch_size:
self.__dataset_access.release()
self.__data_extract_event.clear()
self.__data_extract_event.wait(10)
self.__dataset_access.acquire()
if self.__dataset_access_request.is_set():
self.__dataset_access.release()
self.__dataset_access_request_release.wait()
self.__dataset_access.acquire()
self.__dataset_access.release()
if __name__ == '__main__':
from torch.utils.data import TensorDataset, random_split, dataloader
import torch.random
import time
torch.random.manual_seed(0)
num_train_examples = 10
training_data = torch.rand(num_train_examples, 32)
random_function = torch.rand(32)
sums = torch.sum(training_data * random_function, dim=1) ** 2
targets = torch.where(sums < 100, sums, torch.zeros_like(sums))
all_data = TensorDataset(training_data, targets)
train_data, eval_data = random_split(all_data, [0.5, 0.5])
train = AsyncDataLoader(train_data, batch_size=1, prefetch_size=4)
eval = AsyncDataLoader(eval_data, batch_size=1, prefetch_size=4)
for e in range(3):
print('e', e)
for i, _ in enumerate(train):
print("train", i)
time.sleep(1)
for i, _ in enumerate(eval):
print('eval', i)
time.sleep(1)