83 lines
2.6 KiB
Python
83 lines
2.6 KiB
Python
import queue
|
|
import threading
|
|
from typing import Optional
|
|
|
|
import torch.utils.data
|
|
|
|
|
|
class AsyncDataLoader(torch.utils.data.DataLoader):
|
|
|
|
def __init__(self, data, prefetch_size: Optional[int] = None, **kwargs):
|
|
super().__init__(data, **kwargs)
|
|
|
|
self.data_len = len(self)
|
|
|
|
self.prefetch_size = prefetch_size if prefetch_size is not None else self.data_len
|
|
self.__data_queue = queue.Queue()
|
|
self.__dataset_access = threading.Lock()
|
|
self.__dataset_access_request = threading.Event()
|
|
self.__dataset_access_request_release = threading.Event()
|
|
self.__data_extract_event = threading.Event()
|
|
|
|
self._worker_thread = threading.Thread(target=AsyncDataLoader.__async_loader, args=(self,), daemon=True)
|
|
self._worker_thread.start()
|
|
|
|
def __iter__(self):
|
|
for _ in range(self.data_len):
|
|
_, data = self.__data_queue.get()
|
|
self.__data_extract_event.set()
|
|
yield data
|
|
|
|
@staticmethod
|
|
def __async_loader(self):
|
|
self.__dataset_access.acquire(blocking=True)
|
|
|
|
while True:
|
|
|
|
for enumerated_data in enumerate(super().__iter__()):
|
|
self.__data_queue.put(enumerated_data)
|
|
|
|
while not self.__data_queue.qsize() < self.prefetch_size:
|
|
self.__dataset_access.release()
|
|
|
|
self.__data_extract_event.clear()
|
|
self.__data_extract_event.wait(10)
|
|
|
|
self.__dataset_access.acquire()
|
|
|
|
if self.__dataset_access_request.is_set():
|
|
self.__dataset_access.release()
|
|
self.__dataset_access_request_release.wait()
|
|
|
|
self.__dataset_access.acquire()
|
|
|
|
self.__dataset_access.release()
|
|
|
|
|
|
if __name__ == '__main__':
|
|
from torch.utils.data import TensorDataset, random_split, dataloader
|
|
import torch.random
|
|
import time
|
|
|
|
torch.random.manual_seed(0)
|
|
num_train_examples = 10
|
|
training_data = torch.rand(num_train_examples, 32)
|
|
random_function = torch.rand(32)
|
|
sums = torch.sum(training_data * random_function, dim=1) ** 2
|
|
targets = torch.where(sums < 100, sums, torch.zeros_like(sums))
|
|
all_data = TensorDataset(training_data, targets)
|
|
train_data, eval_data = random_split(all_data, [0.5, 0.5])
|
|
|
|
train = AsyncDataLoader(train_data, batch_size=1, prefetch_size=4)
|
|
eval = AsyncDataLoader(eval_data, batch_size=1, prefetch_size=4)
|
|
|
|
for e in range(3):
|
|
print('e', e)
|
|
for i, _ in enumerate(train):
|
|
print("train", i)
|
|
time.sleep(1)
|
|
|
|
for i, _ in enumerate(eval):
|
|
print('eval', i)
|
|
time.sleep(1)
|