AI-Project/AImageDataset.py

59 lines
2.3 KiB
Python

import torch.utils.data
import numpy as np
import numpy.random as npr
import torchvision
class AImagesDataset(torch.utils.data.Dataset):
@staticmethod
def augment_tensor(t: torch.Tensor, i: int) -> (torch.Tensor, str):
match i % 7:
case 0:
return t, "Original"
case 1:
return torchvision.transforms.GaussianBlur(kernel_size=5).forward(t), "GaussianBlur"
case 2:
return torchvision.transforms.RandomRotation(degrees=180).forward(t), "RandomRotation"
case 3:
return torchvision.transforms.RandomVerticalFlip().forward(t), "RandomVerticalFlip"
case 4:
return torchvision.transforms.RandomHorizontalFlip().forward(t), "RandomHorizontalFlip"
case 5:
return torchvision.transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2,
hue=0.1).forward(t), "ColorJitter"
case 6:
rng = npr.default_rng()
return AImagesDataset.augment_tensor(
AImagesDataset.augment_tensor(
AImagesDataset.augment_tensor(t, rng.integers(1, 6))[0],
rng.integers(1, 6))[0],
rng.integers(1, 6))[0], "Compose"
@staticmethod
def augment_image(
img_np: np.ndarray,
index: int
) -> (torch.Tensor, str):
tensor = torch.from_numpy(img_np)
return AImagesDataset.augment_tensor(tensor, index)
def __init__(self, data_set: torch.utils.data.Dataset, augment=True):
super().__init__()
self.data = data_set
self.augment = augment
def __getitem__(self, index: int):
if self.augment:
image, class_id, class_name, image_filepath = self.data[index // 7]
img, transform = AImagesDataset.augment_tensor(image, index)
else:
img, class_id, class_name, image_filepath = self.data[index]
transform = "None"
return img, transform, index, class_id, class_name, image_filepath
def __len__(self):
if self.augment:
return self.data.__len__() * 7
else:
return self.data.__len__()