import logging
from typing import Dict, Tuple, Union

import kornia as K
import numpy as np
import torch
import torch.nn.functional as F
from image_geometry.line_groups_ransac import fit_vanishing_point
from image_geometry.line_segments import LineSegments
from librectify import find_lines
from mit_semseg.models import ModelBuilder
from scipy.ndimage import gaussian_filter1d
from skimage.util import img_as_float32
from torch import nn
from torch.tensor import Tensor
from torchvision.models.mobilenetv3 import mobilenet_v3_small
from torchvision.transforms import (CenterCrop, Compose, Normalize, Resize,
                                    ToTensor)

from ..sord import theta_from_soft_label

# These are classes hand picked from ADE2K dataset
_top_classes = torch.tensor([3,6], dtype=torch.long)-1
_bottom_classes = torch.tensor([4,7,10,12,14,22,27,30,53,61,92,95,129], dtype=torch.long)-1

class HorizonModule(nn.Module):
    def __init__(self, theta_bins=128):
        super(HorizonModule, self).__init__()

        # We use smallest model for segmentation
        self.encoder = ModelBuilder.build_encoder(
            arch='mobilenetv2dilated',
            fc_dim=320,
            weights='ckpt/ade20k-mobilenetv2dilated-c1_deepsup/encoder_epoch_20.pth').eval()

        self.decoder = ModelBuilder.build_decoder(
            arch='c1_deepsup',
            fc_dim=320,
            num_class=150,
            weights='ckpt/ade20k-mobilenetv2dilated-c1_deepsup/decoder_epoch_20.pth',
            use_softmax=True).eval()
        n_classes = 150

        # Compress 150 class maps
        self.semantic_features = nn.Conv2d(n_classes,32,1)
        
        backbone = mobilenet_v3_small(pretrained=True)
        n_ftrs = backbone.features[-1][0].out_channels  # hack to get output size
        self.image_features = nn.Sequential(
            backbone.features,
            nn.ReLU(inplace=True),
            nn.Conv2d(n_ftrs, 32, 1),
            nn.ReLU(inplace=True)
        )

        self.merge = nn.Sequential(
            nn.Conv2d(32+32, 128, 1),
            nn.ReLU(),
            nn.Conv2d(128, 128, 1),
            nn.ReLU(),
        )

        self.theta_predictor = nn.Sequential(
            nn.Linear(128, 128),
            nn.ReLU(),
            nn.Dropout(0.1),
            nn.Linear(128, theta_bins),
        )

    def forward(self, input_tensor: torch.Tensor) -> Dict[str,torch.Tensor]:

        with torch.no_grad():
            output_seg_size = input_tensor.shape[2:]
            class_maps = self.decoder(self.encoder(input_tensor, return_feature_maps=True), segSize=output_seg_size)
        # class_maps  B,150,H,W

        # These could be obtained by 1x1 conv with fixed weights
        map_top = class_maps[:,_top_classes,...].sum(1, keepdim=True)  # (B,1,H,W)
        map_bottom = class_maps[:,_bottom_classes,...].sum(1, keepdim=True)
        sem_maps = torch.cat([map_top, map_bottom], dim=1)
        # sem_maps: B,2,H,W

        # Calculate features from image
        img_ftrs = self.image_features(input_tensor)  # some backbone net
        # B,32,H',W', H' << H

        h,w = img_ftrs.shape[2:]   # _,_,H,W
        sem_ftrs = F.adaptive_max_pool2d(self.semantic_features(class_maps), (h,w))
        # B,32,H',W'

        # Merge features
        ftrs = torch.cat([img_ftrs, sem_ftrs], dim=1)
        ftrs = self.merge(ftrs)
        # B,128,H',W'

        global_ftrs:torch.Tensor = torch.squeeze(F.adaptive_max_pool2d(ftrs, 1))
        # B,128
        theta = self.theta_predictor(global_ftrs)
        # B

        return dict(
            theta=theta,
            sem_maps=sem_maps
        )


def line_groups(lines:LineSegments):
    """Generate subsets from lines based on their group"""
    line_group = lines.get_field("group")
    groups = np.unique(line_group)
    for g in groups:
        yield lines[line_group == g]


def find_zenith(lines:LineSegments, z_prior:np.ndarray):
    z_prior = np.atleast_2d(z_prior).T
    z = z_prior
    for lg in line_groups(lines):
        # verify group
        print(len(lg), lg.length().sum())
        # calc vp direction
        vp = fit_vanishing_point(lg.homogeneous())

        d = np.atleast_2d(vp[:2] / np.linalg.norm(vp[:2]))
        if np.abs(d @ z_prior) > 0.9:
            pass
    return z



def find_rho(z_dir, sem_maps:Tensor):
    # Visual horizon
    x,y = z_dir
    angle = (np.arctan2(y,x) / np.pi * 180) + 90

    sem_maps_rotated = K.rotate(sem_maps, torch.tensor(-angle)).detach().cpu().numpy()

    b_map_proj = gaussian_filter1d(np.max(sem_maps_rotated[0,0], 1), 4)
    t_map_proj = gaussian_filter1d(np.max(sem_maps_rotated[0,1], 1), 4)

    b_prob =   np.maximum.accumulate(  b_map_proj)
    t_prob = 1-np.maximum.accumulate(1-t_map_proj)
    confidence = np.log((t_prob+0.01) / (b_prob+0.01))

    t = 1
    valid_interval = (confidence<t) & (confidence>-t)
    if not np.any(valid_interval):
        valid_interval = np.ones_like(confidence, dtype=np.bool)
    valid_interval = np.nonzero(valid_interval)[0]
    rho_bottom = valid_interval.max()
    rho_top = valid_interval.min()
    rho_center = np.argmin(np.abs(confidence))

    return rho_center, rho_top, rho_bottom



def get_horizon_line(image, model:nn.Module):
    image = img_as_float32(image)
    h,w = image.shape[:2]

    transform = Compose([
        ToTensor(),
        Resize(256),
        Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
    ])
    pred_dict = model(transform(image))
    
    scale = max(h,w)
    pp = (w//2,h//2)
    lines = find_lines(image, max_size=1200, smooth=1).normalized(scale=scale, shift=pp)
    logging.info(f"Extracted {len(lines)} line segments")

    soft_theta = torch.softmax(pred_dict["theta"]).detach().cpu().numpy()
    t = soft_theta.max() / 10
    theta = theta_from_soft_label(soft_theta, offset=t)[0] - np.pi/2
    z_prior = np.array([np.cos(theta), np.sin(theta)])

    z_dir = find_zenith(lines, z_prior)
    rho = find_rho(z_dir, pred_dict["sem_maps"])
    

    endpts = [get_horizon_endpoints(zenith_direction, r, (h,w)) for r in [rho_mean, rho_top, rho_bottom]]
    
    # construct output
    out_dict = dict(
        z = zenith_direction,
        endpts = endpts,    
        lines = lines,
        top=top,
        bottom=bottom,
    )

    return out_dict




        
