#include "TrackMatcher.h"

namespace Motion {


TrackMatcher::TrackMatcher() {

}


int TrackMatcher::getTrackMatched(Track & track) {
	return track.matched;
}


void TrackMatcher::setTrackMatched(Track & track, int matched) {
	track.matched = matched;
}


BaseTrackMatcher::BaseTrackMatcher(float maxDist, float minNcc, float ransacThreshold) : TrackMatcher() {
	setMaxDistance(maxDist);
	setMinNcc(minNcc);
	setRansacThreshold(ransacThreshold);
}


void BaseTrackMatcher::setMaxDistance(float maxDist) {
	if(maxDist < 0 && maxDist != -1) {
		CV_Error(CV_StsBadArg, "Maximal feature distance must be >= 0");
	}

	this->maxDist = maxDist;
}


void BaseTrackMatcher::setMinNcc(float minNcc) {
	if(minNcc < 0 || minNcc > 1) {
		CV_Error(CV_StsBadArg, "Minimal NCC value must be in 0..1");
	}

	this->minNcc = minNcc;
	this->useNcc = minNcc > 0;
}


void BaseTrackMatcher::setRansacThreshold(float threshold) {
	if(threshold < 0 && threshold != -1) {
		CV_Error(CV_StsBadArg, "RANSAC threshold must be >= 0");
	}

	this->ransacThreshold = threshold;
	this->useRansac = threshold != -1;
}


void BaseTrackMatcher::ransacOutliers(TrackContainer & currentTracks) {
	if(!useRansac) {
		return;
	}
	
	vector<TrackPtr> ptrs;
	vector<Point2f> srcPoints;
	vector<Point2f> dstPoints;
	vector<bool> outlierMask;

	// prepare points for RANSAC homography estimation
	for(unsigned int i = 0; i < currentTracks.size(); i++) {
		if(getTrackMatched(*currentTracks.item(i)) && currentTracks.item(i)->getLength() >= 2) {
			ptrs.push_back(currentTracks.item(i));
			srcPoints.push_back(currentTracks.item(i)->getKeypoint(1).pt);
			dstPoints.push_back(currentTracks.item(i)->getKeypoint(0).pt);
			outlierMask.push_back(false);
		}
	}

	if(srcPoints.size() < 4) {
		return;
	}
	
	// we don't need homography, just divide to inliers/outliers
	//findHomography(srcPoints, dstPoints, CV_RANSAC, ransacThreshold, outlierMask);

	// set outliers as not matched
	for(unsigned int i = 0; i < outlierMask.size(); i++) {
		if(outlierMask[i]) {
			setTrackMatched(*ptrs[i], false);
		}
	}
}


OpticalFlowMatcher::OpticalFlowMatcher(float maxDist, float minNcc, float ransacThreshold) : BaseTrackMatcher(maxDist, minNcc, ransacThreshold) {

}


void OpticalFlowMatcher::match(TrackContainer & currentTracks, const TrackContainer & detectedTracks, const Mat & currentFrame, const Mat & prevFrame) {
	if(detectedTracks.size() == 0 || currentTracks.size() == 0) {
		return;
	}

	int64 timestamp = detectedTracks.item(0)->getLastTimeStamp();
	
	vector<Point2f> currentPts;
	vector<Point2f> detectedPts;
	vector<uchar> status;
	vector<float> err;

	for(unsigned int i = 0; i < currentTracks.size(); i++) {
		currentPts.push_back(currentTracks.item(i)->getKeypoint().pt);
	}

	TermCriteria termcrit(CV_TERMCRIT_ITER|CV_TERMCRIT_EPS, 20, 0.03);
	calcOpticalFlowPyrLK(prevFrame, currentFrame, currentPts, detectedPts, status, err, Size(31, 31), 3, termcrit, 0);

	for(unsigned int i = 0; i < currentTracks.size(); i++) {
		if(status[i]) {
			KeyPoint kpt(detectedPts[i].x, detectedPts[i].y, 15);
			Track second(kpt, currentTracks.item(i)->getDescriptor(), timestamp);
			if((!useNcc || currentTracks.item(i)->ncc(second, prevFrame, currentFrame) >= minNcc) && (maxDist < 0 || currentTracks.item(i)->distancePosition(second) <= maxDist)) {
				currentTracks.item(i)->update(kpt, currentTracks.item(i)->getDescriptor(), timestamp);
				setTrackMatched(*currentTracks.item(i), true);
			}
			else {
				setTrackMatched(*currentTracks.item(i), false);
			}
		}
		else {
			setTrackMatched(*currentTracks.item(i), false);
		}
	}

	// remove incorrect matches by RANSAC, if enabled
	ransacOutliers(currentTracks);
}


}
