#ifndef LA_OPERATIONS
#define LA_OPERATIONS

#include <la/la_wrapper.h>
#include <la/la_objects.h>

namespace la_operations
{

template <typename T>
void copy_data(const la_objects::LABaseObject<T>& _src, la_objects::LABaseObject<T>& _dest)
{
    _dest.resize(_src.n_rows(), _src.n_cols());
    blas_wrapper::copy(_src.n_rows() * _src.n_cols(),
                       _src.get_data_ptr(),
                       1,
                       _dest.get_data_ptr(),
                       1);
}

template <typename T>
void scale(const T& _scale, la_objects::LABaseObject<T>& _dest)
{
    int elem_dist = 1;
    int n = _dest.n_rows() * _dest.n_cols();

    blas_wrapper::scal(n, _scale, _dest.get_data_ptr(), elem_dist);
}

template <typename T>
void add(const T& _lscal,
         const la_objects::LAMatrix<T>& _larg,
         const T& _rscal,
         const la_objects::LAMatrix<T>& _rarg,
         la_objects::LAMatrix<T>& _dest)
{
    copy_data(_rarg, _dest);
    scale(_rscal, _dest);

    blas_wrapper::axpy(_larg.n_rows() * _larg.n_cols(),
                       _lscal,
                       _larg.get_data_ptr(),
                       1,
                       _dest.get_data_ptr(),
                       1);
}

template <typename T>
void add(const la_objects::LAMatrix<T>& _arg,
         la_objects::LAMatrix<T>& _dest)
{
    blas_wrapper::axpy(_dest.n_rows() * _dest.n_cols(),
                       1.0,
                       _arg.get_data_ptr(),
                       1,
                       _dest.get_data_ptr(),
                       1);
}

template <typename T>
void contract(const la_objects::LAMatrix<T>& _larg,
              char _op_larg,
              const la_objects::LAMatrix<T>& _rarg,
              char _op_rarg,
              const T& _alpha,
              const la_objects::LAMatrix<T>& _carg,
              const T& _beta,
              la_objects::LAMatrix<T>& _dest)
{

    bool transpose_larg = ((_op_larg == 't') || (_op_larg == 'T') || (_op_larg == 'c') || (_op_larg == 'C'));
    bool transpose_rarg = ((_op_rarg == 't') || (_op_rarg == 'T') || (_op_rarg == 'c') || (_op_rarg == 'C'));

    unsigned int dest_row_dim = transpose_larg? _larg.n_cols(): _larg.n_rows();
    unsigned int dest_col_dim = transpose_rarg? _rarg.n_rows(): _rarg.n_cols();
    unsigned int inter_dim_left = transpose_larg? _larg.n_rows(): _larg.n_cols();
    unsigned int inter_dim_right = transpose_rarg? _rarg.n_cols(): _rarg.n_rows();

    if ((dest_row_dim != _carg.n_rows()) || (dest_col_dim != _carg.n_cols())) {

        std::stringstream msg("");
        msg << "Failure contracting α*op(A)*op(B)+β*C, dimensional mismatch between op(A)*op(B) with dim: " << dest_row_dim << "x" << dest_col_dim << 
               " and C with dim: " << _carg.n_rows() << "x" << _carg.n_cols() << std::endl;
        throw std::runtime_error(msg.str());
    }

    if ((inter_dim_left != inter_dim_right)) {
        std::stringstream msg("");
        msg << "Failure contracting α*op(A)*op(B)+β*C, dimensional mismatch between op(A) with dim: " << dest_row_dim << "x" << inter_dim_left << 
               " and op(B) with dim: " << inter_dim_right << "x" << dest_col_dim << std::endl;
        throw std::runtime_error(msg.str());
    }

    copy_data(_carg, _dest);
    blas_wrapper::gemm(_op_larg,
                       _op_rarg,
                       dest_row_dim,
                       dest_col_dim,
                       inter_dim_left,
                       _alpha,
                       _larg.get_data_ptr(),
                       _larg.leading_dim(),
                       _rarg.get_data_ptr(),
                       _rarg.leading_dim(),
                       _beta,
                       _dest.get_data_ptr(),
                       _dest.leading_dim());
}

template <typename T>
void contract(const la_objects::LAMatrix<T>& _larg, char _op_larg,
              const la_objects::LAMatrix<T>& _rarg, char _op_rarg,
              la_objects::LAMatrix<T>& _dest)
{
    
    bool transpose_larg = ((_op_larg == 't') || (_op_larg == 'T') || (_op_larg == 'c') || (_op_larg == 'C'));
    bool transpose_rarg = ((_op_rarg == 't') || (_op_rarg == 'T') || (_op_rarg == 'c') || (_op_rarg == 'C'));

    unsigned int dest_row_dim = transpose_larg? _larg.n_cols(): _larg.n_rows();
    unsigned int dest_col_dim = transpose_rarg? _rarg.n_rows(): _rarg.n_cols();
    unsigned int inter_dim_left = transpose_larg? _larg.n_rows(): _larg.n_cols();
    unsigned int inter_dim_right = transpose_rarg? _rarg.n_cols(): _rarg.n_rows();

    if (inter_dim_left != inter_dim_right) {
        std::stringstream msg("");
        msg << "Failure contracting α*op(A)*op(B)+β*C, dimensional mismatch between op(A) with dim: " << dest_row_dim << "x" << inter_dim_left << 
               " and op(B) with dim: " << inter_dim_right << "x" << dest_col_dim << std::endl;
        throw std::runtime_error(msg.str());
    }

    _dest.resize(dest_row_dim, dest_col_dim);

    blas_wrapper::gemm(_op_larg,
                       _op_rarg,
                       dest_row_dim,
                       dest_col_dim,
                       inter_dim_left,
                       1.0,
                       _larg.get_data_ptr(),
                       _larg.leading_dim(),
                       _rarg.get_data_ptr(),
                       _rarg.leading_dim(),
                       0.0,
                       _dest.get_data_ptr(),
                       _dest.leading_dim());
}

template <typename T>
void contract(const la_objects::LAMatrix<T>& _larg, const la_objects::LAMatrix<T>& _rarg, la_objects::LAMatrix<T>& _dest)
{
    contract(_larg, 'n', _rarg, 'n', _dest);
}

template <typename T>
void kronecker(const la_objects::LAMatrix<T>& _larg, const la_objects::LAMatrix<T>& _rarg, la_objects::LAMatrix<T>& _dest)
{
    _dest.resize(_larg.n_rows() * _rarg.n_rows(), _larg.n_cols() * _rarg.n_rows());

    // for each element of _larg, we calculate _larg[r,c] * _rarg[:,cc] and use axpy do add resulting column to _dest at the correct position using col-major format of matrices
    for (unsigned int r = 0; r < _larg.n_rows(); r++) {
        for (unsigned int c = 0; c < _larg.n_cols(); c++) {
            for (unsigned int cc = 0; cc < _rarg.n_cols(); cc++) {

                blas_wrapper::axpy(_rarg.n_rows(), // number of elements per column
                                   _larg(r,c), // scaling factor for column from _rarg from _larg
                                   &_rarg(0, cc), // pointer to column from _rarg
                                   1, // increment for elements in _rarg -> 1 since col-major format used
                                   &_dest(_rarg.n_rows() * r, _rarg.n_cols() * c + cc), // pointer to begin of sub-column in _dest
                                   1); // increment for elements in _dest -> 1 since col-major format used

            }
        }
    }
}

template <typename T>
void evd(const la_objects::LAMatrix<T>& _A, la_objects::LAMatrix<typename la_objects::LAMatrix<T>::BaseType >& _D, la_objects::LAMatrix<T>& _V)
{
    // check for sqaure matrix
    if (_A.n_rows() != _A.n_cols()) {
        std::stringstream msg("");
        msg << "Error in evd, n_rows " << _A.n_rows() << " != " << _A.n_cols() << std::endl;
        throw std::runtime_error(msg.str());
    }

    // init buffer for eigenvalues
    _D.resize(_A.n_rows(), 1);

    // _V is used as input argument and overwritten by xxev -> copy _A into _V
    copy_data(_A, _V);

    blas_wrapper::xxev('V', _V.n_rows(), _V.get_data_ptr(), _V.leading_dim(), _D.get_data_ptr());
}

template <typename T>
void evd(const la_objects::LAMatrix<T>& _A, la_objects::LAMatrix<typename la_objects::LAMatrix<T>::BaseType >& _D, la_objects::LAMatrix<T>& _V, const unsigned int _n_evecs)
{
    // check for sqaure matrix
    if (_A.n_rows() != _A.n_cols()) {
        std::stringstream msg("");
        msg << "Error in evd, n_rows " << _A.n_rows() << " != " << _A.n_cols() << std::endl;
        throw std::runtime_error(msg.str());
    }

    // init buffer for eigenvalues
    _D.resize(_A.n_rows(), 1);

    // _A is overwritten by xxevr -> copy _A into Abuf
    la_objects::LAMatrix<T> Abuf;
    copy_data(_A, Abuf);

    // init buffer for eigenvalues and eigenvectors
    la_objects::LAMatrix<T> VBuf;
    la_objects::LAMatrix<typename la_objects::LAMatrix<T>::BaseType> DBuf;
    DBuf.resize(_A.n_rows(), 1);
    VBuf.resize(_A.n_rows(), _A.n_cols());

    lapack_int m;
    std::vector<lapack_int> suppz(2*_A.n_rows());

    blas_wrapper::xxevr('V',
                        'I',
                        Abuf.n_rows(),
                        Abuf.get_data_ptr(),
                        Abuf.leading_dim(),
                        0.0,
                        0.0,
                        1,
                        1 + _n_evecs-1,
                        0.0,
                        &m,
                        DBuf.get_data_ptr(),
                        VBuf.get_data_ptr(),
                        VBuf.leading_dim(),
                        &suppz[0]);

    // fill actually calculated eigenvalues and vectors
    _D.resize(m, 1);
    _V.resize(_A.n_rows(), m);

    blas_wrapper::copy(_D.n_rows() * _D.n_cols(),
                    DBuf.get_data_ptr(),
                    1,
                    _D.get_data_ptr(),
                    1);
    blas_wrapper::copy(_V.n_rows() * _V.n_cols(),
                    VBuf.get_data_ptr(),
                    1,
                    _V.get_data_ptr(),
                    1);
    }

} // END NAMESPACE la_operations

#endif // LA_OPERATIONS

