% ---------------
% Function: Machine Learning on coral
% Author:   Alan Li
% Project:  Coral Classification with Machine Learning
% NASA Ames Research Center, September 2016
%
% Description: Use of SVM (for now) to distinguish coral cover based upon 
% a truth set (which might be derived from UAV products, hence not exactly 
% "truth"), applied to a training set (usually satellite products). This particular
% type of learning is focused on gradients
% 
% NOTE: Truth and training images MUST be rectified beforehand (see SIFT_register_down.m)
%
% Input:
%   c_truth         - Truth Data (cell array of Images)
%   c_ori           - Original Images( cell array of Images)
%   c_train         - Training Images (cell array of Images)
%   n               - Number of training points (per class)
%   figureson       - Turn figures on
%
% Output: 
%   cl              - Machine learning model

function cl = coral_learn_remap(c_truth,c_ori,c_train,n,figureson)

customMap = [0 0 0; 1 0.5 0 ; 0 0.5 0 ; 0 0 0.8];

if (length(c_truth) ~= length(c_train)) || (length(c_truth) ~= length(c_ori)) || (length(c_train) ~= length(c_ori))     % check same # of images
    error('Not same number of images in cells!')
    return
end

n_im = length(c_truth);
n_class = length(unique(c_truth{1}));
n_features = size(c_ori{1},3);
colors = distinguishable_colors(n_class);

c_ori_down = cell(n_im,1);
c_truth_down = cell(n_im,1);
truth_pts = [];
ori_pts = [];
train_pts = [];
for i =1:n_im
    % resize to training image size
    scalefactor = max(size(c_train{i},1)/size(c_ori{i},1), size(c_train{i},2)/size(c_ori{i},2));
    c_ori_down{i} = imresize(c_ori{i},scalefactor);
    c_ori_down{i} = c_ori_down{i}(1:size(c_train{i},1), 1:size(c_train{i},2), :);
    c_truth_down{i} = uint8(imresize(double(c_truth{i}),scalefactor));
    c_truth_down{i} = c_truth_down{i}(1:size(c_train{i},1), 1:size(c_train{i},2), :);

    truth_pts = [truth_pts; reshape(c_truth_down{i},[],1)];
    ori_pts = [ori_pts; reshape(c_ori_down{i},[],n_features)];      % no preallocation for now
    train_pts = [train_pts; reshape(c_train{i},[],n_features)];
end

truth_class = cell(1,n_class);
ori_class = cell(1,n_class);
train_class = cell(1,n_class);
idx_rand = zeros(n,n_class);
for i=1:n_class
    truth_class{i} = truth_pts(truth_pts == i);
    ori_class{i} = ori_pts(truth_pts == i,:);
    train_class{i} = train_pts(truth_pts == i,:);
    
    idx_rand(:,i) = randi(length(truth_class{i}),n,1);
end

TrainFeature = [];
OriFeature = [];
TrainClass = [];
for i=1:n_class;
    OriFeature = [OriFeature; ori_class{i}(idx_rand(:,i),:)];
    TrainFeature = [TrainFeature; train_class{i}(idx_rand(:,i),:)];
    TrainClass = [TrainClass; truth_class{i}(idx_rand(:,i))];
end
t = templateSVM('Standardize',1,'KernelFunction','polynomial','PolynomialOrder',3);
OriFeature_mean = mean(OriFeature);
COEFF_ori = princomp(double(OriFeature));
p_OriFeature = (double(OriFeature) - repmat(OriFeature_mean,length(OriFeature),1))*COEFF_ori;

% Support vector machine
str = ['Performing SVM 3rd order polynomial fit with ',num2str(n), ' training points per class...'];
disp('-----------------')
disp(str)
disp('-----------------')
cl.model = fitcecoc(double(p_OriFeature),TrainClass,'Learners',t);
cl.avg = OriFeature_mean;
cl.p = COEFF_ori;
cl.cov = cov(p_OriFeature);

if figureson
    d = length(p_OriFeature)/n_class;
    figure
    hold on
    for i=1:n_class
        scatter3(OriFeature(d*(i-1)+1:d*i,1), OriFeature(d*(i-1)+1:d*i,2), OriFeature(d*(i-1)+1:d*i,3),8,customMap(i,:),'filled');
    end
    axis equal; hold off; grid on;
    title('Spectral Space')
    xlabel('R');ylabel('G');zlabel('B')
        
    figure
    hold on
    for i=1:n_class
        scatter3(p_OriFeature(d*(i-1)+1:d*i,1), p_OriFeature(d*(i-1)+1:d*i,2), p_OriFeature(d*(i-1)+1:d*i,3),8,customMap(i,:),'filled');
    end
    axis equal; hold off; grid on;
    title('PCA Space')
    xlabel('PC1');ylabel('PC2');zlabel('PC3')
end


