/*
 * $Id: lbpdetector.cpp 176 2015-07-01 12:38:52Z ijuranek $
 *
 * API for object detection in images. This is common to both MATLAB and C++
 * versions of detector.
 *
 * LBP Detector Toolbox
 * Roman Juranek <ijuranek@fit.vutbr.cz>
 * Faculty of Information Technology, BUT, Brno
 *
 * TODO:
 * * Optimize ftrs
 * * Mask or ROI
 * * #stages to eval
 * * Eval more detectors on the same image
 */

#include <cstdio>
#include <map>
#include <tuple>
#include <algorithm>
#include <cassert>
#include <cmath>
#include "common.h"
#include "lbpdetector.h"
#ifdef MATLAB_MEX_FILE
    #include "mex.h"
#endif
using namespace std;

namespace LBPDetector {

///////////////////////////////////////////////////////////////////////////////    
/// Caching subsystem
    
class FeatureCache
{
protected:
    long hits {0};
    long misses {0};
public:
    static FeatureCache * createCache(int x0, int x1, int sz0, int sz1, int cs, bool = false);
    virtual ~FeatureCache() {};
    virtual int getValue(int ind) = 0;
    virtual void update(int ind, int val) = 0;
    virtual int getEntry(int f0, int f1, int c) const = 0;
    long getHits() { return hits; }
    long getMisses() { return misses; }
    float getEffeciency() { return hits/float(hits+misses); }
};

class DummyCache: public FeatureCache
{
public:
    int getValue(int ind) { misses++; return -1; }
    void update(int ind, int val) {}
    int getEntry(int f0, int f1, int c) const { return 0; }
};

class DenseCache: public FeatureCache
{
    int ofs0, ofs1;
    int size;
    int * data;
    int strides[3];
public:
    DenseCache(int x0, int x1, int sz0, int sz1, int cs)
    :ofs0(x0), ofs1(x1), size(sz0 * sz1 * cs)
    {
        data = new int[size];
        fill(data, data+size, -1);
        strides[0] = 1;
        strides[1] = sz0;
        strides[2] = sz0*sz1;
    }
    ~DenseCache()
    {
        delete [] data;
    }
    int getValue(int ind)
    {
        (data[ind]<0) ? misses++ : hits++;
        return data[ind];
    }
    void update(int ind, int val)
    {
        data[ind] = val;
    }
    int getEntry(int f0, int f1, int c) const
    {
        f0 -= ofs0; f1 -= ofs1; 
        return f0 * strides[0] + f1 * strides[1] + c * strides[2];
    }
};

FeatureCache * FeatureCache::createCache(int x0, int x1, int sz0, int sz1, int cs, bool active)
{
    if (active) return new DenseCache(x0, x1, sz0, sz1, cs);
    else return new DummyCache();
}

///////////////////////////////////////////////////////////////////////////////
// Feature evaluation

const float weights[8] = {1, 2, 4, 8, 16, 32, 64, 128};
//const float weights[8] =    {0, 1, 0, 2,  4,  0,  8,   0};

template <class Tpx>
static Tpx sum_block1(const Tpx * im, const int im_w, int w, int h)
{
    Tpx sum = 0;
    for (int j = 0; j < h; ++j, im+=im_w)
        for (int i = 0; i < w; ++i) sum += im[i];
    return sum;
}

// u,v - position of feature  row/col
// s0,s1 - size of feature height width
int eval_lbp1(const Image & im, int f0, int f1, int f2, int f3, int f4)
{
    const float * base = im.data + im.strides[0]*f0 + im.strides[1]*f1 + im.strides[2]*f4;
    int d0_ofs = im.strides[0] * f2; int d1_ofs = im.strides[1] * f3;

    float c = sum_block1<float>(base + 1*d0_ofs + 1*d1_ofs, im.strides[1], f2, f3);

    int lbp = 0;
    lbp += weights[0] * (sum_block1<float>(base + 0*d0_ofs + 0*d1_ofs, im.strides[1], f2, f3) > c);
    lbp += weights[1] * (sum_block1<float>(base + 0*d0_ofs + 1*d1_ofs, im.strides[1], f2, f3) > c);
    lbp += weights[2] * (sum_block1<float>(base + 0*d0_ofs + 2*d1_ofs, im.strides[1], f2, f3) > c);
    lbp += weights[3] * (sum_block1<float>(base + 1*d0_ofs + 0*d1_ofs, im.strides[1], f2, f3) > c);
    lbp += weights[4] * (sum_block1<float>(base + 1*d0_ofs + 2*d1_ofs, im.strides[1], f2, f3) > c);
    lbp += weights[5] * (sum_block1<float>(base + 2*d0_ofs + 0*d1_ofs, im.strides[1], f2, f3) > c);
    lbp += weights[6] * (sum_block1<float>(base + 2*d0_ofs + 1*d1_ofs, im.strides[1], f2, f3) > c);
    lbp += weights[7] * (sum_block1<float>(base + 2*d0_ofs + 2*d1_ofs, im.strides[1], f2, f3) > c);
    return lbp;
}

int eval_lbp1_1x1(const Image & im, int f0, int f1, int, int, int f4)
{
    const float * base = im.data + im.strides[0]*f0 + im.strides[1]*f1 + im.strides[2]*f4;
    int d0_ofs = 1; int d1_ofs = im.strides[1];
    float c = *(base + 1*d0_ofs + 1*d1_ofs);
    int lbp = 0;
    lbp += weights[0] * (*(base + 0*d0_ofs + 0*d1_ofs) > c);
    lbp += weights[1] * (*(base + 0*d0_ofs + 1*d1_ofs) > c);
    lbp += weights[2] * (*(base + 0*d0_ofs + 2*d1_ofs) > c);
    lbp += weights[3] * (*(base + 1*d0_ofs + 0*d1_ofs) > c);
    lbp += weights[4] * (*(base + 1*d0_ofs + 2*d1_ofs) > c);
    lbp += weights[5] * (*(base + 2*d0_ofs + 0*d1_ofs) > c);
    lbp += weights[6] * (*(base + 2*d0_ofs + 1*d1_ofs) > c);
    lbp += weights[7] * (*(base + 2*d0_ofs + 2*d1_ofs) > c);
    return lbp;
}


int eval_lbp2(const Image & im, int f0, int f1, int f2, int f3, int f4)
{
    const float * base = im.data + im.strides[0]*f0 + im.strides[1]*f1 + im.strides[2]*f4;
    int d0_ofs = im.strides[0] * f2; int d1_ofs = im.strides[1] * f3;
    float c = *(base + 1*d0_ofs + 1*d1_ofs);
    int lbp = 0;
    lbp += weights[0] * (*(base + 0*d0_ofs + 0*d1_ofs) > c);
    lbp += weights[1] * (*(base + 0*d0_ofs + 1*d1_ofs) > c);
    lbp += weights[2] * (*(base + 0*d0_ofs + 2*d1_ofs) > c);
    lbp += weights[3] * (*(base + 1*d0_ofs + 0*d1_ofs) > c);
    lbp += weights[4] * (*(base + 1*d0_ofs + 2*d1_ofs) > c);
    lbp += weights[5] * (*(base + 2*d0_ofs + 0*d1_ofs) > c);
    lbp += weights[6] * (*(base + 2*d0_ofs + 1*d1_ofs) > c);
    lbp += weights[7] * (*(base + 2*d0_ofs + 2*d1_ofs) > c);
    return lbp;
}

///////////////////////////////////////////////////////////////////////////////
// Detector evaluation

struct ScalarResponse
{
    bool valid {false};
    float h {0};
    unsigned nf {0};
};

static const size_t szAlpha = 256;
static const size_t szFtr = 5;

template <class CacheType>
static ScalarResponse evalLBPScalarDetector(const Image & im, int x, int y, const Detector & D, const vector<int> & fids, FeatureCache & cache)
{
    CacheType & C = (CacheType)cache;
    ScalarResponse R;
    const unsigned char * f_t = D.ftr;
    const float * a_t = D.alpha;
    const int base_ind = C.getEntry(x,y,0);
    for (int t = 0; t < D.T; ++t, f_t+=szFtr, a_t+=szAlpha)
    {
        auto ind = base_ind + fids[t];
        int lbp = C.getValue(ind);
        if (lbp < 0)
        {
            lbp = D.FtrEvaluator[t](im, x+f_t[0], y+f_t[1], f_t[2], f_t[3], f_t[4]);
            C.update(ind, lbp);
        }        
        R.h += a_t[lbp];
        if (R.h < D.cascThr)
        {
            R.nf = t + 1;
            R.valid = false;
            return R;
        }
    }
    R.nf = D.T;
    R.valid = true;
    return R;
}


///////////////////////////////////////////////////////////////////////////////
// Image scanning

static vector<int> getFids(int x0, int y0, const FeatureCache & cache, const Detector & D)
{
    vector<int> fids(D.T);
    const unsigned char * ftr = D.ftr;
    for (int i = 0; i < D.T; ++i, ftr+=szFtr)
        fids[i] = cache.getEntry(x0+ftr[0], y0+ftr[1], D.CachePlane[i]);
    return fids;
}

ScalarResult scalarScanImage(const Image & im, const int x0, const int y0, const int x1, const int y1, const int stride, const Detector & D, const float scale, bool useCache)
{
    ScalarResult res;
    // Construct cache
    FeatureCache * cache = FeatureCache::createCache(x0, y0, (x1-x0)+D.wdim0, (y1-y0)+D.wdim1, D.numCachePlanes, useCache);
    // Configure evaluation function suitable for the cache
    typedef ScalarResponse (*DetectorEvalFnc)(const Image&, int, int, const Detector&, const vector<int>&, FeatureCache&);
    DetectorEvalFnc eval;
    if (useCache) eval = evalLBPScalarDetector<DenseCache&>;
    else eval = evalLBPScalarDetector<DummyCache&>;

    vector<int> fids = getFids(x0, y0, *cache, D);

    for (int x = x0; x < x1; x+=stride)
    {
        for (int y = y0; y < y1; y+=stride)
        {
            ScalarResponse R = eval(im, x, y, D, fids, *cache);
            if (R.valid)
            {
                res.bbs.push_back(BB(x, y, D.wdim0, D.wdim1, scale));
                res.hs.push_back(R.h);
            }
            res.nf += R.nf;
        }
    }
    res.ns = ((x1 - x0)/stride) * ((y1 - y0)/stride);
    res.cacheEffeciency = cache->getEffeciency();
    delete cache;
    return res;
}


void scalarScanImage(const Image & im, const int x0, const int y0, const int x1, const int y1, const int stride, const Detector & D, const float scale, bool useCache, Image & result)
{
    // TODO: Check im and result size
    // Construct cache
    FeatureCache * cache = FeatureCache::createCache(x0, y0, (x1-x0)+D.wdim0, (y1-y0)+D.wdim1, D.numCachePlanes, useCache);
    // Configure evaluation function suitable for the cache
    typedef ScalarResponse (*DetectorEvalFnc)(const Image&, int, int, const Detector&, const vector<int>&, FeatureCache&);
    DetectorEvalFnc eval;
    if (useCache) eval = evalLBPScalarDetector<DenseCache&>;
    else eval = evalLBPScalarDetector<DummyCache&>;

    vector<int> fids = getFids(x0, y0, *cache, D);

    for (int x = x0; x < x1; x+=stride)
    {
        for (int y = y0; y < y1; y+=stride)
        {
            ScalarResponse R = eval(im, x, y, D, fids, *cache);
            float * pResult = result.data + result.strides[0]*x + result.strides[1]*y;
            *pResult = (R.valid) ? R.h : -INFINITY;
        }
    }
    delete cache;
}

///////////////////////////////////////////////////////////////////////////////
// Detector management

void destroyDetector(Detector ** D)
{
    if (!*D) return;
    Detector * d = *D;
    if (d->dynamic)
    {
        delete [] d->ftr;
        delete [] d->alpha;
    }
    delete d;
}

bool initDetector(Detector & D)
{
    D.FtrEvaluator.resize(D.T);
    D.CachePlane.resize(D.T);
    map<int,int> cachePlanes;
    int p = 0;
    const unsigned char * f_t = D.ftr;
    for (int t = 0; t < D.T; ++t, f_t+=szFtr)
    {
        int w = f_t[2]; int h = f_t[3]; int c = f_t[4];
        Detector::FtrEvalFunc func = eval_lbp2;
        // if (w == 1 && h == 1) func = eval_lbp1_1x1;
        D.FtrEvaluator[t] = func;
        // Assign cahce plane for every ftr based on w,h,c
        int ftrsz = 100*w + 10*h + c; // hash [w,h,c] to unique int
        if (cachePlanes.find(ftrsz) == cachePlanes.end()) cachePlanes[ftrsz] = p++;
        D.CachePlane[t] = cachePlanes[ftrsz];
    }
    D.numCachePlanes = p;
    return true;
}

Detector * createDetector(int sz0, int sz1, int T, unsigned char * ftrs, float * hs, float cascThr)
{
    Detector * D = new Detector;
    D->T = T;
    D->wdim0 = sz0;
    D->wdim1 = sz1;
    D->ftr = ftrs;
    D->alpha = hs;
    D->cascThr = cascThr;
    D->dynamic = false;
    if (!initDetector(*D))
    {
        destroyDetector(&D);
        D = 0;
    }
    return D;
}

} // namespace
