import torch.utils.data import numpy as np import numpy.random as npr import torchvision class AImagesDataset(torch.utils.data.Dataset): @staticmethod def augment_tensor(t: torch.Tensor, i: int) -> (torch.Tensor, str): match i % 7: case 0: return t, "Original" case 1: return torchvision.transforms.GaussianBlur(kernel_size=5).forward(t), "GaussianBlur" case 2: return torchvision.transforms.RandomRotation(degrees=180).forward(t), "RandomRotation" case 3: return torchvision.transforms.RandomVerticalFlip().forward(t), "RandomVerticalFlip" case 4: return torchvision.transforms.RandomHorizontalFlip().forward(t), "RandomHorizontalFlip" case 5: return torchvision.transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.1).forward(t), "ColorJitter" case 6: rng = npr.default_rng() return AImagesDataset.augment_tensor( AImagesDataset.augment_tensor( AImagesDataset.augment_tensor(t, rng.integers(1, 6))[0], rng.integers(1, 6))[0], rng.integers(1, 6))[0], "Compose" @staticmethod def augment_image( img_np: np.ndarray, index: int ) -> (torch.Tensor, str): tensor = torch.from_numpy(img_np) return AImagesDataset.augment_tensor(tensor, index) def __init__(self, data_set: torch.utils.data.Dataset, augment=True): super().__init__() self.data = data_set self.augment = augment def __getitem__(self, index: int): if self.augment: image, class_id, class_name, image_filepath = self.data[index // 7] img, transform = AImagesDataset.augment_tensor(image, index) else: img, class_id, class_name, image_filepath = self.data[index] transform = "None" return img, transform, index, class_id, class_name, image_filepath def __len__(self): if self.augment: return self.data.__len__() * 7 else: return self.data.__len__()