added async data loading
This commit is contained in:
parent
992f84fa5a
commit
8454c73d77
|
|
@ -0,0 +1,83 @@
|
|||
import queue
|
||||
import threading
|
||||
from typing import Optional
|
||||
|
||||
import torch.utils.data
|
||||
|
||||
|
||||
class AsyncDataLoader(torch.utils.data.DataLoader):
|
||||
|
||||
def __init__(self, data, prefetch_size: Optional[int] = None, **kwargs):
|
||||
super().__init__(data, **kwargs)
|
||||
|
||||
self.data_len = len(self)
|
||||
|
||||
self.prefetch_size = prefetch_size if prefetch_size is not None else self.data_len
|
||||
self.__data_queue = queue.Queue()
|
||||
self.__dataset_access = threading.Lock()
|
||||
self.__dataset_access_request = threading.Event()
|
||||
self.__dataset_access_request_release = threading.Event()
|
||||
self.__data_extract_event = threading.Event()
|
||||
|
||||
self._worker_thread = threading.Thread(target=AsyncDataLoader.__async_loader, args=(self,), daemon=True)
|
||||
self._worker_thread.start()
|
||||
|
||||
def __iter__(self):
|
||||
for _ in range(self.data_len):
|
||||
_, data = self.__data_queue.get()
|
||||
self.__data_extract_event.set()
|
||||
yield data
|
||||
|
||||
@staticmethod
|
||||
def __async_loader(self):
|
||||
self.__dataset_access.acquire(blocking=True)
|
||||
|
||||
while True:
|
||||
|
||||
for enumerated_data in enumerate(super().__iter__()):
|
||||
self.__data_queue.put(enumerated_data)
|
||||
|
||||
while not self.__data_queue.qsize() < self.prefetch_size:
|
||||
self.__dataset_access.release()
|
||||
|
||||
self.__data_extract_event.clear()
|
||||
self.__data_extract_event.wait(10)
|
||||
|
||||
self.__dataset_access.acquire()
|
||||
|
||||
if self.__dataset_access_request.is_set():
|
||||
self.__dataset_access.release()
|
||||
self.__dataset_access_request_release.wait()
|
||||
|
||||
self.__dataset_access.acquire()
|
||||
|
||||
|
||||
self.__dataset_access.release()
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
from torch.utils.data import TensorDataset, random_split, dataloader
|
||||
import torch.random
|
||||
import time
|
||||
|
||||
torch.random.manual_seed(0)
|
||||
num_train_examples = 10
|
||||
training_data = torch.rand(num_train_examples, 32)
|
||||
random_function = torch.rand(32)
|
||||
sums = torch.sum(training_data * random_function, dim=1) ** 2
|
||||
targets = torch.where(sums < 100, sums, torch.zeros_like(sums))
|
||||
all_data = TensorDataset(training_data, targets)
|
||||
train_data, eval_data = random_split(all_data, [0.5, 0.5])
|
||||
|
||||
train = AsyncDataLoader(train_data, batch_size=1, prefetch_size=4)
|
||||
eval = AsyncDataLoader(eval_data, batch_size=1, prefetch_size=4)
|
||||
|
||||
for e in range(3):
|
||||
print('e', e)
|
||||
for i, _ in enumerate(train):
|
||||
print("train", i)
|
||||
time.sleep(1)
|
||||
|
||||
for i, _ in enumerate(eval):
|
||||
print('eval', i)
|
||||
time.sleep(1)
|
||||
23
cnn_train.py
23
cnn_train.py
|
|
@ -8,6 +8,7 @@ from architecture import MyCNN
|
|||
from dataset import ImagesDataset
|
||||
|
||||
from AImageDataset import AImagesDataset
|
||||
from AsyncDataLoader import AsyncDataLoader
|
||||
|
||||
|
||||
def train_model(accuracies,
|
||||
|
|
@ -22,7 +23,6 @@ def train_model(accuracies,
|
|||
loss_function,
|
||||
device,
|
||||
start_time):
|
||||
|
||||
torch.random.manual_seed(42)
|
||||
numpy.random.seed(42)
|
||||
torch.multiprocessing.set_start_method('spawn', force=True)
|
||||
|
|
@ -33,16 +33,16 @@ def train_model(accuracies,
|
|||
|
||||
train_data, eval_data = torch.utils.data.random_split(dataset, [0.5, 0.5])
|
||||
|
||||
augmented_train_data = AImagesDataset(train_data, False)
|
||||
train_loader = torch.utils.data.DataLoader(augmented_train_data,
|
||||
batch_size=batch_size,
|
||||
num_workers=3,
|
||||
pin_memory=True,
|
||||
shuffle=True)
|
||||
eval_loader = torch.utils.data.DataLoader(eval_data,
|
||||
batch_size=batch_size,
|
||||
num_workers=3,
|
||||
pin_memory=True)
|
||||
augmented_train_data = AImagesDataset(train_data, True)
|
||||
train_loader = AsyncDataLoader(augmented_train_data,
|
||||
batch_size=batch_size,
|
||||
num_workers=3,
|
||||
pin_memory=True,
|
||||
shuffle=True)
|
||||
eval_loader = AsyncDataLoader(eval_data,
|
||||
batch_size=batch_size,
|
||||
num_workers=3,
|
||||
pin_memory=True)
|
||||
|
||||
for epoch in progress_epoch.range(num_epochs):
|
||||
|
||||
|
|
@ -60,6 +60,7 @@ def train_model(accuracies,
|
|||
|
||||
for batch_nr, (image_t, transforms, img_index, class_ids, labels, paths) \
|
||||
in enumerate(progress_train_data.iter(train_loader)):
|
||||
|
||||
image_t = image_t.to(device)
|
||||
class_ids = class_ids.to(device)
|
||||
|
||||
|
|
|
|||
Loading…
Reference in New Issue