//==============================================================================
/*! \file
 * Medical Data Segmentation Toolkit (MDSTk)    \n
 * Copyright (c) 2003-2005 by Michal Spanel     \n
 *
 * Author:  Michal Spanel, spanel@fit.vutbr.cz  \n 
 * File:    mdsEMAlgorithm.cpp                  \n
 * Section: mSliceSegEM                         \n
 * Date:    2005/11/02                          \n
 *
 * $Id: mdsEMAlgorithm.cpp 345 2007-06-11 13:23:09Z spanel $
 *
 * Description:
 * - Image segmentation algorithm based on gaussian mixture model
 *   and Expectation-Maximization (EM) algorithm.
 */

#include "mdsEMAlgorithm.h"

#include <MDSTk/Base/mdsGlobalLog.h>
#include <MDSTk/Math/mdsLogNum.h>
#include <MDSTk/Image/mdsImageFunctions.h>

#include <float.h>
#include <limits.h>
#include <cmath>


namespace mds
{
namespace seg
{

//==============================================================================
/*
 * Implementation of the CGaussianFunc1D class.
 */

//! Minimal gaussian component standard deviation.
double CGaussianFunc1D::m_dMinSigma = CImageEMAlgorithm::DEFAULT_MIN_SIGMA;


//==============================================================================
/*
 * Implementation of the mds::CImageEMAlgorithm class.
 */
const double CImageEMAlgorithm::DEFAULT_MIN_SIGMA       = 100.0;
const double CImageEMAlgorithm::DEFAULT_MIN_ADD_CHANGE  = 0.001;
const double CImageEMAlgorithm::MIN_CHANGE              = 1.0e-6;
const double CImageEMAlgorithm::SPLITTING_COEFF         = 0.5;
const tSize CImageEMAlgorithm::MAX_ITERS                = 10;
const tSize CImageEMAlgorithm::MAX_ITERS2               = 10;


CImageEMAlgorithm::CImageEMAlgorithm(tSize NumOfComponents, double dMinAddChange)
    : m_dMinAddChange(dMinAddChange)
    , m_NumOfComponents(NumOfComponents)
    , m_PixelMin(mds::img::CPixelTraits<tPixel>::getPixelMin())
    , m_PixelMax(mds::img::CPixelTraits<tPixel>::getPixelMax())
    , m_Span(tSize(m_PixelMax - m_PixelMin + 1))
    , m_Count(0)
    , m_Components(NumOfComponents)
    , m_Maps(NumOfComponents, m_Span)
    , m_Histogram(m_PixelMin, m_PixelMax)
{
    MDS_ASSERT(m_NumOfComponents >= 0);

    m_dInvSpan = 1.0 / m_Span;
    m_dInvCount = 0.0;
}


bool CImageEMAlgorithm::operator()(const tImage& SrcImage, tImage& DstImage)
{
    // Get the effective size
    tSize XCount = MIN(SrcImage.getXSize(), DstImage.getXSize());
    tSize YCount = MIN(SrcImage.getYSize(), DstImage.getYSize());

    // Compute histogram of input image
    m_Histogram(SrcImage);

    // Compute the number of classified pixels in the histogram bins
    m_Count = tSize(m_Histogram.getTotalCount());
    m_dInvCount = (m_Count > 0) ? 1.0 / m_Count : 0.0;

    // Is the number of clusters known?
    if( m_NumOfComponents > 0 )
    {
        // Initialize the components
        initComponents(SrcImage, XCount, YCount);

        // Compute the EM
        iterateEM(MIN_CHANGE);

        // Final segmentation
        segmentImage(SrcImage, DstImage, XCount, YCount);

        // O.K.
        return true;
    }

    // Initial log-likelihood
    double dLogLikelihood = 1.0;

    // Test various numbers of clusters
    m_NumOfComponents = 1;

    // Resize the support maps
    m_Maps.create(m_NumOfComponents, m_Span);

    // Resize the vector of segment parameters
    m_Components.create(m_NumOfComponents);

    // Initialize the components
    initFirstComponent();

    // Iterate
    for( ;; )
    {
        // Compute the EM
        double dNewValue = iterateEM(MIN_CHANGE);

        // Estimate change of the log-likelihood
        double dDelta = mds::math::getAbs(dNewValue / dLogLikelihood - 1.0);

        // Estimate changes
        if( dDelta < m_dMinAddChange )
        {
            break;
        }

        // Image segmentation
        segmentImage(SrcImage, DstImage, XCount, YCount);

        // Update current log-likelihood
        dLogLikelihood = dNewValue;

        // Find the largest component
        tSize ComponentToSplit = findComponent();

        // Increment the number of clusters
        ++m_NumOfComponents;

        // Resize support maps
        m_Maps.create(m_NumOfComponents, m_Span);

        // Resize vector of components
        tComponents Old(m_Components);
        m_Components.create(m_NumOfComponents);
        m_Components = Old;

        // Split the largest component
        splitComponent(ComponentToSplit, m_NumOfComponents - 1);
    }

    // O.K.
    return true;
}


void CImageEMAlgorithm::processEStep()
{
    // Clear all support maps
    clearSupportMaps();

    // For each histogram bin
    for( tSize j = 0; j < m_Span; ++j )
    {
        // Representative of the current bin
        double dPixel = m_Histogram.getLowerBound(j);

        // The number of pixels classified in the bin
        double dCount = m_Histogram.getCount(j);

        // Add probabilities to the support maps
        double dSum = 0.0;
        tSize i;
        for( i = 0; i < m_NumOfComponents; ++i )
        {
            double dTemp = dCount * m_Components(i).getWeightedValue(dPixel);
            m_Maps(i,j) += dTemp;
            dSum += dTemp;
        }

        // Invert the computed sum
        double dInvSum = (dSum > 0.0) ? 1.0 / dSum : 1.0;

        // Divide the value in each support map
        for( i = 0; i < m_NumOfComponents; ++i )
        {
            m_Maps(i,j) *= dInvSum;
        }
    }
}


void CImageEMAlgorithm::processMStep()
{
    // Total number of pixels classified in histogram
    double dInvTotalCount = mds::math::getSum<double>(m_Histogram.getHistogram());
    dInvTotalCount = (dInvTotalCount > 0.0) ? 1.0 / dInvTotalCount : 1.0;

    // For each segment form new values of its parameters
    for( tSize i = 0; i < m_NumOfComponents; ++i )
    {
        double dComponentMean = m_Components(i).getMean();
        
        // Initialize required sums
        double dPSum = 0.0, dMSum = 0.0, dSSum = 0.0;
        
        // For each histogram bin
        for( tSize j = 0; j < m_Span; ++j )
        {
            // Representative of the current histogram bin
            double dPixel = m_Histogram.getLowerBound(j);
            
            // Number of pixels classified in the same bin
            double dCount = m_Histogram.getCount(j);
 
            dPSum += dCount * m_Maps(i, j);
            dMSum += dCount * dPixel * m_Maps(i, j);
 
            double dTemp = dPixel - dComponentMean;
            dSSum += dCount * dTemp * dTemp * m_Maps(i, j);
        }

        // Safe invert of the sum
        double dInvPSum = (dPSum > 0.0) ? 1.0 / dPSum : 1.0;
 
        // Final values of weight, mean and standard deviation    
        m_Components(i).setWeight(dPSum * dInvTotalCount);
        m_Components(i).setMean(dMSum * dInvPSum);
        m_Components(i).setSigma(sqrt(dSSum * dInvPSum));
    }
}


double CImageEMAlgorithm::iterateEM(double dMinChange)
{
    // Initial value of the log-likelihood function
    double dLogLikelihoodFunc = 1.0;

    // Second level iterations
    for( tSize j = 0; j < MAX_ITERS2; ++j )
    {
        // Re-initialization of similar components
        checkForSimilarComponents();
        
        // Initial value of the log-likelihood function
        dLogLikelihoodFunc = 1.0;

        // Iterate while the function converges
        for( tSize i = 0; i < MAX_ITERS; ++i )
        {
            // E-step
            processEStep();
    
            // M-step
            processMStep();
    
            // Evaluate the log-likelihood function
            double dNewValue = computeLogLikelihood();
    
            // Eestimate change of the log-likelihood function
            double dDelta = mds::math::getAbs(dNewValue / dLogLikelihoodFunc - 1.0);

#ifdef EM_LOGGING_ENABLED
            MDS_LOG_NOTE("iterateEM()");
            MDS_LOG_NOTE("  Log-likelihood Function = " << dNewValue);
            MDS_LOG_NOTE("  Delta = " << dDelta);
#endif // EM_LOGGING_ENABLED

            // Estimate changes
            if( dDelta < dMinChange )
            {
                break;
            }
    
            // Update the current value
            dLogLikelihoodFunc = dNewValue;
        }
    }

#ifdef EM_LOGGING_ENABLED
    MDS_LOG_NOTE("iterateEM()");
    for( tSize i = 0; i < m_NumOfComponents; ++i )
    {
        MDS_LOG_NOTE("  Component " << i << ":");
        MDS_LOG_NOTE("    Weight = " << m_Components(i).getWeight());
        MDS_LOG_NOTE("    Mean = " << m_Components(i).getMean());
        MDS_LOG_NOTE("    Sigma = " << m_Components(i).getSigma());
    }
#endif // EM_LOGGING_ENABLED

    // Final log-likelihood function value
    return dLogLikelihoodFunc;
}


void CImageEMAlgorithm::clearSupportMaps()
{
    m_Maps.zeros();
}


void CImageEMAlgorithm::initComponentMean(tSize i,
                                          const tImage& SrcImage,
                                          tSize XSize,
                                          tSize YSize
                                          )
{
    static const tSize MIN_REGION_SIZE = 5;
    static const tSize MAX_REGION_SIZE = 15;

    // Randomly choose an image region
    tSize x = tSize(m_Uniform.random(0, XSize - 1));
    tSize y = tSize(m_Uniform.random(0, YSize - 1));
    tSize Size = tSize(m_Uniform.random(MIN_REGION_SIZE, MAX_REGION_SIZE));

    // Create image subwindow
    tImage ImageRegion(SrcImage, x, y, Size, Size, mds::REFERENCE);

    // Initialize the mixture component
    m_Components(i).setMean(mds::img::getMean<double>(ImageRegion));
}


void CImageEMAlgorithm::initComponentWeight(tSize i, double dWeight)
{
    m_Components(i).setWeight(dWeight);
}


void CImageEMAlgorithm::initComponentSigma(tSize i)
{
    static const double dTemp = 1.0 / (2 * NUM_OF_DIMENSIONS);

    double dMean = m_Components(i).getMean();

    double dMin = DBL_MAX;
    for( tSize j = 0; j < m_NumOfComponents; ++j )
    {
        if( j != i )
        {
            double dValue = mds::math::getAbs(dMean - m_Components(j).getMean());
            if( dValue < dMin )
            {
                dMin = dValue;
            }
        }
    }

    m_Components(i).setSigma(dTemp * dMin);
}


void CImageEMAlgorithm::initFirstComponent()
{
    MDS_ASSERT(m_NumOfComponents == 1);

    // Compute mean value of all input values
    double dMean = 0.0;
    tSize i;
    for( i = 0; i < m_Span; ++i )
    {
        double dCount = m_Histogram.getCount(i);
        dMean += dCount * m_Histogram.getLowerBound(i);
    }
    dMean *= m_dInvCount;

    // Compute variance of all input values
    double dVariance = 0.0;
    for( i = 0; i < m_Span; ++i )
    {
        double dCount = m_Histogram.getCount(i);
        double dTemp = m_Histogram.getLowerBound(i) - dMean;

        dVariance += dCount * dTemp * dTemp;
    }
    dVariance *= m_dInvCount;

    // Initialize the mixture component
    m_Components(0).setMean(dMean);
    m_Components(0).setSigma(sqrt(dVariance));
    m_Components(0).setWeight(1.0);
}


void CImageEMAlgorithm::initComponents(const tImage& SrcImage,
                                       tSize XSize,
                                       tSize YSize
                                       )
{
    MDS_ASSERT(m_NumOfComponents > 0);

    // Estimate initial weight of all components
    double dWeight = 1.0 / m_NumOfComponents;

    // Initialize all mixture components using statistics of small
    // randomly chosen image regions.
    for( tSize i = 0; i < m_NumOfComponents; ++i )
    {
        // Initialize the component
        initComponentMean(i, SrcImage, XSize, YSize);
        initComponentSigma(i);
        initComponentWeight(i, dWeight);
    }

#ifdef EM_LOGGING_ENABLED
    MDS_LOG_NOTE("initComponents()");
    for( tSize j = 0; j < m_NumOfComponents; ++j )
    {
        MDS_LOG_NOTE("  Component " << j << ":");
        MDS_LOG_NOTE("    Weight = " << m_Components(j).getWeight());
        MDS_LOG_NOTE("    Mean = " << m_Components(j).getMean());
        MDS_LOG_NOTE("    Sigma = " << m_Components(j).getSigma());
    }
#endif // EM_LOGGING_ENABLED
}


double CImageEMAlgorithm::computeLogLikelihood()
{
    mds::math::CLogNum<double> Result = 1.0;

    // For each histogram bin
    for( tSize j = 0; j < m_Span; ++j )
    {
        // Representative of the current bin
        double dPixel = m_Histogram.getLowerBound(j);

        // The number of pixels classified in the bin
        tSize Count = m_Histogram.getCount(j);

        // Sum of probabilities
        double dSum = 0.0;
        for( tSize i = 0; i < m_NumOfComponents; ++i )
        {
            dSum += m_Components(i).getWeightedValue(dPixel);
        }

        for( tSize c = 0; c < Count; ++c )
        {
            Result *= dSum;
        }
    }

    return Result.get(mds::math::LOG_VALUE);
}


/*double CImageEMAlgorithm::getNormCorrelation(tSize i, tSize j)
{
    MDS_ASSERT(i < m_NumOfComponents && j < m_NumOfComponents);
    
    double dMean1 = m_Components(i).getMean();
    double dSigma1 = m_Components(i).getSigma();
 
    double dMean2 = m_Components(j).getMean();
    double dSigma2 = m_Components(j).getSigma();
    
    double dS = dMean1 * dMean2 + dSigma1 * dSigma2;
    double dS1 = dMean1 * dMean1 + dSigma1 * dSigma1;
    double dS2 = dMean2 * dMean2 + dSigma2 * dSigma2;
    
    return dS / sqrt(dS1 * dS2);
}*/


bool CImageEMAlgorithm::areSeparated(tSize i, tSize j, int c)
{
    MDS_ASSERT(i < m_NumOfComponents && j < m_NumOfComponents);

    double dMean1 = m_Components(i).getMean();
    double dSigma1 = m_Components(i).getSigma();

    double dMean2 = m_Components(j).getMean();
    double dSigma2 = m_Components(j).getSigma();

//    double dMin = c * sqrt(mds::math::getMax(dSigma1 * dSigma1, dSigma2 * dSigma2));
    double dMin = c * sqrt(mds::math::getMax(dSigma1, dSigma2));
    double dValue = mds::math::getAbs(dMean1 - dMean2);

    return (dValue >= dMin);
}


int CImageEMAlgorithm::getDegreeOfSeparation(tSize NumOfComponents)
{
    static const int MIN = 4;
    static const int DIV = 2;

    if (NumOfComponents <= MIN)
    {
        return 1;
    }

    return 1 + (NumOfComponents - MIN) / DIV;
}


void CImageEMAlgorithm::splitComponent(tSize Index, tSize NewIndex)
{
    MDS_ASSERT(Index < m_NumOfComponents && NewIndex < m_NumOfComponents);

    double dMean = m_Components(Index).getMean();
    double dSigma = m_Components(Index).getSigma();
    double dWeight = m_Components(Index).getWeight();

    double dNewMean1 = dMean + SPLITTING_COEFF * dSigma;
    double dNewMean2 = dMean - SPLITTING_COEFF * dSigma;
    double dNewSigma = dSigma;
    double dNewWeight = 0.5 * dWeight;

    m_Components(Index).setMean(dNewMean1);
    m_Components(Index).setSigma(dNewSigma);
    m_Components(Index).setWeight(dNewWeight);

    m_Components(NewIndex).setMean(dNewMean2);
    m_Components(NewIndex).setSigma(dNewSigma);
    m_Components(NewIndex).setWeight(dNewWeight);
}


tSize CImageEMAlgorithm::findComponent()
{
    MDS_ASSERT(m_NumOfComponents > 0);

    if( m_NumOfComponents == 1 )
    {
        return 0;
    }

    tSize Max = 0;
    double dMaxSigma = m_Components(0).getSigma();
    for( tSize i = 1; i < m_NumOfComponents; ++i )
    {
        double dSigma = m_Components(i).getSigma();
        if( dSigma > dMaxSigma )
        {
            Max = i;
            dMaxSigma = dSigma;
        }
    }

    return Max;
}


bool CImageEMAlgorithm::checkForSimilarComponents()
{
    if( m_NumOfComponents < 2 )
    {
        return false;
    }

    // Calculate degree of separation
    int iC = getDegreeOfSeparation(m_NumOfComponents);

    // Check similarity of all possible pairs of components
    for( tSize i = 0; i < m_NumOfComponents; ++i )
    {
        for( tSize j = i + 1; j < m_NumOfComponents; ++j )
        {
            // Normalized correlation of both components
//            double dNC = getNormCorrelation(i, j);
//            if( dNC > MAX_CORRELATION )

            // Degree of separation
            if( !areSeparated(i, j, iC) )
            {
                // Splitting of the largest component
                tSize k = findComponent();
                if( k != j )
                {
                    splitComponent(k, j);
                }
                else
                {
                    splitComponent(k, i);
                }

                // Terminate the function
                // - Just one component could be reinitialized at the moment
                return true;
            }
        }
    }

    // No change
    return false;
}


void CImageEMAlgorithm::segmentImage(const tImage& SrcImage,
                                     tImage& DstImage,
                                     tSize XSize,
                                     tSize YSize
                                     )
{
    // Classify input image pixels
    for( tSize y = 0; y < YSize; ++y )
    {
        for( tSize x = 0; x < XSize; ++x )
        {
            // Get index of the histogram bin
            tSize Index = m_Histogram.getIndex(SrcImage(x,y));

            // Find segment whose probability is maximal
            tSize Max = 0;
            for( tSize i = 1; i < m_NumOfComponents; ++i )
            {
                if( m_Maps(i, Index) > m_Maps(Max, Index) )
                {
                    Max = i;
                }
            }
            DstImage(x, y) = tPixel(Max);
        }
    }
}


} // namespace seg
} // namespace mds

