#include <ed/lattice.h>
#include <inc/la_tools.h>

Lattice::Spin::Spin(const double& _s, const unsigned int _sites, SpinBasis _basis)
    : Lattice::Base<std::complex<double>>(_sites, (unsigned int) (2*_s)+1), spin(_s), basis(_basis) {}

void Lattice::Spin::gen_sx(const unsigned int _site, MatrixType& _op) const
{
    typedef typename MatrixType::NumberType T;
    std::vector<T> rep;

    switch(this->basis) {
        case SpinBasis::X:
        {
            rep = {T(-1.0), T(0.0), T(0.0), T(1.0)};
            break;
        }
        case SpinBasis::Y:
        {
            throw std::runtime_error("sx rep in y-basis not implemented yet");
            break;
        }
        case SpinBasis::Z:
        {
            rep = {T(0.0), T(1.0), T(1.0), T(0.0)};
            break;
        }
    }
    this->gen_tensor_product_operator(_site, MatrixType(this->local_dim(), this->local_dim(), rep), _op);
    // std::cout << "generated sx@" << _site << ": " << _op << std::endl;
}

void Lattice::Spin::gen_sy(const unsigned int _site, MatrixType& _op) const
{

    typedef typename MatrixType::NumberType T;
    std::vector<T> rep;

    switch(this->basis) {
        case SpinBasis::X:
        {
            throw std::runtime_error("sy rep in x-basis not implemented yet");
            break;
        }
        case SpinBasis::Y:
        {
            rep = {T(-1.0), T(0.0, 0.0), T(0.0, 0.0), T(1.0)};
            break;
        }
        case SpinBasis::Z:
        {
            rep = {T(0.0), T(0.0, -1.0), T(0.0, 1.0), T(0.0)};
            break;
        }
    }
    this->gen_tensor_product_operator(_site, MatrixType(this->local_dim(), this->local_dim(), rep), _op);
    // std::cout << "generated sy@" << _site << ": " << _op << std::endl;
}

void Lattice::Spin::gen_sz(const unsigned int _site, MatrixType& _op) const
{

    typedef typename MatrixType::NumberType T;
    std::vector<T> rep;

    switch(this->basis) {
        case SpinBasis::X:
        {
            throw std::runtime_error("sz rep in x-basis not implemented yet");
            break;
        }
        case SpinBasis::Y:
        {
            throw std::runtime_error("sz rep in y-basis not implemented yet");
            break;
        }
        case SpinBasis::Z:
        {
            rep = {T(-1.0), T(0.0), T(0.0), T(1.0)};
            break;
        }
    }
    this->gen_tensor_product_operator(_site, MatrixType(this->local_dim(), this->local_dim(), rep), _op);
    // std::cout << "generated sz@" << _site << ": " << _op << std::endl;
}

void Lattice::Spin::gen_polarized_state(const SpinBasis _direction, MatrixType& _state, bool _anti_parallel)
{
    std::vector<unsigned int> rep;

    switch(this->basis) {
        case SpinBasis::X:
        {
            throw std::runtime_error("polarized rep in x-basis not implemented yet");
            break;
        }
        case SpinBasis::Y:
        {
            throw std::runtime_error("polarized rep in y-basis not implemented yet");
            break;
        }
        case SpinBasis::Z:
        {
            rep = {_anti_parallel? 0u: 1u};
            switch(_direction) {
                case SpinBasis::X:
                {
                    // start from a polarized state along z-direction and then rotate each site
                    this->gen_tensor_product_state(rep, _state);

                    // need some buffers
                    MatrixType sx, v, V, state_buf; RealMatrixType d;

                    // create a single-site lattice so that we can get basis transformations cheaper
                    Lattice::Spin reduced_lattice(this->spin, 1, SpinBasis::Z);

                    // from eigenvectors of sx we get rotation matrix
                    reduced_lattice.gen_sx(0, sx);
                    la_operations::evd(sx, d, v);

                    for (unsigned int i = 0; i < this->sites(); i++) {    
                        this->gen_tensor_product_operator(i, v, V);
                        state_buf = _state;
                        _state = this->apply(V, state_buf);
                    }

                    break;
                }
                case SpinBasis::Y:
                {
                    // start from a polarized state along z-direction and then rotate each site
                    this->gen_tensor_product_state(rep, _state);

                    // need some buffers
                    MatrixType sy, v, V, state_buf; RealMatrixType d;

                    // create a single-site lattice so that we can get basis transformations cheaper
                    Lattice::Spin reduced_lattice(this->spin, 1, SpinBasis::Z);

                    // from eigenvectors of sy we get rotation matrix
                    reduced_lattice.gen_sy(0, sy);
                    la_operations::evd(sy, d, v);

                    for (unsigned int i = 0; i < this->sites(); i++) {    
                        this->gen_tensor_product_operator(i, v, V);
                        state_buf = _state;
                        _state = this->apply(V, state_buf);
                    }

                    break;
                }
                case SpinBasis::Z:
                {
                    this->gen_tensor_product_state(rep, _state);
                    break;
                }
            }
            break;
        }
    }
}

void Lattice::Spin::gen_neel_state(const SpinBasis _direction, MatrixType& _state, bool _anti_parallel)
{
    std::vector<unsigned int> rep;

    switch(this->basis) {
        case SpinBasis::X:
        {
            throw std::runtime_error("polarized rep in x-basis not implemented yet");
            break;
        }
        case SpinBasis::Y:
        {
            throw std::runtime_error("polarized rep in y-basis not implemented yet");
            break;
        }
        case SpinBasis::Z:
        {
            rep = {_anti_parallel? 0u, 1u: 1u, 0u};
            switch(_direction) {
                case SpinBasis::X:
                {
                    // start from a polarized state along z-direction and then rotate each site
                    this->gen_tensor_product_state(rep, _state);

                    // need some buffers
                    MatrixType sx, v, V, state_buf; RealMatrixType d;

                    // create a single-site lattice so that we can get basis transformations cheaper
                    Lattice::Spin reduced_lattice(this->spin, 1, SpinBasis::Z);

                    // from eigenvectors of sx we get rotation matrix
                    reduced_lattice.gen_sx(0, sx);
                    la_operations::evd(sx, d, v);

                    for (unsigned int i = 0; i < this->sites(); i++) {    
                        this->gen_tensor_product_operator(i, v, V);
                        state_buf = _state;
                        _state = this->apply(V, state_buf);
                    }

                    break;
                }
                case SpinBasis::Y:
                {
                    // start from a polarized state along z-direction and then rotate each site
                    this->gen_tensor_product_state(rep, _state);

                    // need some buffers
                    MatrixType sy, v, V, state_buf; RealMatrixType d;

                    // create a single-site lattice so that we can get basis transformations cheaper
                    Lattice::Spin reduced_lattice(this->spin, 1, SpinBasis::Z);

                    // from eigenvectors of sy we get rotation matrix
                    reduced_lattice.gen_sy(0, sy);
                    la_operations::evd(sy, d, v);

                    for (unsigned int i = 0; i < this->sites(); i++) {    
                        this->gen_tensor_product_operator(i, v, V);
                        state_buf = _state;
                        _state = this->apply(V, state_buf);
                    }

                    break;
                }
                case SpinBasis::Z:
                {
                    this->gen_tensor_product_state(rep, _state);               
                    break;
                }
            }
            break;
        }
    }
}