added possibility to turn off data augmentation

This commit is contained in:
Patrick 2024-07-06 18:16:46 +02:00
parent 6031c15723
commit 4dc1ab37d0
1 changed files with 13 additions and 4 deletions

View File

@ -37,14 +37,23 @@ class AImagesDataset(torch.utils.data.Dataset):
tensor = torch.from_numpy(img_np)
return AImagesDataset.augment_tensor(tensor, index)
def __init__(self, data_set: torch.utils.data.Dataset):
def __init__(self, data_set: torch.utils.data.Dataset, augment=True):
super().__init__()
self.data = data_set
self.augment = augment
def __getitem__(self, index: int):
image, class_id, class_name, image_filepath = self.data[index // 7]
img, transform = AImagesDataset.augment_tensor(image, index)
if self.augment:
image, class_id, class_name, image_filepath = self.data[index // 7]
img, transform = AImagesDataset.augment_tensor(image, index)
else:
img, class_id, class_name, image_filepath = self.data[index]
transform = "None"
return img, transform, index, class_id, class_name, image_filepath
def __len__(self):
return self.data.__len__() * 7
if self.augment:
return self.data.__len__() * 7
else:
return self.data.__len__()