%% Machine learning with MPS
% Author: <https://www.theorie.physik.uni-muenchen.de/lsvondelft/members/sci_mem/seung_sup-lee/index.html 
% Seung-Sup Lee>
%% 
% Here we demonstrate how the MPS technology can be used in machine learning, 
% with the application to handwriting recognition. We implement the algorithm 
% introduced in Stoudenmire2016 [<https://papers.nips.cc/paper/6211-supervised-learning-with-tensor-networks 
% E. M. Stoudenmire and D. J. Schwab, Advances in Neural Information Processing 
% Systems *29*, 4799 (2016)> or <https://arxiv.org/abs/1605.05775 its arXiv version>]. 
% (Two versions of the paper have minor differences; here we follow the notation 
% of the published NeurIPS version.) The goal of this algorithm is to construct 
% the MPS such that the contraction of the "weight" MPS and feature vectors evaluates 
% how the input data is close to the data pattern with a certain label.
%% MNIST data
% First, we load the MNIST data of handwritten digits, from the .csv files zipped 
% together with this document. In these .csv files, each row corresponds one handwritten 
% digit. The first column indicates the correct labels. The rest of columns, of 
% column indices |(1:(28^2))+1|, indicate the gray-scale values (from black 0 
% to white 255) of 28 $\times$ 28 image pixels. For each row, the |(2:(28^2+1))| 
% elements are the concatenation of rows, i.e., |[(the_1st_row), (the_2nd_row), 
% ...]|.
% 
% For quicker demonstration, we use only a subset for training (i.e., optimizing 
% the MPS) and a smaller subset for test the performance of the trained MPS. (To 
% have better classification result, one needs to use the whole dataset; but it 
% will take more computational cost.)

clear

Ntrain = 2000; % Number of training datasets
Ntest = 100; % Number of test datasets for verification

data_train = csvread('MNIST_train.csv', 0, 0, [0 0 (Ntrain-1) (28*28)]);
data_test  = csvread('MNIST_test.csv',  0, 0, [0 0 (Ntest-1)  (28*28)]);
%% 
% The numbers are from 0 to 9.

labels = unique(data_train(:,1));
disp(labels.');
%% 
% Let's visualize the data.

figure;
imagesc(cell2mat(squeeze(num2cell( ...
    permute(reshape(data_train(1:25,2:end),[5 5 28 28]),[4 3 1 2]), ...
    [1 2]))));
% permute the dimensions of column/row of pixels for visualization, since
% MATLAB is column-major language
colormap(gray);
title('Training data (Original)','FontSize',14);
%% 
% These gray-scale images are labeled as:

disp(reshape(data_train(1:25,1),[5 5]))
%% 
% In Stoudenmire2016, the images are down-scaled by factor 2. By down-scaling, 
% we later can use shorter MPS of length $14^2 = 196$.

data_train = [data_train(:,1), ...
    reshape(mean(mean( ...
    reshape(data_train(:,(2:end)),[size(data_train,1) 2 14 2 14]), ...
    2),4),[size(data_train,1) 14^2])];
data_test = [data_test(:,1), ...
    reshape(mean(mean( ...
    reshape(data_test(:,(2:end)),[size(data_test,1) 2 14 2 14]), ...
    2),4),[size(data_test,1) 14^2])];
%% 
% The down-scaled images look like:

figure;
imagesc(cell2mat(squeeze(num2cell( ...
    permute(reshape(data_train(1:25,2:end),[5 5 14 14]),[4 3 1 2]), ...
    [1 2]))));
% permute the dimensions of column/row of pixels for visualization, since
% MATLAB is column-major language
colormap(gray);
title('Training data (Down-scaled)','FontSize',14);
%% Generate feature vectors and correct decision function
% Then the gray-scale pixels are individually mapped onto two-dimensional vectors 
% which are similar to the spin-1/2 spinors. In the machine learning context, 
% we will call such vectors as feature vectors.
% 
% We define the mapping from an integer $\in [0, 255]$ to a two-dimensional 
% vector so that for completely white pixels, the vector would be [0 1] and for 
% completely black pixels, [1 0]. We use Eq. (3) of Stoudenmire2016, which provides 
% the one-to-one correspondence between a pixel value and a vector.

dtmp = data_train(:,2:end)*(pi/2/255);
F_train = permute(cat(3,cos(dtmp),sin(dtmp)),[1 3 2]);
% F_train(m,:,n) is the 2-dimensional feature vector for the n-th pixel (=
% site) and the m-th image

% similarly for test data
dtmp = data_test(:,2:end)*(pi/2/255);
F_test = permute(cat(3,cos(dtmp),sin(dtmp)),[1 3 2]);
%% 
% And we also construct the matrix for correct decision function.

% for training data
y_train = zeros(size(data_train,1),numel(labels));
for itl = (1:numel(labels))
    y_train(data_train(:,1) == labels(itl),itl) = 1;
end
% y_train(m,n) is 1 if the m-th data (i.e. image) is labeled by the n-th 
% label, 0 otherwise.

% similarly for test data
y_test = zeros(size(data_test,1),numel(labels));
for itl = (1:numel(labels))
    y_test(data_test(:,1) == labels(itl),itl) = 1;
end
%% Exercise: Complete |ML_MPS_Ex.m|
% Here now we arrive at the last exercise in this lecture course (_Hurray!_). 
% There is a function |ML_MPS_Ex.m| zipped together with this document. *Complete 
% the parts which are enclosed by the comments |TODO - Exercise (a)*|, following 
% the description given in Sec. 4 of Stoudenmire2016. Once you complete the function, 
% you can follow the demonstration below.
% 
% Note that there are some important technical details that are not discussed 
% in Stoudenmire2016 and devised by Seung-Sup to achieve stability and performance. 
% Such parts in the code are denoted by the comment |Unpublished; devised by S.Lee|.
%% Machine learning of recognizing handwritten digits
% The prefactor |estep| ($\eta$ in Stoudenmire2016) to the gradient $\Delta 
% B$ is a parameter which can affect the convergence of the algorithm. The choice 
% of the parameter may depend on other parameters such as |Ntrain|, |Nkeep|, etc.

Nkeep = 20;
estep = [0.1 0.3 1 3 10];
[M,cfun,err,cfun_test,err_test] = ...
    ML_MPS_Ex([],F_train,y_train,F_test,y_test,Nkeep,estep);
%% 
% Plot how the cost function per dataset and the error rate of predicting correct 
% labels change.

figure;
plot((1:numel(cfun)).'/numel(M), ...
    [cfun(:),cfun_test(:),err(:),err_test(:)], ...
    'LineWidth',1,'LineStyle','-');
set(gca,'LineWidth',1,'FontSize',13,'YScale','linear');
legend({'Cost function (training)','Cost function (test)', ...
    'Error rate (training)','Error rate (test)'},'Location','best');
xlabel('# of sweeps')
grid on;
%% 
% The misclassification error rate can be further improved by tuning |estep|, 
% increasing |Nkeep|, and training more datasets.