#define CATCH_CONFIG_MAIN
#include <inc/catch.hpp>
#include <la/la_operations.h>
#include <la/la_instantiations.h>

TEST_CASE("Test float matrix-matrix multiplication")
{
    typedef float Type;
    la_objects::LAMatrix<Type> A(2, 2, {0.0, 1.0, 1.0, 0.0});
    la_objects::LAMatrix<Type> B(2, 2, {0.0, 1.0, 1.0, 0.0});

    la_objects::LAMatrix<Type> C;
    C = A * B;

    REQUIRE(C == la_objects::LAMatrix<Type>(2, 2, {1.0, 0.0, 0.0, 1.0}));
}

TEST_CASE("Test double matrix-matrix multiplication")
{
    typedef double Type;
    la_objects::LAMatrix<Type> A(2, 2, {0.0, 1.0, 1.0, 0.0});
    la_objects::LAMatrix<Type> B(2, 2, {0.0, 1.0, 1.0, 0.0});

    la_objects::LAMatrix<Type> C;
    C = A * B;

    REQUIRE(C == la_objects::LAMatrix<Type>(2, 2, {1.0, 0.0, 0.0, 1.0}));
}

TEST_CASE("Test gemm")
{
    typedef double Type;
    la_objects::LAMatrix<Type> A(2, 2, {0.0, 1.0, 1.0, 0.0});
    la_objects::LAMatrix<Type> B(2, 2, {0.0, 1.0, 1.0, 0.0});
    la_objects::LAMatrix<Type> C(2, 2, {0.0, 1.0, 1.0, 0.0});

    la_objects::LAMatrix<Type> D;
    D = (A * B) + (Type(2.0) * C);

    REQUIRE(D == la_objects::LAMatrix<Type>(2, 2, {1.0, 2.0, 2.0, 1.0}));
}

TEST_CASE("Test matrix diagonalization")
{
    typedef std::complex<double> Type;
    typedef typename la_objects::LAMatrix<Type>::BaseType BaseType;

    la_objects::LAMatrix<Type> A(2, 2, {0.0, Type(0.0, 1.0), Type(0.0, -1.0), 0.0});
    la_objects::LAMatrix<Type> V;
    la_objects::LAMatrix<BaseType> D;

    la_operations::evd(A, D, V);

    REQUIRE(D == la_objects::LAMatrix<BaseType>(2, 1, {-1.0, 1.0}));

    la_objects::LAMatrix<Type> B;
    B = la_operations::adjoint(V) * V;
    REQUIRE(B == la_objects::LAMatrix<Type>(2, 2, {1.0, 0.0, 0.0, 1.0}));
}

TEST_CASE("Test reduced matrix diagonalization")
{
    typedef std::complex<double> Type;
    typedef typename la_objects::LAMatrix<Type>::BaseType BaseType;

    la_objects::LAMatrix<Type> A(2, 2, {0.0, Type(0.0, 1.0), Type(0.0, -1.0), 0.0});
    la_objects::LAMatrix<Type> V;
    la_objects::LAMatrix<BaseType> D;

    la_operations::evd(A, D, V, 1);

    REQUIRE(D == la_objects::LAMatrix<BaseType>(1, 1, {-1.0}));

    la_objects::LAMatrix<Type> B;
    B = la_operations::adjoint(V) * V;
    REQUIRE(B == la_objects::LAMatrix<Type>(1, 1, {1.0}));
}

TEST_CASE("Test kronecker product")
{
    typedef double Type;
    
    la_objects::LAMatrix<Type> A(2, 2, {0.0, 2.0, 1.0, 0.0});
    la_objects::LAMatrix<Type> B(2, 2, {2.0, 0.0, 0.0, 1.0});
    la_objects::LAMatrix<Type> C;

    // |0 1| times  |2 0|  =  |0 0 2 0|
    // |2 0|        |0 1|     |0 0 0 1|
    //                        |4 0 0 0|
    //                        |0 2 0 0|

    la_operations::kronecker(A, B, C);

    REQUIRE(C == la_objects::LAMatrix<Type>(4, 4, {0.0, 0.0, 4.0, 0.0, 
                                                   0.0, 0.0, 0.0, 2.0,
                                                   2.0, 0.0, 0.0, 0.0,
                                                   0.0, 1.0, 0.0, 0.0}));
}