/*
 * $Id: mexVectorDetect.cpp 176 2015-07-01 12:38:52Z ijuranek $
 *
 * MEX version of object detector. Compile with lbpdetector.compile in MATLAB.
 *
 * lhs  0   1  2              rhs  0     1       2    3      4    5   6
 *     [bbs,ns,nf] = mx_lbp_detect(imgs, scales, ftr, alpha, thr, sz, [nThreads]);
 *
 * LBP Detector Toolbox
 * Roman Juranek <ijuranek@fit.vutbr.cz>
 * Faculty of Information Technology, BUT, Brno
 *
 */

#include "mex.h"
#include <vector>
#include <list>
#include <cmath>
#include <cassert>
#include <numeric>
#include <algorithm>
#include <ThreadPool.h>
#include <common.h>
#include "lbpdetector.h"

using namespace std;
using namespace LBPDetector;

// Create C structure from parameters obtained from MATLAB.
Detector * createDetectorFromMxArray(const mxArray * ftr, const mxArray * alpha, const mxArray * thr, const mxArray * dim, const int predictor)
{
    const int * sz = (const int*)mxGetData(dim); // 1 x 2 int
    Detector * D = createDetector(sz[0], sz[1], mxGetN(ftr), (unsigned char*)mxGetData(ftr), (float*)mxGetData(alpha), (float*)mxGetData(thr));
    D->predictor = predictor;
    return D;
}

void mexFunction(int nlhs, mxArray ** lhs, int nrhs, const mxArray ** rhs)
{
    const mxArray * mx_imgs = rhs[0];
    const mxArray * mx_scales = rhs[1];

    int n_imgs = mxGetNumberOfElements(mx_imgs);
    const float * scales = (const float*)mxGetData(mx_scales);
    
    int predictor = 1;
    if (nrhs > 6) predictor = mxGetScalar(rhs[6]);
	assert(predictor >= 1 && predictor <= 8);
    
    //mexPrintf("%d images\n",n_imgs);
    Detector * D = createDetectorFromMxArray(rhs[2], rhs[3], rhs[4], rhs[5], predictor);
    if (!D)
    {
        mexErrMsgTxt("Invalid detector");
    }
    //mexPrintf("Detector: %dx%d with %d stages\n", D.w, D.h, D.T);
    
    double nThreads = 0;
    if (nrhs > 7) nThreads = mxGetScalar(rhs[7]);
    
    vector<BB> bbs;
    vector<float_vec_8> hs;
    int nf = 0; int ns = 0;

    if (nThreads <= 1)
    {
        // No threads version
        list<Image> ims;
        list<VectorResult> results;
        for (int k = 0; k < n_imgs; ++k)
        {
            const mxArray * img_k = mxGetCell(mx_imgs, k);
            // TODO: check img_k - single matrix
            ims.push_back(Image((float*)mxGetData(img_k), mxGetNumberOfDimensions(img_k), mxGetDimensions(img_k)));
            Image & im = ims.back();
            if (!im.data) break;
            const int max_d0 = im.dims[0] - D->wdim0 - 1;
            const int max_d1 = im.dims[1] - D->wdim1 - 1;
            if (max_d0 <= 0 || max_d1 <= 0) break;
            //mexPrintf("Image: %dx%d px, sacle=%.2f, k=%d\n", im.dim0, im.dim1, scales[k], k);
            VectorResult r = vector_scan_image(std::cref(im), 0, 0, max_d0, max_d1, 1, std::cref(*D), scales[k]);
            bbs.insert(bbs.end(), r.bbs.begin(), r.bbs.end());
            hs.insert(hs.end(), r.hs.begin(), r.hs.end());
            ns += r.ns; nf += r.nf;
        }
    }
    else
    {
        // Threaded version
        nThreads = std::max(1U, std::min(std::thread::hardware_concurrency(), unsigned(nThreads)));
        ThreadPool pool(nThreads);
        const int tile_sz = 256;
        list<Image> ims;
        list< future<VectorResult> > results;
        for (int k = 0; k < n_imgs; ++k)
        {
            const mxArray * img_k = mxGetCell(mx_imgs, k);
            // TODO: check img_k - single matrix
            ims.push_back(Image((float*)mxGetData(img_k), mxGetNumberOfDimensions(img_k), mxGetDimensions(img_k)));
            const Image & im = ims.back();
            if (!im.data) break;
            const int max_d0 = im.dims[0] - D->wdim0 - 1;
            const int max_d1 = im.dims[1] - D->wdim1 - 1;
            if (max_d0 <= 0 || max_d1 <= 0) break;
            //mexPrintf("Image: %dx%d px, sacle=%.2f, k=%d\n", im.dims[0], im.dims[1], scales[k], k);
            for (int d0 = 0; d0 < max_d0; d0+=tile_sz)
                for (int d1 = 0; d1 < max_d1; d1+=tile_sz)
                {
                    //mexPrintf("x=%i:%i; y=%i:%i\n",x, min(x+TILE_SZ-1, max_x), y, min(y+TILE_SZ-1, max_y));
                    results.emplace_back(
                        pool.enqueue(
                            vector_scan_image, std::cref(im), d0, d1, min(d0+tile_sz, max_d0), min(d1+tile_sz, max_d1), 1, std::cref(*D), scales[k] // TODO: parameterize stride
                        )
                    );
                }
        }
        for (auto && result: results)
        {
            VectorResult r = result.get();
            bbs.insert(bbs.end(), r.bbs.begin(), r.bbs.end());
            hs.insert(hs.end(), r.hs.begin(), r.hs.end());
            //mexPrintf("ns=%d, nf=%d, passed: %d\n", r.ns, r.nf, r.bbs.size());
            ns += r.ns; nf += r.nf;
        }
    }
    
    lhs[0] = mxCreateNumericMatrix(5, bbs.size(), mxDOUBLE_CLASS, mxREAL);
    lhs[1] = mxCreateNumericMatrix(D->predictor, bbs.size(), mxDOUBLE_CLASS, mxREAL);
    lhs[2] = mxCreateDoubleScalar(double(ns));
    lhs[3] = mxCreateDoubleScalar(double(nf));
    
    double * bbs_out = (double*)mxGetData(lhs[0]);
    double * hs_out = (double*)mxGetData(lhs[1]);
    for (int i = 0; i < bbs.size(); ++i)
    {
        BB & b = bbs[i];
        float_vec_8 h = hs[i];
        bbs_out[5*i+0] = b.y;
        bbs_out[5*i+1] = b.x;
        bbs_out[5*i+2] = b.height;
        bbs_out[5*i+3] = b.width;        
        bbs_out[5*i+4] = b.scale;

        copy_n(h.begin(), D->predictor, hs_out+D->predictor*i);
        //mexPrintf("%.2f, ", h);
    }
    
    destroyDetector(&D);
    
    return;
}
