50 lines
2.0 KiB
Python
50 lines
2.0 KiB
Python
import torch.utils.data
|
|
import numpy as np
|
|
import numpy.random as npr
|
|
import torchvision
|
|
|
|
|
|
class AImagesDataset(torch.utils.data.Dataset):
|
|
@staticmethod
|
|
def augment_tensor(t: torch.Tensor, i: int) -> (torch.Tensor, str):
|
|
match i % 7:
|
|
case 0:
|
|
return t, "Original"
|
|
case 1:
|
|
return torchvision.transforms.GaussianBlur(kernel_size=5).forward(t), "GaussianBlur"
|
|
case 2:
|
|
return torchvision.transforms.RandomRotation(degrees=180).forward(t), "RandomRotation"
|
|
case 3:
|
|
return torchvision.transforms.RandomVerticalFlip().forward(t), "RandomVerticalFlip"
|
|
case 4:
|
|
return torchvision.transforms.RandomHorizontalFlip().forward(t), "RandomHorizontalFlip"
|
|
case 5:
|
|
return torchvision.transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2,
|
|
hue=0.1).forward(t), "ColorJitter"
|
|
case 6:
|
|
rng = npr.default_rng()
|
|
return AImagesDataset.augment_tensor(
|
|
AImagesDataset.augment_tensor(
|
|
AImagesDataset.augment_tensor(t, rng.integers(1, 6))[0],
|
|
rng.integers(1, 6))[0],
|
|
rng.integers(1, 6))[0], "Compose"
|
|
|
|
@staticmethod
|
|
def augment_image(
|
|
img_np: np.ndarray,
|
|
index: int
|
|
) -> (torch.Tensor, str):
|
|
tensor = torch.from_numpy(img_np)
|
|
return AImagesDataset.augment_tensor(tensor, index)
|
|
|
|
def __init__(self, data_set: torch.utils.data.Dataset):
|
|
super().__init__()
|
|
self.data = data_set
|
|
|
|
def __getitem__(self, index: int):
|
|
image, class_id, class_name, image_filepath = self.data[index // 7]
|
|
img, transform = AImagesDataset.augment_tensor(image, index)
|
|
return img, transform, index, class_id, class_name, image_filepath
|
|
|
|
def __len__(self):
|
|
return self.data.__len__() * 7 |