blob: f1034655eb7cda8797ecc6e603c5a725218569d7 [file] [log] [blame]
// This file is part of Eigen, a lightweight C++ template library
// for linear algebra.
//
// Copyright (C) 2009 Gael Guennebaud <gael.guennebaud@inria.fr>
//
// This Source Code Form is subject to the terms of the Mozilla
// Public License v. 2.0. If a copy of the MPL was not distributed
// with this file, You can obtain one at http://mozilla.org/MPL/2.0/.
#ifndef EIGEN_SELFADJOINT_PRODUCT_H
#define EIGEN_SELFADJOINT_PRODUCT_H
/**********************************************************************
* This file implements a self adjoint product: C += A A^T updating only
* half of the selfadjoint matrix C.
* It corresponds to the level 3 SYRK and level 2 SYR Blas routines.
**********************************************************************/
// IWYU pragma: private
#include "../InternalHeaderCheck.h"
namespace Eigen {
template <typename Scalar, typename Index, int UpLo, bool ConjLhs, bool ConjRhs>
struct selfadjoint_rank1_update<Scalar, Index, ColMajor, UpLo, ConjLhs, ConjRhs> {
static void run(Index size, Scalar* mat, Index stride, const Scalar* vecX, const Scalar* vecY, const Scalar& alpha) {
internal::conj_if<ConjRhs> cj;
typedef Map<const Matrix<Scalar, Dynamic, 1> > OtherMap;
typedef std::conditional_t<ConjLhs, typename OtherMap::ConjugateReturnType, const OtherMap&> ConjLhsType;
for (Index i = 0; i < size; ++i) {
Map<Matrix<Scalar, Dynamic, 1> >(mat + stride * i + (UpLo == Lower ? i : 0),
(UpLo == Lower ? size - i : (i + 1))) +=
(alpha * cj(vecY[i])) *
ConjLhsType(OtherMap(vecX + (UpLo == Lower ? i : 0), UpLo == Lower ? size - i : (i + 1)));
}
}
};
template <typename Scalar, typename Index, int UpLo, bool ConjLhs, bool ConjRhs>
struct selfadjoint_rank1_update<Scalar, Index, RowMajor, UpLo, ConjLhs, ConjRhs> {
static void run(Index size, Scalar* mat, Index stride, const Scalar* vecX, const Scalar* vecY, const Scalar& alpha) {
selfadjoint_rank1_update<Scalar, Index, ColMajor, UpLo == Lower ? Upper : Lower, ConjRhs, ConjLhs>::run(
size, mat, stride, vecY, vecX, alpha);
}
};
template <typename MatrixType, typename OtherType, int UpLo, bool OtherIsVector = OtherType::IsVectorAtCompileTime>
struct selfadjoint_product_selector;
template <typename MatrixType, typename OtherType, int UpLo>
struct selfadjoint_product_selector<MatrixType, OtherType, UpLo, true> {
static void run(MatrixType& mat, const OtherType& other, const typename MatrixType::Scalar& alpha) {
typedef typename MatrixType::Scalar Scalar;
typedef internal::blas_traits<OtherType> OtherBlasTraits;
typedef typename OtherBlasTraits::DirectLinearAccessType ActualOtherType;
typedef internal::remove_all_t<ActualOtherType> ActualOtherType_;
internal::add_const_on_value_type_t<ActualOtherType> actualOther = OtherBlasTraits::extract(other.derived());
Scalar actualAlpha = alpha * OtherBlasTraits::extractScalarFactor(other.derived());
enum {
StorageOrder = (internal::traits<MatrixType>::Flags & RowMajorBit) ? RowMajor : ColMajor,
UseOtherDirectly = ActualOtherType_::InnerStrideAtCompileTime == 1
};
internal::gemv_static_vector_if<Scalar, OtherType::SizeAtCompileTime, OtherType::MaxSizeAtCompileTime,
!UseOtherDirectly>
static_other;
ei_declare_aligned_stack_constructed_variable(
Scalar, actualOtherPtr, other.size(),
(UseOtherDirectly ? const_cast<Scalar*>(actualOther.data()) : static_other.data()));
if (!UseOtherDirectly)
Map<typename ActualOtherType_::PlainObject>(actualOtherPtr, actualOther.size()) = actualOther;
selfadjoint_rank1_update<
Scalar, Index, StorageOrder, UpLo, OtherBlasTraits::NeedToConjugate && NumTraits<Scalar>::IsComplex,
(!OtherBlasTraits::NeedToConjugate) && NumTraits<Scalar>::IsComplex>::run(other.size(), mat.data(),
mat.outerStride(), actualOtherPtr,
actualOtherPtr, actualAlpha);
}
};
template <typename MatrixType, typename OtherType, int UpLo>
struct selfadjoint_product_selector<MatrixType, OtherType, UpLo, false> {
static void run(MatrixType& mat, const OtherType& other, const typename MatrixType::Scalar& alpha) {
typedef typename MatrixType::Scalar Scalar;
typedef internal::blas_traits<OtherType> OtherBlasTraits;
typedef typename OtherBlasTraits::DirectLinearAccessType ActualOtherType;
typedef internal::remove_all_t<ActualOtherType> ActualOtherType_;
internal::add_const_on_value_type_t<ActualOtherType> actualOther = OtherBlasTraits::extract(other.derived());
Scalar actualAlpha = alpha * OtherBlasTraits::extractScalarFactor(other.derived());
enum {
IsRowMajor = (internal::traits<MatrixType>::Flags & RowMajorBit) ? 1 : 0,
OtherIsRowMajor = ActualOtherType_::Flags & RowMajorBit ? 1 : 0
};
Index size = mat.cols();
Index depth = actualOther.cols();
typedef internal::gemm_blocking_space<IsRowMajor ? RowMajor : ColMajor, Scalar, Scalar,
MatrixType::MaxColsAtCompileTime, MatrixType::MaxColsAtCompileTime,
ActualOtherType_::MaxColsAtCompileTime>
BlockingType;
BlockingType blocking(size, size, depth, 1, false);
internal::general_matrix_matrix_triangular_product<
Index, Scalar, OtherIsRowMajor ? RowMajor : ColMajor,
OtherBlasTraits::NeedToConjugate && NumTraits<Scalar>::IsComplex, Scalar, OtherIsRowMajor ? ColMajor : RowMajor,
(!OtherBlasTraits::NeedToConjugate) && NumTraits<Scalar>::IsComplex, IsRowMajor ? RowMajor : ColMajor,
MatrixType::InnerStrideAtCompileTime, UpLo>::run(size, depth, actualOther.data(), actualOther.outerStride(),
actualOther.data(), actualOther.outerStride(), mat.data(),
mat.innerStride(), mat.outerStride(), actualAlpha, blocking);
}
};
// high level API
template <typename MatrixType, unsigned int UpLo>
template <typename DerivedU>
EIGEN_DEVICE_FUNC SelfAdjointView<MatrixType, UpLo>& SelfAdjointView<MatrixType, UpLo>::rankUpdate(
const MatrixBase<DerivedU>& u, const Scalar& alpha) {
selfadjoint_product_selector<MatrixType, DerivedU, UpLo>::run(_expression().const_cast_derived(), u.derived(), alpha);
return *this;
}
} // end namespace Eigen
#endif // EIGEN_SELFADJOINT_PRODUCT_H