import random
from collections import deque
from queue import deque

import imgaug
import numpy as np
from imgaug.augmentables import Keypoint, KeypointsOnImage
from imgaug.augmenters import *
from numpy.core.fromnumeric import size
from skimage.transform import rescale

from .utils import array_info


def center_crop(output_shape=(224,224)):
    h,w = output_shape
    return Sequential( [
        Resize({"shorter-side":max(h,w), "longer-side":"keep-aspect-ratio"},interpolation=cv2.INTER_AREA),
        CropToFixedSize(width=w,height=h,position="center")
    ], random_order=False)


def random_augment():
    return Sequential([
        Fliplr(0.5),
        Rotate((-45,45), mode="ALL", fit_output=False),
        #Rotate((-1,1), mode="ALL", fit_output=False),
        Sequential(
            [
                GammaContrast( (0.5,2), per_channel=True),
                #Sometimes(0.2, Sharpen((0.1,0.5))),
                #Sometimes(0.2, GaussianBlur((0,0.5))),
                Sometimes(0.2, Grayscale()),
                #Sometimes(0.1, AdditiveGaussianNoise((0,1))),
            ], random_order=True),
    ],
    random_order=False)


def random_crop(output_shape=(224,224), image_dropout=True):
    height, width = output_shape
    sequence = [random_augment(),CropToFixedSize(width=width,height=height, position="normal" )]
    if image_dropout:
        sequence.append(
            Sometimes(0.1, Cutout(size=(0.1,0.2), nb_iterations=(1,2), fill_mode="gaussian", fill_per_channel=True))
        )
    return Sequential(sequence, random_order=False)


def line_parameters(kps, center=(0,0)):
    A,B = kps.keypoints
    u,v = center
    return np.cross(
        np.array([A.x-u,A.y-v,1], np.float64),
        np.array([B.x-u,B.y-v,1], np.float64))


def dict_to_sample(image_dict, transform=None):
    image = image_dict["image"]
    h,w = shape = image.shape[:2]

    (x1,y1), (x2,y2) = image_dict["A"], image_dict["B"]
    kps = KeypointsOnImage([Keypoint(x=x1,y=y1), Keypoint(x=x2,y=y2)], shape=shape)

    if transform is not None:
        image, kps = transform.augment(image=image, keypoints=kps)

    h,w = shape = image.shape[:2]
    h = line_parameters(kps, center=(w/2, h/2))
    norm = np.linalg.norm(h[:2])
    h /= norm
    if h[1] < 0: h *= -1
    theta = np.arctan2(h[0], h[1])
    rho = -h[2]

    return dict(image=image, z=np.array((h[0], h[1])), theta=theta, rho=rho, A=kps[0], B=kps[1], kps=kps)


def prescale_image(new_image, size_range):
     min_size, max_size = size_range
     img = new_image["image"]
     A,B = new_image["A"], new_image["B"]
     size = min(img.shape[:2])
     new_size = np.random.uniform(min_size, max_size)
     scale = new_size / size
     img = rescale(img, scale, anti_aliasing=True, preserve_range=True, multichannel=True).astype("u1")
     out_dict = dict(image=img, A=A*scale, B=B*scale)
     return out_dict


def sample_generator(sequence, transform=None, window=1, samples_per_window=1, size_range=(256,300)):
    """
    Generate training samples
    """
    img_queue = deque(maxlen=window)
    for img_dict in sequence:
        img_in = img_dict["image"]
        #print(f"New image: {array_info(img_in)}")
        img_dict = prescale_image(img_dict, size_range=size_range)
        img_queue.append(img_dict)  # add new image to the queue
        for k in range(samples_per_window):  # generate N batches with the images in the queue without reading new one
            j = np.random.randint(0, len(img_queue))
            image_dict = img_queue[j]
            yield dict_to_sample(image_dict, transform=transform)
