"""
Dataset of photos from Flickr with auto-annotated horizons
"""

import json
from pathlib import Path

import numpy as np
from skimage.io import imread


class FlickrDataset:
    def __init__(self, data_path:str, load_imgs=True):
        self.data_path = Path(data_path)
        self.load_images = load_imgs
        self.data = json.load(open(self.data_path/"flickr_data.json","r"))

        filtered = dict()
        for k,d in self.data.items():
            ymin,ymax = d["horizon_range"]
            t,b,_,_ = d["frame"]
            h = b - t
            r = ymax - ymin
            if r/h < 0.1:
                filtered[k] = d

        self.data = filtered
        self.file_list = sorted(list(self.data.keys()))


    def __len__(self):
        return len(self.file_list)

    def get_image(self, filename):
        image = imread(self.data_path/filename)
        if image.ndim == 2:
            image = np.dstack([image]*3)
        return image

    def __getitem__(self, idx:int):
        filename:Path = self.file_list[idx]
        data = self.data[filename]
        y = data["horizon"]
        t,b,l,r = data["frame"]
        shape = (b-t, r-l)
        A,B = (l-l,y-t), (r-l,y-t)
        ymin, ymax = data["horizon_range"]
        img_dict = dict(filename=filename, A=np.array(A,"f"), B=np.array(B,"f"), shape=shape, horizon_range=(ymax-ymin))
        if self.load_images:
            img = self.get_image(filename)
            img_dict.update(image=img[t:b,l:r,:])
        return img_dict
