#include <Eigen/Dense>
#include "util/exceptions.h"
#include "edge_dir_detector.h"
#include "util/umfdebug.h"
#include "util/draw.h"

namespace umf {


#ifndef isnan
#define isnan(x) ((x) != (x))
#endif

/**
 * @brief check if the point is inside the image with size with x height
 * @param point
 * @param width
 * @param height
 * @return  if a point represents a real pixel
 */
static inline bool checkBounds(Eigen::Vector2f &point, int width, int height)
{
    return point[0] < width && point[1] < height && !isnan(point[0]) && !isnan(point[1]) && point[0] >= 0 && point[1] >= 0;
}

static void assignSamples(Eigen::Vector2f origin,
                          Eigen::Vector2f direction,
                          Eigen::Vector2f otherDirection,
                          std::vector< Eigen::Vector2f > &samples)
{

    samples[0] = otherDirection*0.15f*(-1.f) + origin + direction*0.45f;
    samples[1] = otherDirection*0.15f*(+1.f) + origin + direction*0.45f;

    samples[2] = otherDirection*0.3f*(-1.f) + origin + direction*0.4f;
    samples[3] = otherDirection*0.3f*( 0.f) + origin + direction*0.4f;
    samples[4] = otherDirection*0.3f*(+1.f) + origin + direction*0.4f;

    samples[5] = otherDirection*0.15f*(-1.f) + origin + direction*0.3f;
    samples[6] = otherDirection*0.15f*(+1.f) + origin + direction*0.3f;
    samples[7] = origin + direction*0.2f;
}

static void assignCenteredSamples(Eigen::Vector2f origin,
                                  Eigen::Vector2f direction,
                                  Eigen::Vector2f otherDirection,
                                  std::vector< Eigen::Vector2f > &samples)
{

    assignSamples(origin, direction, otherDirection, samples);

    samples[8] = origin + otherDirection*0.1f + direction*0.1f;
    samples[9] = origin - otherDirection*0.1f + direction*0.1f;
    samples[10] = origin - otherDirection*0.1f - direction*0.1f;
    samples[11] = origin - otherDirection*0.1f + direction*0.1f;

    samples[12] = origin + otherDirection*0.25f;
    samples[13] = origin + direction*0.25f;
    samples[14] = origin - otherDirection*0.25f;
    samples[15] = origin - direction*0.25f;
}

static inline float clamp(float x, float a, float b)
{
    return x < a ? a : (x > b ? b : x);
}

/**
 * @brief generate sampling points
 * @tparam NCHAN the number of channels for edge detection - no role here
 * @param p1 first point
 * @param p2 second point
 * @param otherDirection the direction in the detected grid perpendicular to the direction between p1, p2
 * @param samples1 The sampling points in the image for the first point
 * @param samples2 The sampling point in the image for the second point
 */
template <int NCHAN>
void EdgeDirDetector<NCHAN>::getSamplingPoints(Eigen::Vector2f p1,
                                               Eigen::Vector2f p2,
                                               Eigen::Vector2f otherDirection,
                                               std::vector<Eigen::Vector2f> &samples1,
                                               std::vector<Eigen::Vector2f> &samples2)
{
    Eigen::Vector2f p12p2 = p2 - p1;
    Eigen::Vector2f p22p1 = p1 - p2;

    if(samples1.size() <= EDGE_DIR_SAMPLE_COUNT_SIMPLE)
    {
        samples1.resize(EDGE_DIR_SAMPLE_COUNT_SIMPLE);
        samples2.resize(EDGE_DIR_SAMPLE_COUNT_SIMPLE);

        assignSamples(p1, p12p2, otherDirection, samples1);
        assignSamples(p2, p22p1, otherDirection, samples2);
    } else if(samples1.size() <= EDGE_DIR_SAMPLE_COUNT_CENTERED){

        samples1.resize(EDGE_DIR_SAMPLE_COUNT_CENTERED);
        samples2.resize(EDGE_DIR_SAMPLE_COUNT_CENTERED);

        assignCenteredSamples(p1, p12p2, otherDirection, samples1);
        assignCenteredSamples(p2, p22p1, otherDirection, samples2);
    } else {
        throw UMFException();
    }
}


template <int NCHAN>
EdgeDirDetector<NCHAN>::EdgeDirDetector()
{
    int defaultrgb[3] = {1, 1, 1};

    for(int i = 0; i < NCHAN; i++)
    {
        this->fieldDiffThreshold[i] = defaultrgb[i];
    }

    //this->sampleCount = EDGE_DIR_SAMPLE_COUNT_CENTERED;
    //this->scoreDecider = EDGE_DIR_SCORE_DECIDER_CENTERED;
    this->sampleCount = EDGE_DIR_SAMPLE_COUNT_SIMPLE;
    this->scoreDecider = EDGE_DIR_SCORE_DECIDER_SIMPLE;
}

template<int NCHAN> template<class T>
Eigen::Matrix<int, NCHAN, 1> EdgeDirDetector<NCHAN>::getScore(Image<T, NCHAN> *img, std::vector<Eigen::Vector2f> &samples1, std::vector<Eigen::Vector2f> &samples2)
{
    Eigen::Matrix<int, NCHAN, 1> score;
    score.setZero();

#ifdef UMF_DEBUG_COUNT_PIXELS
    UMFDebug *dbg = UMFDSingleton::Instance();
#endif

    Eigen::Matrix<T,NCHAN,1> sampleValue1;
    Eigen::Matrix<T,NCHAN,1> sampleValue2;

    for(unsigned int i = 0; i < this->sampleCount; i++)
    {
        img->get2Der(sampleValue1, clamp(samples1[i][0], 0, img->width-1), clamp(samples1[i][1], 0, img->height-1));
        img->get2Der(sampleValue2, clamp(samples2[i][0], 0, img->width-1), clamp(samples2[i][1], 0, img->height-1));
        //Eigen::Matrix<T,NCHAN,1> sampleValue1 = img->get2De(clamp(samples1[i][0], 0, img->width-1), clamp(samples1[i][1], 0, img->height-1));
        //Eigen::Matrix<T,NCHAN,1> sampleValue2 = img->get2De(clamp(samples2[i][0], 0, img->width-1), clamp(samples2[i][1], 0, img->height-1));
#ifdef UMF_DEBUG_COUNT_PIXELS
        dbg->addPixels(2);
#endif
        Eigen::Matrix<int,NCHAN,1> diff = (sampleValue1.template cast<int>() - sampleValue2.template cast<int>());
        score += (diff.array() > this->fieldDiffThreshold).template cast<int>().matrix();
        score -= (-diff.array() > this->fieldDiffThreshold).template cast<int>().matrix();
    }

    return score;
}

template<int NCHAN> template<class T>
Eigen::Matrix<int, NCHAN, 1> EdgeDirDetector<NCHAN>::getScoreGauss(Image<T, NCHAN> *img, std::vector<Eigen::Vector2f> &samples1, std::vector<Eigen::Vector2f> &samples2)
{
    Eigen::Matrix<int, NCHAN, 1> score;
    score.setZero();

#ifdef UMF_DEBUG_COUNT_PIXELS
    UMFDebug *dbg = UMFDSingleton::Instance();
#endif

    for(unsigned int i = 0; i < this->sampleCount; i++)
    {
        //img->get2Der(sampleValue1, clamp(samples1[i][0], 0, img->width-1), clamp(samples1[i][1], 0, img->height-1));
        //img->get2Der(sampleValue2, clamp(samples2[i][0], 0, img->width-1), clamp(samples2[i][1], 0, img->height-1));
        Eigen::Matrix<int,NCHAN,1> sampleValue1 =
                1*img->get2De(clamp(samples1[i][0] - 1, 0, img->width-1), clamp(samples1[i][1] - 1, 0, img->height-1)).template cast<int>() + 2*img->get2De(clamp(samples1[i][0], 0, img->width-1), clamp(samples1[i][1] - 1, 0, img->height-1)).template cast<int>() + 1*img->get2De(clamp(samples1[i][0] + 1, 0, img->width-1), clamp(samples1[i][1] - 1, 0, img->height-1)).template cast<int>()
                +
                2*img->get2De(clamp(samples1[i][0] - 1, 0, img->width-1), clamp(samples1[i][1] + 0, 0, img->height-1)).template cast<int>() + 4*img->get2De(clamp(samples1[i][0], 0, img->width-1), clamp(samples1[i][1] + 0, 0, img->height-1)).template cast<int>() + 2*img->get2De(clamp(samples1[i][0] + 1, 0, img->width-1), clamp(samples1[i][1] + 0, 0, img->height-1)).template cast<int>()
                +
                1*img->get2De(clamp(samples1[i][0] - 1, 0, img->width-1), clamp(samples1[i][1] + 1, 0, img->height-1)).template cast<int>() + 2*img->get2De(clamp(samples1[i][0], 0, img->width-1), clamp(samples1[i][1] + 1, 0, img->height-1)).template cast<int>() + 1*img->get2De(clamp(samples1[i][0] + 1, 0, img->width-1), clamp(samples1[i][1] + 1, 0, img->height-1)).template cast<int>()
                ;
        Eigen::Matrix<int,NCHAN,1> sampleValue2 =
                1*img->get2De(clamp(samples2[i][0] - 1, 0, img->width-1), clamp(samples2[i][1] - 1, 0, img->height-1)).template cast<int>() + 2*img->get2De(clamp(samples2[i][0], 0, img->width-1), clamp(samples2[i][1] - 1, 0, img->height-1)).template cast<int>() + 1*img->get2De(clamp(samples2[i][0] + 1, 0, img->width-1), clamp(samples2[i][1] - 1, 0, img->height-1)).template cast<int>()
                +
                2*img->get2De(clamp(samples2[i][0] - 1, 0, img->width-1), clamp(samples2[i][1] + 0, 0, img->height-1)).template cast<int>() + 4*img->get2De(clamp(samples2[i][0], 0, img->width-1), clamp(samples2[i][1] + 0, 0, img->height-1)).template cast<int>() + 2*img->get2De(clamp(samples2[i][0] + 1, 0, img->width-1), clamp(samples2[i][1] + 0, 0, img->height-1)).template cast<int>()
                +
                1*img->get2De(clamp(samples2[i][0] - 1, 0, img->width-1), clamp(samples2[i][1] + 1, 0, img->height-1)).template cast<int>() + 2*img->get2De(clamp(samples2[i][0], 0, img->width-1), clamp(samples2[i][1] + 1, 0, img->height-1)).template cast<int>() + 1*img->get2De(clamp(samples2[i][0] + 1, 0, img->width-1), clamp(samples2[i][1] + 1, 0, img->height-1)).template cast<int>()
                ;
#ifdef UMF_DEBUG_COUNT_PIXELS
        dbg->addPixels(32);
#endif
        Eigen::Matrix<int,NCHAN,1> diff = (sampleValue1 - sampleValue2)/16;
        score += (diff.array() > this->fieldDiffThreshold).template cast<int>().matrix();
        score -= (-diff.array() > this->fieldDiffThreshold).template cast<int>().matrix();
    }

    return score;
}


/**
 * @brief Extract edge directions for matching with database
 * @tparam NCHAN the number of channels used to extarct edge direction
 * @tparam T the image type
 * @param img The input image with NCHAN channels
 * @param pencil1 The first pencil of lines
 * @param pencil2 the second pencil of lines
 * @param show whether to show debug output
 *
 * This function does the following
 *  -# check if the pencils are correctly aligned (one of them is not reversed somehow) and optionally reverse one direction
 *  -# calculate intersection points for field centers (see red dots in the image)
 *  -# for each field center for each neighbour downwards or to the right:
 *      -# sample the area between fields and compare the points
 *      -# for each channel based on the comparison decide the edge direction
 *      -# the edge directions are shown for each channel in the image below
 *      .
 *  -# store these values and optionally show debug output
 *
 * \image html 6_edgedir.png
 */
template<int NCHAN> template<class T>
void EdgeDirDetector<NCHAN>::extract(Image<T, NCHAN> *img, std::vector<Eigen::Vector3f> &pencil1, std::vector<Eigen::Vector3f> &pencil2, ImageGray *mask, bool show)
{
    //first check if the pencils are aligned in the good way
    //alignement should be:
    // 1 2..
    //1
    //2
    //... and not:
    // n-1 n-2
    //1
    //2

    //to get this calculate the three intersection firstxfirst; lastxfirst; firstxlast
    Eigen::Vector3f intff = pencil1.front().cross(pencil2.front());
    Eigen::Vector3f intlf = pencil1.back().cross(pencil2.front());
    Eigen::Vector3f intfl = pencil1.front().cross(pencil2.back());

    intff /= intff[2];
    intlf /= intlf[2];
    intfl /= intfl[2];

    //Get the the vectors connecting these points
    Eigen::Vector3f p1line = intfl - intff;
    Eigen::Vector3f p2line = intlf - intff;

    if(p1line.cross(p2line)(2) < 0) //bad alignment we want to loop through p1 as rows and p2 as columns
    {
        std::reverse(pencil2.begin(), pencil2.end());
    }

    this->rows = pencil1.size();
    this->cols = pencil2.size();

    this->extractionPoints.resize(this->rows*this->cols);

    //now we can calculate the field centers
    for(unsigned int row = 0; row < this->rows; row++)
    {
        for(unsigned int col = 0; col < this->cols; col++)
        {
            Eigen::Vector3f intersection = pencil1[row].cross(pencil2[col]);
            intersection /= intersection[2];
            this->extractionPoints[row*this->cols + col] = Eigen::Vector2f(intersection[0], intersection[1]);
        }
    }

    const int verticalOffset = this->rows*this->cols;
    typename Marker<NCHAN>::DirectionType tmp; tmp.setOnes(); tmp *= EDGE_DIRECTION_INVALID;
    this->edgeDirections.resize(2*verticalOffset, tmp);

    std::vector<Eigen::Vector2f> samples1(this->sampleCount);
    std::vector<Eigen::Vector2f> samples2(this->sampleCount);
#ifdef UMF_DEBUG_COUNT_PIXELS
    UMFDebug *dbg = UMFDSingleton::Instance();
#endif

    for(unsigned int rowI = 0; rowI < this->rows; rowI++)
    {
        for(unsigned int colI = 0; colI < this->cols; colI++)
        {
            int pindex = rowI*this->cols + colI;
            Eigen::Vector2f current = this->extractionPoints[pindex];
            if(!checkBounds(current, img->width, img->height))
            {
                continue;
            }


            //HORIZONTAL EDGE
            if(colI != this->cols - 1) //ignore last column
            {

                Eigen::Vector2f right = this->extractionPoints[pindex+1];
                if(checkBounds(right, img->width, img->height))
                {
                    Eigen::Vector2f verticalDirection;
                    if(rowI == this->rows - 1)
                    {
                        Eigen::Vector2f top = this->extractionPoints[pindex - this->cols];
                        verticalDirection = top - current;
                    } else {
                        Eigen::Vector2f bottom = this->extractionPoints[pindex + this->cols];
                        verticalDirection = bottom - current;
                    }


                    getSamplingPoints(current, right, verticalDirection, samples1, samples2);


                    Eigen::Matrix<int, NCHAN, 1> score = this->getScore(img, samples1, samples2);

                    typename Marker<NCHAN>::DirectionType result; result.setZero();

                    result += ((score.array() > this->scoreDecider).template cast<EdgeType>().matrix())*EDGE_DIRECTION_LEFTUP;
                    result += ((score.array() < -this->scoreDecider).template cast<EdgeType>().matrix())*EDGE_DIRECTION_RIGHTDOWN;

                    this->edgeDirections[pindex ] = result;
                }

            }

            //VERTICAL EDGE
            if(rowI != this->rows - 1) //ignore last row
            {
                Eigen::Vector2f bottom = this->extractionPoints[pindex + this->cols];
                if(checkBounds(bottom, img->width, img->height))
                {
                    Eigen::Vector2f horizontalDirection;
                    if(colI == this->cols - 1)
                    {
                        Eigen::Vector2f left = this->extractionPoints[pindex-1];
                        horizontalDirection = left - current;
                    } else {
                        Eigen::Vector2f right = this->extractionPoints[pindex+1];
                        horizontalDirection = right - current;
                    }

                    getSamplingPoints(current, bottom, horizontalDirection, samples1, samples2);

                    Eigen::Matrix<int, NCHAN, 1> score = this->getScore(img, samples1, samples2);

                    typename Marker<NCHAN>::DirectionType result; result.setZero();

                    result += ((score.array() > this->scoreDecider).template cast<EdgeType>().matrix())*EDGE_DIRECTION_LEFTUP;
                    result += ((score.array() < -this->scoreDecider).template cast<EdgeType>().matrix())*EDGE_DIRECTION_RIGHTDOWN;

                    this->edgeDirections[pindex + verticalOffset] = result;
                }
            }
        }
    }

#ifdef UMF_DEBUG
    bool showFieldCenters = false;
    bool showEdgeDirections = true;
    if(show)
    {
        //generate pencils going through the corners
        if(showFieldCenters) this->showFieldCenters(mask);
        if(showEdgeDirections) this->showEdgeDirections(mask);
    }
#endif

}

/**
 * @brief Show the extraction field centers relative to which the sampling points are generated
 */
template<int NCHAN>
void EdgeDirDetector<NCHAN>::showFieldCenters(ImageGray *mask)
{
    UMFDebug *dbg = UMFDSingleton::Instance();
    ImageRGB *imgDbg = dbg->getImage();

    if(imgDbg == nullptr)
    {
        return;
    }

    Eigen::Vector3i color(168, 25, 25);
    int lineWidth = 10;

    //now we can calculate the field centers
    for(unsigned int row = 0; row < this->rows; row++)
    {
        for(unsigned int col = 0; col < this->cols; col++)
        {
            Eigen::Vector2f cpos = this->extractionPoints[row*this->cols + col];

            if(mask != NULL)
            {
                if(checkBounds(cpos, mask->width, mask->height))
                {
                    if(*mask->get2D(cpos[0], cpos[1]) != 255)
                    {
                        continue;
                    }
                }
            }

            //Eigen::Vector3i color(255 - (row*1.0/this->rows + col*1.0/this->cols)*127, row*255.0/this->rows, col*255.0/this->cols);
            drawCircle(imgDbg, cpos.template cast<int>(), lineWidth, color, -1);
        }
    }
}

/**
 * @brief Show the edge directions for each channel
 */
template<int NCHAN>
void EdgeDirDetector<NCHAN>::showEdgeDirections(ImageGray *mask)
{
    UMFDebug *dbg = UMFDSingleton::Instance();
    ImageRGB *imgDbg = dbg->getImage();

    if(imgDbg == nullptr)
    {
        return;
    }

    std::vector<Eigen::Vector3i> colors(NCHAN);

    for(int i = 0; i < NCHAN; i++)
    {
        Eigen::Vector3f hsv(i*360/NCHAN, 100, 100);
        Eigen::Vector3f rgb = hsv2rgb(hsv);
        colors[i] = Eigen::Vector3i(rgb[0]*255/100, rgb[1]*255/100, rgb[2]*255/100);
    }

    int lineWidth = 3;

    //horizontal
    for(unsigned int row = 0; row < this->rows; row++)
    {
        for(unsigned int col = 0; col < this->cols - 1; col++) //ignore last column
        {
            int pindex = row*this->cols + col;
            Eigen::Vector2f current = this->extractionPoints[pindex];
            Eigen::Vector2f right = this->extractionPoints[pindex+1];

            if(mask != NULL)
            {
                if(checkBounds(current, mask->width, mask->height) && checkBounds(right, mask->width, mask->height))
                {
                    if(*mask->get2D(current[0], current[1]) != 255 || *mask->get2D(right[0], right[1]) != 255)
                    {
                        continue;
                    }
                } else {
                    continue;
                }
            }

            Eigen::Vector2f verticalDirection;
            if(row == this->rows - 1)
            {
                Eigen::Vector2f top = this->extractionPoints[pindex - this->cols];
                verticalDirection = top - current;
            } else {
                Eigen::Vector2f bottom = this->extractionPoints[pindex + this->cols];
                verticalDirection = bottom - current;
            }

            verticalDirection /= 2*NCHAN;
            typename Marker<NCHAN>::DirectionType &result = this->edgeDirections[pindex];


            Eigen::Vector2f offset = (right - current)*0.3;
            current += offset;
            right -= offset;

            for(int i = 0; i < NCHAN; i++)
            {
                Eigen::Vector2f vertOffset = (i - NCHAN/2)*verticalDirection;

                switch(result[i])
                {
                case EDGE_DIRECTION_LEFTUP:
                    drawArrow(imgDbg, (right + vertOffset).template cast<int>(), (current + vertOffset).template cast<int>(), colors[i], lineWidth);
                    break;
                case EDGE_DIRECTION_RIGHTDOWN:
                    drawArrow(imgDbg, (current + vertOffset).template cast<int>(), (right + vertOffset).template cast<int>(), colors[i], lineWidth);
                    break;
                default:
                    drawEquals(imgDbg, (current + vertOffset).template cast<int>(), (right + vertOffset).template cast<int>(), colors[i], lineWidth);
                }
            }
        }
    }

    int verticalOffset = this->rows*this->cols; //size of the horizontal stuff

    //vertical
    for(unsigned int row = 0; row < this->rows - 1; row++) //ignore last row
    {
        for(unsigned int col = 0; col < this->cols; col++)
        {
            int pindex = row*this->cols + col;
            Eigen::Vector2f current = this->extractionPoints[pindex];
            Eigen::Vector2f down = this->extractionPoints[pindex+this->cols];

            if(mask != NULL)
            {
                if(checkBounds(current, mask->width, mask->height) && checkBounds(down, mask->width, mask->height))
                {
                    if(*mask->get2D(current[0], current[1]) != 255 || *mask->get2D(down[0], down[1]) != 255)
                    {
                        continue;
                    }
                } else {
                    continue;
                }
            }

            Eigen::Vector2f horizontalDirection;
            if(col == this->cols - 1)
            {
                Eigen::Vector2f left = this->extractionPoints[pindex-1];
                horizontalDirection = left - current;
            } else {
                Eigen::Vector2f right = this->extractionPoints[pindex+1];
                horizontalDirection = right - current;
            }

            horizontalDirection /= 2*NCHAN;
            typename Marker<NCHAN>::DirectionType &result = this->edgeDirections[pindex+ verticalOffset];

            Eigen::Vector2f posOffset = (down - current)*0.3;
            current += posOffset;
            down -= posOffset;

            for(int i = 0; i < NCHAN; i++)
            {
                Eigen::Vector2f offset = (i - NCHAN/2)*horizontalDirection;

                switch(result[i])
                {
                case EDGE_DIRECTION_LEFTUP:
                    drawArrow(imgDbg, (down + offset).template cast<int>(), (current + offset).template cast<int>(), colors[i], lineWidth);
                    break;
                case EDGE_DIRECTION_RIGHTDOWN:
                    drawArrow(imgDbg, (current + offset).template cast<int>(), (down + offset).template cast<int>(), colors[i], lineWidth);
                    break;
                default:
                    drawEquals(imgDbg, (current + offset).template cast<int>(), (down + offset).template cast<int>(), colors[i], lineWidth);
                }
            }
        }
    }

}


template EdgeDirDetector<1>::EdgeDirDetector(); //for grayscale
template EdgeDirDetector<3>::EdgeDirDetector(); //for RGB

template void EdgeDirDetector<1>::extract(ImageGray *img, std::vector<Eigen::Vector3f> &pencil1, std::vector<Eigen::Vector3f> &pencil2, ImageGray* mask, bool show);
template void EdgeDirDetector<3>::extract(ImageRGB *img, std::vector<Eigen::Vector3f> &pencil1, std::vector<Eigen::Vector3f> &pencil2, ImageGray* mask, bool show);

}
