from torch import nn


def sequential_conv_block(in_channels:int, out_channels:int, n_layers:int=2) -> nn.Sequential:
    layers = list()
    for k in range(n_layers):
        if k == 0:
            ch_in, ch_out = in_channels, out_channels
        else:
            ch_in = ch_out = out_channels
        layer = [nn.Conv2d(ch_in, ch_out, kernel_size=3, padding=1),
                 nn.BatchNorm2d(ch_out),
                 nn.SiLU(inplace=True)]
        layers.extend(layer)
    layers.append(nn.MaxPool2d(2))
    layers.append(nn.ReLU())
    return nn.Sequential(*layers)


def vgg_backbone(in_channels:int=3) -> nn.Sequential:
    return nn.Sequential(
            sequential_conv_block(in_channels, 32, n_layers=2), # 24
            sequential_conv_block(32, 64, n_layers=2), # 12
            sequential_conv_block(64, 128, n_layers=2), # 6
            sequential_conv_block(128, 256, n_layers=2), # 3
        )