added async data loading

This commit is contained in:
Patrick 2024-07-06 22:27:09 +02:00
parent 992f84fa5a
commit 8454c73d77
2 changed files with 95 additions and 11 deletions

83
AsyncDataLoader.py Normal file
View File

@ -0,0 +1,83 @@
import queue
import threading
from typing import Optional
import torch.utils.data
class AsyncDataLoader(torch.utils.data.DataLoader):
def __init__(self, data, prefetch_size: Optional[int] = None, **kwargs):
super().__init__(data, **kwargs)
self.data_len = len(self)
self.prefetch_size = prefetch_size if prefetch_size is not None else self.data_len
self.__data_queue = queue.Queue()
self.__dataset_access = threading.Lock()
self.__dataset_access_request = threading.Event()
self.__dataset_access_request_release = threading.Event()
self.__data_extract_event = threading.Event()
self._worker_thread = threading.Thread(target=AsyncDataLoader.__async_loader, args=(self,), daemon=True)
self._worker_thread.start()
def __iter__(self):
for _ in range(self.data_len):
_, data = self.__data_queue.get()
self.__data_extract_event.set()
yield data
@staticmethod
def __async_loader(self):
self.__dataset_access.acquire(blocking=True)
while True:
for enumerated_data in enumerate(super().__iter__()):
self.__data_queue.put(enumerated_data)
while not self.__data_queue.qsize() < self.prefetch_size:
self.__dataset_access.release()
self.__data_extract_event.clear()
self.__data_extract_event.wait(10)
self.__dataset_access.acquire()
if self.__dataset_access_request.is_set():
self.__dataset_access.release()
self.__dataset_access_request_release.wait()
self.__dataset_access.acquire()
self.__dataset_access.release()
if __name__ == '__main__':
from torch.utils.data import TensorDataset, random_split, dataloader
import torch.random
import time
torch.random.manual_seed(0)
num_train_examples = 10
training_data = torch.rand(num_train_examples, 32)
random_function = torch.rand(32)
sums = torch.sum(training_data * random_function, dim=1) ** 2
targets = torch.where(sums < 100, sums, torch.zeros_like(sums))
all_data = TensorDataset(training_data, targets)
train_data, eval_data = random_split(all_data, [0.5, 0.5])
train = AsyncDataLoader(train_data, batch_size=1, prefetch_size=4)
eval = AsyncDataLoader(eval_data, batch_size=1, prefetch_size=4)
for e in range(3):
print('e', e)
for i, _ in enumerate(train):
print("train", i)
time.sleep(1)
for i, _ in enumerate(eval):
print('eval', i)
time.sleep(1)

View File

@ -8,6 +8,7 @@ from architecture import MyCNN
from dataset import ImagesDataset from dataset import ImagesDataset
from AImageDataset import AImagesDataset from AImageDataset import AImagesDataset
from AsyncDataLoader import AsyncDataLoader
def train_model(accuracies, def train_model(accuracies,
@ -22,7 +23,6 @@ def train_model(accuracies,
loss_function, loss_function,
device, device,
start_time): start_time):
torch.random.manual_seed(42) torch.random.manual_seed(42)
numpy.random.seed(42) numpy.random.seed(42)
torch.multiprocessing.set_start_method('spawn', force=True) torch.multiprocessing.set_start_method('spawn', force=True)
@ -33,16 +33,16 @@ def train_model(accuracies,
train_data, eval_data = torch.utils.data.random_split(dataset, [0.5, 0.5]) train_data, eval_data = torch.utils.data.random_split(dataset, [0.5, 0.5])
augmented_train_data = AImagesDataset(train_data, False) augmented_train_data = AImagesDataset(train_data, True)
train_loader = torch.utils.data.DataLoader(augmented_train_data, train_loader = AsyncDataLoader(augmented_train_data,
batch_size=batch_size, batch_size=batch_size,
num_workers=3, num_workers=3,
pin_memory=True, pin_memory=True,
shuffle=True) shuffle=True)
eval_loader = torch.utils.data.DataLoader(eval_data, eval_loader = AsyncDataLoader(eval_data,
batch_size=batch_size, batch_size=batch_size,
num_workers=3, num_workers=3,
pin_memory=True) pin_memory=True)
for epoch in progress_epoch.range(num_epochs): for epoch in progress_epoch.range(num_epochs):
@ -60,6 +60,7 @@ def train_model(accuracies,
for batch_nr, (image_t, transforms, img_index, class_ids, labels, paths) \ for batch_nr, (image_t, transforms, img_index, class_ids, labels, paths) \
in enumerate(progress_train_data.iter(train_loader)): in enumerate(progress_train_data.iter(train_loader)):
image_t = image_t.to(device) image_t = image_t.to(device)
class_ids = class_ids.to(device) class_ids = class_ids.to(device)