import sys

from torch import functional

from horizon import sord

sys.path.append("/mnt/data/semantic-segmentation-pytorch")


import logging
from itertools import islice

import numpy as np
import torch
import torch.nn.functional as F
from more_itertools import interleave
from rich.logging import RichHandler
from rich.progress import track
from skimage.util import img_as_float32
from torch import nn
from torch.nn.utils import clip_grad_norm_
from torch.optim import Adam
from torch.optim.lr_scheduler import ExponentialLR
from torch.utils.data import DataLoader, IterableDataset
from torchvision.transforms import Compose, Normalize, ToTensor

import datasets
from horizon import models
from horizon.data import center_crop, random_crop, sample_generator
from horizon.sord import soft_label_theta

# os.environ['CUDA_LAUNCH_BLOCKING'] = '1'


# This class just converts generic representation to pytorch tensors
class HorizonSamples(IterableDataset):
    def __init__(self, sample_sequence):
        self.seq = sample_sequence
        self.transform = Compose([
            ToTensor(),
            Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), # required by the segmentation module
        ])
    def __iter__(self):
        for image_dict in self.seq:
            image_tensor = self.transform(img_as_float32(image_dict["image"]))
            sample_dict = dict(image=image_tensor)

            theta_tensor = torch.tensor(np.atleast_1d(image_dict["theta"]), dtype=torch.float32)  # value of theta
            sample_dict.update(theta=theta_tensor)

            rho_tensor = torch.tensor(np.atleast_1d(image_dict["rho"]), dtype=torch.float32)  # value of rho
            sample_dict.update(rho=rho_tensor)
    
            # soft_theta_tensor = torch.tensor(soft_label_theta(t, n_bins=128, K=4)[0], dtype=torch.float32)
            # z_tensor = torch.tensor(image_dict["z"], dtype=torch.float32)
            yield sample_dict


if __name__ == "__main__":
    logging.basicConfig(level=logging.INFO, format="%(message)s", handlers=[RichHandler()])

    # # load state dict and apply to H
    # in_ckpt = "horizon_module_e0009.pt"
    # start_epoch = 10
    # # in_ckpt = None
    # # start_epoch = 1
    # try:
    #     if in_ckpt is not None:
    #         logging.info(f"Loading weights from {in_ckpt}")
    #         H.load_state_dict(torch.load(in_ckpt))
    # except FileNotFoundError:
    #     logging.warning(f"Cannot load {in_ckpt} starting from scratch")
    #     start_epoch = 1
    #     in_ckpt = None

    
    #H = models.ThetaSORDEstimator(theta_bins=64).cuda()
    H = models.ThetaRegressionEstimator().cuda()
    #H = models.ZenithVectorEstimator().cuda()
    
    num_params = sum(p.numel() for p in H.parameters())
    logging.info(f"Built model with {num_params} parameters")

    def get_mean_losses(losses):
        L = np.array(losses)
        return L.mean(0)

    from datasets.paths import *

    # Image sources
    logging.info("Initializing datasets")
    yud_images = datasets.YorkUrbanDataset(yud_path, "train")
    logging.info(f"YUD: {len(yud_images)} images")
    gp3k_images = datasets.GeoPose3KDataset(gp3k_path) 
    logging.info(f"GP3K: {len(gp3k_images)} images")
    hlw_images = datasets.HorizonLinesInTheWildDataset(hlw_path, "train")
    logging.info(f"HLW: {len(hlw_images)} images")
    flickr_images = datasets.FlickrDataset(flickr_path)
    logging.info(f"FLICKR: {len(flickr_images)} images")
    gsw_images = datasets.GoogleDataset(gsw_path)
    logging.info(f"GSW: {len(gsw_images)} images")

    hlw_val_images = datasets.HorizonLinesInTheWildDataset(hlw_path, "val")
    logging.info(f"HLW (val): {len(hlw_val_images)} images")

    # This combines multiple sources to one sequence of randomly chosen items in round-robin fashion
    training_images = interleave(
        datasets.random_iterator(flickr_images, maxlen=None),
        datasets.random_iterator(hlw_images, maxlen=None),
        datasets.random_iterator(gp3k_images, maxlen=None),
        datasets.random_iterator(gsw_images, maxlen=None),
    )

    # This transforms the images from the source (actually dicts with annotations)
    # to training samples - image and arrays with theta/rho representation
    # sample_generator is ysed for speed reasons since it remembers few past
    # images, making the sampling process much faster.
    S = 48

    training_samples = sample_generator(
        training_images,
        transform=random_crop((S,S)),
        samples_per_window=64,
        size_range=(S*1.3, S*2),
        window=100)

    n_prefill = 200
    for _ in track(islice(training_samples, n_prefill), total=n_prefill, description="Filling training samples"):
        pass

    # This provides batched tensors for the training
    loader = DataLoader(HorizonSamples(training_samples), batch_size=32, num_workers=4, pin_memory=True, prefetch_factor=10)

    logging.info("Loading validation data")
    val_samples = sample_generator(
        hlw_val_images,
        transform=center_crop((S,S)),
        size_range=(S,S),
        samples_per_window=1,
        window=1)
    val_loader = DataLoader(HorizonSamples(val_samples), batch_size=1, num_workers=4, pin_memory=True, prefetch_factor=10)
    val_data = []
    for x in track(islice(val_loader, 200), total=20, description="Loading validation samples"):
        val_data.append(x)
    
    logging.info("Starting training loop")

    optimizer = Adam(H.parameters(), lr=1e-3, weight_decay=1e-4)
    scheduler = ExponentialLR(optimizer, gamma=0.95)

    start_epoch = 1
    epochs = 100
    steps_per_epoch = 2000

    def vector_similarity(input, target):
        return torch.mean(1 - torch.abs(F.cosine_similarity(input, target)))

    def vector_perpendiculrity(input, target):
        return torch.mean(torch.abs(F.cosine_similarity(input, target)))

    try:
        for epoch in range(start_epoch, epochs+1):
            logging.info(f"Training epoch {epoch}/{epochs} for {steps_per_epoch} steps")
            losses = []
            training_batches = enumerate(islice(loader, steps_per_epoch))
            for k, sample_dict in track(training_batches, total=steps_per_epoch, description=f"Epoch {epoch}/{epochs}"):
                # Forward pass
                image = sample_dict["image"].cuda()
                t_true = sample_dict["theta"]

                # regression
                pred_dict = H(image.cuda())
                #print(t_true.shape, pred_dict["theta"].shape)
                loss = F.mse_loss(pred_dict["theta"], t_true.cuda())
                losses.append(loss.item())
                
                # Soft classification
                # pred_dict = H(image)
                # t_pred = pred_dict["theta"]
                # t_soft_true = torch.tensor(sord.soft_label_theta(t_true.numpy(), n_bins=64, K=3))
                # loss = F.kl_div(t_pred, t_soft_true.cuda(), reduction="batchmean")
                # losses.append(loss.item())
                
                # Zenith vector prediction
                # z_pred = H(image)
                # z_true = torch.cat([torch.cos(t_true), torch.sin(t_true)], dim=1)
                # print(t_true.shape, z_true.shape, z_pred.shape)
                # loss = F.cosine_similarity(z_pred, z_true).mean()

                loss.backward()
                #clip_grad_norm_(H.parameters(), 10)
                optimizer.step()
                optimizer.zero_grad()

                if k % 200 == 199 or k == (steps_per_epoch-1):
                    ml = get_mean_losses(losses)
                    logging.info(f"Step {k}: mean: {ml}")

            with torch.no_grad():
                val_losses = []
                for sample_dict in track(val_data, description="Running validation"):
                    image = sample_dict["image"].cuda()
                    t_true = sample_dict["theta"]

                    # pred_dict = H(image)
                    # t_pred = pred_dict["theta"]
                    # t_soft_true = torch.tensor(sord.soft_label_theta(t_true.numpy(), n_bins=64, K=3))
                    # loss = F.kl_div(t_pred, t_soft_true.cuda(), reduction="batchmean")

                    # z_pred = H(image)
                    # z_true = torch.cat([torch.cos(t_true), torch.sin(t_true)], dim=1)
                    # loss = F.cosine_similarity(z_pred, z_true.cuda()).mean()

                    pred_dict = H(image.cuda())
                    loss = F.mse_loss(pred_dict["theta"][None], t_true.cuda())

                    val_losses.append(loss.item())
                logging.info(f"Validation loss: {get_mean_losses(val_losses)}")

            if epoch % 5 == 4 or epoch == epochs:
                output_file = f"theta_module_e{epoch:04d}.pt"
                logging.info(f"Saving model to {output_file}")
                torch.save(H.state_dict(), output_file)
            
            scheduler.step()
            logging.info("Updating learning rate")
        
    except KeyboardInterrupt:
        logging.warning("User interrupt")

    logging.info("Done")
