revamped architecture
This commit is contained in:
parent
e9d86f3309
commit
992f84fa5a
122
architecture.py
122
architecture.py
|
|
@ -5,55 +5,97 @@ import torch.nn
|
|||
class MyCNN(torch.nn.Module):
|
||||
def __init__(self,
|
||||
input_channels: int,
|
||||
input_size: tuple[int, int],
|
||||
hidden_channels: list[int],
|
||||
output_channels: int,
|
||||
use_batchnorm: bool,
|
||||
kernel_size: list,
|
||||
stride: list[int],
|
||||
activation_function: torch.nn.Module = torch.nn.ReLU()):
|
||||
input_size: tuple[int, int]):
|
||||
super().__init__()
|
||||
|
||||
input_layer = torch.nn.Conv2d(in_channels=input_channels,
|
||||
out_channels=hidden_channels[0],
|
||||
kernel_size=kernel_size[0],
|
||||
padding='same' if (stride[0] == 1 or stride[0] == 0) else 'valid',
|
||||
stride=stride[0])
|
||||
hidden_layers = [torch.nn.Conv2d(hidden_channels[i - 1],
|
||||
hidden_channels[i],
|
||||
kernel_size[i],
|
||||
padding='same' if (stride[i] == 1 or stride[i] == 0) else 'valid',
|
||||
stride=stride[i])
|
||||
for i in range(1, len(hidden_channels))]
|
||||
self.output_layer = torch.nn.Linear(hidden_channels[-1] * input_size[0] * input_size[1], output_channels)
|
||||
# input_layer = torch.nn.Conv2d(in_channels=input_channels,
|
||||
# out_channels=hidden_channels[0],
|
||||
# kernel_size=kernel_size[0],
|
||||
# padding='same' if (stride[0] == 1 or stride[0] == 0) else 'valid',
|
||||
# stride=stride[0],
|
||||
# bias=not use_batchnorm)
|
||||
# hidden_layers = [torch.nn.Conv2d(hidden_channels[i - 1],
|
||||
# hidden_channels[i],
|
||||
# kernel_size[i],
|
||||
# padding='same' if (stride[i] == 1 or stride[i] == 0) else 'valid',
|
||||
# stride=stride[i],
|
||||
# bias=not use_batchnorm)
|
||||
# for i in range(1, len(hidden_channels))]
|
||||
# self.output_layer = torch.nn.Linear(hidden_channels[-1] * input_size[0] * input_size[1], output_channels)
|
||||
|
||||
def activation_function_repeater():
|
||||
while True:
|
||||
yield activation_function
|
||||
self.layers = torch.nn.Sequential(
|
||||
torch.nn.Conv2d(in_channels=input_channels, out_channels=64, kernel_size=3, padding='same', bias=False),
|
||||
torch.nn.BatchNorm2d(64),
|
||||
torch.nn.MaxPool2d(kernel_size=3, stride=2),
|
||||
torch.nn.ReLU(),
|
||||
|
||||
layers_except_output = [input_layer,
|
||||
*hidden_layers]
|
||||
torch.nn.Conv2d(in_channels=64, out_channels=128, kernel_size=3, padding='same', bias=False),
|
||||
torch.nn.BatchNorm2d(128),
|
||||
torch.nn.MaxPool2d(kernel_size=3, stride=2),
|
||||
torch.nn.ReLU(),
|
||||
|
||||
if use_batchnorm:
|
||||
batch_norm_layers = [torch.nn.BatchNorm2d(hidden_channels[i]) for i in range(0, len(hidden_channels))]
|
||||
# Adding an empty layer to not mess up list concatenation
|
||||
batch_norm_layers = [*batch_norm_layers, torch.nn.BatchNorm2d(0)]
|
||||
torch.nn.Conv2d(in_channels=128, out_channels=256, kernel_size=3, padding='same', bias=False),
|
||||
torch.nn.BatchNorm2d(256),
|
||||
torch.nn.MaxPool2d(kernel_size=3, stride=2),
|
||||
torch.nn.ReLU(),
|
||||
|
||||
layers_except_output = [layer
|
||||
for layer_tuple
|
||||
in zip(layers_except_output, batch_norm_layers, activation_function_repeater())
|
||||
for layer
|
||||
in layer_tuple]
|
||||
else:
|
||||
layers_except_output = [layer
|
||||
for layer_tuple in zip(layers_except_output, activation_function_repeater())
|
||||
for layer in layer_tuple]
|
||||
torch.nn.Conv2d(in_channels=256, out_channels=512, kernel_size=3, padding='same', bias=False),
|
||||
torch.nn.BatchNorm2d(512),
|
||||
torch.nn.MaxPool2d(kernel_size=3, stride=2),
|
||||
torch.nn.ReLU(),
|
||||
|
||||
self.layers = torch.nn.Sequential(*layers_except_output)
|
||||
torch.nn.Flatten(),
|
||||
torch.nn.Linear(in_features=12800, out_features=4096, bias=False),
|
||||
torch.nn.ReLU(),
|
||||
torch.nn.Linear(in_features=4096, out_features=4096, bias=False),
|
||||
torch.nn.ReLU(),
|
||||
torch.nn.Linear(in_features=4096, out_features=20, bias=False),
|
||||
torch.nn.Softmax(dim=1)
|
||||
)
|
||||
|
||||
# self.layers = torch.nn.Sequential(
|
||||
# torch.nn.Conv2d(in_channels=input_channels, out_channels=64, kernel_size=1, padding='same', bias=False),
|
||||
# torch.nn.BatchNorm2d(64),
|
||||
# torch.nn.ReLU(),
|
||||
# torch.nn.Conv2d(in_channels=64, out_channels=64, kernel_size=7, padding='same', bias=False),
|
||||
# torch.nn.BatchNorm2d(64),
|
||||
# torch.nn.ReLU(),
|
||||
# torch.nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, padding='same', bias=False),
|
||||
# torch.nn.BatchNorm2d(64),
|
||||
# torch.nn.ReLU(),
|
||||
# torch.nn.MaxPool2d(kernel_size=3, padding=1),
|
||||
# torch.nn.Dropout2d(0.1),
|
||||
# torch.nn.Conv2d(in_channels=64, out_channels=32, kernel_size=5, padding='same', bias=False),
|
||||
# torch.nn.BatchNorm2d(32),
|
||||
# torch.nn.ReLU(),
|
||||
# torch.nn.Conv2d(in_channels=32, out_channels=32, kernel_size=3, padding='same', bias=False),
|
||||
# torch.nn.BatchNorm2d(32),
|
||||
# torch.nn.ReLU(),
|
||||
# torch.nn.Conv2d(in_channels=32, out_channels=32, kernel_size=1, padding='same', bias=False),
|
||||
# torch.nn.BatchNorm2d(32),
|
||||
# torch.nn.ReLU(),
|
||||
# torch.nn.MaxPool2d(kernel_size=3, padding=1),
|
||||
# torch.nn.Dropout2d(0.1),
|
||||
# torch.nn.Conv2d(in_channels=32, out_channels=16, kernel_size=3, padding='same', bias=False),
|
||||
# torch.nn.BatchNorm2d(16),
|
||||
# torch.nn.ReLU(),
|
||||
# torch.nn.Conv2d(in_channels=16, out_channels=16, kernel_size=1, padding='same', bias=False),
|
||||
# torch.nn.BatchNorm2d(16),
|
||||
# torch.nn.ReLU(),
|
||||
# torch.nn.MaxPool2d(kernel_size=3, padding=1),
|
||||
# torch.nn.Flatten(),
|
||||
# torch.nn.Dropout(0.25),
|
||||
# torch.nn.Linear(in_features=256, out_features=512),
|
||||
# torch.nn.ReLU(),
|
||||
# torch.nn.Linear(in_features=512, out_features=20, bias=False),
|
||||
# # torch.nn.Softmax(dim=1),
|
||||
# )
|
||||
|
||||
def forward(self, input_images: torch.Tensor) -> torch.Tensor:
|
||||
output = self.layers(input_images)
|
||||
return self.layers(input_images)
|
||||
|
||||
def __repr__(self):
|
||||
return str(self.layers)
|
||||
|
||||
return self.output_layer(output.view(output.shape[0], -1))
|
||||
|
||||
# model = MyCNN()
|
||||
|
|
|
|||
Loading…
Reference in New Issue