added async data loading
This commit is contained in:
parent
992f84fa5a
commit
8454c73d77
|
|
@ -0,0 +1,83 @@
|
||||||
|
import queue
|
||||||
|
import threading
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
import torch.utils.data
|
||||||
|
|
||||||
|
|
||||||
|
class AsyncDataLoader(torch.utils.data.DataLoader):
|
||||||
|
|
||||||
|
def __init__(self, data, prefetch_size: Optional[int] = None, **kwargs):
|
||||||
|
super().__init__(data, **kwargs)
|
||||||
|
|
||||||
|
self.data_len = len(self)
|
||||||
|
|
||||||
|
self.prefetch_size = prefetch_size if prefetch_size is not None else self.data_len
|
||||||
|
self.__data_queue = queue.Queue()
|
||||||
|
self.__dataset_access = threading.Lock()
|
||||||
|
self.__dataset_access_request = threading.Event()
|
||||||
|
self.__dataset_access_request_release = threading.Event()
|
||||||
|
self.__data_extract_event = threading.Event()
|
||||||
|
|
||||||
|
self._worker_thread = threading.Thread(target=AsyncDataLoader.__async_loader, args=(self,), daemon=True)
|
||||||
|
self._worker_thread.start()
|
||||||
|
|
||||||
|
def __iter__(self):
|
||||||
|
for _ in range(self.data_len):
|
||||||
|
_, data = self.__data_queue.get()
|
||||||
|
self.__data_extract_event.set()
|
||||||
|
yield data
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def __async_loader(self):
|
||||||
|
self.__dataset_access.acquire(blocking=True)
|
||||||
|
|
||||||
|
while True:
|
||||||
|
|
||||||
|
for enumerated_data in enumerate(super().__iter__()):
|
||||||
|
self.__data_queue.put(enumerated_data)
|
||||||
|
|
||||||
|
while not self.__data_queue.qsize() < self.prefetch_size:
|
||||||
|
self.__dataset_access.release()
|
||||||
|
|
||||||
|
self.__data_extract_event.clear()
|
||||||
|
self.__data_extract_event.wait(10)
|
||||||
|
|
||||||
|
self.__dataset_access.acquire()
|
||||||
|
|
||||||
|
if self.__dataset_access_request.is_set():
|
||||||
|
self.__dataset_access.release()
|
||||||
|
self.__dataset_access_request_release.wait()
|
||||||
|
|
||||||
|
self.__dataset_access.acquire()
|
||||||
|
|
||||||
|
|
||||||
|
self.__dataset_access.release()
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
from torch.utils.data import TensorDataset, random_split, dataloader
|
||||||
|
import torch.random
|
||||||
|
import time
|
||||||
|
|
||||||
|
torch.random.manual_seed(0)
|
||||||
|
num_train_examples = 10
|
||||||
|
training_data = torch.rand(num_train_examples, 32)
|
||||||
|
random_function = torch.rand(32)
|
||||||
|
sums = torch.sum(training_data * random_function, dim=1) ** 2
|
||||||
|
targets = torch.where(sums < 100, sums, torch.zeros_like(sums))
|
||||||
|
all_data = TensorDataset(training_data, targets)
|
||||||
|
train_data, eval_data = random_split(all_data, [0.5, 0.5])
|
||||||
|
|
||||||
|
train = AsyncDataLoader(train_data, batch_size=1, prefetch_size=4)
|
||||||
|
eval = AsyncDataLoader(eval_data, batch_size=1, prefetch_size=4)
|
||||||
|
|
||||||
|
for e in range(3):
|
||||||
|
print('e', e)
|
||||||
|
for i, _ in enumerate(train):
|
||||||
|
print("train", i)
|
||||||
|
time.sleep(1)
|
||||||
|
|
||||||
|
for i, _ in enumerate(eval):
|
||||||
|
print('eval', i)
|
||||||
|
time.sleep(1)
|
||||||
23
cnn_train.py
23
cnn_train.py
|
|
@ -8,6 +8,7 @@ from architecture import MyCNN
|
||||||
from dataset import ImagesDataset
|
from dataset import ImagesDataset
|
||||||
|
|
||||||
from AImageDataset import AImagesDataset
|
from AImageDataset import AImagesDataset
|
||||||
|
from AsyncDataLoader import AsyncDataLoader
|
||||||
|
|
||||||
|
|
||||||
def train_model(accuracies,
|
def train_model(accuracies,
|
||||||
|
|
@ -22,7 +23,6 @@ def train_model(accuracies,
|
||||||
loss_function,
|
loss_function,
|
||||||
device,
|
device,
|
||||||
start_time):
|
start_time):
|
||||||
|
|
||||||
torch.random.manual_seed(42)
|
torch.random.manual_seed(42)
|
||||||
numpy.random.seed(42)
|
numpy.random.seed(42)
|
||||||
torch.multiprocessing.set_start_method('spawn', force=True)
|
torch.multiprocessing.set_start_method('spawn', force=True)
|
||||||
|
|
@ -33,16 +33,16 @@ def train_model(accuracies,
|
||||||
|
|
||||||
train_data, eval_data = torch.utils.data.random_split(dataset, [0.5, 0.5])
|
train_data, eval_data = torch.utils.data.random_split(dataset, [0.5, 0.5])
|
||||||
|
|
||||||
augmented_train_data = AImagesDataset(train_data, False)
|
augmented_train_data = AImagesDataset(train_data, True)
|
||||||
train_loader = torch.utils.data.DataLoader(augmented_train_data,
|
train_loader = AsyncDataLoader(augmented_train_data,
|
||||||
batch_size=batch_size,
|
batch_size=batch_size,
|
||||||
num_workers=3,
|
num_workers=3,
|
||||||
pin_memory=True,
|
pin_memory=True,
|
||||||
shuffle=True)
|
shuffle=True)
|
||||||
eval_loader = torch.utils.data.DataLoader(eval_data,
|
eval_loader = AsyncDataLoader(eval_data,
|
||||||
batch_size=batch_size,
|
batch_size=batch_size,
|
||||||
num_workers=3,
|
num_workers=3,
|
||||||
pin_memory=True)
|
pin_memory=True)
|
||||||
|
|
||||||
for epoch in progress_epoch.range(num_epochs):
|
for epoch in progress_epoch.range(num_epochs):
|
||||||
|
|
||||||
|
|
@ -60,6 +60,7 @@ def train_model(accuracies,
|
||||||
|
|
||||||
for batch_nr, (image_t, transforms, img_index, class_ids, labels, paths) \
|
for batch_nr, (image_t, transforms, img_index, class_ids, labels, paths) \
|
||||||
in enumerate(progress_train_data.iter(train_loader)):
|
in enumerate(progress_train_data.iter(train_loader)):
|
||||||
|
|
||||||
image_t = image_t.to(device)
|
image_t = image_t.to(device)
|
||||||
class_ids = class_ids.to(device)
|
class_ids = class_ids.to(device)
|
||||||
|
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue