#define CATCH_CONFIG_MAIN
#include <inc/catch.hpp>
#include <inc/la_tools.h>

TEST_CASE("Test commutation relations")
{
    typedef typename Lattice::Spin::MatrixType MatrixType;
    typedef MatrixType::NumberType T;

    Lattice::Spin lattice(0.5, 3);

    MatrixType sx1, sy1, sz1, sx2, sy2, sz2, id, zero(lattice.get_dimension(), lattice.get_dimension());
    lattice.gen_sx(0, sx1);
    lattice.gen_sy(0, sy1);
    lattice.gen_sz(0, sz1);
    lattice.gen_sx(1, sx2);
    lattice.gen_sy(1, sy2);
    lattice.gen_sz(1, sz2);
    lattice.gen_identity(id);

    MatrixType buf;

    // check commutators at different sites for sx and *
    buf = sx1 * sx2;
    REQUIRE(MatrixType(sx2*sx1 + T(-1.0)*buf) == zero);
    buf = sx1 * sy2;
    REQUIRE(MatrixType(sy2*sx1 + T(-1.0)*buf) == zero);
    buf = sx1 * sz2;
    REQUIRE(MatrixType(sz2*sx1 + T(-1.0)*buf) == zero);

    // check commutators at different sites for sy and *
    buf = sy1 * sx2;
    REQUIRE(MatrixType(sx2*sy1 + T(-1.0)*buf) == zero);
    buf = sy1 * sy2;
    REQUIRE(MatrixType(sy2*sy1 + T(-1.0)*buf) == zero);
    buf = sy1 * sz2;
    REQUIRE(MatrixType(sz2*sy1 + T(-1.0)*buf) == zero);

    // check commutators at different sites for sz and *
    buf = sz1 * sx2;
    REQUIRE(MatrixType(sx2*sz1 + T(-1.0)*buf) == zero);
    buf = sz1 * sy2;
    REQUIRE(MatrixType(sy2*sz1 + T(-1.0)*buf) == zero);
    buf = sz1 * sz2;
    REQUIRE(MatrixType(sz2*sz1 + T(-1.0)*buf) == zero);

    // check commutators at same sites for sx and *
    buf = sx1 * sx1;
    REQUIRE(MatrixType(sx1*sx1 + T(-1.0)*buf) == zero);
    buf = sx1 * sy1;
    REQUIRE(MatrixType(sy1*sx1 + T(-1.0)*buf) == MatrixType(T(0.0, -2.0)*sz1));
    buf = sx1 * sz1;
    REQUIRE(MatrixType(sz1*sx1 + T(-1.0)*buf) == MatrixType(T(0.0, 2.0)*sy1));

    // check commutators at same sites for sy and *
    buf = sy1 * sx1;
    REQUIRE(MatrixType(sx1*sy1 + T(-1.0)*buf) == MatrixType(T(0.0, 2.0)*sz1));
    buf = sy1 * sy1;
    REQUIRE(MatrixType(sy1*sy1 + T(-1.0)*buf) == zero);
    buf = sy1 * sz1;
    REQUIRE(MatrixType(sz1*sy1 + T(-1.0)*buf) == MatrixType(T(0.0, -2.0)*sx1));

    // check commutators at same sites for sz and *
    buf = sz1 * sx1;
    REQUIRE(MatrixType(sx1*sz1 + T(-1.0)*buf) == MatrixType(T(0.0, -2.0)*sy1));
    buf = sz1 * sy1;
    REQUIRE(MatrixType(sy1*sz1 + T(-1.0)*buf) == MatrixType(T(0.0, 2.0)*sx1));
    buf = sz1 * sz1;
    REQUIRE(MatrixType(sz1*sz1 + T(-1.0)*buf) == zero);

}

TEST_CASE("Test gen_polarized_state")
{
    typedef typename Lattice::Spin::MatrixType MatrixType;
    typedef MatrixType::NumberType T;

    Lattice::Spin lattice(0.5, 3);

    MatrixType sx1, sy1, sz1, sx2, sy2, sz2;
    lattice.gen_sx(0, sx1);
    lattice.gen_sy(0, sy1);
    lattice.gen_sz(0, sz1);
    lattice.gen_sx(1, sx2);
    lattice.gen_sy(1, sy2);
    lattice.gen_sz(1, sz2);

    MatrixType zx_pol;
    lattice.gen_polarized_state(Lattice::SpinBasis::X, zx_pol);
    // <x=1,1|X[1]|x=1,1>
    REQUIRE(lattice.scalar_product(zx_pol, sx1, zx_pol).real() == Approx(1.0).epsilon(1e-15)); REQUIRE(lattice.scalar_product(zx_pol, sx1, zx_pol).imag() == Approx(0.0).epsilon(1e-15));
    // <x=1,1|X[2]|x=1,1>
    REQUIRE(lattice.scalar_product(zx_pol, sx2, zx_pol).real() == Approx(1.0).epsilon(1e-15)); REQUIRE(lattice.scalar_product(zx_pol, sx2, zx_pol).imag() == Approx(0.0).epsilon(1e-15));
    // <x=1,1|Y[1]|x=1,1>
    REQUIRE(lattice.scalar_product(zx_pol, sy1, zx_pol).real() == Approx(0.0).epsilon(1e-15)); REQUIRE(lattice.scalar_product(zx_pol, sy1, zx_pol).imag() == Approx(0.0).epsilon(1e-15));
    // <x=1,1|Y[1]|x=1,1>
    REQUIRE(lattice.scalar_product(zx_pol, sy2, zx_pol).real() == Approx(0.0).epsilon(1e-15)); REQUIRE(lattice.scalar_product(zx_pol, sy2, zx_pol).imag() == Approx(0.0).epsilon(1e-15));
    // <x=1,1|Z[1]|x=1,1>
    REQUIRE(lattice.scalar_product(zx_pol, sz1, zx_pol).real() == Approx(0.0).epsilon(1e-15)); REQUIRE(lattice.scalar_product(zx_pol, sz1, zx_pol).imag() == Approx(0.0).epsilon(1e-15));
    // <x=1,1|Z[2]|x=1,1>
    REQUIRE(lattice.scalar_product(zx_pol, sz2, zx_pol).real() == Approx(0.0).epsilon(1e-15)); REQUIRE(lattice.scalar_product(zx_pol, sz2, zx_pol).imag() == Approx(0.0).epsilon(1e-15));

    MatrixType zy_pol;
    lattice.gen_polarized_state(Lattice::SpinBasis::Y, zy_pol);
    // <y=1,1|X[1]|y=1,1>
    REQUIRE(lattice.scalar_product(zy_pol, sx1, zy_pol).real() == Approx(0.0).epsilon(1e-15)); REQUIRE(lattice.scalar_product(zy_pol, sx1, zy_pol).imag() == Approx(0.0).epsilon(1e-15));
    // <y=1,1|X[2]|y=1,1>
    REQUIRE(lattice.scalar_product(zy_pol, sx2, zy_pol).real() == Approx(0.0).epsilon(1e-15)); REQUIRE(lattice.scalar_product(zy_pol, sx2, zy_pol).imag() == Approx(0.0).epsilon(1e-15));
    // <y=1,1|Y[1]|y=1,1>
    REQUIRE(lattice.scalar_product(zy_pol, sy1, zy_pol).real() == Approx(1.0).epsilon(1e-15)); REQUIRE(lattice.scalar_product(zy_pol, sy1, zy_pol).imag() == Approx(0.0).epsilon(1e-15));
    // <y=1,1|Y[2]|y=1,1>
    REQUIRE(lattice.scalar_product(zy_pol, sy2, zy_pol).real() == Approx(1.0).epsilon(1e-15)); REQUIRE(lattice.scalar_product(zy_pol, sy2, zy_pol).imag() == Approx(0.0).epsilon(1e-15));
    // <y=1,1|Z[1]|y=1,1>
    REQUIRE(lattice.scalar_product(zy_pol, sz1, zy_pol).real() == Approx(0.0).epsilon(1e-15)); REQUIRE(lattice.scalar_product(zy_pol, sz1, zy_pol).imag() == Approx(0.0).epsilon(1e-15));
    // <y=1,1|Z[2]|y=1,1>
    REQUIRE(lattice.scalar_product(zy_pol, sz2, zy_pol).real() == Approx(0.0).epsilon(1e-15)); REQUIRE(lattice.scalar_product(zy_pol, sz2, zy_pol).imag() == Approx(0.0).epsilon(1e-15));

    MatrixType zz_pol;
    lattice.gen_polarized_state(Lattice::SpinBasis::Z, zz_pol);
    // <z=1,1|X[1]|z=1,1>
    REQUIRE(lattice.scalar_product(zz_pol, sx1, zz_pol).real() == Approx(0.0).epsilon(1e-15)); REQUIRE(lattice.scalar_product(zz_pol, sx1, zz_pol).imag() == Approx(0.0).epsilon(1e-15));
    // <z=1,1|X[2]|z=1,1>
    REQUIRE(lattice.scalar_product(zz_pol, sx2, zz_pol).real() == Approx(0.0).epsilon(1e-15)); REQUIRE(lattice.scalar_product(zz_pol, sx2, zz_pol).imag() == Approx(0.0).epsilon(1e-15));
    // <z=1,1|Y[1]|z=1,1>
    REQUIRE(lattice.scalar_product(zz_pol, sy1, zz_pol).real() == Approx(0.0).epsilon(1e-15)); REQUIRE(lattice.scalar_product(zz_pol, sy1, zz_pol).imag() == Approx(0.0).epsilon(1e-15));
    // <z=1,1|Y[2]|z=1,1>
    REQUIRE(lattice.scalar_product(zz_pol, sy2, zz_pol).real() == Approx(0.0).epsilon(1e-15)); REQUIRE(lattice.scalar_product(zz_pol, sy2, zz_pol).imag() == Approx(0.0).epsilon(1e-15));
    // <z=1,1|Z[1]|z=1,1>
    REQUIRE(lattice.scalar_product(zz_pol, sz1, zz_pol).real() == Approx(1.0).epsilon(1e-15)); REQUIRE(lattice.scalar_product(zz_pol, sz1, zz_pol).imag() == Approx(0.0).epsilon(1e-15));
    // <z=1,1|Z[2]|z=1,1>
    REQUIRE(lattice.scalar_product(zz_pol, sz2, zz_pol).real() == Approx(1.0).epsilon(1e-15)); REQUIRE(lattice.scalar_product(zz_pol, sz2, zz_pol).imag() == Approx(0.0).epsilon(1e-15));
}

TEST_CASE("Test gen_neel_state")
{
    typedef typename Lattice::Spin::MatrixType MatrixType;
    typedef MatrixType::NumberType T;

    Lattice::Spin lattice(0.5, 3);

    MatrixType sx1, sy1, sz1, sx2, sy2, sz2;
    lattice.gen_sx(0, sx1);
    lattice.gen_sy(0, sy1);
    lattice.gen_sz(0, sz1);
    lattice.gen_sx(1, sx2);
    lattice.gen_sy(1, sy2);
    lattice.gen_sz(1, sz2);

    MatrixType zx_pol;
    lattice.gen_neel_state(Lattice::SpinBasis::X, zx_pol);
    // <x=1,-1|X[1]|x=1,-1>
    REQUIRE(lattice.scalar_product(zx_pol, sx1, zx_pol).real() == Approx(1.0).epsilon(1e-15)); REQUIRE(lattice.scalar_product(zx_pol, sx1, zx_pol).imag() == Approx(0.0).epsilon(1e-15));
    // <x=1,-1|X[2]|x=1,-1>
    REQUIRE(lattice.scalar_product(zx_pol, sx2, zx_pol).real() == Approx(-1.0).epsilon(1e-15)); REQUIRE(lattice.scalar_product(zx_pol, sx2, zx_pol).imag() == Approx(0.0).epsilon(1e-15));
    // <x=1,-1|Y[1]|x=1,-1>
    REQUIRE(lattice.scalar_product(zx_pol, sy1, zx_pol).real() == Approx(0.0).epsilon(1e-15)); REQUIRE(lattice.scalar_product(zx_pol, sy1, zx_pol).imag() == Approx(0.0).epsilon(1e-15));
    // <x=1,-1|Y[1]|x=1,-1>
    REQUIRE(lattice.scalar_product(zx_pol, sy2, zx_pol).real() == Approx(0.0).epsilon(1e-15)); REQUIRE(lattice.scalar_product(zx_pol, sy2, zx_pol).imag() == Approx(0.0).epsilon(1e-15));
    // <x=1,-1|Z[1]|x=1,-1>
    REQUIRE(lattice.scalar_product(zx_pol, sz1, zx_pol).real() == Approx(0.0).epsilon(1e-15)); REQUIRE(lattice.scalar_product(zx_pol, sz1, zx_pol).imag() == Approx(0.0).epsilon(1e-15));
    // <x=1,-1|Z[2]|x=1,-1>
    REQUIRE(lattice.scalar_product(zx_pol, sz2, zx_pol).real() == Approx(0.0).epsilon(1e-15)); REQUIRE(lattice.scalar_product(zx_pol, sz2, zx_pol).imag() == Approx(0.0).epsilon(1e-15));

    MatrixType zy_pol;
    lattice.gen_neel_state(Lattice::SpinBasis::Y, zy_pol);
    // <y=1,-1|X[1]|y=1,-1>
    REQUIRE(lattice.scalar_product(zy_pol, sx1, zy_pol).real() == Approx(0.0).epsilon(1e-15)); REQUIRE(lattice.scalar_product(zy_pol, sx1, zy_pol).imag() == Approx(0.0).epsilon(1e-15));
    // <y=1,-1|X[2]|y=1,-1>
    REQUIRE(lattice.scalar_product(zy_pol, sx2, zy_pol).real() == Approx(0.0).epsilon(1e-15)); REQUIRE(lattice.scalar_product(zy_pol, sx2, zy_pol).imag() == Approx(0.0).epsilon(1e-15));
    // <y=1,-1|Y[1]|y=1,-1>
    REQUIRE(lattice.scalar_product(zy_pol, sy1, zy_pol).real() == Approx(1.0).epsilon(1e-15)); REQUIRE(lattice.scalar_product(zy_pol, sy1, zy_pol).imag() == Approx(0.0).epsilon(1e-15));
    // <y=1,-1|Y[2]|y=1,-1>
    REQUIRE(lattice.scalar_product(zy_pol, sy2, zy_pol).real() == Approx(-1.0).epsilon(1e-15)); REQUIRE(lattice.scalar_product(zy_pol, sy2, zy_pol).imag() == Approx(0.0).epsilon(1e-15));
    // <y=1,-1|Z[1]|y=1,-1>
    REQUIRE(lattice.scalar_product(zy_pol, sz1, zy_pol).real() == Approx(0.0).epsilon(1e-15)); REQUIRE(lattice.scalar_product(zy_pol, sz1, zy_pol).imag() == Approx(0.0).epsilon(1e-15));
    // <y=1,-1|Z[2]|y=1,-1>
    REQUIRE(lattice.scalar_product(zy_pol, sz2, zy_pol).real() == Approx(0.0).epsilon(1e-15)); REQUIRE(lattice.scalar_product(zy_pol, sz2, zy_pol).imag() == Approx(0.0).epsilon(1e-15));

    MatrixType zz_pol;
    lattice.gen_neel_state(Lattice::SpinBasis::Z, zz_pol);
    // <z=1,-1|X[1]|z=1,-1>
    REQUIRE(lattice.scalar_product(zz_pol, sx1, zz_pol).real() == Approx(0.0).epsilon(1e-15)); REQUIRE(lattice.scalar_product(zz_pol, sx1, zz_pol).imag() == Approx(0.0).epsilon(1e-15));
    // <z=1,-1|X[2]|z=1,-1>
    REQUIRE(lattice.scalar_product(zz_pol, sx2, zz_pol).real() == Approx(0.0).epsilon(1e-15)); REQUIRE(lattice.scalar_product(zz_pol, sx2, zz_pol).imag() == Approx(0.0).epsilon(1e-15));
    // <z=1,-1|Y[1]|z=1,-1>
    REQUIRE(lattice.scalar_product(zz_pol, sy1, zz_pol).real() == Approx(0.0).epsilon(1e-15)); REQUIRE(lattice.scalar_product(zz_pol, sy1, zz_pol).imag() == Approx(0.0).epsilon(1e-15));
    // <z=1,-1|Y[2]|z=1,-1>
    REQUIRE(lattice.scalar_product(zz_pol, sy2, zz_pol).real() == Approx(0.0).epsilon(1e-15)); REQUIRE(lattice.scalar_product(zz_pol, sy2, zz_pol).imag() == Approx(0.0).epsilon(1e-15));
    // <z=1,-1|Z[1]|z=1,-1>
    REQUIRE(lattice.scalar_product(zz_pol, sz1, zz_pol).real() == Approx(1.0).epsilon(1e-15)); REQUIRE(lattice.scalar_product(zz_pol, sz1, zz_pol).imag() == Approx(0.0).epsilon(1e-15));
    // <z=1,-1|Z[2]|z=1,-1>
    REQUIRE(lattice.scalar_product(zz_pol, sz2, zz_pol).real() == Approx(-1.0).epsilon(1e-15)); REQUIRE(lattice.scalar_product(zz_pol, sz2, zz_pol).imag() == Approx(0.0).epsilon(1e-15));
}