/*
 * bow_generator.cpp
 *
 *  Created on: 12.10.2011
 *      Author: isvoboda
 */

#include "bow.hpp"

namespace ppd
{

	//--------------------- Class Extractor ------------------------------------

	Extractor::Extractor():
			featureDetector(cv::FeatureDetector::create("SURF")),
			descriptorExtractor(cv::DescriptorExtractor::create("SURF"))
	{
	}

	Extractor::Extractor(const cv::Ptr<cv::FeatureDetector> featureDetector, const cv::Ptr<cv::DescriptorExtractor> descriptorExtractor):
			featureDetector(featureDetector),
			descriptorExtractor(descriptorExtractor)
	{}

	void Extractor::set(const cv::Ptr<cv::FeatureDetector> featureDetector, const cv::Ptr<cv::DescriptorExtractor> descriptorExtractor)
	{
		this->featureDetector = featureDetector;
		this->descriptorExtractor = descriptorExtractor;
	}

	cv::Ptr<cv::DescriptorExtractor> Extractor::getDescriptorExtractor(void)
	{
		return this->descriptorExtractor;
	}

	cv::Ptr<cv::FeatureDetector> Extractor::getFeatureDetector(void)
	{
		return this->featureDetector;
	}

	void Extractor::extract(const cv::Mat& image, cv::Mat& descriptors, const cv::Rect_<int>* mask) const
	{
		std::vector<cv::KeyPoint> keypoints;
		std::vector<cv::KeyPoint> tmpKeypoints;

		this->featureDetector->detect(image, keypoints);

		if(mask != NULL)
		{
			this->get_masked(keypoints, mask, tmpKeypoints);
			keypoints = tmpKeypoints;
		}

		this->descriptorExtractor->compute(image, keypoints, descriptors);
	}

	void Extractor::detect(const cv::Mat& image, std::vector<cv::KeyPoint>& keypoints, const cv::Rect_<int>* mask) const
	{
		this->featureDetector->detect(image, keypoints);

		if(mask != NULL)
		{
			std::vector<cv::KeyPoint> masked_keypoints;
			this->get_masked(keypoints, mask, masked_keypoints);
			keypoints = masked_keypoints;
		}
	}

	void Extractor::compute(const cv::Mat& image, std::vector<cv::KeyPoint>& keypoints, cv::Mat& descriptors) const
	{
		this->descriptorExtractor->compute(image, keypoints, descriptors);
	}

	void Extractor::get_masked(const std::vector<cv::KeyPoint>& keypoints, const cv::Rect_<int>* mask, std::vector<cv::KeyPoint>& maskedKeypoints) const
	{
		maskedKeypoints.clear();
		for(size_t i = 0; i < keypoints.size(); i++)
		{
			if( cvRound(keypoints[i].pt.x) > mask->x && cvRound(keypoints[i].pt.x < mask->x+mask->width) && \
					cvRound(keypoints[i].pt.y > mask->y && cvRound(keypoints[i].pt.y) < mask->y + mask->height)	)
			{
				maskedKeypoints.push_back(keypoints[i]);
			}
		}
	}

	Extractor::~Extractor(void)
	{
	}


	//--------------------- BOW Methods Definitions ----------------------------

	BOW::BOW()
	{}

	BOW::BOW(cv::Ptr<cv::DescriptorMatcher> dmatcher):
			dmatcher(dmatcher)
	{}

	//Descriptors			- computed descriptors
	//dmatcher				- matching the descriptors to the words
	//feature				- computed BoW descriptor
	//pointIdxsOfClusters	- Array where on descriptor's index is the word's index
	//							pointIdxsOfClusters[descriptor_index] = word_index
	void BOW::extract(const cv::Mat& descriptors, cv::Ptr<cv::DescriptorMatcher> dmatcher, cv::Mat& feature, std::vector<int>* pointIdxsOfClusters)const
	{
		if(descriptors.empty())
			return;
		if(dmatcher->empty())
			return;
		//ToDo not very elegant way to obtain the cluster count - rows of VOC mat
		int clusterCount = static_cast<int>(dmatcher->getTrainDescriptors()[0].rows);

		if(pointIdxsOfClusters)
		{
			pointIdxsOfClusters->clear();
			//OpenCv Implementation |	word1 	|	- [ matched keypoints-descriptors]
			//						|	-----	|
			//						|	word2 	|	- [ matched keypoints-descriptors]
//			pointIdxsOfClusters->resize(clusterCount);

			//Our implementation - on index of descriptor is its word index
			pointIdxsOfClusters->resize(descriptors.rows, -1);
		}

		std::vector<cv::DMatch> matches;
		dmatcher->match(descriptors, matches);
		cv::Mat imgDescriptor = cv::Mat( 1, clusterCount, CV_32FC1, cv::Scalar::all(0.0) );
		float *dptr = (float*)imgDescriptor.data;
		for( size_t i = 0; i < matches.size(); i++ )
		{
			int queryIdx = matches[i].queryIdx; // descriptor index
			int trainIdx = matches[i].trainIdx; // cluster index
			CV_Assert( queryIdx == (int)i ); // Raise the exception if the expression is 0
			//Constructing the BoW feature
			dptr[trainIdx] = dptr[trainIdx] + 1.f;
			//If pointIdxsofClusters is defined, its element on index of descriptor index is set as the index of the word.
			if( pointIdxsOfClusters )
				(*pointIdxsOfClusters)[queryIdx] = trainIdx;
		}

		// Normalize image descriptor.
		imgDescriptor /= descriptors.rows;
		feature = imgDescriptor;
	}

	void BOW::extract(const cv::Mat& descriptors, cv::Mat& feature, std::vector<int>* pointIdxOfClusters)const
	{
		if(!this->dmatcher.empty())
			this->extract(descriptors, this->dmatcher, feature, pointIdxOfClusters);
	}

	void BOW::set_DescriptorMatcher(cv::Ptr<cv::DescriptorMatcher> dmatcher)
	{
		this->dmatcher = dmatcher;
	}

	int BOW::get_voc_size(void) const
	{
		if(!this->dmatcher.empty())
			return static_cast<int>(this->dmatcher->getTrainDescriptors()[0].rows);
		else
			return 0;
	}

	//--------------------- Extractor Manager ----------------------------------

	ExtractorManager::ExtractorManager() :
			resolution(cv::Point_<int>(640, 480))
	{
	}

	void ExtractorManager::resize(const cv::Mat& input, const cv::Point_<int>& size,
			cv::Mat& output) const
	{
		int width = input.cols;
		int height = input.rows;
		float ratio = 0.0;
		if (width > height)
			ratio = static_cast<float>(size.x) / width;
		else
			ratio = static_cast<float>(size.y) / height;

		cv::Mat tmp = cv::Mat(static_cast<int>(height * ratio),
				static_cast<int>(width * ratio), CV_8UC3);
	//Potentially expensive operation
		cv::resize(input, tmp, tmp.size(), 0.0, 0.0, cv::INTER_LANCZOS4);
		output = tmp;
		tmp.release();
	}

	void ExtractorManager::extract_data(const cv::Mat& image, Extractor& ext, cv::Mat& feature, const cv::Rect_<int>* mask) const
	{
		//If the image is too large, resize it to 640x480
		cv::Mat tmpImage;
		if (image.cols > this->resolution.x || image.rows > this->resolution.y)
		{
			this->resize(image, this->resolution, tmpImage);
		}
		else
			tmpImage = image;

		ext.extract(tmpImage, feature, mask);

	}

	void ExtractorManager::extract_data(DataSetReader& reader, Extractor& ext, BOW& bow_ext, cv::Mat& features) const
	{
		//Feature stored in cv::Mat container
		features = cv::Mat();
		unsigned int size = reader.get_number_of_objects();
		//	omp_set_num_threads(1);
		#pragma omp parallel for schedule(dynamic)
		for (unsigned int i = 0; i < size; i++)
		{
			std::stringstream msg;
			std::stringstream err;
			cv::Mat image;
			cv::Mat feature;
			cv::Mat bow_feature;

			image = reader[i];
			const ppd::Object* object = reader.get(i);
//			image = cv::imread(object->get_file_name(), -1);

			if (image.data == NULL)
			{
				err << "Couldn't load the image: " << object->get_file_name()//reader.get_file_name(i)
						<< ". " << "Object num: " << i << std::endl;
				continue;
			}

			int threadID = omp_get_thread_num();
			int numThreads = omp_get_num_threads();
			//flag == 1 - the object contain image of wanted class
			if(object->get_flag() == 1)
			{
				cv::Rect_<int> mask = (static_cast<const BoundedObject*>(object)->get_bounding_box());
				this->extract_data(image, ext, feature, &mask);
			}
			else
				this->extract_data(image, ext, feature);

			bow_ext.extract(feature, bow_feature);

			#pragma omp critical(feature_add)
			{
				std::cerr << err.str();

				std::cout << "Thread: " << threadID << " of " << numThreads
						<< " | processing: " << (i + 1) << "/" << size << std::endl;

				if (bow_feature.cols <= 0 || bow_feature.rows <= 0)
					std::cerr << "No BOW descriptor extracted from image: "
							<< reader.get_file_name(i) << std::endl;
				else
					features.push_back(bow_feature);
			}
		}
	}

	void ExtractorManager::extract_data(DataSetReader& reader, Extractor& ext, cv::Mat& features) const
	{
		//Feature stored in cv::Mat container
		unsigned int size = reader.get_number_of_objects();
		//	omp_set_num_threads(1);
		#pragma omp parallel for schedule(dynamic)
		for (unsigned int i = 0; i < size; i++)
		{
			std::stringstream msg;
			std::stringstream err;
			cv::Mat image;
			cv::Mat feature;

			image = reader[i];
			const ppd::Object* object = reader.get(i);
//			image = cv::imread(object->get_file_name(), -1);

			if (image.data == NULL)
			{
				err << "Couldn't load the image: " << object->get_file_name()//reader.get_file_name(i)
						<< ". " << "Object num: " << i << std::endl;
				continue;
			}

			int threadID = omp_get_thread_num();
			int numThreads = omp_get_num_threads();
			//flag == 1 - the object contain image of wanted class
			if(object->get_flag() == 1)
			{
				cv::Rect_<int> mask = (static_cast<const BoundedObject*>(object)->get_bounding_box());
				this->extract_data(image, ext, feature, &mask);
			}
			else
				this->extract_data(image, ext, feature);

			#pragma omp critical(feature_add)
			{
				std::cerr << err.str();

				std::cout << "Thread: " << threadID << " of " << numThreads
						<< " | processing: " << (i + 1) << "/" << size << std::endl;

				if (feature.cols <= 0 || feature.rows <= 0)
					std::cerr << "No descriptor extracted from image: "
							<< reader.get_file_name(i) << std::endl;
				else
					features.push_back(feature);
			}
		}
	}

	void ExtractorManager::extract_data(const cv::Mat& image, Extractor& ext, BOW& bow_ext, cv::Mat& feature, const cv::Rect_<int>* mask) const
	{
		//If the image is too large, resize it to 640x480
		feature.release();
		cv::Mat tmpImage;
		if (image.cols > this->resolution.x || image.rows > this->resolution.y)
		{
			this->resize(image, this->resolution, tmpImage);
		}
		else
			tmpImage = image;
		cv::Mat tmp_feature;

		ext.extract(tmpImage, tmp_feature, mask);
		bow_ext.extract(tmp_feature, feature);

	}


	void ExtractorManager::set_resolution(const unsigned int x,
			const unsigned int y)
	{
		this->resolution = cv::Point_<int>(x, y);
	}

	ExtractorManager::~ExtractorManager()
	{
	}

	//--------------------- VOC Methods Definitions ----------------------------

	VOC::VOC() :
			voc_size(1024)
	{
	}

	void VOC::set_voc_size(unsigned int number_of_words)
	{
		this->voc_size = number_of_words;
	}

	void VOC::createVOC(DataSetReader& reader, Extractor& ext, ExtractorManager& extractorManager)
	{
		cv::Mat descriptors;
		extractorManager.extract_data(reader, ext, descriptors);
		if(!descriptors.empty())
		{
			//2000,5,3x
			cv::BOWKMeansTrainer kmeans_trainer(this->voc_size,
					cv::TermCriteria(CV_TERMCRIT_ITER + CV_TERMCRIT_EPS, 1000, 5.0), 3,
					cv::KMEANS_PP_CENTERS);
			this->voc = kmeans_trainer.cluster(descriptors);
		}
		else
			std::cerr << "No Descriptors extracted to create VOC, operation failed!" << std::endl;
	}

	void VOC::saveVOC(const std::string& file_name_voc_storage)
	{
		cv::FileStorage voc_fs(file_name_voc_storage, cv::FileStorage::WRITE);
		voc_fs << "VOC" << "[";
		voc_fs << this->voc;
		voc_fs << "]";
		voc_fs.release();
	}

	VOC::~VOC()
	{
	}

	void read_voc(const std::string& voc_path, cv::Mat& voc)
	{
		cv::FileStorage voc_fs(voc_path, cv::FileStorage::READ);
		cv::FileNode voc_node = voc_fs["VOC"];
		cv::FileNodeIterator it = voc_node.begin(), it_end = voc_node.end();
		int index = 0;
		for (; it != it_end; it++, index++)
		{
			cv::Mat clusterCenter;
			(*it) >> clusterCenter;
			voc.push_back(clusterCenter);
		}

	}

}

