% ------------------------------------------------------------
% Function: Machine Learning on coral
% Author:   Alan Li
% Project:  Coral Classification with Machine Learning
% NASA Ames Research Center, September 2016
%
% Description: Predict coral cover based upon previous model
%
% Dependencies: DBSCAN.m
%               Matlab Machine Learning Toolbox
%
% Input:
%   cl              - Machine learning model (2 class)
%   c_test          - Test Images (cell array of Images)
%   r               - 1/2 side length of regional square area
%   figureson       - Turn figures on or off (currently no figures within function)
%   cl_multi        - Machine learning model (4 class)
%
% Output: 
%   im_predict          - Predicted 2 class classification
%   im_multi_predict    - Predicted 4 class classification
%
% ------------------------------------------------------------
function [im_predict,im_multi_predict] = coral_predict_remap(cl, c_test, r, figureson, cl_multi)

customMap = [0 0 0; 1 0.5 0 ; 0 0.5 0 ; 0 0 0.8];

n_im = length(c_test);
im_predict = cell(n_im,1);
im_multi_predict = cell(n_im,1);

for k=1:n_im
    test_im = c_test{k};
    test_im_mat = reshape(test_im,[],3);
    im_gray = rgb2gray(uint8(test_im));
    
    % Find gradient locations
    [Gmag, ~] = imgradient(im_gray,'sobel');
    grad_norm = double(reshape(Gmag,[],1))./max(max(Gmag));
    grad_idx = find(grad_norm >= 0.4 & grad_norm <= 0.8);
    [J,I] = ind2sub([size(test_im,1) size(test_im,2)],grad_idx);
    X = [J I];
    IDX = DBSCAN(X,1,3);
     
    % principal components of entire image 
    COEFF = princomp(double(test_im_mat));
    for pp = 1:length(COEFF)
        if sum(COEFF(:,pp).* cl.p(:,pp)) < 0
            COEFF(:,pp) = -COEFF(:,pp);
        end
    end
    
    % Boundary based deduction of coral vs noncoral
    coralpts = [];
    nocoralpts = [];
    coralptsidx = [];
    nocoralptsidx = [];
    for ii=1:length(unique(IDX))-1
        clusteridx = find(IDX == ii);
        kk = boundary(J(clusteridx),I(clusteridx));
        [pts1,pts2] = surroundingpts3([I(clusteridx) J(clusteridx)],[I(clusteridx) J(clusteridx)], test_im,COEFF(:,1),1);
        
        coralpts = [coralpts; test_im_mat(pts2,:)*COEFF(:,1)];
        coralptsidx = [coralptsidx; pts2];
        nocoralpts = [nocoralpts; test_im_mat(pts1,:)*COEFF(:,1)];
        nocoralptsidx = [nocoralptsidx; pts1];
        nsamplepts = length(coralpts) + length(nocoralpts);
    end
    
    nr = 100;
    covrand = zeros(3,3);
    randlength = min(length(nocoralptsidx),length(coralptsidx));
    for inr = 1:nr
        rand1 = randi(length(nocoralptsidx),randlength,1);
        rand2 = randi(length(coralptsidx),randlength,1);
        randnocoral = nocoralptsidx(rand1);
        randcoral = coralptsidx(rand2);
        randtotal = [randnocoral; randcoral];
        covrand = cov(test_im_mat(randtotal,:)*COEFF)+covrand;
    end
    covrand = 1/nr*covrand;
    factor = sqrt(diag(cl.cov./covrand));           % Find appropriate scaling factor based upon entire image
    
    % Setup for next section
    count_times = zeros(size(test_im,1),size(test_im,2));        % # of times pixel has been classified
    score_total = zeros(size(test_im,1),size(test_im,2));        % pixel score
    label_total = zeros(size(test_im,1),size(test_im,2));        % label score
    if nargin == 5
        labelm_total = zeros(size(test_im,1),size(test_im,2));
    end
    
    disp('-----------------')
    for i=1:floor((size(im_gray,2)-r)/(r/2))+2
        for j=1:floor((size(im_gray,1)-r)/(r/2))+2
            xmin = max(i*r/2-r+1,1);
            xmax = min(i*r/2+r,size(test_im,2));
            ymin = max(j*r/2-r+1,1);
            ymax = min(j*r/2+r,size(test_im,1));
           
            sample_im = double(reshape(test_im(ymin:ymax,xmin:xmax,:),[],3));
            sample_grad = reshape(Gmag(ymin:ymax,xmin:xmax),[],1);
            sample_grad_norm = double(sample_grad)./max(sample_grad);
            
            COEFF = princomp(double(sample_im));
            for pp = 1:length(COEFF)
                if sum(COEFF(:,pp).* cl.p(:,pp)) < 0
                    COEFF(:,pp) = -COEFF(:,pp);
                end
            end
           
            g = 0.5;  % gradient information
            nsamplepts = 0;
            ptsflag = 0;
            
            while nsamplepts/((ymax-ymin+1)*(xmax-xmin+1)) <= 0.28 && ptsflag == 0
                temp_grad_idx = find(sample_grad_norm >= g & sample_grad_norm <= 1);

                [J,I] = ind2sub([ymax-ymin+1 xmax-xmin+1],temp_grad_idx);
                X = [J I];
                IDX = DBSCAN(X,1,3);

                coralpts = [];
                nocoralpts = [];
                coralptsidx = [];
                nocoralptsidx = [];
                for ii=1:length(unique(IDX))-1
                    clusteridx = find(IDX == ii);
                    kk = boundary(J(clusteridx),I(clusteridx));
                    [pts1,pts2] = surroundingpts3([I(clusteridx) J(clusteridx)],[I(clusteridx) J(clusteridx)], test_im(ymin:ymax,xmin:xmax,:),COEFF(:,1),4);

                    coralpts = [coralpts; sample_im(pts2,:)*COEFF(:,1)];
                    coralptsidx = [coralptsidx; pts2];
                    nocoralpts = [nocoralpts; sample_im(pts1,:)*COEFF(:,1)];
                    nocoralptsidx = [nocoralptsidx; pts1];
                    nsamplepts = length(coralpts) + length(nocoralpts);
                end
                
                if length(temp_grad_idx)/((ymax-ymin+1)*(xmax-xmin+1)) >= 0.5
                    ptsflag = 1;
                end
                
                if nsamplepts/((ymax-ymin+1)*(xmax-xmin+1)) <= 0.28  
                    g = g-0.01;
                elseif nsamplepts/((ymax-ymin+1)*(xmax-xmin+1)) >= 0.32
                    g = g+0.01;
                end
            end
            
            nr = 100;
            murand = 0;
            covrand = zeros(3,3);
            randlength = min(length(nocoralptsidx),length(coralptsidx));
            for inr = 1:nr
                rand1 = randi(length(nocoralptsidx),randlength,1);
                rand2 = randi(length(coralptsidx),randlength,1);
                randnocoral = nocoralptsidx(rand1);
                randcoral = coralptsidx(rand2);
                randtotal = [randnocoral; randcoral];
                murand = murand + median(sample_im(randtotal,:)*COEFF);
                covrand = cov(sample_im(randtotal,:)*COEFF)+covrand;
            end
            murand = murand/nr;
            covrand = 1/nr*covrand;
            
            p_sample_mu = double(sample_im)*COEFF - repmat([murand(1) murand(2) murand(3)],length(sample_im),1);
            p_sample_im = p_sample_mu.*repmat(factor',length(sample_im),1);
            
            [label_sample,scores_sample] = predict(cl.model,p_sample_im);
            if nargin == 5
                offset = cl.p'*(cl.avg'-cl_multi.avg');
                [label_multi, scores_multi] = predict(cl_multi.model,(p_sample_mu + repmat(offset',length(p_sample_im),1)).*repmat(factor',length(p_sample_im),1));
                label_multi = reshape(label_multi,ymax-ymin+1,[]);
                
                labelm_total(ymin:ymax,xmin:xmax) = labelm_total(ymin:ymax,xmin:xmax) + double(label_multi);
            end
            
            scores_sample = reshape(scores_sample(:,2) - scores_sample(:,1),ymax-ymin+1,[]);
            label_sample = reshape(label_sample,ymax-ymin+1,[]);
            
            label_total(ymin:ymax,xmin:xmax) = label_total(ymin:ymax,xmin:xmax) + double(label_sample);
            score_total(ymin:ymax,xmin:xmax) = score_total(ymin:ymax,xmin:xmax) + scores_sample;
            count_times(ymin:ymax,xmin:xmax) = count_times(ymin:ymax,xmin:xmax) + ones(ymax-ymin+1,xmax-xmin+1);
            
            str = ['Currently processing partition row: ',num2str(i), ' column: ',num2str(j), '...'];
            disp(str)
        end
    end
    disp('-----------------')
    im_predict{k} = floor(label_total./count_times);
    score_total(score_total > 0) = 2;
    score_total(score_total <= 0) = 1;
    im_predict{k} = score_total;
    
    if nargin == 5
        im_multi_predict{k} = round(labelm_total./count_times);
    else
        im_multi_predict{k} = 0;
    end   
end
end

    function result = surroundingpts(X,setpts,xmax,ymax)
        npoints = size(X,1);
        result = [];
        addpts = [0 1; 1 1; 1 0; 1 -1; 0 -1; -1 -1; -1 0; -1 1];
        
        for nn=1:npoints
            temppts = repmat(X(nn,:),length(addpts),1) + addpts;
            result = unique([result; temppts(~ismember(temppts,setpts,'rows') & temppts(:,1) > 0 & temppts(:,1) <= ymax & temppts(:,2) >0 & temppts(:,2) <= xmax,:)],'rows');
        end
    end

    function result = surroundingpts2(pts,setpts,fullimg,coeff,figuren)
        addpts = [0 1; 1 1; 1 0; 1 -1; 0 -1; -1 -1; -1 0; -1 1];

        npoints = size(pts,1);
        nadd = size(addpts,1);
        max_y = size(fullimg,1);
        max_x = size(fullimg,2);
        totalmap = zeros(max_y*max_x,1);
        fullimgind = reshape(fullimg,[],3);
        lvalue = [];
        uvalue = [];
        
        for nn=1:npoints
            temppts = repmat(pts(nn,:),length(addpts),1) + addpts;
            temppts(ismember(temppts,setpts,'rows'),:) = [];
            temppts(temppts(:,1) <= 0 | temppts(:,1) > max_x | temppts(:,2) <=0 | temppts(:,2) > max_y, :) = [];
            indidx = sub2ind([max_y,max_x],temppts(:,2),temppts(:,1));
            
            refpt = reshape(fullimg(pts(nn,2),pts(nn,1),:),[],3)*coeff;
            cmppt = fullimgind(indidx,:)*coeff;
            subpt = cmppt - refpt;
            lvalue = [lvalue; subpt(subpt > 0)];
            uvalue = [uvalue; subpt(subpt < 0)];
            lpts = temppts(cmppt > refpt,:);
            dpts = temppts(cmppt < refpt,:);
            lidx = sub2ind([max_y,max_x], lpts(:,2), lpts(:,1));
            didx = sub2ind([max_y,max_x], dpts(:,2), dpts(:,1));
            totalmap(lidx) = totalmap(lidx) - 1;
            totalmap(didx) = totalmap(didx) + 1;
        end
        totalmap(totalmap < 0) = -1;
        totalmap(totalmap > 0) = 1;
        nidx1 = find(totalmap == -1);
        nidx2 = find(totalmap == 1);
        [JJ1,II1] = ind2sub([max_y max_x],nidx1);
        [JJ2,II2] = ind2sub([max_y max_x],nidx2);
        figure(figuren)
        hold on
        scatter(II1,JJ1,'y','filled')
        scatter(II2,JJ2,'k','filled')
        
        if median(lvalue) > median(abs(uvalue)) || isempty(uvalue)
            result = -1;
        elseif median(lvalue) < median(abs(uvalue)) || isempty(lvalue)
            result = 1;
        else
            result = 0;
        end     
    end

    % Finds surrounding points of a cluster
    function [result1,result2] = surroundingpts3(pts,setpts,fullimg,coeff,figuren)
%         addpts = [0 1; 1 1; 1 0; 1 -1; 0 -1; -1 -1; -1 0; -1 1];
%         addpts = [0 1; 1 1; 1 0; 1 -1; 0 -1; -1 -1; -1 0; -1 1; 0 2; 1 2; 2 2; 2 1; 2 0; 2 -1; 2 -2; 1 -2; 0 -2; -1 -2; -2 -2; -2 -1; -2 0; -2 1; -2 2; -1 2];
        addpts = [0 1; 1 1; 1 0; 1 -1; 0 -1; -1 -1; -1 0; -1 1; 0 2; 1 2; 2 1; 2 0; 2 -1; 1 -2; 0 -2; -1 -2; -2 -1; -2 0; -2 1; -1 2];

        npoints = size(pts,1);
        nadd = size(addpts,1);
        max_y = size(fullimg,1);
        max_x = size(fullimg,2);
        totalmap = zeros(max_y*max_x,1);
        fullimgind = reshape(fullimg,[],3);
        lvalue = [];
        uvalue = [];
        
        for nn=1:npoints
            temppts = repmat(pts(nn,:),length(addpts),1) + addpts;
            temppts(ismember(temppts,setpts,'rows'),:) = [];
            temppts(temppts(:,1) <= 0 | temppts(:,1) > max_x | temppts(:,2) <=0 | temppts(:,2) > max_y, :) = [];
            indidx = sub2ind([max_y,max_x],temppts(:,2),temppts(:,1));
            
            refpt = reshape(fullimg(pts(nn,2),pts(nn,1),:),[],3)*coeff;
            cmppt = fullimgind(indidx,:)*coeff;
            subpt = cmppt - refpt;
            lvalue = [lvalue; subpt(subpt > 0)];
            uvalue = [uvalue; subpt(subpt < 0)];
            lpts = temppts(cmppt > refpt,:);
            dpts = temppts(cmppt < refpt,:);
            lidx = sub2ind([max_y,max_x], lpts(:,2), lpts(:,1));
            didx = sub2ind([max_y,max_x], dpts(:,2), dpts(:,1));
            totalmap(lidx) = totalmap(lidx) - 1;
            totalmap(didx) = totalmap(didx) + 1;
        end
        totalmap(totalmap < 0) = -1;
        totalmap(totalmap > 0) = 1;
        nidx1 = find(totalmap == -1);
        nidx2 = find(totalmap == 1);
        nidx0 = find(totalmap == 0);
        totalmapim = reshape(totalmap,max_y,[]);
        totalmap_conv = reshape(conv2(totalmapim,ones(3,3),'same'),[],1);
        
        totalmap_conv(nidx0) = 0;
        totalmap_conv(totalmap_conv < 0) = -1;
        totalmap_conv(totalmap_conv > 0) = 1;
        nidx1 = find(totalmap_conv == -1);
        nidx2 = find(totalmap_conv == 1);
                      
        [JJ1,II1] = ind2sub([max_y max_x],nidx1);
        [JJ2,II2] = ind2sub([max_y max_x],nidx2);
               
        result1 = nidx1;
        result2 = nidx2;
    end