import torch
from torch import nn
from torch.nn import functional as F

from .blocks import vgg_backbone


class ThetaRegressionEstimator(nn.Module):
    def __init__(self, in_channels=3):
        super(ThetaRegressionEstimator, self).__init__()
        self.backbone = vgg_backbone(in_channels)
        self.aggregate = nn.Conv2d(256, 128, kernel_size=3)
        self.proj = nn.Sequential(
            nn.Linear(128, 128),
            nn.SiLU(inplace=True),
            nn.Linear(128, 1),
        )

    def forward(self, x):
        # backbone gives  (B,128,H/16,W/16), aggregate (B,64,1,1)
        features = F.relu(self.aggregate(self.backbone(x)).squeeze())
        theta = self.proj(features)
        return dict(theta=theta)


class ThetaSORDEstimator(nn.Module):
    def __init__(self, in_channels=3, theta_bins=100):
        super(ThetaSORDEstimator, self).__init__()
        self.backbone = vgg_backbone(in_channels)
        self.aggregate = nn.Conv2d(256, 128, kernel_size=8)
        self.proj = nn.Sequential(
            nn.Dropout(0.2),
            nn.Linear(128, 128),
            nn.SiLU(inplace=True),
            nn.Linear(128, theta_bins),
        )

    def forward(self, x):
        B = x.shape[0]
        features = F.relu(self.aggregate(self.backbone(x)).view(B,-1))
        theta_logits = self.proj(features)
        theta = torch.log_softmax(theta_logits, dim=1)
        return dict(theta=theta)


class ZenithVectorEstimator(nn.Module):
    def __init__(self, in_channels:int=3):
        super(ZenithVectorEstimator, self).__init__()
        self.backbone = vgg_backbone(in_channels=in_channels)  # out B,256,8,8
        self.reduce = nn.Conv2d(256, 32, 1)
        self.zenith = nn.Conv2d(32, 2, 1)

    def forward(self, x):
        ftrs = F.relu(self.reduce(self.backbone(x)))
        z = self.zenith(ftrs)
        z = torch.mean(z, dim=[2,3])
        return z / z.norm(dim=1, p=2, keepdim=True) + 1e-8
