#include "TrackRepairer.h"

namespace Motion {


void NoneTrackRepairer::repair(TrackContainer & currentTracks, TrackContainer & lostTracks, TrackContainer & newTracks, const Mat & gray, const Mat & prevGray) {
	lostTracks.clear();
}


PropagationTrackRepairer::PropagationTrackRepairer(TrackMatcherPtr trackMatcher, unsigned int minLength, unsigned int deadLength, unsigned int medianSize) {
	setMatcher(trackMatcher);
	setMinLength(minLength);
	setDeadLength(deadLength);
	setMedianSize(medianSize);
}


void PropagationTrackRepairer::repair(TrackContainer & currentTracks, TrackContainer & lostTracks, TrackContainer & newTracks, const Mat & gray, const Mat & prevGray) {
	// only tracks with minimal length will be repaired
	TrackContainer lostMinLength = lostTracks.removeContent(TrackContainer::ShorterThan(minLength));

	// some tracks can be already matched (propagated from previous frame)
	matcher->match(lostTracks, newTracks, gray, prevGray);
	currentTracks.addContent(lostTracks, 5, true);
	lostTracks = lostTracks.removeContent(TrackContainer::NotMatched());

	// propagate lost tracks to next frame
	for(unsigned int i = 0; i < lostTracks.size(); i++) {
		Point2f position = lostTracks[i]->getKeypoint().pt;
		Point2f translation(0, 0);

		if(estimateTranslation(currentTracks, position, translation)) {
		//if(estimateTranslation2(*lostTracks[i], translation)) {
			lostTracks[i]->propagate(position+translation);
		}
		else {
			lostTracks[i]->propagate();
		}
	}

	// remove dead tracks
	TrackContainer lostUpdatedBefore = lostTracks.removeContent(TrackContainer::UpdatedBefore(deadLength));
}


void PropagationTrackRepairer::setMatcher(TrackMatcherPtr matcher) {
	this->matcher = matcher;
}


void PropagationTrackRepairer::setMinLength(unsigned int minLength) {
	if(minLength == 0) {
		CV_Error(CV_StsBadArg, "Minimal track length to repair must be > 0");
	}
	
	this->minLength = minLength;
}


void PropagationTrackRepairer::setDeadLength(unsigned int deadLength) {
	if(deadLength == 0) {
		CV_Error(CV_StsBadArg, "Dead track length must be > 0");
	}

	this->deadLength = deadLength;
}


void PropagationTrackRepairer::setMedianSize(unsigned int medianSize) {
	if(medianSize > MEDIAN_MAX) {
		CV_Error(CV_StsBadArg, "Median size exceeds size limit.");
	}
	
	this->medianSize = medianSize;
}


void PropagationTrackRepairer::insertToArray(Point2f arrPts[], float arrDist[], unsigned int size, float dist, Point2f pt) {
	int i = size-1;
	while(i > 0 && dist >= arrDist[i-1]) {
		arrDist[i] = arrDist[i-1];
		i--;
	}
  
	if(arrDist[i] < dist) {
		arrDist[i] = dist;
		arrPts[i] = pt;
	}
}


bool PropagationTrackRepairer::estimateTranslation(const TrackContainer & tracks, Point2f pos, Point2f & trans) {
	for(unsigned int i = 0; i < medianSize; i++) {
		arrDists[i] = 0;
		arrPts[i] = Point2f(0, 0);
	}
	
	unsigned int count = 0;
	for(unsigned int i = 0; i < tracks.size(); i++) {
		if(tracks[i]->getLength() < 2) {
			continue;
		}

		Point2f ptSrc = tracks[i]->getKeypoint(1).pt;
		Point2f ptDst = tracks[i]->getKeypoint(0).pt;
		
		Point2f trans = ptDst-ptSrc;
		float dist = sqrt(pow(pos.x-ptSrc.x, 2) + pow(pos.y-ptSrc.y, 2));

		insertToArray(arrPts, arrDists, MIN(count+1, medianSize), dist, trans);
		count++;
	}

	if(count < medianSize) {
		return false;
	}

	vector<float> trX;
	vector<float> trY;
	for(unsigned int i = 0; i < medianSize; i++) {
		trX.push_back(arrPts[i].x);
		trY.push_back(arrPts[i].y);
	}

	sort(trX.begin(), trX.end());
	sort(trY.begin(), trY.end());
	
	trans.x = trX[medianSize/2];
	trans.y = trY[medianSize/2];
	return true;
}


bool PropagationTrackRepairer::estimateTranslation2(const Track & track, Point2f & trans) {
	if(track.getLength() < 2) {
		return false;
	}

	trans = track.getKeypoint(0).pt - track.getKeypoint(1).pt;
	return true;
}


}
