#include <inc/la_tools.h>
#include <ed/models.h>

void models::gen_tfim(const Lattice::Spin& _lattice, const double& _J, const double& _h, const double& _hz, Lattice::Spin::MatrixType& _H)
{
    typedef Lattice::Spin::MatrixType MatrixType;
    typedef typename MatrixType::NumberType T;

    _H.resize(_lattice.get_dimension(), _lattice.get_dimension());

    // create reduced lattices to efficiently initalize operators
    Lattice::Spin reduced_lattice(_lattice.get_spin(), 1, _lattice.get_basis());
    Lattice::Spin half_lattice_left(_lattice.get_spin(), std::floor(_lattice.sites()/2), _lattice.get_basis());
    Lattice::Spin half_lattice_right(_lattice.get_spin(), _lattice.sites() - half_lattice_left.sites(), _lattice.get_basis());

    MatrixType buf, sz_p, sz_m, sx, sxz;
    reduced_lattice.gen_sz(0, sz_p);
    reduced_lattice.gen_sz(0, sz_m);
    reduced_lattice.gen_sx(0, sx);

    sz_m *= T(-1.0 * _J);
    sxz   = T(-1.0 * _h)*sx + T(-1.0 * _hz)*sz_p;

    // init terms from i=0 to i=floor(L/2)
    MatrixType half_h_left(half_lattice_left.get_dimension(), half_lattice_left.get_dimension());
    for (unsigned int i = 0; i < half_lattice_left.sites(); i++) {

        // exclude term connecting left and right part, we do it in the end
        if (i < half_lattice_left.sites() - 1){
            // ZZ
            half_lattice_left.gen_tensor_product_operator(i, {sz_p, sz_m}, buf);
            half_h_left += buf;
        }

        // X+Z
        half_lattice_left.gen_tensor_product_operator(i, sxz, buf);
        half_h_left += buf;
    }

    // init terms from floor(i/2)+1 to _lattice.size()
    MatrixType half_h_right(half_lattice_right.get_dimension(), half_lattice_right.get_dimension());
    for (unsigned int i = half_lattice_left.sites(); i < _lattice.sites(); i++) {
    
        unsigned int i_right(i - half_lattice_left.sites());
        if (i < _lattice.sites() - 1){
            // ZZ
            half_lattice_right.gen_tensor_product_operator(i_right, {sz_p, sz_m}, buf);
            half_h_right += buf;
        }

        // X+Z
        half_lattice_right.gen_tensor_product_operator(i_right, sxz, buf);
        half_h_right += buf;

    }

    // now put things together
    MatrixType lid, rid;
    half_lattice_left.gen_identity(lid);
    half_lattice_right.gen_identity(rid);

    la_operations::kronecker(half_h_left, rid, _H);
    la_operations::kronecker(lid, half_h_right, buf);

    _H += buf;

    // add missing bond term between left and right part
    if (_lattice.sites() > 1) {
        _lattice.gen_tensor_product_operator(half_lattice_left.sites()-1, {sz_p, sz_m}, buf);
        _H += buf;
    }
}