import logging
from typing import Union

import numpy as np
import torch
import torch.nn.functional as F
from torch.tensor import Tensor
from image_geometry.line_groups_ransac import find_line_groups_ransac_only, fit_vanishing_point
from image_geometry.line_segments import LineSegments, find_line_segments_ff
from skimage.transform import rescale, resize
from skimage.measure import block_reduce
from skimage.util import img_as_float32
from torch import nn
from torchvision.transforms import Compose, Normalize, ToTensor

from ..sord import rho_from_soft_label, theta_from_soft_label


class HorizonModule(nn.Module):
    def __init__(self, encoder, decoder, theta_bins=100, rho_bins=100):
        super(HorizonModule, self).__init__()
        
        self.encoder = encoder
        self.decoder = decoder

        self.image_features = nn.Sequential(
            nn.Conv2d(3, 32, kernel_size=3, padding=(1,1)),
            nn.BatchNorm2d(32),
            nn.ReLU(inplace=True),
            nn.Conv2d(32, 32, kernel_size=3, padding=(1,1)),
            nn.BatchNorm2d(32),
        )

        self.features = nn.Sequential(
            nn.Conv2d(64, 32, kernel_size=3),
            nn.BatchNorm2d(32),
            nn.ReLU(inplace=True),
            nn.Conv2d(32, 32, kernel_size=3),
            nn.BatchNorm2d(32),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(2),

            nn.Conv2d(32, 64, kernel_size=3),
            nn.BatchNorm2d(64),
            nn.ReLU(inplace=True),
            nn.Conv2d(64, 64, kernel_size=3),
            nn.BatchNorm2d(64),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(2),

            nn.Conv2d(64, 128, kernel_size=3),
            nn.BatchNorm2d(128),
            nn.ReLU(inplace=True),
            nn.Conv2d(128, 128, kernel_size=3),
            nn.BatchNorm2d(128),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(2),
        )

        self.theta_predictor = nn.Sequential(
            nn.Linear(128, 128),
            nn.ReLU(inplace=True),
            nn.Dropout(0.1),
            nn.Linear(128, 128),
            nn.ReLU(inplace=True),
            nn.Linear(128, theta_bins),   # Logits
        )

        # self.rho_predictor = nn.Sequential(
        #     nn.Linear(128, 128),
        #     nn.ReLU(inplace=True),
        #     nn.Dropout(0.1),
        #     nn.Linear(128, 128),
        #     nn.ReLU(inplace=True),
        #     nn.Linear(128, rho_bins),  # Logits
        # )

        self.sem_map = nn.Conv2d(150,32,1)  # Compression of maps

        # Define what classes constitute high level semantic classes - handpicked
        self.top_classes = torch.tensor([3,6], dtype=torch.long)-1
        self.bottom_classes = torch.tensor([4,7,10,12,14,22,27,30,53,61,92,95,129], dtype=torch.long)-1
        # self.mask_classes = torch.tensor([1,2,4,6,7,8,9,11,12,15,16,23,25,26,34,36,39,41,42,43,44,45,49,51,52,53,54,55,56,57,59,62,63,64,68,70,74,78,85,88,89,90,94,101,108,111,122,124,125,131], dtype=torch.long)-1

    def forward(self, img_dict:dict):
        rgb_image = img_dict["image"]

        # Get semantic segmentation
        output_seg_size = rgb_image.shape[2:]
        with torch.no_grad():  # We do not want to update te segmentation net
            class_scores = self.decoder(self.encoder(rgb_image, return_feature_maps=True), segSize=output_seg_size)

        # Semantic maps
        map_top = class_scores[:,self.top_classes,...].sum(1, keepdim=True)  # (B,1,H,W)
        map_bottom = class_scores[:,self.bottom_classes,...].sum(1, keepdim=True)
        # map_mask = class_scores[:,self.mask_classes,...].sum(1, keepdim=True)
        sem_maps = torch.cat([map_top, map_bottom], dim=1)

        image_ftrs = self.image_features(rgb_image)

        # Compose input for global features
        X = torch.cat([image_ftrs, self.sem_map(class_scores)], dim=1)  # (B,32+32,H,W)
        X = F.relu(X)

        # Calc features and outputs
        ftrs = self.features(X)  # (B,128,H,W)
        ftrs = ftrs.amax(dim=(2,3))  # global max pool -> (B,128)
        
        theta_logits = self.theta_predictor(ftrs)
        theta = torch.log_softmax(theta_logits, dim=1)

        return dict(
            theta=theta,
            maps = sem_maps)


from torchvision.models import mobilenet_v3_small

class HorizonModuleV2(nn.Module):
    def __init__(self, encoder, decoder, theta_bins=100, rho_bins=100):
        super(HorizonModule, self).__init__()
        
        self.encoder = encoder
        self.decoder = decoder

        self.image_features = nn.Sequential(
            nn.Conv2d(3, 32, kernel_size=3, padding=(1,1)),
            nn.BatchNorm2d(32),
            nn.ReLU(inplace=True),
            nn.Conv2d(32, 32, kernel_size=3, padding=(1,1)),
            nn.BatchNorm2d(32),
        )

        self.features = nn.Sequential(
            nn.Conv2d(64, 32, kernel_size=3),
            nn.BatchNorm2d(32),
            nn.ReLU(inplace=True),
            nn.Conv2d(32, 32, kernel_size=3),
            nn.BatchNorm2d(32),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(2),

            nn.Conv2d(32, 64, kernel_size=3),
            nn.BatchNorm2d(64),
            nn.ReLU(inplace=True),
            nn.Conv2d(64, 64, kernel_size=3),
            nn.BatchNorm2d(64),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(2),

            nn.Conv2d(64, 128, kernel_size=3),
            nn.BatchNorm2d(128),
            nn.ReLU(inplace=True),
            nn.Conv2d(128, 128, kernel_size=3),
            nn.BatchNorm2d(128),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(2),
        )

        self.theta_predictor = nn.Sequential(
            nn.Linear(128, 128),
            nn.ReLU(inplace=True),
            nn.Dropout(0.1),
            nn.Linear(128, 128),
            nn.ReLU(inplace=True),
            nn.Linear(128, theta_bins),   # Logits
        )

        # self.rho_predictor = nn.Sequential(
        #     nn.Linear(128, 128),
        #     nn.ReLU(inplace=True),
        #     nn.Dropout(0.1),
        #     nn.Linear(128, 128),
        #     nn.ReLU(inplace=True),
        #     nn.Linear(128, rho_bins),  # Logits
        # )

        self.sem_map = nn.Conv2d(150,32,1)  # Compression of maps

        # Define what classes constitute high level semantic classes - handpicked
        self.top_classes = torch.tensor([3,6], dtype=torch.long)-1
        self.bottom_classes = torch.tensor([4,7,10,12,14,22,27,30,53,61,92,95,129], dtype=torch.long)-1
        # self.mask_classes = torch.tensor([1,2,4,6,7,8,9,11,12,15,16,23,25,26,34,36,39,41,42,43,44,45,49,51,52,53,54,55,56,57,59,62,63,64,68,70,74,78,85,88,89,90,94,101,108,111,122,124,125,131], dtype=torch.long)-1

    def forward(self, img_dict:dict):
        rgb_image = img_dict["image"]

        # Get semantic segmentation
        output_seg_size = rgb_image.shape[2:]
        with torch.no_grad():  # We do not want to update te segmentation net
            class_scores = self.decoder(self.encoder(rgb_image, return_feature_maps=True), segSize=output_seg_size)

        # Semantic maps
        map_top = class_scores[:,self.top_classes,...].sum(1, keepdim=True)  # (B,1,H,W)
        map_bottom = class_scores[:,self.bottom_classes,...].sum(1, keepdim=True)
        # map_mask = class_scores[:,self.mask_classes,...].sum(1, keepdim=True)
        sem_maps = torch.cat([map_top, map_bottom], dim=1)

        image_ftrs = self.image_features(rgb_image)

        # Compose input for global features
        X = torch.cat([image_ftrs, self.sem_map(class_scores)], dim=1)  # (B,32+32,H,W)
        X = F.relu(X)

        # Calc features and outputs
        ftrs = self.features(X)  # (B,128,H,W)
        ftrs = ftrs.amax(dim=(2,3))  # global max pool -> (B,128)
        
        theta_logits = self.theta_predictor(ftrs)
        theta = torch.log_softmax(theta_logits, dim=1)

        return dict(
            theta=theta,
            maps = sem_maps)



transform = Compose([
    ToTensor(),
    Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])


def prescale_image(image, max_size):
    h,w = image.shape[:2]
    logging.info(f"Prescale {w}x{h} to max_size={max_size}")
    b = int(np.ceil(max(h,w) / max_size))
    if b == 1:
        logging.info("Keeping original size")
        return image, 1
    block_size = (b,b)
    if len(image.shape) == 3:
        block_size += (1,)
    new_image = block_reduce(image, block_size=block_size, func=np.mean).astype(image.dtype)
    h,w = new_image.shape[:2]
    logging.info(f"block={b}, scale={1/b}, new shape {h}x{w}")
    return new_image, 1/float(b)


def get_zenith_line(lines:LineSegments, zenith_prior:np.ndarray=None) -> np.ndarray:
    if zenith_prior is None:
        zenith_prior = np.array([0,1])

    if len(lines) == 0:
        return zenith_prior, None

    groups = lines.get_field("group")
    num_g = min(4, groups.max()+1)

    vps = []

    for i in range(num_g):
        vp = fit_vanishing_point(lines[groups == i].homogeneous())
        vps.append(vp)

    vps = np.array(vps)  # (N,3), Vanishing points
    vp_dist = np.linalg.norm(vps[:,:2], axis=1, keepdims=True)
    dirs = vps[:,:2] / vp_dist  # Directions to vps

    zenith_prior = np.atleast_2d(zenith_prior)
    dirs *= np.where(dirs@zenith_prior.T<0, -1, 1)

    #print("Zenith prior\n", zenith_prior)
    #print("Directions\n", dirs)

    A = np.abs(np.array(dirs) @ zenith_prior.T)  # cosine distance of directions to zenit guess
    #print(A)

    valid = np.logical_and(A > 0.90, vp_dist[:] > 1)

    if not np.any(valid):
        return zenith_prior[0], None
    else:
        k = np.nonzero(valid)[0]
        k = k[0]
        return dirs[k], k

from skimage.filters import gaussian

def extract_lines(image, mask):
    # img, scale = prescale_image(image, 800)
    # if mask is not None:
    #     mask = resize(mask, img.shape[:2], order=0)
    #lines = find_line_segments_ff(gaussian(img,1), mask)
    #lines = find_line_groups_ransac_only(lines)

    from librectify import detect_line_segments
    l, f = detect_line_segments(image, max_size=1200, smooth=1)
    lines = LineSegments(l, **f)
    
    #lines = lines.normalized(scale=scale)
    return lines
    

def group_pairs(groups, n=1000):
    n_groups = groups.max() + 1
    si = []
    for k in range(n_groups):
        group_k = np.nonzero(groups == k)[0]
        if group_k.size == 2:
            si.append(np.array(group_k, ndmin=2))
        else:
            size = (group_k.size*group_k.size//4, 2)
            choice = np.random.choice(group_k, size)
            not_same = choice[:,0]!=choice[:,1]
            si.append(choice[not_same,:])

    if not si:
        s = []
        return s

    s = np.concatenate(si)

    if s.size > n:
        s = s[np.random.choice(s.shape[0], n)]

    return s


def gaussian_function(x, mu, sig):
    return np.exp(-np.power(x - mu, 2.) / (2 * np.power(sig, 2.)))

from scipy.ndimage import gaussian_filter1d


def find_optimal_rho(
    zenith_direction,
    lines:LineSegments,
    pairs:np.ndarray,
    prior):

    l = lines.homogeneous()
    ln = lines.length().flatten()
    d = lines.direction()

    a, b = pairs.T

    H = np.cross(l[a], l[b])  # horizon candidate points
    W = np.maximum(ln[a],ln[b]) * (1 - np.abs((d[a] * d[b]).sum(1)))

    # Normalize and keep only real points (remove those in infinity)
    k = np.abs(H[:,2]) > 1e-6
    H = H[k,:2] / H[k,2].reshape(-1, 1)
    W = W[k]

    d = (H @ np.atleast_2d(zenith_direction).T).flatten()  # Location along the zenith line

    mu, sigma = prior
    w = W * gaussian_function(d, mu, sigma)

    hst, bins = np.histogram(d, np.linspace(-4,4,512), weights=w)
    hst = gaussian_filter1d(hst, 4)
    k = np.argmax(hst)
    h = 0.5*(bins[k] + bins[k+1])

    return h, d, w, hst



def get_horizon_endpoints(z_dir, rho, shape):
    c, s = -z_dir
    n = np.array([c, s])
    d = np.array([-s,c])

    h,w = shape
    pp = [w/2, h/2]

    A = pp + rho * n

    l = 0.4 * w

    U = [
    A[0] + l * d[0],
    A[0] - l * d[0],
    ]

    V = [
    A[1] + l * d[1],
    A[1] - l * d[1],
    ]

    x1,x2 = U
    y1,y2 = V

    return (x1,y1), (x2, y2)

from skimage.transform import rotate

def get_prior_from_cnn(image, model):
    image_scaled, scale = prescale_image(image, 800)
    # Pass image through the network
    logging.info("Extracting semantic information from image")
    image_tensor = transform(image_scaled)
    cnn_dict = model(dict(image=image_tensor[None]))

    # Get theta and rho priors
    soft_theta = np.exp(cnn_dict["theta"].cpu().detach().numpy())
    #soft_rho = np.exp(cnn_dict["rho"].cpu().detach().numpy())
    theta = theta_from_soft_label(soft_theta)[0] - np.pi/2
    #rho = rho_from_soft_label(soft_rho)[0] / scale
    z_prior = np.array([np.cos(theta), np.sin(theta)])
    logging.info(f"Prior: theta={theta} (z={z_prior})")
    
    maps = cnn_dict["maps"].detach().cpu().numpy()
    top, bottom = np.split(maps[0], 2, axis=0)

    return z_prior, scale, top[0], bottom[0], soft_theta


def get_horizon_line(image, model):
    image = img_as_float32(image)
    h,w = image.shape[:2]

    z_prior, cnn_scale, top, bottom,_ = get_prior_from_cnn(image, model)
    
    # Get lines masked by network output
    logging.info("Extracting lines")
    lines = extract_lines(image.mean(-1), None)
    groups = lines.get_field("group")
    logging.info(f"Extracted {len(lines)} line segments")

    # get zenith line
    scale = max(h,w)
    pp = (w//2,h//2)
    zenith_direction, zenith_group = get_zenith_line(lines.normalized(scale=scale, shift=pp), z_prior)
    #zenith_direction = z_prior
    logging.info(f"Detected zenit direction: {zenith_direction}")

    # rho = rho_prior
    # # accumulate lines intersections of the zenith line
    # if zenith_group is not None:
    #     logging.info(f"Zenith group {zenith_group}")
    #     lines_1 = lines[groups != zenith_group]
    #     groups_1 = lines_1.get_field("group")

    #     if len(lines_1) > 0:
    #         logging.info("Searching for rho using lines")
    #         pairs = group_pairs(groups_1, 20000)
    #         rho,_,_,hst = find_optimal_rho(zenith_direction, lines_1.normalized(scale=scale, shift=pp), pairs, (rho_prior, 1))

    #         print(rho, scale)
    #         rho *= -scale

    # Visual horizon
    x,y = zenith_direction
    angle = (np.arctan2(y,x) / np.pi * 180) + 90
    tmap_rot = rotate(top, angle)
    bmap_rot = rotate(bottom, angle)

    b_map_proj = gaussian_filter1d(np.max(bmap_rot, 1), 4)
    t_map_proj = gaussian_filter1d(np.max(tmap_rot, 1), 4)

    b_prob = np.maximum.accumulate(b_map_proj)
    t_prob = np.maximum.accumulate(t_map_proj[::-1])[::-1]
    confidence = np.log((t_prob+0.01) / (b_prob+0.01))

    def from_cnn_to_image_space(x, shape):
        return (x-(shape[0]/2)) / cnn_scale

    
    t = 1
    valid_interval = (confidence<t) & (confidence>-t)
    if not np.any(valid_interval):
        valid_interval = np.ones_like(confidence, dtype=np.bool)
    valid_interval = np.nonzero(valid_interval)[0]
    rho_bottom = valid_interval.max()
    rho_top = valid_interval.min()
    rho_bottom = from_cnn_to_image_space(rho_bottom, tmap_rot.shape)
    rho_top = from_cnn_to_image_space(rho_top, tmap_rot.shape)
    rho_mean = from_cnn_to_image_space(np.argmin(np.abs(confidence)), tmap_rot.shape)

    endpts = [get_horizon_endpoints(zenith_direction, r, (h,w)) for r in [rho_mean, rho_top, rho_bottom]]
    
    # construct output
    out_dict = dict(
        z = zenith_direction,
        endpts = endpts,    
        lines = lines,
        top=top,
        bottom=bottom,
    )

    return out_dict



import kornia as K


def visual_horizon_location(tb_maps:Tensor, theta:Union[Tensor,float]) -> Tensor:
    """
    tb_maps : Tensor
        (B,2,H,W)
    """
    tb_maps_rot = K.rotate(tb_maps, theta)   # (B,2,H,W)
    tb_proj = torch.max(tb_maps_rot, dim=-1)  #(B,2,H)

    