import torch import torch.nn class MyCNN(torch.nn.Module): def __init__(self, input_channels: int, input_size: tuple[int, int], hidden_channels: list[int], output_channels: int, use_batchnorm: bool, kernel_size: list, stride: list[int], activation_function: torch.nn.Module = torch.nn.ReLU()): super().__init__() input_layer = torch.nn.Conv2d(in_channels=input_channels, out_channels=hidden_channels[0], kernel_size=kernel_size[0], padding='same' if (stride[0] == 1 or stride[0] == 0) else 'valid', stride=stride[0]) hidden_layers = [torch.nn.Conv2d(hidden_channels[i - 1], hidden_channels[i], kernel_size[i], padding='same' if (stride[i] == 1 or stride[i] == 0) else 'valid', stride=stride[i]) for i in range(1, len(hidden_channels))] self.output_layer = torch.nn.Linear(hidden_channels[-1] * input_size[0] * input_size[1], output_channels) def activation_function_repeater(): while True: yield activation_function layers_except_output = [input_layer, *hidden_layers] if use_batchnorm: batch_norm_layers = [torch.nn.BatchNorm2d(hidden_channels[i]) for i in range(0, len(hidden_channels))] # Adding an empty layer to not mess up list concatenation batch_norm_layers = [*batch_norm_layers, torch.nn.BatchNorm2d(0)] layers_except_output = [layer for layer_tuple in zip(layers_except_output, batch_norm_layers, activation_function_repeater()) for layer in layer_tuple] else: layers_except_output = [layer for layer_tuple in zip(layers_except_output, activation_function_repeater()) for layer in layer_tuple] self.layers = torch.nn.Sequential(*layers_except_output) def forward(self, input_images: torch.Tensor) -> torch.Tensor: output = self.layers(input_images) return self.output_layer(output.view(output.shape[0], -1)) # model = MyCNN()