#include <la/la_operations.h>
#include <la/la_expressions.h>

// here we have the apply function to resolve expressions and call blas/lapack routines

// c*Matrix
template <typename T>
void la_operations::apply(la_objects::LAMatrix<T>& _dest, CM<T>& _src_expr)
{
    la_operations::copy_data(_src_expr.rarg, _dest);
    _dest *= _src_expr.larg;
}

template void la_operations::apply<float>(la_objects::LAMatrix<float>& _dest, CM<float>& _src_expr);
template void la_operations::apply<double>(la_objects::LAMatrix<double>& _dest, CM<double>& _src_expr);
template void la_operations::apply<std::complex<float>>(la_objects::LAMatrix<std::complex<float>>& _dest, CM<std::complex<float>>& _src_expr);
template void la_operations::apply<std::complex<double>>(la_objects::LAMatrix<std::complex<double>>& _dest, CM<std::complex<double>>& _src_expr);

// Matrix+Matrix
template <typename T>
void la_operations::apply(la_objects::LAMatrix<T>& _dest, MpM<T>& _src_expr)
{
    la_operations::add(T(1.0), _src_expr.larg, T(1.0), _src_expr.rarg, _dest);
}

template void la_operations::apply<float>(la_objects::LAMatrix<float>& _dest, MpM<float>& _src_expr);
template void la_operations::apply<double>(la_objects::LAMatrix<double>& _dest, MpM<double>& _src_expr);
template void la_operations::apply<std::complex<float>>(la_objects::LAMatrix<std::complex<float>>& _dest, MpM<std::complex<float>>& _src_expr);
template void la_operations::apply<std::complex<double>>(la_objects::LAMatrix<std::complex<double>>& _dest, MpM<std::complex<double>>& _src_expr);

// a*Matrix+b*Matrix
template <typename T>
void la_operations::apply(la_objects::LAMatrix<T>& _dest, aMpbM<T>& _src_expr)
{
    CM<T>& l = _src_expr.larg;
    CM<T>& r = _src_expr.rarg;
    la_operations::add(l.larg, l.rarg, r.larg, r.rarg, _dest);
}

template void la_operations::apply<float>(la_objects::LAMatrix<float>& _dest, aMpbM<float>& _src_expr);
template void la_operations::apply<double>(la_objects::LAMatrix<double>& _dest, aMpbM<double>& _src_expr);
template void la_operations::apply<std::complex<float>>(la_objects::LAMatrix<std::complex<float>>& _dest, aMpbM<std::complex<float>>& _src_expr);
template void la_operations::apply<std::complex<double>>(la_objects::LAMatrix<std::complex<double>>& _dest, aMpbM<std::complex<double>>& _src_expr);

// Matrix*Matrix
template <typename T>
void la_operations::apply(la_objects::LAMatrix<T>& _dest, MM<T>& _src_expr)
{
    la_operations::contract(_src_expr.larg, _src_expr.rarg, _dest);
}

template void la_operations::apply<float>(la_objects::LAMatrix<float>& _dest, MM<float>& _src_expr);
template void la_operations::apply<double>(la_objects::LAMatrix<double>& _dest, MM<double>& _src_expr);
template void la_operations::apply<std::complex<float>>(la_objects::LAMatrix<std::complex<float>>& _dest, MM<std::complex<float>>& _src_expr);
template void la_operations::apply<std::complex<double>>(la_objects::LAMatrix<std::complex<double>>& _dest, MM<std::complex<double>>& _src_expr);

// Matrix*transpose(Matrix)
template <typename T>
void la_operations::apply(la_objects::LAMatrix<T>& _dest, MtM<T>& _src_expr)
{
    la_operations::contract(_src_expr.larg, 'n', _src_expr.rarg.arg, 't', _dest);
}

template void la_operations::apply<float>(la_objects::LAMatrix<float>& _dest, MtM<float>& _src_expr);
template void la_operations::apply<double>(la_objects::LAMatrix<double>& _dest, MtM<double>& _src_expr);
template void la_operations::apply<std::complex<float>>(la_objects::LAMatrix<std::complex<float>>& _dest, MtM<std::complex<float>>& _src_expr);
template void la_operations::apply<std::complex<double>>(la_objects::LAMatrix<std::complex<double>>& _dest, MtM<std::complex<double>>& _src_expr);

// transpose(Matrix)*Matrix
template <typename T>
void la_operations::apply(la_objects::LAMatrix<T>& _dest, tMM<T>& _src_expr)
{
    la_operations::contract(_src_expr.larg.arg, 't', _src_expr.rarg, 'n', _dest);
}

template void la_operations::apply<float>(la_objects::LAMatrix<float>& _dest, tMM<float>& _src_expr);
template void la_operations::apply<double>(la_objects::LAMatrix<double>& _dest, tMM<double>& _src_expr);
template void la_operations::apply<std::complex<float>>(la_objects::LAMatrix<std::complex<float>>& _dest, tMM<std::complex<float>>& _src_expr);
template void la_operations::apply<std::complex<double>>(la_objects::LAMatrix<std::complex<double>>& _dest, tMM<std::complex<double>>& _src_expr);

// Matrix*adjoint(Matrix)
template <typename T>
void la_operations::apply(la_objects::LAMatrix<T>& _dest, MaM<T>& _src_expr)
{
    la_operations::contract(_src_expr.larg, 'n', _src_expr.rarg.arg, 'c', _dest);
}

template void la_operations::apply<float>(la_objects::LAMatrix<float>& _dest, MaM<float>& _src_expr);
template void la_operations::apply<double>(la_objects::LAMatrix<double>& _dest, MaM<double>& _src_expr);
template void la_operations::apply<std::complex<float>>(la_objects::LAMatrix<std::complex<float>>& _dest, MaM<std::complex<float>>& _src_expr);
template void la_operations::apply<std::complex<double>>(la_objects::LAMatrix<std::complex<double>>& _dest, MaM<std::complex<double>>& _src_expr);

// adjoint(Matrix)*Matrix
template <typename T>
void la_operations::apply(la_objects::LAMatrix<T>& _dest, aMM<T>& _src_expr)
{
    la_operations::contract(_src_expr.larg.arg, 'c', _src_expr.rarg, 'n', _dest);
}

template void la_operations::apply<float>(la_objects::LAMatrix<float>& _dest, aMM<float>& _src_expr);
template void la_operations::apply<double>(la_objects::LAMatrix<double>& _dest, aMM<double>& _src_expr);
template void la_operations::apply<std::complex<float>>(la_objects::LAMatrix<std::complex<float>>& _dest, aMM<std::complex<float>>& _src_expr);
template void la_operations::apply<std::complex<double>>(la_objects::LAMatrix<std::complex<double>>& _dest, aMM<std::complex<double>>& _src_expr);

// Matrix*Matrix + beta*Matrix
template <typename T>
void la_operations::apply(la_objects::LAMatrix<T>& _dest, MMpCM<T>& _src_expr)
{
    MM<T> ab = _src_expr.larg;
    CM<T> betac = _src_expr.rarg;

    la_operations::contract(ab.larg, 'n', ab.rarg, 'n', T(1.0), betac.rarg, betac.larg, _dest);
}

template void la_operations::apply<float>(la_objects::LAMatrix<float>& _dest, MMpCM<float>& _src_expr);
template void la_operations::apply<double>(la_objects::LAMatrix<double>& _dest, MMpCM<double>& _src_expr);
template void la_operations::apply<std::complex<float>>(la_objects::LAMatrix<std::complex<float>>& _dest, MMpCM<std::complex<float>>& _src_expr);
template void la_operations::apply<std::complex<double>>(la_objects::LAMatrix<std::complex<double>>& _dest, MMpCM<std::complex<double>>& _src_expr);

// trans(Matrix)*Matrix + beta*Matrix
template <typename T>
void la_operations::apply(la_objects::LAMatrix<T>& _dest, tMMpCM<T>& _src_expr)
{
    tMM<T> ab = _src_expr.larg;
    CM<T> betac = _src_expr.rarg;

    la_operations::contract(ab.larg.arg, 't', ab.rarg, 'n', T(1.0), betac.rarg, betac.larg, _dest);
}

template void la_operations::apply<float>(la_objects::LAMatrix<float>& _dest, tMMpCM<float>& _src_expr);
template void la_operations::apply<double>(la_objects::LAMatrix<double>& _dest, tMMpCM<double>& _src_expr);
template void la_operations::apply<std::complex<float>>(la_objects::LAMatrix<std::complex<float>>& _dest, tMMpCM<std::complex<float>>& _src_expr);
template void la_operations::apply<std::complex<double>>(la_objects::LAMatrix<std::complex<double>>& _dest, tMMpCM<std::complex<double>>& _src_expr);

// Matrix*trans(Matrix) + beta*Matrix
template <typename T>
void la_operations::apply(la_objects::LAMatrix<T>& _dest, MtMpCM<T>& _src_expr)
{
    MtM<T> ab = _src_expr.larg;
    CM<T> betac = _src_expr.rarg;

    la_operations::contract(ab.larg, 'n', ab.rarg.arg, 't', T(1.0), betac.rarg, betac.larg, _dest);
}

template void la_operations::apply<float>(la_objects::LAMatrix<float>& _dest, MtMpCM<float>& _src_expr);
template void la_operations::apply<double>(la_objects::LAMatrix<double>& _dest, MtMpCM<double>& _src_expr);
template void la_operations::apply<std::complex<float>>(la_objects::LAMatrix<std::complex<float>>& _dest, MtMpCM<std::complex<float>>& _src_expr);
template void la_operations::apply<std::complex<double>>(la_objects::LAMatrix<std::complex<double>>& _dest, MtMpCM<std::complex<double>>& _src_expr);

// adj(Matrix)*Matrix + beta*Matrix
template <typename T>
void la_operations::apply(la_objects::LAMatrix<T>& _dest, aMMpCM<T>& _src_expr)
{
    aMM<T> ab = _src_expr.larg;
    CM<T> betac = _src_expr.rarg;

    la_operations::contract(ab.larg.arg, 'c', ab.rarg, 'n', T(1.0), betac.rarg, betac.larg, _dest);
}

template void la_operations::apply<float>(la_objects::LAMatrix<float>& _dest, aMMpCM<float>& _src_expr);
template void la_operations::apply<double>(la_objects::LAMatrix<double>& _dest, aMMpCM<double>& _src_expr);
template void la_operations::apply<std::complex<float>>(la_objects::LAMatrix<std::complex<float>>& _dest, aMMpCM<std::complex<float>>& _src_expr);
template void la_operations::apply<std::complex<double>>(la_objects::LAMatrix<std::complex<double>>& _dest, aMMpCM<std::complex<double>>& _src_expr);

// Matrix*adj(Matrix) + beta*Matrix
template <typename T>
void la_operations::apply(la_objects::LAMatrix<T>& _dest, MaMpCM<T>& _src_expr)
{
    MaM<T> ab = _src_expr.larg;
    CM<T> betac = _src_expr.rarg;

    la_operations::contract(ab.larg, 'n', ab.rarg.arg, 'c', T(1.0), betac.rarg, betac.larg, _dest);
}

template void la_operations::apply<float>(la_objects::LAMatrix<float>& _dest, MaMpCM<float>& _src_expr);
template void la_operations::apply<double>(la_objects::LAMatrix<double>& _dest, MaMpCM<double>& _src_expr);
template void la_operations::apply<std::complex<float>>(la_objects::LAMatrix<std::complex<float>>& _dest, MaMpCM<std::complex<float>>& _src_expr);
template void la_operations::apply<std::complex<double>>(la_objects::LAMatrix<std::complex<double>>& _dest, MaMpCM<std::complex<double>>& _src_expr);

// // here we have the apply function to resolve expressions and call blas/lapack routines

// // Matrix+Matrix
// extern template void la_operations::apply<float>(la_objects::LAMatrix<float>& _dest, MpM<float>& _src_expr);

// extern template void la_operations::apply<double>(la_objects::LAMatrix<double>& _dest, MpM<double>& _src_expr);

// extern template void la_operations::apply<std::complex<float>>(la_objects::LAMatrix<std::complex<float>>& _dest, MpM<std::complex<float>>& _src_expr);

// extern template void la_operations::apply<std::complex<double>>(la_objects::LAMatrix<std::complex<double>>& _dest, MpM<std::complex<double>>& _src_expr);

// // Matrix*Matrix
// extern template void la_operations::apply<float>(la_objects::LAMatrix<float>& _dest, MM<float>& _src_expr);

// extern template void la_operations::apply<double>(la_objects::LAMatrix<double>& _dest, MM<double>& _src_expr);

// extern template void la_operations::apply<std::complex<float>>(la_objects::LAMatrix<std::complex<float>>& _dest, MM<std::complex<float>>& _src_expr);

// extern template void la_operations::apply<std::complex<double>>(la_objects::LAMatrix<std::complex<double>>& _dest, MM<std::complex<double>>& _src_expr);

// // Matrix*transpose(Matrix)
// extern template void la_operations::apply<float>(la_objects::LAMatrix<float>& _dest, MtM<float>& _src_expr);

// extern template void la_operations::apply<double>(la_objects::LAMatrix<double>& _dest, MtM<double>& _src_expr);

// extern template void la_operations::apply<std::complex<float>>(la_objects::LAMatrix<std::complex<float>>& _dest, MtM<std::complex<float>>& _src_expr);

// extern template void la_operations::apply<std::complex<double>>(la_objects::LAMatrix<std::complex<double>>& _dest, MtM<std::complex<double>>& _src_expr);

// // Matrix*adjoint(Matrix)
// extern template void la_operations::apply<float>(la_objects::LAMatrix<float>& _dest, MaM<float>& _src_expr);

// extern template void la_operations::apply<double>(la_objects::LAMatrix<double>& _dest, MaM<double>& _src_expr);

// extern template void la_operations::apply<std::complex<float>>(la_objects::LAMatrix<std::complex<float>>& _dest, MaM<std::complex<float>>& _src_expr);

// extern template void la_operations::apply<std::complex<double>>(la_objects::LAMatrix<std::complex<double>>& _dest, MaM<std::complex<double>>& _src_expr);

// // Matrix*Matrix + beta*Matrix
// extern template void la_operations::apply<float>(la_objects::LAMatrix<float>& _dest, MMpCM<float>& _src_expr);

// extern template void la_operations::apply<double>(la_objects::LAMatrix<double>& _dest, MMpCM<double>& _src_expr);

// extern template void la_operations::apply<std::complex<float>>(la_objects::LAMatrix<std::complex<float>>& _dest, MMpCM<std::complex<float>>& _src_expr);

// extern template void la_operations::apply<std::complex<double>>(la_objects::LAMatrix<std::complex<double>>& _dest, MMpCM<std::complex<double>>& _src_expr);