#include "mex.h"
#include <cmath>
#include <cassert>
#include <ThreadPool.h>
#include "common.h"

using namespace std;

// Undocumented functions from lbpdetector.cpp
namespace LBPDetector{
    int eval_lbp1(const Image & im, int f0, int f1, int f2, int f3, int f4);
    int eval_lbp2(const Image & im, int f0, int f1, int f2, int f3, int f4);
}

void calcFeatures(float * X, const int * dims, int nX, int x0, int x1, const unsigned char * P, int nP, unsigned char * result)
{
    for (int x = x0; x < x1; ++x)
    {
        int xOffset = dims[0]*dims[1]*dims[2]*x;
        Image im(X+xOffset, 3, dims);
        for (int p = 0; p < nP; ++p)
        {
            const unsigned char * f = P+5*p;
            result[nX*p+x] = LBPDetector::eval_lbp2(im, f[0], f[1], f[2], f[3], f[4]);
        }
    }
}


// lbp = mexLBP(X, P, nThreads)
void mexFunction(int nlhs, mxArray ** lhs, int nrhs, const mxArray ** rhs)
{
    const mxArray * X = rhs[0];
    const mxArray * P = rhs[1];
    int nThreads = (int)mxGetScalar(rhs[2]);
    
    const mwSize * dims = mxGetDimensions(X);
    mwSize ndims = mxGetNumberOfDimensions(X);
    
    assert(nlhs == 3 || nlhs == 5);
    assert(ndims==4);
    assert(mxIsSingle(X));
    assert(mxIsUint8(P));
    assert(mxGetM(P)==5);
    
    int nSamples = dims[3];
    int nP = mxGetN(P);
    float * xData = (float*)mxGetData(X);
    unsigned char * pData = (unsigned char*)mxGetData(P);
    /*
     * mexPrintf("samples: %d %d %d %d\n", dims[0], dims[1], dims[2], dims[3]);
     * mexPrintf("nSamples: %d; nP: %d\n", nSamples, nP);
     */
    mxArray * lbp = mxCreateNumericMatrix(nSamples, nP, mxUINT8_CLASS, mxREAL);
    unsigned char * lbpData = (unsigned char*)mxGetData(lbp);
    
    if (nThreads > 1)
    {
        ThreadPool pool(nThreads);
        int batch = 64; // # samples for one thread
        for (int i = 0; i < nSamples; i+=batch)
            pool.enqueue(calcFeatures, xData, dims, nSamples, i, min(i+batch,nSamples), pData, nP, lbpData);
    } // Threads join automatically here
    else
    {
        // Single thread version
        calcFeatures(xData, dims, nSamples, 0, nSamples, pData, nP, lbpData);
    }
    
    lhs[0] = lbp;
}
