From 4dc1ab37d00183e9523c155f68ed590e339f240c Mon Sep 17 00:00:00 2001 From: Patrick Date: Sat, 6 Jul 2024 18:16:46 +0200 Subject: [PATCH] added possibility to turn off data augmentation --- AImageDataset.py | 17 +++++++++++++---- 1 file changed, 13 insertions(+), 4 deletions(-) diff --git a/AImageDataset.py b/AImageDataset.py index e3f2ae6..53d7872 100644 --- a/AImageDataset.py +++ b/AImageDataset.py @@ -37,14 +37,23 @@ class AImagesDataset(torch.utils.data.Dataset): tensor = torch.from_numpy(img_np) return AImagesDataset.augment_tensor(tensor, index) - def __init__(self, data_set: torch.utils.data.Dataset): + def __init__(self, data_set: torch.utils.data.Dataset, augment=True): super().__init__() self.data = data_set + self.augment = augment def __getitem__(self, index: int): - image, class_id, class_name, image_filepath = self.data[index // 7] - img, transform = AImagesDataset.augment_tensor(image, index) + if self.augment: + image, class_id, class_name, image_filepath = self.data[index // 7] + img, transform = AImagesDataset.augment_tensor(image, index) + else: + img, class_id, class_name, image_filepath = self.data[index] + transform = "None" + return img, transform, index, class_id, class_name, image_filepath def __len__(self): - return self.data.__len__() * 7 \ No newline at end of file + if self.augment: + return self.data.__len__() * 7 + else: + return self.data.__len__() \ No newline at end of file