import torch
from torch import nn
from torch.nn.modules.activation import ReLU
from torch.nn.modules.batchnorm import BatchNorm2d

class UNetElement(nn.Module):
    def __init__(self, f_in, f_in_up, f_in_down, f_out):
        super(UNetElement, self).__init__()
        if f_in_up is not None:
            self.downscale= nn.AvgPool2d(kernel_size=2)
        if f_in_down is not None:
            self.upscale = nn.ConvTranspose2d(f_in_down, f_in_down, kernel_size=2)
            #Upsample(scale_factor=2, mode="bilinear", align_corners=True

        f_in_total = sum(x for x in [f_in, f_in_up, f_in_down] if x is not None)
        
        self.block = nn.Sequential(
            nn.Conv2d(f_in_total, f_out, kernel_size=3, padding=(1,1)),
            nn.BatchNorm2d(f_out),
            nn.ReLU(inplace=True),
        )

    def forward(self, x, x_up, x_down):
        if x is None and x_up is None and x_down is None:
            raise ValueError
        x_in = []
        if x is not None:
            x_in.append(x)
        if x_up is not None:
            x_in.append(self.downscale(x_up))
        if x_down is not None:
            x_in.append(self.upscale(x_down))
        x_in = torch.cat(x_in, dim=1)
        return self.block(x_in)


class UNet(nn.Module):
    def __init__(self):
        super(UNet, self).__init__()
        self.m_in_0 = UNetElement(32, None, None, 64)
        self.m_in_1 = UNetElement(None, 64, None, 128)
        self.m_bottom = UNetElement(None, 128, None, 256)
        self.m_out_1 = UNetElement(128, None, 256, 128)
        self.m_out_0 = UNetElement(64, None, 128, 32)

    def forward(self, x):
        a = self.m_in_0(x, None, None)
        b = self.m_in_1(None, a, None)
        c = self.m_bottom(None, b, None)
        d = self.m_out_1(b, None, c)
        e = self.m_out_0(a, None, d)
        return e