#ifndef LATTICE
#define LATTICE

#include <cmath>
#include <la/la_operations.h>

namespace Lattice {

template <typename T>
class Base {

private:

    unsigned int sites_;
    unsigned int local_dim_;

public:

    typedef la_objects::LAMatrix<T> MatrixType;
    typedef la_objects::LAMatrix<typename MatrixType::BaseType> RealMatrixType;

    Base<T>(const unsigned int _sites, const unsigned int _local_dim)
        : sites_(_sites), local_dim_(_local_dim)
    {}

    const unsigned int& sites() const { return this->sites_; }
    const unsigned int& local_dim() const { return this->local_dim_; }
    const unsigned int get_dimension() const { return (int) pow((double) this->local_dim(), (double) this->sites()); }

    void gen_identity(MatrixType& _op) const {

        unsigned int dimension = this->get_dimension();
        _op.resize(dimension, dimension);
        for (unsigned int i = 0; i < dimension; i++) { _op(i,i) = 1.0; }
        
    };
    void gen_tensor_product_operator(const unsigned int _site, const MatrixType& _local_operator, MatrixType& _op) const {

        Base<T> llat(_site , this->local_dim());
        Base<T> rlat(this->sites() - _site - 1, this->local_dim());

        MatrixType buf, lid, rid;
        llat.gen_identity(lid);
        rlat.gen_identity(rid);

        la_operations::kronecker(lid, _local_operator, buf);
        la_operations::kronecker(buf, rid, _op);

    }
    void gen_tensor_product_operator(const unsigned int _site, const std::vector<MatrixType>& _local_operators, MatrixType& _op) const {

        Base<T> llat(_site , this->local_dim());
        Base<T> rlat(this->sites() - _site - _local_operators.size(), this->local_dim());

        MatrixType buf, tensor_buf, lid, rid;
        llat.gen_identity(lid);
        rlat.gen_identity(rid);

        tensor_buf = _local_operators[0];
        for (unsigned int i = 1; i < _local_operators.size(); i++){
            buf = tensor_buf;
            la_operations::kronecker(buf, _local_operators[i], tensor_buf);
        }
        la_operations::kronecker(lid, tensor_buf, buf);
        
        la_operations::kronecker(buf, rid, _op);

    }
    void gen_tensor_product_state(const std::vector<unsigned int>& _site_indices, MatrixType& _state) const {

        _state.resize(this->get_dimension(), 1);
        unsigned int state_index = 0;

        for (unsigned int i = 0; i < this->sites(); i++) {
            state_index = state_index * this->local_dim() + _site_indices[i % _site_indices.size()];
        }
        _state(state_index, 0) = 1.0;
    }

    MatrixType apply(const MatrixType& _operator, const MatrixType& _state) const {
        return MatrixType(_operator * _state);
    }

    MatrixType apply_adj(const MatrixType& _operator, const MatrixType& _state) const {
        return MatrixType(la_operations::adjoint(_operator) * _state);
    }

    typename MatrixType::NumberType scalar_product(const MatrixType& _larg, const MatrixType& _rarg) const {

        if ((_larg.n_cols() > 1) || (_rarg.n_cols() > 1)) {
            throw std::runtime_error("Failure calculating scalar product, invalid argument dimensions");
        }

        return MatrixType(la_operations::adjoint(_larg) * _rarg)(0,0);

    }
    typename MatrixType::NumberType scalar_product(const MatrixType& _larg, const MatrixType& _operator, const MatrixType& _rarg) const {

        if ((_larg.n_cols() > 1) || (_rarg.n_cols() > 1)) {
            throw std::runtime_error("Failure calculating scalar product, invalid argument dimensions");
        }

        MatrixType buffer = this->apply(_operator, _rarg);
        return MatrixType(la_operations::adjoint(_larg) * buffer)(0,0);

    }

};

enum SpinBasis { Z, X, Y };

class Spin: public Base<std::complex<double>> {

private:

    double spin;
    SpinBasis basis;

public:

    Spin(const double& _s, const unsigned int _sites, SpinBasis _basis = SpinBasis::Z);

    double get_spin() const { return this->spin; };
    SpinBasis get_basis() const { return this->basis; };

    void gen_sx(const unsigned int _site, MatrixType& _op) const;
    void gen_sy(const unsigned int _site, MatrixType& _op) const;
    void gen_sz(const unsigned int _site, MatrixType& _op) const;

    void gen_polarized_state(const SpinBasis _direction, MatrixType& _state, bool _anti_parallel = false);
    void gen_neel_state(const SpinBasis _direction, MatrixType& _state, bool _anti_parallel = false);

};

} // END NAMESPACE LATTICE

#endif // LATTICE