
// THIS FUNCTION IS UNDER DEVELOPMENT AND DOES NOT WORK YET!


// t = optimize_splits(M, Y, W)
#include <mex.h>
#include <cmath>
#include <algorithm>
#include <cstring>

using namespace std;

void mexFunction(int nlhs, mxArray ** plhs, int nrhs, const mxArray ** prhs)
{
    const float * M = (const float*)mxGetData(prhs[0]); // nsamples x nftrs float
    int nsamples = mxGetM(prhs[0]);
    int nftrs = mxGetN(prhs[0]);
    const char * Y = (const char*)mxGetData(prhs[1]);   // nsamples x 1 char
    // const float * W = (const float*)mxGetData(prhs[2]); // nsamples x 1 float
    const char * classes = (const char*)mxGetData(prhs[3]); // nclasses x 1 char
    int nclasses = mxGetNumberOfElements(prhs[3]);
    
    int nbins = 16; // TODO: Make it a parmeter
    
    //mexPrintf("#samples: %d; #ftrs: %d; #classes: %d;\n", nsamples, nftrs, nclasses);
    
    
    // Transcode Y so entries denote indices
    // Suppose we have classes 5, 2 and 1 in Y. We need to represent such classes
    // in a compact way - so entries in Y are 0, 1 and 2.
    //           1 2     5
    // ytab = [x 2 1 x x 0 x x ... ]
    //
    // This table is used to transform entries in Y.
    
    char ytab[256];
    for (int i = 0; i < nclasses; ++i)
        ytab[classes[i]] = i;
    
    // class probability for histogram normalization
    float pc[256];
    {
        const char * y = Y;
        fill(pc, pc+256, 0.0f);
        for (int sid = 0; sid < nsamples; ++sid, ++y)
        {
            char c = ytab[*y];
            pc[c] += 1/nsamples;
        }
    }
    
    
    // Get ranges for every fid - get min and max of M
    float m0[nftrs], m1[nftrs];
    for (int fid = 0; fid < nftrs; ++fid)
    {
        const float * M_fid = M + fid*nsamples; // base of M for fid
        m0[fid] = INFINITY; m1[fid] = -INFINITY;
        for (const float * m = M_fid; m < M_fid + nsamples; ++m)
        {
            m0[fid] = min(m0[fid], (*m)-0.1f);
            m1[fid] = max(m1[fid], (*m)+0.1f);
        }
        //mexPrintf("fid=%d; <%.3f,%.3f>\n", fid, m0[fid], m1[fid]);
    }
    
    
    // Calc cumulative distributions for every feature and class
    float * cdf = new float[nftrs*nbins*nclasses]; // nbins x nclasses x nftrs float 
    fill(cdf, cdf+nftrs*nbins*nclasses, 0.0f); // Zero the cdf
    
    for (int fid = 0; fid < nftrs; ++fid)
    {
        //mexPrintf("fid=%d; min=%f; max=%f\n", fid, m0[fid], m1[fid]);
        const float * m = M + fid*nsamples;
        const char * y = Y;
        //const float * w = W;
        float fid_range = m1[fid] - m0[fid];
        
        // Weighted histogram
        for (int sid = 0; sid < nsamples; ++sid, ++m, ++y)
        {
            int bin = round(((*m - m0[fid]) / fid_range) * (nbins-1));
            // mexPrintf("fid %d; class %d; m=%f; bin=%d\n", fid, *y, *m, bin);
            char c = ytab[*y];
            cdf[fid*(nbins*nclasses)+c*nbins+bin] += 1/nsamples;
        }
        
        // Accumulate and normalize
        for (char c = 0; c < nclasses; ++c)
        {
            for (int bin = 1; bin < nbins; ++bin)
            {
                cdf[fid*(nbins*nclasses)+c*nbins+bin] += cdf[fid*(nbins*nclasses)+c*nbins+(bin-1)];
            }
        }
        
    }
    
    
    plhs[0] = mxCreateNumericMatrix(1,nftrs,mxSINGLE_CLASS,mxREAL); // thresholds
    plhs[1] = mxCreateNumericMatrix(1,nftrs,mxSINGLE_CLASS,mxREAL); // errors
    float * thr = (float*)mxGetData(plhs[0]);
    float * err = (float*)mxGetData(plhs[1]);
    
    for (int fid = 0; fid < nftrs; ++fid)
    {
        int bin_id = -1;
        float e_min = INFINITY;
        for (int bin = 0; bin < nbins; ++bin)
        {
            float e = 0;
            for (char c = 0; c < nclasses; ++c)
            {
                float p = cdf[fid*(nbins*nclasses)+c*nbins+bin];
                float L = p * log(p);
                float f = pc[ytab[c]];
                float R = (f-p) * log(f-p);
                e -= L + R;
            }
            //mexPrintf("%f\n", e);
            if (e < e_min)
            {
                bin_id = bin;
                e_min = e;
            }
        }
        
        float fid_range = m1[fid] - m0[fid];
        float b = fid_range / nbins; // bin width
        thr[fid] = m0[fid] + (bin_id+0.5) * b;
        err[fid] = e_min;
        //mexPrintf("%f, %f, %f\n", r, m0[fid], m1[fid]);
    }
    

    delete [] cdf;
    
}   
