#include "PointObserver.h"


PointObserver::PointObserver(int count)
{
	tracker = FlowTracker();
	max_points = count;
	counter = 0;
	IDcount = 1;

	//create feature processors
	detector = FeatureDetector::create( "SURF" );
	extractor = DescriptorExtractor::create( "SURF" );
	matcher = DescriptorMatcher::create( "FlannBased" );

	//create camera matrix
	cam_D.create(5, 1, CV_64F);
	cam_A.create(3, 3, CV_64F);
	cam_R = Mat::zeros(3, 1, CV_64F);
	cam_T = Mat::zeros(3, 1, CV_64F);
	cam_T2 = Mat::zeros(3, 1, CV_64F);

	dep_D.create(5, 1, CV_64F);
	dep_A.create(3, 3, CV_64F);

	F.create(3, 3, CV_64F);
	E.create(3, 3, CV_64F);

	init = false;
	frame = 1;
	write.open ("output.txt");
}

int PointObserver::getDepth(int x, int y, const Mat& depth, Point3f &p) 
{
	if(y > 0 && y < depth.rows && x > 0 && x < depth.cols) 
	{
		//read depth from video
		float rawdepth = (depth.at<Vec3b>( y, x )[0] + 
						depth.at<Vec3b>( y, x )[1] * 256) / 100.0;

		if(rawdepth > 1)
		{
			//init 3D positions
			//compute 3D position of cetrain point
			p.x = (x - cam_A.at<double>(0, 2)) * rawdepth / cam_A.at<double>(0, 0);
			p.y = (y - cam_A.at<double>(1, 2)) * rawdepth / cam_A.at<double>(1, 1);
			p.z = rawdepth;

			return 1;
		} else {
			return 0;
		}
	} else
		return 0;
}

void PointObserver::trackPoints(const Mat& grey, const Mat& depth, int keyframe) 
{
	//initialize if not
	if(!init) 
	{
		//get keypoints
		findKeyPoints(grey);
		printf("loaded %d features\n", keys.size());
				
		//initialize 3D positions
		keys3D.resize(keys.size());
		state3D.resize(keys.size());
		keyID.resize(keys.size());

		for(int a = 0; a < keys.size(); a++) 
		{
			state3D[a] = getDepth((int)keys[a].pt.x, (int)keys[a].pt.y, depth, keys3D[a]);

			//associate ID
			keyID[a] = IDcount;
			IDcount ++;

			//print to file
			//write << keys3D[a].x << " " << keys3D[a].y << " " << keys3D[a].z << " " << frame << " " 
			//	<< (int)keys[a].pt.x << " " << (int)keys[a].pt.y << "\n";
		}

		//elliminate lost points
		for(int a = keys.size()-1; a >= 0; a--) {
			//elliminate lost points
			if(state3D[a] == 0) 
			{
				//elliminate point
				keys.erase(keys.begin() + a);
				keys3D.erase(keys3D.begin() + a);
				state3D.erase(state3D.begin() + a);
				keyID.erase(keyID.begin() + a);
			}
		}

		//initialize flow tracker
		tracker.track(grey, keys, true);
	}
	else {

		if(counter != keyframe)
		{
			//copy last keys
			keys2.resize(keys.size());
			copy(keys.begin(), keys.end(), keys2.begin());
			//track points
			tracker.track(grey, keys, true);

			//check 3D correspondences
			int b = keys.size();
			for(int a = 0; a < b; a++) 
			{
				if(tracker.state[a] == 0 || keys[a].pt.x < 30 || keys[a].pt.x > grey.cols - 30 || 
					keys[a].pt.y < 30 || keys[a].pt.y > grey.rows - 30)
					state3D[a] = 0;
				if(abs(keys[a].pt.x - keys2[a].pt.x) + abs(keys[a].pt.y - keys2[a].pt.y) > 30)
					state3D[a] = 0;
			}
			//elliminate lost points
			for(int a = keys.size()-1; a >= 0; a--) {
				//elliminate lost points
				if(state3D[a] == 0) 
				{
					//elliminate point
					keys.erase(keys.begin() + a);
					keys2.erase(keys2.begin() + a);
					keys3D.erase(keys3D.begin() + a);
					state3D.erase(state3D.begin() + a);
					keyID.erase(keyID.begin() + a);
				}
			}

			counter ++;
		} else {
			//copy last keys
			keys2.resize(keys.size());
			copy(keys.begin(), keys.end(), keys2.begin());

			matchKeyPoints(grey);
			
			counter = 0;
		}
	}

	//compute 3D cam. position
	p2D.clear();
	p3D.clear();
	for(int a = 0; a < keys3D.size(); a++)
	{
		if(state3D[a] == 1) 
		{
			//add point
			p2D.push_back(keys[a].pt);
			p3D.push_back(keys3D[a]);
		}
	}

	//compute PnP
	solvePnP(Mat(p3D), Mat(p2D), cam_A, cam_D, cam_R, cam_T, true);
	//compute real camera position
	Mat R(3, 3, CV_64F);
	Rodrigues(cam_R, R);
	invert(R, R, 0);
	//multiply
	cam_T2 = R * (-cam_T);

	if(keys.size() < 800) 
	{
		//add new points to our structures
		vector<KeyPoint> newkeys;
		detector->detect( grey, newkeys );
		printf("adding new points %d\n", newkeys.size());

		int added = 0;
		int c = 0;
		while(added < 200 && c < newkeys.size())
		{
			//check if the added point isnt same as other poitns
			bool other = true;
			for(int a = 0; a < keys.size(); a++)
			{
				if(abs(keys[a].pt.x - newkeys[c].pt.x) + abs(keys[a].pt.y - newkeys[c].pt.y) < 4) 
				{
					other = false;
					break;
				}
			}
			//alright?
			if(other) 
			{
				//adding
				keys.push_back(newkeys[c]);
				//add state
				Point3f p;
				state3D.push_back(getDepth((int)newkeys[c].pt.x, (int)newkeys[c].pt.y, depth,p));
				//add 3D position
				Mat X(3, 1, CV_64F);
				Mat X2(3, 1, CV_64F);
				X.at<double>(0, 0) = p.x;
				X.at<double>(1, 0) = p.y;
				X.at<double>(2, 0) = p.z;
				Mat R(3, 3, CV_64F);
				Rodrigues(cam_R, R);
				invert(R, R, 0);
				//multiply
				X2 = R * X + cam_T2;

				p.x = X2.at<double>(0, 0);
				p.y = X2.at<double>(1, 0);
				p.z = X2.at<double>(2, 0);
				keys3D.push_back(p);

				added++;

				//associate ID
				keyID.push_back(IDcount);
				IDcount ++;

				//write to file
				//write << p.x << " " << p.y << " " << p.z << " " << frame << " " 
				//	<< (int)newkeys[c].pt.x << " " << (int)newkeys[c].pt.y << "\n";
			}

			c++;
		} // while
	}
	
}

void PointObserver::findKeyPoints(const Mat& grey) 
{
	//FeatureDetector
	detector->detect( grey, keys );
	//descriptors
	extractor->compute(grey, keys, descs);

	init = true;
}

void PointObserver::matchKeyPoints(const Mat& grey) 
{
	//FeatureDetector
	detector->detect( grey, keys);
	//descriptors
	extractor->compute(grey, keys, descs2);

	//match
	matcher->match(descs, descs2, matches);

	Mat draw;
	grey.copyTo(draw);
	//drawMatches( tracker.prev_grey, keys2, grey, keys, matches, draw );
	for(int a = 0; a < keys.size(); a++)
		circle(draw, keys[a].pt, 1, CV_RGB(255, 255, 255), 3, 8, 0);
	imshow("win", draw);
	waitKey(0);

	vector<KeyPoint> k;
	Mat d(matches.size(), descs2.cols, descs2.type()); 
	tracker.state.resize(matches.size());

	//order the keys to correspond with keys2
	for(int a = 0; a < matches.size(); a++) 
	{
		k.push_back(KeyPoint(keys[matches[a].trainIdx]));
		descs2.row(matches[a].trainIdx).copyTo(d.row(a));
		//determine validity
		if((abs(keys2[a].pt.x - keys[matches[a].trainIdx].pt.x) + abs(keys2[a].pt.y - keys[matches[a].trainIdx].pt.y)) < 25)
			tracker.state[a] = 1;
		else
			tracker.state[a] = 0;
	}
	//copy back
	keys.clear();
	keys = k;
	descs.release();
	descs = d;
}

void PointObserver::switchKeyPoints()
{
	//switch keypoints
	keys.resize(keys2.size());
	copy(keys2.begin(), keys2.end(), keys.begin());
	//switch descriptors
	descs.~Mat();
	descs = descs2.clone();
}

void PointObserver::computeF()
{
	vector<Point2f> pt1, pt2;
	//filter good matching points
	for(int a = 0; a < matches.size(); a++) 
	{
		int i1 = matches[a].queryIdx;
		int i2 = matches[a].trainIdx;

		if(matches[a].distance < 0.15) 
		{
			//setup the 2D correspondences
			pt1.push_back(keys[1].pt);
			pt2.push_back(keys2[2].pt);
		}
	}

	Mat m1(pt1);
	Mat m2(pt2);

	//compute F matrix
	F = findFundamentalMat(m1, m2, FM_RANSAC, 3.0, 0.99);
}

void PointObserver::loadMatrix(char *name, Mat& mat) 
{
	CvMat *c = cvCreateMat(3, 3, CV_64F);
	c = (CvMat *)cvLoad(name);

	if(c->rows != mat.rows || c->cols != mat.cols)
		printf("Bad matrix size\n");
	else
		mat = c;
}


PointObserver::~PointObserver(void)
{
}
