function [M,cfun,err,cfun_test,err_test] = ML_MPS (M,F,y,F_test,y_test,Nkeep,estep)
% < Description >
%
% [M,cfun,err,cfun_test,err_test] = ML_MPS (M,F,y,F_test,y_test,Nkeep,alpha)
%
% Machine learning method based on matrix product states (MPS), proposed by
% Stoudenmire2016 [E. M. Stoudenmire and D. J. Schwab, Advances in Neural
% Information Processing Systems 29, 4799 (2016), or arXiv:1605.05775]. Two
% versions of the paper have minor differences; here we follow the notation
% of the published NeurIPS version.
%
% Note that there are some important technical details that are not
% discussed in Stoudenmire2016 and devised by Seung-Sup to achieve
% stability and performance. Such parts in the code are denoted by the
% comment "Unpublished; devised by S.Lee".
%
% < Input >
% M : [cell array] MPS. Each cell contains rank-3 tensor for each chain
%       site. If given as empty (i.e. []), this function initializes the
%       MPS based on the feature vectors training dataset.
% F : [rank-3 tensor] Collection of feature vectors for training dataset.
%       F(m,:,n) is the feature vector for the n-th site (i.e. pixel) and
%       the m-th data (i.e. image).
% y : [matrix] Collection of correct decision functions for training
%       dataset. y(m,n) is 1 if the m-th data (i.e. image) is labeled by
%       the n-th label, 0 otherwise.
% F_test : [rank-3 tensor] Collection of feature vectors for test dataset.
% y_test : [matrix] Collection of correct decision functions for test
%       dataset.
% Nkeep : [numeric] Maximum bond dimension for the MPS.
% estep : [vector] Prefactors to the gradient of B tensor, i.e. \eta
%       mentiond in between Eqs. (7) and (8) in Stoudenmire2016. The length
%       of estep defines the number of sweeps; there will be 2*numel(estep)
%       sweeps, left and right. Each value, estep(m), applies to the m-th
%       pair of sweeps.
%
% < Output >
% M : [cell array] Result MPS optimized after sweeps.
% cfun : [matrix] Record of measured cost functions (see the second
%       paragraph of Sec. 4 of Stoudenmire2016) per dataset at each
%       iteration, for training dataset. Its n-th column corresponds to the
%       n-th sweep. The values are recorded following the linear indexing
%       order of MATLAB. That is, cfun(:) contains the values in the order
%       of occurences.
%       Note that in Stoudenmire2016, the cost function is defined by
%       summing over the training dataset. But here, we use the *average*
%       cost function for each training dataset; that is, mean instead of
%       sum. By doing this, the value of the cost function is better
%       comparable when we use different size of datasets.
% err : [matrix] Record of misclassification error rates for predicting
%       labels at each iteration, for training dataset. 'err' follows the
%       same indexing as 'cfun'.
% cfun_test : [matrix] Record of measured cost functions for test dataset.
% err_test : [matrix] Record of error rate of predicting labels for test
%       dataset. 
%
% Written by S.Lee (Jul.17,2019)
% Revised by S.Lee (Jul.08,2020): Implemented a stable way of initializing
%       MPS, and the normalization of feature vectors.

% try % uncomment for debugging

tobj = tic2;

% sanity check of input
if size(F,1) ~= size(y,1)
    error('ERR: # of training datasets in F and y are inconsistent.');
elseif size(F_test,1) ~= size(y_test,1)
    error('ERR: # of test datasets in F and y are inconsistent.');
elseif size(y_test,2) ~= size(y,2)
    error('ERR: # of labels for training and test datasets are inconsistent.');
end

% result matrices
cfun = nan(size(F,3),numel(estep)*2); % cost function for training data
err = nan(size(F,3),numel(estep)*2); % error rate for training data
cfun_test = nan(size(F,3),numel(estep)*2); % cost function for test data
err_test = nan(size(F,3),numel(estep)*2); % error rate for test data

fprintf('Machine learning using MPS\n');
fprintf(['  Length = ',sprintf('%i',size(F,3)),', # of training dataset = ', ...
    sprintf('%.4g',size(F,1)),', # of test dataset = ', ... 
    sprintf('%.4g',size(F_test,1)),'\n']);
fprintf(['  # of labels = ',sprintf('%.4g',size(y,2)),', Nkeep = ', ...
    sprintf('%.4g',Nkeep),', ',sprintf('%i',numel(estep)),' x 2 sweeps\n']);


if isempty(M)
% initialize MPS based on the training dataset, and update feature vectors
% in the effective basis
    disptime('Initialize MPS');
    M = cell(1,size(F,3)); % MPS tensors
end
% feature vectors in the effective space spanned by MPS
Flr = cell(1,numel(M)+2); % for training data
% Each cell element Flr{m} is a matrix. Each Flr{m}(n,:) corresponds to the
% feature vectors for the n-th dataset (i.e. n-th image).

for itn = (1:numel(M))
    % contract Flr{itn} from the previous iteration and local feature
    % vectors
    T = reshape(F(:,:,itn),[size(F,1) 1 size(F,2)]);
    if ~isempty(Flr{itn})
        T = T.*Flr{itn};
    end
    
    if isempty(M{itn})
        % [Unpublished; devised by S.Lee]
        % Initialize each tensor based on the training data set. We choose
        % the MPS tensors to better span the space that spanned by feature
        % vectors.
        
        if itn < numel(M)
            % all the tensors except for the last one has three legs, as
            % usual: left-physical-right
            V = ML_MPS_rightSV(T(:,:),Nkeep); % find the space spanned by feature vectors
            M{itn} = reshape(V,[size(T,2) size(T,3) size(V,2)]);

        else % itn == numel(M)
            % the last tensor has four legs: left-physical-right-label
            M{itn} = zeros(size(M{itn-1},3),size(F,2),1,size(y,2));
            for itl = (1:size(y,2))
                V = ML_MPS_rightSV(T(:,:).*y(:,itl),1); % find the space spanned by feature vectors
                M{itn}(:,:,1,itl) = reshape(V,[size(T,2),size(T,3)]);
            end
        end
    end
    
    if itn < numel(M)
        % Flr{numel(M)+1} is not necessary, as it will be overwritten by a
        % new matrix without being used
        T = sum(sum(T.*reshape(M{itn},[1 size(M{itn})]),2),3);
        
        % [Unpublished; devised by S.Lee] normalize feature vectors
        Flr{itn+1} = ML_MPS_normalize(reshape(T,[size(T,1) size(T,4)]));
    end
end

% for test data
Flr_test = cell(1,numel(M)+2);

for itn = (1:(numel(M)-2)) % skip last two sites, since Flr for them will not be used
    T = reshape(M{itn},[1 size(M{itn})]);
    if ~isempty(Flr_test{itn})
        T = sum(T.*Flr_test{itn},2);
    end
    T = reshape(T,[size(T,1) size(T,3) size(T,4)]);
    T = sum(T.*F_test(:,:,itn),2);
    T = reshape(T,[size(T,1) size(T,3)]); % put the singleton dimension at the end
    
	% [Unpublished; devised by S.Lee] normalize feature vectors
    Flr_test{itn+1} = ML_MPS_normalize(T);
    
end

disptime('Start sweeping');

for its = (1:numel(estep))
    % left <- right
    for itn = (numel(M):-1:2)
        B = contract(M{itn-1},3,3,M{itn},4,1);
        % leg order of M{itn-1}: left-physical-right
        % leg order of M{itn}: left-physical-right-label
        % leg order of B: left-physical(itn-1)-physical(itn)-right-label
        
        % for test data, compute only cost function and error rate
        [cfun_test(end+1-itn,2*its-1),err_test(end+1-itn,2*its-1)] = ...
            ML_MPS_1step (B,Flr_test{itn-1},F_test(:,:,itn-1),F_test(:,:,itn),Flr_test{itn+2},y_test);

        % compute cost function, error rate, gradient of B
        [cfun(end+1-itn,2*its-1),err(end+1-itn,2*its-1),dB] = ...
            ML_MPS_1step (B,Flr{itn-1},F(:,:,itn-1),F(:,:,itn),Flr{itn+2},y);
        
        % update B tensor
        B = B + dB*estep(its);
              ML_MPS_1step (B,Flr{itn-1},F(:,:,itn-1),F(:,:,itn),Flr{itn+2},y);

        
        % SVD and update M{itn-1}, M{itn}
        [M{itn-1},S,M{itn}] = svdTr(B,5,[1 2 5],Nkeep,[]); % associate the label leg with M{itn-1}
        M{itn-1} = contract(M{itn-1},4,4,diag(S),2,1,[1 2 4 3]); % leg order: left-physical-right-label
        
        % update Flr{itn+1} in accordance with the updated M{itn}
        T = permute(M{itn},[3 2 1]);
        T = reshape(T,[1 size(T)]);
        if ~isempty(Flr{itn+2})
            T = sum(T.*Flr{itn+2},2);
        end
        T = reshape(T,[size(T,1) size(T,3) size(T,4)]);
        T = sum(T.*F(:,:,itn),2);
        T = reshape(T,[size(T,1) size(T,3)]);
        % [Unpublished; devised by S.Lee] normalize feature vectors
        Flr{itn+1} = ML_MPS_normalize(T);
        
        % update Flr_test{itn+1} in accordance with the updated M{itn}
        T = permute(M{itn},[3 2 1]);
        T = reshape(T,[1 size(T)]);
        if ~isempty(Flr_test{itn+2})
            T = sum(T.*Flr_test{itn+2},2);
        end
        T = reshape(T,[size(T,1) size(T,3) size(T,4)]);
        T = sum(T.*F_test(:,:,itn),2);
        T = reshape(T,[size(T,1) size(T,3)]);
        % [Unpublished; devised by S.Lee] normalize feature vectors
        Flr_test{itn+1} = ML_MPS_normalize(T);
    end
    
    disptime(['Sweep #',sprintf('%02i/%02i',2*its-1,2*numel(estep)), ...
        ' | left <- right,  eta = ',sprintf('%.3g',estep(its))]);
    fprintf(['  Training: cost fun = ',sprintf('%.3e',cfun(numel(M)-1,2*its-1)), ...
        ', error rate = ',sprintf('%.2f',err(numel(M)-1,2*its-1)*100),'%%\n']);
    fprintf(['     Test : cost fun = ',sprintf('%.3e',cfun_test(numel(M)-1,2*its-1)), ...
        ', error rate = ',sprintf('%.2f',err_test(numel(M)-1,2*its-1)*100),'%%\n']);

    % left -> right
    for itn = (2:numel(M))
        B = contract(M{itn-1},4,3,M{itn},3,1,[1 2 4 5 3]);
        % leg order of M{itn-1}: left-physical-right-label
        % leg order of M{itn}: left-physical-right
        % leg order of B: left-physical(itn-1)-physical(itn)-right-label
        
        % for test data, compute only cost function and error rate
        [cfun_test(itn-1,2*its),err_test(itn-1,2*its)] = ...
            ML_MPS_1step (B,Flr_test{itn-1},F_test(:,:,itn-1),F_test(:,:,itn),Flr_test{itn+2},y_test);

        % compute cost function, error rate, gradient
        [cfun(itn-1,2*its),err(itn-1,2*its),dB] = ...
            ML_MPS_1step (B,Flr{itn-1},F(:,:,itn-1),F(:,:,itn),Flr{itn+2},y);

        % update B tensor
        B = B + dB*estep(its);
        
        % % % % % TODO - Exercise (a)   (Start) % % % % %
        % SVD and update M{itn-1}, M{itn}
        
        % update Flr{itn} in accordance with the updated M{itn-1}
        
        % [Unpublished; devised by S.Lee] normalize feature vectors
        Flr{itn} = ML_MPS_normalize(T);
        
        % update Flr_test{itn} in accordance with the updated M{itn-1}
        
        % [Unpublished; devised by S.Lee] normalize feature vectors
        Flr_test{itn} = ML_MPS_normalize(T);
        % % % % % TODO - Exercise (a)   (End) % % % % %
    end
    
    disptime(['Sweep #',sprintf('%02i/%02i',2*its,2*numel(estep)), ...
        ' | left -> right,  eta = ',sprintf('%.3g',estep(its))]);
    fprintf(['  Training: cost fun = ',sprintf('%.3e',cfun(numel(M)-1,2*its)), ...
        ', error rate = ',sprintf('%.2f',err(numel(M)-1,2*its)*100),'%%\n']);
    fprintf(['     Test : cost fun = ',sprintf('%.3e',cfun_test(numel(M)-1,2*its)), ...
        ', error rate = ',sprintf('%.2f',err_test(numel(M)-1,2*its)*100),'%%\n']);
end

toc2(tobj,'-v');
chkmem;

% % % uncomment for debugging
% catch e
%     disp(getReport(e));
%     keyboard
% end

end


function [cfun,err,varargout] = ML_MPS_1step (B,F1,F2,F3,F4,y)
% < Description >
%
% [cfun,err [, dB] ] = ML_MPS_1step (B,F1,F2,F3,F4,y)
%
% < Input >
% B : [rank-5 tensor] The contraction of two local tensors of the MPS,
%       associated with the current orthogonality center.
% F1, F2, F3, F4 : [tensors] Feature vector data. F2 and F3 are local
%       feature vectors at two sites associated with the current
%       orthogonality center. F1 (F4) is the contraction of feature vectors
%       at the left (right) parts of chain with the MPS tensors; F1 and F4
%       are the feature vectors in effective basis.
% y : [matrix] Collection of correct decision functions.
%
% < Output >
% cfun, err : [numeric] Cost function per dataset and error rate of
%       predicting labels, respectively. 
% dB : (Optional) [rank-5 tensor] Gradient for the B tensor. It is \Delta B
%       in Eq. (7) of Stoudenmire2016.
%

% % % % % TODO - Exercise (a)   (Start) % % % % %
% % evaluate decision function f^l (x) in Eq. (6) of Stoudenmire2016

% Goal: generate matrix fx, such that fx(n,l) means f^l (x_n)

% insert one dimension (leg) to the front, which corresponds to
% data set indices


% % % % % TODO - Exercise (a)   (End) % % % % %

[~,maxid] = max(fx,[],2);
err = 1 - sum(y((1:size(fx,1)).'+(maxid-1)*size(fx,1)))/size(fx,1);

% % deviation of the decision function from the correct value;
% % y_n^l - f^l (x_n) in Fig. 6(d) of Stoudenmire2016
ydiff = y - fx; 

% % cost function per data set
cfun = sum(abs(ydiff(:)).^2)/2/size(y,1);

if nargout > 2 %~isempty(estep)
    % % % % % TODO - Exercise (a)   (Start) % % % % %
    % % compute the gradient \Delta B in Eq. (7) of Stoudenmire2016
    
    % Goal: generate rank-5 tensor dB that corresponds to \Delta B
    
    % % % % % TODO - Exercise (a)   (End) % % % % %

    % [Unpublished; devised by S.Lee]
    % In Eq. (7) of Stoudenmire2016, \Delta B is obtained by the sum over
    % training dataset, which can be translated to sum(...) in MATLAB. But
    % here we use mean(...), i.e. dividing by the size of the training
    % dataset. By doing so, we can use similar choices of \eta's (estep)
    % for different training dataset sizes.
    dB = mean(dB,1);
    dB = reshape(dB,[size(dB,2) size(dB,3) size(dB,4) size(dB,5) size(dB,6)]);
    
    varargout{1} = dB;
    
end

end

function V = ML_MPS_rightSV (M,Nkeep)
% [Unpublished; devised by S.Lee]
% Obtain the right singular vectors of a matrix M associated with largest
% singular values. When there are more than Nkeep singular vectors, the
% function chooses only Nkeep singular vectors.

% as left singular vectors are not needed and the row dimension is huge, we
% first contract M to use standard eigendecomposition.
M2 = M'*M;
[V,D] = eig(M2+M2');
[~,ids] = sort(diag(D),'descend');
ids((Nkeep+1):end) = [];
V = V(:,ids);

end

function F = ML_MPS_normalize(F)
% [Unpublished; devised by S.Lee]
% Normalize each row of matrix F so that the row vector has norm 1. Within
% the funciton ML_MPS, each row means the feature vectors of each data set.

F = F./sqrt(sum(abs(F).^2,2));
% in case of divide-by-zero
F(isinf(F)) = 0;
F(isnan(F)) = 0;

end
