#define CATCH_CONFIG_MAIN
#include <inc/catch.hpp>
#include <ed/lattice.h>
#include <la/la_instantiations.h>

TEST_CASE("Test gen_tensor_product_operator")
{
    typedef double Type;
    Lattice::Base<Type> lattice(2, 2);

    la_objects::LAMatrix<Type> op;
    lattice.gen_tensor_product_operator(0, la_objects::LAMatrix<Type>(2, 2, {0.0, 1.0, 1.0, 0.0}), op);

    REQUIRE(op == la_objects::LAMatrix<Type>(4, 4, {0.0, 0.0, 1.0, 0.0,
                                                    0.0, 0.0, 0.0, 1.0,
                                                    1.0, 0.0, 0.0, 0.0,
                                                    0.0, 1.0, 0.0, 0.0}));
}

TEST_CASE("Test gen_tensor_product_state")
{
    typedef double Type;
    Lattice::Base<Type> lattice(2, 2);

    la_objects::LAMatrix<Type> state;
    lattice.gen_tensor_product_state({0, 0}, state);
    REQUIRE(state == la_objects::LAMatrix<Type>(4, 1, {1.0, 0.0, 0.0, 0.0}));

    lattice.gen_tensor_product_state({0, 1}, state);
    REQUIRE(state == la_objects::LAMatrix<Type>(4, 1, {0.0, 1.0, 0.0, 0.0}));

    lattice.gen_tensor_product_state({1, 0}, state);
    REQUIRE(state == la_objects::LAMatrix<Type>(4, 1, {0.0, 0.0, 1.0, 0.0}));

    lattice.gen_tensor_product_state({1, 1}, state);
    REQUIRE(state == la_objects::LAMatrix<Type>(4, 1, {0.0, 0.0, 0.0, 1.0}));
}

TEST_CASE("Test apply")
{
    typedef double Type;
    Lattice::Base<Type> lattice(2, 2);

    // generate pauli x-gate @site 0
    la_objects::LAMatrix<Type> op;
    lattice.gen_tensor_product_operator(0, la_objects::LAMatrix<Type>(2, 2, {0.0, 1.0, 1.0, 0.0}), op);

    // generate |0> x |0>
    la_objects::LAMatrix<Type> state;
    lattice.gen_tensor_product_state({0, 0}, state);
    
    REQUIRE(lattice.apply(op, state) == la_objects::LAMatrix<Type>(4, 1, {0.0, 0.0, 1.0, 0.0}));
}

TEST_CASE("Test scalar_product")
{
    typedef double Type;
    Lattice::Base<Type> lattice(2, 2);

    // generate pauli x-gate @site 0
    la_objects::LAMatrix<Type> op;
    lattice.gen_tensor_product_operator(0, la_objects::LAMatrix<Type>(2, 2, {0.0, 1.0, 1.0, 0.0}), op);

    // generate |0> x |0>
    la_objects::LAMatrix<Type> state00;
    lattice.gen_tensor_product_state({0, 0}, state00);

    // generate |1> x |0>
    la_objects::LAMatrix<Type> state10;
    lattice.gen_tensor_product_state({1, 0}, state10);
    
    REQUIRE(lattice.scalar_product(state10, op, state00) == Type(1.0));
}