#include "mex.h"
#include <cmath>
#include <ThreadPool.h>

using namespace std;

void accumWts(const unsigned char * X, const double * wts, const int * is, int i0, int i1, int N, int F, double * W)
{
    for (int i = i0; i < i1; ++i)
    {
        int ofsX = N * is[i];
        int ofsW = 256 * i;
        for (int k = 0; k < N; ++k)
        {
            unsigned char lbp = X[ofsX+k];
            W[ofsW+lbp] += wts[k];
        }
    }
}

// W = lbpTrainFeature1(X, wts, is, nThreads)
void mexFunction(int nlhs, mxArray ** lhs, int nrhs, const mxArray ** rhs)
{
    const mxArray * X = rhs[0];   // N x F
    const mxArray * wts = rhs[1]; // N x 1
    const mxArray * is = rhs[2];   // 1 x K
    unsigned char * X1 = (unsigned char*)mxGetData(X);
    double * wts1 = (double*)mxGetData(wts);
    int * is1 = (int*)mxGetData(is);
    int nThreads = (int)mxGetScalar(rhs[3]);

    int N = mxGetM(X);
    int F = mxGetN(X);
    int K = mxGetN(is);

    mxArray * W = mxCreateNumericMatrix(256, K, mxDOUBLE_CLASS, mxREAL);
    double * W1 = (double*)mxGetData(W);

    if (nThreads > 1)
    {
        nThreads = min((unsigned)nThreads, std::thread::hardware_concurrency());
        ThreadPool pool(nThreads);
        const int batch = 1024;
        for (int j = 0; j < K; j += batch)
            pool.enqueue(accumWts, X1, wts1, is1, j, min(j+batch,K), N, F, W1);
    }
    else
    {
        accumWts(X1, wts1, is1, 0, K, N, F, W1);
    }

    lhs[0] = W;
}
