import numpy as np
import matplotlib.pyplot as plt


from image_geometry.utils import plot_line_segments


def plot_result(image_dict, out_dict, ax=None):
    if ax is None:
        ax = plt.gca()
    image = image_dict["image"]
    w = image.shape[1]
    ax.imshow(image.mean(-1), cmap="gray", interpolation="bilinear")
    #plot_line_segments(ax, out_dict["lines"], (0,1,1), line_width=1)
    if "A" in image_dict:
        (x1,y1),(x2,y2) = image_dict["A"], image_dict["B"]
        ax.plot((x1,x2),(y1,y2),"--",lw=4,color=(0,1,0), alpha=0.5)
    (x1,y1),(x2,y2) = out_dict["endpts"][0]
    ax.plot((x1,x2),(y1,y2),"k-",lw=5,ms=10)
    ax.plot((x1,x2),(y1,y2),"-",lw=3, color=(1,1,0))
    for endpts in out_dict["endpts"][1:]:
        (x1,y1),(x2,y2) = endpts
        ax.plot((x1,x2),(y1,y2),"-",lw=2, color=(0,1,0))
    ax.set(xticks=[],yticks=[],xlim=[0,w])

def figsize_for_image(image_or_shape, max_fig_size=10):
    if isinstance(image_or_shape, tuple):
        h,w = image_or_shape[:2]
    elif isinstance(image_or_shape, np.ndarray):
        h,w = image_or_shape.shape[:2]
    else:
        raise TypeError
    scale = 15/max(h,w)
    return (w*scale, h*scale)


"""
plt.figure(figsize=(15,10))
plt.imshow(image_dict["image"].mean(-1), cmap="gray")
w = lines.get_field("weight") > 0.0
ls = lines[w]
colors = ["r","g","b","c","m","y","k"]
for (x1,y1,x2,y2),c in zip(ls.coordinates(), ls.get_field("group").astype("i")):
    plt.plot([x1,x2],[y1,y2],"-", color=colors[c], lw=2)
plt.axis("off")
plt.tight_layout()
"""

def zenith_intersections(img_size, A, B, pp = None):
    horizont_direction = A[0:2] - B[0:2]
    img_h = img_size[0]
    img_w = img_size[1]

    if pp is None:
        mid_point = [img_w/2, img_h/2]
    else:
        mid_point = pp

    horizont_direction = horizont_direction/np.linalg.norm(horizont_direction)
    zenith_line = [horizont_direction[0],horizont_direction[1], -mid_point[0]*horizont_direction[0] - mid_point[1]*horizont_direction[1]]

    C = np.cross(zenith_line, np.array([0, 1, 0]))
    D = np.cross(zenith_line, np.array([0, 1, -img_h]))

    C = C / C[2]
    D = D / D[2]
    return C, D
##
def draw_line_segments_array(coords, color='red'):
    x0, y0, x1, y1 = np.split(coords, 4, axis=1)
    X = np.concatenate([x0,x1], axis=1)
    Y = np.concatenate([y0,y1], axis=1)

    for x,y in zip(X,Y):
        plt.plot(x,y,"-",c=color, zorder=1, lw=2)
##
def draw_colored_line_segments_array(coords, group):
    colors = ['red','green','blue','yellow','orange','pink','violet']
    x0, y0, x1, y1 = np.split(coords, 4, axis=1)
    #X = np.concatenate([x0, x1], axis=1)
    #Y = np.concatenate([y0, y1], axis=1)

    for i in range(group.max() + 1):
        X = np.concatenate([x0[group==i], x1[group==i]], axis=1)
        Y = np.concatenate([y0[group==i], y1[group==i]], axis=1)
        for x,y in zip(X,Y):
            plt.plot(x, y, "-", c=colors[i], zorder=1, lw=1)

##
def draw_line_to_zenith(pp, zenith_point, img_height):
    zenith_line = np.cross(np.hstack([pp,1]), zenith_point)

    C = np.cross(zenith_line, np.array([0, 1, 0]))
    D = np.cross(zenith_line, np.array([0, 1, -img_height]))

    C = C / C[2]
    D = D / D[2]

    plt.plot(np.array([C[0], D[0]]), np.array([C[1],D[1]]), ':', color='blue')

    return
##
def draw_line_as_horizon(horizon_line, img_width):
    A = np.cross(horizon_line, np.array([1, 0, 0]))
    B = np.cross(horizon_line, np.array([1, 0, -img_width]))

    A = A / A[2]
    B = B / B[2]

    plt.plot(np.array([A[0], B[0]]), np.array([A[1], B[1]]), ':', color='blue')
    return
##
def draw_cross(img_data, A = None, B = None, horizon_line = None, zenith_point = None, lc = 'red'):
    img_h, img_w = img_data['shape']
    pp = img_data['pp']

    if horizon_line is not None:
        A = np.cross(horizon_line, np.array([1, 0, 0]))
        B = np.cross(horizon_line, np.array([1, 0, -img_w]))
        A = A / A[2]
        B = B / B[2]

    if zenith_point is None:
        horizont_direction = A[0:2] - B[0:2]
        horizont_direction = horizont_direction / np.linalg.norm(horizont_direction)
        zenith_line = [horizont_direction[0], horizont_direction[1], -pp[0] * horizont_direction[0] - pp[1] * horizont_direction[1]]
    else:
        zenith_line = np.cross(np.hstack([pp, 1]), zenith_point)

    C = np.cross(zenith_line, np.array([0, 1, 0]))
    D = np.cross(zenith_line, np.array([0, 1, -img_h]))

    C = C / C[2]
    D = D / D[2]

    plt.plot(np.array([A[0], B[0]]), np.array([A[1], B[1]]), '--', color=lc)
    plt.plot(np.array([C[0], D[0]]), np.array([C[1], D[1]]), '--', color=lc)



def line_segments_from_homogeneous(lines, bbox):
    x,y,w,h = bbox
    
    # Corner points
    A = np.array([x,y,1])
    B = np.array([x+w,y,1])
    C = np.array([x+w,y+h,1])
    D = np.array([x,y+h,1])

    # Cross product of pairs of corner points
    edges = [
        np.cross(a,b) for a,b in [[A,B],[B,C],[C,D],[D,A]]
    ]

    # Cross product of line params with edges
    intersections = [
        np.cross(lines, e) for e in edges
    ]

    # Normalize
    normalized = [
        p[:,:2] / p[:,-1].reshape(-1,1) for p in intersections
    ]

    X = []
    Y = []
    
    for p in zip(*normalized):
        P = []
        for (u,v) in p:
            if (x <= u <= x+w) and (y <= v <= y+h):
                P.append( (u,v) )
        if len(P) == 2:
            (x0,y0), (x1,y1) = P
            X.append( (x0,x1) )
            Y.append( (y0,y1) )
        else:
            X.append(None)
            Y.append(None)

    return X, Y