""" This module contains the definition of the CNN model. """ import torch.nn as nn import torch from typing import Optional, List, Tuple, Dict, Any class CNN(nn.Module): def __init__( self, num_classes: int, out_channels: List[int] = [32, 64, 128, 256], conv_layers_dropout: Optional[List[float]] = None, fc_layers: List[int] = [256, 128], dropout_rate: float = 0.0, input_size: Tuple[int, int] = (80, 80), name: Optional[str] = None ): """ Args: num_classes: The number of classes in the classification task. out_channels: A list of the number of filters in each convolutional layer. conv_layers_dropout: An optional list of dropout rates to apply to the convolutional layers. fc_layers: A list of the number of neurons in each fully connected layer. dropout_rate: The dropout rate to apply to the fully connected layers. input_size: The size of the input image. Initializes the CNN model with the specified parameters. """ super(CNN, self).__init__() if name is None: name = f"CNN {out_channels} + FC {fc_layers}" self.name = name self.num_classes = num_classes self.out_channels = out_channels self.conv_layers_dropout = conv_layers_dropout self.fc_layers = fc_layers self.dropout_rate = dropout_rate self.input_size = input_size if conv_layers_dropout is None: conv_layers_dropout = [0.0] * len(out_channels) assert len(out_channels) == len(conv_layers_dropout), "Mismatch in layer count" self.conv_blocks = nn.ModuleList() in_channels = 3 for out_ch, drop in zip(out_channels, conv_layers_dropout): block = [ nn.Conv2d(in_channels, out_ch, kernel_size=3, padding=1), nn.BatchNorm2d(out_ch), nn.ReLU(inplace=True) ] if drop > 0: block.append(nn.Dropout2d(p=drop)) block.append(nn.MaxPool2d(2)) self.conv_blocks.append(nn.Sequential(*block)) in_channels = out_ch # Compute flatten_dim dynamically with torch.no_grad(): dummy_input = torch.zeros(1, 3, *input_size) for block in self.conv_blocks: dummy_input = block(dummy_input) self.flatten_dim = dummy_input.view(1, -1).shape[1] # Build FC layers dynamically fc_dims = [self.flatten_dim] + fc_layers fc_blocks = [] for in_dim, out_dim in zip(fc_dims[:-1], fc_dims[1:]): fc_blocks.append(nn.Linear(in_dim, out_dim)) fc_blocks.append(nn.ReLU(inplace=True)) fc_blocks.append(nn.Dropout(p=dropout_rate)) fc_blocks.append(nn.Linear(fc_layers[-1], num_classes)) self.fcs = nn.Sequential(*fc_blocks) def forward(self, x): for block in self.conv_blocks: x = block(x) x = x.view(x.size(0), -1) return self.fcs(x) # def classify(self, class_to_idx: Dict[str, int], x: torch.Tensor) -> int: # idx_to_class = {v:k for k, v in class_to_idx.items()} # return idx_to_class[self.forward(x).argmax().item()] def predict(self, idx_to_class: Dict[int, str], x: torch.Tensor) -> Tuple[int, float]: return idx_to_class[self.forward(x).argmax().item()], self.forward(x).max().item() @property def model_config(self) -> Dict[str, Any]: return { 'num_classes': self.num_classes, 'out_channels': self.out_channels, 'conv_layers_dropout': self.conv_layers_dropout, 'fc_layers': self.fc_layers, 'dropout_rate': self.dropout_rate, 'input_size': self.input_size, 'name': self.name } def save(self, path: str): torch.save({ 'model_state_dict': self.state_dict(), 'model_config': self.model_config }, path) @staticmethod def load(path: str) -> "CNN": checkpoint = torch.load(path) model = CNN(**checkpoint['model_config']) model.load_state_dict(checkpoint['model_state_dict']) return model