//==============================================================================
/*! \file
 * Medical Data Segmentation Toolkit (MDSTk)    \n
 * Copyright (c) 2003-2006 by Michal Spanel     \n
 *
 * Author:  Michal Spanel, spanel@fit.vutbr.cz  \n
 * File:    mdsEMAlgorithmTEST.cpp              \n
 * Section: libMathTEST                         \n
 * Date:    2006/09/15                          \n
 *
 * $Id:$
 *
 * Description:
 * - Testing of the mds::math::CMaxLikelihoodEstimationByEM template.
 */

#include <MDSTk/Base/mdsSetup.h>
#include <MDSTk/Math/mdsRandom.h>
#include <MDSTk/Math/mdsVector.h>
#include <MDSTk/Math/mdsVectorFunctions.h>

// Enable logging
#define EM_LOGGING_ENABLED

#include <MDSTk/Math/Algorithm/mdsEM.h>

// STL
#include <iostream>


//==============================================================================
/*
 * Global constants.
 */

//! Number of input samples.
const mds::tSize NUM_OF_SAMPLES = 1000;

//! Gaussian parameters.
const double M1 = 0.0;
const double S1 = 5.0;
const double M2 = -30.0;
const double S2 = 5.0;
const double M3 = 50.0;
const double S3 = 10.0;


//==============================================================================
/*!
 * Prints a given vector.
 */
void printVector(mds::math::CFVector& v)
{
    std::cout.setf(std::ios_base::fixed);
    std::cout << "  ";
    for( mds::tSize i = 0; i < v.getSize(); i++ )
    {
        std::cout << v(i) << " ";
    }
    std::cout << std::endl;
}


//==============================================================================
/*!
 * Waiting for a key.
 */
void keypress()
{
    while( std::cin.get() != '\n' );
}


//==============================================================================
/*!
 * main()
 */
int main(int argc, const char *argv[])
{
    // Init global log
    MDS_LOG_INIT_STDERR;
//    MDS_LOG_INIT_FILE("temp.log");

    // Random numbers generator
    mds::math::CNormalPRNG Random;
    std::cout << "Random number generator parameters:" << std::endl;
    std::cout << "  Gaussian 1: Mean = " << M1 << ", Sigma = " << S1 << std::endl;
    std::cout << "  Gaussian 2: Mean = " << M2 << ", Sigma = " << S2 << std::endl;
    std::cout << "  Gaussian 3: Mean = " << M3 << ", Sigma = " << S3 << std::endl;

    // One third of all samples
    mds::tSize THIRD = NUM_OF_SAMPLES / 3;

    // Generate random data
    std::cout << "Vector 1" << std::endl;
    mds::math::CFVector v1(NUM_OF_SAMPLES, 1);
    for( mds::tSize i = 0; i < NUM_OF_SAMPLES; ++i )
    {
        if( i < THIRD )
        {
            v1(i) = float(Random.random(M1, S1));
        }
        else if( i < 2 * THIRD )
        {
            v1(i) = float(Random.random(M2, S2));
        }
        else
        {
            v1(i) = float(Random.random(M3, S3));
        }
    }
    printVector(v1);
    keypress();

    mds::math::CMaxLikelihoodByEM<mds::math::CFVector,1> Clustering;
//    if( !Clustering.execute(v1, 3) )
    if( !Clustering.execute(v1) )
    {
        std::cout << "Error: EM algorithm failed!" << std::endl;
        return 0;
    }
    
    mds::math::CMaxLikelihoodByEM<mds::math::CFVector,1>::tComponent c;
    std::cout << "Number of components: " << Clustering.getNumOfComponents() << std::endl;
    for( mds::tSize i = 0; i < Clustering.getNumOfComponents(); ++i )
    {
        c = Clustering.getComponent(i);
        std::cout << "  Component mean " << i << ": " << c.getMean() << std::endl;
    }

    mds::math::CMaxLikelihoodByEM<mds::math::CFVector,1>::tVector v;
    std::cout << "Membership function:"<< std::endl;
    Clustering.getMembership(50, v);
    std::cout << "  Sample 150: " << v << std::endl;
    Clustering.getMembership(150, v);
    std::cout << "  Sample 450: " << v << std::endl;
    Clustering.getMembership(250, v);
    std::cout << "  Sample 750: " << v << std::endl;

    return 0;
}

