#include <iostream>
#include <fstream>
#include <string>
#include <libxml/tree.h>
#include <cassert>

#include "./support/Log.h"
#include "./support/SimpleXML.h"


using namespace std;
struct TStatistic{
    const double minH;
    const double maxH;
    const int binNumber;
    
    std::string testID;

    std::vector< int> positiveHistogram;
    std::vector< int> negativeHistogram;

    int           totalSampleCount;
    int           positiveCount;

    std::vector< int>  terminationHistogram;
    std::vector< int>  detectionStageCounts;
    std::vector< int>  FPStageCounts;

    TStatistic()
        : minH( -80), maxH( 150), binNumber( 10000), positiveHistogram( binNumber), negativeHistogram( binNumber)
    {}

    TStatistic( const double _minH, const double _maxH, const unsigned _binNumber, unsigned _stageCount, std::string _testID = "")
        : minH( _minH), maxH( _maxH), binNumber( _binNumber), 
          positiveHistogram( binNumber),  testID( _testID), negativeHistogram( _binNumber),
          totalSampleCount( 0), positiveCount( 0),
          terminationHistogram( _stageCount), detectionStageCounts( _stageCount),
          FPStageCounts( _stageCount)
 
    {}
    
    void combine( const TStatistic &newStat);

    TStatistic clone( );

    bool saveResult( );
    bool loadResult( );

    int resultToHistogramPosition( const double result);
    double histogramPositionToResult( const int position);

    void integrate();
};


void processResultFile( const string fileName, TStatistic &statistic);


int main(int argc, char ** argv)
{
    string configFilename;
    string logFilename;
     
    // Read command line parameters
    for( int i = 1; i < argc; i++) {
        if( string( argv[ i]) == "-i"){
            configFilename = string(argv[++i]);
        } else if( string( argv[ i]) == "-l") {
            logFilename = argv[ ++i];
        } else if( strcmp(argv[i],"-h") == 0) {
            cout << "evaluator [options]" << endl;
            cout << "  options:" << endl;
            cout << "   -h         this help" << endl;
            cout << "   -i file    input XML config file " << endl;
            cout << "   -l file    the log file name. The format of log file is HTML and is" << endl;
            cout << "              printed on standard output. You must use this if you want to run" << endl;
            cout << "              multiple evaluation instances in the same directory or want to redirect" << endl;
            cout << "              the log output into a file in different directory then the current" << endl;
            cout << "              working directory. If you do not use this and instead redirect" << endl;
            cout << "              standard output using > on command line, it will be expected that" << endl;
            cout << "              the log was redirected into a file in current working directory" << endl;
            cout << "              and also all generated plots will be stored in the current working" << endl;
            cout << "              directory. With -l all plots are stored into a subdirectory" << endl;
            cout << "              with the same name as the log file name." << endl;
           return 0;
        } else {
            cerr << "Warning: parameter \"" << argv[i] << "\" was not recognized." << endl;
        }
    } 
    
    if( configFilename.empty()) { 
        cerr << "Error: Input config file must be specified. Use -h to get more help.\n"; 
        return 0; 
    }

    startLog( logFilename);

    

/////////////////////////////////////////////////////////////////////////////////////
//
/////////////////////////////////////////////////////////////////////////////////////

    // read the configuration
    xmlDocPtr xmlDoc;
    xmlInitParser();
    if( !(xmlDoc = xmlParseFile(configFilename.c_str()))){
        cerr << "Error: Cannot load configuration file \"" << configFilename << "\"." << endl;
        return 0;
    }

    vector< vector< string> > resultFileNames;
    vector< string> classifierNames;
    vector< string> testNames;
    
    xmlNodePtr rootNode = NULL;
    for( xmlNodePtr node = xmlDoc->children; node != NULL; node = node->next) // process all nodes
    {
        if( string( "EvaluatorConfig") == (char *)(node->name) ) {
            rootNode = node;
            break;
        }
    }

    if( rootNode == NULL){
        cerr << "Error: xml element \"EvaluatorConfig\" missing in file \"" << configFilename << "\"." << endl;
        return 0;
    }

    // get the names of the tests
    for( xmlNodePtr node = getNode( "TestName", rootNode); node != NULL; node = getNextNode( "TestName", node->next)){
        testNames.push_back( getAttr( "name", node));
    }

    // get the file names for test of the classifiers
    for( xmlNodePtr classifierNode = getNode( "Classifier", rootNode); classifierNode != NULL; classifierNode = getNextNode( "Classifier", classifierNode->next)){

        classifierNames.push_back( getAttr( "name", classifierNode));
        resultFileNames.push_back( vector< string>());

        for( xmlNodePtr testNode = getNode( "Test", classifierNode); testNode != NULL; testNode = getNextNode( "Test", testNode->next)){
            resultFileNames.back().push_back( getAttr( "fileName", testNode));
        }
    }


    vector< vector< TStatistic *> > statistics( classifierNames.size());

    double minValue = 0;
    double maxValue = 1;
    const int binNumber = 10000;


    // evaluate the performance of the classifiers on the selected tests
    for( unsigned i = 0; i < resultFileNames.size(); i++){
        for( unsigned j = 0; j < resultFileNames[i].size(); j++){
            
            statistics[ i].push_back( new TStatistic( minValue, maxValue, binNumber, 1, testNames[ j]));
            processResultFile( resultFileNames[i][j], *statistics[ i].back());
            statistics[ i].back()->integrate();
        }
    }

    // transpose the statistics for plotting results of multiple classifiers on a single test
    vector< vector< TStatistic *> > transposedStatistics( resultFileNames[0].size());
    for( unsigned j = 0; j <  resultFileNames[0].size(); j++){
        for( unsigned i = 0; i < resultFileNames.size(); i++){
             transposedStatistics[ j].push_back( statistics[ i][ j]);
        }
    }



    const int maximumFPCountShown = 10000;
    const bool logaritmicFP = false;
    const bool logaritmicFN = false;





    // if there are more classifiers do the graphs per test with more classifiers
    if( statistics.size() > 1){

       // create graphs for each classifier with results on all tests
        for( unsigned i = 0; i < transposedStatistics.size(); i++){ // for each classifier

            if( transposedStatistics[ i].empty()){
                continue;
            }

            vector< string> lineNames;
            vector< vector< double> > lastStageFP;
            vector< vector< double> > lastStageFN;
            vector< vector< double> > lastStageFP_x;
            vector< vector< double> > lastStageFN_x;
            vector< vector< double> > DET_X;
            vector< vector< double> > DET_Y;
            vector< vector< double> > ROC_X;
            vector< vector< double> > ROC_Y;
            vector< vector< double> > precision;
            vector< vector< double> > recall;

            // vector of TStatistics results of the classifiers on current test i
            vector< TStatistic *> &stat = transposedStatistics[ i];

            cout << "#" << endl;
            cout << "#   TEST " <<  stat[0]->testID << endl;
            cout << "#" << endl;

            // compute the plot lines for all classifiers on the current test
            for( unsigned testID = 0; testID < stat.size(); testID++){

                // add new line to each of the plots
                lineNames.push_back( classifierNames[ testID]);

                DET_X.push_back( vector<double>());
                DET_Y.push_back( vector<double>());
                ROC_X.push_back( vector<double>());
                ROC_Y.push_back( vector<double>());
                precision.push_back( vector<double>());
                recall.push_back( vector<double>());
                lastStageFP.push_back( vector<double>());
                lastStageFN.push_back( vector<double>());
                lastStageFP_x.push_back( vector<double>());
                lastStageFN_x.push_back( vector<double>());

                // lastStage FP and FN
                for( unsigned j = 0; j < stat[ testID]->negativeHistogram.size(); j++){
                    if( lastStageFP.back().empty() || lastStageFP.back().back() != stat[ testID]->negativeHistogram[ j] / (double) stat[ testID]->FPStageCounts[0]){
                        lastStageFP.back().push_back( stat[ testID]->negativeHistogram[ j] / (double) stat[ testID]->FPStageCounts[0]);
                        lastStageFP_x.back().push_back( stat[ testID]->histogramPositionToResult( j));
                    }
                    
                    if( lastStageFN.back().empty() || lastStageFN.back().back() != (double) (stat[ testID]->positiveCount - stat[ testID]->positiveHistogram[ j]) / (double) stat[ testID]->positiveCount){
                        lastStageFN.back().push_back( (double) (stat[ testID]->positiveCount - stat[ testID]->positiveHistogram[ j]) / (double) stat[ testID]->positiveCount);
                        lastStageFN_x.back().push_back( stat[ testID]->histogramPositionToResult( j));
                    }
                }

                
                // complete DET, ROC, recall and precision  using the last stage information
                for( unsigned j = 0; j < stat[ testID]->positiveHistogram.size(); j++){

                    double newMiss = (double) (stat[ testID]->positiveCount - stat[ testID]->positiveHistogram[ j]) / (double) stat[ testID]->positiveCount;
                    double newFP = (double) stat[ testID]->negativeHistogram[ j];

                    if( newFP != DET_X.back().back() || newMiss != DET_Y.back().back() ){
                        DET_X.back().push_back( newFP);
                        DET_Y.back().push_back( newMiss);

                        if( stat[ testID]->negativeHistogram[ j] < maximumFPCountShown){
                            ROC_X.back().push_back( (double) stat[ testID]->negativeHistogram[ j]);
                            ROC_Y.back().push_back( (double) stat[ testID]->positiveHistogram[ j] / (double) stat[ testID]->positiveCount);
                        }

                        if( (double)(stat[ testID]->positiveHistogram[ j] + stat[ testID]->negativeHistogram[ j]) > 0){
                            precision.back().push_back( stat[ testID]->positiveHistogram[ j] / (double)(stat[ testID]->positiveHistogram[ j] + stat[ testID]->negativeHistogram[ j]));
                        } else {
                            precision.back().push_back( 1.0);
                        }

                        recall.back().push_back( stat[ testID]->positiveHistogram[ j] / (double) stat[ testID]->positiveCount);
                    }
                }
                
                
                for( unsigned j = 0; j < DET_X.back().size(); j++){
                    DET_X.back()[ j] /= stat[ testID]->FPStageCounts[0];
                }
            }

            // plot the selected information into files

            cout << "</pre><p>" << endl;
            string formatingString;
            if( logaritmicFP){
                formatingString +=  "set logscale x\n";
            }
            if( logaritmicFN){
                formatingString +=  "set logscale y\n";
            }
            cout << "<img src=\"" << plotToLog( DET_X, DET_Y, lineNames, stat[0]->testID + " DET", formatingString) << "\"> " << endl;

            cout << "<img src=\"" << plotToLog( ROC_X, ROC_Y, lineNames, stat[0]->testID + " ROC") << "\"> " << endl;

            cout << "<img src=\"" << plotToLog( recall, precision, lineNames, stat[0]->testID + " Recall-Precision") << "\"> " << endl;

            cout << "<img src=\"" << plotToLog( lastStageFP_x, lastStageFP, lineNames, stat[0]->testID + " lastStage negative") << "\"> " << endl;
            cout << "<img src=\"" << plotToLog( lastStageFN_x, lastStageFN, lineNames, stat[0]->testID + " lastStage positive") << "\"> " << endl;

            cout << "</p><pre>" << endl;

            cout << "#" << endl;
            cout << "#   TEST " <<  stat[0]->testID << " DONE " << endl;
            cout << "#" << endl;
           
        }
    } 

    // plot results of each classifier on all tests
    for( unsigned i = 0; i < statistics.size(); i++){ // for each classifier
       
        vector< string> lineNames;
        vector< vector< double> > rejectionRate;
        vector< vector< double> > missRate;
        vector< vector< double> > avgSpeed;
        vector< vector< double> > lastStageFP;
        vector< vector< double> > lastStageFN;
        vector< vector< double> > lastStageFP_x;
        vector< vector< double> > lastStageFN_x;
        vector< vector< double> > DET_X;
        vector< vector< double> > DET_Y;
        vector< vector< double> > ROC_X;
        vector< vector< double> > ROC_Y;
        vector< vector< double> > precision;
        vector< vector< double> > recall;
        
        // vector of TStatistics results of current classifier i on all test
        vector< TStatistic *> &stat = statistics[ i];

        // compute the plot lines for the current classfier on all tests
        for( unsigned testID = 0; testID < stat.size(); testID++){

            // add new line to each of the plots
            lineNames.push_back( stat[ testID]->testID);

            rejectionRate.push_back( vector<double>());
            missRate.push_back( vector<double>());
            avgSpeed.push_back( vector<double>());
            DET_X.push_back( vector<double>());
            DET_Y.push_back( vector<double>());
            ROC_X.push_back( vector<double>());
            ROC_Y.push_back( vector<double>());
            precision.push_back( vector<double>());
            recall.push_back( vector<double>());
            lastStageFP.push_back( vector<double>());
            lastStageFN.push_back( vector<double>());
            lastStageFP_x.push_back( vector<double>());
            lastStageFN_x.push_back( vector<double>());


            // lastStage FP and FN
            for( unsigned j = 0; j < stat[ testID]->negativeHistogram.size(); j++){
                if( lastStageFP.back().empty() || lastStageFP.back().back() != stat[ testID]->negativeHistogram[ j] / (double) stat[ testID]->FPStageCounts[0]){
                    lastStageFP.back().push_back( stat[ testID]->negativeHistogram[ j] / (double) stat[ testID]->FPStageCounts[0]);
                    lastStageFP_x.back().push_back( stat[ testID]->histogramPositionToResult( j));
                }
                
                if( lastStageFN.back().empty() || lastStageFN.back().back() != (double) (stat[ testID]->positiveCount - stat[ testID]->positiveHistogram[ j]) / (double) stat[ testID]->positiveCount){
                    lastStageFN.back().push_back( (double) (stat[ testID]->positiveCount - stat[ testID]->positiveHistogram[ j]) / (double) stat[ testID]->positiveCount);
                    lastStageFN_x.back().push_back( stat[ testID]->histogramPositionToResult( j));
                }
            }

            // compute the rejection rate
            for( unsigned j = 0; j < stat[ testID]->terminationHistogram.size(); j++){
                rejectionRate.back().push_back( (stat[ testID]->totalSampleCount - stat[ testID]->terminationHistogram[ j]) / (double) stat[ testID]->totalSampleCount);
                missRate.back().push_back( (stat[ testID]->positiveCount - stat[ testID]->detectionStageCounts[ j]) / (double) stat[ testID]->positiveCount);
            }

            // compute avg. speed for all stages
            unsigned lengthSum = 0;
            unsigned lastTerminated = 0;

            for( unsigned j = 0; j < stat[ testID]->terminationHistogram.size(); j++){

                lengthSum += (j + 1) * ( stat[ testID]->terminationHistogram[ j] - lastTerminated);
                unsigned restSum = (j + 1) * ( stat[ testID]->totalSampleCount - stat[ testID]->terminationHistogram[ j]) ;
                lastTerminated = stat[ testID]->terminationHistogram[ j];
    
                avgSpeed.back().push_back( (double) (lengthSum + restSum) / (double) stat[ testID]->totalSampleCount);
            }

            // compute DET, ROC, recall and precision going through all stages
            for( unsigned j = 0; j < stat[ testID]->FPStageCounts.size(); j++){
                DET_X.back().push_back( (double) stat[ testID]->FPStageCounts[ j]);
                DET_Y.back().push_back( (double) (stat[ testID]->positiveCount - stat[ testID]->detectionStageCounts[ j]) / (double) stat[ testID]->positiveCount);   

                if( stat[ testID]->FPStageCounts[ j] < maximumFPCountShown){
                    ROC_X.back().push_back( (double) stat[ testID]->FPStageCounts[ j]);
                    ROC_Y.back().push_back( (double) stat[ testID]->detectionStageCounts[ j] / (double) stat[ testID]->positiveCount);
                }

                precision.back().push_back( stat[ testID]->detectionStageCounts[ j] / (double)(stat[ testID]->detectionStageCounts[ j] + stat[ testID]->FPStageCounts[ j]));
                recall.back().push_back( stat[ testID]->detectionStageCounts[ j] / (double) stat[ testID]->positiveCount);
            }
            
            // complete DET, ROC, recall and precision  using the last stage information
            for( unsigned j = 0; j < stat[ testID]->positiveHistogram.size(); j++){

                double newMiss = (double) (stat[ testID]->positiveCount - stat[ testID]->positiveHistogram[ j]) / (double) stat[ testID]->positiveCount;
                double newFP = (double) stat[ testID]->negativeHistogram[ j];

                if( newFP != DET_X.back().back() || newMiss != DET_Y.back().back() ){
                    DET_X.back().push_back( newFP);
                    DET_Y.back().push_back( newMiss);

                    if( stat[ testID]->negativeHistogram[ j] < maximumFPCountShown){
                        ROC_X.back().push_back( (double) stat[ testID]->negativeHistogram[ j]);
                        ROC_Y.back().push_back( (double) stat[ testID]->positiveHistogram[ j] / (double) stat[ testID]->positiveCount);
                    }

                    if( (double)(stat[ testID]->positiveHistogram[ j] + stat[ testID]->negativeHistogram[ j]) > 0){
                        precision.back().push_back( stat[ testID]->positiveHistogram[ j] / (double)(stat[ testID]->positiveHistogram[ j] + stat[ testID]->negativeHistogram[ j]));
                    } else {
                        precision.back().push_back( 1.0);
                    }
                    
                    recall.back().push_back( stat[ testID]->positiveHistogram[ j] / (double) stat[ testID]->positiveCount);
                }


            }

            // If DET should show relative values, normalize the values
            // ?????????????????????????????????????????????????????????????????????
            // ????????????? is /= stat[ testID]->FPStageCounts[0] OK ??????????????
            // ?????????????????????????????????????????????????????????????????????
            for( unsigned j = 0; j < DET_X.back().size(); j++){
                DET_X.back()[ j] /= stat[ testID]->FPStageCounts[0];
            }

        }


         // plot the selected information into files

        cout << "</pre><p>" << endl;
        

        string formatingString;
        if( logaritmicFP){
            formatingString +=  "set logscale x\n";
        }
        if( logaritmicFN){
            formatingString +=  "set logscale y\n";
        }

        cout << "<img src=\"" << plotToLog( DET_X, DET_Y, lineNames, classifierNames[ i] + " DET", formatingString) << "\"> " << endl;

        cout << "<img src=\"" << plotToLog( ROC_X, ROC_Y, lineNames, classifierNames[ i] + " ROC") << "\"> " << endl;
        
        cout << "<img src=\"" << plotToLog( recall, precision, lineNames, classifierNames[ i] + " Recall-Precision") << "\"> " << endl;

        cout << "<img src=\"" << plotToLog( lastStageFP_x, lastStageFP, lineNames, classifierNames[ i] + " last stage negative") << "\"> " << endl;
        cout << "<img src=\"" << plotToLog( lastStageFN_x, lastStageFN, lineNames, classifierNames[ i] + " last stage positive") << "\"> " << endl;

        cout << "</p><pre>" << endl;

    }

/////////////////////////////////////////////////////////////////////////////////////
//
/////////////////////////////////////////////////////////////////////////////////////

    endLog();
}





void processResultFile( const string fileName, TStatistic &statistic)
{   
    fstream inputFile( fileName.data(), ios_base::in);

    if( inputFile.fail()){
        cerr << "Error: Unable to open file \"" << fileName << "\"." << endl;
        return;
    }
    
    while( inputFile.good()){
        double classID;
        inputFile >> classID; // skip the first column
        if( inputFile.fail()){
            cerr << "Error: Reading file \"" << fileName << "\"." << endl;
            continue;
        }
        inputFile >> classID;
        if( inputFile.fail()){
            cerr << "Error: Reading file \"" << fileName << "\"." << endl;
            continue;
        }
        
        // get the result
        double result;
        inputFile >> result; // skip the third column
        if( inputFile.fail()){
            cerr << "Error: Reading file \"" << fileName << "\"." << endl;
            continue;
        }
        inputFile >> result; 
        if( inputFile.fail()){
            cerr << "Error: Reading file \"" << fileName << "\"." << endl;
            continue;
        }

        statistic.totalSampleCount++;
        if( classID > 0.5){
            statistic.positiveCount++;
            statistic.positiveHistogram[ statistic.resultToHistogramPosition( result)]++;
        } else {
            statistic.negativeHistogram[ statistic.resultToHistogramPosition( result)]++;
        }
    }

}










//=================================================================================================================
// TStatistic
//=================================================================================================================

void TStatistic::combine( const TStatistic &newStat)
{
    assert( this->minH == newStat.minH);
    assert( this->maxH == newStat.maxH);
    assert( this->binNumber == newStat.binNumber);
    assert( this->testID == newStat.testID);
    assert( this->terminationHistogram.size() == newStat.terminationHistogram.size());
 
    this->totalSampleCount += newStat.totalSampleCount;
    this->positiveCount += newStat.positiveCount;

    for( int i = 0; i < binNumber; i++){
        this->positiveHistogram[ i] += newStat.positiveHistogram[ i];
        this->negativeHistogram[ i] += newStat.negativeHistogram[ i];
    }

    for( unsigned i = 0; i < terminationHistogram.size(); i++){
        this->terminationHistogram[ i] += newStat.terminationHistogram[i];
        this->detectionStageCounts[ i] += newStat.detectionStageCounts[i];
        this->FPStageCounts[ i] += newStat.FPStageCounts[i];
    }
}

TStatistic TStatistic::clone( )
{
    return TStatistic( this->minH, this->maxH, this->binNumber, (unsigned) this->terminationHistogram.size(), this->testID);
}

int TStatistic::resultToHistogramPosition( const double result)
{
    int index = int( (  result - minH) / ( maxH - minH) * ( binNumber - 1) + 0.5);

    if( index < 0){ 
        cerr << "Warning: result out of range. Result: " << result << " min: " << minH << endl;
        return 0;
    } else if( index >= binNumber){
        cerr << "Warning: result out of range. Result: " << result << " max: " << maxH << endl;
        return binNumber - 1;
    } else {
        return index;
    }
}


double TStatistic::histogramPositionToResult( const int position)
{
    return ( position + 0.5) / (binNumber - 1.0) * (maxH - minH) + minH;
}


void TStatistic::integrate()
{
    int detections = 0;
    int falsePositives = 0;

    for( int i = binNumber - 1; i >= 0; i--){
        detections += positiveHistogram[ i];
        positiveHistogram[i] = detections;

        falsePositives += negativeHistogram[ i];
        negativeHistogram[ i] = falsePositives;
    }

    for( int i = ((int) detectionStageCounts.size()) - 1; i >= 0; i--){
        detections += detectionStageCounts[ i];
        falsePositives += FPStageCounts[ i];

        detectionStageCounts[ i] = detections;
        FPStageCounts[ i] = falsePositives;
    }


    int sampleCount = 0;
    for( unsigned i = 0; i < terminationHistogram.size(); i++){
        sampleCount += terminationHistogram[ i];
        terminationHistogram[ i] = sampleCount;
    }
}