added possibility to turn off data augmentation
This commit is contained in:
parent
6031c15723
commit
4dc1ab37d0
|
|
@ -37,14 +37,23 @@ class AImagesDataset(torch.utils.data.Dataset):
|
|||
tensor = torch.from_numpy(img_np)
|
||||
return AImagesDataset.augment_tensor(tensor, index)
|
||||
|
||||
def __init__(self, data_set: torch.utils.data.Dataset):
|
||||
def __init__(self, data_set: torch.utils.data.Dataset, augment=True):
|
||||
super().__init__()
|
||||
self.data = data_set
|
||||
self.augment = augment
|
||||
|
||||
def __getitem__(self, index: int):
|
||||
if self.augment:
|
||||
image, class_id, class_name, image_filepath = self.data[index // 7]
|
||||
img, transform = AImagesDataset.augment_tensor(image, index)
|
||||
else:
|
||||
img, class_id, class_name, image_filepath = self.data[index]
|
||||
transform = "None"
|
||||
|
||||
return img, transform, index, class_id, class_name, image_filepath
|
||||
|
||||
def __len__(self):
|
||||
if self.augment:
|
||||
return self.data.__len__() * 7
|
||||
else:
|
||||
return self.data.__len__()
|
||||
Loading…
Reference in New Issue