AI-Project/architecture.py

60 lines
2.6 KiB
Python

import torch
import torch.nn
class MyCNN(torch.nn.Module):
def __init__(self,
input_channels: int,
input_size: tuple[int, int],
hidden_channels: list[int],
output_channels: int,
use_batchnorm: bool,
kernel_size: list,
stride: list[int],
activation_function: torch.nn.Module = torch.nn.ReLU()):
super().__init__()
input_layer = torch.nn.Conv2d(in_channels=input_channels,
out_channels=hidden_channels[0],
kernel_size=kernel_size[0],
padding='same' if (stride[0] == 1 or stride[0] == 0) else 'valid',
stride=stride[0])
hidden_layers = [torch.nn.Conv2d(hidden_channels[i - 1],
hidden_channels[i],
kernel_size[i],
padding='same' if (stride[i] == 1 or stride[i] == 0) else 'valid',
stride=stride[i])
for i in range(1, len(hidden_channels))]
self.output_layer = torch.nn.Linear(hidden_channels[-1] * input_size[0] * input_size[1], output_channels)
def activation_function_repeater():
while True:
yield activation_function
layers_except_output = [input_layer,
*hidden_layers]
if use_batchnorm:
batch_norm_layers = [torch.nn.BatchNorm2d(hidden_channels[i]) for i in range(0, len(hidden_channels))]
# Adding an empty layer to not mess up list concatenation
batch_norm_layers = [*batch_norm_layers, torch.nn.BatchNorm2d(0)]
layers_except_output = [layer
for layer_tuple
in zip(layers_except_output, batch_norm_layers, activation_function_repeater())
for layer
in layer_tuple]
else:
layers_except_output = [layer
for layer_tuple in zip(layers_except_output, activation_function_repeater())
for layer in layer_tuple]
self.layers = torch.nn.Sequential(*layers_except_output)
def forward(self, input_images: torch.Tensor) -> torch.Tensor:
output = self.layers(input_images)
return self.output_layer(output.view(output.shape[0], -1))
# model = MyCNN()