% ---------------
% Function: Register/match high resolution image to low resolution image, output low resolution
% Author:   Alan Li
% Project:  Coral Classification with Machine Learning
% NASA Ames Research Center, September 2016
%
% Description: Use SIFT matching to rectify high resolution image against low
% resolution image. For best result, both images should approximately cover
% the same area, with the low resolution image covering a little more.
%
% Dependencies: 
%   sift_match.m
%   VLFeat 0.9.20 (can be downloaded at http://www.vlfeat.org/index.html)
%   MATLAB Image processing toolbox
%
% Input:
%   imhi            - High resolution image (filepath)
%   imlo            - Low resolution image (filepath)
%   saveon          - Boolean: save current best homography transform as H_{str}
%                     where 'str' is the name of the low resolution image
%   Hprev           - Previous homography transform to compare against (structure result from sift_match)
%   im_add          - Additional images to transform by same homography (usually reserved for additional bands)
%
% Output: 
%   Im2_out         - Registered Image of Im2, at same resolution as Im1, organized in cells;
%                      Normally in RGB, but if im_add is specified, instead outputs according to input band
%   Hsave           - Best homography transform (from low resolution image to downscaled high resolution image)
% ---------------

function [Im2_out, Hsave] = SIFT_register_down(imhi,imlo, saveon, Hprev, im_add)

if nargin == 3 || nargin == 4
    savebool = saveon;
    k = strfind(imlo,'.');
    if ~isempty(k)
        str = imlo(1:k-1);
    else
        str = imlo(1:end);
    end
else
    savebool = 0;
end

% Read in images
Im1 = imread(imhi);     % Hi Resolution Image
Im2 = imread(imlo);     % Low Resolution Image
Im1 = Im1(:,:,1:3);     % RGB only
Im2 = Im2(:,:,1:3);

scalefactor = (size(Im1,1)/size(Im2,1) + size(Im1,2)/size(Im2,2))/3; % average downscale factor
scale = 1/scalefactor;
Im1_down = imresize(Im1,scale);  % downscale

nrep = 10;          % # of times to repeat registration
diffresult = zeros(1,nrep);
H = cell(nrep,1);

for k=1:nrep
%     Everything is done by sift_match, which uses RANSAC
    matchResult = sift_match(Im2, Im1_down);
%     figure(2)
%     axis on
    
    H{k} = matchResult;
    Ht = matchResult.model;
    
    topleft = Ht\[1;1;1]; topleft = 1/topleft(3)*topleft;
    topright = Ht\[size(Im1_down,2);1;1]; topright = 1/topright(3)*topright;
    botleft = Ht\[1;size(Im1_down,1);1]; botleft = 1/botleft(3)*botleft;
    botright = Ht\[size(Im1_down,2);size(Im1_down,1);1]; botright = 1/botright(3)*botright;
    
    % calculate area of quadrilateral
    a = sqrt(sum((topleft(1:2) - topright(1:2)).^2));
    b = sqrt(sum((topright(1:2) - botright(1:2)).^2));
    c = sqrt(sum((botright(1:2) - botleft(1:2)).^2));
    d = sqrt(sum((botleft(1:2) - topleft(1:2)).^2));
    ad = sqrt(sum((topright(1:2) - botleft(1:2)).^2));
    theta1 = acos((a^2+d^2-ad^2)/(2*a*d));
    theta2 = acos((b^2+c^2-ad^2)/(2*b*c));
    s = (a+b+c+d)/2;
    theta = theta1+theta2;
    Area = sqrt((s-a)*(s-b)*(s-c)*(s-d) - a*b*c*d*(cos(theta/2))^2);
    
    % Approximate the scale factor to further scale by (not perfect since projection might introduce skewness, but assume basically rectangular)
    l = sqrt(Area/(size(Im1_down,2)/size(Im1_down,1)));
    w = size(Im1_down,2)/size(Im1_down,1)*l;
    scalefactor2 = l/size(Im1_down,1);
    
    [X,Y] = meshgrid(1:1/scalefactor2:size(Im1_down,2), 1:1/scalefactor2:size(Im1_down,1));
    im1pts = [reshape(X,1,size(X,1)*size(X,2)); reshape(Y,1,size(X,1)*size(X,2)); ones(1,size(X,1)*size(X,2))];
    
    newpts = Ht\im1pts;
    newpts = newpts./repmat(newpts(3,:),3,1);

    Im_reg = zeros(size(X,1)*size(X,2),3);
    counter = 1;
    for i=1:length(newpts)
        if (round(newpts(1,i)) < 1) || (round(newpts(2,i)) < 1) || (round(newpts(1,i)) > size(Im2,2)) || (round(newpts(2,i)) > size(Im2,1))
            Im_reg(counter,:) = [0;0;0];
        else
            Im_reg(counter,:) = Im2(round(newpts(2,i)),round(newpts(1,i)),:);
        end
        counter = counter+1;
    end
    Im_reg = reshape(Im_reg,size(X,1),size(X,2),3);

    % Error
    Im1_down2 = double(imresize(Im1_down,scalefactor2));
    Im_reg_bw = double(rgb2gray(uint8(Im_reg)));
    Im_reg_bwscale = 1/(max(max(Im_reg_bw)) - min(min(Im_reg_bw)))*(Im_reg_bw - min(min(Im_reg_bw)));
    Im1_down2_bw = double(rgb2gray(uint8(Im1_down2(1:size(Im_reg,1),1:size(Im_reg,2),:))));
    Im1_down2_bwscale = 1/(max(max(Im1_down2_bw)) - min(min(Im1_down2_bw)))*(Im1_down2_bw - min(min(Im1_down2_bw)));

    diffresult(k) = sum(sum(abs(Im_reg_bwscale - Im1_down2_bwscale).^2))/(size(Im_reg_bwscale,1)*size(Im_reg_bwscale,2));
end

% Find best homography that gives smallest error
% Can use H from previous run to compare if required
[~,idx] = min(diffresult);
if nargin >= 4
    load(Hprev)
    Hcomp = Hsave.model;
    topleft = Hcomp\[1;1;1]; topleft = 1/topleft(3)*topleft;
    topright = Hcomp\[size(Im1_down,2);1;1]; topright = 1/topright(3)*topright;
    botleft = Hcomp\[1;size(Im1_down,1);1]; botleft = 1/botleft(3)*botleft;
    botright = Hcomp\[size(Im1_down,2);size(Im1_down,1);1]; botright = 1/botright(3)*botright;
    
    a = sqrt(sum((topleft(1:2) - topright(1:2)).^2));
    b = sqrt(sum((topright(1:2) - botright(1:2)).^2));
    c = sqrt(sum((botright(1:2) - botleft(1:2)).^2));
    d = sqrt(sum((botleft(1:2) - topleft(1:2)).^2));
    ad = sqrt(sum((topright(1:2) - botleft(1:2)).^2));
    theta1 = acos((a^2+d^2-ad^2)/(2*a*d));
    theta2 = acos((b^2+c^2-ad^2)/(2*b*c));
    s = (a+b+c+d)/2;
    theta = theta1+theta2;
    Area = sqrt((s-a)*(s-b)*(s-c)*(s-d) - a*b*c*d*(cos(theta/2))^2);
    
    l = sqrt(Area/(size(Im1_down,2)/size(Im1_down,1)));
    scalefactor2 = l/size(Im1_down,1);
    
    [X,Y] = meshgrid(1:1/scalefactor2:size(Im1_down,2), 1:1/scalefactor2:size(Im1_down,1));
    im1pts = [reshape(X,1,size(X,1)*size(X,2)); reshape(Y,1,size(X,1)*size(X,2)); ones(1,size(X,1)*size(X,2))];    
    newpts = Hcomp\im1pts;
    newpts = newpts./repmat(newpts(3,:),3,1);
    
    Im2comp = zeros(size(X,1)*size(X,2),3);
    counter = 1;
    for i=1:length(newpts)
        if (round(newpts(1,i)) < 1) || (round(newpts(2,i)) < 1) || (round(newpts(1,i)) > size(Im2,2)) || (round(newpts(2,i)) > size(Im2,1))
            Im2comp(counter,:) = [0;0;0];
        else
            Im2comp(counter,:) = Im2(round(newpts(2,i)),round(newpts(1,i)),:);
        end
        counter = counter+1;
    end
    Im2comp = reshape(Im2comp,size(X,1),size(X,2),3);
    
    % Error
    Im1_down2 = double(imresize(Im1_down,scalefactor2));
    Im2comp_bw = double(rgb2gray(uint8(Im2comp)));
    Im2comp_bwscale = 1/(max(max(Im2comp_bw)) - min(min(Im2comp_bw)))*(Im2comp_bw - min(min(Im2comp_bw)));
    Im1_down2_bw = double(rgb2gray(uint8(Im1_down2(1:size(Im2comp,1),1:size(Im2comp,2),:))));
    Im1_down2_bwscale = 1/(max(max(Im1_down2_bw)) - min(min(Im1_down2_bw)))*(Im1_down2_bw - min(min(Im1_down2_bw)));
    
    diffcompare = sum(sum(abs(Im2comp_bwscale - Im1_down2_bwscale).^2))/(size(Im_reg_bwscale,1)*size(Im_reg_bwscale,2));
    
    if diffcompare <= min(diffresult)
        disp('-----------------')
        disp('Old homography transform gives better solution')
        disp('-----------------')
        H{idx} = Hsave;
    else
        disp('-----------------')
        disp('New homography transform gives better solution')
        disp('-----------------')
    end
end

% Reconstruct best case match
topleft = H{idx}.model\[1;1;1]; topleft = 1/topleft(3)*topleft;
topright = H{idx}.model\[size(Im1_down,2);1;1]; topright = 1/topright(3)*topright;
botleft = H{idx}.model\[1;size(Im1_down,1);1]; botleft = 1/botleft(3)*botleft;
botright = H{idx}.model\[size(Im1_down,2);size(Im1_down,1);1]; botright = 1/botright(3)*botright;

a = sqrt(sum((topleft(1:2) - topright(1:2)).^2));
b = sqrt(sum((topright(1:2) - botright(1:2)).^2));
c = sqrt(sum((botright(1:2) - botleft(1:2)).^2));
d = sqrt(sum((botleft(1:2) - topleft(1:2)).^2));
ad = sqrt(sum((topright(1:2) - botleft(1:2)).^2));
theta1 = acos((a^2+d^2-ad^2)/(2*a*d));
theta2 = acos((b^2+c^2-ad^2)/(2*b*c));
s = (a+b+c+d)/2;
theta = theta1+theta2;
Area = sqrt((s-a)*(s-b)*(s-c)*(s-d) - a*b*c*d*(cos(theta/2))^2);

l = sqrt(Area/(size(Im1_down,2)/size(Im1_down,1)));
scalefactor2 = l/size(Im1_down,1);

[X,Y] = meshgrid(1:1/scalefactor2:size(Im1_down,2), 1:1/scalefactor2:size(Im1_down,1));
im1pts = [reshape(X,1,size(X,1)*size(X,2)); reshape(Y,1,size(X,1)*size(X,2)); ones(1,size(X,1)*size(X,2))];    
newpts = H{idx}.model\im1pts;
newpts = newpts./repmat(newpts(3,:),3,1);

Im2_reg = zeros(size(X,1)*size(X,2),3);
counter = 1;
for i=1:length(newpts)
    if (round(newpts(1,i)) < 1) || (round(newpts(2,i)) < 1) || (round(newpts(1,i)) > size(Im2,2)) || (round(newpts(2,i)) > size(Im2,1))
        Im2_reg(counter,:) = [0;0;0];
    else
        Im2_reg(counter,:) = Im2(round(newpts(2,i)),round(newpts(1,i)),:);
    end
    counter = counter+1;
end
Im2_reg = reshape(Im2_reg,size(X,1),size(X,2),3);

FigHandle = figure;
set(FigHandle, 'Position', [50, 50, 350, 800]);
subtightplot(1,2,1)
image(imresize(Im1_down,scalefactor2));
axis off; axis equal; axis tight;
% title('Im1 (High Res downscaled)')
subtightplot(1,2,2)
imshow(uint8(Im2_reg))
axis equal; axis off; axis tight;
% title('Im2 (low Res)')

Hsave = H{idx};
if nargin == 5
    for k=1:length(im_add)
        im_band = imread(im_add{k});
        Im2_band = zeros(size(X,1)*size(X,2),1);
        counter = 1;
        for i=1:length(newpts)
            if (round(newpts(1,i)) < 1) || (round(newpts(2,i)) < 1) || (round(newpts(1,i)) > size(Im2,2)) || (round(newpts(2,i)) > size(Im2,1))
                Im2_band(counter) = 0;
            else
                Im2_band(counter) = im_band(round(newpts(2,i)),round(newpts(1,i)));
            end
            counter = counter+1;
        end
        Im2_band = reshape(Im2_band,size(X,1),size(X,2),1);
        Im2_out{k} = Im2_band;
    end
else
    Im2_out{1} = Im2_reg;
end

if savebool
    disp('-----------------')
    disp('Saving best homography transform...')
    disp('-----------------')
    save(['H_',str,'_down.mat'],'Hsave');
end