function random_forest_demo
% A basic example of training random forest models with Random Forest Toolbox.
%
% This demo shows how to train a simple classifier with the toolbox. The
% example is done on classification of 2D points.
%
% See also: random_forest_train, random_forest_predict
%
% Author: Roman Juranek <ijuranek@fit.vutbr.cz>, FIT BUT, Brno

clear; close all; clc;

%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
%% Load data

load private/data.mat;

tr = randsample(numel(Y), 1000);
ts = setdiff(1:numel(Y), tr);

clf;hold on;
classes = unique(Y)';
colors = hsv(length(classes));
for i = 1:numel(classes)
    ind = intersect(find(Y == classes(i)), tr);
    plot(X(ind,1), X(ind,2), '.', 'color', colors(i,:));
end
hold off;
title('Training data');

%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
%% Train Random Forest classifier using RFToolbox

% Setup training options - see 'help random_forest_train' for reference
opts = random_forest_train;
opts.ntrees = 10;
opts.min_samples = 4;
opts.subsample = 500;
opts.max_depth = 5;
opts.splits = 500;
opts.verbose = 1;

disp('Training Random Forest');

%profile on
RF = random_forest_train(X', Y, [], tr, opts);
%profile off
%profile viewer

%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
%% Visualize predictions

range = linspace(-10,10,128);
[x,y] = meshgrid(range, range);
X_vis = [x(:) y(:)];
[p_rf] = random_forest_predict(RF, X_vis');
f_rf = vis_results(X(ts,:), p_rf, range, colors);
title('Random Forest classifier');

%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
%% Test the classifier

[q_rf] = random_forest_predict(RF, X(ts,:)');
[~,lbl_rf] = max(q_rf, [], 2);

C_rf = confusionmat(Y(ts), lbl_rf);
disp('Confusion matrix');
disp(C_rf);
disp(['MAP:  ' num2str(sum(diag(C_rf))/length(ts))]);
disp(' ');
  
end % random_forest_demo


function f = vis_results(X, p, range, colors)
sz = repmat(length(range), 1, 2);
J = zeros([sz 3]);
for j = 1:size(p,2)
    K = repmat(reshape(colors(j,:),1,1,3), sz);
    J = J + K.*repmat(reshape(p(:,j), sz), [1 1 3]);
end
J(J(:) >= 1) = 1-eps;
J(J(:) <= 0) = 0+eps;
f = figure; hold on;
imagesc(range, range, J);
hold on;
plot(X(:,1), X(:,2), 'w.');
axis([min(range) max(range) min(range) max(range)]);
end

% END