import logging
import random
from collections import deque
from itertools import islice

import imgaug.augmentables as ia
import imgaug.augmenters as iaa
import numpy as np
import torch
from image_geometry.line_segments import LineSegments, lines_from_points
from librectify import detect_line_segments
from torch.tensor import Tensor
from torch.utils.data import IterableDataset


def lines_as_adjacency(lines:LineSegments):
    n = len(lines)
    A, B = lines.endpoints()
    a_idx = np.arange(0, n).reshape(-1,1)
    b_idx = np.arange(n, 2*n).reshape(-1,1)
    x = np.concatenate([A,B], axis=0)  # (2*n,2)
    a = np.concatenate([a_idx, b_idx], axis=1)  # (n,2)
    return x, a


def ST(X:Tensor) -> Tensor:
    """
    Calculate structure tensor of matrix rows
    (B,N,3) -> (B,N,9)
    """
    if X.size(-1) != 3:
        raise ValueError(f"The size of the last dimension must equal to 3 (have {X.size(-1)})")
    a,b,c = X.split((1,1,1),dim=-1)
    return torch.cat([a*X, b*X, c*X], dim=-1)


class ZSNet(torch.nn.Module):
    """
    Zenith scoring network
    """
    def __init__(self, fl=32, fz=32, f_hidden=128):
        super(ZSNet,self).__init__()
        # Line feature network
        self.hl = torch.nn.Sequential(
            torch.nn.Linear(3, f_hidden),
            torch.nn.LeakyReLU(),
            torch.nn.Linear(f_hidden, f_hidden),
            torch.nn.LeakyReLU(),
            torch.nn.Linear(f_hidden, fl),
        )
        # Zenith feature network
        self.hz =  torch.nn.Sequential(
            torch.nn.Linear(3, f_hidden),
            torch.nn.LeakyReLU(),
            torch.nn.Linear(f_hidden, f_hidden),
            torch.nn.LeakyReLU(),
            torch.nn.Linear(f_hidden, fz),
        )
        # Scoring network
        self.scoring = torch.nn.Sequential(
            torch.nn.Linear(fl+fz, f_hidden),
            torch.nn.LeakyReLU(),
            torch.nn.Linear(f_hidden, f_hidden),
            torch.nn.LeakyReLU(),
            torch.nn.Linear(f_hidden, 1),
        )

    def forward(self, L, Z):
        """
        L: (B, N, 3)
        Z: (B, M, 3)
        """
        N = Z.size(1)
        _Z = self.hz(Z)  # (B,N,fz)
        _L = self.hl(L)  # (B,M,fl)
        gl,_ = _L.max(dim=1, keepdims=True)  # (B,1,fl)
        gl = gl.repeat(1,N,1)  # (B,N,fl)
        return self.scoring(torch.cat((_Z, gl), dim=-1))


def get_normalized_zenith_vp(v, scale, shift):
    v = np.array(v,"f")
    if np.abs(v[2]) > 1e-4:
        v /= v[2]
        v[:2] = (v[:2]-shift)/scale
    return v / np.linalg.norm(v)


def image_normalization_params(image):
    h,w = image.shape[:2]
    scale = max(h,w)
    pp = (w/2, h/2)
    return scale, pp


def lines_for_image(image) -> LineSegments:
    # Generate Z candidates
    #scale, pp = image_normalization_params(image)
    lines, prop_dict = detect_line_segments(image, max_size=1000, smooth=1)
    L = LineSegments(lines, **prop_dict)
    logging.info(f"Extracted {len(lines)} lines")
    return L


def generate_labeled_intersections(L, z_true, num_intersections=1000):
    ia, ib = np.random.choice(L.shape[0], (2, num_intersections))
    valid = ia != ib
    ia, ib = ia[valid], ib[valid]
    Z = np.cross(L[ia,:], L[ib,:])
    normalize(Z)
    # Label zenith candidates
    affinity = Z @ np.atleast_2d(z_true).T
    return Z, affinity

class LineSegmentDataset:
    """
    Add line segments to an existing image dataset
    """
    def __init__(self, image_data):        
        self.image_data = image_data  # Indexable image dataset
        self.lines = dict()  # cache

    def __len__(self):
        return len(self.image_data)

    def __getitem__(self, idx) -> dict:
        image_dict:dict = self.image_data[idx]
        if idx not in self.lines:
            logging.debug(f"{idx}: obtaining lines")
            self.lines[idx] = lines_for_image(image_dict["image"])
        L = self.lines[idx]
        image_dict.update(lines=self.lines[idx])
        return image_dict

def random_intersections(iterable):
    #tform = iaa.Affine(scale=(0.5,2), rotate=(-30,30), shear=(-5,5))
    tform = iaa.Rotate()

    logging.debug(f"Prefilling dataset")
    data_queue = deque(islice(iterable, 3), maxlen=100)

    for image_dict in iterable:
        data_queue.append(image_dict)
        for _ in range(1):
            image_dict = random.choice(data_queue)

            lines = image_dict["lines"]
            shape = image_dict["image"].shape[:2]
            
            kps, adj = lines_as_adjacency(lines)
            kps = ia.KeypointsOnImage.from_xy_array(kps, shape=shape)

            z = np.array(image_dict["zenith_vp"])
            z[-1] = max(z[-1], 1e-8)
            z /= z[-1]

            z_vp = ia.KeypointsOnImage.from_xy_array(np.atleast_2d(z[:2]), shape=shape)
            kps_aug, z_vp_aug = tform.augment_keypoints([kps, z_vp])
            scale, shift = image_normalization_params(image_dict["image"])
            lines_aug = lines_from_points(kps_aug.to_xy_array(), adj).normalized(scale, shift)

            zx,zy = z_vp_aug.to_xy_array()[0]
            sx,sy = shift
            z = np.array([(zx-sx)/scale,(zy-sy)/scale,1], "f")
            normalize(z)

            yield lines_aug.homogeneous(normalized=True), z


class ZenithLineDataset(IterableDataset):
    def __init__(self, iterable, true_aff_deg=2, false_aff_deg=5, max_samples_per_image=256):
        self.src = iterable
        self.min_true_aff = np.cos(true_aff_deg/180 * np.pi)
        self.max_false_aff = np.cos(false_aff_deg/180 * np.pi)
        self.n_samples = max_samples_per_image

    def __iter__(self):
        for L,z_true in self.src:
            z_samples, affinity = generate_labeled_intersections(L, z_true, 1000)
            pz = np.full_like(affinity, -1, dtype=np.float32)
            pz[affinity>self.min_true_aff] = 1
            pz[affinity<self.max_false_aff] = 0
            valid = np.where(pz != -1)[0]

            z_select = np.random.choice(valid, 256)

            pz = pz[z_select]
            z_samples = z_samples[z_select]

            l_select = np.random.choice(L.shape[0], 256)

            yield L[l_select], z_samples, pz


def structure_tensor_loss(z_true, z_pred, p_z):
    return torch.norm(ST(z_true) - torch.mean(ST(z_pred) * p_z.reshape(1,1,-1)), dim=(0,1))


def normalize(x, axis=-1):
    x /= np.linalg.norm(x,axis=axis,keepdims=True)




def get_zenith_line(image:np.ndarray, model:ZSNet):
    h,w = image.shape[:2]
    scale = max(h,w)
    pp = w/2, h/2
    # Get lines in the image (labeled with group but we do not use that information here)
    coords, fields = detect_line_segments(image)
    lines = LineSegments(coords, **fields).normalized(scale, pp)
    L = lines.homogeneous(normalized=True)
    # Make pairs and get intersections
    ia, ib = np.random.choice(len(L), (2, 5000))
    valid = ia != ib
    ia, ib = ia[valid], ib[valid]
    Z = np.cross(L[ia,:], L[ib,:])
    normalize(Z)

    pz = model(torch.from_numpy(L[None,...].astype("f")), torch.from_numpy(Z[None,...].astype("f")))
    pz = torch.sigmoid(pz.detach()).numpy().flatten()
    idx = np.where(pz > 0.5)[0]

    return np.average(Z[idx], weights=pz[idx], axis=0), pz


if __name__ == "__main__":
    import datasets
    import numpy as np
    import torch.optim as optim
    from more_itertools import interleave
    from torch.utils.data import DataLoader

    logging.basicConfig(level=logging.INFO)

    # Configure data pipeline
    yud_images = datasets.YorkUrbanDataset("/mnt/matylda1/jurankovam/datasets/YorkUrbanDB", "train")  
    # TODO - any number of datasets with horizons in them
    
    data = LineSegmentDataset(yud_images)
    src = ZenithLineDataset(random_intersections(datasets.random_iterator(data)))  
    loader = DataLoader(src, batch_size=16)

    # Initialize network
    zsnet = ZSNet(fl=32, fz=32, f_hidden=128).cuda()

    criterion = torch.nn.BCEWithLogitsLoss(reduction="mean")
    optimizer = optim.Adam(zsnet.parameters(), lr=1e-3)

    running_loss = 0.0
    for epoch in range(1000):
        epoch_losses = []
        for L, Z, pz in islice(loader, 100):
            # get the inputs; data is a list of [inputs, labels]
            # zero the parameter gradients
            optimizer.zero_grad()

            # forward + backward + optimize
            outputs = zsnet(L.float().cuda(),Z.float().cuda())

            #print(outputs.shape)
            loss = criterion(outputs.float().cuda(), pz.float().cuda())
            loss.backward()
            optimizer.step()
            epoch_losses.append(loss)

        print(torch.tensor(epoch_losses).mean())

    torch.save(zsnet, "zsnet.pt")

print('Finished Training')
