//==============================================================================
/*! \file
 * OpenMesh Toolkit for mesh analysis    \n
 * Copyright (c) 2011 by Rostislav Hulik     \n
 *
 * Author:  Rostislav Hulik, rosta.hulik@gmail.com  \n
 * Date:    2011/04/11                          \n
 *
 * This file is part of software developed for support of Rostislav Hulik's dissertation thesis at dcgm-robotics@FIT group.
 *
 * This file is free software: you can redistribute it and/or modify
 * it under the terms of the GNU Lesser General Public License as published by
 * the Free Software Foundation, either version 3 of the License, or
 * (at your option) any later version.
 * 
 * This file is distributed in the hope that it will be useful,
 * but WITHOUT ANY WARRANTY; without even the implied warranty of
 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 * GNU Lesser General Public License for more details.
 * 
 * You should have received a copy of the GNU Lesser General Public License
 * along with this file.  If not, see <http://www.gnu.org/licenses/>.
 * 
 * Module description:
 * - Module rasterizes vertex neighbourhood on a tangent raster
 * - Result is sent to connected MDSTk channel
 */

#include "ComputeMatricesHW.h"

#include <OMToolkit\IO\OMIO.h>
#include <OMToolkit\OMTypes.h>
#include <OpenMesh\Tools\Utils\Timer.hh>

#include <utility>
#define __NO_STD_VECTOR // Use cl::vector instead of STL version
#include <CL/cl.hpp>

///////////////////////////////////////////////////////////////////////////////////////////////////
// Module constants
///////////////////////////////////////////////////////////////////////////////////////////////////

// Module description
const std::string MODULE_DESCRIPTION    = "Module that rasterizes vertex neighbourhood on a tangent raster";

// Additional command line arguments
const std::string MODULE_ARGUMENTS      = "size:relative:resolution:xdir:zdir";

// Additional arguments
const std::string MODULE_ARG_SIZE		= "size";
const std::string MODULE_ARG_RES		= "resolution";
const std::string MODULE_ARG_XDIR		= "xdir";
const std::string MODULE_ARG_ZDIR		= "zdir";
const std::string MODULE_ARG_RELATIVE	= "relative";

const double SIZE_DEFAULT				= 2.0;
const double RES_DEFAULT				= 3.0;
const bool RELATIVE_DEFAULT				= false;

const std::string ZDIR_CURVATURE		= "curvature";
const std::string ZDIR_ZDIR				= "z";
const std::string ZDIR_NORMALS			= "normals";
const std::string ZDIR_DEFAULT			= ZDIR_ZDIR;

const std::string XDIR_CURVATURE		= "curvature";
const std::string XDIR_NONE				= "none";
const std::string XDIR_DEFAULT			= XDIR_NONE;


///////////////////////////////////////////////////////////////////////////////////////////////////
// OpenCL related stuff - to be moved to a galaxy far far away
///////////////////////////////////////////////////////////////////////////////////////////////////

inline void
checkErr(cl_int err, const char * name)
{
    if (err != CL_SUCCESS) {
        std::cerr << "ERROR: " << name
                 << " (" << err << ")" << std::endl;

		getchar();
        exit(EXIT_FAILURE);
    }
}



///////////////////////////////////////////////////////////////////////////////////////////////////


// Type of accepted mesh
typedef OMToolkit::Types::ModuleMeshd	MeshT;
typedef MeshT::AttributeScalar AScalarT;
typedef OMToolkit::Types::OMSerializableMatrix<AScalarT> MatrixT;

///////////////////////////////////////////////////////////////////////////////////////////////////
// Constructor
///////////////////////////////////////////////////////////////////////////////////////////////////
OMComputeMatrices::OMComputeMatrices(const std::string& sDescription) : mds::mod::CModule(sDescription)
{
    allowArguments(MODULE_ARGUMENTS);
}

///////////////////////////////////////////////////////////////////////////////////////////////////
// Destructor
///////////////////////////////////////////////////////////////////////////////////////////////////
OMComputeMatrices::~OMComputeMatrices()
{
}

///////////////////////////////////////////////////////////////////////////////////////////////////
// Do on startup
///////////////////////////////////////////////////////////////////////////////////////////////////
bool OMComputeMatrices::startup()
{
	// Disable all OpenMesh errorlogs (for not mix MDSTk log)
	omlog().disable();
	omerr().disable();
	omout().disable();
    
	// Note
    MDS_LOG_NOTE("Module startup");

    // Test of existence of input and output channel
    if( getNumOfInputs() != 1 || getNumOfOutputs() != 1 )
    {
        MDS_CERR('<' << m_sFilename << "> Wrong number of input and output channels" << std::endl);
        return false;
    }

	m_size = SIZE_DEFAULT;
	m_Arguments.value(MODULE_ARG_SIZE, m_size);

	m_resolution = RES_DEFAULT;
	m_Arguments.value(MODULE_ARG_RES, m_resolution);

	m_direction = XDIR_DEFAULT;
	m_Arguments.value(MODULE_ARG_XDIR, m_direction);

	m_relative = RELATIVE_DEFAULT;
	if (m_Arguments.exists(MODULE_ARG_RELATIVE)) m_relative = true;

	if (m_size <= 0.0 || m_resolution <= 0.0 || ((int)m_resolution)%2 == 0)
	{
		MDS_CERR('<' << m_sFilename << "> Wrong size or resolution of a tangent matrix" << std::endl);
		return false;
	}

	m_directionZ = ZDIR_ZDIR;
	m_Arguments.value(MODULE_ARG_ZDIR, m_directionZ);
	
    // O.K.
    return true;
}

///////////////////////////////////////////////////////////////////////////////////////////////////
// Main module loop
///////////////////////////////////////////////////////////////////////////////////////////////////
bool OMComputeMatrices::main()
{
    //// Note
    MDS_LOG_NOTE("Module main function");

    //// I/O channels
    mds::mod::CChannel *pIChannel = getInput(0);
    mds::mod::CChannel *pOChannel = getOutput(0);

	 // Is any input?
    if( !pIChannel->isConnected() )
    {
        return false;
    }

    // Wait for data
    if( pIChannel->wait(1000) )
    {
		// Mesh specification and read options
		MeshT mesh;
		OMToolkit::IO::Options opt = OMToolkit::IO::Options::Default;
	
		// Read and save mesh
		if (OMToolkit::IO::readMesh(mesh, *pIChannel, opt))
		{

			MDS_LOG_NOTE("Starting matrix computation... Model: " << mesh.n_vertices() << " vertices.");
			/////////////////////////////////////////////////////////////////////////////////////////////
			// if length is set as relative, we must compute median and multiply it with length
			if (m_relative)
			{
				std::vector<AScalarT> all;
				MeshT::EdgeIter ende = mesh.edges_end();
	
				for (MeshT::EdgeIter edge = mesh.edges_begin(); edge != ende; ++edge)
					all.push_back(mesh.calc_edge_length(edge));
	
				std::sort(all.begin(), all.end());
				m_size = all[all.size()/2] * m_size;
			}
			/////////////////////////////////////////////////////////////////////////////////////////////
			OpenMesh::Utils::Timer timer;
			
			// go computing
			timer.start();
			///////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
			// allocate arrays
			MeshT::Scalar *vertices =	new MeshT::Scalar[mesh.n_vertices() * 8];
			int *faces =				new int[mesh.n_faces() * 4 + 1024 * 4]; // needed for GPU memory alingment
			
			// compute normals for vertices
			mesh.request_face_normals();
			mesh.request_vertex_normals();
			mesh.update_normals();
			
			// export into arrays
			OMToolkit::IO::exportVertices(mesh, vertices, true);
			OMToolkit::IO::exportFaces(mesh, faces);

			//---------------------------------------------------------
			// OpenCL takeover - fear not
			//---------------------------------------------------------

			// Generate some test triangles and vertices in order to debug OpenCL
			// best of luck to me
			/*
			vertices[0] = -10;
			vertices[1] = -10;
			vertices[2] = 1;
			vertices[3] = 0;

			vertices[4] = 77;
			vertices[5] = 777;
			vertices[6] = 7777;
			vertices[7] = 0;

			vertices[8] = 0;
			vertices[9] = 10;
			vertices[10] = 2;
			vertices[11] = 0;

			vertices[12] = 0;
			vertices[13] = 0;
			vertices[14] = 1;
			vertices[15] = 0;

			vertices[16] = 10;
			vertices[17] = -10;
			vertices[18] = 3;
			vertices[19] = 0;

			vertices[20] = 0;
			vertices[21] = 0;
			vertices[22] = 1;
			vertices[23] = 0;*/
			/*
			faces[0] = 3;
			faces[1] = 5;
			faces[2] = 9;
			faces[3] = 11;

			faces[4] = 13;
			faces[5] = 15;
			faces[6] = 6;
			faces[7] = 4;

			faces[8] = 2;
			faces[9] = 22;
			faces[10] = 0;
			faces[11] = 1;
			*/
			// end of magic, carry on

			const int matrix_size = 16;
			const int local_size = 32768;
			const int matrices_per_unit = 1;
			const bool single_matrix_mode = true;

			// init stuff
			cl_int err;
			cl::vector< cl::Platform > platformList;
			cl::Platform::get(&platformList);
			checkErr(platformList.size()!=0 ? CL_SUCCESS : -1, "cl::Platform::get");
			std::cerr << "Platform number is: " << platformList.size() << std::endl;
    
			std::string platformVendor;
			platformList[0].getInfo((cl_platform_info)CL_PLATFORM_VENDOR, &platformVendor);
			std::cerr << "Platform is by: " << platformVendor << "\n";
			cl_context_properties cprops[3] = 
				{CL_CONTEXT_PLATFORM, (cl_context_properties)(platformList[0])(), 0};
 
			cl::Context context(
				CL_DEVICE_TYPE_GPU, 
			   cprops,
			   NULL,
			   NULL,
			   &err);
			checkErr(err, "Conext::Context()");

			// allocate output buffer for matrices
			char * outM = new char[mesh.n_vertices() * 4 * matrix_size * matrix_size];
			// vertex buffer
			cl::Buffer verticesCL(
				context,
				CL_MEM_READ_ONLY | CL_MEM_USE_HOST_PTR,
				mesh.n_vertices() * 8 * 4,
				vertices,
				&err);
			checkErr(err, "Buffer::Buffer(1)");
			// triangle buffer
			cl::Buffer trianglesCL(
				context,
				CL_MEM_READ_ONLY | CL_MEM_USE_HOST_PTR,
				mesh.n_faces() * 4 * 4,
				faces,
				&err);
			checkErr(err, "Buffer::Buffer(2)");
			// output matrix buffer
			cl::Buffer outCL(
				context,
				CL_MEM_WRITE_ONLY | CL_MEM_USE_HOST_PTR,
				mesh.n_vertices() * 4 * matrix_size * matrix_size,
				outM,
				&err);
			checkErr(err, "Buffer::Buffer(3)");

			// devices
			cl::vector<cl::Device> devices;
			devices = context.getInfo<CL_CONTEXT_DEVICES>();
			checkErr(
				devices.size() > 0 ? CL_SUCCESS : -1, "devices.size() > 0");

			// load the program, build
			std::ifstream file("OMTMesh_Kernel.cl");
			checkErr(file.is_open() ? CL_SUCCESS:-1, "OMTMesh_Kernel.cl");
 
			std::string prog(
				std::istreambuf_iterator<char>(file),
				(std::istreambuf_iterator<char>()));
 
			cl::Program::Sources source(
				 1,
				std::make_pair(prog.c_str(), prog.length()+1));
 
			cl::Program program(context, source);
			err = program.build(devices,"");
			checkErr(file.is_open() ? CL_SUCCESS : -1, "Program::build()");

			// kernel
			cl::Kernel kernel(program, "omc_single", &err);
			checkErr(err, "Kernel::Kernel()");
 
			err = kernel.setArg(0, verticesCL); 
			checkErr(err, "Kernel::setArg()");
			err = kernel.setArg(1, trianglesCL); 
			checkErr(err, "Kernel::setArg()");
			err = kernel.setArg(2, outCL); 
			checkErr(err, "Kernel::setArg()");
			err = kernel.setArg(3, mesh.n_vertices()); 
			checkErr(err, "Kernel::setArg()");
			err = kernel.setArg(4, mesh.n_faces() < 448 ? 448 : mesh.n_faces() ); 
			checkErr(err, "Kernel::setArg()");

			// queue
			cl::CommandQueue queue(context, devices[0], 0, &err);
			checkErr(err, "CommandQueue::CommandQueue()");

			cl::Event event;
			err = queue.enqueueNDRangeKernel(
			kernel, 
			cl::NullRange,
			cl::NDRange(64*mesh.n_vertices() ), // !!!!!!!!!!!!!!!!!!!!!!!
			cl::NDRange(64),
			NULL, 
			&event);
		    checkErr(err, "ComamndQueue::enqueueNDRangeKernel()");

			event.wait();    

			// read back data
			event.wait();    
			err = queue.enqueueReadBuffer(
				outCL,
				CL_TRUE,
				0,
				mesh.n_vertices() * 4 * matrix_size * matrix_size,
				outM);
			checkErr(err, "ComamndQueue::enqueueReadBuffer()");

			float buffer[256];

			for (int i = 0; i < 256; i++)
				buffer[i] = ((float *)outM)[i + 30 * 256];
	
			// 
			// output matrices are now in outM array - suck em' out
			// 

			const int m_size = 256;
			
			int i = 0;
			for (MeshT::VertexIter vertex = mesh.vertices_begin(); vertex != mesh.vertices_end(); ++vertex)
			{
				MatrixT outputMatrix;	
				for (int j = 0; j < m_size; ++j)
				{
					outputMatrix(j) = ((float *)outM)[j + i * m_size];
				}
				mesh.getMatrix(mesh.vertex_handle(i)) = outputMatrix;
				++i;
			}
			


			//---------------------------------------------------------
			// End of OpenCL takeover
			//---------------------------------------------------------

			// test mesh2

			//MeshT mesh2;
			
			// request vertex attribute
			//mesh2.request_vertex_normals();

			// import data
			//OMToolkit::IO::importVertices(mesh2, vertices, mesh.n_vertices(), true);
			//OMToolkit::IO::importFaces(mesh2, faces, mesh.n_faces());
			///////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
			timer.stop();
						
			MDS_LOG_NOTE("Matrices successfully computed in " << timer.as_string(OpenMesh::Utils::Timer::MSeconds) << ".");

			// write output
			if (!OMToolkit::IO::writeMesh(mesh, *pOChannel))
			{
				MDS_CERR('<' << m_sFilename << "> Failed to write output data" << std::endl);
				return false;
			}
		}
		// Error on input
		else 
		{
			MDS_CERR('<' << m_sFilename << "> Failed to read input mesh data" << std::endl);
			return false;
		}

		return false;
	}
    else
    {
        MDS_LOG_NOTE("Wait timeout");
    }

    // Returning 'true' means to continue processing the input channel
    return true;	
}

///////////////////////////////////////////////////////////////////////////////////////////////////
// On module shutdown
///////////////////////////////////////////////////////////////////////////////////////////////////
void OMComputeMatrices::shutdown()
{
    // Note
    MDS_LOG_NOTE("Module shutdown");
}

///////////////////////////////////////////////////////////////////////////////////////////////////
// Writes extended use of this module
///////////////////////////////////////////////////////////////////////////////////////////////////
void OMComputeMatrices::writeExtendedUsage(std::ostream& Stream)
{
    MDS_CERR("Necessary arguments: [-size matrixSize] [-relative] [-resolution matrixResolution] [-xdir XAxisDir]" << std::endl);
    MDS_CERR("Options:" << std::endl);
	MDS_CERR("  -size Specifies matrix size in mesh space." << std::endl);
	MDS_CERR("	  -Argument is double precision number greater than 0.0" << std::endl);
	MDS_CERR("  -relative Use this argument if you want to specify matrix size relatively." << std::endl);
	MDS_CERR("	  -If set, length is not static but is computed as matrixSize * medianOfEdgeLengths" << std::endl);
	MDS_CERR("  -resolution Specifies square matrix dimension in one direction." << std::endl);
	MDS_CERR("	  -Argument is double precision number greater than 0.0." << std::endl);
	MDS_CERR("	  -Must be odd number (for filtration purposes)" << std::endl);
	MDS_CERR("  -xdir Specifies vector, which will be used to align matrix X direction." << std::endl);
	MDS_CERR("	  " << XDIR_CURVATURE << " - Option sets X direction as computed curvature vector." << std::endl);
	MDS_CERR("	  " << XDIR_NONE      << " - Option sets X direction randomly (does not align)." << std::endl);
    MDS_CERR(std::endl);
}

///////////////////////////////////////////////////////////////////////////////////////////////////
// Main - executing a module
///////////////////////////////////////////////////////////////////////////////////////////////////
int main(int argc, char *argv[])
{
    // Creation of a module using smart pointer
    OMComputeMatricesPtr spModule(new OMComputeMatrices(MODULE_DESCRIPTION));

    // Initialize and execute the module
    if( spModule->init(argc, argv) )
    {
        spModule->run();
    }

    // Console application finished
    return 0;
}

