function [T,IND1] = random_forest_train(X, Y, varargin)
% Train a random forest model T using samples X and labels Y.
%
% Usage:
%
%   T = RANDOM_FOREST_TRAIN(X, Y)
%   T = RANDOM_FOREST_TRAIN(X, Y, OPTS)
%   T = RANDOM_FOREST_TRAIN(X, Y, W, OPTS)
%   T = RANDOM_FOREST_TRAIN(X, Y, W, IND0, OPTS)
%   [T,IND1] = RANDOM_FOREST_TRAIN(...)
%   OPTS = RANDOM_FOREST_TRAIN
%
% Inputs:
%   X     M1 x M2 ... MD x N matrix with N training instances
%   Y     N x 1 vector of labels of training instances
%   W     N x 1 vector of weights. The weights are set to 1 when [] is passed
%   IND0  Vector of indices of training instanecs
%   OPTS
%    .verbose     [false] When set to true, function will write info about
%                 training process.
%    .classes     [] Classes to take into account. When empty, classes are
%                 detected automatically.
%    .subsample   [0] Number of samples to sample from a given pool.
%    .ntrees      [5] Number of trees to train.
%    .min_samples [20] Minimum number of samples in leafs.
%    .max_depth   [8] Maximal tree depth.
%    .splits      [1000] Number of randomly generated splits in each node.
%    .remove_bad  [true] Remove splits with no samples in either branch.
%    .split_gen   [{'random_planes',{1}}] Split generation functions.
%    .split_eval  [{'lin_features_eval',{}}] Split evaluation functions.
%    .impurity    [{'infogain',{}}] Split error calculation functions.
%    .user_fnc    [{}] User functions executed after a node is trained.
%    .predictor   [{'class_prob',{}}] Function to calculate predictor
%                 in every node.
%
% Outputs:
%   T     Trained model which can be passed to random_forest_predict
%   IND1  1 x T cell array of vectors denoting which samples were actualy
%         used for training the individual trees.
%   OPTS  Structure with options.
%
% Functions and Function Parameters Passed as Arguments
% -----------------------------------------------------
% In options structure, user can specify multiple functions that
% parametrize the training proccess. The functions are passed in the
% following format.
%
%     func = {FUNCTION1, {PARAMS1}, FUNCTION2, {PARAMS2}};
%
% The FUNCTION can be a string specigying the function name, or an
% anonymous function suitable for feval function. PARAMS is a list of
% arguments for the function. Example of an impurity function
% specification:
%
%     opts.impurity = {'infogain', {}};
%
% This will use infogain function with no additional arguments for split
% error calculation. Another example with anonymous function.
%
%     opts.split_eval = {@(x,i,f)x(i,:)*f'>0,{}};
%
%
% Split Generators and Evaluators
% -------------------------------
% User can specify one or more split generation and split evaluation
% functions in options structure. Representation of samples and splits is
% LEFT COMPLETELY to the user. One must only make sure that his split
% generators produce parameters that can be used by evaluator functions,
% and also that generators and evaluators understand the sample format.
% Formats of functions are following.
%
%     [F,SPLITS,ERR] = split_gen(N,X,Y,W,IND,CLASSES,...);
%
% Split generator outputs N split function parameters. X, Y, W, IND denotes
% the samples to take into account. Optionally, it can output SPLITS binary
% matrix with result of every split function on every sample, and ERR
% vector of errors for every split. In the case SPLITS or ERR is left empty
% they are obtained afterwards using split evaluation function and impurity
% functions.
%
%     SPLITS = split_eval(F, X, IND);
%
% Impurity functions
% ------------------
% TBD
%
% Predictors
% ----------
% TBD
%
% User Functions
% --------------
% TBD
%
% See also: random_planes, lin_features_eval, infogain, weighted_sampling,
% random_forest_update, random_forest_predict, random_forest_demo, feval
%
% Author:   Roman Juranek <ijuranek@fit.vutbr.cz>


%%%
%%% Process arguments and set default values
%%%
default_options = struct(...
    'compact', true, ...
    'verbose', false, ...
    'ntrees', 5, ...
    'subsample', 0, ...
    'sampling', 'uus', ...
    'min_samples', 20, ...
    'max_depth', 8, ...
    'splits', 1000, ...
    'batch', inf, ...
    'split_gen', { {'default',{}} }, ...
    'split_eval', { {'default_eval',{}} }, ...
    'remove_bad', true, ...
    'user_fnc', { {} }, ...
    'impurity', { {'infogain',{}} }, ...
    'predictor', { {'class_prob',{}} }, ...
    'classes', [] ...
    );

if nargin == 0
    T = default_options;
    return;
end

switch nargin
    case 2, % X, Y
        W = ones(length(Y),1);
        IND0 = int32(1:length(Y));
        options = default_options;
    case 3, % X, Y, [OPTS]
        W = ones(length(Y),1);
        IND0 = int32(1:length(Y));
        options = varargin{1};
    case 4, % X, Y, [W, OPTS]
        IND0 = int32(1:length(Y));
        W = varargin{1};
        options = varargin{2};
    case 5, % X, Y, [W, IND0, OPTS]
        W = varargin{1};
        IND0 = varargin{2};
        options = varargin{3};
    otherwise,
        error('Incorrect parameters');
end

if isempty(W), W=ones(length(Y),1); end;

options = struct(options);

%%%
%%% Check input argument type
%%%

if ~exist('X','var') || isempty(X)
    error('Feature vectors are missing');
end
% FIXME for regression
if ~exist('Y','var') || ~isvector(Y) || isempty(Y)
    error('Labels are missing');
end
% FIXME for regression
if ~iscolumn(Y), Y = Y'; end; % transpose if needed
if ~iscolumn(W), W = W'; end; % transpose if needed

if isstruct(X), n_samples = numel(X);
else n_samples = size(X,ndims(X)); end;

% FIXME for regression
if n_samples ~= numel(Y)
    error('Number of samples must match the number of labels.');
end

if n_samples ~= numel(W)
    error('Number of weights must match the number of samples.');
end

% IND0 must be int32, finite, column, vector, min 1, max, length(Y)
if any(~[all(isfinite(IND0)), isvector(IND0), min(IND0)>=1, max(IND0)<=n_samples])
    error('IND0 must be finite integer vector with elements between 1 and the number of training samples');
end
IND0 = int32(IND0);
if ~iscolumn(IND0), IND0 = IND0'; end;
if ~issorted(IND0), IND0 = sort(IND0); end;

if ~iscell(options.impurity) || ~iscell(options.split_gen) || ...
        ~iscell(options.split_eval) || ~iscell(options.user_fnc)
    error('Bad arguments');
end

if numel(options.split_gen) ~= numel(options.split_eval)
    error('Numer of split generators must match the number of split evaluators.');
end

options.split_gen = reshape(options.split_gen, 2, []);
options.split_eval = reshape(options.split_eval, 2, []);
options.user_fnc = reshape(options.user_fnc, 2, []);
options.impurity = reshape(options.impurity, 2, []);
options.predictor = reshape(options.predictor, 2, []);

%%%
%%% Sample information.
%%% Get active classes and print sample counts.
%%%

if isempty(options.classes), options.classes = unique(Y(IND0))'; end;

log_printf(options.verbose, '#samples: ');
for c = options.classes, log_printf(options.verbose,'%d [%d]; ',sum(Y(IND0)==c),c);end;
log_printf(options.verbose, '\n');

%%%
%%% Independently train the trees. For every tree, samples are
%%% randomly selected form the training set.
%%%

T = struct;
T.classes = options.classes;
T.split_eval = options.split_eval(1,:);
T.split_params = options.split_eval(2,:);
T.n_predictors = 0;
tree = cell(1,options.ntrees);

t_start = tic;

if nargout == 2, IND1 = cell(1,options.ntrees); end;

for t = 1:options.ntrees
    log_printf(1, 'Tree %i/%i in progress\n', t, options.ntrees);
    
    if options.subsample > 0
        ind = randsample(IND0, options.subsample);
        ind = sort(ind);
    else
        ind = IND0;
    end;
    
    norm_W = W(ind);
    norm_Y = Y(ind);
    for c = options.classes
        norm_W(norm_Y==c) = norm_W(norm_Y==c) / (length(options.classes) * sum(norm_W(norm_Y==c)));
    end
    norm_W1 = zeros(size(W));
    norm_W1(ind) = norm_W;
    norm_W = norm_W1;
       
    log_printf(options.verbose, 'Active set: %i samples; ', length(ind));
    for c = options.classes
        w = norm_W(ind);
        y = Y(ind);       
        log_printf(options.verbose, '%i: %i (w=%.2f);  ', c, sum(y==c), sum(w(y==c)));
    end
    log_printf(options.verbose, '\n');
    
    if nargout == 2, IND1{t} = ind; end;
    t0 = tic;
    root = learn_node(X, Y, norm_W, ind, 0, length(ind)/4, options);
    if options.compact, root = compact_tree(root); end;
    tree{t} = root;
    tm = toc(t0);
    log_printf(options.verbose, '  %.1fs\n', tm);
end % trees

t_end = toc(t_start);
fprintf('Totoal time: %.1fs\n', t_end);

T.tree = tree;
T.n_predictors = length(tree{1}(1).p);

end % random_forest_train


function T = leaf(T, X, Y, W, ind, options)
log_printf(options.verbose, '(%d)', length(ind));
T.tp = 0;
T.split = [];
T.udata = exec_user_functions(options.user_fnc, X, Y, W, ind, options.classes);
end


% P=predictor(Y,W,IND,CLASSES,...)
% U=user_fnc(X,Y,W,IND,CLASSES,...)
% [F,S,H]=split_gen(N,X,Y,W,IND,CLASSES,...);
% S=split_eval(X,IND,F,...);
% H=impurity(Y,W,IND,CLASSES,S,...);

function T = learn_node(X, Y, W, ind, depth, min_samples, options)

n_samples = length(ind); % number of samples reaching the node
T = struct;
T.H = histc(Y(ind), options.classes)';
T.p = feval(options.predictor{1}, Y, W, ind, options.classes, options.predictor{2}{:});
T.udata = {};

if (n_samples < options.min_samples) || ...   % Low number of samples in the node
        (depth == options.max_depth)% || ...  % Too deep tree
        %(length(unique(Y(ind))) == 1)         % Pure node 
    T = leaf(T, X, Y, W, ind, options);
else
    T.tp = randsample(size(options.split_gen,2), 1);
    
    [left_ind,right_ind] = deal([]);
    splits = 0; total_splits = 0;
    batch = 1024;
    best_e = inf;
    
    make_split_node = n_samples >= min_samples;
    split_ok = false;
    
    while ~split_ok && make_split_node
        [F,f_splits,err] = feval(options.split_gen{1,T.tp}, batch, X, Y, W, ind, options.classes, options.split_gen{2,T.tp}{:});
        if isempty(f_splits)
            f_splits = feval(options.split_eval{1,T.tp}, X, ind, F, options.split_eval{2,T.tp}{:});
        end
        
        bad = sum(f_splits,1)<min_samples | sum(~f_splits,1)<min_samples;
        
        if isempty(err)
            err = inf(1, size(f_splits,2), 'single');
            err_tp = randsample(size(options.impurity,2), 1);
            tmperr = feval(options.impurity{1,err_tp}, Y, W, ind, options.classes, f_splits(:,~bad), options.impurity{2,err_tp}{:});
            err(~bad) = tmperr;
        else
            err(bad) = inf;
        end
        
        % Find minima
        [~,id] = min(err);
        if err(id) < best_e
            T.split = F(id,:);
            s = f_splits(:,id);
            left_ind = ind(s == 1);
            right_ind = ind(s == 0);
            best_e = err(id);
        end
        
        %splits = splits + batch;
        splits = splits + sum(err ~= inf);
        total_splits = total_splits + batch;
        % fprintf('.'); % This is for dbg only
        split_ok = (best_e < inf && splits >= options.splits) || (splits >= 5*options.splits) || (total_splits >= 10*options.splits);
    end % Batches
    
    clear f_splits;
    
    if ~make_split_node || best_e == inf
        if depth == 0, warning('Degenerated tree'); end;
        T = leaf(T, X, Y, W, ind, options);
    else
        T.udata = exec_user_functions(options.user_fnc, X, Y, W, ind, options.classes);
        % Learn subtrees
        log_printf(options.verbose, '[');
        T.l = learn_node(X, Y, W, left_ind, depth+1, max(min_samples/2, options.min_samples), options);
        T.r = learn_node(X, Y, W, right_ind, depth+1, max(min_samples/2, options.min_samples),options);
        log_printf(options.verbose, ']');
    end
    
end % if (stopping criteria)

end % learn_tree


function udata = exec_user_functions(user_fnc, X, Y, W, ind, classes)
udata = cell(1,size(user_fnc,2));
for i = 1:size(user_fnc,2)
    udata{i} = feval(user_fnc{1,i}, X, Y, W, ind, classes, user_fnc{2,i}{:});
end
end

function log_printf(enabled, varargin)
if enabled,fprintf(varargin{:});end;
end