added possibility to turn off data augmentation
This commit is contained in:
parent
6031c15723
commit
4dc1ab37d0
|
|
@ -37,14 +37,23 @@ class AImagesDataset(torch.utils.data.Dataset):
|
||||||
tensor = torch.from_numpy(img_np)
|
tensor = torch.from_numpy(img_np)
|
||||||
return AImagesDataset.augment_tensor(tensor, index)
|
return AImagesDataset.augment_tensor(tensor, index)
|
||||||
|
|
||||||
def __init__(self, data_set: torch.utils.data.Dataset):
|
def __init__(self, data_set: torch.utils.data.Dataset, augment=True):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.data = data_set
|
self.data = data_set
|
||||||
|
self.augment = augment
|
||||||
|
|
||||||
def __getitem__(self, index: int):
|
def __getitem__(self, index: int):
|
||||||
image, class_id, class_name, image_filepath = self.data[index // 7]
|
if self.augment:
|
||||||
img, transform = AImagesDataset.augment_tensor(image, index)
|
image, class_id, class_name, image_filepath = self.data[index // 7]
|
||||||
|
img, transform = AImagesDataset.augment_tensor(image, index)
|
||||||
|
else:
|
||||||
|
img, class_id, class_name, image_filepath = self.data[index]
|
||||||
|
transform = "None"
|
||||||
|
|
||||||
return img, transform, index, class_id, class_name, image_filepath
|
return img, transform, index, class_id, class_name, image_filepath
|
||||||
|
|
||||||
def __len__(self):
|
def __len__(self):
|
||||||
return self.data.__len__() * 7
|
if self.augment:
|
||||||
|
return self.data.__len__() * 7
|
||||||
|
else:
|
||||||
|
return self.data.__len__()
|
||||||
Loading…
Reference in New Issue