Update Eigen to commit:b02c384ef4e8eba7b8bdef16f9dc6f8f4d6a6b2b
CHANGELOG
=========
b02c384ef - Add fused multiply functions for PowerPC - pmsub, pnmadd and pnmsub
3de96caea - Fix HouseholderSequence.h
f845a8bb1 - Fix cwise NaN propagation for scalar input.
a4bb513b9 - Update HouseholderSequence.h
fc1d88841 - Remove AVX512VL dependency in trsm
07db964bd - Restrict new AVX512 trsm to AVX512VL, rename files for consistency.
67eeba6e7 - Avoidable heap allocation in applyHouseholderToTheLeft
3342fc7e4 - Allow all tests to pass with `EIGEN_TEST_NO_EXPLICIT_VECTORIZATION`
efb08e0bb - Revert "Fix ambiguous DiagonalMatrix constructors."
53eec53d2 - Fix Power GEMV order of operations in predux for MMA.
a81bba962 - Fix ambiguous DiagonalMatrix constructors.
f7b31f864 - Revert "Replace call to FixedDimensions() with a singleton instance of"
f3ba220c5 - Remove EIGEN_EMPTY_STRUCT_CTOR
5ed7a86ae - Fix MSVC+CUDA issues.
734ed1efa - Fix ODR issues in lapacke_helpers.
2c45a3846 - Fix some max size expressions.
edc822666 - Fix navbar scroll with toc.
df87d40e3 - constexpr reshape helper
403fa3340 - Performance improvements in GEMM for Power
e1df3636b - More constexpr helpers
64909b82b - static const class members turned into constexpr
2c0ef43b4 - Added Scaling function overload for vector rvalue reference
ba2cb835a - Add back std::remove* aliases - third-party libraries rely on these.
0c859cf35 - Consider inf/nan in scalar test_isApprox.
1ddd3e29c - fixed order of arguments in blas syrk
2c5644280 - Don't include .cpp in lapack.
73b2c13bf - Disable f16c scalar conversions for MSVC.
9bc9992dd - Eliminate trace unused warning.
e22d58e81 - Add is_constant_evaluated, update alignment checks
f0a91838a - Enable Aarch64 CI
b9d2900e8 - added a missing typename and fixed a unused typedef warning
0611f7fff - Add missing explicit reinterprets
cd3c81c3b - Add a NNLS solver to unsupported - issue #655
0699fa06f - Split general_matrix_vector_product interface for Power into two macros - one ColMajor and RowMajor.
19a6a827c - Optimize visitor traversal in case of RowMajor.
f2a3e03e9 - Fix usages of wrong namespace
4451823fb - Fix ODR violation in trsm.
9a14d91a9 - Fix AVX512 builds with MSVC.
7b10795e3 - Change EIGEN_ALTIVEC_ENABLE_MMA_DYNAMIC_DISPATCH and EIGEN_ALTIVEC_DISABLE_MMA flags to be like TensorFlow's...
3ca1228d4 - Work around MSVC compiler bug dropping `const`.
40eb34bc5 - Fix RowMajorBit <-> RowMajor mixup.
c06298346 - Completed a missing parenthesis in tutorial.
9deaa1912 - Work around g++-10 docker issue for geo_orthomethods_4.
e34db1239 - Fix missing pound
591906477 - Fix up PowerPC MMA flags so it builds by default.
518fc321c - AVX512 Optimizations for Triangular Solve
01b5bc48c - Disable schur non-convergence test.
421cbf086 - Replace Eigen type metaprogramming with corresponding std types and make use of alias templates
514f90c9f - Remove workarounds for bad GCC-4 warnings
9ad566148 - Revert "Fix up PowerPC MMA flags so it builds by default."
65eeedf96 - Fix up PowerPC MMA flags so it builds by default.
cb1e8228e - Convert bit calculation to constexpr, avoid casts.
baf9a985e - Fix swap test for size 1 inputs.
788240885 - Temporarily disable aarch64 CI.
2a6be5492 - Fix construct_at compilation breakage on ROCm.
a3b64625e - Remove ComputeCpp-specific code from SYCL Vptr
9296bb4b9 - Fix edge-case in zeta for large inputs.
cd2ba9d03 - Add construct_at, destroy_at wrappers. Use throughout.
dfa517678 - make SparseSolverBase and IterativeSolverBase move constructable
9883108f3 - Remove copy_bool workaround for gcc 4.3
3a9d404d7 - Add support for Apple's Accelerate sparse matrix solvers
PiperOrigin-RevId: 445164768
Change-Id: I9e4ce57b25b58eeee180e795f774129ae409141e
diff --git a/Eigen/src/Core/ArithmeticSequence.h b/Eigen/src/Core/ArithmeticSequence.h
index 112ca98..81005c5 100644
--- a/Eigen/src/Core/ArithmeticSequence.h
+++ b/Eigen/src/Core/ArithmeticSequence.h
@@ -200,7 +200,7 @@
// Convert a symbolic span into a usable one (i.e., remove last/end "keywords")
template<typename T>
struct make_size_type {
- typedef typename internal::conditional<symbolic::is_symbolic<T>::value, Index, T>::type type;
+ typedef std::conditional_t<symbolic::is_symbolic<T>::value, Index, T> type;
};
template<typename FirstType,typename SizeType,typename IncrType,int XprSize>
diff --git a/Eigen/src/Core/Array.h b/Eigen/src/Core/Array.h
index b652852..7be8971 100644
--- a/Eigen/src/Core/Array.h
+++ b/Eigen/src/Core/Array.h
@@ -274,8 +274,8 @@
template<typename OtherDerived>
EIGEN_DEVICE_FUNC
EIGEN_STRONG_INLINE Array(const EigenBase<OtherDerived> &other,
- typename internal::enable_if<internal::is_convertible<typename OtherDerived::Scalar,Scalar>::value,
- PrivateType>::type = PrivateType())
+ std::enable_if_t<internal::is_convertible<typename OtherDerived::Scalar,Scalar>::value,
+ PrivateType> = PrivateType())
: Base(other.derived())
{ }
diff --git a/Eigen/src/Core/ArrayWrapper.h b/Eigen/src/Core/ArrayWrapper.h
index 8a34ca5..e65b8fb 100644
--- a/Eigen/src/Core/ArrayWrapper.h
+++ b/Eigen/src/Core/ArrayWrapper.h
@@ -28,12 +28,12 @@
namespace internal {
template<typename ExpressionType>
struct traits<ArrayWrapper<ExpressionType> >
- : public traits<typename remove_all<typename ExpressionType::Nested>::type >
+ : public traits<remove_all_t<typename ExpressionType::Nested> >
{
typedef ArrayXpr XprKind;
// Let's remove NestByRefBit
enum {
- Flags0 = traits<typename remove_all<typename ExpressionType::Nested>::type >::Flags,
+ Flags0 = traits<remove_all_t<typename ExpressionType::Nested> >::Flags,
LvalueBitFlag = is_lvalue<ExpressionType>::value ? LvalueBit : 0,
Flags = (Flags0 & ~(NestByRefBit | LvalueBit)) | LvalueBitFlag
};
@@ -47,13 +47,13 @@
typedef ArrayBase<ArrayWrapper> Base;
EIGEN_DENSE_PUBLIC_INTERFACE(ArrayWrapper)
EIGEN_INHERIT_ASSIGNMENT_OPERATORS(ArrayWrapper)
- typedef typename internal::remove_all<ExpressionType>::type NestedExpression;
+ typedef internal::remove_all_t<ExpressionType> NestedExpression;
- typedef typename internal::conditional<
+ typedef std::conditional_t<
internal::is_lvalue<ExpressionType>::value,
Scalar,
const Scalar
- >::type ScalarWithConstIfNotLvalue;
+ > ScalarWithConstIfNotLvalue;
typedef typename internal::ref_selector<ExpressionType>::non_const_type NestedExpressionType;
@@ -93,7 +93,7 @@
inline void evalTo(Dest& dst) const { dst = m_expression; }
EIGEN_DEVICE_FUNC
- const typename internal::remove_all<NestedExpressionType>::type&
+ const internal::remove_all_t<NestedExpressionType>&
nestedExpression() const
{
return m_expression;
@@ -126,12 +126,12 @@
namespace internal {
template<typename ExpressionType>
struct traits<MatrixWrapper<ExpressionType> >
- : public traits<typename remove_all<typename ExpressionType::Nested>::type >
+ : public traits<remove_all_t<typename ExpressionType::Nested> >
{
typedef MatrixXpr XprKind;
// Let's remove NestByRefBit
enum {
- Flags0 = traits<typename remove_all<typename ExpressionType::Nested>::type >::Flags,
+ Flags0 = traits<remove_all_t<typename ExpressionType::Nested> >::Flags,
LvalueBitFlag = is_lvalue<ExpressionType>::value ? LvalueBit : 0,
Flags = (Flags0 & ~(NestByRefBit | LvalueBit)) | LvalueBitFlag
};
@@ -145,13 +145,13 @@
typedef MatrixBase<MatrixWrapper<ExpressionType> > Base;
EIGEN_DENSE_PUBLIC_INTERFACE(MatrixWrapper)
EIGEN_INHERIT_ASSIGNMENT_OPERATORS(MatrixWrapper)
- typedef typename internal::remove_all<ExpressionType>::type NestedExpression;
+ typedef internal::remove_all_t<ExpressionType> NestedExpression;
- typedef typename internal::conditional<
- internal::is_lvalue<ExpressionType>::value,
- Scalar,
- const Scalar
- >::type ScalarWithConstIfNotLvalue;
+ typedef std::conditional_t<
+ internal::is_lvalue<ExpressionType>::value,
+ Scalar,
+ const Scalar
+ > ScalarWithConstIfNotLvalue;
typedef typename internal::ref_selector<ExpressionType>::non_const_type NestedExpressionType;
@@ -187,7 +187,7 @@
}
EIGEN_DEVICE_FUNC
- const typename internal::remove_all<NestedExpressionType>::type&
+ const internal::remove_all_t<NestedExpressionType>&
nestedExpression() const
{
return m_expression;
diff --git a/Eigen/src/Core/AssignEvaluator.h b/Eigen/src/Core/AssignEvaluator.h
index 2c00387..f9dc7a1 100644
--- a/Eigen/src/Core/AssignEvaluator.h
+++ b/Eigen/src/Core/AssignEvaluator.h
@@ -113,7 +113,7 @@
|| int(Traversal) == SliceVectorizedTraversal
};
- typedef typename conditional<int(Traversal)==LinearVectorizedTraversal, LinearPacketType, InnerPacketType>::type PacketType;
+ typedef std::conditional_t<int(Traversal)==LinearVectorizedTraversal, LinearPacketType, InnerPacketType> PacketType;
private:
enum {
@@ -846,7 +846,7 @@
// Deal with "assume-aliasing"
template<typename Dst, typename Src, typename Func>
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
-void call_assignment(Dst& dst, const Src& src, const Func& func, typename enable_if< evaluator_assume_aliasing<Src>::value, void*>::type = 0)
+void call_assignment(Dst& dst, const Src& src, const Func& func, std::enable_if_t< evaluator_assume_aliasing<Src>::value, void*> = 0)
{
typename plain_matrix_type<Src>::type tmp(src);
call_assignment_no_alias(dst, tmp, func);
@@ -854,7 +854,7 @@
template<typename Dst, typename Src, typename Func>
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
-void call_assignment(Dst& dst, const Src& src, const Func& func, typename enable_if<!evaluator_assume_aliasing<Src>::value, void*>::type = 0)
+void call_assignment(Dst& dst, const Src& src, const Func& func, std::enable_if_t<!evaluator_assume_aliasing<Src>::value, void*> = 0)
{
call_assignment_no_alias(dst, src, func);
}
@@ -879,8 +879,8 @@
) && int(Dst::SizeAtCompileTime) != 1
};
- typedef typename internal::conditional<NeedToTranspose, Transpose<Dst>, Dst>::type ActualDstTypeCleaned;
- typedef typename internal::conditional<NeedToTranspose, Transpose<Dst>, Dst&>::type ActualDstType;
+ typedef std::conditional_t<NeedToTranspose, Transpose<Dst>, Dst> ActualDstTypeCleaned;
+ typedef std::conditional_t<NeedToTranspose, Transpose<Dst>, Dst&> ActualDstType;
ActualDstType actualDst(dst);
// TODO check whether this is the right place to perform these checks:
diff --git a/Eigen/src/Core/Assign_MKL.h b/Eigen/src/Core/Assign_MKL.h
index c2afebd..f9b86c8 100755
--- a/Eigen/src/Core/Assign_MKL.h
+++ b/Eigen/src/Core/Assign_MKL.h
@@ -84,7 +84,7 @@
#define EIGEN_MKL_VML_DECLARE_UNARY_CALL(EIGENOP, VMLOP, EIGENTYPE, VMLTYPE, VMLMODE) \
template< typename DstXprType, typename SrcXprNested> \
struct Assignment<DstXprType, CwiseUnaryOp<scalar_##EIGENOP##_op<EIGENTYPE>, SrcXprNested>, assign_op<EIGENTYPE,EIGENTYPE>, \
- Dense2Dense, typename enable_if<vml_assign_traits<DstXprType,SrcXprNested>::EnableVml>::type> { \
+ Dense2Dense, std::enable_if_t<vml_assign_traits<DstXprType,SrcXprNested>::EnableVml>> { \
typedef CwiseUnaryOp<scalar_##EIGENOP##_op<EIGENTYPE>, SrcXprNested> SrcXprType; \
static void run(DstXprType &dst, const SrcXprType &src, const assign_op<EIGENTYPE,EIGENTYPE> &func) { \
resize_if_allowed(dst, src, func); \
@@ -144,7 +144,7 @@
template< typename DstXprType, typename SrcXprNested, typename Plain> \
struct Assignment<DstXprType, CwiseBinaryOp<scalar_##EIGENOP##_op<EIGENTYPE,EIGENTYPE>, SrcXprNested, \
const CwiseNullaryOp<internal::scalar_constant_op<EIGENTYPE>,Plain> >, assign_op<EIGENTYPE,EIGENTYPE>, \
- Dense2Dense, typename enable_if<vml_assign_traits<DstXprType,SrcXprNested>::EnableVml>::type> { \
+ Dense2Dense, std::enable_if_t<vml_assign_traits<DstXprType,SrcXprNested>::EnableVml>> { \
typedef CwiseBinaryOp<scalar_##EIGENOP##_op<EIGENTYPE,EIGENTYPE>, SrcXprNested, \
const CwiseNullaryOp<internal::scalar_constant_op<EIGENTYPE>,Plain> > SrcXprType; \
static void run(DstXprType &dst, const SrcXprType &src, const assign_op<EIGENTYPE,EIGENTYPE> &func) { \
diff --git a/Eigen/src/Core/BandMatrix.h b/Eigen/src/Core/BandMatrix.h
index 913a967..c2d943c 100644
--- a/Eigen/src/Core/BandMatrix.h
+++ b/Eigen/src/Core/BandMatrix.h
@@ -102,9 +102,9 @@
: min_size_prefer_dynamic(RowsAtCompileTime, ColsAtCompileTime - ActualIndex))
};
typedef Block<CoefficientsType,1, DiagonalSize> BuildType;
- typedef typename internal::conditional<Conjugate,
+ typedef std::conditional_t<Conjugate,
CwiseUnaryOp<internal::scalar_conjugate_op<Scalar>,BuildType >,
- BuildType>::type Type;
+ BuildType> Type;
};
/** \returns a vector expression of the \a N -th sub or super diagonal */
diff --git a/Eigen/src/Core/Block.h b/Eigen/src/Core/Block.h
index afbea86..19c4b68 100644
--- a/Eigen/src/Core/Block.h
+++ b/Eigen/src/Core/Block.h
@@ -23,7 +23,7 @@
typedef typename traits<XprType>::StorageKind StorageKind;
typedef typename traits<XprType>::XprKind XprKind;
typedef typename ref_selector<XprType>::type XprTypeNested;
- typedef typename remove_reference<XprTypeNested>::type XprTypeNested_;
+ typedef std::remove_reference_t<XprTypeNested> XprTypeNested_;
enum{
MatrixRows = traits<XprType>::RowsAtCompileTime,
MatrixCols = traits<XprType>::ColsAtCompileTime,
@@ -112,7 +112,7 @@
EIGEN_GENERIC_PUBLIC_INTERFACE(Block)
EIGEN_INHERIT_ASSIGNMENT_OPERATORS(Block)
- typedef typename internal::remove_all<XprType>::type NestedExpression;
+ typedef internal::remove_all_t<XprType> NestedExpression;
/** Column or Row constructor
*/
@@ -297,7 +297,7 @@
#endif
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
- const typename internal::remove_all<XprTypeNested>::type& nestedExpression() const
+ const internal::remove_all_t<XprTypeNested>& nestedExpression() const
{
return m_xpr;
}
@@ -380,7 +380,7 @@
}
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
- const typename internal::remove_all<XprTypeNested>::type& nestedExpression() const EIGEN_NOEXCEPT
+ const internal::remove_all_t<XprTypeNested>& nestedExpression() const EIGEN_NOEXCEPT
{
return m_xpr;
}
diff --git a/Eigen/src/Core/CoreEvaluators.h b/Eigen/src/Core/CoreEvaluators.h
index ac8c8e5..1729507 100644
--- a/Eigen/src/Core/CoreEvaluators.h
+++ b/Eigen/src/Core/CoreEvaluators.h
@@ -500,7 +500,7 @@
: evaluator_base<CwiseNullaryOp<NullaryOp,PlainObjectType> >
{
typedef CwiseNullaryOp<NullaryOp,PlainObjectType> XprType;
- typedef typename internal::remove_all<PlainObjectType>::type PlainObjectTypeCleaned;
+ typedef internal::remove_all_t<PlainObjectType> PlainObjectTypeCleaned;
enum {
CoeffReadCost = internal::functor_traits<NullaryOp>::Cost,
@@ -1225,8 +1225,8 @@
explicit block_evaluator(const XprType& block)
: mapbase_evaluator<XprType, typename XprType::PlainObject>(block)
{
- // TODO: for the 3.3 release, this should be turned to an internal assertion, but let's keep it as is for the beta lifetime
- eigen_assert(((internal::UIntPtr(block.data()) % plain_enum_max(1,evaluator<XprType>::Alignment)) == 0) && "data is not aligned");
+ eigen_internal_assert((internal::is_constant_evaluated() || (internal::UIntPtr(block.data()) % plain_enum_max(1,evaluator<XprType>::Alignment)) == 0) \
+ && "data is not aligned");
}
};
@@ -1298,7 +1298,7 @@
Factor = (RowFactor==Dynamic || ColFactor==Dynamic) ? Dynamic : RowFactor*ColFactor
};
typedef typename internal::nested_eval<ArgType,Factor>::type ArgTypeNested;
- typedef typename internal::remove_all<ArgTypeNested>::type ArgTypeNestedCleaned;
+ typedef internal::remove_all_t<ArgTypeNested> ArgTypeNestedCleaned;
enum {
CoeffReadCost = evaluator<ArgTypeNestedCleaned>::CoeffReadCost,
@@ -1382,7 +1382,7 @@
struct evaluator_wrapper_base
: evaluator_base<XprType>
{
- typedef typename remove_all<typename XprType::NestedExpressionType>::type ArgType;
+ typedef remove_all_t<typename XprType::NestedExpressionType> ArgType;
enum {
CoeffReadCost = evaluator<ArgType>::CoeffReadCost,
Flags = evaluator<ArgType>::Flags,
@@ -1723,14 +1723,14 @@
EIGEN_DEVICE_FUNC explicit evaluator(const XprType& xpr)
: m_result(xpr.arg())
{
- ::new (static_cast<Base*>(this)) Base(m_result);
+ internal::construct_at<Base>(this, m_result);
}
// This constructor is used when nesting an EvalTo evaluator in another evaluator
EIGEN_DEVICE_FUNC evaluator(const ArgType& arg)
: m_result(arg)
{
- ::new (static_cast<Base*>(this)) Base(m_result);
+ internal::construct_at<Base>(this, m_result);
}
protected:
diff --git a/Eigen/src/Core/CwiseBinaryOp.h b/Eigen/src/Core/CwiseBinaryOp.h
index ea491c6..21a061a 100644
--- a/Eigen/src/Core/CwiseBinaryOp.h
+++ b/Eigen/src/Core/CwiseBinaryOp.h
@@ -21,7 +21,7 @@
{
// we must not inherit from traits<Lhs> since it has
// the potential to cause problems with MSVC
- typedef typename remove_all<Lhs>::type Ancestor;
+ typedef remove_all_t<Lhs> Ancestor;
typedef typename traits<Ancestor>::XprKind XprKind;
enum {
RowsAtCompileTime = traits<Ancestor>::RowsAtCompileTime,
@@ -45,8 +45,8 @@
typename traits<Rhs>::StorageIndex>::type StorageIndex;
typedef typename Lhs::Nested LhsNested;
typedef typename Rhs::Nested RhsNested;
- typedef typename remove_reference<LhsNested>::type LhsNested_;
- typedef typename remove_reference<RhsNested>::type RhsNested_;
+ typedef std::remove_reference_t<LhsNested> LhsNested_;
+ typedef std::remove_reference_t<RhsNested> RhsNested_;
enum {
Flags = cwise_promote_storage_order<typename traits<Lhs>::StorageKind,typename traits<Rhs>::StorageKind,LhsNested_::Flags & RowMajorBit,RhsNested_::Flags & RowMajorBit>::value
};
@@ -86,9 +86,9 @@
{
public:
- typedef typename internal::remove_all<BinaryOp>::type Functor;
- typedef typename internal::remove_all<LhsType>::type Lhs;
- typedef typename internal::remove_all<RhsType>::type Rhs;
+ typedef internal::remove_all_t<BinaryOp> Functor;
+ typedef internal::remove_all_t<LhsType> Lhs;
+ typedef internal::remove_all_t<RhsType> Rhs;
typedef typename CwiseBinaryOpImpl<
BinaryOp, LhsType, RhsType,
@@ -102,8 +102,8 @@
typedef typename internal::ref_selector<LhsType>::type LhsNested;
typedef typename internal::ref_selector<RhsType>::type RhsNested;
- typedef typename internal::remove_reference<LhsNested>::type LhsNested_;
- typedef typename internal::remove_reference<RhsNested>::type RhsNested_;
+ typedef std::remove_reference_t<LhsNested> LhsNested_;
+ typedef std::remove_reference_t<RhsNested> RhsNested_;
#if EIGEN_COMP_MSVC
//Required for Visual Studio or the Copy constructor will probably not get inlined!
@@ -121,12 +121,12 @@
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE EIGEN_CONSTEXPR
Index rows() const EIGEN_NOEXCEPT {
// return the fixed size type if available to enable compile time optimizations
- return internal::traits<typename internal::remove_all<LhsNested>::type>::RowsAtCompileTime==Dynamic ? m_rhs.rows() : m_lhs.rows();
+ return internal::traits<internal::remove_all_t<LhsNested>>::RowsAtCompileTime==Dynamic ? m_rhs.rows() : m_lhs.rows();
}
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE EIGEN_CONSTEXPR
Index cols() const EIGEN_NOEXCEPT {
// return the fixed size type if available to enable compile time optimizations
- return internal::traits<typename internal::remove_all<LhsNested>::type>::ColsAtCompileTime==Dynamic ? m_rhs.cols() : m_lhs.cols();
+ return internal::traits<internal::remove_all_t<LhsNested>>::ColsAtCompileTime==Dynamic ? m_rhs.cols() : m_lhs.cols();
}
/** \returns the left hand side nested expression */
diff --git a/Eigen/src/Core/CwiseTernaryOp.h b/Eigen/src/Core/CwiseTernaryOp.h
index 393279bc..8d24a48 100644
--- a/Eigen/src/Core/CwiseTernaryOp.h
+++ b/Eigen/src/Core/CwiseTernaryOp.h
@@ -21,7 +21,7 @@
struct traits<CwiseTernaryOp<TernaryOp, Arg1, Arg2, Arg3> > {
// we must not inherit from traits<Arg1> since it has
// the potential to cause problems with MSVC
- typedef typename remove_all<Arg1>::type Ancestor;
+ typedef remove_all_t<Arg1> Ancestor;
typedef typename traits<Ancestor>::XprKind XprKind;
enum {
RowsAtCompileTime = traits<Ancestor>::RowsAtCompileTime,
@@ -43,9 +43,9 @@
typedef typename Arg1::Nested Arg1Nested;
typedef typename Arg2::Nested Arg2Nested;
typedef typename Arg3::Nested Arg3Nested;
- typedef typename remove_reference<Arg1Nested>::type Arg1Nested_;
- typedef typename remove_reference<Arg2Nested>::type Arg2Nested_;
- typedef typename remove_reference<Arg3Nested>::type Arg3Nested_;
+ typedef std::remove_reference_t<Arg1Nested> Arg1Nested_;
+ typedef std::remove_reference_t<Arg2Nested> Arg2Nested_;
+ typedef std::remove_reference_t<Arg3Nested> Arg3Nested_;
enum { Flags = Arg1Nested_::Flags & RowMajorBit };
};
} // end namespace internal
@@ -89,9 +89,9 @@
internal::no_assignment_operator
{
public:
- typedef typename internal::remove_all<Arg1Type>::type Arg1;
- typedef typename internal::remove_all<Arg2Type>::type Arg2;
- typedef typename internal::remove_all<Arg3Type>::type Arg3;
+ typedef internal::remove_all_t<Arg1Type> Arg1;
+ typedef internal::remove_all_t<Arg2Type> Arg2;
+ typedef internal::remove_all_t<Arg3Type> Arg3;
// require the sizes to match
EIGEN_STATIC_ASSERT_SAME_MATRIX_SIZE(Arg1, Arg2)
@@ -115,9 +115,9 @@
typedef typename internal::ref_selector<Arg1Type>::type Arg1Nested;
typedef typename internal::ref_selector<Arg2Type>::type Arg2Nested;
typedef typename internal::ref_selector<Arg3Type>::type Arg3Nested;
- typedef typename internal::remove_reference<Arg1Nested>::type Arg1Nested_;
- typedef typename internal::remove_reference<Arg2Nested>::type Arg2Nested_;
- typedef typename internal::remove_reference<Arg3Nested>::type Arg3Nested_;
+ typedef std::remove_reference_t<Arg1Nested> Arg1Nested_;
+ typedef std::remove_reference_t<Arg2Nested> Arg2Nested_;
+ typedef std::remove_reference_t<Arg3Nested> Arg3Nested_;
EIGEN_DEVICE_FUNC
EIGEN_STRONG_INLINE CwiseTernaryOp(const Arg1& a1, const Arg2& a2,
@@ -132,14 +132,14 @@
EIGEN_STRONG_INLINE Index rows() const {
// return the fixed size type if available to enable compile time
// optimizations
- if (internal::traits<typename internal::remove_all<Arg1Nested>::type>::
+ if (internal::traits<internal::remove_all_t<Arg1Nested>>::
RowsAtCompileTime == Dynamic &&
- internal::traits<typename internal::remove_all<Arg2Nested>::type>::
+ internal::traits<internal::remove_all_t<Arg2Nested>>::
RowsAtCompileTime == Dynamic)
return m_arg3.rows();
- else if (internal::traits<typename internal::remove_all<Arg1Nested>::type>::
+ else if (internal::traits<internal::remove_all_t<Arg1Nested>>::
RowsAtCompileTime == Dynamic &&
- internal::traits<typename internal::remove_all<Arg3Nested>::type>::
+ internal::traits<internal::remove_all_t<Arg3Nested>>::
RowsAtCompileTime == Dynamic)
return m_arg2.rows();
else
@@ -149,14 +149,14 @@
EIGEN_STRONG_INLINE Index cols() const {
// return the fixed size type if available to enable compile time
// optimizations
- if (internal::traits<typename internal::remove_all<Arg1Nested>::type>::
+ if (internal::traits<internal::remove_all_t<Arg1Nested>>::
ColsAtCompileTime == Dynamic &&
- internal::traits<typename internal::remove_all<Arg2Nested>::type>::
+ internal::traits<internal::remove_all_t<Arg2Nested>>::
ColsAtCompileTime == Dynamic)
return m_arg3.cols();
- else if (internal::traits<typename internal::remove_all<Arg1Nested>::type>::
+ else if (internal::traits<internal::remove_all_t<Arg1Nested>>::
ColsAtCompileTime == Dynamic &&
- internal::traits<typename internal::remove_all<Arg3Nested>::type>::
+ internal::traits<internal::remove_all_t<Arg3Nested>>::
ColsAtCompileTime == Dynamic)
return m_arg2.cols();
else
diff --git a/Eigen/src/Core/CwiseUnaryOp.h b/Eigen/src/Core/CwiseUnaryOp.h
index d9985c0..ff7d0b9 100644
--- a/Eigen/src/Core/CwiseUnaryOp.h
+++ b/Eigen/src/Core/CwiseUnaryOp.h
@@ -24,7 +24,7 @@
UnaryOp(const typename XprType::Scalar&)
>::type Scalar;
typedef typename XprType::Nested XprTypeNested;
- typedef typename remove_reference<XprTypeNested>::type XprTypeNested_;
+ typedef std::remove_reference_t<XprTypeNested> XprTypeNested_;
enum {
Flags = XprTypeNested_::Flags & RowMajorBit
};
@@ -61,7 +61,7 @@
typedef typename CwiseUnaryOpImpl<UnaryOp, XprType,typename internal::traits<XprType>::StorageKind>::Base Base;
EIGEN_GENERIC_PUBLIC_INTERFACE(CwiseUnaryOp)
typedef typename internal::ref_selector<XprType>::type XprTypeNested;
- typedef typename internal::remove_all<XprType>::type NestedExpression;
+ typedef internal::remove_all_t<XprType> NestedExpression;
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
explicit CwiseUnaryOp(const XprType& xpr, const UnaryOp& func = UnaryOp())
@@ -78,12 +78,12 @@
/** \returns the nested expression */
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
- const typename internal::remove_all<XprTypeNested>::type&
+ const internal::remove_all_t<XprTypeNested>&
nestedExpression() const { return m_xpr; }
/** \returns the nested expression */
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
- typename internal::remove_all<XprTypeNested>::type&
+ internal::remove_all_t<XprTypeNested>&
nestedExpression() { return m_xpr; }
protected:
diff --git a/Eigen/src/Core/CwiseUnaryView.h b/Eigen/src/Core/CwiseUnaryView.h
index fabb3f8..b4539a6 100644
--- a/Eigen/src/Core/CwiseUnaryView.h
+++ b/Eigen/src/Core/CwiseUnaryView.h
@@ -23,7 +23,7 @@
ViewOp(const typename traits<MatrixType>::Scalar&)
>::type Scalar;
typedef typename MatrixType::Nested MatrixTypeNested;
- typedef typename remove_all<MatrixTypeNested>::type MatrixTypeNested_;
+ typedef remove_all_t<MatrixTypeNested> MatrixTypeNested_;
enum {
FlagsLvalueBit = is_lvalue<MatrixType>::value ? LvalueBit : 0,
Flags = traits<MatrixTypeNested_>::Flags & (RowMajorBit | FlagsLvalueBit | DirectAccessBit), // FIXME DirectAccessBit should not be handled by expressions
@@ -69,7 +69,7 @@
typedef typename CwiseUnaryViewImpl<ViewOp, MatrixType, StrideType, typename internal::traits<MatrixType>::StorageKind>::Base Base;
EIGEN_GENERIC_PUBLIC_INTERFACE(CwiseUnaryView)
typedef typename internal::ref_selector<MatrixType>::non_const_type MatrixTypeNested;
- typedef typename internal::remove_all<MatrixType>::type NestedExpression;
+ typedef internal::remove_all_t<MatrixType> NestedExpression;
explicit EIGEN_DEVICE_FUNC inline CwiseUnaryView(MatrixType& mat, const ViewOp& func = ViewOp())
: m_matrix(mat), m_functor(func) {}
@@ -85,11 +85,11 @@
EIGEN_DEVICE_FUNC const ViewOp& functor() const { return m_functor; }
/** \returns the nested expression */
- EIGEN_DEVICE_FUNC const typename internal::remove_all<MatrixTypeNested>::type&
+ EIGEN_DEVICE_FUNC const internal::remove_all_t<MatrixTypeNested>&
nestedExpression() const { return m_matrix; }
/** \returns the nested expression */
- EIGEN_DEVICE_FUNC typename internal::remove_reference<MatrixTypeNested>::type&
+ EIGEN_DEVICE_FUNC std::remove_reference_t<MatrixTypeNested>&
nestedExpression() { return m_matrix; }
protected:
diff --git a/Eigen/src/Core/DenseBase.h b/Eigen/src/Core/DenseBase.h
index d62c851..6e17779 100644
--- a/Eigen/src/Core/DenseBase.h
+++ b/Eigen/src/Core/DenseBase.h
@@ -105,8 +105,7 @@
* \sa MatrixBase::rows(), MatrixBase::cols(), RowsAtCompileTime, SizeAtCompileTime */
- SizeAtCompileTime = (internal::size_at_compile_time<internal::traits<Derived>::RowsAtCompileTime,
- internal::traits<Derived>::ColsAtCompileTime>::ret),
+ SizeAtCompileTime = (internal::size_of_xpr_at_compile_time<Derived>::ret),
/**< This is equal to the number of coefficients, i.e. the number of
* rows times the number of columns, or to \a Dynamic if this is not
* known at compile-time. \sa RowsAtCompileTime, ColsAtCompileTime */
@@ -133,8 +132,8 @@
* \sa ColsAtCompileTime, MaxRowsAtCompileTime, MaxSizeAtCompileTime
*/
- MaxSizeAtCompileTime = (internal::size_at_compile_time<internal::traits<Derived>::MaxRowsAtCompileTime,
- internal::traits<Derived>::MaxColsAtCompileTime>::ret),
+ MaxSizeAtCompileTime = internal::size_at_compile_time(internal::traits<Derived>::MaxRowsAtCompileTime,
+ internal::traits<Derived>::MaxColsAtCompileTime),
/**< This value is equal to the maximum possible number of coefficients that this expression
* might have. If this expression might have an arbitrarily high number of coefficients,
* this value is set to \a Dynamic.
@@ -201,8 +200,8 @@
* the return type of eval() is a const reference to a matrix, not a matrix! It is however guaranteed
* that the return type of eval() is either PlainObject or const PlainObject&.
*/
- typedef typename internal::conditional<internal::is_same<typename internal::traits<Derived>::XprKind,MatrixXpr >::value,
- PlainMatrix, PlainArray>::type PlainObject;
+ typedef std::conditional_t<internal::is_same<typename internal::traits<Derived>::XprKind,MatrixXpr >::value,
+ PlainMatrix, PlainArray> PlainObject;
/** \returns the outer size.
*
@@ -314,9 +313,9 @@
typedef Transpose<Derived> TransposeReturnType;
EIGEN_DEVICE_FUNC
TransposeReturnType transpose();
- typedef typename internal::add_const<Transpose<const Derived> >::type ConstTransposeReturnType;
+ typedef Transpose<const Derived> ConstTransposeReturnType;
EIGEN_DEVICE_FUNC
- ConstTransposeReturnType transpose() const;
+ const ConstTransposeReturnType transpose() const;
EIGEN_DEVICE_FUNC
void transposeInPlace();
@@ -385,7 +384,7 @@
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
Derived& operator/=(const Scalar& other);
- typedef typename internal::add_const_on_value_type<typename internal::eval<Derived>::type>::type EvalReturnType;
+ typedef internal::add_const_on_value_type_t<typename internal::eval<Derived>::type> EvalReturnType;
/** \returns the matrix or vector obtained by evaluating this expression.
*
* Notice that in the case of a plain matrix or vector (not an expression) this function just returns
@@ -429,9 +428,9 @@
EIGEN_DEVICE_FUNC inline const ForceAlignedAccess<Derived> forceAlignedAccess() const;
EIGEN_DEVICE_FUNC inline ForceAlignedAccess<Derived> forceAlignedAccess();
template<bool Enable> EIGEN_DEVICE_FUNC
- inline const typename internal::conditional<Enable,ForceAlignedAccess<Derived>,Derived&>::type forceAlignedAccessIf() const;
+ inline const std::conditional_t<Enable,ForceAlignedAccess<Derived>,Derived&> forceAlignedAccessIf() const;
template<bool Enable> EIGEN_DEVICE_FUNC
- inline typename internal::conditional<Enable,ForceAlignedAccess<Derived>,Derived&>::type forceAlignedAccessIf();
+ inline std::conditional_t<Enable,ForceAlignedAccess<Derived>,Derived&> forceAlignedAccessIf();
EIGEN_DEVICE_FUNC Scalar sum() const;
EIGEN_DEVICE_FUNC Scalar mean() const;
@@ -611,27 +610,21 @@
/** This is the const version of iterator (aka read-only) */
typedef random_access_iterator_type const_iterator;
#else
- typedef typename internal::conditional< (Flags&DirectAccessBit)==DirectAccessBit,
- internal::pointer_based_stl_iterator<Derived>,
- internal::generic_randaccess_stl_iterator<Derived>
- >::type iterator_type;
+ typedef std::conditional_t< (Flags&DirectAccessBit)==DirectAccessBit,
+ internal::pointer_based_stl_iterator<Derived>,
+ internal::generic_randaccess_stl_iterator<Derived>
+ > iterator_type;
- typedef typename internal::conditional< (Flags&DirectAccessBit)==DirectAccessBit,
- internal::pointer_based_stl_iterator<const Derived>,
- internal::generic_randaccess_stl_iterator<const Derived>
- >::type const_iterator_type;
+ typedef std::conditional_t< (Flags&DirectAccessBit)==DirectAccessBit,
+ internal::pointer_based_stl_iterator<const Derived>,
+ internal::generic_randaccess_stl_iterator<const Derived>
+ > const_iterator_type;
// Stl-style iterators are supported only for vectors.
- typedef typename internal::conditional< IsVectorAtCompileTime,
- iterator_type,
- void
- >::type iterator;
+ typedef std::conditional_t<IsVectorAtCompileTime, iterator_type, void> iterator;
- typedef typename internal::conditional< IsVectorAtCompileTime,
- const_iterator_type,
- void
- >::type const_iterator;
+ typedef std::conditional_t<IsVectorAtCompileTime, const_iterator_type, void> const_iterator;
#endif
inline iterator begin();
diff --git a/Eigen/src/Core/DenseCoeffsBase.h b/Eigen/src/Core/DenseCoeffsBase.h
index 46d8730..7f0bcf4 100644
--- a/Eigen/src/Core/DenseCoeffsBase.h
+++ b/Eigen/src/Core/DenseCoeffsBase.h
@@ -17,7 +17,7 @@
namespace internal {
template<typename T> struct add_const_on_value_type_if_arithmetic
{
- typedef typename conditional<is_arithmetic<T>::value, T, typename add_const_on_value_type<T>::type>::type type;
+ typedef std::conditional_t<is_arithmetic<T>::value, T, add_const_on_value_type_t<T>> type;
};
}
@@ -48,10 +48,10 @@
// - The is_arithmetic check is required since "const int", "const double", etc. will cause warnings on some systems
// while the declaration of "const T", where T is a non arithmetic type does not. Always returning "const Scalar&" is
// not possible, since the underlying expressions might not offer a valid address the reference could be referring to.
- typedef typename internal::conditional<bool(internal::traits<Derived>::Flags&LvalueBit),
- const Scalar&,
- typename internal::conditional<internal::is_arithmetic<Scalar>::value, Scalar, const Scalar>::type
- >::type CoeffReturnType;
+ typedef std::conditional_t<bool(internal::traits<Derived>::Flags&LvalueBit),
+ const Scalar&,
+ std::conditional_t<internal::is_arithmetic<Scalar>::value, Scalar, const Scalar>
+ > CoeffReturnType;
typedef typename internal::add_const_on_value_type_if_arithmetic<
typename internal::packet_traits<Scalar>::type
diff --git a/Eigen/src/Core/DenseStorage.h b/Eigen/src/Core/DenseStorage.h
index 371da3c..5e2763e 100644
--- a/Eigen/src/Core/DenseStorage.h
+++ b/Eigen/src/Core/DenseStorage.h
@@ -69,13 +69,14 @@
template<typename PtrType>
EIGEN_ALWAYS_INLINE PtrType eigen_unaligned_array_assert_workaround_gcc47(PtrType array) { return array; }
#define EIGEN_MAKE_UNALIGNED_ARRAY_ASSERT(sizemask) \
- eigen_assert((internal::UIntPtr(eigen_unaligned_array_assert_workaround_gcc47(array)) & (sizemask)) == 0 \
+ eigen_assert((internal::is_constant_evaluated() \
+ || (internal::UIntPtr(eigen_unaligned_array_assert_workaround_gcc47(array)) & (sizemask)) == 0) \
&& "this assertion is explained here: " \
"http://eigen.tuxfamily.org/dox-devel/group__TopicUnalignedArrayAssert.html" \
" **** READ THIS WEB PAGE !!! ****");
#else
#define EIGEN_MAKE_UNALIGNED_ARRAY_ASSERT(sizemask) \
- eigen_assert((internal::UIntPtr(array) & (sizemask)) == 0 \
+ eigen_assert((internal::is_constant_evaluated() || (internal::UIntPtr(array) & (sizemask)) == 0) \
&& "this assertion is explained here: " \
"http://eigen.tuxfamily.org/dox-devel/group__TopicUnalignedArrayAssert.html" \
" **** READ THIS WEB PAGE !!! ****");
diff --git a/Eigen/src/Core/Diagonal.h b/Eigen/src/Core/Diagonal.h
index 2bf4527..4af17dd 100644
--- a/Eigen/src/Core/Diagonal.h
+++ b/Eigen/src/Core/Diagonal.h
@@ -40,7 +40,7 @@
: traits<MatrixType>
{
typedef typename ref_selector<MatrixType>::type MatrixTypeNested;
- typedef typename remove_reference<MatrixTypeNested>::type MatrixTypeNested_;
+ typedef std::remove_reference_t<MatrixTypeNested> MatrixTypeNested_;
typedef typename MatrixType::StorageKind StorageKind;
enum {
RowsAtCompileTime = (int(DiagIndex) == DynamicIndex || int(MatrixType::SizeAtCompileTime) == Dynamic) ? Dynamic
@@ -97,11 +97,11 @@
EIGEN_DEVICE_FUNC EIGEN_CONSTEXPR
inline Index outerStride() const EIGEN_NOEXCEPT { return 0; }
- typedef typename internal::conditional<
- internal::is_lvalue<MatrixType>::value,
- Scalar,
- const Scalar
- >::type ScalarWithConstIfNotLvalue;
+ typedef std::conditional_t<
+ internal::is_lvalue<MatrixType>::value,
+ Scalar,
+ const Scalar
+ > ScalarWithConstIfNotLvalue;
EIGEN_DEVICE_FUNC
inline ScalarWithConstIfNotLvalue* data() { return &(m_matrix.coeffRef(rowOffset(), colOffset())); }
@@ -147,7 +147,7 @@
}
EIGEN_DEVICE_FUNC
- inline const typename internal::remove_all<typename MatrixType::Nested>::type&
+ inline const internal::remove_all_t<typename MatrixType::Nested>&
nestedExpression() const
{
return m_matrix;
@@ -193,7 +193,8 @@
/** This is the const version of diagonal(). */
template<typename Derived>
-EIGEN_DEVICE_FUNC inline typename MatrixBase<Derived>::ConstDiagonalReturnType
+EIGEN_DEVICE_FUNC inline
+const typename MatrixBase<Derived>::ConstDiagonalReturnType
MatrixBase<Derived>::diagonal() const
{
return ConstDiagonalReturnType(derived());
@@ -211,18 +212,18 @@
*
* \sa MatrixBase::diagonal(), class Diagonal */
template<typename Derived>
-EIGEN_DEVICE_FUNC inline typename MatrixBase<Derived>::DiagonalDynamicIndexReturnType
+EIGEN_DEVICE_FUNC inline Diagonal<Derived, DynamicIndex>
MatrixBase<Derived>::diagonal(Index index)
{
- return DiagonalDynamicIndexReturnType(derived(), index);
+ return Diagonal<Derived, DynamicIndex>(derived(), index);
}
/** This is the const version of diagonal(Index). */
template<typename Derived>
-EIGEN_DEVICE_FUNC inline typename MatrixBase<Derived>::ConstDiagonalDynamicIndexReturnType
+EIGEN_DEVICE_FUNC inline const Diagonal<const Derived, DynamicIndex>
MatrixBase<Derived>::diagonal(Index index) const
{
- return ConstDiagonalDynamicIndexReturnType(derived(), index);
+ return Diagonal<const Derived, DynamicIndex>(derived(), index);
}
/** \returns an expression of the \a DiagIndex-th sub or super diagonal of the matrix \c *this
@@ -239,20 +240,20 @@
template<typename Derived>
template<int Index_>
EIGEN_DEVICE_FUNC
-inline typename MatrixBase<Derived>::template DiagonalIndexReturnType<Index_>::Type
+inline Diagonal<Derived, Index_>
MatrixBase<Derived>::diagonal()
{
- return typename DiagonalIndexReturnType<Index_>::Type(derived());
+ return Diagonal<Derived, Index_>(derived());
}
/** This is the const version of diagonal<int>(). */
template<typename Derived>
template<int Index_>
EIGEN_DEVICE_FUNC
-inline typename MatrixBase<Derived>::template ConstDiagonalIndexReturnType<Index_>::Type
+inline const Diagonal<const Derived, Index_>
MatrixBase<Derived>::diagonal() const
{
- return typename ConstDiagonalIndexReturnType<Index_>::Type(derived());
+ return Diagonal<const Derived, Index_>(derived());
}
} // end namespace Eigen
diff --git a/Eigen/src/Core/DiagonalMatrix.h b/Eigen/src/Core/DiagonalMatrix.h
index 45375d5..06cfdc1 100644
--- a/Eigen/src/Core/DiagonalMatrix.h
+++ b/Eigen/src/Core/DiagonalMatrix.h
@@ -200,6 +200,10 @@
explicit EIGEN_STRONG_INLINE DiagonalMatrix(const std::initializer_list<std::initializer_list<Scalar>>& list)
: m_diagonal(list) {}
+ /** \brief Constructs a DiagonalMatrix from an r-value diagonal vector type */
+ EIGEN_DEVICE_FUNC
+ explicit inline DiagonalMatrix(DiagonalVectorType&& diag) : m_diagonal(std::move(diag)) {}
+
/** Copy constructor. */
template<typename OtherDerived>
EIGEN_DEVICE_FUNC
diff --git a/Eigen/src/Core/Dot.h b/Eigen/src/Core/Dot.h
index a7c20f5..0c13192 100644
--- a/Eigen/src/Core/Dot.h
+++ b/Eigen/src/Core/Dot.h
@@ -20,14 +20,9 @@
// with mismatched types, the compiler emits errors about failing to instantiate cwiseProduct BEFORE
// looking at the static assertions. Thus this is a trick to get better compile errors.
template<typename T, typename U,
-// the NeedToTranspose condition here is taken straight from Assign.h
- bool NeedToTranspose = T::IsVectorAtCompileTime
- && U::IsVectorAtCompileTime
- && ((int(T::RowsAtCompileTime) == 1 && int(U::ColsAtCompileTime) == 1)
- | // FIXME | instead of || to please GCC 4.4.0 stupid warning "suggest parentheses around &&".
- // revert to || as soon as not needed anymore.
- (int(T::ColsAtCompileTime) == 1 && int(U::RowsAtCompileTime) == 1))
->
+ bool NeedToTranspose = T::IsVectorAtCompileTime && U::IsVectorAtCompileTime &&
+ ((int(T::RowsAtCompileTime) == 1 && int(U::ColsAtCompileTime) == 1) ||
+ (int(T::ColsAtCompileTime) == 1 && int(U::RowsAtCompileTime) == 1))>
struct dot_nocheck
{
typedef scalar_conj_product_op<typename traits<T>::Scalar,typename traits<U>::Scalar> conj_prod;
diff --git a/Eigen/src/Core/ForceAlignedAccess.h b/Eigen/src/Core/ForceAlignedAccess.h
index 7c46573..b00785e 100644
--- a/Eigen/src/Core/ForceAlignedAccess.h
+++ b/Eigen/src/Core/ForceAlignedAccess.h
@@ -130,7 +130,7 @@
*/
template<typename Derived>
template<bool Enable>
-inline typename internal::add_const_on_value_type<typename internal::conditional<Enable,ForceAlignedAccess<Derived>,Derived&>::type>::type
+inline add_const_on_value_type_t<std::conditional_t<Enable,ForceAlignedAccess<Derived>,Derived&>>
MatrixBase<Derived>::forceAlignedAccessIf() const
{
return derived(); // FIXME This should not work but apparently is never used
@@ -141,7 +141,7 @@
*/
template<typename Derived>
template<bool Enable>
-inline typename internal::conditional<Enable,ForceAlignedAccess<Derived>,Derived&>::type
+inline std::conditional_t<Enable,ForceAlignedAccess<Derived>,Derived&>
MatrixBase<Derived>::forceAlignedAccessIf()
{
return derived(); // FIXME This should not work but apparently is never used
diff --git a/Eigen/src/Core/GeneralProduct.h b/Eigen/src/Core/GeneralProduct.h
index 783a3b6..661a3c4 100644
--- a/Eigen/src/Core/GeneralProduct.h
+++ b/Eigen/src/Core/GeneralProduct.h
@@ -52,8 +52,8 @@
template<typename Lhs, typename Rhs> struct product_type
{
- typedef typename remove_all<Lhs>::type Lhs_;
- typedef typename remove_all<Rhs>::type Rhs_;
+ typedef remove_all_t<Lhs> Lhs_;
+ typedef remove_all_t<Rhs> Rhs_;
enum {
MaxRows = traits<Lhs_>::MaxRowsAtCompileTime,
Rows = traits<Lhs_>::RowsAtCompileTime,
@@ -233,7 +233,7 @@
ResScalar actualAlpha = combine_scalar_factors(alpha, lhs, rhs);
// make sure Dest is a compile-time vector type (bug 1166)
- typedef typename conditional<Dest::IsVectorAtCompileTime, Dest, typename Dest::ColXpr>::type ActualDest;
+ typedef std::conditional_t<Dest::IsVectorAtCompileTime, Dest, typename Dest::ColXpr> ActualDest;
enum {
// FIXME find a way to allow an inner stride on the result if packet_traits<Scalar>::size==1
@@ -316,10 +316,10 @@
typedef typename LhsBlasTraits::DirectLinearAccessType ActualLhsType;
typedef internal::blas_traits<Rhs> RhsBlasTraits;
typedef typename RhsBlasTraits::DirectLinearAccessType ActualRhsType;
- typedef typename internal::remove_all<ActualRhsType>::type ActualRhsTypeCleaned;
+ typedef internal::remove_all_t<ActualRhsType> ActualRhsTypeCleaned;
- typename add_const<ActualLhsType>::type actualLhs = LhsBlasTraits::extract(lhs);
- typename add_const<ActualRhsType>::type actualRhs = RhsBlasTraits::extract(rhs);
+ std::add_const_t<ActualLhsType> actualLhs = LhsBlasTraits::extract(lhs);
+ std::add_const_t<ActualRhsType> actualRhs = RhsBlasTraits::extract(rhs);
ResScalar actualAlpha = combine_scalar_factors(alpha, lhs, rhs);
diff --git a/Eigen/src/Core/GenericPacketMath.h b/Eigen/src/Core/GenericPacketMath.h
index f60724a..3ea6855 100644
--- a/Eigen/src/Core/GenericPacketMath.h
+++ b/Eigen/src/Core/GenericPacketMath.h
@@ -220,6 +220,15 @@
template<> EIGEN_DEVICE_FUNC inline bool
padd(const bool& a, const bool& b) { return a || b; }
+/** \internal \returns a packet version of \a *from, (un-aligned masked add)
+ * There is no generic implementation. We only have implementations for specialized
+ * cases. Generic case should not be called.
+ */
+template<typename Packet> EIGEN_DEVICE_FUNC inline
+std::enable_if_t<unpacket_traits<Packet>::masked_fpops_available, Packet>
+padd(const Packet& a, const Packet& b, typename unpacket_traits<Packet>::mask_t umask);
+
+
/** \internal \returns a - b (coeff-wise) */
template<typename Packet> EIGEN_DEVICE_FUNC inline Packet
psub(const Packet& a, const Packet& b) { return a-b; }
@@ -262,7 +271,7 @@
// have another option, since the scalar type requires initialization.
template<typename T>
struct ptrue_impl<T,
- typename internal::enable_if<is_scalar<T>::value && NumTraits<T>::RequireInitialization>::type > {
+ std::enable_if_t<is_scalar<T>::value && NumTraits<T>::RequireInitialization> > {
static EIGEN_DEVICE_FUNC inline T run(const T& /*a*/){
return T(1);
}
@@ -288,7 +297,7 @@
// for zero may not consist of all-zero bits.
template<typename T>
struct pzero_impl<T,
- typename internal::enable_if<is_scalar<T>::value>::type> {
+ std::enable_if_t<is_scalar<T>::value>> {
static EIGEN_DEVICE_FUNC inline T run(const T& /*a*/) {
return T(0);
}
@@ -359,16 +368,16 @@
EIGEN_DEVICE_FUNC static inline T bitwise_and(const T& a, const T& b) {
return binary(a, b, bit_and<unsigned char>());
}
- EIGEN_DEVICE_FUNC static inline T bitwise_or(const T& a, const T& b) {
+ EIGEN_DEVICE_FUNC static inline T bitwise_or(const T& a, const T& b) {
return binary(a, b, bit_or<unsigned char>());
}
EIGEN_DEVICE_FUNC static inline T bitwise_xor(const T& a, const T& b) {
return binary(a, b, bit_xor<unsigned char>());
}
- EIGEN_DEVICE_FUNC static inline T bitwise_not(const T& a) {
+ EIGEN_DEVICE_FUNC static inline T bitwise_not(const T& a) {
return unary(a,bit_not<unsigned char>());
}
-
+
private:
template<typename Op>
EIGEN_DEVICE_FUNC static inline T unary(const T& a, Op op) {
@@ -401,8 +410,8 @@
// For integers or non-trivial scalars, use binary operators.
template<typename T>
struct bitwise_helper<T,
- typename internal::enable_if<
- is_scalar<T>::value && (NumTraits<T>::IsInteger || NumTraits<T>::RequireInitialization)>::type
+ typename std::enable_if_t<
+ is_scalar<T>::value && (NumTraits<T>::IsInteger || NumTraits<T>::RequireInitialization)>
> : public operator_bitwise_helper<T> {};
/** \internal \returns the bitwise and of \a a and \a b */
@@ -444,7 +453,7 @@
// For scalars, use ternary select.
template<typename Packet>
struct pselect_impl<Packet,
- typename internal::enable_if<is_scalar<Packet>::value>::type > {
+ std::enable_if_t<is_scalar<Packet>::value> > {
static EIGEN_DEVICE_FUNC inline Packet run(const Packet& mask, const Packet& a, const Packet& b) {
return numext::equal_strict(mask, Packet(0)) ? b : a;
}
@@ -610,7 +619,7 @@
* cases. Generic case should not be called.
*/
template<typename Packet> EIGEN_DEVICE_FUNC inline
-typename enable_if<unpacket_traits<Packet>::masked_load_available, Packet>::type
+std::enable_if_t<unpacket_traits<Packet>::masked_load_available, Packet>
ploadu(const typename unpacket_traits<Packet>::type* from, typename unpacket_traits<Packet>::mask_t umask);
/** \internal \returns a packet with constant coefficients \a a, e.g.: (a,a,a,a) */
@@ -709,7 +718,7 @@
*/
template<typename Scalar, typename Packet>
EIGEN_DEVICE_FUNC inline
-typename enable_if<unpacket_traits<Packet>::masked_store_available, void>::type
+std::enable_if_t<unpacket_traits<Packet>::masked_store_available, void>
pstoreu(Scalar* to, const Packet& from, typename unpacket_traits<Packet>::mask_t umask);
template<typename Scalar, typename Packet> EIGEN_DEVICE_FUNC inline Packet pgather(const Scalar* from, Index /*stride*/)
@@ -810,7 +819,7 @@
template<typename Packet> EIGEN_DECLARE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS
Packet plog2(const Packet& a) {
typedef typename internal::unpacket_traits<Packet>::type Scalar;
- return pmul(pset1<Packet>(Scalar(EIGEN_LOG2E)), plog(a));
+ return pmul(pset1<Packet>(Scalar(EIGEN_LOG2E)), plog(a));
}
/** \internal \returns the square-root of \a a (coeff-wise) */
@@ -845,7 +854,7 @@
* For packet-size smaller or equal to 4, this boils down to a noop.
*/
template<typename Packet>
-EIGEN_DEVICE_FUNC inline typename conditional<(unpacket_traits<Packet>::size%8)==0,typename unpacket_traits<Packet>::half,Packet>::type
+EIGEN_DEVICE_FUNC inline std::conditional_t<(unpacket_traits<Packet>::size%8)==0,typename unpacket_traits<Packet>::half,Packet>
predux_half_dowto4(const Packet& a)
{ return a; }
@@ -877,7 +886,7 @@
template <typename Packet>
EIGEN_DEVICE_FUNC inline typename unpacket_traits<Packet>::type predux_mul(
const Packet& a) {
- typedef typename unpacket_traits<Packet>::type Scalar;
+ typedef typename unpacket_traits<Packet>::type Scalar;
return predux_helper(a, EIGEN_BINARY_OP_NAN_PROPAGATION(Scalar, (pmul<Scalar>)));
}
@@ -885,14 +894,14 @@
template <typename Packet>
EIGEN_DEVICE_FUNC inline typename unpacket_traits<Packet>::type predux_min(
const Packet &a) {
- typedef typename unpacket_traits<Packet>::type Scalar;
+ typedef typename unpacket_traits<Packet>::type Scalar;
return predux_helper(a, EIGEN_BINARY_OP_NAN_PROPAGATION(Scalar, (pmin<PropagateFast, Scalar>)));
}
template <int NaNPropagation, typename Packet>
EIGEN_DEVICE_FUNC inline typename unpacket_traits<Packet>::type predux_min(
const Packet& a) {
- typedef typename unpacket_traits<Packet>::type Scalar;
+ typedef typename unpacket_traits<Packet>::type Scalar;
return predux_helper(a, EIGEN_BINARY_OP_NAN_PROPAGATION(Scalar, (pmin<NaNPropagation, Scalar>)));
}
@@ -900,14 +909,14 @@
template <typename Packet>
EIGEN_DEVICE_FUNC inline typename unpacket_traits<Packet>::type predux_max(
const Packet &a) {
- typedef typename unpacket_traits<Packet>::type Scalar;
+ typedef typename unpacket_traits<Packet>::type Scalar;
return predux_helper(a, EIGEN_BINARY_OP_NAN_PROPAGATION(Scalar, (pmax<PropagateFast, Scalar>)));
}
template <int NaNPropagation, typename Packet>
EIGEN_DEVICE_FUNC inline typename unpacket_traits<Packet>::type predux_max(
const Packet& a) {
- typedef typename unpacket_traits<Packet>::type Scalar;
+ typedef typename unpacket_traits<Packet>::type Scalar;
return predux_helper(a, EIGEN_BINARY_OP_NAN_PROPAGATION(Scalar, (pmax<NaNPropagation, Scalar>)));
}
diff --git a/Eigen/src/Core/IO.h b/Eigen/src/Core/IO.h
index efc2c43..897d7b0 100644
--- a/Eigen/src/Core/IO.h
+++ b/Eigen/src/Core/IO.h
@@ -133,7 +133,6 @@
std::ostream & print_matrix(std::ostream & s, const Derived& _m, const IOFormat& fmt)
{
using internal::is_same;
- using internal::conditional;
if(_m.size() == 0)
{
@@ -143,22 +142,21 @@
typename Derived::Nested m = _m;
typedef typename Derived::Scalar Scalar;
- typedef typename
- conditional<
+ typedef std::conditional_t<
is_same<Scalar, char>::value ||
is_same<Scalar, unsigned char>::value ||
is_same<Scalar, numext::int8_t>::value ||
is_same<Scalar, numext::uint8_t>::value,
int,
- typename conditional<
+ std::conditional_t<
is_same<Scalar, std::complex<char> >::value ||
is_same<Scalar, std::complex<unsigned char> >::value ||
is_same<Scalar, std::complex<numext::int8_t> >::value ||
is_same<Scalar, std::complex<numext::uint8_t> >::value,
std::complex<int>,
const Scalar&
- >::type
- >::type PrintType;
+ >
+ > PrintType;
Index width = 0;
diff --git a/Eigen/src/Core/IndexedView.h b/Eigen/src/Core/IndexedView.h
index e7ca88b..c0907bf 100644
--- a/Eigen/src/Core/IndexedView.h
+++ b/Eigen/src/Core/IndexedView.h
@@ -42,7 +42,7 @@
InnerSize = XprTypeIsRowMajor ? ColsAtCompileTime : RowsAtCompileTime,
IsBlockAlike = InnerIncr==1 && OuterIncr==1,
- IsInnerPannel = HasSameStorageOrderAsXprType && is_same<AllRange<InnerSize>,typename conditional<XprTypeIsRowMajor,ColIndices,RowIndices>::type>::value,
+ IsInnerPannel = HasSameStorageOrderAsXprType && is_same<AllRange<InnerSize>,std::conditional_t<XprTypeIsRowMajor,ColIndices,RowIndices>>::value,
InnerStrideAtCompileTime = InnerIncr<0 || InnerIncr==DynamicIndex || XprInnerStride==Dynamic || InnerIncr==UndefinedIncr ? Dynamic : XprInnerStride * InnerIncr,
OuterStrideAtCompileTime = OuterIncr<0 || OuterIncr==DynamicIndex || XprOuterstride==Dynamic || OuterIncr==UndefinedIncr ? Dynamic : XprOuterstride * OuterIncr,
@@ -116,7 +116,7 @@
EIGEN_INHERIT_ASSIGNMENT_OPERATORS(IndexedView)
typedef typename internal::ref_selector<XprType>::non_const_type MatrixTypeNested;
- typedef typename internal::remove_all<XprType>::type NestedExpression;
+ typedef internal::remove_all_t<XprType> NestedExpression;
template<typename T0, typename T1>
IndexedView(XprType& xpr, const T0& rowIndices, const T1& colIndices)
@@ -130,11 +130,11 @@
Index cols() const { return internal::index_list_size(m_colIndices); }
/** \returns the nested expression */
- const typename internal::remove_all<XprType>::type&
+ const internal::remove_all_t<XprType>&
nestedExpression() const { return m_xpr; }
/** \returns the nested expression */
- typename internal::remove_reference<XprType>::type&
+ std::remove_reference_t<XprType>&
nestedExpression() { return m_xpr; }
/** \returns a const reference to the object storing/generating the row indices */
diff --git a/Eigen/src/Core/Inverse.h b/Eigen/src/Core/Inverse.h
index 268f8d4..9c70733 100644
--- a/Eigen/src/Core/Inverse.h
+++ b/Eigen/src/Core/Inverse.h
@@ -48,9 +48,9 @@
typedef typename XprType::StorageIndex StorageIndex;
typedef typename XprType::Scalar Scalar;
typedef typename internal::ref_selector<XprType>::type XprTypeNested;
- typedef typename internal::remove_all<XprTypeNested>::type XprTypeNestedCleaned;
+ typedef internal::remove_all_t<XprTypeNested> XprTypeNestedCleaned;
typedef typename internal::ref_selector<Inverse>::type Nested;
- typedef typename internal::remove_all<XprType>::type NestedExpression;
+ typedef internal::remove_all_t<XprType> NestedExpression;
explicit EIGEN_DEVICE_FUNC Inverse(const XprType &xpr)
: m_xpr(xpr)
@@ -104,7 +104,7 @@
unary_evaluator(const InverseType& inv_xpr)
: m_result(inv_xpr.rows(), inv_xpr.cols())
{
- ::new (static_cast<Base*>(this)) Base(m_result);
+ internal::construct_at<Base>(this, m_result);
internal::call_assignment_no_alias(m_result, inv_xpr);
}
diff --git a/Eigen/src/Core/MapBase.h b/Eigen/src/Core/MapBase.h
index 89192c3..bf8c163 100644
--- a/Eigen/src/Core/MapBase.h
+++ b/Eigen/src/Core/MapBase.h
@@ -53,11 +53,11 @@
typedef typename internal::traits<Derived>::Scalar Scalar;
typedef typename internal::packet_traits<Scalar>::type PacketScalar;
typedef typename NumTraits<Scalar>::Real RealScalar;
- typedef typename internal::conditional<
- bool(internal::is_lvalue<Derived>::value),
- Scalar *,
- const Scalar *>::type
- PointerType;
+ typedef std::conditional_t<
+ bool(internal::is_lvalue<Derived>::value),
+ Scalar *,
+ const Scalar *>
+ PointerType;
using Base::derived;
// using Base::RowsAtCompileTime;
@@ -191,7 +191,7 @@
template<typename T>
EIGEN_DEVICE_FUNC
- void checkSanity(typename internal::enable_if<(internal::traits<T>::Alignment>0),void*>::type = 0) const
+ void checkSanity(std::enable_if_t<(internal::traits<T>::Alignment>0),void*> = 0) const
{
#if EIGEN_MAX_ALIGN_BYTES>0
// innerStride() is not set yet when this function is called, so we optimistically assume the lowest plausible value:
@@ -204,7 +204,7 @@
template<typename T>
EIGEN_DEVICE_FUNC
- void checkSanity(typename internal::enable_if<internal::traits<T>::Alignment==0,void*>::type = 0) const
+ void checkSanity(std::enable_if_t<internal::traits<T>::Alignment==0,void*> = 0) const
{}
PointerType m_data;
@@ -247,11 +247,11 @@
using Base::rowStride;
using Base::colStride;
- typedef typename internal::conditional<
+ typedef std::conditional_t<
internal::is_lvalue<Derived>::value,
Scalar,
const Scalar
- >::type ScalarWithConstIfNotLvalue;
+ > ScalarWithConstIfNotLvalue;
EIGEN_DEVICE_FUNC
inline const Scalar* data() const { return this->m_data; }
diff --git a/Eigen/src/Core/MathFunctions.h b/Eigen/src/Core/MathFunctions.h
index 55e3159..59cc831 100644
--- a/Eigen/src/Core/MathFunctions.h
+++ b/Eigen/src/Core/MathFunctions.h
@@ -434,9 +434,9 @@
// generating warnings on clang. Here we explicitly cast the real component.
template<typename OldType, typename NewType>
struct cast_impl<OldType, NewType,
- typename internal::enable_if<
+ typename std::enable_if_t<
!NumTraits<OldType>::IsComplex && NumTraits<NewType>::IsComplex
- >::type>
+ >>
{
EIGEN_DEVICE_FUNC
static inline NewType run(const OldType& x)
@@ -891,7 +891,7 @@
// ScalarX is the widest of ScalarU and unsigned int.
// We'll deal only with ScalarX and unsigned int below thus avoiding signed
// types and arithmetic and signed overflows (which are undefined behavior).
- typedef typename conditional<(ScalarU(-1) > unsigned(-1)), ScalarU, unsigned>::type ScalarX;
+ typedef std::conditional_t<(ScalarU(-1) > unsigned(-1)), ScalarU, unsigned> ScalarX;
// The following difference doesn't overflow, provided our integer types are two's
// complement and have the same number of padding bits in signed and unsigned variants.
// This is the case in most modern implementations of C++.
@@ -962,22 +962,22 @@
template<typename T>
EIGEN_DEVICE_FUNC
-typename internal::enable_if<internal::is_integral<T>::value,bool>::type
+std::enable_if_t<internal::is_integral<T>::value,bool>
isnan_impl(const T&) { return false; }
template<typename T>
EIGEN_DEVICE_FUNC
-typename internal::enable_if<internal::is_integral<T>::value,bool>::type
+std::enable_if_t<internal::is_integral<T>::value,bool>
isinf_impl(const T&) { return false; }
template<typename T>
EIGEN_DEVICE_FUNC
-typename internal::enable_if<internal::is_integral<T>::value,bool>::type
+std::enable_if_t<internal::is_integral<T>::value,bool>
isfinite_impl(const T&) { return true; }
template<typename T>
EIGEN_DEVICE_FUNC
-typename internal::enable_if<(!internal::is_integral<T>::value)&&(!NumTraits<T>::IsComplex),bool>::type
+std::enable_if_t<(!internal::is_integral<T>::value)&&(!NumTraits<T>::IsComplex),bool>
isfinite_impl(const T& x)
{
#if defined(EIGEN_GPU_COMPILE_PHASE)
@@ -992,7 +992,7 @@
template<typename T>
EIGEN_DEVICE_FUNC
-typename internal::enable_if<(!internal::is_integral<T>::value)&&(!NumTraits<T>::IsComplex),bool>::type
+std::enable_if_t<(!internal::is_integral<T>::value)&&(!NumTraits<T>::IsComplex),bool>
isinf_impl(const T& x)
{
#if defined(EIGEN_GPU_COMPILE_PHASE)
@@ -1007,7 +1007,7 @@
template<typename T>
EIGEN_DEVICE_FUNC
-typename internal::enable_if<(!internal::is_integral<T>::value)&&(!NumTraits<T>::IsComplex),bool>::type
+std::enable_if_t<(!internal::is_integral<T>::value)&&(!NumTraits<T>::IsComplex),bool>
isnan_impl(const T& x)
{
#if defined(EIGEN_GPU_COMPILE_PHASE)
@@ -1232,7 +1232,7 @@
template<typename Scalar>
EIGEN_DEVICE_FUNC
-inline typename internal::add_const_on_value_type< EIGEN_MATHFUNC_RETVAL(real_ref, Scalar) >::type real_ref(const Scalar& x)
+inline internal::add_const_on_value_type_t< EIGEN_MATHFUNC_RETVAL(real_ref, Scalar) > real_ref(const Scalar& x)
{
return internal::real_ref_impl<Scalar>::run(x);
}
@@ -1260,7 +1260,7 @@
template<typename Scalar>
EIGEN_DEVICE_FUNC
-inline typename internal::add_const_on_value_type< EIGEN_MATHFUNC_RETVAL(imag_ref, Scalar) >::type imag_ref(const Scalar& x)
+inline internal::add_const_on_value_type_t< EIGEN_MATHFUNC_RETVAL(imag_ref, Scalar) > imag_ref(const Scalar& x)
{
return internal::imag_ref_impl<Scalar>::run(x);
}
@@ -1503,7 +1503,7 @@
template<typename T>
EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE
-typename internal::enable_if<NumTraits<T>::IsSigned || NumTraits<T>::IsComplex,typename NumTraits<T>::Real>::type
+std::enable_if_t<NumTraits<T>::IsSigned || NumTraits<T>::IsComplex,typename NumTraits<T>::Real>
abs(const T &x) {
EIGEN_USING_STD(abs);
return abs(x);
@@ -1511,7 +1511,7 @@
template<typename T>
EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE
-typename internal::enable_if<!(NumTraits<T>::IsSigned || NumTraits<T>::IsComplex),typename NumTraits<T>::Real>::type
+std::enable_if_t<!(NumTraits<T>::IsSigned || NumTraits<T>::IsComplex),typename NumTraits<T>::Real>
abs(const T &x) {
return x;
}
diff --git a/Eigen/src/Core/Matrix.h b/Eigen/src/Core/Matrix.h
index 5145574..23acd8a 100644
--- a/Eigen/src/Core/Matrix.h
+++ b/Eigen/src/Core/Matrix.h
@@ -20,7 +20,7 @@
struct traits<Matrix<Scalar_, Rows_, Cols_, Options_, MaxRows_, MaxCols_> >
{
private:
- enum { size = internal::size_at_compile_time<Rows_,Cols_>::ret };
+ constexpr static int size = internal::size_at_compile_time(Rows_,Cols_);
typedef typename find_best_packet<Scalar_,size>::type PacketScalar;
enum {
row_major_bit = Options_&RowMajor ? RowMajorBit : 0,
@@ -42,7 +42,7 @@
ColsAtCompileTime = Cols_,
MaxRowsAtCompileTime = MaxRows_,
MaxColsAtCompileTime = MaxCols_,
- Flags = compute_matrix_flags<Scalar_, Rows_, Cols_, Options_, MaxRows_, MaxCols_>::ret,
+ Flags = compute_matrix_flags(Options_),
Options = Options_,
InnerStrideAtCompileTime = 1,
OuterStrideAtCompileTime = (Options&RowMajor) ? ColsAtCompileTime : RowsAtCompileTime,
diff --git a/Eigen/src/Core/MatrixBase.h b/Eigen/src/Core/MatrixBase.h
index 20c9138..a5463b3 100644
--- a/Eigen/src/Core/MatrixBase.h
+++ b/Eigen/src/Core/MatrixBase.h
@@ -109,10 +109,10 @@
/** \internal Represents a matrix with all coefficients equal to one another*/
typedef CwiseNullaryOp<internal::scalar_constant_op<Scalar>,PlainObject> ConstantReturnType;
/** \internal the return type of MatrixBase::adjoint() */
- typedef typename internal::conditional<NumTraits<Scalar>::IsComplex,
- CwiseUnaryOp<internal::scalar_conjugate_op<Scalar>, ConstTransposeReturnType>,
- ConstTransposeReturnType
- >::type AdjointReturnType;
+ typedef std::conditional_t<NumTraits<Scalar>::IsComplex,
+ CwiseUnaryOp<internal::scalar_conjugate_op<Scalar>, ConstTransposeReturnType>,
+ ConstTransposeReturnType
+ > AdjointReturnType;
/** \internal Return type of eigenvalues() */
typedef Matrix<std::complex<RealScalar>, internal::traits<Derived>::ColsAtCompileTime, 1, ColMajor> EigenvaluesReturnType;
/** \internal the return type of identity */
@@ -208,28 +208,22 @@
EIGEN_DEVICE_FUNC
DiagonalReturnType diagonal();
- typedef typename internal::add_const<Diagonal<const Derived> >::type ConstDiagonalReturnType;
+ typedef Diagonal<const Derived> ConstDiagonalReturnType;
EIGEN_DEVICE_FUNC
- ConstDiagonalReturnType diagonal() const;
-
- template<int Index> struct DiagonalIndexReturnType { typedef Diagonal<Derived,Index> Type; };
- template<int Index> struct ConstDiagonalIndexReturnType { typedef const Diagonal<const Derived,Index> Type; };
+ const ConstDiagonalReturnType diagonal() const;
template<int Index>
EIGEN_DEVICE_FUNC
- typename DiagonalIndexReturnType<Index>::Type diagonal();
+ Diagonal<Derived, Index> diagonal();
template<int Index>
EIGEN_DEVICE_FUNC
- typename ConstDiagonalIndexReturnType<Index>::Type diagonal() const;
-
- typedef Diagonal<Derived,DynamicIndex> DiagonalDynamicIndexReturnType;
- typedef typename internal::add_const<Diagonal<const Derived,DynamicIndex> >::type ConstDiagonalDynamicIndexReturnType;
+ const Diagonal<const Derived, Index> diagonal() const;
EIGEN_DEVICE_FUNC
- DiagonalDynamicIndexReturnType diagonal(Index index);
+ Diagonal<Derived, DynamicIndex> diagonal(Index index);
EIGEN_DEVICE_FUNC
- ConstDiagonalDynamicIndexReturnType diagonal(Index index) const;
+ const Diagonal<const Derived, DynamicIndex> diagonal(Index index) const;
template<unsigned int Mode> struct TriangularViewReturnType { typedef TriangularView<Derived, Mode> Type; };
template<unsigned int Mode> struct ConstTriangularViewReturnType { typedef const TriangularView<const Derived, Mode> Type; };
diff --git a/Eigen/src/Core/NumTraits.h b/Eigen/src/Core/NumTraits.h
index e484bb6..74edd2c 100644
--- a/Eigen/src/Core/NumTraits.h
+++ b/Eigen/src/Core/NumTraits.h
@@ -164,11 +164,7 @@
};
typedef T Real;
- typedef typename internal::conditional<
- IsInteger,
- typename internal::conditional<sizeof(T)<=2, float, double>::type,
- T
- >::type NonInteger;
+ typedef std::conditional_t<IsInteger, std::conditional_t<sizeof(T)<=2, float, double>, T> NonInteger;
typedef T Nested;
typedef T Literal;
diff --git a/Eigen/src/Core/PartialReduxEvaluator.h b/Eigen/src/Core/PartialReduxEvaluator.h
index b6f31f9..693fc35 100644
--- a/Eigen/src/Core/PartialReduxEvaluator.h
+++ b/Eigen/src/Core/PartialReduxEvaluator.h
@@ -141,8 +141,8 @@
{
typedef PartialReduxExpr<ArgType, MemberOp, Direction> XprType;
typedef typename internal::nested_eval<ArgType,1>::type ArgTypeNested;
- typedef typename internal::add_const_on_value_type<ArgTypeNested>::type ConstArgTypeNested;
- typedef typename internal::remove_all<ArgTypeNested>::type ArgTypeNestedCleaned;
+ typedef add_const_on_value_type_t<ArgTypeNested> ConstArgTypeNested;
+ typedef internal::remove_all_t<ArgTypeNested> ArgTypeNestedCleaned;
typedef typename ArgType::Scalar InputScalar;
typedef typename XprType::Scalar Scalar;
enum {
diff --git a/Eigen/src/Core/PermutationMatrix.h b/Eigen/src/Core/PermutationMatrix.h
index 1b4195a..73a7300 100644
--- a/Eigen/src/Core/PermutationMatrix.h
+++ b/Eigen/src/Core/PermutationMatrix.h
@@ -500,7 +500,7 @@
{}
/** const version of indices(). */
- const typename internal::remove_all<typename IndicesType::Nested>::type&
+ const internal::remove_all_t<typename IndicesType::Nested>&
indices() const { return m_indices; }
protected:
diff --git a/Eigen/src/Core/PlainObjectBase.h b/Eigen/src/Core/PlainObjectBase.h
index fb7cd0b..e0bde54 100644
--- a/Eigen/src/Core/PlainObjectBase.h
+++ b/Eigen/src/Core/PlainObjectBase.h
@@ -798,7 +798,7 @@
template<typename T0, typename T1>
EIGEN_DEVICE_FUNC
- EIGEN_STRONG_INLINE void _init2(Index rows, Index cols, typename internal::enable_if<Base::SizeAtCompileTime!=2,T0>::type* = 0)
+ EIGEN_STRONG_INLINE void _init2(Index rows, Index cols, std::enable_if_t<Base::SizeAtCompileTime!=2,T0>* = 0)
{
const bool t0_is_integer_alike = internal::is_valid_index_type<T0>::value;
const bool t1_is_integer_alike = internal::is_valid_index_type<T1>::value;
@@ -810,7 +810,7 @@
template<typename T0, typename T1>
EIGEN_DEVICE_FUNC
- EIGEN_STRONG_INLINE void _init2(const T0& val0, const T1& val1, typename internal::enable_if<Base::SizeAtCompileTime==2,T0>::type* = 0)
+ EIGEN_STRONG_INLINE void _init2(const T0& val0, const T1& val1, std::enable_if_t<Base::SizeAtCompileTime==2,T0>* = 0)
{
EIGEN_STATIC_ASSERT_VECTOR_SPECIFIC_SIZE(PlainObjectBase, 2)
m_storage.data()[0] = Scalar(val0);
@@ -820,10 +820,10 @@
template<typename T0, typename T1>
EIGEN_DEVICE_FUNC
EIGEN_STRONG_INLINE void _init2(const Index& val0, const Index& val1,
- typename internal::enable_if< (!internal::is_same<Index,Scalar>::value)
- && (internal::is_same<T0,Index>::value)
- && (internal::is_same<T1,Index>::value)
- && Base::SizeAtCompileTime==2,T1>::type* = 0)
+ std::enable_if_t< (!internal::is_same<Index,Scalar>::value)
+ && (internal::is_same<T0,Index>::value)
+ && (internal::is_same<T1,Index>::value)
+ && Base::SizeAtCompileTime==2,T1>* = 0)
{
EIGEN_STATIC_ASSERT_VECTOR_SPECIFIC_SIZE(PlainObjectBase, 2)
m_storage.data()[0] = Scalar(val0);
@@ -834,8 +834,8 @@
// then the argument is meant to be the size of the object.
template<typename T>
EIGEN_DEVICE_FUNC
- EIGEN_STRONG_INLINE void _init1(Index size, typename internal::enable_if< (Base::SizeAtCompileTime!=1 || !internal::is_convertible<T, Scalar>::value)
- && ((!internal::is_same<typename internal::traits<Derived>::XprKind,ArrayXpr>::value || Base::SizeAtCompileTime==Dynamic)),T>::type* = 0)
+ EIGEN_STRONG_INLINE void _init1(Index size, std::enable_if_t< (Base::SizeAtCompileTime!=1 || !internal::is_convertible<T, Scalar>::value)
+ && ((!internal::is_same<typename internal::traits<Derived>::XprKind,ArrayXpr>::value || Base::SizeAtCompileTime==Dynamic)),T>* = 0)
{
// NOTE MSVC 2008 complains if we directly put bool(NumTraits<T>::IsInteger) as the EIGEN_STATIC_ASSERT argument.
const bool is_integer_alike = internal::is_valid_index_type<T>::value;
@@ -848,7 +848,7 @@
// We have a 1x1 matrix/array => the argument is interpreted as the value of the unique coefficient (case where scalar type can be implicitly converted)
template<typename T>
EIGEN_DEVICE_FUNC
- EIGEN_STRONG_INLINE void _init1(const Scalar& val0, typename internal::enable_if<Base::SizeAtCompileTime==1 && internal::is_convertible<T, Scalar>::value,T>::type* = 0)
+ EIGEN_STRONG_INLINE void _init1(const Scalar& val0, std::enable_if_t<Base::SizeAtCompileTime==1 && internal::is_convertible<T, Scalar>::value,T>* = 0)
{
EIGEN_STATIC_ASSERT_VECTOR_SPECIFIC_SIZE(PlainObjectBase, 1)
m_storage.data()[0] = val0;
@@ -858,10 +858,10 @@
template<typename T>
EIGEN_DEVICE_FUNC
EIGEN_STRONG_INLINE void _init1(const Index& val0,
- typename internal::enable_if< (!internal::is_same<Index,Scalar>::value)
- && (internal::is_same<Index,T>::value)
- && Base::SizeAtCompileTime==1
- && internal::is_convertible<T, Scalar>::value,T*>::type* = 0)
+ std::enable_if_t< (!internal::is_same<Index,Scalar>::value)
+ && (internal::is_same<Index,T>::value)
+ && Base::SizeAtCompileTime==1
+ && internal::is_convertible<T, Scalar>::value,T*>* = 0)
{
EIGEN_STATIC_ASSERT_VECTOR_SPECIFIC_SIZE(PlainObjectBase, 1)
m_storage.data()[0] = Scalar(val0);
@@ -914,10 +914,10 @@
template<typename T>
EIGEN_DEVICE_FUNC
EIGEN_STRONG_INLINE void _init1(const Scalar& val0,
- typename internal::enable_if< Base::SizeAtCompileTime!=Dynamic
- && Base::SizeAtCompileTime!=1
- && internal::is_convertible<T, Scalar>::value
- && internal::is_same<typename internal::traits<Derived>::XprKind,ArrayXpr>::value,T>::type* = 0)
+ std::enable_if_t< Base::SizeAtCompileTime!=Dynamic
+ && Base::SizeAtCompileTime!=1
+ && internal::is_convertible<T, Scalar>::value
+ && internal::is_same<typename internal::traits<Derived>::XprKind,ArrayXpr>::value,T>* = 0)
{
Base::setConstant(val0);
}
@@ -926,12 +926,12 @@
template<typename T>
EIGEN_DEVICE_FUNC
EIGEN_STRONG_INLINE void _init1(const Index& val0,
- typename internal::enable_if< (!internal::is_same<Index,Scalar>::value)
- && (internal::is_same<Index,T>::value)
- && Base::SizeAtCompileTime!=Dynamic
- && Base::SizeAtCompileTime!=1
- && internal::is_convertible<T, Scalar>::value
- && internal::is_same<typename internal::traits<Derived>::XprKind,ArrayXpr>::value,T*>::type* = 0)
+ std::enable_if_t< (!internal::is_same<Index,Scalar>::value)
+ && (internal::is_same<Index,T>::value)
+ && Base::SizeAtCompileTime!=Dynamic
+ && Base::SizeAtCompileTime!=1
+ && internal::is_convertible<T, Scalar>::value
+ && internal::is_same<typename internal::traits<Derived>::XprKind,ArrayXpr>::value,T*>* = 0)
{
Base::setConstant(val0);
}
@@ -982,7 +982,7 @@
template <typename Derived, typename OtherDerived, bool IsVector>
struct conservative_resize_like_impl
{
- static const bool IsRelocatable = std::is_trivially_copyable<typename Derived::Scalar>::value;
+ static constexpr bool IsRelocatable = std::is_trivially_copyable<typename Derived::Scalar>::value;
static void run(DenseBase<Derived>& _this, Index rows, Index cols)
{
if (_this.rows() == rows && _this.cols() == cols) return;
diff --git a/Eigen/src/Core/Product.h b/Eigen/src/Core/Product.h
index 3b788b3..85842d1 100644
--- a/Eigen/src/Core/Product.h
+++ b/Eigen/src/Core/Product.h
@@ -21,8 +21,8 @@
template<typename Lhs, typename Rhs, int Option>
struct traits<Product<Lhs, Rhs, Option> >
{
- typedef typename remove_all<Lhs>::type LhsCleaned;
- typedef typename remove_all<Rhs>::type RhsCleaned;
+ typedef remove_all_t<Lhs> LhsCleaned;
+ typedef remove_all_t<Rhs> RhsCleaned;
typedef traits<LhsCleaned> LhsTraits;
typedef traits<RhsCleaned> RhsTraits;
@@ -89,8 +89,8 @@
typedef typename internal::ref_selector<Lhs>::type LhsNested;
typedef typename internal::ref_selector<Rhs>::type RhsNested;
- typedef typename internal::remove_all<LhsNested>::type LhsNestedCleaned;
- typedef typename internal::remove_all<RhsNested>::type RhsNestedCleaned;
+ typedef internal::remove_all_t<LhsNested> LhsNestedCleaned;
+ typedef internal::remove_all_t<RhsNested> RhsNestedCleaned;
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
Product(const Lhs& lhs, const Rhs& rhs) : m_lhs(lhs), m_rhs(rhs)
diff --git a/Eigen/src/Core/ProductEvaluators.h b/Eigen/src/Core/ProductEvaluators.h
index a874ee3..05ffa25 100644
--- a/Eigen/src/Core/ProductEvaluators.h
+++ b/Eigen/src/Core/ProductEvaluators.h
@@ -109,14 +109,14 @@
explicit product_evaluator(const XprType& xpr)
: m_result(xpr.rows(), xpr.cols())
{
- ::new (static_cast<Base*>(this)) Base(m_result);
+ internal::construct_at<Base>(this, m_result);
// FIXME shall we handle nested_eval here?,
// if so, then we must take care at removing the call to nested_eval in the specializations (e.g., in permutation_matrix_product, transposition_matrix_product, etc.)
// typedef typename internal::nested_eval<Lhs,Rhs::ColsAtCompileTime>::type LhsNested;
// typedef typename internal::nested_eval<Rhs,Lhs::RowsAtCompileTime>::type RhsNested;
-// typedef typename internal::remove_all<LhsNested>::type LhsNestedCleaned;
-// typedef typename internal::remove_all<RhsNested>::type RhsNestedCleaned;
+// typedef internal::remove_all_t<LhsNested> LhsNestedCleaned;
+// typedef internal::remove_all_t<RhsNested> RhsNestedCleaned;
//
// const LhsNested lhs(xpr.lhs());
// const RhsNested rhs(xpr.rhs());
@@ -136,7 +136,7 @@
// Dense = Product
template< typename DstXprType, typename Lhs, typename Rhs, int Options, typename Scalar>
struct Assignment<DstXprType, Product<Lhs,Rhs,Options>, internal::assign_op<Scalar,Scalar>, Dense2Dense,
- typename enable_if<(Options==DefaultProduct || Options==AliasFreeProduct)>::type>
+ std::enable_if_t<(Options==DefaultProduct || Options==AliasFreeProduct)>>
{
typedef Product<Lhs,Rhs,Options> SrcXprType;
static EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
@@ -154,7 +154,7 @@
// Dense += Product
template< typename DstXprType, typename Lhs, typename Rhs, int Options, typename Scalar>
struct Assignment<DstXprType, Product<Lhs,Rhs,Options>, internal::add_assign_op<Scalar,Scalar>, Dense2Dense,
- typename enable_if<(Options==DefaultProduct || Options==AliasFreeProduct)>::type>
+ std::enable_if_t<(Options==DefaultProduct || Options==AliasFreeProduct)>>
{
typedef Product<Lhs,Rhs,Options> SrcXprType;
static EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
@@ -169,7 +169,7 @@
// Dense -= Product
template< typename DstXprType, typename Lhs, typename Rhs, int Options, typename Scalar>
struct Assignment<DstXprType, Product<Lhs,Rhs,Options>, internal::sub_assign_op<Scalar,Scalar>, Dense2Dense,
- typename enable_if<(Options==DefaultProduct || Options==AliasFreeProduct)>::type>
+ std::enable_if_t<(Options==DefaultProduct || Options==AliasFreeProduct)>>
{
typedef Product<Lhs,Rhs,Options> SrcXprType;
static EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
@@ -298,7 +298,7 @@
template<typename Lhs, typename Rhs>
struct generic_product_impl<Lhs,Rhs,DenseShape,DenseShape,OuterProduct>
{
- template<typename T> struct is_row_major : internal::conditional<(int(T::Flags)&RowMajorBit), internal::true_type, internal::false_type>::type {};
+ template<typename T> struct is_row_major : std::conditional_t<(int(T::Flags)&RowMajorBit), internal::true_type, internal::false_type> {};
typedef typename Product<Lhs,Rhs>::Scalar Scalar;
// TODO it would be nice to be able to exploit our *_assign_op functors for that purpose
@@ -372,7 +372,7 @@
typedef typename nested_eval<Rhs,1>::type RhsNested;
typedef typename Product<Lhs,Rhs>::Scalar Scalar;
enum { Side = Lhs::IsVectorAtCompileTime ? OnTheLeft : OnTheRight };
- typedef typename internal::remove_all<typename internal::conditional<int(Side)==OnTheRight,LhsNested,RhsNested>::type>::type MatrixType;
+ typedef internal::remove_all_t<std::conditional_t<int(Side)==OnTheRight,LhsNested,RhsNested>> MatrixType;
template<typename Dest>
static EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void scaleAndAddTo(Dest& dst, const Lhs& lhs, const Rhs& rhs, const Scalar& alpha)
@@ -450,7 +450,7 @@
blas_traits<Rhs>::extract(rhs).template conjugateIf<ConjRhs>(),
func,
actualAlpha,
- typename conditional<HasScalarFactor,true_type,false_type>::type());
+ std::conditional_t<HasScalarFactor,true_type,false_type>());
}
protected:
@@ -528,8 +528,8 @@
typedef typename internal::nested_eval<Lhs,Rhs::ColsAtCompileTime>::type LhsNested;
typedef typename internal::nested_eval<Rhs,Lhs::RowsAtCompileTime>::type RhsNested;
- typedef typename internal::remove_all<LhsNested>::type LhsNestedCleaned;
- typedef typename internal::remove_all<RhsNested>::type RhsNestedCleaned;
+ typedef internal::remove_all_t<LhsNested> LhsNestedCleaned;
+ typedef internal::remove_all_t<RhsNested> RhsNestedCleaned;
typedef evaluator<LhsNestedCleaned> LhsEtorType;
typedef evaluator<RhsNestedCleaned> RhsEtorType;
@@ -642,8 +642,8 @@
}
protected:
- typename internal::add_const_on_value_type<LhsNested>::type m_lhs;
- typename internal::add_const_on_value_type<RhsNested>::type m_rhs;
+ add_const_on_value_type_t<LhsNested> m_lhs;
+ add_const_on_value_type_t<RhsNested> m_rhs;
LhsEtorType m_lhsImpl;
RhsEtorType m_rhsImpl;
@@ -934,7 +934,7 @@
// FIXME: NVCC used to complain about the template keyword, but we have to check whether this is still the case.
// See also similar calls below.
return this->template packet_impl<LoadMode,PacketType>(row,col, row,
- typename internal::conditional<int(StorageOrder)==RowMajor, internal::true_type, internal::false_type>::type());
+ std::conditional_t<int(StorageOrder)==RowMajor, internal::true_type, internal::false_type>());
}
template<int LoadMode,typename PacketType>
@@ -976,7 +976,7 @@
EIGEN_STRONG_INLINE PacketType packet(Index row, Index col) const
{
return this->template packet_impl<LoadMode,PacketType>(row,col, col,
- typename internal::conditional<int(StorageOrder)==ColMajor, internal::true_type, internal::false_type>::type());
+ std::conditional_t<int(StorageOrder)==ColMajor, internal::true_type, internal::false_type>());
}
template<int LoadMode,typename PacketType>
@@ -1003,7 +1003,7 @@
struct permutation_matrix_product<ExpressionType, Side, Transposed, DenseShape>
{
typedef typename nested_eval<ExpressionType, 1>::type MatrixType;
- typedef typename remove_all<MatrixType>::type MatrixTypeCleaned;
+ typedef remove_all_t<MatrixType> MatrixTypeCleaned;
template<typename Dest, typename PermutationType>
static EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void run(Dest& dst, const PermutationType& perm, const ExpressionType& xpr)
@@ -1111,7 +1111,7 @@
struct transposition_matrix_product
{
typedef typename nested_eval<ExpressionType, 1>::type MatrixType;
- typedef typename remove_all<MatrixType>::type MatrixTypeCleaned;
+ typedef remove_all_t<MatrixType> MatrixTypeCleaned;
template<typename Dest, typename TranspositionType>
static EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void run(Dest& dst, const TranspositionType& tr, const ExpressionType& xpr)
diff --git a/Eigen/src/Core/Random.h b/Eigen/src/Core/Random.h
index 0b304e7..fab6889 100644
--- a/Eigen/src/Core/Random.h
+++ b/Eigen/src/Core/Random.h
@@ -17,7 +17,6 @@
namespace internal {
template<typename Scalar> struct scalar_random_op {
- EIGEN_EMPTY_STRUCT_CTOR(scalar_random_op)
inline const Scalar operator() () const { return random<Scalar>(); }
};
diff --git a/Eigen/src/Core/Ref.h b/Eigen/src/Core/Ref.h
index c31dfcc..81de5f9 100644
--- a/Eigen/src/Core/Ref.h
+++ b/Eigen/src/Core/Ref.h
@@ -48,7 +48,7 @@
ScalarTypeMatch = internal::is_same<typename PlainObjectType::Scalar, typename Derived::Scalar>::value,
MatchAtCompileTime = HasDirectAccess && StorageOrderMatch && InnerStrideMatch && OuterStrideMatch && AlignmentMatch && ScalarTypeMatch
};
- typedef typename internal::conditional<MatchAtCompileTime,internal::true_type,internal::false_type>::type type;
+ typedef std::conditional_t<MatchAtCompileTime,internal::true_type,internal::false_type> type;
};
};
@@ -199,8 +199,8 @@
return false;
}
- ::new (static_cast<Base*>(this)) Base(expr.data(), rows, cols);
- ::new (&m_stride) StrideBase(
+ internal::construct_at<Base>(this, expr.data(), rows, cols);
+ internal::construct_at(&m_stride,
(StrideType::OuterStrideAtCompileTime == 0) ? 0 : outer_stride,
(StrideType::InnerStrideAtCompileTime == 0) ? 0 : inner_stride );
return true;
@@ -287,7 +287,7 @@
typedef internal::traits<Ref> Traits;
template<typename Derived>
EIGEN_DEVICE_FUNC inline Ref(const PlainObjectBase<Derived>& expr,
- typename internal::enable_if<bool(Traits::template match<Derived>::MatchAtCompileTime),Derived>::type* = 0);
+ std::enable_if_t<bool(Traits::template match<Derived>::MatchAtCompileTime),Derived>* = 0);
public:
typedef RefBase<Ref> Base;
@@ -297,7 +297,7 @@
#ifndef EIGEN_PARSED_BY_DOXYGEN
template<typename Derived>
EIGEN_DEVICE_FUNC inline Ref(PlainObjectBase<Derived>& expr,
- typename internal::enable_if<bool(Traits::template match<Derived>::MatchAtCompileTime),Derived>::type* = 0)
+ std::enable_if_t<bool(Traits::template match<Derived>::MatchAtCompileTime),Derived>* = 0)
{
EIGEN_STATIC_ASSERT(bool(Traits::template match<Derived>::MatchAtCompileTime), STORAGE_LAYOUT_DOES_NOT_MATCH);
// Construction must pass since we will not create temporary storage in the non-const case.
@@ -307,7 +307,7 @@
}
template<typename Derived>
EIGEN_DEVICE_FUNC inline Ref(const DenseBase<Derived>& expr,
- typename internal::enable_if<bool(Traits::template match<Derived>::MatchAtCompileTime),Derived>::type* = 0)
+ std::enable_if_t<bool(Traits::template match<Derived>::MatchAtCompileTime),Derived>* = 0)
#else
/** Implicit constructor from any dense expression */
template<typename Derived>
@@ -339,7 +339,7 @@
template<typename Derived>
EIGEN_DEVICE_FUNC inline Ref(const DenseBase<Derived>& expr,
- typename internal::enable_if<bool(Traits::template match<Derived>::ScalarTypeMatch),Derived>::type* = 0)
+ std::enable_if_t<bool(Traits::template match<Derived>::ScalarTypeMatch),Derived>* = 0)
{
// std::cout << match_helper<Derived>::HasDirectAccess << "," << match_helper<Derived>::OuterStrideMatch << "," << match_helper<Derived>::InnerStrideMatch << "\n";
// std::cout << int(StrideType::OuterStrideAtCompileTime) << " - " << int(Derived::OuterStrideAtCompileTime) << "\n";
diff --git a/Eigen/src/Core/Replicate.h b/Eigen/src/Core/Replicate.h
index 6b5f9fe..4f91bbe 100644
--- a/Eigen/src/Core/Replicate.h
+++ b/Eigen/src/Core/Replicate.h
@@ -23,7 +23,7 @@
typedef typename traits<MatrixType>::StorageKind StorageKind;
typedef typename traits<MatrixType>::XprKind XprKind;
typedef typename ref_selector<MatrixType>::type MatrixTypeNested;
- typedef typename remove_reference<MatrixTypeNested>::type MatrixTypeNested_;
+ typedef std::remove_reference_t<MatrixTypeNested> MatrixTypeNested_;
enum {
RowsAtCompileTime = RowFactor==Dynamic || int(MatrixType::RowsAtCompileTime)==Dynamic
? Dynamic
@@ -69,14 +69,14 @@
typedef typename internal::dense_xpr_base<Replicate>::type Base;
EIGEN_DENSE_PUBLIC_INTERFACE(Replicate)
- typedef typename internal::remove_all<MatrixType>::type NestedExpression;
+ typedef internal::remove_all_t<MatrixType> NestedExpression;
template<typename OriginalMatrixType>
EIGEN_DEVICE_FUNC
inline explicit Replicate(const OriginalMatrixType& matrix)
: m_matrix(matrix), m_rowFactor(RowFactor), m_colFactor(ColFactor)
{
- EIGEN_STATIC_ASSERT((internal::is_same<typename internal::remove_const<MatrixType>::type,OriginalMatrixType>::value),
+ EIGEN_STATIC_ASSERT((internal::is_same<std::remove_const_t<MatrixType>,OriginalMatrixType>::value),
THE_MATRIX_OR_EXPRESSION_THAT_YOU_PASSED_DOES_NOT_HAVE_THE_EXPECTED_TYPE)
eigen_assert(RowFactor!=Dynamic && ColFactor!=Dynamic);
}
@@ -86,7 +86,7 @@
inline Replicate(const OriginalMatrixType& matrix, Index rowFactor, Index colFactor)
: m_matrix(matrix), m_rowFactor(rowFactor), m_colFactor(colFactor)
{
- EIGEN_STATIC_ASSERT((internal::is_same<typename internal::remove_const<MatrixType>::type,OriginalMatrixType>::value),
+ EIGEN_STATIC_ASSERT((internal::is_same<std::remove_const_t<MatrixType>,OriginalMatrixType>::value),
THE_MATRIX_OR_EXPRESSION_THAT_YOU_PASSED_DOES_NOT_HAVE_THE_EXPECTED_TYPE)
}
diff --git a/Eigen/src/Core/Reshaped.h b/Eigen/src/Core/Reshaped.h
index baa550e..53889fe 100644
--- a/Eigen/src/Core/Reshaped.h
+++ b/Eigen/src/Core/Reshaped.h
@@ -157,7 +157,7 @@
EIGEN_INHERIT_ASSIGNMENT_OPERATORS(ReshapedImpl_dense)
typedef typename internal::ref_selector<XprType>::non_const_type MatrixTypeNested;
- typedef typename internal::remove_all<XprType>::type NestedExpression;
+ typedef internal::remove_all_t<XprType> NestedExpression;
class InnerIterator;
@@ -187,12 +187,12 @@
/** \returns the nested expression */
EIGEN_DEVICE_FUNC
- const typename internal::remove_all<XprType>::type&
+ const internal::remove_all_t<XprType>&
nestedExpression() const { return m_xpr; }
/** \returns the nested expression */
EIGEN_DEVICE_FUNC
- typename internal::remove_reference<XprType>::type&
+ std::remove_reference_t<XprType>&
nestedExpression() { return m_xpr; }
protected:
@@ -232,7 +232,7 @@
{}
EIGEN_DEVICE_FUNC
- const typename internal::remove_all<XprTypeNested>::type& nestedExpression() const
+ const internal::remove_all_t<XprTypeNested>& nestedExpression() const
{
return m_xpr;
}
diff --git a/Eigen/src/Core/ReturnByValue.h b/Eigen/src/Core/ReturnByValue.h
index d2dd349..9025282 100644
--- a/Eigen/src/Core/ReturnByValue.h
+++ b/Eigen/src/Core/ReturnByValue.h
@@ -106,7 +106,7 @@
EIGEN_DEVICE_FUNC explicit evaluator(const XprType& xpr)
: m_result(xpr.rows(), xpr.cols())
{
- ::new (static_cast<Base*>(this)) Base(m_result);
+ internal::construct_at<Base>(this, m_result);
xpr.evalTo(m_result);
}
diff --git a/Eigen/src/Core/Reverse.h b/Eigen/src/Core/Reverse.h
index 6c66922..97e1d68 100644
--- a/Eigen/src/Core/Reverse.h
+++ b/Eigen/src/Core/Reverse.h
@@ -26,7 +26,7 @@
typedef typename traits<MatrixType>::StorageKind StorageKind;
typedef typename traits<MatrixType>::XprKind XprKind;
typedef typename ref_selector<MatrixType>::type MatrixTypeNested;
- typedef typename remove_reference<MatrixTypeNested>::type MatrixTypeNested_;
+ typedef std::remove_reference_t<MatrixTypeNested> MatrixTypeNested_;
enum {
RowsAtCompileTime = MatrixType::RowsAtCompileTime,
ColsAtCompileTime = MatrixType::ColsAtCompileTime,
@@ -69,7 +69,7 @@
typedef typename internal::dense_xpr_base<Reverse>::type Base;
EIGEN_DENSE_PUBLIC_INTERFACE(Reverse)
- typedef typename internal::remove_all<MatrixType>::type NestedExpression;
+ typedef internal::remove_all_t<MatrixType> NestedExpression;
using Base::IsRowMajor;
protected:
@@ -101,7 +101,7 @@
return -m_matrix.innerStride();
}
- EIGEN_DEVICE_FUNC const typename internal::remove_all<typename MatrixType::Nested>::type&
+ EIGEN_DEVICE_FUNC const internal::remove_all_t<typename MatrixType::Nested>&
nestedExpression() const
{
return m_matrix;
diff --git a/Eigen/src/Core/SelfAdjointView.h b/Eigen/src/Core/SelfAdjointView.h
index 7096058..7a930db 100644
--- a/Eigen/src/Core/SelfAdjointView.h
+++ b/Eigen/src/Core/SelfAdjointView.h
@@ -35,7 +35,7 @@
struct traits<SelfAdjointView<MatrixType, UpLo> > : traits<MatrixType>
{
typedef typename ref_selector<MatrixType>::non_const_type MatrixTypeNested;
- typedef typename remove_all<MatrixTypeNested>::type MatrixTypeNestedCleaned;
+ typedef remove_all_t<MatrixTypeNested> MatrixTypeNestedCleaned;
typedef MatrixType ExpressionType;
typedef typename MatrixType::PlainObject FullMatrixType;
enum {
@@ -63,8 +63,8 @@
/** \brief The type of coefficients in this matrix */
typedef typename internal::traits<SelfAdjointView>::Scalar Scalar;
typedef typename MatrixType::StorageIndex StorageIndex;
- typedef typename internal::remove_all<typename MatrixType::ConjugateReturnType>::type MatrixConjugateReturnType;
- typedef SelfAdjointView<typename internal::add_const<MatrixType>::type, UpLo> ConstSelfAdjointView;
+ typedef internal::remove_all_t<typename MatrixType::ConjugateReturnType> MatrixConjugateReturnType;
+ typedef SelfAdjointView<std::add_const_t<MatrixType>, UpLo> ConstSelfAdjointView;
enum {
Mode = internal::traits<SelfAdjointView>::Mode,
@@ -180,16 +180,16 @@
*/
template<unsigned int TriMode>
EIGEN_DEVICE_FUNC
- typename internal::conditional<(TriMode&(Upper|Lower))==(UpLo&(Upper|Lower)),
- TriangularView<MatrixType,TriMode>,
- TriangularView<typename MatrixType::AdjointReturnType,TriMode> >::type
+ std::conditional_t<(TriMode&(Upper|Lower))==(UpLo&(Upper|Lower)),
+ TriangularView<MatrixType,TriMode>,
+ TriangularView<typename MatrixType::AdjointReturnType,TriMode> >
triangularView() const
{
- typename internal::conditional<(TriMode&(Upper|Lower))==(UpLo&(Upper|Lower)), MatrixType&, typename MatrixType::ConstTransposeReturnType>::type tmp1(m_matrix);
- typename internal::conditional<(TriMode&(Upper|Lower))==(UpLo&(Upper|Lower)), MatrixType&, typename MatrixType::AdjointReturnType>::type tmp2(tmp1);
- return typename internal::conditional<(TriMode&(Upper|Lower))==(UpLo&(Upper|Lower)),
- TriangularView<MatrixType,TriMode>,
- TriangularView<typename MatrixType::AdjointReturnType,TriMode> >::type(tmp2);
+ std::conditional_t<(TriMode&(Upper|Lower))==(UpLo&(Upper|Lower)), MatrixType&, typename MatrixType::ConstTransposeReturnType> tmp1(m_matrix);
+ std::conditional_t<(TriMode&(Upper|Lower))==(UpLo&(Upper|Lower)), MatrixType&, typename MatrixType::AdjointReturnType> tmp2(tmp1);
+ return std::conditional_t<(TriMode&(Upper|Lower))==(UpLo&(Upper|Lower)),
+ TriangularView<MatrixType,TriMode>,
+ TriangularView<typename MatrixType::AdjointReturnType,TriMode> >(tmp2);
}
typedef SelfAdjointView<const MatrixConjugateReturnType,UpLo> ConjugateReturnType;
@@ -203,10 +203,10 @@
*/
template<bool Cond>
EIGEN_DEVICE_FUNC
- inline typename internal::conditional<Cond,ConjugateReturnType,ConstSelfAdjointView>::type
+ inline std::conditional_t<Cond,ConjugateReturnType,ConstSelfAdjointView>
conjugateIf() const
{
- typedef typename internal::conditional<Cond,ConjugateReturnType,ConstSelfAdjointView>::type ReturnType;
+ typedef std::conditional_t<Cond,ConjugateReturnType,ConstSelfAdjointView> ReturnType;
return ReturnType(m_matrix.template conjugateIf<Cond>());
}
@@ -220,7 +220,7 @@
/** \sa MatrixBase::transpose() */
template<class Dummy=int>
EIGEN_DEVICE_FUNC
- inline TransposeReturnType transpose(typename internal::enable_if<Eigen::internal::is_lvalue<MatrixType>::value, Dummy*>::type = nullptr)
+ inline TransposeReturnType transpose(std::enable_if_t<Eigen::internal::is_lvalue<MatrixType>::value, Dummy*> = nullptr)
{
typename MatrixType::TransposeReturnType tmp(m_matrix);
return TransposeReturnType(tmp);
diff --git a/Eigen/src/Core/Solve.h b/Eigen/src/Core/Solve.h
index 3d3a3c9..f77eac9 100644
--- a/Eigen/src/Core/Solve.h
+++ b/Eigen/src/Core/Solve.h
@@ -125,7 +125,7 @@
EIGEN_DEVICE_FUNC explicit evaluator(const SolveType& solve)
: m_result(solve.rows(), solve.cols())
{
- ::new (static_cast<Base*>(this)) Base(m_result);
+ internal::construct_at<Base>(this, m_result);
solve.dec()._solve_impl(solve.rhs(), m_result);
}
diff --git a/Eigen/src/Core/SolveTriangular.h b/Eigen/src/Core/SolveTriangular.h
index 518a6c6..71d6f85 100644
--- a/Eigen/src/Core/SolveTriangular.h
+++ b/Eigen/src/Core/SolveTriangular.h
@@ -89,7 +89,7 @@
static EIGEN_DEVICE_FUNC void run(const Lhs& lhs, Rhs& rhs)
{
- typename internal::add_const_on_value_type<ActualLhsType>::type actualLhs = LhsProductTraits::extract(lhs);
+ add_const_on_value_type_t<ActualLhsType> actualLhs = LhsProductTraits::extract(lhs);
const Index size = lhs.rows();
const Index othersize = Side==OnTheLeft? rhs.cols() : rhs.rows();
@@ -176,11 +176,11 @@
return;
enum { copy = (internal::traits<OtherDerived>::Flags & RowMajorBit) && OtherDerived::IsVectorAtCompileTime && OtherDerived::SizeAtCompileTime!=1};
- typedef typename internal::conditional<copy,
- typename internal::plain_matrix_type_column_major<OtherDerived>::type, OtherDerived&>::type OtherCopy;
+ typedef std::conditional_t<copy,
+ typename internal::plain_matrix_type_column_major<OtherDerived>::type, OtherDerived&> OtherCopy;
OtherCopy otherCopy(other);
- internal::triangular_solver_selector<MatrixType, typename internal::remove_reference<OtherCopy>::type,
+ internal::triangular_solver_selector<MatrixType, std::remove_reference_t<OtherCopy>,
Side, Mode>::run(derived().nestedExpression(), otherCopy);
if (copy)
@@ -208,7 +208,7 @@
template<int Side, typename TriangularType, typename Rhs> struct triangular_solve_retval
: public ReturnByValue<triangular_solve_retval<Side, TriangularType, Rhs> >
{
- typedef typename remove_all<typename Rhs::Nested>::type RhsNestedCleaned;
+ typedef remove_all_t<typename Rhs::Nested> RhsNestedCleaned;
typedef ReturnByValue<triangular_solve_retval> Base;
triangular_solve_retval(const TriangularType& tri, const Rhs& rhs)
diff --git a/Eigen/src/Core/SolverBase.h b/Eigen/src/Core/SolverBase.h
index 2f238ac..7396e04 100644
--- a/Eigen/src/Core/SolverBase.h
+++ b/Eigen/src/Core/SolverBase.h
@@ -30,7 +30,7 @@
template<bool Transpose_, typename Rhs>
static void run(const type& transpose, const Rhs& b)
{
- internal::solve_assertion<typename internal::remove_all<Derived>::type>::template run<true>(transpose.nestedExpression(), b);
+ internal::solve_assertion<internal::remove_all_t<Derived>>::template run<true>(transpose.nestedExpression(), b);
}
};
@@ -42,7 +42,7 @@
template<bool Transpose_, typename Rhs>
static void run(const type& adjoint, const Rhs& b)
{
- internal::solve_assertion<typename internal::remove_all<Transpose<Derived> >::type>::template run<true>(adjoint.nestedExpression(), b);
+ internal::solve_assertion<internal::remove_all_t<Transpose<Derived> >>::template run<true>(adjoint.nestedExpression(), b);
}
};
} // end namespace internal
@@ -81,12 +81,11 @@
enum {
RowsAtCompileTime = internal::traits<Derived>::RowsAtCompileTime,
ColsAtCompileTime = internal::traits<Derived>::ColsAtCompileTime,
- SizeAtCompileTime = (internal::size_at_compile_time<internal::traits<Derived>::RowsAtCompileTime,
- internal::traits<Derived>::ColsAtCompileTime>::ret),
+ SizeAtCompileTime = (internal::size_of_xpr_at_compile_time<Derived>::ret),
MaxRowsAtCompileTime = internal::traits<Derived>::MaxRowsAtCompileTime,
MaxColsAtCompileTime = internal::traits<Derived>::MaxColsAtCompileTime,
- MaxSizeAtCompileTime = (internal::size_at_compile_time<internal::traits<Derived>::MaxRowsAtCompileTime,
- internal::traits<Derived>::MaxColsAtCompileTime>::ret),
+ MaxSizeAtCompileTime = internal::size_at_compile_time(internal::traits<Derived>::MaxRowsAtCompileTime,
+ internal::traits<Derived>::MaxColsAtCompileTime),
IsVectorAtCompileTime = internal::traits<Derived>::MaxRowsAtCompileTime == 1
|| internal::traits<Derived>::MaxColsAtCompileTime == 1,
NumDimensions = int(MaxSizeAtCompileTime) == 1 ? 0 : bool(IsVectorAtCompileTime) ? 1 : 2
@@ -107,12 +106,12 @@
inline const Solve<Derived, Rhs>
solve(const MatrixBase<Rhs>& b) const
{
- internal::solve_assertion<typename internal::remove_all<Derived>::type>::template run<false>(derived(), b);
+ internal::solve_assertion<internal::remove_all_t<Derived>>::template run<false>(derived(), b);
return Solve<Derived, Rhs>(derived(), b.derived());
}
/** \internal the return type of transpose() */
- typedef typename internal::add_const<Transpose<const Derived> >::type ConstTransposeReturnType;
+ typedef Transpose<const Derived> ConstTransposeReturnType;
/** \returns an expression of the transposed of the factored matrix.
*
* A typical usage is to solve for the transposed problem A^T x = b:
@@ -120,16 +119,16 @@
*
* \sa adjoint(), solve()
*/
- inline ConstTransposeReturnType transpose() const
+ inline const ConstTransposeReturnType transpose() const
{
return ConstTransposeReturnType(derived());
}
/** \internal the return type of adjoint() */
- typedef typename internal::conditional<NumTraits<Scalar>::IsComplex,
- CwiseUnaryOp<internal::scalar_conjugate_op<Scalar>, ConstTransposeReturnType>,
- ConstTransposeReturnType
- >::type AdjointReturnType;
+ typedef std::conditional_t<NumTraits<Scalar>::IsComplex,
+ CwiseUnaryOp<internal::scalar_conjugate_op<Scalar>, const ConstTransposeReturnType>,
+ const ConstTransposeReturnType
+ > AdjointReturnType;
/** \returns an expression of the adjoint of the factored matrix
*
* A typical usage is to solve for the adjoint problem A' x = b:
@@ -139,7 +138,7 @@
*
* \sa transpose(), solve()
*/
- inline AdjointReturnType adjoint() const
+ inline const AdjointReturnType adjoint() const
{
return AdjointReturnType(derived().transpose());
}
diff --git a/Eigen/src/Core/StableNorm.h b/Eigen/src/Core/StableNorm.h
index c006c25..a3bc918 100644
--- a/Eigen/src/Core/StableNorm.h
+++ b/Eigen/src/Core/StableNorm.h
@@ -59,7 +59,7 @@
const Index blockSize = 4096;
typedef typename internal::nested_eval<VectorType,2>::type VectorTypeCopy;
- typedef typename internal::remove_all<VectorTypeCopy>::type VectorTypeCopyClean;
+ typedef internal::remove_all_t<VectorTypeCopy> VectorTypeCopyClean;
const VectorTypeCopy copy(vec);
enum {
@@ -68,8 +68,8 @@
) && (blockSize*sizeof(Scalar)*2<EIGEN_STACK_ALLOCATION_LIMIT)
&& (EIGEN_MAX_STATIC_ALIGN_BYTES>0) // if we cannot allocate on the stack, then let's not bother about this optimization
};
- typedef typename internal::conditional<CanAlign, Ref<const Matrix<Scalar,Dynamic,1,0,blockSize,1>, internal::evaluator<VectorTypeCopyClean>::Alignment>,
- typename VectorTypeCopyClean::ConstSegmentReturnType>::type SegmentWrapper;
+ typedef std::conditional_t<CanAlign, Ref<const Matrix<Scalar,Dynamic,1,0,blockSize,1>, internal::evaluator<VectorTypeCopyClean>::Alignment>,
+ typename VectorTypeCopyClean::ConstSegmentReturnType> SegmentWrapper;
Index n = vec.size();
Index bi = internal::first_default_aligned(copy);
@@ -81,7 +81,7 @@
template<typename VectorType>
typename VectorType::RealScalar
-stable_norm_impl(const VectorType &vec, typename enable_if<VectorType::IsVectorAtCompileTime>::type* = 0 )
+stable_norm_impl(const VectorType &vec, std::enable_if_t<VectorType::IsVectorAtCompileTime>* = 0 )
{
using std::sqrt;
using std::abs;
@@ -103,7 +103,7 @@
template<typename MatrixType>
typename MatrixType::RealScalar
-stable_norm_impl(const MatrixType &mat, typename enable_if<!MatrixType::IsVectorAtCompileTime>::type* = 0 )
+stable_norm_impl(const MatrixType &mat, std::enable_if_t<!MatrixType::IsVectorAtCompileTime>* = 0 )
{
using std::sqrt;
diff --git a/Eigen/src/Core/StlIterators.h b/Eigen/src/Core/StlIterators.h
index d9529c0..d5d3971 100644
--- a/Eigen/src/Core/StlIterators.h
+++ b/Eigen/src/Core/StlIterators.h
@@ -27,7 +27,7 @@
typedef typename traits::XprType XprType;
typedef indexed_based_stl_iterator_base<typename traits::non_const_iterator> non_const_iterator;
typedef indexed_based_stl_iterator_base<typename traits::const_iterator> const_iterator;
- typedef typename internal::conditional<internal::is_const<XprType>::value,non_const_iterator,const_iterator>::type other_iterator;
+ typedef std::conditional_t<internal::is_const<XprType>::value,non_const_iterator,const_iterator> other_iterator;
// NOTE: in C++03 we cannot declare friend classes through typedefs because we need to write friend class:
friend class indexed_based_stl_iterator_base<typename traits::const_iterator>;
friend class indexed_based_stl_iterator_base<typename traits::non_const_iterator>;
@@ -106,7 +106,7 @@
typedef typename traits::XprType XprType;
typedef indexed_based_stl_reverse_iterator_base<typename traits::non_const_iterator> non_const_iterator;
typedef indexed_based_stl_reverse_iterator_base<typename traits::const_iterator> const_iterator;
- typedef typename internal::conditional<internal::is_const<XprType>::value,non_const_iterator,const_iterator>::type other_iterator;
+ typedef std::conditional_t<internal::is_const<XprType>::value,non_const_iterator,const_iterator> other_iterator;
// NOTE: in C++03 we cannot declare friend classes through typedefs because we need to write friend class:
friend class indexed_based_stl_reverse_iterator_base<typename traits::const_iterator>;
friend class indexed_based_stl_reverse_iterator_base<typename traits::non_const_iterator>;
@@ -181,18 +181,18 @@
class pointer_based_stl_iterator
{
enum { is_lvalue = internal::is_lvalue<XprType>::value };
- typedef pointer_based_stl_iterator<typename internal::remove_const<XprType>::type> non_const_iterator;
- typedef pointer_based_stl_iterator<typename internal::add_const<XprType>::type> const_iterator;
- typedef typename internal::conditional<internal::is_const<XprType>::value,non_const_iterator,const_iterator>::type other_iterator;
+ typedef pointer_based_stl_iterator<std::remove_const_t<XprType>> non_const_iterator;
+ typedef pointer_based_stl_iterator<std::add_const_t<XprType>> const_iterator;
+ typedef std::conditional_t<internal::is_const<XprType>::value,non_const_iterator,const_iterator> other_iterator;
// NOTE: in C++03 we cannot declare friend classes through typedefs because we need to write friend class:
- friend class pointer_based_stl_iterator<typename internal::add_const<XprType>::type>;
- friend class pointer_based_stl_iterator<typename internal::remove_const<XprType>::type>;
+ friend class pointer_based_stl_iterator<std::add_const_t<XprType>>;
+ friend class pointer_based_stl_iterator<std::remove_const_t<XprType>>;
public:
typedef Index difference_type;
typedef typename XprType::Scalar value_type;
typedef std::random_access_iterator_tag iterator_category;
- typedef typename internal::conditional<bool(is_lvalue), value_type*, const value_type*>::type pointer;
- typedef typename internal::conditional<bool(is_lvalue), value_type&, const value_type&>::type reference;
+ typedef std::conditional_t<bool(is_lvalue), value_type*, const value_type*> pointer;
+ typedef std::conditional_t<bool(is_lvalue), value_type&, const value_type&> reference;
pointer_based_stl_iterator() EIGEN_NO_THROW : m_ptr(0) {}
@@ -262,8 +262,8 @@
struct indexed_based_stl_iterator_traits<generic_randaccess_stl_iterator<XprType_> >
{
typedef XprType_ XprType;
- typedef generic_randaccess_stl_iterator<typename internal::remove_const<XprType>::type> non_const_iterator;
- typedef generic_randaccess_stl_iterator<typename internal::add_const<XprType>::type> const_iterator;
+ typedef generic_randaccess_stl_iterator<std::remove_const_t<XprType>> non_const_iterator;
+ typedef generic_randaccess_stl_iterator<std::add_const_t<XprType>> const_iterator;
};
template<typename XprType>
@@ -285,13 +285,13 @@
// TODO currently const Transpose/Reshape expressions never returns const references,
// so lets return by value too.
- //typedef typename internal::conditional<bool(has_direct_access), const value_type&, const value_type>::type read_only_ref_t;
+ //typedef std::conditional_t<bool(has_direct_access), const value_type&, const value_type> read_only_ref_t;
typedef const value_type read_only_ref_t;
public:
- typedef typename internal::conditional<bool(is_lvalue), value_type *, const value_type *>::type pointer;
- typedef typename internal::conditional<bool(is_lvalue), value_type&, read_only_ref_t>::type reference;
+ typedef std::conditional_t<bool(is_lvalue), value_type *, const value_type *> pointer;
+ typedef std::conditional_t<bool(is_lvalue), value_type&, read_only_ref_t> reference;
generic_randaccess_stl_iterator() : Base() {}
generic_randaccess_stl_iterator(XprType& xpr, Index index) : Base(xpr,index) {}
@@ -307,8 +307,8 @@
struct indexed_based_stl_iterator_traits<subvector_stl_iterator<XprType_,Direction> >
{
typedef XprType_ XprType;
- typedef subvector_stl_iterator<typename internal::remove_const<XprType>::type, Direction> non_const_iterator;
- typedef subvector_stl_iterator<typename internal::add_const<XprType>::type, Direction> const_iterator;
+ typedef subvector_stl_iterator<std::remove_const_t<XprType>, Direction> non_const_iterator;
+ typedef subvector_stl_iterator<std::add_const_t<XprType>, Direction> const_iterator;
};
template<typename XprType, DirectionType Direction>
@@ -322,12 +322,12 @@
using Base::m_index;
using Base::mp_xpr;
- typedef typename internal::conditional<Direction==Vertical,typename XprType::ColXpr,typename XprType::RowXpr>::type SubVectorType;
- typedef typename internal::conditional<Direction==Vertical,typename XprType::ConstColXpr,typename XprType::ConstRowXpr>::type ConstSubVectorType;
+ typedef std::conditional_t<Direction==Vertical,typename XprType::ColXpr,typename XprType::RowXpr> SubVectorType;
+ typedef std::conditional_t<Direction==Vertical,typename XprType::ConstColXpr,typename XprType::ConstRowXpr> ConstSubVectorType;
public:
- typedef typename internal::conditional<bool(is_lvalue), SubVectorType, ConstSubVectorType>::type reference;
+ typedef std::conditional_t<bool(is_lvalue), SubVectorType, ConstSubVectorType> reference;
typedef typename reference::PlainObject value_type;
private:
@@ -355,8 +355,8 @@
struct indexed_based_stl_iterator_traits<subvector_stl_reverse_iterator<XprType_,Direction> >
{
typedef XprType_ XprType;
- typedef subvector_stl_reverse_iterator<typename internal::remove_const<XprType>::type, Direction> non_const_iterator;
- typedef subvector_stl_reverse_iterator<typename internal::add_const<XprType>::type, Direction> const_iterator;
+ typedef subvector_stl_reverse_iterator<std::remove_const_t<XprType>, Direction> non_const_iterator;
+ typedef subvector_stl_reverse_iterator<std::add_const_t<XprType>, Direction> const_iterator;
};
template<typename XprType, DirectionType Direction>
@@ -370,12 +370,12 @@
using Base::m_index;
using Base::mp_xpr;
- typedef typename internal::conditional<Direction==Vertical,typename XprType::ColXpr,typename XprType::RowXpr>::type SubVectorType;
- typedef typename internal::conditional<Direction==Vertical,typename XprType::ConstColXpr,typename XprType::ConstRowXpr>::type ConstSubVectorType;
+ typedef std::conditional_t<Direction==Vertical,typename XprType::ColXpr,typename XprType::RowXpr> SubVectorType;
+ typedef std::conditional_t<Direction==Vertical,typename XprType::ConstColXpr,typename XprType::ConstRowXpr> ConstSubVectorType;
public:
- typedef typename internal::conditional<bool(is_lvalue), SubVectorType, ConstSubVectorType>::type reference;
+ typedef std::conditional_t<bool(is_lvalue), SubVectorType, ConstSubVectorType> reference;
typedef typename reference::PlainObject value_type;
private:
diff --git a/Eigen/src/Core/Transpose.h b/Eigen/src/Core/Transpose.h
index d302766..74650ef 100644
--- a/Eigen/src/Core/Transpose.h
+++ b/Eigen/src/Core/Transpose.h
@@ -20,7 +20,7 @@
struct traits<Transpose<MatrixType> > : public traits<MatrixType>
{
typedef typename ref_selector<MatrixType>::type MatrixTypeNested;
- typedef typename remove_reference<MatrixTypeNested>::type MatrixTypeNestedPlain;
+ typedef std::remove_reference_t<MatrixTypeNested> MatrixTypeNestedPlain;
enum {
RowsAtCompileTime = MatrixType::ColsAtCompileTime,
ColsAtCompileTime = MatrixType::RowsAtCompileTime,
@@ -60,7 +60,7 @@
typedef typename TransposeImpl<MatrixType,typename internal::traits<MatrixType>::StorageKind>::Base Base;
EIGEN_GENERIC_PUBLIC_INTERFACE(Transpose)
- typedef typename internal::remove_all<MatrixType>::type NestedExpression;
+ typedef internal::remove_all_t<MatrixType> NestedExpression;
EIGEN_DEVICE_FUNC
explicit EIGEN_STRONG_INLINE Transpose(MatrixType& matrix) : m_matrix(matrix) {}
@@ -74,12 +74,12 @@
/** \returns the nested expression */
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
- const typename internal::remove_all<MatrixTypeNested>::type&
+ const internal::remove_all_t<MatrixTypeNested>&
nestedExpression() const { return m_matrix; }
/** \returns the nested expression */
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
- typename internal::remove_reference<MatrixTypeNested>::type&
+ std::remove_reference_t<MatrixTypeNested>&
nestedExpression() { return m_matrix; }
/** \internal */
@@ -132,11 +132,11 @@
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
Index outerStride() const { return derived().nestedExpression().outerStride(); }
- typedef typename internal::conditional<
- internal::is_lvalue<MatrixType>::value,
- Scalar,
- const Scalar
- >::type ScalarWithConstIfNotLvalue;
+ typedef std::conditional_t<
+ internal::is_lvalue<MatrixType>::value,
+ Scalar,
+ const Scalar
+ > ScalarWithConstIfNotLvalue;
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
ScalarWithConstIfNotLvalue* data() { return derived().nestedExpression().data(); }
@@ -180,7 +180,7 @@
* \sa transposeInPlace(), adjoint() */
template<typename Derived>
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
-Transpose<Derived>
+typename DenseBase<Derived>::TransposeReturnType
DenseBase<Derived>::transpose()
{
return TransposeReturnType(derived());
@@ -193,7 +193,7 @@
* \sa transposeInPlace(), adjoint() */
template<typename Derived>
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
-typename DenseBase<Derived>::ConstTransposeReturnType
+const typename DenseBase<Derived>::ConstTransposeReturnType
DenseBase<Derived>::transpose() const
{
return ConstTransposeReturnType(derived());
diff --git a/Eigen/src/Core/TriangularMatrix.h b/Eigen/src/Core/TriangularMatrix.h
index 28f515f..c1bd13a 100644
--- a/Eigen/src/Core/TriangularMatrix.h
+++ b/Eigen/src/Core/TriangularMatrix.h
@@ -37,14 +37,13 @@
MaxRowsAtCompileTime = internal::traits<Derived>::MaxRowsAtCompileTime,
MaxColsAtCompileTime = internal::traits<Derived>::MaxColsAtCompileTime,
- SizeAtCompileTime = (internal::size_at_compile_time<internal::traits<Derived>::RowsAtCompileTime,
- internal::traits<Derived>::ColsAtCompileTime>::ret),
+ SizeAtCompileTime = (internal::size_of_xpr_at_compile_time<Derived>::ret),
/**< This is equal to the number of coefficients, i.e. the number of
* rows times the number of columns, or to \a Dynamic if this is not
* known at compile-time. \sa RowsAtCompileTime, ColsAtCompileTime */
- MaxSizeAtCompileTime = (internal::size_at_compile_time<internal::traits<Derived>::MaxRowsAtCompileTime,
- internal::traits<Derived>::MaxColsAtCompileTime>::ret)
+ MaxSizeAtCompileTime = internal::size_at_compile_time(internal::traits<Derived>::MaxRowsAtCompileTime,
+ internal::traits<Derived>::MaxColsAtCompileTime)
};
typedef typename internal::traits<Derived>::Scalar Scalar;
@@ -172,8 +171,8 @@
struct traits<TriangularView<MatrixType, Mode_> > : traits<MatrixType>
{
typedef typename ref_selector<MatrixType>::non_const_type MatrixTypeNested;
- typedef typename remove_reference<MatrixTypeNested>::type MatrixTypeNestedNonRef;
- typedef typename remove_all<MatrixTypeNested>::type MatrixTypeNestedCleaned;
+ typedef std::remove_reference_t<MatrixTypeNested> MatrixTypeNestedNonRef;
+ typedef remove_all_t<MatrixTypeNested> MatrixTypeNestedCleaned;
typedef typename MatrixType::PlainObject FullMatrixType;
typedef MatrixType ExpressionType;
enum {
@@ -199,8 +198,8 @@
typedef typename internal::traits<TriangularView>::MatrixTypeNested MatrixTypeNested;
typedef typename internal::traits<TriangularView>::MatrixTypeNestedNonRef MatrixTypeNestedNonRef;
- typedef typename internal::remove_all<typename MatrixType::ConjugateReturnType>::type MatrixConjugateReturnType;
- typedef TriangularView<typename internal::add_const<MatrixType>::type, Mode_> ConstTriangularView;
+ typedef internal::remove_all_t<typename MatrixType::ConjugateReturnType> MatrixConjugateReturnType;
+ typedef TriangularView<std::add_const_t<MatrixType>, Mode_> ConstTriangularView;
public:
@@ -249,10 +248,10 @@
*/
template<bool Cond>
EIGEN_DEVICE_FUNC
- inline typename internal::conditional<Cond,ConjugateReturnType,ConstTriangularView>::type
+ inline std::conditional_t<Cond,ConjugateReturnType,ConstTriangularView>
conjugateIf() const
{
- typedef typename internal::conditional<Cond,ConjugateReturnType,ConstTriangularView>::type ReturnType;
+ typedef std::conditional_t<Cond,ConjugateReturnType,ConstTriangularView> ReturnType;
return ReturnType(m_matrix.template conjugateIf<Cond>());
}
@@ -266,7 +265,7 @@
/** \sa MatrixBase::transpose() */
template<class Dummy=int>
EIGEN_DEVICE_FUNC
- inline TransposeReturnType transpose(typename internal::enable_if<Eigen::internal::is_lvalue<MatrixType>::value, Dummy*>::type = nullptr)
+ inline TransposeReturnType transpose(std::enable_if_t<Eigen::internal::is_lvalue<MatrixType>::value, Dummy*> = nullptr)
{
typename MatrixType::TransposeReturnType tmp(m_matrix);
return TransposeReturnType(tmp);
@@ -731,10 +730,10 @@
template<typename MatrixType, unsigned int Mode>
struct unary_evaluator<TriangularView<MatrixType,Mode>, IndexBased>
- : evaluator<typename internal::remove_all<MatrixType>::type>
+ : evaluator<internal::remove_all_t<MatrixType>>
{
typedef TriangularView<MatrixType,Mode> XprType;
- typedef evaluator<typename internal::remove_all<MatrixType>::type> Base;
+ typedef evaluator<internal::remove_all_t<MatrixType>> Base;
EIGEN_DEVICE_FUNC
unary_evaluator(const XprType &xpr) : Base(xpr.nestedExpression()) {}
};
diff --git a/Eigen/src/Core/VectorwiseOp.h b/Eigen/src/Core/VectorwiseOp.h
index 5d4c11f..b004f76 100644
--- a/Eigen/src/Core/VectorwiseOp.h
+++ b/Eigen/src/Core/VectorwiseOp.h
@@ -88,7 +88,6 @@
#define EIGEN_MAKE_PARTIAL_REDUX_FUNCTOR(MEMBER,COST,VECTORIZABLE,BINARYOP) \
template <typename ResultType,typename Scalar> \
struct member_##MEMBER { \
- EIGEN_EMPTY_STRUCT_CTOR(member_##MEMBER) \
typedef ResultType result_type; \
typedef BINARYOP<Scalar,Scalar> BinaryOp; \
template<int Size> struct Cost { enum { value = COST }; }; \
@@ -193,7 +192,7 @@
typedef typename ExpressionType::RealScalar RealScalar;
typedef Eigen::Index Index; ///< \deprecated since Eigen 3.3
typedef typename internal::ref_selector<ExpressionType>::non_const_type ExpressionTypeNested;
- typedef typename internal::remove_all<ExpressionTypeNested>::type ExpressionTypeNestedCleaned;
+ typedef internal::remove_all_t<ExpressionTypeNested> ExpressionTypeNestedCleaned;
template<template<typename OutScalar,typename InputScalar> class Functor,
typename ReturnScalar=Scalar> struct ReturnType
diff --git a/Eigen/src/Core/Visitor.h b/Eigen/src/Core/Visitor.h
index cb55245..e1c17fc 100644
--- a/Eigen/src/Core/Visitor.h
+++ b/Eigen/src/Core/Visitor.h
@@ -23,8 +23,10 @@
struct visitor_impl<Visitor, Derived, UnrollCount, false>
{
enum {
- col = (UnrollCount-1) / Derived::RowsAtCompileTime,
- row = (UnrollCount-1) % Derived::RowsAtCompileTime
+ col = Derived::IsRowMajor ? (UnrollCount-1) % Derived::ColsAtCompileTime
+ : (UnrollCount-1) / Derived::RowsAtCompileTime,
+ row = Derived::IsRowMajor ? (UnrollCount-1) / Derived::ColsAtCompileTime
+ : (UnrollCount-1) % Derived::RowsAtCompileTime
};
EIGEN_DEVICE_FUNC
@@ -60,11 +62,25 @@
static inline void run(const Derived& mat, Visitor& visitor)
{
visitor.init(mat.coeff(0,0), 0, 0);
- for(Index i = 1; i < mat.rows(); ++i)
- visitor(mat.coeff(i, 0), i, 0);
- for(Index j = 1; j < mat.cols(); ++j)
- for(Index i = 0; i < mat.rows(); ++i)
- visitor(mat.coeff(i, j), i, j);
+ if (Derived::IsRowMajor) {
+ for(Index i = 1; i < mat.cols(); ++i) {
+ visitor(mat.coeff(0, i), 0, i);
+ }
+ for(Index j = 1; j < mat.rows(); ++j) {
+ for(Index i = 0; i < mat.cols(); ++i) {
+ visitor(mat.coeff(j, i), j, i);
+ }
+ }
+ } else {
+ for(Index i = 1; i < mat.rows(); ++i) {
+ visitor(mat.coeff(i, 0), i, 0);
+ }
+ for(Index j = 1; j < mat.cols(); ++j) {
+ for(Index i = 0; i < mat.rows(); ++i) {
+ visitor(mat.coeff(i, j), i, j);
+ }
+ }
+ }
}
};
@@ -114,6 +130,7 @@
PacketAccess = Evaluator::Flags & PacketAccessBit,
IsRowMajor = XprType::IsRowMajor,
RowsAtCompileTime = XprType::RowsAtCompileTime,
+ ColsAtCompileTime = XprType::ColsAtCompileTime,
CoeffReadCost = Evaluator::CoeffReadCost
};
@@ -122,8 +139,8 @@
explicit visitor_evaluator(const XprType &xpr) : m_evaluator(xpr), m_xpr(xpr) { }
typedef typename XprType::Scalar Scalar;
- typedef typename internal::remove_const<typename XprType::CoeffReturnType>::type CoeffReturnType;
- typedef typename internal::remove_const<typename XprType::PacketReturnType>::type PacketReturnType;
+ typedef std::remove_const_t<typename XprType::CoeffReturnType> CoeffReturnType;
+ typedef std::remove_const_t<typename XprType::PacketReturnType> PacketReturnType;
EIGEN_DEVICE_FUNC EIGEN_CONSTEXPR Index rows() const EIGEN_NOEXCEPT { return m_xpr.rows(); }
EIGEN_DEVICE_FUNC EIGEN_CONSTEXPR Index cols() const EIGEN_NOEXCEPT { return m_xpr.cols(); }
diff --git a/Eigen/src/Core/arch/AVX/PacketMath.h b/Eigen/src/Core/arch/AVX/PacketMath.h
index 7d74188..4a646c0 100644
--- a/Eigen/src/Core/arch/AVX/PacketMath.h
+++ b/Eigen/src/Core/arch/AVX/PacketMath.h
@@ -236,7 +236,11 @@
typedef Packet4f half;
typedef Packet8i integer_packet;
typedef uint8_t mask_t;
- enum {size=8, alignment=Aligned32, vectorizable=true, masked_load_available=true, masked_store_available=true};
+ enum {size=8, alignment=Aligned32, vectorizable=true, masked_load_available=true, masked_store_available=true
+#ifdef EIGEN_VECTORIZE_AVX512
+ , masked_fpops_available=true
+#endif
+ };
};
template<> struct unpacket_traits<Packet4d> {
typedef double type;
@@ -464,6 +468,16 @@
template<> EIGEN_STRONG_INLINE Packet4d pload1<Packet4d>(const double* from) { return _mm256_broadcast_sd(from); }
template<> EIGEN_STRONG_INLINE Packet8f padd<Packet8f>(const Packet8f& a, const Packet8f& b) { return _mm256_add_ps(a,b); }
+#ifdef EIGEN_VECTORIZE_AVX512
+template <>
+EIGEN_STRONG_INLINE Packet8f padd<Packet8f>(const Packet8f& a, const Packet8f& b, uint8_t umask) {
+ __mmask16 mask = static_cast<__mmask16>(umask & 0x00FF);
+ return _mm512_castps512_ps256(_mm512_maskz_add_ps(
+ mask,
+ _mm512_castps256_ps512(a),
+ _mm512_castps256_ps512(b)));
+}
+#endif
template<> EIGEN_STRONG_INLINE Packet4d padd<Packet4d>(const Packet4d& a, const Packet4d& b) { return _mm256_add_pd(a,b); }
template<> EIGEN_STRONG_INLINE Packet8i padd<Packet8i>(const Packet8i& a, const Packet8i& b) {
#ifdef EIGEN_VECTORIZE_AVX2
@@ -848,11 +862,16 @@
template<> EIGEN_STRONG_INLINE Packet8i ploadu<Packet8i>(const int* from) { EIGEN_DEBUG_UNALIGNED_LOAD return _mm256_loadu_si256(reinterpret_cast<const __m256i*>(from)); }
template<> EIGEN_STRONG_INLINE Packet8f ploadu<Packet8f>(const float* from, uint8_t umask) {
+#ifdef EIGEN_VECTORIZE_AVX512
+ __mmask16 mask = static_cast<__mmask16>(umask & 0x00FF);
+ EIGEN_DEBUG_UNALIGNED_LOAD return _mm512_castps512_ps256(_mm512_maskz_loadu_ps(mask, from));
+#else
Packet8i mask = _mm256_set1_epi8(static_cast<char>(umask));
const Packet8i bit_mask = _mm256_set_epi32(0xffffff7f, 0xffffffbf, 0xffffffdf, 0xffffffef, 0xfffffff7, 0xfffffffb, 0xfffffffd, 0xfffffffe);
mask = por<Packet8i>(mask, bit_mask);
mask = pcmp_eq<Packet8i>(mask, _mm256_set1_epi32(0xffffffff));
EIGEN_DEBUG_UNALIGNED_LOAD return _mm256_maskload_ps(from, mask);
+#endif
}
// Loads 4 floats from memory a returns the packet {a0, a0 a1, a1, a2, a2, a3, a3}
@@ -911,11 +930,16 @@
template<> EIGEN_STRONG_INLINE void pstoreu<int>(int* to, const Packet8i& from) { EIGEN_DEBUG_UNALIGNED_STORE _mm256_storeu_si256(reinterpret_cast<__m256i*>(to), from); }
template<> EIGEN_STRONG_INLINE void pstoreu<float>(float* to, const Packet8f& from, uint8_t umask) {
+#ifdef EIGEN_VECTORIZE_AVX512
+ __mmask16 mask = static_cast<__mmask16>(umask & 0x00FF);
+ EIGEN_DEBUG_UNALIGNED_STORE return _mm512_mask_storeu_ps(to, mask, _mm512_castps256_ps512(from));
+#else
Packet8i mask = _mm256_set1_epi8(static_cast<char>(umask));
const Packet8i bit_mask = _mm256_set_epi32(0xffffff7f, 0xffffffbf, 0xffffffdf, 0xffffffef, 0xfffffff7, 0xfffffffb, 0xfffffffd, 0xfffffffe);
mask = por<Packet8i>(mask, bit_mask);
mask = pcmp_eq<Packet8i>(mask, _mm256_set1_epi32(0xffffffff));
EIGEN_DEBUG_UNALIGNED_STORE return _mm256_maskstore_ps(to, mask, from);
+#endif
}
// NOTE: leverage _mm256_i32gather_ps and _mm256_i32gather_pd if AVX2 instructions are available
diff --git a/Eigen/src/Core/arch/AVX512/PacketMath.h b/Eigen/src/Core/arch/AVX512/PacketMath.h
index 2a98b78..337001b 100644
--- a/Eigen/src/Core/arch/AVX512/PacketMath.h
+++ b/Eigen/src/Core/arch/AVX512/PacketMath.h
@@ -180,13 +180,14 @@
typedef Packet8f half;
typedef Packet16i integer_packet;
typedef uint16_t mask_t;
- enum { size = 16, alignment=Aligned64, vectorizable=true, masked_load_available=true, masked_store_available=true };
+ enum { size = 16, alignment=Aligned64, vectorizable=true, masked_load_available=true, masked_store_available=true, masked_fpops_available=true };
};
template <>
struct unpacket_traits<Packet8d> {
typedef double type;
typedef Packet4d half;
- enum { size = 8, alignment=Aligned64, vectorizable=true, masked_load_available=false, masked_store_available=false };
+ typedef uint8_t mask_t;
+ enum { size = 8, alignment=Aligned64, vectorizable=true, masked_load_available=true, masked_store_available=true, masked_fpops_available=true };
};
template <>
struct unpacket_traits<Packet16i> {
@@ -244,11 +245,25 @@
template <>
EIGEN_STRONG_INLINE Packet16f pload1<Packet16f>(const float* from) {
+#if (EIGEN_COMP_GNUC != 0) || (EIGEN_COMP_CLANG != 0)
+ // Inline asm here helps reduce some register spilling in TRSM kernels.
+ // See note in unrolls::gemm::microKernel in TrsmKernel.h
+ Packet16f ret;
+ __asm__ ("vbroadcastss %[mem], %[dst]" : [dst] "=v" (ret) : [mem] "m" (*from));
+ return ret;
+#else
return _mm512_broadcastss_ps(_mm_load_ps1(from));
+#endif
}
template <>
EIGEN_STRONG_INLINE Packet8d pload1<Packet8d>(const double* from) {
+#if (EIGEN_COMP_GNUC != 0) || (EIGEN_COMP_CLANG != 0)
+ Packet8d ret;
+ __asm__ ("vbroadcastsd %[mem], %[dst]" : [dst] "=v" (ret) : [mem] "m" (*from));
+ return ret;
+#else
return _mm512_set1_pd(*from);
+#endif
}
template <>
@@ -287,6 +302,21 @@
}
template <>
+EIGEN_STRONG_INLINE Packet16f padd<Packet16f>(const Packet16f& a,
+ const Packet16f& b,
+ uint16_t umask) {
+ __mmask16 mask = static_cast<__mmask16>(umask);
+ return _mm512_maskz_add_ps(mask, a, b);
+}
+template <>
+EIGEN_STRONG_INLINE Packet8d padd<Packet8d>(const Packet8d& a,
+ const Packet8d& b,
+ uint8_t umask) {
+ __mmask8 mask = static_cast<__mmask8>(umask);
+ return _mm512_maskz_add_pd(mask, a, b);
+}
+
+template <>
EIGEN_STRONG_INLINE Packet16f psub<Packet16f>(const Packet16f& a,
const Packet16f& b) {
return _mm512_sub_ps(a, b);
@@ -755,7 +785,7 @@
template <>
EIGEN_STRONG_INLINE Packet16i pload<Packet16i>(const int* from) {
EIGEN_DEBUG_ALIGNED_LOAD return _mm512_load_si512(
- reinterpret_cast<const __m512i*>(from));
+ reinterpret_cast<const __m512i*>(from));
}
template <>
@@ -777,6 +807,11 @@
__mmask16 mask = static_cast<__mmask16>(umask);
EIGEN_DEBUG_UNALIGNED_LOAD return _mm512_maskz_loadu_ps(mask, from);
}
+template <>
+EIGEN_STRONG_INLINE Packet8d ploadu<Packet8d>(const double* from, uint8_t umask) {
+ __mmask8 mask = static_cast<__mmask8>(umask);
+ EIGEN_DEBUG_UNALIGNED_LOAD return _mm512_maskz_loadu_pd(mask, from);
+}
// Loads 8 floats from memory a returns the packet
// {a0, a0 a1, a1, a2, a2, a3, a3, a4, a4, a5, a5, a6, a6, a7, a7}
@@ -886,6 +921,11 @@
__mmask16 mask = static_cast<__mmask16>(umask);
EIGEN_DEBUG_UNALIGNED_STORE return _mm512_mask_storeu_ps(to, mask, from);
}
+template <>
+EIGEN_STRONG_INLINE void pstoreu<double>(double* to, const Packet8d& from, uint8_t umask) {
+ __mmask8 mask = static_cast<__mmask8>(umask);
+ EIGEN_DEBUG_UNALIGNED_STORE return _mm512_mask_storeu_pd(to, mask, from);
+}
template <>
EIGEN_DEVICE_FUNC inline Packet16f pgather<float, Packet16f>(const float* from,
@@ -1017,7 +1057,7 @@
// Extract exponent without existence of Packet8l.
template<>
-EIGEN_STRONG_INLINE
+EIGEN_STRONG_INLINE
Packet8d pfrexp_generic_get_biased_exponent(const Packet8d& a) {
const Packet8d cst_exp_mask = pset1frombits<Packet8d>(static_cast<uint64_t>(0x7ff0000000000000ull));
#ifdef EIGEN_VECTORIZE_AVX512DQ
@@ -1040,11 +1080,11 @@
// Clamp exponent to [-2099, 2099]
const Packet8d max_exponent = pset1<Packet8d>(2099.0);
const Packet8i e = _mm512_cvtpd_epi32(pmin(pmax(exponent, pnegate(max_exponent)), max_exponent));
-
+
// Split 2^e into four factors and multiply.
const Packet8i bias = pset1<Packet8i>(1023);
Packet8i b = parithmetic_shift_right<2>(e); // floor(e/4)
-
+
// 2^b
const Packet8i permute_idx = _mm256_setr_epi32(0, 4, 1, 5, 2, 6, 3, 7);
Packet8i hi = _mm256_permutevar8x32_epi32(padd(b, bias), permute_idx);
@@ -1052,7 +1092,7 @@
hi = _mm256_slli_epi64(_mm256_srli_epi64(hi, 32), 52);
Packet8d c = _mm512_castsi512_pd(_mm512_inserti64x4(_mm512_castsi256_si512(lo), hi, 1));
Packet8d out = pmul(pmul(pmul(a, c), c), c); // a * 2^(3b)
-
+
// 2^(e - 3b)
b = psub(psub(psub(e, b), b), b); // e - 3b
hi = _mm256_permutevar8x32_epi32(padd(b, bias), permute_idx);
@@ -1072,7 +1112,7 @@
// AVX512F does not define _mm512_extracti32x8_epi32 to extract _m256i from _m512i
#define EIGEN_EXTRACT_8i_FROM_16i(INPUT, OUTPUT) \
__m256i OUTPUT##_0 = _mm512_extracti32x8_epi32(INPUT, 0); \
- __m256i OUTPUT##_1 = _mm512_extracti32x8_epi32(INPUT, 1)
+ __m256i OUTPUT##_1 = _mm512_extracti32x8_epi32(INPUT, 1)
#else
#define EIGEN_EXTRACT_8f_FROM_16f(INPUT, OUTPUT) \
__m256 OUTPUT##_0 = _mm256_insertf128_ps( \
@@ -1392,6 +1432,48 @@
EIGEN_INSERT_8f_INTO_16f(OUTPUT[INDEX], INPUT[2 * INDEX], \
INPUT[2 * INDEX + STRIDE]);
+EIGEN_DEVICE_FUNC inline void ptranspose(PacketBlock<Packet16f, 8>& kernel) {
+ __m512 T0 = _mm512_unpacklo_ps(kernel.packet[0],kernel.packet[1]);
+ __m512 T1 = _mm512_unpackhi_ps(kernel.packet[0],kernel.packet[1]);
+ __m512 T2 = _mm512_unpacklo_ps(kernel.packet[2],kernel.packet[3]);
+ __m512 T3 = _mm512_unpackhi_ps(kernel.packet[2],kernel.packet[3]);
+ __m512 T4 = _mm512_unpacklo_ps(kernel.packet[4],kernel.packet[5]);
+ __m512 T5 = _mm512_unpackhi_ps(kernel.packet[4],kernel.packet[5]);
+ __m512 T6 = _mm512_unpacklo_ps(kernel.packet[6],kernel.packet[7]);
+ __m512 T7 = _mm512_unpackhi_ps(kernel.packet[6],kernel.packet[7]);
+
+ kernel.packet[0] = _mm512_castpd_ps(_mm512_unpacklo_pd(_mm512_castps_pd(T0),_mm512_castps_pd(T2)));
+ kernel.packet[1] = _mm512_castpd_ps(_mm512_unpackhi_pd(_mm512_castps_pd(T0),_mm512_castps_pd(T2)));
+ kernel.packet[2] = _mm512_castpd_ps(_mm512_unpacklo_pd(_mm512_castps_pd(T1),_mm512_castps_pd(T3)));
+ kernel.packet[3] = _mm512_castpd_ps(_mm512_unpackhi_pd(_mm512_castps_pd(T1),_mm512_castps_pd(T3)));
+ kernel.packet[4] = _mm512_castpd_ps(_mm512_unpacklo_pd(_mm512_castps_pd(T4),_mm512_castps_pd(T6)));
+ kernel.packet[5] = _mm512_castpd_ps(_mm512_unpackhi_pd(_mm512_castps_pd(T4),_mm512_castps_pd(T6)));
+ kernel.packet[6] = _mm512_castpd_ps(_mm512_unpacklo_pd(_mm512_castps_pd(T5),_mm512_castps_pd(T7)));
+ kernel.packet[7] = _mm512_castpd_ps(_mm512_unpackhi_pd(_mm512_castps_pd(T5),_mm512_castps_pd(T7)));
+
+ T0 = _mm512_castpd_ps(_mm512_permutex_pd(_mm512_castps_pd(kernel.packet[4]), 0x4E));
+ T0 = _mm512_mask_blend_ps(0xF0F0, kernel.packet[0], T0);
+ T4 = _mm512_castpd_ps(_mm512_permutex_pd(_mm512_castps_pd(kernel.packet[0]), 0x4E));
+ T4 = _mm512_mask_blend_ps(0xF0F0, T4, kernel.packet[4]);
+ T1 = _mm512_castpd_ps(_mm512_permutex_pd(_mm512_castps_pd(kernel.packet[5]), 0x4E));
+ T1 = _mm512_mask_blend_ps(0xF0F0, kernel.packet[1], T1);
+ T5 = _mm512_castpd_ps(_mm512_permutex_pd(_mm512_castps_pd(kernel.packet[1]), 0x4E));
+ T5 = _mm512_mask_blend_ps(0xF0F0, T5, kernel.packet[5]);
+ T2 = _mm512_castpd_ps(_mm512_permutex_pd(_mm512_castps_pd(kernel.packet[6]), 0x4E));
+ T2 = _mm512_mask_blend_ps(0xF0F0, kernel.packet[2], T2);
+ T6 = _mm512_castpd_ps(_mm512_permutex_pd(_mm512_castps_pd(kernel.packet[2]), 0x4E));
+ T6 = _mm512_mask_blend_ps(0xF0F0, T6, kernel.packet[6]);
+ T3 = _mm512_castpd_ps(_mm512_permutex_pd(_mm512_castps_pd(kernel.packet[7]), 0x4E));
+ T3 = _mm512_mask_blend_ps(0xF0F0, kernel.packet[3], T3);
+ T7 = _mm512_castpd_ps(_mm512_permutex_pd(_mm512_castps_pd(kernel.packet[3]), 0x4E));
+ T7 = _mm512_mask_blend_ps(0xF0F0, T7, kernel.packet[7]);
+
+ kernel.packet[0] = T0; kernel.packet[1] = T1;
+ kernel.packet[2] = T2; kernel.packet[3] = T3;
+ kernel.packet[4] = T4; kernel.packet[5] = T5;
+ kernel.packet[6] = T6; kernel.packet[7] = T7;
+}
+
EIGEN_DEVICE_FUNC inline void ptranspose(PacketBlock<Packet16f, 4>& kernel) {
__m512 T0 = _mm512_unpacklo_ps(kernel.packet[0], kernel.packet[1]);
__m512 T1 = _mm512_unpackhi_ps(kernel.packet[0], kernel.packet[1]);
@@ -1468,62 +1550,53 @@
}
EIGEN_DEVICE_FUNC inline void ptranspose(PacketBlock<Packet8d, 8>& kernel) {
- __m512d T0 = _mm512_unpacklo_pd(kernel.packet[0], kernel.packet[1]);
- __m512d T1 = _mm512_unpackhi_pd(kernel.packet[0], kernel.packet[1]);
- __m512d T2 = _mm512_unpacklo_pd(kernel.packet[2], kernel.packet[3]);
- __m512d T3 = _mm512_unpackhi_pd(kernel.packet[2], kernel.packet[3]);
- __m512d T4 = _mm512_unpacklo_pd(kernel.packet[4], kernel.packet[5]);
- __m512d T5 = _mm512_unpackhi_pd(kernel.packet[4], kernel.packet[5]);
- __m512d T6 = _mm512_unpacklo_pd(kernel.packet[6], kernel.packet[7]);
- __m512d T7 = _mm512_unpackhi_pd(kernel.packet[6], kernel.packet[7]);
+ __m512d T0 = _mm512_unpacklo_pd(kernel.packet[0],kernel.packet[1]);
+ __m512d T1 = _mm512_unpackhi_pd(kernel.packet[0],kernel.packet[1]);
+ __m512d T2 = _mm512_unpacklo_pd(kernel.packet[2],kernel.packet[3]);
+ __m512d T3 = _mm512_unpackhi_pd(kernel.packet[2],kernel.packet[3]);
+ __m512d T4 = _mm512_unpacklo_pd(kernel.packet[4],kernel.packet[5]);
+ __m512d T5 = _mm512_unpackhi_pd(kernel.packet[4],kernel.packet[5]);
+ __m512d T6 = _mm512_unpacklo_pd(kernel.packet[6],kernel.packet[7]);
+ __m512d T7 = _mm512_unpackhi_pd(kernel.packet[6],kernel.packet[7]);
- PacketBlock<Packet4d, 16> tmp;
+ kernel.packet[0] = _mm512_permutex_pd(T2, 0x4E);
+ kernel.packet[0] = _mm512_mask_blend_pd(0xCC, T0, kernel.packet[0]);
+ kernel.packet[2] = _mm512_permutex_pd(T0, 0x4E);
+ kernel.packet[2] = _mm512_mask_blend_pd(0xCC, kernel.packet[2], T2);
+ kernel.packet[1] = _mm512_permutex_pd(T3, 0x4E);
+ kernel.packet[1] = _mm512_mask_blend_pd(0xCC, T1, kernel.packet[1]);
+ kernel.packet[3] = _mm512_permutex_pd(T1, 0x4E);
+ kernel.packet[3] = _mm512_mask_blend_pd(0xCC, kernel.packet[3], T3);
+ kernel.packet[4] = _mm512_permutex_pd(T6, 0x4E);
+ kernel.packet[4] = _mm512_mask_blend_pd(0xCC, T4, kernel.packet[4]);
+ kernel.packet[6] = _mm512_permutex_pd(T4, 0x4E);
+ kernel.packet[6] = _mm512_mask_blend_pd(0xCC, kernel.packet[6], T6);
+ kernel.packet[5] = _mm512_permutex_pd(T7, 0x4E);
+ kernel.packet[5] = _mm512_mask_blend_pd(0xCC, T5, kernel.packet[5]);
+ kernel.packet[7] = _mm512_permutex_pd(T5, 0x4E);
+ kernel.packet[7] = _mm512_mask_blend_pd(0xCC, kernel.packet[7], T7);
- tmp.packet[0] = _mm256_permute2f128_pd(_mm512_extractf64x4_pd(T0, 0),
- _mm512_extractf64x4_pd(T2, 0), 0x20);
- tmp.packet[1] = _mm256_permute2f128_pd(_mm512_extractf64x4_pd(T1, 0),
- _mm512_extractf64x4_pd(T3, 0), 0x20);
- tmp.packet[2] = _mm256_permute2f128_pd(_mm512_extractf64x4_pd(T0, 0),
- _mm512_extractf64x4_pd(T2, 0), 0x31);
- tmp.packet[3] = _mm256_permute2f128_pd(_mm512_extractf64x4_pd(T1, 0),
- _mm512_extractf64x4_pd(T3, 0), 0x31);
+ T0 = _mm512_shuffle_f64x2(kernel.packet[4], kernel.packet[4], 0x4E);
+ T0 = _mm512_mask_blend_pd(0xF0, kernel.packet[0], T0);
+ T4 = _mm512_shuffle_f64x2(kernel.packet[0], kernel.packet[0], 0x4E);
+ T4 = _mm512_mask_blend_pd(0xF0, T4, kernel.packet[4]);
+ T1 = _mm512_shuffle_f64x2(kernel.packet[5], kernel.packet[5], 0x4E);
+ T1 = _mm512_mask_blend_pd(0xF0, kernel.packet[1], T1);
+ T5 = _mm512_shuffle_f64x2(kernel.packet[1], kernel.packet[1], 0x4E);
+ T5 = _mm512_mask_blend_pd(0xF0, T5, kernel.packet[5]);
+ T2 = _mm512_shuffle_f64x2(kernel.packet[6], kernel.packet[6], 0x4E);
+ T2 = _mm512_mask_blend_pd(0xF0, kernel.packet[2], T2);
+ T6 = _mm512_shuffle_f64x2(kernel.packet[2], kernel.packet[2], 0x4E);
+ T6 = _mm512_mask_blend_pd(0xF0, T6, kernel.packet[6]);
+ T3 = _mm512_shuffle_f64x2(kernel.packet[7], kernel.packet[7], 0x4E);
+ T3 = _mm512_mask_blend_pd(0xF0, kernel.packet[3], T3);
+ T7 = _mm512_shuffle_f64x2(kernel.packet[3], kernel.packet[3], 0x4E);
+ T7 = _mm512_mask_blend_pd(0xF0, T7, kernel.packet[7]);
- tmp.packet[4] = _mm256_permute2f128_pd(_mm512_extractf64x4_pd(T0, 1),
- _mm512_extractf64x4_pd(T2, 1), 0x20);
- tmp.packet[5] = _mm256_permute2f128_pd(_mm512_extractf64x4_pd(T1, 1),
- _mm512_extractf64x4_pd(T3, 1), 0x20);
- tmp.packet[6] = _mm256_permute2f128_pd(_mm512_extractf64x4_pd(T0, 1),
- _mm512_extractf64x4_pd(T2, 1), 0x31);
- tmp.packet[7] = _mm256_permute2f128_pd(_mm512_extractf64x4_pd(T1, 1),
- _mm512_extractf64x4_pd(T3, 1), 0x31);
-
- tmp.packet[8] = _mm256_permute2f128_pd(_mm512_extractf64x4_pd(T4, 0),
- _mm512_extractf64x4_pd(T6, 0), 0x20);
- tmp.packet[9] = _mm256_permute2f128_pd(_mm512_extractf64x4_pd(T5, 0),
- _mm512_extractf64x4_pd(T7, 0), 0x20);
- tmp.packet[10] = _mm256_permute2f128_pd(_mm512_extractf64x4_pd(T4, 0),
- _mm512_extractf64x4_pd(T6, 0), 0x31);
- tmp.packet[11] = _mm256_permute2f128_pd(_mm512_extractf64x4_pd(T5, 0),
- _mm512_extractf64x4_pd(T7, 0), 0x31);
-
- tmp.packet[12] = _mm256_permute2f128_pd(_mm512_extractf64x4_pd(T4, 1),
- _mm512_extractf64x4_pd(T6, 1), 0x20);
- tmp.packet[13] = _mm256_permute2f128_pd(_mm512_extractf64x4_pd(T5, 1),
- _mm512_extractf64x4_pd(T7, 1), 0x20);
- tmp.packet[14] = _mm256_permute2f128_pd(_mm512_extractf64x4_pd(T4, 1),
- _mm512_extractf64x4_pd(T6, 1), 0x31);
- tmp.packet[15] = _mm256_permute2f128_pd(_mm512_extractf64x4_pd(T5, 1),
- _mm512_extractf64x4_pd(T7, 1), 0x31);
-
- PACK_OUTPUT_SQ_D(kernel.packet, tmp.packet, 0, 8);
- PACK_OUTPUT_SQ_D(kernel.packet, tmp.packet, 1, 8);
- PACK_OUTPUT_SQ_D(kernel.packet, tmp.packet, 2, 8);
- PACK_OUTPUT_SQ_D(kernel.packet, tmp.packet, 3, 8);
-
- PACK_OUTPUT_SQ_D(kernel.packet, tmp.packet, 4, 8);
- PACK_OUTPUT_SQ_D(kernel.packet, tmp.packet, 5, 8);
- PACK_OUTPUT_SQ_D(kernel.packet, tmp.packet, 6, 8);
- PACK_OUTPUT_SQ_D(kernel.packet, tmp.packet, 7, 8);
+ kernel.packet[0] = T0; kernel.packet[1] = T1;
+ kernel.packet[2] = T2; kernel.packet[3] = T3;
+ kernel.packet[4] = T4; kernel.packet[5] = T5;
+ kernel.packet[6] = T6; kernel.packet[7] = T7;
}
#define PACK_OUTPUT_I32(OUTPUT, INPUT, INDEX, STRIDE) \
diff --git a/Eigen/src/Core/arch/AVX512/TrsmKernel.h b/Eigen/src/Core/arch/AVX512/TrsmKernel.h
new file mode 100644
index 0000000..4b81bf9
--- /dev/null
+++ b/Eigen/src/Core/arch/AVX512/TrsmKernel.h
@@ -0,0 +1,1107 @@
+// This file is part of Eigen, a lightweight C++ template library
+// for linear algebra.
+//
+// Copyright (C) 2022 Intel Corporation
+//
+// This Source Code Form is subject to the terms of the Mozilla
+// Public License v. 2.0. If a copy of the MPL was not distributed
+// with this file, You can obtain one at http://mozilla.org/MPL/2.0/.
+
+#ifndef EIGEN_TRSM_KERNEL_IMPL_H
+#define EIGEN_TRSM_KERNEL_IMPL_H
+
+#include "../../InternalHeaderCheck.h"
+
+#define EIGEN_USE_AVX512_TRSM_KERNELS // Comment out to prevent using optimized trsm kernels.
+
+#if defined(EIGEN_HAS_CXX17_IFCONSTEXPR)
+#define EIGEN_IF_CONSTEXPR(X) if constexpr (X)
+#else
+#define EIGEN_IF_CONSTEXPR(X) if (X)
+#endif
+
+// Need this for some std::min calls.
+#ifdef min
+#undef min
+#endif
+
+namespace Eigen {
+namespace internal {
+
+#define EIGEN_AVX_MAX_NUM_ACC (24L)
+#define EIGEN_AVX_MAX_NUM_ROW (8L) // Denoted L in code.
+#define EIGEN_AVX_MAX_K_UNROL (4L)
+#define EIGEN_AVX_B_LOAD_SETS (2L)
+#define EIGEN_AVX_MAX_A_BCAST (2L)
+typedef Packet16f vecFullFloat;
+typedef Packet8d vecFullDouble;
+typedef Packet8f vecHalfFloat;
+typedef Packet4d vecHalfDouble;
+
+// Compile-time unrolls are implemented here.
+// Note: this depends on macros and typedefs above.
+#include "TrsmUnrolls.inc"
+
+
+#if defined(EIGEN_USE_AVX512_TRSM_KERNELS) && (EIGEN_COMP_CLANG != 0)
+/**
+ * For smaller problem sizes, and certain compilers, using the optimized kernels trsmKernelL/R directly
+ * is faster than the packed versions in TriangularSolverMatrix.h.
+ *
+ * The current heuristic is based on having having all arrays used in the largest gemm-update
+ * in triSolve fit in roughly L2Cap (percentage) of the L2 cache. These cutoffs are a bit conservative and could be
+ * larger for some trsm cases.
+ * The formula:
+ *
+ * (L*M + M*N + L*N)*sizeof(Scalar) < L2Cache*L2Cap
+ *
+ * L = number of rows to solve at a time
+ * N = number of rhs
+ * M = Dimension of triangular matrix
+ *
+ */
+#define EIGEN_ENABLE_AVX512_NOCOPY_TRSM_CUTOFFS // Comment out to disable no-copy dispatch
+template <typename Scalar>
+int64_t avx512_trsm_cutoff(int64_t L2Size, int64_t N, double L2Cap){
+ const int64_t U3 = 3*packet_traits<Scalar>::size;
+ const int64_t MaxNb = 5*U3;
+ int64_t Nb = std::min(MaxNb, N);
+ double cutoff_d = (((L2Size*L2Cap)/(sizeof(Scalar)))-(EIGEN_AVX_MAX_NUM_ROW)*Nb)/
+ ((EIGEN_AVX_MAX_NUM_ROW)+Nb);
+ int64_t cutoff_l = static_cast<int64_t>(cutoff_d);
+ return (cutoff_l/EIGEN_AVX_MAX_NUM_ROW)*EIGEN_AVX_MAX_NUM_ROW;
+}
+#endif
+
+
+/**
+ * Used by gemmKernel for the case A/B row-major and C col-major.
+ */
+template <typename Scalar, typename vec, int64_t unrollM, int64_t unrollN, bool remM, bool remN>
+static EIGEN_ALWAYS_INLINE
+void transStoreC(PacketBlock<vec,EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> &zmm,
+ Scalar *C_arr, int64_t LDC, int64_t remM_ = 0, int64_t remN_ = 0) {
+ EIGEN_UNUSED_VARIABLE(remN_);
+ EIGEN_UNUSED_VARIABLE(remM_);
+ using urolls = unrolls::trans<Scalar>;
+
+ constexpr int64_t U3 = urolls::PacketSize * 3;
+ constexpr int64_t U2 = urolls::PacketSize * 2;
+ constexpr int64_t U1 = urolls::PacketSize * 1;
+
+ static_assert( unrollN == U1 || unrollN == U2 || unrollN == U3, "unrollN should be a multiple of PacketSize");
+ static_assert( unrollM == EIGEN_AVX_MAX_NUM_ROW, "unrollM should be equal to EIGEN_AVX_MAX_NUM_ROW");
+
+ urolls::template transpose<unrollN, 0>(zmm);
+ EIGEN_IF_CONSTEXPR(unrollN > U2) urolls::template transpose<unrollN, 2>(zmm);
+ EIGEN_IF_CONSTEXPR(unrollN > U1) urolls::template transpose<unrollN, 1>(zmm);
+
+ static_assert( (remN && unrollN == U1) || !remN, "When handling N remainder set unrollN=U1");
+ EIGEN_IF_CONSTEXPR(!remN) {
+ urolls::template storeC<std::min(unrollN, U1), unrollN, 0, remM>(C_arr, LDC, zmm, remM_);
+ EIGEN_IF_CONSTEXPR(unrollN > U1) {
+ constexpr int64_t unrollN_ = std::min(unrollN-U1, U1);
+ urolls::template storeC<unrollN_, unrollN, 1, remM>(C_arr + U1*LDC, LDC, zmm, remM_);
+ }
+ EIGEN_IF_CONSTEXPR(unrollN > U2) {
+ constexpr int64_t unrollN_ = std::min(unrollN-U2, U1);
+ urolls:: template storeC<unrollN_, unrollN, 2, remM>(C_arr + U2*LDC, LDC, zmm, remM_);
+ }
+ }
+ else {
+ EIGEN_IF_CONSTEXPR( (std::is_same<Scalar, float>::value) ) {
+ // Note: without "if constexpr" this section of code will also be
+ // parsed by the compiler so each of the storeC will still be instantiated.
+ // We use enable_if in aux_storeC to set it to an empty function for
+ // these cases.
+ if(remN_ == 15)
+ urolls::template storeC<15, unrollN, 0, remM>(C_arr, LDC, zmm, remM_);
+ else if(remN_ == 14)
+ urolls::template storeC<14, unrollN, 0, remM>(C_arr, LDC, zmm, remM_);
+ else if(remN_ == 13)
+ urolls::template storeC<13, unrollN, 0, remM>(C_arr, LDC, zmm, remM_);
+ else if(remN_ == 12)
+ urolls::template storeC<12, unrollN, 0, remM>(C_arr, LDC, zmm, remM_);
+ else if(remN_ == 11)
+ urolls::template storeC<11, unrollN, 0, remM>(C_arr, LDC, zmm, remM_);
+ else if(remN_ == 10)
+ urolls::template storeC<10, unrollN, 0, remM>(C_arr, LDC, zmm, remM_);
+ else if(remN_ == 9)
+ urolls::template storeC<9, unrollN, 0, remM>(C_arr, LDC, zmm, remM_);
+ else if(remN_ == 8)
+ urolls::template storeC<8, unrollN, 0, remM>(C_arr, LDC, zmm, remM_);
+ else if(remN_ == 7)
+ urolls::template storeC<7, unrollN, 0, remM>(C_arr, LDC, zmm, remM_);
+ else if(remN_ == 6)
+ urolls::template storeC<6, unrollN, 0, remM>(C_arr, LDC, zmm, remM_);
+ else if(remN_ == 5)
+ urolls::template storeC<5, unrollN, 0, remM>(C_arr, LDC, zmm, remM_);
+ else if(remN_ == 4)
+ urolls::template storeC<4, unrollN, 0, remM>(C_arr, LDC, zmm, remM_);
+ else if(remN_ == 3)
+ urolls::template storeC<3, unrollN, 0, remM>(C_arr, LDC, zmm, remM_);
+ else if(remN_ == 2)
+ urolls::template storeC<2, unrollN, 0, remM>(C_arr, LDC, zmm, remM_);
+ else if(remN_ == 1)
+ urolls::template storeC<1, unrollN, 0, remM>(C_arr, LDC, zmm, remM_);
+ }
+ else {
+ if(remN_ == 7)
+ urolls::template storeC<7, unrollN, 0, remM>(C_arr, LDC, zmm, remM_);
+ else if(remN_ == 6)
+ urolls::template storeC<6, unrollN, 0, remM>(C_arr, LDC, zmm, remM_);
+ else if(remN_ == 5)
+ urolls::template storeC<5, unrollN, 0, remM>(C_arr, LDC, zmm, remM_);
+ else if(remN_ == 4)
+ urolls::template storeC<4, unrollN, 0, remM>(C_arr, LDC, zmm, remM_);
+ else if(remN_ == 3)
+ urolls::template storeC<3, unrollN, 0, remM>(C_arr, LDC, zmm, remM_);
+ else if(remN_ == 2)
+ urolls::template storeC<2, unrollN, 0, remM>(C_arr, LDC, zmm, remM_);
+ else if(remN_ == 1)
+ urolls::template storeC<1, unrollN, 0, remM>(C_arr, LDC, zmm, remM_);
+ }
+ }
+}
+
+/**
+ * GEMM like operation for trsm panel updates.
+ * Computes: C -= A*B
+ * K must be multipe of 4.
+ *
+ * Unrolls used are {1,2,4,8}x{U1,U2,U3};
+ * For good performance we want K to be large with M/N relatively small, but also large enough
+ * to use the {8,U3} unroll block.
+ *
+ * isARowMajor: is A_arr row-major?
+ * isCRowMajor: is C_arr row-major? (B_arr is assumed to be row-major).
+ * isAdd: C += A*B or C -= A*B (used by trsm)
+ * handleKRem: Handle arbitrary K? This is not needed for trsm.
+ */
+template <typename Scalar, bool isARowMajor, bool isCRowMajor, bool isAdd, bool handleKRem>
+void gemmKernel(Scalar *A_arr, Scalar *B_arr, Scalar *C_arr,
+ int64_t M, int64_t N, int64_t K,
+ int64_t LDA, int64_t LDB, int64_t LDC) {
+ using urolls = unrolls::gemm<Scalar, isAdd>;
+ constexpr int64_t U3 = urolls::PacketSize * 3;
+ constexpr int64_t U2 = urolls::PacketSize * 2;
+ constexpr int64_t U1 = urolls::PacketSize * 1;
+ using vec = typename std::conditional<std::is_same<Scalar, float>::value,
+ vecFullFloat,
+ vecFullDouble>::type;
+ int64_t N_ = (N/U3)*U3;
+ int64_t M_ = (M/EIGEN_AVX_MAX_NUM_ROW)*EIGEN_AVX_MAX_NUM_ROW;
+ int64_t K_ = (K/EIGEN_AVX_MAX_K_UNROL)*EIGEN_AVX_MAX_K_UNROL;
+ int64_t j = 0;
+ for(; j < N_; j += U3) {
+ constexpr int64_t EIGEN_AVX_MAX_B_LOAD = EIGEN_AVX_B_LOAD_SETS*3;
+ int64_t i = 0;
+ for(; i < M_; i += EIGEN_AVX_MAX_NUM_ROW) {
+ Scalar *A_t = &A_arr[idA<isARowMajor>(i,0,LDA)], *B_t = &B_arr[0*LDB + j];
+ PacketBlock<vec,EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> zmm;
+ urolls::template setzero<3,EIGEN_AVX_MAX_NUM_ROW>(zmm);
+ for(int64_t k = 0; k < K_ ; k += EIGEN_AVX_MAX_K_UNROL) {
+ urolls:: template microKernel<isARowMajor,3,EIGEN_AVX_MAX_NUM_ROW,EIGEN_AVX_MAX_K_UNROL,
+ EIGEN_AVX_MAX_B_LOAD,EIGEN_AVX_MAX_A_BCAST>(
+ B_t, A_t, LDB, LDA, zmm);
+ B_t += EIGEN_AVX_MAX_K_UNROL*LDB;
+ EIGEN_IF_CONSTEXPR(isARowMajor) A_t += EIGEN_AVX_MAX_K_UNROL; else A_t += EIGEN_AVX_MAX_K_UNROL*LDA;
+ }
+ EIGEN_IF_CONSTEXPR(handleKRem) {
+ for(int64_t k = K_; k < K ; k ++) {
+ urolls:: template microKernel<isARowMajor,3,EIGEN_AVX_MAX_NUM_ROW,1,
+ EIGEN_AVX_B_LOAD_SETS*3,EIGEN_AVX_MAX_A_BCAST>(
+ B_t, A_t, LDB, LDA, zmm);
+ B_t += LDB;
+ EIGEN_IF_CONSTEXPR(isARowMajor) A_t++; else A_t += LDA;
+ }
+ }
+ EIGEN_IF_CONSTEXPR(isCRowMajor) {
+ urolls::template updateC<3,EIGEN_AVX_MAX_NUM_ROW>(&C_arr[i*LDC + j], LDC, zmm);
+ urolls::template storeC<3,EIGEN_AVX_MAX_NUM_ROW>(&C_arr[i*LDC+ j], LDC, zmm);
+ }
+ else {
+ transStoreC<Scalar,vec,EIGEN_AVX_MAX_NUM_ROW,U3,false, false>(zmm, &C_arr[i + j*LDC], LDC);
+ }
+ }
+ if(M - i >= 4) { // Note: this block assumes EIGEN_AVX_MAX_NUM_ROW = 8. Should be removed otherwise
+ Scalar *A_t = &A_arr[idA<isARowMajor>(i,0,LDA)];
+ Scalar *B_t = &B_arr[0*LDB + j];
+ PacketBlock<vec,EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> zmm;
+ urolls::template setzero<3,4>(zmm);
+ for(int64_t k = 0; k < K_ ; k += EIGEN_AVX_MAX_K_UNROL) {
+ urolls:: template microKernel<isARowMajor,3,4,EIGEN_AVX_MAX_K_UNROL,
+ EIGEN_AVX_B_LOAD_SETS*3,EIGEN_AVX_MAX_A_BCAST>(
+ B_t, A_t, LDB, LDA, zmm);
+ B_t += EIGEN_AVX_MAX_K_UNROL*LDB;
+ EIGEN_IF_CONSTEXPR(isARowMajor) A_t += EIGEN_AVX_MAX_K_UNROL; else A_t += EIGEN_AVX_MAX_K_UNROL*LDA;
+ }
+ EIGEN_IF_CONSTEXPR(handleKRem) {
+ for(int64_t k = K_; k < K ; k ++) {
+ urolls:: template microKernel<isARowMajor,3,4,1,
+ EIGEN_AVX_B_LOAD_SETS*3,EIGEN_AVX_MAX_A_BCAST>(
+ B_t, A_t, LDB, LDA, zmm);
+ B_t += LDB;
+ EIGEN_IF_CONSTEXPR(isARowMajor) A_t++; else A_t += LDA;
+ }
+ }
+ EIGEN_IF_CONSTEXPR(isCRowMajor) {
+ urolls::template updateC<3,4>(&C_arr[i*LDC + j], LDC, zmm);
+ urolls::template storeC<3,4>(&C_arr[i*LDC + j], LDC, zmm);
+ }
+ else {
+ transStoreC<Scalar,vec,EIGEN_AVX_MAX_NUM_ROW,U3,true, false>(zmm, &C_arr[i + j*LDC], LDC, 4);
+ }
+ i += 4;
+ }
+ if(M - i >= 2) {
+ Scalar *A_t = &A_arr[idA<isARowMajor>(i,0,LDA)];
+ Scalar *B_t = &B_arr[0*LDB + j];
+ PacketBlock<vec,EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> zmm;
+ urolls::template setzero<3,2>(zmm);
+ for(int64_t k = 0; k < K_ ; k += EIGEN_AVX_MAX_K_UNROL) {
+ urolls:: template microKernel<isARowMajor,3,2,EIGEN_AVX_MAX_K_UNROL,
+ EIGEN_AVX_B_LOAD_SETS*3,EIGEN_AVX_MAX_A_BCAST>(
+ B_t, A_t, LDB, LDA, zmm);
+ B_t += EIGEN_AVX_MAX_K_UNROL*LDB;
+ EIGEN_IF_CONSTEXPR(isARowMajor) A_t += EIGEN_AVX_MAX_K_UNROL; else A_t += EIGEN_AVX_MAX_K_UNROL*LDA;
+ }
+ EIGEN_IF_CONSTEXPR(handleKRem) {
+ for(int64_t k = K_; k < K ; k ++) {
+ urolls:: template microKernel<isARowMajor,3,2,1,
+ EIGEN_AVX_B_LOAD_SETS*3,EIGEN_AVX_MAX_A_BCAST>(
+ B_t, A_t, LDB, LDA, zmm);
+ B_t += LDB;
+ EIGEN_IF_CONSTEXPR(isARowMajor) A_t++; else A_t += LDA;
+ }
+ }
+ EIGEN_IF_CONSTEXPR(isCRowMajor) {
+ urolls::template updateC<3,2>(&C_arr[i*LDC + j], LDC, zmm);
+ urolls::template storeC<3,2>(&C_arr[i*LDC + j], LDC, zmm);
+ }
+ else {
+ transStoreC<Scalar,vec,EIGEN_AVX_MAX_NUM_ROW,U3,true, false>(zmm, &C_arr[i + j*LDC], LDC, 2);
+ }
+ i += 2;
+ }
+ if(M - i > 0) {
+ Scalar *A_t = &A_arr[idA<isARowMajor>(i,0,LDA)];
+ Scalar *B_t = &B_arr[0*LDB + j];
+ PacketBlock<vec,EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> zmm;
+ urolls::template setzero<3,1>(zmm);
+ {
+ for(int64_t k = 0; k < K_ ; k += EIGEN_AVX_MAX_K_UNROL) {
+ urolls:: template microKernel<isARowMajor,3,1,EIGEN_AVX_MAX_K_UNROL,
+ EIGEN_AVX_B_LOAD_SETS*3,1>(
+ B_t, A_t, LDB, LDA, zmm);
+ B_t += EIGEN_AVX_MAX_K_UNROL*LDB;
+ EIGEN_IF_CONSTEXPR(isARowMajor) A_t += EIGEN_AVX_MAX_K_UNROL; else A_t += EIGEN_AVX_MAX_K_UNROL*LDA;
+ }
+ EIGEN_IF_CONSTEXPR(handleKRem) {
+ for(int64_t k = K_; k < K ; k ++) {
+ urolls:: template microKernel<isARowMajor,3,1,1,
+ EIGEN_AVX_B_LOAD_SETS*3,1>(B_t, A_t, LDB, LDA, zmm);
+ B_t += LDB;
+ EIGEN_IF_CONSTEXPR(isARowMajor) A_t++; else A_t += LDA;
+ }
+ }
+ EIGEN_IF_CONSTEXPR(isCRowMajor) {
+ urolls::template updateC<3,1>(&C_arr[i*LDC + j], LDC, zmm);
+ urolls::template storeC<3,1>(&C_arr[i*LDC + j], LDC, zmm);
+ }
+ else {
+ transStoreC<Scalar,vec,EIGEN_AVX_MAX_NUM_ROW,U3,true, false>(zmm, &C_arr[i + j*LDC], LDC, 1);
+ }
+ }
+ }
+ }
+ if(N - j >= U2) {
+ constexpr int64_t EIGEN_AVX_MAX_B_LOAD = EIGEN_AVX_B_LOAD_SETS*2;
+ int64_t i = 0;
+ for(; i < M_; i += EIGEN_AVX_MAX_NUM_ROW) {
+ Scalar *A_t = &A_arr[idA<isARowMajor>(i,0,LDA)], *B_t = &B_arr[0*LDB + j];
+ EIGEN_IF_CONSTEXPR(isCRowMajor) B_t = &B_arr[0*LDB + j];
+ PacketBlock<vec,EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> zmm;
+ urolls::template setzero<2,EIGEN_AVX_MAX_NUM_ROW>(zmm);
+ for(int64_t k = 0; k < K_ ; k += EIGEN_AVX_MAX_K_UNROL) {
+ urolls:: template microKernel<isARowMajor,2,EIGEN_AVX_MAX_NUM_ROW,
+ EIGEN_AVX_MAX_K_UNROL,EIGEN_AVX_MAX_B_LOAD,EIGEN_AVX_MAX_A_BCAST>(
+ B_t, A_t, LDB, LDA, zmm);
+ B_t += EIGEN_AVX_MAX_K_UNROL*LDB;
+ EIGEN_IF_CONSTEXPR(isARowMajor) A_t += EIGEN_AVX_MAX_K_UNROL; else A_t += EIGEN_AVX_MAX_K_UNROL*LDA;
+ }
+ EIGEN_IF_CONSTEXPR(handleKRem) {
+ for(int64_t k = K_; k < K ; k ++) {
+ urolls:: template microKernel<isARowMajor,2,EIGEN_AVX_MAX_NUM_ROW,1,
+ EIGEN_AVX_MAX_B_LOAD,EIGEN_AVX_MAX_A_BCAST>(
+ B_t, A_t, LDB, LDA, zmm);
+ B_t += LDB;
+ EIGEN_IF_CONSTEXPR(isARowMajor) A_t++; else A_t += LDA;
+ }
+ }
+ EIGEN_IF_CONSTEXPR(isCRowMajor) {
+ urolls::template updateC<2,EIGEN_AVX_MAX_NUM_ROW>(&C_arr[i*LDC + j], LDC, zmm);
+ urolls::template storeC<2,EIGEN_AVX_MAX_NUM_ROW>(&C_arr[i*LDC + j], LDC, zmm);
+ }
+ else {
+ transStoreC<Scalar,vec,EIGEN_AVX_MAX_NUM_ROW,U2,false, false>(zmm, &C_arr[i + j*LDC], LDC);
+ }
+ }
+ if(M - i >= 4) { // Note: this block assumes EIGEN_AVX_MAX_NUM_ROW = 8. Should be removed otherwise
+ Scalar *A_t = &A_arr[idA<isARowMajor>(i,0,LDA)];
+ Scalar *B_t = &B_arr[0*LDB + j];
+ PacketBlock<vec,EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> zmm;
+ urolls::template setzero<2,4>(zmm);
+ for(int64_t k = 0; k < K_ ; k += EIGEN_AVX_MAX_K_UNROL) {
+ urolls:: template microKernel<isARowMajor,2,4,EIGEN_AVX_MAX_K_UNROL,
+ EIGEN_AVX_MAX_B_LOAD,EIGEN_AVX_MAX_A_BCAST>(
+ B_t, A_t, LDB, LDA, zmm);
+ B_t += EIGEN_AVX_MAX_K_UNROL*LDB;
+ EIGEN_IF_CONSTEXPR(isARowMajor) A_t += EIGEN_AVX_MAX_K_UNROL; else A_t += EIGEN_AVX_MAX_K_UNROL*LDA;
+ }
+ EIGEN_IF_CONSTEXPR(handleKRem) {
+ for(int64_t k = K_; k < K ; k ++) {
+ urolls:: template microKernel<isARowMajor,2,4,1,
+ EIGEN_AVX_MAX_B_LOAD,EIGEN_AVX_MAX_A_BCAST>(
+ B_t, A_t, LDB, LDA, zmm);
+ B_t += LDB;
+ EIGEN_IF_CONSTEXPR(isARowMajor) A_t++; else A_t += LDA;
+ }
+ }
+ EIGEN_IF_CONSTEXPR(isCRowMajor) {
+ urolls::template updateC<2,4>(&C_arr[i*LDC + j], LDC, zmm);
+ urolls::template storeC<2,4>(&C_arr[i*LDC + j], LDC, zmm);
+ }
+ else {
+ transStoreC<Scalar,vec,EIGEN_AVX_MAX_NUM_ROW,U2,true, false>(zmm, &C_arr[i + j*LDC], LDC, 4);
+ }
+ i += 4;
+ }
+ if(M - i >= 2) {
+ Scalar *A_t = &A_arr[idA<isARowMajor>(i,0,LDA)];
+ Scalar *B_t = &B_arr[0*LDB + j];
+ PacketBlock<vec,EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> zmm;
+ urolls::template setzero<2,2>(zmm);
+ for(int64_t k = 0; k < K_ ; k += EIGEN_AVX_MAX_K_UNROL) {
+ urolls:: template microKernel<isARowMajor,2,2,EIGEN_AVX_MAX_K_UNROL,
+ EIGEN_AVX_MAX_B_LOAD,EIGEN_AVX_MAX_A_BCAST>(
+ B_t, A_t, LDB, LDA, zmm);
+ B_t += EIGEN_AVX_MAX_K_UNROL*LDB;
+ EIGEN_IF_CONSTEXPR(isARowMajor) A_t += EIGEN_AVX_MAX_K_UNROL; else A_t += EIGEN_AVX_MAX_K_UNROL*LDA;
+ }
+ EIGEN_IF_CONSTEXPR(handleKRem) {
+ for(int64_t k = K_; k < K ; k ++) {
+ urolls:: template microKernel<isARowMajor,2,2,1,
+ EIGEN_AVX_MAX_B_LOAD,EIGEN_AVX_MAX_A_BCAST>(
+ B_t, A_t, LDB, LDA, zmm);
+ B_t += LDB;
+ EIGEN_IF_CONSTEXPR(isARowMajor) A_t++; else A_t += LDA;
+ }
+ }
+ EIGEN_IF_CONSTEXPR(isCRowMajor) {
+ urolls::template updateC<2,2>(&C_arr[i*LDC + j], LDC, zmm);
+ urolls::template storeC<2,2>(&C_arr[i*LDC + j], LDC, zmm);
+ }
+ else {
+ transStoreC<Scalar,vec,EIGEN_AVX_MAX_NUM_ROW,U2,true, false>(zmm, &C_arr[i + j*LDC], LDC, 2);
+ }
+ i += 2;
+ }
+ if(M - i > 0) {
+ Scalar *A_t = &A_arr[idA<isARowMajor>(i,0,LDA)];
+ Scalar *B_t = &B_arr[0*LDB + j];
+ PacketBlock<vec,EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> zmm;
+ urolls::template setzero<2,1>(zmm);
+ for(int64_t k = 0; k < K_ ; k += EIGEN_AVX_MAX_K_UNROL) {
+ urolls:: template microKernel<isARowMajor,2,1,EIGEN_AVX_MAX_K_UNROL,
+ EIGEN_AVX_MAX_B_LOAD,1>(
+ B_t, A_t, LDB, LDA, zmm);
+ B_t += EIGEN_AVX_MAX_K_UNROL*LDB;
+ EIGEN_IF_CONSTEXPR(isARowMajor) A_t += EIGEN_AVX_MAX_K_UNROL; else A_t += EIGEN_AVX_MAX_K_UNROL*LDA;
+ }
+ EIGEN_IF_CONSTEXPR(handleKRem) {
+ for(int64_t k = K_; k < K ; k ++) {
+ urolls:: template microKernel<isARowMajor,2,1,1,
+ EIGEN_AVX_MAX_B_LOAD,1>(B_t, A_t, LDB, LDA, zmm);
+ B_t += LDB;
+ EIGEN_IF_CONSTEXPR(isARowMajor) A_t++; else A_t += LDA;
+ }
+ }
+ EIGEN_IF_CONSTEXPR(isCRowMajor) {
+ urolls::template updateC<2,1>(&C_arr[i*LDC + j], LDC, zmm);
+ urolls::template storeC<2,1>(&C_arr[i*LDC + j], LDC, zmm);
+ }
+ else {
+ transStoreC<Scalar,vec,EIGEN_AVX_MAX_NUM_ROW,U2,true, false>(zmm, &C_arr[i + j*LDC], LDC, 1);
+ }
+ }
+ j += U2;
+ }
+ if(N - j >= U1) {
+ constexpr int64_t EIGEN_AVX_MAX_B_LOAD = EIGEN_AVX_B_LOAD_SETS*1;
+ int64_t i = 0;
+ for(; i < M_; i += EIGEN_AVX_MAX_NUM_ROW) {
+ Scalar *A_t = &A_arr[idA<isARowMajor>(i,0,LDA)], *B_t = &B_arr[0*LDB + j];
+ PacketBlock<vec,EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> zmm;
+ urolls::template setzero<1,EIGEN_AVX_MAX_NUM_ROW>(zmm);
+ for(int64_t k = 0; k < K_ ; k += EIGEN_AVX_MAX_K_UNROL) {
+ urolls:: template microKernel<isARowMajor,1,EIGEN_AVX_MAX_NUM_ROW,EIGEN_AVX_MAX_K_UNROL,
+ EIGEN_AVX_MAX_B_LOAD,EIGEN_AVX_MAX_A_BCAST>(
+ B_t, A_t, LDB, LDA, zmm);
+ B_t += EIGEN_AVX_MAX_K_UNROL*LDB;
+ EIGEN_IF_CONSTEXPR(isARowMajor) A_t += EIGEN_AVX_MAX_K_UNROL; else A_t += EIGEN_AVX_MAX_K_UNROL*LDA;
+ }
+ EIGEN_IF_CONSTEXPR(handleKRem) {
+ for(int64_t k = K_; k < K ; k ++) {
+ urolls:: template microKernel<isARowMajor,1,EIGEN_AVX_MAX_NUM_ROW,1,
+ EIGEN_AVX_B_LOAD_SETS*1,EIGEN_AVX_MAX_A_BCAST>(
+ B_t, A_t, LDB, LDA, zmm);
+ B_t += LDB;
+ EIGEN_IF_CONSTEXPR(isARowMajor) A_t++; else A_t += LDA;
+ }
+ }
+ EIGEN_IF_CONSTEXPR(isCRowMajor) {
+ urolls::template updateC<1,EIGEN_AVX_MAX_NUM_ROW>(&C_arr[i*LDC + j], LDC, zmm);
+ urolls::template storeC<1,EIGEN_AVX_MAX_NUM_ROW>(&C_arr[i*LDC + j], LDC, zmm);
+ }
+ else {
+ transStoreC<Scalar,vec,EIGEN_AVX_MAX_NUM_ROW,U1,false, false>(zmm, &C_arr[i + j*LDC], LDC);
+ }
+ }
+ if(M - i >= 4) { // Note: this block assumes EIGEN_AVX_MAX_NUM_ROW = 8. Should be removed otherwise
+ Scalar *A_t = &A_arr[idA<isARowMajor>(i,0,LDA)];
+ Scalar *B_t = &B_arr[0*LDB + j];
+ PacketBlock<vec,EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> zmm;
+ urolls::template setzero<1,4>(zmm);
+ for(int64_t k = 0; k < K_ ; k += EIGEN_AVX_MAX_K_UNROL) {
+ urolls:: template microKernel<isARowMajor,1,4,EIGEN_AVX_MAX_K_UNROL,
+ EIGEN_AVX_MAX_B_LOAD,EIGEN_AVX_MAX_A_BCAST>(
+ B_t, A_t, LDB, LDA, zmm);
+ B_t += EIGEN_AVX_MAX_K_UNROL*LDB;
+ EIGEN_IF_CONSTEXPR(isARowMajor) A_t += EIGEN_AVX_MAX_K_UNROL; else A_t += EIGEN_AVX_MAX_K_UNROL*LDA;
+ }
+ EIGEN_IF_CONSTEXPR(handleKRem) {
+ for(int64_t k = K_; k < K ; k ++) {
+ urolls:: template microKernel<isARowMajor,1,4,1,
+ EIGEN_AVX_MAX_B_LOAD,EIGEN_AVX_MAX_A_BCAST>(
+ B_t, A_t, LDB, LDA, zmm);
+ B_t += LDB;
+ EIGEN_IF_CONSTEXPR(isARowMajor) A_t++; else A_t += LDA;
+ }
+ }
+ EIGEN_IF_CONSTEXPR(isCRowMajor) {
+ urolls::template updateC<1,4>(&C_arr[i*LDC + j], LDC, zmm);
+ urolls::template storeC<1,4>(&C_arr[i*LDC + j], LDC, zmm);
+ }
+ else {
+ transStoreC<Scalar,vec,EIGEN_AVX_MAX_NUM_ROW,U1,true, false>(zmm, &C_arr[i + j*LDC], LDC, 4);
+ }
+ i += 4;
+ }
+ if(M - i >= 2) {
+ Scalar *A_t = &A_arr[idA<isARowMajor>(i,0,LDA)];
+ Scalar *B_t = &B_arr[0*LDB + j];
+ PacketBlock<vec,EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> zmm;
+ urolls::template setzero<1,2>(zmm);
+ for(int64_t k = 0; k < K_ ; k += EIGEN_AVX_MAX_K_UNROL) {
+ urolls:: template microKernel<isARowMajor,1,2,EIGEN_AVX_MAX_K_UNROL,
+ EIGEN_AVX_MAX_B_LOAD,EIGEN_AVX_MAX_A_BCAST>(
+ B_t, A_t, LDB, LDA, zmm);
+ B_t += EIGEN_AVX_MAX_K_UNROL*LDB;
+ EIGEN_IF_CONSTEXPR(isARowMajor) A_t += EIGEN_AVX_MAX_K_UNROL; else A_t += EIGEN_AVX_MAX_K_UNROL*LDA;
+ }
+ EIGEN_IF_CONSTEXPR(handleKRem) {
+ for(int64_t k = K_; k < K ; k ++) {
+ urolls:: template microKernel<isARowMajor,1,2,1,
+ EIGEN_AVX_MAX_B_LOAD,EIGEN_AVX_MAX_A_BCAST>(
+ B_t, A_t, LDB, LDA, zmm);
+ B_t += LDB;
+ EIGEN_IF_CONSTEXPR(isARowMajor) A_t++; else A_t += LDA;
+ }
+ }
+ EIGEN_IF_CONSTEXPR(isCRowMajor) {
+ urolls::template updateC<1,2>(&C_arr[i*LDC + j], LDC, zmm);
+ urolls::template storeC<1,2>(&C_arr[i*LDC + j], LDC, zmm);
+ }
+ else {
+ transStoreC<Scalar,vec,EIGEN_AVX_MAX_NUM_ROW,U1,true, false>(zmm, &C_arr[i + j*LDC], LDC, 2);
+ }
+ i += 2;
+ }
+ if(M - i > 0) {
+ Scalar *A_t = &A_arr[idA<isARowMajor>(i,0,LDA)];
+ Scalar *B_t = &B_arr[0*LDB + j];
+ PacketBlock<vec,EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> zmm;
+ urolls::template setzero<1,1>(zmm);
+ {
+ for(int64_t k = 0; k < K_ ; k += EIGEN_AVX_MAX_K_UNROL) {
+ urolls:: template microKernel<isARowMajor,1,1,EIGEN_AVX_MAX_K_UNROL,
+ EIGEN_AVX_MAX_B_LOAD,1>(
+ B_t, A_t, LDB, LDA, zmm);
+ B_t += EIGEN_AVX_MAX_K_UNROL*LDB;
+ EIGEN_IF_CONSTEXPR(isARowMajor) A_t += EIGEN_AVX_MAX_K_UNROL; else A_t += EIGEN_AVX_MAX_K_UNROL*LDA;
+ }
+ EIGEN_IF_CONSTEXPR(handleKRem) {
+ for(int64_t k = K_; k < K ; k ++) {
+ urolls:: template microKernel<isARowMajor,1,1,1,EIGEN_AVX_B_LOAD_SETS*1,1>(B_t, A_t, LDB, LDA, zmm);
+ B_t += LDB;
+ EIGEN_IF_CONSTEXPR(isARowMajor) A_t++; else A_t += LDA;
+ }
+ }
+ EIGEN_IF_CONSTEXPR(isCRowMajor) {
+ urolls::template updateC<1,1>(&C_arr[i*LDC + j], LDC, zmm);
+ urolls::template storeC<1,1>(&C_arr[i*LDC + j], LDC, zmm);
+ }
+ else {
+ transStoreC<Scalar,vec,EIGEN_AVX_MAX_NUM_ROW,U1,true, false>(zmm, &C_arr[i + j*LDC], LDC, 1);
+ }
+ }
+ }
+ j += U1;
+ }
+ if(N - j > 0) {
+ constexpr int64_t EIGEN_AVX_MAX_B_LOAD = EIGEN_AVX_B_LOAD_SETS*1;
+ int64_t i = 0;
+ for(; i < M_; i += EIGEN_AVX_MAX_NUM_ROW) {
+ Scalar *A_t = &A_arr[idA<isARowMajor>(i,0,LDA)];
+ Scalar *B_t = &B_arr[0*LDB + j];
+ PacketBlock<vec,EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> zmm;
+ urolls::template setzero<1,EIGEN_AVX_MAX_NUM_ROW>(zmm);
+ for(int64_t k = 0; k < K_ ; k += EIGEN_AVX_MAX_K_UNROL) {
+ urolls:: template microKernel<isARowMajor,1,EIGEN_AVX_MAX_NUM_ROW,EIGEN_AVX_MAX_K_UNROL,
+ EIGEN_AVX_MAX_B_LOAD,EIGEN_AVX_MAX_A_BCAST,true>(
+ B_t, A_t, LDB, LDA, zmm, N - j);
+ B_t += EIGEN_AVX_MAX_K_UNROL*LDB;
+ EIGEN_IF_CONSTEXPR(isARowMajor) A_t += EIGEN_AVX_MAX_K_UNROL; else A_t += EIGEN_AVX_MAX_K_UNROL*LDA;
+ }
+ EIGEN_IF_CONSTEXPR(handleKRem) {
+ for(int64_t k = K_; k < K ; k ++) {
+ urolls:: template microKernel<isARowMajor,1,EIGEN_AVX_MAX_NUM_ROW,1,
+ EIGEN_AVX_MAX_B_LOAD,EIGEN_AVX_MAX_A_BCAST,true>(
+ B_t, A_t, LDB, LDA, zmm, N - j);
+ B_t += LDB;
+ EIGEN_IF_CONSTEXPR(isARowMajor) A_t++; else A_t += LDA;
+ }
+ }
+ EIGEN_IF_CONSTEXPR(isCRowMajor) {
+ urolls::template updateC<1,EIGEN_AVX_MAX_NUM_ROW,true>(&C_arr[i*LDC + j], LDC, zmm, N - j);
+ urolls::template storeC<1,EIGEN_AVX_MAX_NUM_ROW,true>(&C_arr[i*LDC + j], LDC, zmm, N - j);
+ }
+ else {
+ transStoreC<Scalar,vec,EIGEN_AVX_MAX_NUM_ROW,U1,false, true>(zmm, &C_arr[i + j*LDC], LDC, 0, N-j);
+ }
+ }
+ if(M - i >= 4) { // Note: this block assumes EIGEN_AVX_MAX_NUM_ROW = 8. Should be removed otherwise
+ Scalar *A_t = &A_arr[idA<isARowMajor>(i,0,LDA)];
+ Scalar *B_t = &B_arr[0*LDB + j];
+ PacketBlock<vec,EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> zmm;
+ urolls::template setzero<1,4>(zmm);
+ for(int64_t k = 0; k < K_ ; k += EIGEN_AVX_MAX_K_UNROL) {
+ urolls:: template microKernel<isARowMajor,1,4,EIGEN_AVX_MAX_K_UNROL,
+ EIGEN_AVX_MAX_B_LOAD,EIGEN_AVX_MAX_A_BCAST,true>(
+ B_t, A_t, LDB, LDA, zmm, N - j);
+ B_t += EIGEN_AVX_MAX_K_UNROL*LDB;
+ EIGEN_IF_CONSTEXPR(isARowMajor) A_t += EIGEN_AVX_MAX_K_UNROL; else A_t += EIGEN_AVX_MAX_K_UNROL*LDA;
+ }
+ EIGEN_IF_CONSTEXPR(handleKRem) {
+ for(int64_t k = K_; k < K ; k ++) {
+ urolls:: template microKernel<isARowMajor,1,4,1,
+ EIGEN_AVX_MAX_B_LOAD,EIGEN_AVX_MAX_A_BCAST,true>(
+ B_t, A_t, LDB, LDA, zmm, N - j);
+ B_t += LDB;
+ EIGEN_IF_CONSTEXPR(isARowMajor) A_t++; else A_t += LDA;
+ }
+ }
+ EIGEN_IF_CONSTEXPR(isCRowMajor) {
+ urolls::template updateC<1,4,true>(&C_arr[i*LDC + j], LDC, zmm, N - j);
+ urolls::template storeC<1,4,true>(&C_arr[i*LDC + j], LDC, zmm, N - j);
+ }
+ else {
+ transStoreC<Scalar,vec,EIGEN_AVX_MAX_NUM_ROW,U1,true, true>(zmm, &C_arr[i + j*LDC], LDC, 4, N-j);
+ }
+ i += 4;
+ }
+ if(M - i >= 2) {
+ Scalar *A_t = &A_arr[idA<isARowMajor>(i,0,LDA)];
+ Scalar *B_t = &B_arr[0*LDB + j];
+ PacketBlock<vec,EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> zmm;
+ urolls::template setzero<1,2>(zmm);
+ for(int64_t k = 0; k < K_ ; k += EIGEN_AVX_MAX_K_UNROL) {
+ urolls:: template microKernel<isARowMajor,1,2,EIGEN_AVX_MAX_K_UNROL,
+ EIGEN_AVX_MAX_B_LOAD,EIGEN_AVX_MAX_A_BCAST,true>(
+ B_t, A_t, LDB, LDA, zmm, N - j);
+ B_t += EIGEN_AVX_MAX_K_UNROL*LDB;
+ EIGEN_IF_CONSTEXPR(isARowMajor) A_t += EIGEN_AVX_MAX_K_UNROL; else A_t += EIGEN_AVX_MAX_K_UNROL*LDA;
+ }
+ EIGEN_IF_CONSTEXPR(handleKRem) {
+ for(int64_t k = K_; k < K ; k ++) {
+ urolls:: template microKernel<isARowMajor,1,2,1,
+ EIGEN_AVX_MAX_B_LOAD,EIGEN_AVX_MAX_A_BCAST,true>(
+ B_t, A_t, LDB, LDA, zmm, N - j);
+ B_t += LDB;
+ EIGEN_IF_CONSTEXPR(isARowMajor) A_t++; else A_t += LDA;
+ }
+ }
+ EIGEN_IF_CONSTEXPR(isCRowMajor) {
+ urolls::template updateC<1,2,true>(&C_arr[i*LDC + j], LDC, zmm, N - j);
+ urolls::template storeC<1,2,true>(&C_arr[i*LDC + j], LDC, zmm, N - j);
+ }
+ else {
+ transStoreC<Scalar,vec,EIGEN_AVX_MAX_NUM_ROW,U1,true, true>(zmm, &C_arr[i + j*LDC], LDC, 2, N-j);
+ }
+ i += 2;
+ }
+ if(M - i > 0) {
+ Scalar *A_t = &A_arr[idA<isARowMajor>(i,0,LDA)];
+ Scalar *B_t = &B_arr[0*LDB + j];
+ PacketBlock<vec,EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> zmm;
+ urolls::template setzero<1,1>(zmm);
+ for(int64_t k = 0; k < K_ ; k += EIGEN_AVX_MAX_K_UNROL) {
+ urolls:: template microKernel<isARowMajor,1,1,EIGEN_AVX_MAX_K_UNROL,
+ EIGEN_AVX_MAX_B_LOAD,1,true>(
+ B_t, A_t, LDB, LDA, zmm, N - j);
+ B_t += EIGEN_AVX_MAX_K_UNROL*LDB;
+ EIGEN_IF_CONSTEXPR(isARowMajor) A_t += EIGEN_AVX_MAX_K_UNROL; else A_t += EIGEN_AVX_MAX_K_UNROL*LDA;
+ }
+ EIGEN_IF_CONSTEXPR(handleKRem) {
+ for(int64_t k = K_; k < K ; k ++) {
+ urolls:: template microKernel<isARowMajor,1,1,1,
+ EIGEN_AVX_MAX_B_LOAD,1,true>(
+ B_t, A_t, LDB, LDA, zmm, N - j);
+ B_t += LDB;
+ EIGEN_IF_CONSTEXPR(isARowMajor) A_t++; else A_t += LDA;
+ }
+ }
+ EIGEN_IF_CONSTEXPR(isCRowMajor) {
+ urolls::template updateC<1,1,true>(&C_arr[i*LDC + j], LDC, zmm, N - j);
+ urolls::template storeC<1,1,true>(&C_arr[i*LDC + j], LDC, zmm, N - j);
+ }
+ else {
+ transStoreC<Scalar,vec,EIGEN_AVX_MAX_NUM_ROW,U1,true, true>(zmm, &C_arr[i + j*LDC], LDC, 1, N-j);
+ }
+ }
+ }
+}
+
+/**
+ * Triangular solve kernel with A on left with K number of rhs. dim(A) = unrollM
+ *
+ * unrollM: dimension of A matrix (triangular matrix). unrollM should be <= EIGEN_AVX_MAX_NUM_ROW
+ * isFWDSolve: is forward solve?
+ * isUnitDiag: is the diagonal of A all ones?
+ * The B matrix (RHS) is assumed to be row-major
+*/
+template <typename Scalar, typename vec, int64_t unrollM, bool isARowMajor, bool isFWDSolve, bool isUnitDiag>
+static EIGEN_ALWAYS_INLINE
+void triSolveKernel(Scalar *A_arr, Scalar *B_arr, int64_t K, int64_t LDA, int64_t LDB) {
+
+ static_assert( unrollM <= EIGEN_AVX_MAX_NUM_ROW, "unrollM should be equal to EIGEN_AVX_MAX_NUM_ROW" );
+ using urolls = unrolls::trsm<Scalar>;
+ constexpr int64_t U3 = urolls::PacketSize * 3;
+ constexpr int64_t U2 = urolls::PacketSize * 2;
+ constexpr int64_t U1 = urolls::PacketSize * 1;
+
+ PacketBlock<vec,EIGEN_AVX_MAX_NUM_ACC> RHSInPacket;
+ PacketBlock<vec,EIGEN_AVX_MAX_NUM_ROW> AInPacket;
+
+ int64_t k = 0;
+ while(K - k >= U3) {
+ urolls:: template loadRHS<isFWDSolve, unrollM, 3>(B_arr + k, LDB, RHSInPacket);
+ urolls:: template triSolveMicroKernel<isARowMajor, isFWDSolve, isUnitDiag, unrollM, 3>(
+ A_arr, LDA, RHSInPacket, AInPacket);
+ urolls:: template storeRHS<isFWDSolve, unrollM, 3>(B_arr + k, LDB, RHSInPacket);
+ k += U3;
+ }
+ if(K - k >= U2) {
+ urolls:: template loadRHS<isFWDSolve, unrollM, 2>(B_arr + k, LDB, RHSInPacket);
+ urolls:: template triSolveMicroKernel<isARowMajor, isFWDSolve, isUnitDiag, unrollM, 2>(
+ A_arr, LDA, RHSInPacket, AInPacket);
+ urolls:: template storeRHS<isFWDSolve, unrollM, 2>(B_arr + k, LDB, RHSInPacket);
+ k += U2;
+ }
+ if(K - k >= U1) {
+ urolls:: template loadRHS<isFWDSolve, unrollM, 1>(B_arr + k, LDB, RHSInPacket);
+ urolls:: template triSolveMicroKernel<isARowMajor, isFWDSolve, isUnitDiag, unrollM, 1>(
+ A_arr, LDA, RHSInPacket, AInPacket);
+ urolls:: template storeRHS<isFWDSolve, unrollM, 1>(B_arr + k, LDB, RHSInPacket);
+ k += U1;
+ }
+ if(K - k > 0) {
+ // Handle remaining number of RHS
+ urolls::template loadRHS<isFWDSolve, unrollM, 1, true>(B_arr + k, LDB, RHSInPacket, K-k);
+ urolls::template triSolveMicroKernel<isARowMajor, isFWDSolve, isUnitDiag, unrollM, 1>(
+ A_arr, LDA, RHSInPacket, AInPacket);
+ urolls::template storeRHS<isFWDSolve, unrollM, 1, true>(B_arr + k, LDB, RHSInPacket, K-k);
+ }
+}
+
+/**
+ * Triangular solve routine with A on left and dimension of at most L with K number of rhs. This is essentially
+ * a wrapper for triSolveMicrokernel for M = {1,2,3,4,5,6,7,8}.
+ *
+ * isFWDSolve: is forward solve?
+ * isUnitDiag: is the diagonal of A all ones?
+ * The B matrix (RHS) is assumed to be row-major
+*/
+template <typename Scalar, bool isARowMajor, bool isFWDSolve, bool isUnitDiag>
+void triSolveKernelLxK(Scalar *A_arr, Scalar *B_arr, int64_t M, int64_t K, int64_t LDA, int64_t LDB) {
+ // Note: this assumes EIGEN_AVX_MAX_NUM_ROW = 8. Unrolls should be adjusted
+ // accordingly if EIGEN_AVX_MAX_NUM_ROW is smaller.
+ using vec = typename std::conditional<std::is_same<Scalar, float>::value, vecFullFloat, vecFullDouble>::type;
+ if (M == 8)
+ triSolveKernel<Scalar, vec, 8, isARowMajor, isFWDSolve, isUnitDiag>(A_arr, B_arr, K, LDA, LDB);
+ else if (M == 7)
+ triSolveKernel<Scalar, vec, 7, isARowMajor, isFWDSolve, isUnitDiag>(A_arr, B_arr, K, LDA, LDB);
+ else if (M == 6)
+ triSolveKernel<Scalar, vec, 6, isARowMajor, isFWDSolve, isUnitDiag>(A_arr, B_arr, K, LDA, LDB);
+ else if (M == 5)
+ triSolveKernel<Scalar, vec, 5, isARowMajor, isFWDSolve, isUnitDiag>(A_arr, B_arr, K, LDA, LDB);
+ else if (M == 4)
+ triSolveKernel<Scalar, vec, 4, isARowMajor, isFWDSolve, isUnitDiag>(A_arr, B_arr, K, LDA, LDB);
+ else if (M == 3)
+ triSolveKernel<Scalar, vec, 3, isARowMajor, isFWDSolve, isUnitDiag>(A_arr, B_arr, K, LDA, LDB);
+ else if (M == 2)
+ triSolveKernel<Scalar, vec, 2, isARowMajor, isFWDSolve, isUnitDiag>(A_arr, B_arr, K, LDA, LDB);
+ else if (M == 1)
+ triSolveKernel<Scalar, vec, 1, isARowMajor, isFWDSolve, isUnitDiag>(A_arr, B_arr, K, LDA, LDB);
+ return;
+}
+
+/**
+ * This routine is used to copy B to/from a temporary array (row-major) for cases where B is column-major.
+ *
+ * toTemp: true => copy to temporary array, false => copy from temporary array
+ * remM: true = need to handle remainder values for M (M < EIGEN_AVX_MAX_NUM_ROW)
+ *
+ */
+template <typename Scalar, bool toTemp = true, bool remM = false>
+static EIGEN_ALWAYS_INLINE
+void copyBToRowMajor(Scalar *B_arr, int64_t LDB, int64_t K,
+ Scalar *B_temp, int64_t LDB_, int64_t remM_ = 0) {
+ EIGEN_UNUSED_VARIABLE(remM_);
+ using urolls = unrolls::transB<Scalar>;
+ using vecHalf = typename std::conditional<std::is_same<Scalar, float>::value, vecHalfFloat, vecFullDouble>::type;
+ PacketBlock<vecHalf,EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> ymm;
+ constexpr int64_t U3 = urolls::PacketSize * 3;
+ constexpr int64_t U2 = urolls::PacketSize * 2;
+ constexpr int64_t U1 = urolls::PacketSize * 1;
+ int64_t K_ = K/U3*U3;
+ int64_t k = 0;
+
+ for(; k < K_; k += U3) {
+ urolls::template transB_kernel<U3, toTemp, remM>(B_arr + k*LDB, LDB, B_temp, LDB_, ymm, remM_);
+ B_temp += U3;
+ }
+ if(K - k >= U2) {
+ urolls::template transB_kernel<U2, toTemp, remM>(B_arr + k*LDB, LDB, B_temp, LDB_, ymm, remM_);
+ B_temp += U2; k += U2;
+ }
+ if(K - k >= U1) {
+ urolls::template transB_kernel<U1, toTemp, remM>(B_arr + k*LDB, LDB, B_temp, LDB_, ymm, remM_);
+ B_temp += U1; k += U1;
+ }
+ EIGEN_IF_CONSTEXPR( U1 > 8) {
+ // Note: without "if constexpr" this section of code will also be
+ // parsed by the compiler so there is an additional check in {load/store}BBlock
+ // to make sure the counter is not non-negative.
+ if(K - k >= 8) {
+ urolls::template transB_kernel<8, toTemp, remM>(B_arr + k*LDB, LDB, B_temp, LDB_, ymm, remM_);
+ B_temp += 8; k += 8;
+ }
+ }
+ EIGEN_IF_CONSTEXPR( U1 > 4) {
+ // Note: without "if constexpr" this section of code will also be
+ // parsed by the compiler so there is an additional check in {load/store}BBlock
+ // to make sure the counter is not non-negative.
+ if(K - k >= 4) {
+ urolls::template transB_kernel<4, toTemp, remM>(B_arr + k*LDB, LDB, B_temp, LDB_, ymm, remM_);
+ B_temp += 4; k += 4;
+ }
+ }
+ if(K - k >= 2) {
+ urolls::template transB_kernel<2, toTemp, remM>(B_arr + k*LDB, LDB, B_temp, LDB_, ymm, remM_);
+ B_temp += 2; k += 2;
+ }
+ if(K - k >= 1) {
+ urolls::template transB_kernel<1, toTemp, remM>(B_arr + k*LDB, LDB, B_temp, LDB_, ymm, remM_);
+ B_temp += 1; k += 1;
+ }
+}
+
+/**
+ * Main triangular solve driver
+ *
+ * Triangular solve with A on the left.
+ * Scalar: Scalar precision, only float/double is supported.
+ * isARowMajor: is A row-major?
+ * isBRowMajor: is B row-major?
+ * isFWDSolve: is this forward solve or backward (true => forward)?
+ * isUnitDiag: is diagonal of A unit or nonunit (true => A has unit diagonal)?
+ *
+ * M: dimension of A
+ * numRHS: number of right hand sides (coincides with K dimension for gemm updates)
+ *
+ * Here are the mapping between the different TRSM cases (col-major) and triSolve:
+ *
+ * LLN (left , lower, A non-transposed) :: isARowMajor=false, isBRowMajor=false, isFWDSolve=true
+ * LUT (left , upper, A transposed) :: isARowMajor=true, isBRowMajor=false, isFWDSolve=true
+ * LUN (left , upper, A non-transposed) :: isARowMajor=false, isBRowMajor=false, isFWDSolve=false
+ * LLT (left , lower, A transposed) :: isARowMajor=true, isBRowMajor=false, isFWDSolve=false
+ * RUN (right, upper, A non-transposed) :: isARowMajor=true, isBRowMajor=true, isFWDSolve=true
+ * RLT (right, lower, A transposed) :: isARowMajor=false, isBRowMajor=true, isFWDSolve=true
+ * RUT (right, upper, A transposed) :: isARowMajor=false, isBRowMajor=true, isFWDSolve=false
+ * RLN (right, lower, A non-transposed) :: isARowMajor=true, isBRowMajor=true, isFWDSolve=false
+ *
+ * Note: For RXX cases M,numRHS should be swapped.
+ *
+*/
+template <typename Scalar, bool isARowMajor = true, bool isBRowMajor = true, bool isFWDSolve = true, bool isUnitDiag = false>
+void triSolve(Scalar *A_arr, Scalar *B_arr, int64_t M, int64_t numRHS, int64_t LDA, int64_t LDB) {
+ /**
+ * The values for kB, numM were determined experimentally.
+ * kB: Number of RHS we process at a time.
+ * numM: number of rows of B we will store in a temporary array (see below.) This should be a multiple of L.
+ *
+ * kB was determined by initially setting kB = numRHS and benchmarking triSolve (TRSM-RUN case)
+ * performance with M=numRHS.
+ * It was observed that performance started to drop around M=numRHS=240. This is likely machine dependent.
+ *
+ * numM was chosen "arbitrarily". It should be relatively small so B_temp is not too large, but it should be
+ * large enough to allow GEMM updates to have larger "K"s (see below.) No benchmarking has been done so far to
+ * determine optimal values for numM.
+ */
+ const int64_t kB = (3*packet_traits<Scalar>::size)*5; // 5*U3
+ const int64_t numM = 64;
+
+ int64_t sizeBTemp = 0;
+ Scalar *B_temp = NULL;
+ EIGEN_IF_CONSTEXPR(!isBRowMajor) {
+ /**
+ * If B is col-major, we copy it to a fixed-size temporary array of size at most ~numM*kB and
+ * transpose it to row-major. Call the solve routine, and copy+transpose it back to the original array.
+ * The updated row-major copy of B is reused in the GEMM updates.
+ */
+ sizeBTemp = (((std::min(kB, numRHS) + 15)/16+ 4)*16)*numM;
+ B_temp = (Scalar*) aligned_alloc(4096,sizeof(Scalar)*sizeBTemp);
+ }
+ for(int64_t k = 0; k < numRHS; k += kB) {
+ int64_t bK = numRHS - k > kB ? kB : numRHS - k;
+ int64_t M_ = (M/EIGEN_AVX_MAX_NUM_ROW)*EIGEN_AVX_MAX_NUM_ROW, gemmOff = 0;
+
+ // bK rounded up to next multiple of L=EIGEN_AVX_MAX_NUM_ROW. When B_temp is used, we solve for bkL RHS
+ // instead of bK RHS in triSolveKernelLxK.
+ int64_t bkL = ((bK + (EIGEN_AVX_MAX_NUM_ROW-1))/EIGEN_AVX_MAX_NUM_ROW)*EIGEN_AVX_MAX_NUM_ROW;
+ const int64_t numScalarPerCache = 64/sizeof(Scalar);
+ // Leading dimension of B_temp, will be a multiple of the cache line size.
+ int64_t LDT = ((bkL+(numScalarPerCache-1))/numScalarPerCache)*numScalarPerCache;
+ int64_t offsetBTemp = 0;
+ for(int64_t i = 0; i < M_; i += EIGEN_AVX_MAX_NUM_ROW) {
+ EIGEN_IF_CONSTEXPR(!isBRowMajor) {
+ int64_t indA_i = isFWDSolve ? i : M - 1 - i;
+ int64_t indB_i = isFWDSolve ? i : M - (i + EIGEN_AVX_MAX_NUM_ROW);
+ int64_t offB_1 = isFWDSolve ? offsetBTemp : sizeBTemp - EIGEN_AVX_MAX_NUM_ROW*LDT - offsetBTemp;
+ int64_t offB_2 = isFWDSolve ? offsetBTemp : sizeBTemp - LDT - offsetBTemp;
+ // Copy values from B to B_temp.
+ copyBToRowMajor<Scalar, true, false>(B_arr + indB_i + k*LDB, LDB, bK, B_temp + offB_1, LDT);
+ // Triangular solve with a small block of A and long horizontal blocks of B (or B_temp if B col-major)
+ triSolveKernelLxK<Scalar, isARowMajor, isFWDSolve, isUnitDiag>(
+ &A_arr[idA<isARowMajor>(indA_i, indA_i, LDA)], B_temp + offB_2, EIGEN_AVX_MAX_NUM_ROW, bkL, LDA, LDT);
+ // Copy values from B_temp back to B. B_temp will be reused in gemm call below.
+ copyBToRowMajor<Scalar, false, false>(B_arr + indB_i + k*LDB, LDB, bK, B_temp + offB_1, LDT);
+
+ offsetBTemp += EIGEN_AVX_MAX_NUM_ROW*LDT;
+ }
+ else {
+ int64_t ind = isFWDSolve ? i : M - 1 - i;
+ triSolveKernelLxK<Scalar, isARowMajor, isFWDSolve, isUnitDiag>(
+ &A_arr[idA<isARowMajor>(ind, ind, LDA)], B_arr + k + ind*LDB, EIGEN_AVX_MAX_NUM_ROW, bK, LDA, LDB);
+ }
+ if(i+EIGEN_AVX_MAX_NUM_ROW < M_) {
+ /**
+ * For the GEMM updates, we want "K" (K=i+8 in this case) to be large as soon as possible
+ * to reuse the accumulators in GEMM as much as possible. So we only update 8xbK blocks of
+ * B as follows:
+ *
+ * A B
+ * __
+ * |__|__ |__|
+ * |__|__|__ |__|
+ * |__|__|__|__ |__|
+ * |********|__| |**|
+ */
+ EIGEN_IF_CONSTEXPR(isBRowMajor) {
+ int64_t indA_i = isFWDSolve ? i + EIGEN_AVX_MAX_NUM_ROW : M - (i + 2*EIGEN_AVX_MAX_NUM_ROW);
+ int64_t indA_j = isFWDSolve ? 0 : M - (i + EIGEN_AVX_MAX_NUM_ROW);
+ int64_t indB_i = isFWDSolve ? 0 : M - (i + EIGEN_AVX_MAX_NUM_ROW);
+ int64_t indB_i2 = isFWDSolve ? i + EIGEN_AVX_MAX_NUM_ROW : M - (i + 2*EIGEN_AVX_MAX_NUM_ROW);
+ gemmKernel<Scalar,isARowMajor, isBRowMajor,false,false>(
+ &A_arr[idA<isARowMajor>(indA_i,indA_j,LDA)],
+ B_arr + k + indB_i*LDB,
+ B_arr + k + indB_i2*LDB,
+ EIGEN_AVX_MAX_NUM_ROW, bK, i + EIGEN_AVX_MAX_NUM_ROW,
+ LDA, LDB, LDB);
+ }
+ else {
+ if(offsetBTemp + EIGEN_AVX_MAX_NUM_ROW*LDT > sizeBTemp) {
+ /**
+ * Similar idea as mentioned above, but here we are limited by the number of updated values of B
+ * that can be stored (row-major) in B_temp.
+ *
+ * If there is not enough space to store the next batch of 8xbK of B in B_temp, we call GEMM
+ * update and partially update the remaining old values of B which depends on the new values
+ * of B stored in B_temp. These values are then no longer needed and can be overwritten.
+ */
+ int64_t indA_i = isFWDSolve ? i + EIGEN_AVX_MAX_NUM_ROW : 0;
+ int64_t indA_j = isFWDSolve ? gemmOff : M - (i + EIGEN_AVX_MAX_NUM_ROW);
+ int64_t indB_i = isFWDSolve ? i + EIGEN_AVX_MAX_NUM_ROW : 0;
+ int64_t offB_1 = isFWDSolve ? 0 : sizeBTemp - offsetBTemp;
+ gemmKernel<Scalar,isARowMajor, isBRowMajor,false,false>(
+ &A_arr[idA<isARowMajor>(indA_i, indA_j,LDA)],
+ B_temp + offB_1,
+ B_arr + indB_i + (k)*LDB,
+ M - (i + EIGEN_AVX_MAX_NUM_ROW), bK, i + EIGEN_AVX_MAX_NUM_ROW - gemmOff,
+ LDA, LDT, LDB);
+ offsetBTemp = 0; gemmOff = i + EIGEN_AVX_MAX_NUM_ROW;
+ }
+ else {
+ /**
+ * If there is enough space in B_temp, we only update the next 8xbK values of B.
+ */
+ int64_t indA_i = isFWDSolve ? i + EIGEN_AVX_MAX_NUM_ROW : M - (i + 2*EIGEN_AVX_MAX_NUM_ROW);
+ int64_t indA_j = isFWDSolve ? gemmOff : M - (i + EIGEN_AVX_MAX_NUM_ROW);
+ int64_t indB_i = isFWDSolve ? i + EIGEN_AVX_MAX_NUM_ROW : M - (i + 2*EIGEN_AVX_MAX_NUM_ROW);
+ int64_t offB_1 = isFWDSolve ? 0 : sizeBTemp - offsetBTemp;
+ gemmKernel<Scalar,isARowMajor, isBRowMajor,false,false>(
+ &A_arr[idA<isARowMajor>(indA_i,indA_j,LDA)],
+ B_temp + offB_1,
+ B_arr + indB_i + (k)*LDB,
+ EIGEN_AVX_MAX_NUM_ROW, bK, i + EIGEN_AVX_MAX_NUM_ROW - gemmOff,
+ LDA, LDT, LDB);
+ }
+ }
+ }
+ }
+ // Handle M remainder..
+ int64_t bM = M-M_;
+ if (bM > 0){
+ if( M_ > 0) {
+ EIGEN_IF_CONSTEXPR(isBRowMajor) {
+ int64_t indA_i = isFWDSolve ? M_ : 0;
+ int64_t indA_j = isFWDSolve ? 0 : bM;
+ int64_t indB_i = isFWDSolve ? 0 : bM;
+ int64_t indB_i2 = isFWDSolve ? M_ : 0;
+ gemmKernel<Scalar,isARowMajor, isBRowMajor,false,false>(
+ &A_arr[idA<isARowMajor>(indA_i,indA_j,LDA)],
+ B_arr + k +indB_i*LDB,
+ B_arr + k + indB_i2*LDB,
+ bM , bK, M_,
+ LDA, LDB, LDB);
+ }
+ else {
+ int64_t indA_i = isFWDSolve ? M_ : 0;
+ int64_t indA_j = isFWDSolve ? gemmOff : bM;
+ int64_t indB_i = isFWDSolve ? M_ : 0;
+ int64_t offB_1 = isFWDSolve ? 0 : sizeBTemp - offsetBTemp;
+ gemmKernel<Scalar,isARowMajor, isBRowMajor,false,false>(
+ &A_arr[idA<isARowMajor>(indA_i,indA_j,LDA)],
+ B_temp + offB_1,
+ B_arr + indB_i + (k)*LDB,
+ bM , bK, M_ - gemmOff,
+ LDA, LDT, LDB);
+ }
+ }
+ EIGEN_IF_CONSTEXPR(!isBRowMajor) {
+ int64_t indA_i = isFWDSolve ? M_ : M - 1 - M_;
+ int64_t indB_i = isFWDSolve ? M_ : 0;
+ int64_t offB_1 = isFWDSolve ? 0 : (bM-1)*bkL;
+ copyBToRowMajor<Scalar, true, true>(B_arr + indB_i + k*LDB, LDB, bK, B_temp, bkL, bM);
+ triSolveKernelLxK<Scalar, isARowMajor, isFWDSolve, isUnitDiag>(
+ &A_arr[idA<isARowMajor>(indA_i, indA_i, LDA)], B_temp + offB_1, bM, bkL, LDA, bkL);
+ copyBToRowMajor<Scalar, false, true>(B_arr + indB_i + k*LDB, LDB, bK, B_temp, bkL, bM);
+ }
+ else {
+ int64_t ind = isFWDSolve ? M_ : M - 1 - M_;
+ triSolveKernelLxK<Scalar, isARowMajor, isFWDSolve, isUnitDiag>(
+ &A_arr[idA<isARowMajor>(ind, ind, LDA)], B_arr + k + ind*LDB, bM, bK, LDA, LDB);
+ }
+ }
+ }
+ EIGEN_IF_CONSTEXPR(!isBRowMajor) free(B_temp);
+}
+
+template <typename Scalar, bool isARowMajor = true, bool isCRowMajor = true>
+void gemmKer(Scalar *A_arr, Scalar *B_arr, Scalar *C_arr,
+ int64_t M, int64_t N, int64_t K,
+ int64_t LDA, int64_t LDB, int64_t LDC) {
+ gemmKernel<Scalar, isARowMajor, isCRowMajor, true, true>(B_arr, A_arr, C_arr, N, M, K, LDB, LDA, LDC);
+}
+
+
+// Template specializations of trsmKernelL/R for float/double and inner strides of 1.
+#if defined(EIGEN_USE_AVX512_TRSM_KERNELS)
+template <typename Scalar, typename Index, int Mode, bool Conjugate, int TriStorageOrder,int OtherInnerStride>
+struct trsm_kernels;
+
+template <typename Index, int Mode, int TriStorageOrder>
+struct trsm_kernels<float, Index, Mode, false, TriStorageOrder, 1>{
+ static void trsmKernelL(Index size, Index otherSize, const float* _tri, Index triStride,
+ float* _other, Index otherIncr, Index otherStride);
+ static void trsmKernelR(Index size, Index otherSize, const float* _tri, Index triStride,
+ float* _other, Index otherIncr, Index otherStride);
+};
+
+template <typename Index, int Mode, int TriStorageOrder>
+struct trsm_kernels<double, Index, Mode, false, TriStorageOrder, 1>{
+ static void trsmKernelL(Index size, Index otherSize, const double* _tri, Index triStride,
+ double* _other, Index otherIncr, Index otherStride);
+ static void trsmKernelR(Index size, Index otherSize, const double* _tri, Index triStride,
+ double* _other, Index otherIncr, Index otherStride);
+};
+
+template <typename Index, int Mode, int TriStorageOrder>
+EIGEN_DONT_INLINE void trsm_kernels<float, Index, Mode, false, TriStorageOrder, 1>::trsmKernelL(
+ Index size, Index otherSize,
+ const float* _tri, Index triStride,
+ float* _other, Index otherIncr, Index otherStride)
+{
+ EIGEN_UNUSED_VARIABLE(otherIncr);
+ triSolve<float, TriStorageOrder==RowMajor, false, (Mode&Lower)==Lower, (Mode & UnitDiag)!=0>(
+ const_cast<float*>(_tri), _other, size, otherSize, triStride, otherStride);
+}
+
+template <typename Index, int Mode, int TriStorageOrder>
+EIGEN_DONT_INLINE void trsm_kernels<float, Index, Mode, false, TriStorageOrder, 1>::trsmKernelR(
+ Index size, Index otherSize,
+ const float* _tri, Index triStride,
+ float* _other, Index otherIncr, Index otherStride)
+{
+ EIGEN_UNUSED_VARIABLE(otherIncr);
+ triSolve<float, TriStorageOrder!=RowMajor, true, (Mode&Lower)!=Lower, (Mode & UnitDiag)!=0>(
+ const_cast<float*>(_tri), _other, size, otherSize, triStride, otherStride);
+}
+
+template <typename Index, int Mode, int TriStorageOrder>
+EIGEN_DONT_INLINE void trsm_kernels<double, Index, Mode, false, TriStorageOrder, 1>::trsmKernelL(
+ Index size, Index otherSize,
+ const double* _tri, Index triStride,
+ double* _other, Index otherIncr, Index otherStride)
+{
+ EIGEN_UNUSED_VARIABLE(otherIncr);
+ triSolve<double, TriStorageOrder==RowMajor, false, (Mode&Lower)==Lower, (Mode & UnitDiag)!=0>(
+ const_cast<double*>(_tri), _other, size, otherSize, triStride, otherStride);
+}
+
+template <typename Index, int Mode, int TriStorageOrder>
+EIGEN_DONT_INLINE void trsm_kernels<double, Index, Mode, false, TriStorageOrder, 1>::trsmKernelR(
+ Index size, Index otherSize,
+ const double* _tri, Index triStride,
+ double* _other, Index otherIncr, Index otherStride)
+{
+ EIGEN_UNUSED_VARIABLE(otherIncr);
+ triSolve<double, TriStorageOrder!=RowMajor, true, (Mode&Lower)!=Lower, (Mode & UnitDiag)!=0>(
+ const_cast<double*>(_tri), _other, size, otherSize, triStride, otherStride);
+}
+#endif //EIGEN_USE_AVX512_TRSM_KERNELS
+}
+}
+#endif //EIGEN_TRSM_KERNEL_IMPL_H
diff --git a/Eigen/src/Core/arch/AVX512/TrsmUnrolls.inc b/Eigen/src/Core/arch/AVX512/TrsmUnrolls.inc
new file mode 100644
index 0000000..22cb1c9
--- /dev/null
+++ b/Eigen/src/Core/arch/AVX512/TrsmUnrolls.inc
@@ -0,0 +1,1201 @@
+// This file is part of Eigen, a lightweight C++ template library
+// for linear algebra.
+//
+// Copyright (C) 2022 Intel Corporation
+//
+// This Source Code Form is subject to the terms of the Mozilla
+// Public License v. 2.0. If a copy of the MPL was not distributed
+// with this file, You can obtain one at http://mozilla.org/MPL/2.0/.
+
+#ifndef EIGEN_UNROLLS_IMPL_H
+#define EIGEN_UNROLLS_IMPL_H
+
+template <bool isARowMajor = true>
+static EIGEN_ALWAYS_INLINE
+int64_t idA(int64_t i, int64_t j, int64_t LDA) {
+ EIGEN_IF_CONSTEXPR(isARowMajor) return i * LDA + j;
+ else return i + j * LDA;
+}
+
+/**
+ * This namespace contains various classes used to generate compile-time unrolls which are
+ * used throughout the trsm/gemm kernels. The unrolls are characterized as for-loops (1-D), nested
+ * for-loops (2-D), or triple nested for-loops (3-D). Unrolls are generated using template recursion
+ *
+ * Example, the 2-D for-loop is unrolled recursively by first flattening to a 1-D loop.
+ *
+ * for(startI = 0; startI < endI; startI++) for(startC = 0; startC < endI*endJ; startC++)
+ * for(startJ = 0; startJ < endJ; startJ++) ----> startI = (startC)/(endJ)
+ * func(startI,startJ) startJ = (startC)%(endJ)
+ * func(...)
+ *
+ * The 1-D loop can be unrolled recursively by using enable_if and defining an auxillary function
+ * with a template parameter used as a counter.
+ *
+ * template <endI, endJ, counter>
+ * std::enable_if_t<(counter <= 0)> <---- tail case.
+ * aux_func {}
+ *
+ * template <endI, endJ, counter>
+ * std::enable_if_t<(counter > 0)> <---- actual for-loop
+ * aux_func {
+ * startC = endI*endJ - counter
+ * startI = (startC)/(endJ)
+ * startJ = (startC)%(endJ)
+ * func(startI, startJ)
+ * aux_func<endI, endJ, counter-1>()
+ * }
+ *
+ * Note: Additional wrapper functions are provided for aux_func which hides the counter template
+ * parameter since counter usually depends on endI, endJ, etc...
+ *
+ * Conventions:
+ * 1) endX: specifies the terminal value for the for-loop, (ex: for(startX = 0; startX < endX; startX++))
+ *
+ * 2) rem, remM, remK template parameters are used for deciding whether to use masked operations for
+ * handling remaining tails (when sizes are not multiples of PacketSize or EIGEN_AVX_MAX_NUM_ROW)
+ */
+namespace unrolls {
+
+template <int64_t N>
+EIGEN_ALWAYS_INLINE auto remMask(int64_t m) {
+ EIGEN_IF_CONSTEXPR( N == 16) { return 0xFFFF >> (16 - m); }
+ else EIGEN_IF_CONSTEXPR( N == 8) { return 0xFF >> (8 - m); }
+ else EIGEN_IF_CONSTEXPR( N == 4) { return 0x0F >> (4 - m); }
+ return 0;
+}
+
+/***
+ * Unrolls for tranposed C stores
+ */
+template <typename Scalar>
+class trans {
+public:
+ using vec = typename std::conditional<std::is_same<Scalar, float>::value, vecFullFloat, vecFullDouble>::type;
+ using vecHalf = typename std::conditional<std::is_same<Scalar, float>::value, vecHalfFloat, vecFullDouble>::type;
+ static constexpr int64_t PacketSize = packet_traits<Scalar>::size;
+
+
+ /***********************************
+ * Auxillary Functions for:
+ * - storeC
+ ***********************************
+ */
+
+ /**
+ * aux_storeC
+ *
+ * 1-D unroll
+ * for(startN = 0; startN < endN; startN++)
+ *
+ * (endN <= PacketSize) is required to handle the fp32 case, see comments in transStoreC
+ *
+ **/
+ template<int64_t endN, int64_t counter, int64_t unrollN, int64_t packetIndexOffset, bool remM>
+ static EIGEN_ALWAYS_INLINE std::enable_if_t<(counter > 0 && endN <= PacketSize)>
+ aux_storeC(Scalar *C_arr, int64_t LDC,
+ PacketBlock<vec,EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> &zmm, int64_t remM_ = 0) {
+ constexpr int64_t counterReverse = endN-counter;
+ constexpr int64_t startN = counterReverse;
+
+ EIGEN_IF_CONSTEXPR(startN < EIGEN_AVX_MAX_NUM_ROW) {
+ EIGEN_IF_CONSTEXPR(remM) {
+ pstoreu<Scalar>(
+ C_arr + LDC*startN,
+ padd(ploadu<vecHalf>((const Scalar*)C_arr + LDC*startN, remMask<EIGEN_AVX_MAX_NUM_ROW>(remM_)),
+ preinterpret<vecHalf>(zmm.packet[packetIndexOffset + (unrollN/PacketSize)*startN]),
+ remMask<EIGEN_AVX_MAX_NUM_ROW>(remM_)),
+ remMask<EIGEN_AVX_MAX_NUM_ROW>(remM_));
+ }
+ else {
+ pstoreu<Scalar>(
+ C_arr + LDC*startN,
+ padd(ploadu<vecHalf>((const Scalar*)C_arr + LDC*startN),
+ preinterpret<vecHalf>(zmm.packet[packetIndexOffset + (unrollN/PacketSize)*startN])));
+ }
+ }
+ else { // This block is only needed for fp32 case
+ // Reinterpret as __m512 for _mm512_shuffle_f32x4
+ vecFullFloat zmm2vecFullFloat = preinterpret<vecFullFloat>(
+ zmm.packet[packetIndexOffset + (unrollN/PacketSize)*(startN - EIGEN_AVX_MAX_NUM_ROW)]);
+ // Swap lower and upper half of avx register.
+ zmm.packet[packetIndexOffset + (unrollN/PacketSize)*(startN - EIGEN_AVX_MAX_NUM_ROW)] =
+ preinterpret<vec>(_mm512_shuffle_f32x4(zmm2vecFullFloat, zmm2vecFullFloat, 0b01001110));
+
+ EIGEN_IF_CONSTEXPR(remM) {
+ pstoreu<Scalar>(
+ C_arr + LDC*startN,
+ padd(ploadu<vecHalf>((const Scalar*)C_arr + LDC*startN,
+ remMask<EIGEN_AVX_MAX_NUM_ROW>(remM_)),
+ preinterpret<vecHalf>(zmm.packet[packetIndexOffset + (unrollN/PacketSize)*(startN-EIGEN_AVX_MAX_NUM_ROW)])),
+ remMask<EIGEN_AVX_MAX_NUM_ROW>(remM_));
+ }
+ else {
+ pstoreu<Scalar>(
+ C_arr + LDC*startN,
+ padd(ploadu<vecHalf>((const Scalar*)C_arr + LDC*startN),
+ preinterpret<vecHalf>(zmm.packet[packetIndexOffset + (unrollN/PacketSize)*(startN-EIGEN_AVX_MAX_NUM_ROW)])));
+ }
+ }
+ aux_storeC<endN, counter - 1, unrollN, packetIndexOffset, remM>(C_arr, LDC, zmm, remM_);
+ }
+
+ template<int64_t endN, int64_t counter, int64_t unrollN, int64_t packetIndexOffset, bool remM>
+ static EIGEN_ALWAYS_INLINE std::enable_if_t<!(counter > 0 && endN <= PacketSize)>
+ aux_storeC(Scalar *C_arr, int64_t LDC,
+ PacketBlock<vec,EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> &zmm, int64_t remM_ = 0)
+ {
+ EIGEN_UNUSED_VARIABLE(C_arr);
+ EIGEN_UNUSED_VARIABLE(LDC);
+ EIGEN_UNUSED_VARIABLE(zmm);
+ EIGEN_UNUSED_VARIABLE(remM_);
+ }
+
+ template<int64_t endN, int64_t unrollN, int64_t packetIndexOffset, bool remM>
+ static EIGEN_ALWAYS_INLINE
+ void storeC(Scalar *C_arr, int64_t LDC,
+ PacketBlock<vec,EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> &zmm, int64_t remM_ = 0){
+ aux_storeC<endN, endN, unrollN, packetIndexOffset, remM>(C_arr, LDC, zmm, remM_);
+ }
+
+ /**
+ * Transposes LxunrollN row major block of matrices stored EIGEN_AVX_MAX_NUM_ACC zmm registers to
+ * "unrollN"xL ymm registers to be stored col-major into C.
+ *
+ * For 8x48, the 8x48 block (row-major) is stored in zmm as follows:
+ *
+ * row0: zmm0 zmm1 zmm2
+ * row1: zmm3 zmm4 zmm5
+ * .
+ * .
+ * row7: zmm21 zmm22 zmm23
+ *
+ * For 8x32, the 8x32 block (row-major) is stored in zmm as follows:
+ *
+ * row0: zmm0 zmm1
+ * row1: zmm2 zmm3
+ * .
+ * .
+ * row7: zmm14 zmm15
+ *
+ *
+ * In general we will have {1,2,3} groups of avx registers each of size
+ * EIGEN_AVX_MAX_NUM_ROW. packetIndexOffset is used to select which "block" of
+ * avx registers are being transposed.
+ */
+ template<int64_t unrollN, int64_t packetIndexOffset>
+ static EIGEN_ALWAYS_INLINE
+ void transpose(PacketBlock<vec,EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> &zmm) {
+ // Note: this assumes EIGEN_AVX_MAX_NUM_ROW = 8. Unrolls should be adjusted
+ // accordingly if EIGEN_AVX_MAX_NUM_ROW is smaller.
+ constexpr int64_t zmmStride = unrollN/PacketSize;
+ PacketBlock<vec,EIGEN_AVX_MAX_NUM_ROW> r;
+ r.packet[0] = zmm.packet[packetIndexOffset + zmmStride*0];
+ r.packet[1] = zmm.packet[packetIndexOffset + zmmStride*1];
+ r.packet[2] = zmm.packet[packetIndexOffset + zmmStride*2];
+ r.packet[3] = zmm.packet[packetIndexOffset + zmmStride*3];
+ r.packet[4] = zmm.packet[packetIndexOffset + zmmStride*4];
+ r.packet[5] = zmm.packet[packetIndexOffset + zmmStride*5];
+ r.packet[6] = zmm.packet[packetIndexOffset + zmmStride*6];
+ r.packet[7] = zmm.packet[packetIndexOffset + zmmStride*7];
+ ptranspose(r);
+ zmm.packet[packetIndexOffset + zmmStride*0] = r.packet[0];
+ zmm.packet[packetIndexOffset + zmmStride*1] = r.packet[1];
+ zmm.packet[packetIndexOffset + zmmStride*2] = r.packet[2];
+ zmm.packet[packetIndexOffset + zmmStride*3] = r.packet[3];
+ zmm.packet[packetIndexOffset + zmmStride*4] = r.packet[4];
+ zmm.packet[packetIndexOffset + zmmStride*5] = r.packet[5];
+ zmm.packet[packetIndexOffset + zmmStride*6] = r.packet[6];
+ zmm.packet[packetIndexOffset + zmmStride*7] = r.packet[7];
+ }
+};
+
+/**
+ * Unrolls for copyBToRowMajor
+ *
+ * Idea:
+ * 1) Load a block of right-hand sides to registers (using loadB).
+ * 2) Convert the block from column-major to row-major (transposeLxL)
+ * 3) Store the blocks from register either to a temp array (toTemp == true), or back to B (toTemp == false).
+ *
+ * We use at most EIGEN_AVX_MAX_NUM_ACC avx registers to store the blocks of B. The remaining registers are
+ * used as temps for transposing.
+ *
+ * Blocks will be of size Lx{U1,U2,U3}. packetIndexOffset is used to index between these subblocks
+ * For fp32, PacketSize = 2*EIGEN_AVX_MAX_NUM_ROW, so we reinterpret packets as packets half the size (zmm -> ymm).
+ */
+template <typename Scalar>
+class transB {
+public:
+ using vec = typename std::conditional<std::is_same<Scalar, float>::value, vecFullFloat, vecFullDouble>::type;
+ using vecHalf = typename std::conditional<std::is_same<Scalar, float>::value, vecHalfFloat, vecFullDouble>::type;
+ static constexpr int64_t PacketSize = packet_traits<Scalar>::size;
+
+ /***********************************
+ * Auxillary Functions for:
+ * - loadB
+ * - storeB
+ * - loadBBlock
+ * - storeBBlock
+ ***********************************
+ */
+
+ /**
+ * aux_loadB
+ *
+ * 1-D unroll
+ * for(startN = 0; startN < endN; startN++)
+ **/
+ template<int64_t endN, int64_t counter, int64_t packetIndexOffset, bool remM>
+ static EIGEN_ALWAYS_INLINE std::enable_if_t<(counter > 0)>
+ aux_loadB(Scalar *B_arr, int64_t LDB,
+ PacketBlock<vecHalf,EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> &ymm, int64_t remM_ = 0) {
+ constexpr int64_t counterReverse = endN-counter;
+ constexpr int64_t startN = counterReverse;
+
+ EIGEN_IF_CONSTEXPR(remM) {
+ ymm.packet[packetIndexOffset + startN] = ploadu<vecHalf>(
+ (const Scalar*)&B_arr[startN*LDB], remMask<EIGEN_AVX_MAX_NUM_ROW>(remM_));
+ }
+ else
+ ymm.packet[packetIndexOffset + startN] = ploadu<vecHalf>((const Scalar*)&B_arr[startN*LDB]);
+
+ aux_loadB<endN, counter-1, packetIndexOffset, remM>(B_arr, LDB, ymm, remM_);
+ }
+
+ template<int64_t endN, int64_t counter, int64_t packetIndexOffset, bool remM>
+ static EIGEN_ALWAYS_INLINE std::enable_if_t<(counter <= 0)>
+ aux_loadB(Scalar *B_arr, int64_t LDB,
+ PacketBlock<vecHalf,EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> &ymm, int64_t remM_ = 0)
+ {
+ EIGEN_UNUSED_VARIABLE(B_arr);
+ EIGEN_UNUSED_VARIABLE(LDB);
+ EIGEN_UNUSED_VARIABLE(ymm);
+ EIGEN_UNUSED_VARIABLE(remM_);
+ }
+
+ /**
+ * aux_storeB
+ *
+ * 1-D unroll
+ * for(startN = 0; startN < endN; startN++)
+ **/
+ template<int64_t endN, int64_t counter, int64_t packetIndexOffset, bool remK, bool remM>
+ static EIGEN_ALWAYS_INLINE std::enable_if_t<(counter > 0)>
+ aux_storeB(Scalar *B_arr, int64_t LDB,
+ PacketBlock<vecHalf,EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> &ymm, int64_t rem_ = 0) {
+ constexpr int64_t counterReverse = endN-counter;
+ constexpr int64_t startN = counterReverse;
+
+ EIGEN_IF_CONSTEXPR( remK || remM) {
+ pstoreu<Scalar>(
+ &B_arr[startN*LDB],
+ ymm.packet[packetIndexOffset + startN],
+ remMask<EIGEN_AVX_MAX_NUM_ROW>(rem_));
+ }
+ else {
+ pstoreu<Scalar>(&B_arr[startN*LDB], ymm.packet[packetIndexOffset + startN]);
+ }
+
+ aux_storeB<endN, counter-1, packetIndexOffset, remK, remM>(B_arr, LDB, ymm, rem_);
+ }
+
+ template<int64_t endN, int64_t counter, int64_t packetIndexOffset, bool remK, bool remM>
+ static EIGEN_ALWAYS_INLINE std::enable_if_t<(counter <= 0)>
+ aux_storeB(Scalar *B_arr, int64_t LDB,
+ PacketBlock<vecHalf,EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> &ymm, int64_t rem_ = 0)
+ {
+ EIGEN_UNUSED_VARIABLE(B_arr);
+ EIGEN_UNUSED_VARIABLE(LDB);
+ EIGEN_UNUSED_VARIABLE(ymm);
+ EIGEN_UNUSED_VARIABLE(rem_);
+ }
+
+ /**
+ * aux_loadBBlock
+ *
+ * 1-D unroll
+ * for(startN = 0; startN < endN; startN += EIGEN_AVX_MAX_NUM_ROW)
+ **/
+ template<int64_t endN, int64_t counter, bool toTemp, bool remM>
+ static EIGEN_ALWAYS_INLINE std::enable_if_t<(counter > 0)>
+ aux_loadBBlock(Scalar *B_arr, int64_t LDB, Scalar *B_temp, int64_t LDB_,
+ PacketBlock<vecHalf,EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> &ymm,
+ int64_t remM_ = 0) {
+ constexpr int64_t counterReverse = endN-counter;
+ constexpr int64_t startN = counterReverse;
+ transB::template loadB<EIGEN_AVX_MAX_NUM_ROW,startN, false>(&B_temp[startN], LDB_, ymm);
+ aux_loadBBlock<endN, counter-EIGEN_AVX_MAX_NUM_ROW, toTemp, remM>(
+ B_arr, LDB, B_temp, LDB_, ymm, remM_);
+ }
+
+ template<int64_t endN, int64_t counter, bool toTemp, bool remM>
+ static EIGEN_ALWAYS_INLINE std::enable_if_t<(counter <= 0)>
+ aux_loadBBlock(Scalar *B_arr, int64_t LDB, Scalar *B_temp, int64_t LDB_,
+ PacketBlock<vecHalf,EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> &ymm,
+ int64_t remM_ = 0)
+ {
+ EIGEN_UNUSED_VARIABLE(B_arr);
+ EIGEN_UNUSED_VARIABLE(LDB);
+ EIGEN_UNUSED_VARIABLE(B_temp);
+ EIGEN_UNUSED_VARIABLE(LDB_);
+ EIGEN_UNUSED_VARIABLE(ymm);
+ EIGEN_UNUSED_VARIABLE(remM_);
+ }
+
+
+ /**
+ * aux_storeBBlock
+ *
+ * 1-D unroll
+ * for(startN = 0; startN < endN; startN += EIGEN_AVX_MAX_NUM_ROW)
+ **/
+ template<int64_t endN, int64_t counter, bool toTemp, bool remM, int64_t remK_>
+ static EIGEN_ALWAYS_INLINE std::enable_if_t<(counter > 0)>
+ aux_storeBBlock(Scalar *B_arr, int64_t LDB, Scalar *B_temp, int64_t LDB_,
+ PacketBlock<vecHalf,EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> &ymm,
+ int64_t remM_ = 0) {
+ constexpr int64_t counterReverse = endN-counter;
+ constexpr int64_t startN = counterReverse;
+
+ EIGEN_IF_CONSTEXPR(toTemp) {
+ transB::template storeB<EIGEN_AVX_MAX_NUM_ROW,startN, remK_ != 0, false>(
+ &B_temp[startN], LDB_, ymm, remK_);
+ }
+ else {
+ transB::template storeB<std::min(EIGEN_AVX_MAX_NUM_ROW,endN),startN, false, remM>(
+ &B_arr[0 + startN*LDB], LDB, ymm, remM_);
+ }
+ aux_storeBBlock<endN, counter-EIGEN_AVX_MAX_NUM_ROW, toTemp, remM, remK_>(
+ B_arr, LDB, B_temp, LDB_, ymm, remM_);
+ }
+
+ template<int64_t endN, int64_t counter, bool toTemp, bool remM, int64_t remK_>
+ static EIGEN_ALWAYS_INLINE std::enable_if_t<(counter <= 0)>
+ aux_storeBBlock(Scalar *B_arr, int64_t LDB, Scalar *B_temp, int64_t LDB_,
+ PacketBlock<vecHalf,EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> &ymm,
+ int64_t remM_ = 0)
+ {
+ EIGEN_UNUSED_VARIABLE(B_arr);
+ EIGEN_UNUSED_VARIABLE(LDB);
+ EIGEN_UNUSED_VARIABLE(B_temp);
+ EIGEN_UNUSED_VARIABLE(LDB_);
+ EIGEN_UNUSED_VARIABLE(ymm);
+ EIGEN_UNUSED_VARIABLE(remM_);
+ }
+
+
+ /********************************************************
+ * Wrappers for aux_XXXX to hide counter parameter
+ ********************************************************/
+
+ template<int64_t endN, int64_t packetIndexOffset, bool remM>
+ static EIGEN_ALWAYS_INLINE
+ void loadB(Scalar *B_arr, int64_t LDB,
+ PacketBlock<vecHalf,EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> &ymm, int64_t remM_ = 0) {
+ aux_loadB<endN, endN, packetIndexOffset, remM>(B_arr, LDB, ymm, remM_);
+ }
+
+ template<int64_t endN, int64_t packetIndexOffset, bool remK, bool remM>
+ static EIGEN_ALWAYS_INLINE
+ void storeB(Scalar *B_arr, int64_t LDB,
+ PacketBlock<vecHalf,EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> &ymm, int64_t rem_ = 0) {
+ aux_storeB<endN, endN, packetIndexOffset, remK, remM>(B_arr, LDB, ymm, rem_);
+ }
+
+ template<int64_t unrollN, bool toTemp, bool remM>
+ static EIGEN_ALWAYS_INLINE
+ void loadBBlock(Scalar *B_arr, int64_t LDB, Scalar *B_temp, int64_t LDB_,
+ PacketBlock<vecHalf,EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> &ymm,
+ int64_t remM_ = 0) {
+ EIGEN_IF_CONSTEXPR(toTemp) {
+ transB::template loadB<unrollN,0,remM>(&B_arr[0],LDB, ymm, remM_);
+ }
+ else {
+ aux_loadBBlock<unrollN, unrollN, toTemp, remM>(
+ B_arr, LDB, B_temp, LDB_, ymm, remM_);
+ }
+ }
+
+ template<int64_t unrollN, bool toTemp, bool remM, int64_t remK_>
+ static EIGEN_ALWAYS_INLINE
+ void storeBBlock(Scalar *B_arr, int64_t LDB, Scalar *B_temp, int64_t LDB_,
+ PacketBlock<vecHalf,EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> &ymm,
+ int64_t remM_ = 0) {
+ aux_storeBBlock<unrollN, unrollN, toTemp, remM, remK_>(
+ B_arr, LDB, B_temp, LDB_, ymm, remM_);
+ }
+
+ template<int64_t packetIndexOffset>
+ static EIGEN_ALWAYS_INLINE
+ void transposeLxL(PacketBlock<vecHalf,EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> &ymm){
+ // Note: this assumes EIGEN_AVX_MAX_NUM_ROW = 8. Unrolls should be adjusted
+ // accordingly if EIGEN_AVX_MAX_NUM_ROW is smaller.
+ PacketBlock<vecHalf,EIGEN_AVX_MAX_NUM_ROW> r;
+ r.packet[0] = ymm.packet[packetIndexOffset + 0];
+ r.packet[1] = ymm.packet[packetIndexOffset + 1];
+ r.packet[2] = ymm.packet[packetIndexOffset + 2];
+ r.packet[3] = ymm.packet[packetIndexOffset + 3];
+ r.packet[4] = ymm.packet[packetIndexOffset + 4];
+ r.packet[5] = ymm.packet[packetIndexOffset + 5];
+ r.packet[6] = ymm.packet[packetIndexOffset + 6];
+ r.packet[7] = ymm.packet[packetIndexOffset + 7];
+ ptranspose(r);
+ ymm.packet[packetIndexOffset + 0] = r.packet[0];
+ ymm.packet[packetIndexOffset + 1] = r.packet[1];
+ ymm.packet[packetIndexOffset + 2] = r.packet[2];
+ ymm.packet[packetIndexOffset + 3] = r.packet[3];
+ ymm.packet[packetIndexOffset + 4] = r.packet[4];
+ ymm.packet[packetIndexOffset + 5] = r.packet[5];
+ ymm.packet[packetIndexOffset + 6] = r.packet[6];
+ ymm.packet[packetIndexOffset + 7] = r.packet[7];
+ }
+
+ template<int64_t unrollN, bool toTemp, bool remM>
+ static EIGEN_ALWAYS_INLINE
+ void transB_kernel(Scalar *B_arr, int64_t LDB, Scalar *B_temp, int64_t LDB_,
+ PacketBlock<vecHalf,EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> &ymm, int64_t remM_ = 0) {
+ constexpr int64_t U3 = PacketSize * 3;
+ constexpr int64_t U2 = PacketSize * 2;
+ constexpr int64_t U1 = PacketSize * 1;
+ /**
+ * Unrolls needed for each case:
+ * - AVX512 fp32 48 32 16 8 4 2 1
+ * - AVX512 fp64 24 16 8 4 2 1
+ *
+ * For fp32 L and U1 are 1:2 so for U3/U2 cases the loads/stores need to be split up.
+ */
+ EIGEN_IF_CONSTEXPR(unrollN == U3) {
+ // load LxU3 B col major, transpose LxU3 row major
+ constexpr int64_t maxUBlock = std::min(3*EIGEN_AVX_MAX_NUM_ROW, U3);
+ transB::template loadBBlock<maxUBlock,toTemp, remM>(B_arr, LDB, B_temp, LDB_, ymm, remM_);
+ transB::template transposeLxL<0*EIGEN_AVX_MAX_NUM_ROW>(ymm);
+ transB::template transposeLxL<1*EIGEN_AVX_MAX_NUM_ROW>(ymm);
+ transB::template transposeLxL<2*EIGEN_AVX_MAX_NUM_ROW>(ymm);
+ transB::template storeBBlock<maxUBlock,toTemp, remM,0>(B_arr, LDB, B_temp, LDB_, ymm, remM_);
+
+ EIGEN_IF_CONSTEXPR( maxUBlock < U3) {
+ transB::template loadBBlock<maxUBlock,toTemp, remM>(&B_arr[maxUBlock*LDB], LDB, &B_temp[maxUBlock], LDB_, ymm, remM_);
+ transB::template transposeLxL<0*EIGEN_AVX_MAX_NUM_ROW>(ymm);
+ transB::template transposeLxL<1*EIGEN_AVX_MAX_NUM_ROW>(ymm);
+ transB::template transposeLxL<2*EIGEN_AVX_MAX_NUM_ROW>(ymm);
+ transB::template storeBBlock<maxUBlock,toTemp, remM,0>(&B_arr[maxUBlock*LDB], LDB, &B_temp[maxUBlock], LDB_, ymm, remM_);
+ }
+ }
+ else EIGEN_IF_CONSTEXPR(unrollN == U2) {
+ // load LxU2 B col major, transpose LxU2 row major
+ constexpr int64_t maxUBlock = std::min(3*EIGEN_AVX_MAX_NUM_ROW, U2);
+ transB::template loadBBlock<maxUBlock,toTemp, remM>(B_arr, LDB, B_temp, LDB_, ymm, remM_);
+ transB::template transposeLxL<0*EIGEN_AVX_MAX_NUM_ROW>(ymm);
+ transB::template transposeLxL<1*EIGEN_AVX_MAX_NUM_ROW>(ymm);
+ EIGEN_IF_CONSTEXPR(maxUBlock < U2) transB::template transposeLxL<2*EIGEN_AVX_MAX_NUM_ROW>(ymm);
+ transB::template storeBBlock<maxUBlock,toTemp,remM,0>(B_arr, LDB, B_temp, LDB_, ymm, remM_);
+
+ EIGEN_IF_CONSTEXPR( maxUBlock < U2) {
+ transB::template loadBBlock<EIGEN_AVX_MAX_NUM_ROW,toTemp, remM>(
+ &B_arr[maxUBlock*LDB], LDB, &B_temp[maxUBlock], LDB_, ymm, remM_);
+ transB::template transposeLxL<0>(ymm);
+ transB::template storeBBlock<EIGEN_AVX_MAX_NUM_ROW,toTemp,remM,0>(
+ &B_arr[maxUBlock*LDB], LDB, &B_temp[maxUBlock], LDB_, ymm, remM_);
+ }
+ }
+ else EIGEN_IF_CONSTEXPR(unrollN == U1) {
+ // load LxU1 B col major, transpose LxU1 row major
+ transB::template loadBBlock<U1,toTemp, remM>(B_arr, LDB, B_temp, LDB_, ymm, remM_);
+ transB::template transposeLxL<0>(ymm);
+ EIGEN_IF_CONSTEXPR(EIGEN_AVX_MAX_NUM_ROW < U1) {
+ transB::template transposeLxL<1*EIGEN_AVX_MAX_NUM_ROW>(ymm);
+ }
+ transB::template storeBBlock<U1,toTemp,remM,0>(B_arr, LDB, B_temp, LDB_, ymm, remM_);
+ }
+ else EIGEN_IF_CONSTEXPR(unrollN == 8 && U1 > 8) {
+ // load Lx4 B col major, transpose Lx4 row major
+ transB::template loadBBlock<8,toTemp,remM>(B_arr, LDB, B_temp, LDB_, ymm, remM_);
+ transB::template transposeLxL<0>(ymm);
+ transB::template storeBBlock<8,toTemp,remM,8>(B_arr, LDB, B_temp, LDB_, ymm, remM_);
+ }
+ else EIGEN_IF_CONSTEXPR(unrollN == 4 && U1 > 4) {
+ // load Lx4 B col major, transpose Lx4 row major
+ transB::template loadBBlock<4,toTemp,remM>(B_arr, LDB, B_temp, LDB_, ymm, remM_);
+ transB::template transposeLxL<0>(ymm);
+ transB::template storeBBlock<4,toTemp,remM,4>(B_arr, LDB, B_temp, LDB_, ymm, remM_);
+ }
+ else EIGEN_IF_CONSTEXPR(unrollN == 2) {
+ // load Lx2 B col major, transpose Lx2 row major
+ transB::template loadBBlock<2,toTemp,remM>(B_arr, LDB, B_temp, LDB_, ymm, remM_);
+ transB::template transposeLxL<0>(ymm);
+ transB::template storeBBlock<2,toTemp,remM,2>(B_arr, LDB, B_temp, LDB_, ymm, remM_);
+ }
+ else EIGEN_IF_CONSTEXPR(unrollN == 1) {
+ // load Lx1 B col major, transpose Lx1 row major
+ transB::template loadBBlock<1,toTemp,remM>(B_arr, LDB, B_temp, LDB_, ymm, remM_);
+ transB::template transposeLxL<0>(ymm);
+ transB::template storeBBlock<1,toTemp,remM,1>(B_arr, LDB, B_temp, LDB_, ymm, remM_);
+ }
+ }
+};
+
+/**
+ * Unrolls for triSolveKernel
+ *
+ * Idea:
+ * 1) Load a block of right-hand sides to registers in RHSInPacket (using loadRHS).
+ * 2) Do triangular solve with RHSInPacket and a small block of A (triangular matrix)
+ * stored in AInPacket (using triSolveMicroKernel).
+ * 3) Store final results (in avx registers) back into memory (using storeRHS).
+ *
+ * RHSInPacket uses at most EIGEN_AVX_MAX_NUM_ACC avx registers and AInPacket uses at most
+ * EIGEN_AVX_MAX_NUM_ROW registers.
+ */
+template <typename Scalar>
+class trsm {
+public:
+ using vec = typename std::conditional<std::is_same<Scalar, float>::value,
+ vecFullFloat,
+ vecFullDouble>::type;
+ static constexpr int64_t PacketSize = packet_traits<Scalar>::size;
+
+ /***********************************
+ * Auxillary Functions for:
+ * - loadRHS
+ * - storeRHS
+ * - divRHSByDiag
+ * - updateRHS
+ * - triSolveMicroKernel
+ ************************************/
+ /**
+ * aux_loadRHS
+ *
+ * 2-D unroll
+ * for(startM = 0; startM < endM; startM++)
+ * for(startK = 0; startK < endK; startK++)
+ **/
+ template<bool isFWDSolve, int64_t endM, int64_t endK, int64_t counter, bool krem>
+ static EIGEN_ALWAYS_INLINE std::enable_if_t<(counter > 0)>
+ aux_loadRHS(Scalar* B_arr, int64_t LDB, PacketBlock<vec,EIGEN_AVX_MAX_NUM_ACC> &RHSInPacket, int64_t rem = 0) {
+
+ constexpr int64_t counterReverse = endM*endK-counter;
+ constexpr int64_t startM = counterReverse/(endK);
+ constexpr int64_t startK = counterReverse%endK;
+
+ constexpr int64_t packetIndex = startM*endK + startK;
+ constexpr int64_t startM_ = isFWDSolve ? startM : -startM;
+ const int64_t rhsIndex = (startK*PacketSize) + startM_*LDB;
+ EIGEN_IF_CONSTEXPR(krem) {
+ RHSInPacket.packet[packetIndex] = ploadu<vec>(&B_arr[rhsIndex], remMask<PacketSize>(rem));
+ }
+ else {
+ RHSInPacket.packet[packetIndex] = ploadu<vec>(&B_arr[rhsIndex]);
+ }
+ aux_loadRHS<isFWDSolve,endM, endK, counter-1, krem>(B_arr, LDB, RHSInPacket, rem);
+ }
+
+ template<bool isFWDSolve, int64_t endM, int64_t endK, int64_t counter, bool krem>
+ static EIGEN_ALWAYS_INLINE std::enable_if_t<(counter <= 0)>
+ aux_loadRHS(Scalar* B_arr, int64_t LDB, PacketBlock<vec,EIGEN_AVX_MAX_NUM_ACC> &RHSInPacket, int64_t rem = 0)
+ {
+ EIGEN_UNUSED_VARIABLE(B_arr);
+ EIGEN_UNUSED_VARIABLE(LDB);
+ EIGEN_UNUSED_VARIABLE(RHSInPacket);
+ EIGEN_UNUSED_VARIABLE(rem);
+ }
+
+ /**
+ * aux_storeRHS
+ *
+ * 2-D unroll
+ * for(startM = 0; startM < endM; startM++)
+ * for(startK = 0; startK < endK; startK++)
+ **/
+ template<bool isFWDSolve, int64_t endM, int64_t endK, int64_t counter, bool krem>
+ static EIGEN_ALWAYS_INLINE std::enable_if_t<(counter > 0)>
+ aux_storeRHS(Scalar* B_arr, int64_t LDB, PacketBlock<vec,EIGEN_AVX_MAX_NUM_ACC> &RHSInPacket, int64_t rem = 0) {
+ constexpr int64_t counterReverse = endM*endK-counter;
+ constexpr int64_t startM = counterReverse/(endK);
+ constexpr int64_t startK = counterReverse%endK;
+
+ constexpr int64_t packetIndex = startM*endK + startK;
+ constexpr int64_t startM_ = isFWDSolve ? startM : -startM;
+ const int64_t rhsIndex = (startK*PacketSize) + startM_*LDB;
+ EIGEN_IF_CONSTEXPR(krem) {
+ pstoreu<Scalar>(&B_arr[rhsIndex], RHSInPacket.packet[packetIndex], remMask<PacketSize>(rem));
+ }
+ else {
+ pstoreu<Scalar>(&B_arr[rhsIndex], RHSInPacket.packet[packetIndex]);
+ }
+ aux_storeRHS<isFWDSolve,endM, endK, counter-1, krem>(B_arr, LDB, RHSInPacket, rem);
+ }
+
+ template<bool isFWDSolve, int64_t endM, int64_t endK, int64_t counter, bool krem>
+ static EIGEN_ALWAYS_INLINE std::enable_if_t<(counter <= 0)>
+ aux_storeRHS(Scalar* B_arr, int64_t LDB, PacketBlock<vec,EIGEN_AVX_MAX_NUM_ACC> &RHSInPacket, int64_t rem = 0)
+ {
+ EIGEN_UNUSED_VARIABLE(B_arr);
+ EIGEN_UNUSED_VARIABLE(LDB);
+ EIGEN_UNUSED_VARIABLE(RHSInPacket);
+ EIGEN_UNUSED_VARIABLE(rem);
+ }
+
+ /**
+ * aux_divRHSByDiag
+ *
+ * currM may be -1, (currM >=0) in enable_if checks for this
+ *
+ * 1-D unroll
+ * for(startK = 0; startK < endK; startK++)
+ **/
+ template<int64_t currM, int64_t endK, int64_t counter>
+ static EIGEN_ALWAYS_INLINE std::enable_if_t<(counter > 0 && currM >= 0)>
+ aux_divRHSByDiag(PacketBlock<vec,EIGEN_AVX_MAX_NUM_ACC> &RHSInPacket, PacketBlock<vec,EIGEN_AVX_MAX_NUM_ROW> &AInPacket) {
+ constexpr int64_t counterReverse = endK-counter;
+ constexpr int64_t startK = counterReverse;
+
+ constexpr int64_t packetIndex = currM*endK + startK;
+ RHSInPacket.packet[packetIndex] = pmul(AInPacket.packet[currM], RHSInPacket.packet[packetIndex]);
+ aux_divRHSByDiag<currM, endK, counter-1>(RHSInPacket, AInPacket);
+ }
+
+ template<int64_t currM, int64_t endK, int64_t counter>
+ static EIGEN_ALWAYS_INLINE std::enable_if_t<!(counter > 0 && currM >= 0)>
+ aux_divRHSByDiag(PacketBlock<vec,EIGEN_AVX_MAX_NUM_ACC> &RHSInPacket, PacketBlock<vec,EIGEN_AVX_MAX_NUM_ROW> &AInPacket) {
+ EIGEN_UNUSED_VARIABLE(RHSInPacket);
+ EIGEN_UNUSED_VARIABLE(AInPacket);
+ }
+
+ /**
+ * aux_updateRHS
+ *
+ * 2-D unroll
+ * for(startM = initM; startM < endM; startM++)
+ * for(startK = 0; startK < endK; startK++)
+ **/
+ template<bool isARowMajor, bool isFWDSolve, bool isUnitDiag, int64_t initM, int64_t endM, int64_t endK, int64_t counter, int64_t currentM>
+ static EIGEN_ALWAYS_INLINE std::enable_if_t<(counter > 0)>
+ aux_updateRHS(Scalar *A_arr, int64_t LDA, PacketBlock<vec,EIGEN_AVX_MAX_NUM_ACC> &RHSInPacket, PacketBlock<vec,EIGEN_AVX_MAX_NUM_ROW> &AInPacket) {
+
+ constexpr int64_t counterReverse = (endM-initM)*endK-counter;
+ constexpr int64_t startM = initM + counterReverse/(endK);
+ constexpr int64_t startK = counterReverse%endK;
+
+ // For each row of A, first update all corresponding RHS
+ constexpr int64_t packetIndex = startM*endK + startK;
+ EIGEN_IF_CONSTEXPR(currentM > 0) {
+ RHSInPacket.packet[packetIndex] =
+ pnmadd(AInPacket.packet[startM],
+ RHSInPacket.packet[(currentM-1)*endK+startK],
+ RHSInPacket.packet[packetIndex]);
+ }
+
+ EIGEN_IF_CONSTEXPR(startK == endK - 1) {
+ // Once all RHS for previous row of A is updated, we broadcast the next element in the column A_{i, currentM}.
+ EIGEN_IF_CONSTEXPR(startM == currentM && !isUnitDiag) {
+ // If diagonal is not unit, we broadcast reciprocals of diagonals AinPacket.packet[currentM].
+ // This will be used in divRHSByDiag
+ EIGEN_IF_CONSTEXPR(isFWDSolve)
+ AInPacket.packet[currentM] = pset1<vec>(Scalar(1)/A_arr[idA<isARowMajor>(currentM,currentM,LDA)]);
+ else
+ AInPacket.packet[currentM] = pset1<vec>(Scalar(1)/A_arr[idA<isARowMajor>(-currentM,-currentM,LDA)]);
+ }
+ else {
+ // Broadcast next off diagonal element of A
+ EIGEN_IF_CONSTEXPR(isFWDSolve)
+ AInPacket.packet[startM] = pset1<vec>(A_arr[idA<isARowMajor>(startM,currentM,LDA)]);
+ else
+ AInPacket.packet[startM] = pset1<vec>(A_arr[idA<isARowMajor>(-startM,-currentM,LDA)]);
+ }
+ }
+
+ aux_updateRHS<isARowMajor, isFWDSolve, isUnitDiag, initM, endM, endK, counter - 1, currentM>(A_arr, LDA, RHSInPacket, AInPacket);
+ }
+
+ template<bool isARowMajor, bool isFWDSolve, bool isUnitDiag, int64_t initM, int64_t endM, int64_t endK, int64_t counter, int64_t currentM>
+ static EIGEN_ALWAYS_INLINE std::enable_if_t<(counter <= 0)>
+ aux_updateRHS(Scalar *A_arr, int64_t LDA, PacketBlock<vec,EIGEN_AVX_MAX_NUM_ACC> &RHSInPacket, PacketBlock<vec,EIGEN_AVX_MAX_NUM_ROW> &AInPacket) {
+ EIGEN_UNUSED_VARIABLE(A_arr);
+ EIGEN_UNUSED_VARIABLE(LDA);
+ EIGEN_UNUSED_VARIABLE(RHSInPacket);
+ EIGEN_UNUSED_VARIABLE(AInPacket);
+ }
+
+ /**
+ * aux_triSolverMicroKernel
+ *
+ * 1-D unroll
+ * for(startM = 0; startM < endM; startM++)
+ **/
+ template <bool isARowMajor, bool isFWDSolve, bool isUnitDiag, int64_t endM, int64_t counter, int64_t numK>
+ static EIGEN_ALWAYS_INLINE std::enable_if_t<(counter > 0)>
+ aux_triSolveMicroKernel(Scalar *A_arr, int64_t LDA, PacketBlock<vec,EIGEN_AVX_MAX_NUM_ACC> &RHSInPacket, PacketBlock<vec,EIGEN_AVX_MAX_NUM_ROW> &AInPacket) {
+
+ constexpr int64_t counterReverse = endM-counter;
+ constexpr int64_t startM = counterReverse;
+
+ constexpr int64_t currentM = startM;
+ // Divides the right-hand side in row startM, by digonal value of A
+ // broadcasted to AInPacket.packet[startM-1] in the previous iteration.
+ //
+ // Without "if constexpr" the compiler instantiates the case <-1, numK>
+ // this is handled with enable_if to prevent out-of-bound warnings
+ // from the compiler
+ EIGEN_IF_CONSTEXPR(!isUnitDiag && startM > 0)
+ trsm::template divRHSByDiag<startM-1, numK>(RHSInPacket, AInPacket);
+
+ // After division, the rhs corresponding to subsequent rows of A can be partially updated
+ // We also broadcast the reciprocal of the next diagonal to AInPacket.packet[currentM] (if needed)
+ // to be used in the next iteration.
+ trsm::template
+ updateRHS<isARowMajor, isFWDSolve, isUnitDiag, startM, endM, numK, currentM>(
+ A_arr, LDA, RHSInPacket, AInPacket);
+
+ // Handle division for the RHS corresponding to the final row of A.
+ EIGEN_IF_CONSTEXPR(!isUnitDiag && startM == endM-1)
+ trsm::template divRHSByDiag<startM, numK>(RHSInPacket, AInPacket);
+
+ aux_triSolveMicroKernel<isARowMajor, isFWDSolve, isUnitDiag, endM, counter - 1, numK>(A_arr, LDA, RHSInPacket, AInPacket);
+ }
+
+ template <bool isARowMajor, bool isFWDSolve, bool isUnitDiag, int64_t endM, int64_t counter, int64_t numK>
+ static EIGEN_ALWAYS_INLINE std::enable_if_t<(counter <= 0)>
+ aux_triSolveMicroKernel(Scalar *A_arr, int64_t LDA, PacketBlock<vec,EIGEN_AVX_MAX_NUM_ACC> &RHSInPacket, PacketBlock<vec,EIGEN_AVX_MAX_NUM_ROW> &AInPacket)
+ {
+ EIGEN_UNUSED_VARIABLE(A_arr);
+ EIGEN_UNUSED_VARIABLE(LDA);
+ EIGEN_UNUSED_VARIABLE(RHSInPacket);
+ EIGEN_UNUSED_VARIABLE(AInPacket);
+ }
+
+ /********************************************************
+ * Wrappers for aux_XXXX to hide counter parameter
+ ********************************************************/
+
+ /**
+ * Load endMxendK block of B to RHSInPacket
+ * Masked loads are used for cases where endK is not a multiple of PacketSize
+ */
+ template<bool isFWDSolve, int64_t endM, int64_t endK, bool krem = false>
+ static EIGEN_ALWAYS_INLINE
+ void loadRHS(Scalar* B_arr, int64_t LDB, PacketBlock<vec,EIGEN_AVX_MAX_NUM_ACC> &RHSInPacket, int64_t rem = 0) {
+ aux_loadRHS<isFWDSolve, endM, endK, endM*endK, krem>(B_arr, LDB, RHSInPacket, rem);
+ }
+
+ /**
+ * Load endMxendK block of B to RHSInPacket
+ * Masked loads are used for cases where endK is not a multiple of PacketSize
+ */
+ template<bool isFWDSolve, int64_t endM, int64_t endK, bool krem = false>
+ static EIGEN_ALWAYS_INLINE
+ void storeRHS(Scalar* B_arr, int64_t LDB, PacketBlock<vec,EIGEN_AVX_MAX_NUM_ACC> &RHSInPacket, int64_t rem = 0) {
+ aux_storeRHS<isFWDSolve, endM, endK, endM*endK, krem>(B_arr, LDB, RHSInPacket, rem);
+ }
+
+ /**
+ * Only used if Triangular matrix has non-unit diagonal values
+ */
+ template<int64_t currM, int64_t endK>
+ static EIGEN_ALWAYS_INLINE
+ void divRHSByDiag(PacketBlock<vec,EIGEN_AVX_MAX_NUM_ACC> &RHSInPacket, PacketBlock<vec,EIGEN_AVX_MAX_NUM_ROW> &AInPacket) {
+ aux_divRHSByDiag<currM, endK, endK>(RHSInPacket, AInPacket);
+ }
+
+ /**
+ * Update right-hand sides (stored in avx registers)
+ * Traversing along the column A_{i,currentM}, where currentM <= i <= endM, and broadcasting each value to AInPacket.
+ **/
+ template<bool isARowMajor, bool isFWDSolve, bool isUnitDiag, int64_t startM, int64_t endM, int64_t endK, int64_t currentM>
+ static EIGEN_ALWAYS_INLINE
+ void updateRHS(Scalar *A_arr, int64_t LDA, PacketBlock<vec,EIGEN_AVX_MAX_NUM_ACC> &RHSInPacket, PacketBlock<vec,EIGEN_AVX_MAX_NUM_ROW> &AInPacket) {
+ aux_updateRHS<isARowMajor, isFWDSolve, isUnitDiag, startM, endM, endK, (endM-startM)*endK, currentM>(
+ A_arr, LDA, RHSInPacket, AInPacket);
+ }
+
+ /**
+ * endM: dimension of A. 1 <= endM <= EIGEN_AVX_MAX_NUM_ROW
+ * numK: number of avx registers to use for each row of B (ex fp32: 48 rhs => 3 avx reg used). 1 <= endK <= 3.
+ * isFWDSolve: true => forward substitution, false => backwards substitution
+ * isUnitDiag: true => triangular matrix has unit diagonal.
+ */
+ template <bool isARowMajor, bool isFWDSolve, bool isUnitDiag, int64_t endM, int64_t numK>
+ static EIGEN_ALWAYS_INLINE
+ void triSolveMicroKernel(Scalar *A_arr, int64_t LDA, PacketBlock<vec,EIGEN_AVX_MAX_NUM_ACC> &RHSInPacket, PacketBlock<vec,EIGEN_AVX_MAX_NUM_ROW> &AInPacket) {
+ static_assert( numK >= 1 && numK <= 3, "numK out of range" );
+ aux_triSolveMicroKernel<isARowMajor, isFWDSolve, isUnitDiag, endM, endM, numK>(
+ A_arr, LDA, RHSInPacket, AInPacket);
+ }
+};
+
+/**
+ * Unrolls for gemm kernel
+ *
+ * isAdd: true => C += A*B, false => C -= A*B
+ */
+template <typename Scalar, bool isAdd>
+class gemm {
+public:
+ using vec = typename std::conditional<std::is_same<Scalar, float>::value, vecFullFloat, vecFullDouble>::type;
+ static constexpr int64_t PacketSize = packet_traits<Scalar>::size;
+
+ /***********************************
+ * Auxillary Functions for:
+ * - setzero
+ * - updateC
+ * - storeC
+ * - startLoadB
+ * - triSolveMicroKernel
+ ************************************/
+
+ /**
+ * aux_setzero
+ *
+ * 2-D unroll
+ * for(startM = 0; startM < endM; startM++)
+ * for(startN = 0; startN < endN; startN++)
+ **/
+ template<int64_t endM, int64_t endN, int64_t counter>
+ static EIGEN_ALWAYS_INLINE std::enable_if_t<(counter > 0)>
+ aux_setzero(PacketBlock<vec,EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> &zmm) {
+ constexpr int64_t counterReverse = endM*endN-counter;
+ constexpr int64_t startM = counterReverse/(endN);
+ constexpr int64_t startN = counterReverse%endN;
+
+ zmm.packet[startN*endM + startM] = pzero(zmm.packet[startN*endM + startM]);
+ aux_setzero<endM, endN, counter-1>(zmm);
+ }
+
+ template<int64_t endM, int64_t endN, int64_t counter>
+ static EIGEN_ALWAYS_INLINE std::enable_if_t<(counter <= 0)>
+ aux_setzero(PacketBlock<vec,EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> &zmm)
+ {
+ EIGEN_UNUSED_VARIABLE(zmm);
+ }
+
+ /**
+ * aux_updateC
+ *
+ * 2-D unroll
+ * for(startM = 0; startM < endM; startM++)
+ * for(startN = 0; startN < endN; startN++)
+ **/
+ template<int64_t endM, int64_t endN, int64_t counter, bool rem>
+ static EIGEN_ALWAYS_INLINE std::enable_if_t<(counter > 0)>
+ aux_updateC(Scalar *C_arr, int64_t LDC, PacketBlock<vec,EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> &zmm, int64_t rem_ = 0) {
+ EIGEN_UNUSED_VARIABLE(rem_);
+ constexpr int64_t counterReverse = endM*endN-counter;
+ constexpr int64_t startM = counterReverse/(endN);
+ constexpr int64_t startN = counterReverse%endN;
+
+ EIGEN_IF_CONSTEXPR(rem)
+ zmm.packet[startN*endM + startM] =
+ padd(ploadu<vec>(&C_arr[(startN) * LDC + startM*PacketSize], remMask<PacketSize>(rem_)),
+ zmm.packet[startN*endM + startM],
+ remMask<PacketSize>(rem_));
+ else
+ zmm.packet[startN*endM + startM] =
+ padd(ploadu<vec>(&C_arr[(startN) * LDC + startM*PacketSize]), zmm.packet[startN*endM + startM]);
+ aux_updateC<endM, endN, counter-1, rem>(C_arr, LDC, zmm, rem_);
+ }
+
+ template<int64_t endM, int64_t endN, int64_t counter, bool rem>
+ static EIGEN_ALWAYS_INLINE std::enable_if_t<(counter <= 0)>
+ aux_updateC(Scalar *C_arr, int64_t LDC, PacketBlock<vec,EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> &zmm, int64_t rem_ = 0)
+ {
+ EIGEN_UNUSED_VARIABLE(C_arr);
+ EIGEN_UNUSED_VARIABLE(LDC);
+ EIGEN_UNUSED_VARIABLE(zmm);
+ EIGEN_UNUSED_VARIABLE(rem_);
+ }
+
+ /**
+ * aux_storeC
+ *
+ * 2-D unroll
+ * for(startM = 0; startM < endM; startM++)
+ * for(startN = 0; startN < endN; startN++)
+ **/
+ template<int64_t endM, int64_t endN, int64_t counter, bool rem>
+ static EIGEN_ALWAYS_INLINE std::enable_if_t<(counter > 0)>
+ aux_storeC(Scalar *C_arr, int64_t LDC, PacketBlock<vec,EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> &zmm, int64_t rem_ = 0) {
+ EIGEN_UNUSED_VARIABLE(rem_);
+ constexpr int64_t counterReverse = endM*endN-counter;
+ constexpr int64_t startM = counterReverse/(endN);
+ constexpr int64_t startN = counterReverse%endN;
+
+ EIGEN_IF_CONSTEXPR(rem)
+ pstoreu<Scalar>(&C_arr[(startN) * LDC + startM*PacketSize], zmm.packet[startN*endM + startM], remMask<PacketSize>(rem_));
+ else
+ pstoreu<Scalar>(&C_arr[(startN) * LDC + startM*PacketSize], zmm.packet[startN*endM + startM]);
+ aux_storeC<endM, endN, counter-1, rem>(C_arr, LDC, zmm, rem_);
+ }
+
+ template<int64_t endM, int64_t endN, int64_t counter, bool rem>
+ static EIGEN_ALWAYS_INLINE std::enable_if_t<(counter <= 0)>
+ aux_storeC(Scalar *C_arr, int64_t LDC, PacketBlock<vec,EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> &zmm, int64_t rem_ = 0)
+ {
+ EIGEN_UNUSED_VARIABLE(C_arr);
+ EIGEN_UNUSED_VARIABLE(LDC);
+ EIGEN_UNUSED_VARIABLE(zmm);
+ EIGEN_UNUSED_VARIABLE(rem_);
+ }
+
+ /**
+ * aux_startLoadB
+ *
+ * 1-D unroll
+ * for(startL = 0; startL < endL; startL++)
+ **/
+ template<int64_t unrollM, int64_t unrollN, int64_t endL, int64_t counter, bool rem>
+ static EIGEN_ALWAYS_INLINE std::enable_if_t<(counter > 0)>
+ aux_startLoadB(Scalar *B_t, int64_t LDB, PacketBlock<vec,EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> &zmm, int64_t rem_ = 0) {
+ EIGEN_UNUSED_VARIABLE(rem_);
+ constexpr int64_t counterReverse = endL-counter;
+ constexpr int64_t startL = counterReverse;
+
+ EIGEN_IF_CONSTEXPR(rem)
+ zmm.packet[unrollM*unrollN+startL] =
+ ploadu<vec>(&B_t[(startL/unrollM)*LDB + (startL%unrollM)*PacketSize], remMask<PacketSize>(rem_));
+ else
+ zmm.packet[unrollM*unrollN+startL] = ploadu<vec>(&B_t[(startL/unrollM)*LDB + (startL%unrollM)*PacketSize]);
+
+ aux_startLoadB<unrollM, unrollN, endL, counter-1, rem>(B_t, LDB, zmm, rem_);
+ }
+
+ template<int64_t unrollM, int64_t unrollN, int64_t endL, int64_t counter, bool rem>
+ static EIGEN_ALWAYS_INLINE std::enable_if_t<(counter <= 0)>
+ aux_startLoadB(
+ Scalar *B_t, int64_t LDB,
+ PacketBlock<vec,EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> &zmm, int64_t rem_ = 0)
+ {
+ EIGEN_UNUSED_VARIABLE(B_t);
+ EIGEN_UNUSED_VARIABLE(LDB);
+ EIGEN_UNUSED_VARIABLE(zmm);
+ EIGEN_UNUSED_VARIABLE(rem_);
+ }
+
+ /**
+ * aux_startBCastA
+ *
+ * 1-D unroll
+ * for(startB = 0; startB < endB; startB++)
+ **/
+ template<bool isARowMajor, int64_t unrollM, int64_t unrollN, int64_t endB, int64_t counter, int64_t numLoad>
+ static EIGEN_ALWAYS_INLINE std::enable_if_t<(counter > 0)>
+ aux_startBCastA(Scalar *A_t, int64_t LDA, PacketBlock<vec,EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> &zmm) {
+ constexpr int64_t counterReverse = endB-counter;
+ constexpr int64_t startB = counterReverse;
+
+ zmm.packet[unrollM*unrollN+numLoad+startB] = pload1<vec>(&A_t[idA<isARowMajor>(startB, 0,LDA)]);
+
+ aux_startBCastA<isARowMajor, unrollM, unrollN, endB, counter-1, numLoad>(A_t, LDA, zmm);
+ }
+
+ template<bool isARowMajor, int64_t unrollM, int64_t unrollN, int64_t endB, int64_t counter, int64_t numLoad>
+ static EIGEN_ALWAYS_INLINE std::enable_if_t<(counter <= 0)>
+ aux_startBCastA(Scalar *A_t, int64_t LDA, PacketBlock<vec,EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> &zmm)
+ {
+ EIGEN_UNUSED_VARIABLE(A_t);
+ EIGEN_UNUSED_VARIABLE(LDA);
+ EIGEN_UNUSED_VARIABLE(zmm);
+ }
+
+ /**
+ * aux_loadB
+ * currK: current K
+ *
+ * 1-D unroll
+ * for(startM = 0; startM < endM; startM++)
+ **/
+ template<int64_t endM, int64_t counter, int64_t unrollN, int64_t currK, int64_t unrollK, int64_t numLoad, int64_t numBCast, bool rem>
+ static EIGEN_ALWAYS_INLINE std::enable_if_t<(counter > 0)>
+ aux_loadB(Scalar *B_t, int64_t LDB, PacketBlock<vec,EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> &zmm, int64_t rem_ = 0) {
+ EIGEN_UNUSED_VARIABLE(rem_);
+ if ((numLoad/endM + currK < unrollK)) {
+ constexpr int64_t counterReverse = endM-counter;
+ constexpr int64_t startM = counterReverse;
+
+ EIGEN_IF_CONSTEXPR(rem) {
+ zmm.packet[endM*unrollN+(startM+currK*endM)%numLoad] =
+ ploadu<vec>(&B_t[(numLoad/endM + currK)*LDB + startM*PacketSize], remMask<PacketSize>(rem_));
+ }
+ else {
+ zmm.packet[endM*unrollN+(startM+currK*endM)%numLoad] =
+ ploadu<vec>(&B_t[(numLoad/endM + currK)*LDB + startM*PacketSize]);
+ }
+
+ aux_loadB<endM, counter-1, unrollN, currK, unrollK, numLoad, numBCast, rem>(B_t, LDB, zmm, rem_);
+ }
+ }
+
+ template<int64_t endM, int64_t counter, int64_t unrollN, int64_t currK, int64_t unrollK, int64_t numLoad, int64_t numBCast, bool rem>
+ static EIGEN_ALWAYS_INLINE std::enable_if_t<(counter <= 0)>
+ aux_loadB(
+ Scalar *B_t, int64_t LDB,
+ PacketBlock<vec,EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> &zmm, int64_t rem_ = 0)
+ {
+ EIGEN_UNUSED_VARIABLE(B_t);
+ EIGEN_UNUSED_VARIABLE(LDB);
+ EIGEN_UNUSED_VARIABLE(zmm);
+ EIGEN_UNUSED_VARIABLE(rem_);
+ }
+
+ /**
+ * aux_microKernel
+ *
+ * 3-D unroll
+ * for(startM = 0; startM < endM; startM++)
+ * for(startN = 0; startN < endN; startN++)
+ * for(startK = 0; startK < endK; startK++)
+ **/
+ template<bool isARowMajor, int64_t endM, int64_t endN, int64_t endK, int64_t counter, int64_t numLoad, int64_t numBCast, bool rem>
+ static EIGEN_ALWAYS_INLINE std::enable_if_t<(counter > 0)>
+ aux_microKernel(
+ Scalar *B_t, Scalar* A_t, int64_t LDB, int64_t LDA,
+ PacketBlock<vec,EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> &zmm, int64_t rem_ = 0) {
+ EIGEN_UNUSED_VARIABLE(rem_);
+ constexpr int64_t counterReverse = endM*endN*endK-counter;
+ constexpr int startK = counterReverse/(endM*endN);
+ constexpr int startN = (counterReverse/(endM))%endN;
+ constexpr int startM = counterReverse%endM;
+
+ EIGEN_IF_CONSTEXPR(startK == 0 && startM == 0 && startN == 0) {
+ gemm:: template
+ startLoadB<endM, endN, numLoad, rem>(B_t, LDB, zmm, rem_);
+ gemm:: template
+ startBCastA<isARowMajor, endM, endN, numBCast, numLoad>(A_t, LDA, zmm);
+ }
+
+ {
+ // Interleave FMA and Bcast
+ EIGEN_IF_CONSTEXPR(isAdd) {
+ zmm.packet[startN*endM + startM] =
+ pmadd(zmm.packet[endM*endN+numLoad+(startN+startK*endN)%numBCast],
+ zmm.packet[endM*endN+(startM+startK*endM)%numLoad], zmm.packet[startN*endM + startM]);
+ }
+ else {
+ zmm.packet[startN*endM + startM] =
+ pnmadd(zmm.packet[endM*endN+numLoad+(startN+startK*endN)%numBCast],
+ zmm.packet[endM*endN+(startM+startK*endM)%numLoad], zmm.packet[startN*endM + startM]);
+ }
+ // Bcast
+ EIGEN_IF_CONSTEXPR(startM == endM - 1 && (numBCast + startN + startK*endN < endK*endN)) {
+ zmm.packet[endM*endN+numLoad+(startN+startK*endN)%numBCast] =
+ pload1<vec>(&A_t[idA<isARowMajor>((numBCast + startN + startK*endN)%endN,
+ (numBCast + startN + startK*endN)/endN, LDA)]);
+ }
+ }
+
+ // We have updated all accumlators, time to load next set of B's
+ EIGEN_IF_CONSTEXPR( (startN == endN - 1) && (startM == endM - 1) ) {
+ gemm::template loadB<endM, endN, startK, endK, numLoad, numBCast, rem>(B_t, LDB, zmm, rem_);
+ }
+ aux_microKernel<isARowMajor, endM, endN, endK, counter-1, numLoad, numBCast, rem>(B_t, A_t, LDB, LDA, zmm, rem_);
+
+ }
+
+ template<bool isARowMajor, int64_t endM, int64_t endN, int64_t endK, int64_t counter, int64_t numLoad, int64_t numBCast, bool rem>
+ static EIGEN_ALWAYS_INLINE std::enable_if_t<(counter <= 0)>
+ aux_microKernel(
+ Scalar *B_t, Scalar* A_t, int64_t LDB, int64_t LDA,
+ PacketBlock<vec,EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> &zmm, int64_t rem_ = 0)
+ {
+ EIGEN_UNUSED_VARIABLE(B_t);
+ EIGEN_UNUSED_VARIABLE(A_t);
+ EIGEN_UNUSED_VARIABLE(LDB);
+ EIGEN_UNUSED_VARIABLE(LDA);
+ EIGEN_UNUSED_VARIABLE(zmm);
+ EIGEN_UNUSED_VARIABLE(rem_);
+ }
+
+ /********************************************************
+ * Wrappers for aux_XXXX to hide counter parameter
+ ********************************************************/
+
+ template<int64_t endM, int64_t endN>
+ static EIGEN_ALWAYS_INLINE
+ void setzero(PacketBlock<vec,EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> &zmm){
+ aux_setzero<endM, endN, endM*endN>(zmm);
+ }
+
+ /**
+ * Ideally the compiler folds these into vaddp{s,d} with an embedded memory load.
+ */
+ template<int64_t endM, int64_t endN, bool rem = false>
+ static EIGEN_ALWAYS_INLINE
+ void updateC(Scalar *C_arr, int64_t LDC, PacketBlock<vec,EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> &zmm, int64_t rem_ = 0){
+ EIGEN_UNUSED_VARIABLE(rem_);
+ aux_updateC<endM, endN, endM*endN, rem>(C_arr, LDC, zmm, rem_);
+ }
+
+ template<int64_t endM, int64_t endN, bool rem = false>
+ static EIGEN_ALWAYS_INLINE
+ void storeC(Scalar *C_arr, int64_t LDC, PacketBlock<vec,EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> &zmm, int64_t rem_ = 0){
+ EIGEN_UNUSED_VARIABLE(rem_);
+ aux_storeC<endM, endN, endM*endN, rem>(C_arr, LDC, zmm, rem_);
+ }
+
+ /**
+ * Use numLoad registers for loading B at start of microKernel
+ */
+ template<int64_t unrollM, int64_t unrollN, int64_t endL, bool rem>
+ static EIGEN_ALWAYS_INLINE
+ void startLoadB(Scalar *B_t, int64_t LDB, PacketBlock<vec,EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> &zmm, int64_t rem_ = 0){
+ EIGEN_UNUSED_VARIABLE(rem_);
+ aux_startLoadB<unrollM, unrollN, endL, endL, rem>(B_t, LDB, zmm, rem_);
+ }
+
+ /**
+ * Use numBCast registers for broadcasting A at start of microKernel
+ */
+ template<bool isARowMajor, int64_t unrollM, int64_t unrollN, int64_t endB, int64_t numLoad>
+ static EIGEN_ALWAYS_INLINE
+ void startBCastA(Scalar *A_t, int64_t LDA, PacketBlock<vec,EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> &zmm){
+ aux_startBCastA<isARowMajor, unrollM, unrollN, endB, endB, numLoad>(A_t, LDA, zmm);
+ }
+
+ /**
+ * Loads next set of B into vector registers between each K unroll.
+ */
+ template<int64_t endM, int64_t unrollN, int64_t currK, int64_t unrollK, int64_t numLoad, int64_t numBCast, bool rem>
+ static EIGEN_ALWAYS_INLINE
+ void loadB(
+ Scalar *B_t, int64_t LDB, PacketBlock<vec,EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> &zmm, int64_t rem_ = 0){
+ EIGEN_UNUSED_VARIABLE(rem_);
+ aux_loadB<endM, endM, unrollN, currK, unrollK, numLoad, numBCast, rem>(B_t, LDB, zmm, rem_);
+ }
+
+ /**
+ * Generates a microkernel for gemm (row-major) with unrolls {1,2,4,8}x{U1,U2,U3} to compute C -= A*B.
+ * A matrix can be row/col-major. B matrix is assumed row-major.
+ *
+ * isARowMajor: is A row major
+ * endM: Number registers per row
+ * endN: Number of rows
+ * endK: Loop unroll for K.
+ * numLoad: Number of registers for loading B.
+ * numBCast: Number of registers for broadcasting A.
+ *
+ * Ex: microkernel<isARowMajor,0,3,0,4,0,4,6,2>: 8x48 unroll (24 accumulators), k unrolled 4 times,
+ * 6 register for loading B, 2 for broadcasting A.
+ *
+ * Note: Ideally the microkernel should not have any register spilling.
+ * The avx instruction counts should be:
+ * - endK*endN vbroadcasts{s,d}
+ * - endK*endM vmovup{s,d}
+ * - endK*endN*endM FMAs
+ *
+ * From testing, there are no register spills with clang. There are register spills with GNU, which
+ * causes a performance hit.
+ */
+ template<bool isARowMajor, int64_t endM, int64_t endN, int64_t endK, int64_t numLoad, int64_t numBCast, bool rem = false>
+ static EIGEN_ALWAYS_INLINE
+ void microKernel(
+ Scalar *B_t, Scalar* A_t, int64_t LDB, int64_t LDA,
+ PacketBlock<vec,EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> &zmm, int64_t rem_ = 0){
+ EIGEN_UNUSED_VARIABLE(rem_);
+ aux_microKernel<isARowMajor,endM, endN, endK, endM*endN*endK, numLoad, numBCast, rem>(
+ B_t, A_t, LDB, LDA, zmm, rem_);
+ }
+
+};
+} // namespace unrolls
+
+
+#endif //EIGEN_UNROLLS_IMPL_H
diff --git a/Eigen/src/Core/arch/AVX512/TypeCasting.h b/Eigen/src/Core/arch/AVX512/TypeCasting.h
index 2f299e2..8baced1 100644
--- a/Eigen/src/Core/arch/AVX512/TypeCasting.h
+++ b/Eigen/src/Core/arch/AVX512/TypeCasting.h
@@ -32,6 +32,26 @@
return _mm512_castsi512_ps(a);
}
+template<> EIGEN_STRONG_INLINE Packet8d preinterpret<Packet8d, Packet16f>(const Packet16f& a) {
+ return _mm512_castps_pd(a);
+}
+
+template<> EIGEN_STRONG_INLINE Packet16f preinterpret<Packet16f, Packet8d>(const Packet8d& a) {
+ return _mm512_castpd_ps(a);
+}
+
+template<> EIGEN_STRONG_INLINE Packet8f preinterpret<Packet8f, Packet16f>(const Packet16f& a) {
+ return _mm512_castps512_ps256(a);
+}
+
+template<> EIGEN_STRONG_INLINE Packet16f preinterpret<Packet16f, Packet16f>(const Packet16f& a) {
+ return a;
+}
+
+template<> EIGEN_STRONG_INLINE Packet8d preinterpret<Packet8d, Packet8d>(const Packet8d& a) {
+ return a;
+}
+
template <>
struct type_casting_traits<half, float> {
enum {
diff --git a/Eigen/src/Core/arch/AltiVec/Complex.h b/Eigen/src/Core/arch/AltiVec/Complex.h
index 4fd923e..ba5a3fd 100644
--- a/Eigen/src/Core/arch/AltiVec/Complex.h
+++ b/Eigen/src/Core/arch/AltiVec/Complex.h
@@ -114,11 +114,19 @@
template<> EIGEN_STRONG_INLINE Packet2cf pset1<Packet2cf>(const std::complex<float>& from)
{
Packet2cf res;
+#ifdef __VSX__
+ // Load a single std::complex<float> from memory and duplicate
+ //
+ // Using pload would read past the end of the reference in this case
+ // Using vec_xl_len + vec_splat, generates poor assembly
+ __asm__ ("lxvdsx %x0,%y1" : "=wa" (res.v) : "Z" (from));
+#else
if((std::ptrdiff_t(&from) % 16) == 0)
res.v = pload<Packet4f>((const float *)&from);
else
res.v = ploadu<Packet4f>((const float *)&from);
res.v = vec_perm(res.v, res.v, p16uc_PSET64_HI);
+#endif
return res;
}
@@ -133,6 +141,7 @@
{
Packet4f res0, res1;
#ifdef __VSX__
+ // Load two std::complex<float> from memory and combine
__asm__ ("lxsdx %x0,%y1" : "=wa" (res0) : "Z" (from0));
__asm__ ("lxsdx %x0,%y1" : "=wa" (res1) : "Z" (from1));
#ifdef _BIG_ENDIAN
@@ -186,7 +195,7 @@
template<> EIGEN_STRONG_INLINE Packet2cf preverse(const Packet2cf& a)
{
Packet4f rev_a;
- rev_a = vec_perm(a.v, a.v, p16uc_COMPLEX32_REV2);
+ rev_a = vec_sld(a.v, a.v, 8);
return Packet2cf(rev_a);
}
@@ -222,8 +231,8 @@
EIGEN_STRONG_INLINE void ptranspose(PacketBlock<Packet2cf,2>& kernel)
{
- Packet4f tmp = vec_perm(kernel.packet[0].v, kernel.packet[1].v, p16uc_TRANSPOSE64_HI);
- kernel.packet[1].v = vec_perm(kernel.packet[0].v, kernel.packet[1].v, p16uc_TRANSPOSE64_LO);
+ Packet4f tmp = reinterpret_cast<Packet4f>(vec_mergeh(reinterpret_cast<Packet2d>(kernel.packet[0].v), reinterpret_cast<Packet2d>(kernel.packet[1].v)));
+ kernel.packet[1].v = reinterpret_cast<Packet4f>(vec_mergel(reinterpret_cast<Packet2d>(kernel.packet[0].v), reinterpret_cast<Packet2d>(kernel.packet[1].v)));
kernel.packet[0].v = tmp;
}
@@ -358,7 +367,7 @@
template<> EIGEN_STRONG_INLINE std::complex<double> pfirst<Packet1cd>(const Packet1cd& a)
{
- EIGEN_ALIGN16 std::complex<double> res[2];
+ EIGEN_ALIGN16 std::complex<double> res[1];
pstore<std::complex<double> >(res, a);
return res[0];
@@ -384,8 +393,8 @@
EIGEN_STRONG_INLINE void ptranspose(PacketBlock<Packet1cd,2>& kernel)
{
- Packet2d tmp = vec_perm(kernel.packet[0].v, kernel.packet[1].v, p16uc_TRANSPOSE64_HI);
- kernel.packet[1].v = vec_perm(kernel.packet[0].v, kernel.packet[1].v, p16uc_TRANSPOSE64_LO);
+ Packet2d tmp = vec_mergeh(kernel.packet[0].v, kernel.packet[1].v);
+ kernel.packet[1].v = vec_mergel(kernel.packet[0].v, kernel.packet[1].v);
kernel.packet[0].v = tmp;
}
diff --git a/Eigen/src/Core/arch/AltiVec/MatrixProduct.h b/Eigen/src/Core/arch/AltiVec/MatrixProduct.h
index 2919dda..e24b5d5 100644
--- a/Eigen/src/Core/arch/AltiVec/MatrixProduct.h
+++ b/Eigen/src/Core/arch/AltiVec/MatrixProduct.h
@@ -17,24 +17,35 @@
#include "MatrixProductCommon.h"
-// Since LLVM doesn't support dynamic dispatching, force either always MMA or VSX
-#if EIGEN_COMP_LLVM
-#if !defined(EIGEN_ALTIVEC_DISABLE_MMA) && !defined(EIGEN_ALTIVEC_MMA_ONLY)
-#ifdef __MMA__
-#define EIGEN_ALTIVEC_MMA_ONLY
-#else
-#define EIGEN_ALTIVEC_DISABLE_MMA
-#endif
-#endif
+#if !defined(EIGEN_ALTIVEC_DISABLE_MMA)
+#define EIGEN_ALTIVEC_DISABLE_MMA 0
#endif
-#ifdef __has_builtin
+// Check for MMA builtin support.
+#if !EIGEN_ALTIVEC_DISABLE_MMA && defined(__has_builtin)
#if __has_builtin(__builtin_mma_assemble_acc)
- #define ALTIVEC_MMA_SUPPORT
+ #define EIGEN_ALTIVEC_MMA_SUPPORT
#endif
#endif
-#if defined(ALTIVEC_MMA_SUPPORT) && !defined(EIGEN_ALTIVEC_DISABLE_MMA)
+// Check if and how we should actually use MMA if supported.
+#if defined(EIGEN_ALTIVEC_MMA_SUPPORT)
+
+#if !defined(EIGEN_ALTIVEC_ENABLE_MMA_DYNAMIC_DISPATCH)
+#define EIGEN_ALTIVEC_ENABLE_MMA_DYNAMIC_DISPATCH 0
+#endif
+
+// Check if we want to enable dynamic dispatch. Not supported by LLVM.
+#if EIGEN_ALTIVEC_ENABLE_MMA_DYNAMIC_DISPATCH && !EIGEN_COMP_LLVM
+#define EIGEN_ALTIVEC_MMA_DYNAMIC_DISPATCH 1
+// Otherwise, use MMA by default if available.
+#elif defined(__MMA__)
+#define EIGEN_ALTIVEC_MMA_ONLY 1
+#endif
+
+#endif // EIGEN_ALTIVEC_MMA_SUPPORT
+
+#if defined(EIGEN_ALTIVEC_MMA_ONLY) || defined(EIGEN_ALTIVEC_MMA_DYNAMIC_DISPATCH)
#include "MatrixProductMMA.h"
#endif
@@ -1009,26 +1020,12 @@
{
if(NegativeAccumulate)
{
- acc->packet[0] = vec_nmsub(lhsV, rhsV[0], acc->packet[0]);
- if (N > 1) {
- acc->packet[1] = vec_nmsub(lhsV, rhsV[1], acc->packet[1]);
- }
- if (N > 2) {
- acc->packet[2] = vec_nmsub(lhsV, rhsV[2], acc->packet[2]);
- }
- if (N > 3) {
- acc->packet[3] = vec_nmsub(lhsV, rhsV[3], acc->packet[3]);
+ for (int M = 0; M < N; M++) {
+ acc->packet[M] = vec_nmsub(lhsV, rhsV[M], acc->packet[M]);
}
} else {
- acc->packet[0] = vec_madd(lhsV, rhsV[0], acc->packet[0]);
- if (N > 1) {
- acc->packet[1] = vec_madd(lhsV, rhsV[1], acc->packet[1]);
- }
- if (N > 2) {
- acc->packet[2] = vec_madd(lhsV, rhsV[2], acc->packet[2]);
- }
- if (N > 3) {
- acc->packet[3] = vec_madd(lhsV, rhsV[3], acc->packet[3]);
+ for (int M = 0; M < N; M++) {
+ acc->packet[M] = vec_madd(lhsV, rhsV[M], acc->packet[M]);
}
}
}
@@ -1041,31 +1038,9 @@
pger_common<Packet, NegativeAccumulate, N>(acc, lhsV, rhsV);
}
-template<typename Scalar, typename Packet, typename Index, const Index remaining_rows>
-EIGEN_ALWAYS_INLINE void loadPacketRemaining(const Scalar* lhs, Packet &lhsV)
-{
-#ifdef _ARCH_PWR9
- lhsV = vec_xl_len((Scalar *)lhs, remaining_rows * sizeof(Scalar));
-#else
- Index i = 0;
- do {
- lhsV[i] = lhs[i];
- } while (++i < remaining_rows);
-#endif
-}
-
-template<int N, typename Scalar, typename Packet, typename Index, bool NegativeAccumulate, const Index remaining_rows>
-EIGEN_ALWAYS_INLINE void pger(PacketBlock<Packet,N>* acc, const Scalar* lhs, const Packet* rhsV)
-{
- Packet lhsV;
- loadPacketRemaining<Scalar, Packet, Index, remaining_rows>(lhs, lhsV);
-
- pger_common<Packet, NegativeAccumulate, N>(acc, lhsV, rhsV);
-}
-
// 512-bits rank1-update of complex acc. It takes decoupled accumulators as entries. It also takes cares of mixed types real * complex and complex * real.
template<int N, typename Packet, bool ConjugateLhs, bool ConjugateRhs, bool LhsIsReal, bool RhsIsReal>
-EIGEN_ALWAYS_INLINE void pgerc_common(PacketBlock<Packet,N>* accReal, PacketBlock<Packet,N>* accImag, const Packet &lhsV, const Packet &lhsVi, const Packet* rhsV, const Packet* rhsVi)
+EIGEN_ALWAYS_INLINE void pgerc_common(PacketBlock<Packet,N>* accReal, PacketBlock<Packet,N>* accImag, const Packet &lhsV, Packet &lhsVi, const Packet* rhsV, const Packet* rhsVi)
{
pger_common<Packet, false, N>(accReal, lhsV, rhsV);
if(LhsIsReal)
@@ -1086,97 +1061,56 @@
template<int N, typename Scalar, typename Packet, bool ConjugateLhs, bool ConjugateRhs, bool LhsIsReal, bool RhsIsReal>
EIGEN_ALWAYS_INLINE void pgerc(PacketBlock<Packet,N>* accReal, PacketBlock<Packet,N>* accImag, const Scalar* lhs_ptr, const Scalar* lhs_ptr_imag, const Packet* rhsV, const Packet* rhsVi)
{
- Packet lhsV = ploadLhs<Scalar, Packet>(lhs_ptr);
+ Packet lhsV = ploadLhs<Packet>(lhs_ptr);
Packet lhsVi;
- if(!LhsIsReal) lhsVi = ploadLhs<Scalar, Packet>(lhs_ptr_imag);
+ if(!LhsIsReal) lhsVi = ploadLhs<Packet>(lhs_ptr_imag);
else EIGEN_UNUSED_VARIABLE(lhs_ptr_imag);
pgerc_common<N, Packet, ConjugateLhs, ConjugateRhs, LhsIsReal, RhsIsReal>(accReal, accImag, lhsV, lhsVi, rhsV, rhsVi);
}
-template<typename Scalar, typename Packet, typename Index, bool LhsIsReal, const Index remaining_rows>
-EIGEN_ALWAYS_INLINE void loadPacketRemaining(const Scalar* lhs_ptr, const Scalar* lhs_ptr_imag, Packet &lhsV, Packet &lhsVi)
-{
-#ifdef _ARCH_PWR9
- lhsV = vec_xl_len((Scalar *)lhs_ptr, remaining_rows * sizeof(Scalar));
- if(!LhsIsReal) lhsVi = vec_xl_len((Scalar *)lhs_ptr_imag, remaining_rows * sizeof(Scalar));
- else EIGEN_UNUSED_VARIABLE(lhs_ptr_imag);
-#else
- Index i = 0;
- do {
- lhsV[i] = lhs_ptr[i];
- if(!LhsIsReal) lhsVi[i] = lhs_ptr_imag[i];
- } while (++i < remaining_rows);
- if(LhsIsReal) EIGEN_UNUSED_VARIABLE(lhs_ptr_imag);
-#endif
-}
-
-template<int N, typename Scalar, typename Packet, typename Index, bool ConjugateLhs, bool ConjugateRhs, bool LhsIsReal, bool RhsIsReal, const Index remaining_rows>
-EIGEN_ALWAYS_INLINE void pgerc(PacketBlock<Packet,N>* accReal, PacketBlock<Packet,N>* accImag, const Scalar* lhs_ptr, const Scalar* lhs_ptr_imag, const Packet* rhsV, const Packet* rhsVi)
-{
- Packet lhsV, lhsVi;
- loadPacketRemaining<Scalar, Packet, Index, LhsIsReal, remaining_rows>(lhs_ptr, lhs_ptr_imag, lhsV, lhsVi);
-
- pgerc_common<N, Packet, ConjugateLhs, ConjugateRhs, LhsIsReal, RhsIsReal>(accReal, accImag, lhsV, lhsVi, rhsV, rhsVi);
-}
-
-template<typename Scalar, typename Packet>
-EIGEN_ALWAYS_INLINE Packet ploadLhs(const Scalar* lhs)
+template<typename Packet>
+EIGEN_ALWAYS_INLINE Packet ploadLhs(const __UNPACK_TYPE__(Packet)* lhs)
{
return ploadu<Packet>(lhs);
}
// Zero the accumulator on PacketBlock.
-template<typename Scalar, typename Packet, int N>
+template<typename Packet, int N>
EIGEN_ALWAYS_INLINE void bsetzero(PacketBlock<Packet,N>& acc)
{
- acc.packet[0] = pset1<Packet>((Scalar)0);
- if (N > 1) {
- acc.packet[1] = pset1<Packet>((Scalar)0);
- }
- if (N > 2) {
- acc.packet[2] = pset1<Packet>((Scalar)0);
- }
- if (N > 3) {
- acc.packet[3] = pset1<Packet>((Scalar)0);
- }
-}
-
-// Scale the PacketBlock vectors by alpha.
-template<typename Packet, int N>
-EIGEN_ALWAYS_INLINE void bscale(PacketBlock<Packet,N>& acc, PacketBlock<Packet,N>& accZ, const Packet& pAlpha)
-{
- acc.packet[0] = pmadd(pAlpha, accZ.packet[0], acc.packet[0]);
- if (N > 1) {
- acc.packet[1] = pmadd(pAlpha, accZ.packet[1], acc.packet[1]);
- }
- if (N > 2) {
- acc.packet[2] = pmadd(pAlpha, accZ.packet[2], acc.packet[2]);
- }
- if (N > 3) {
- acc.packet[3] = pmadd(pAlpha, accZ.packet[3], acc.packet[3]);
+ for (int M = 0; M < N; M++) {
+ acc.packet[M] = pset1<Packet>((__UNPACK_TYPE__(Packet))0);
}
}
template<typename Packet, int N>
EIGEN_ALWAYS_INLINE void bscalec_common(PacketBlock<Packet,N>& acc, PacketBlock<Packet,N>& accZ, const Packet& pAlpha)
{
- acc.packet[0] = pmul<Packet>(accZ.packet[0], pAlpha);
- if (N > 1) {
- acc.packet[1] = pmul<Packet>(accZ.packet[1], pAlpha);
+ for (int M = 0; M < N; M++) {
+ acc.packet[M] = vec_mul(accZ.packet[M], pAlpha);
}
- if (N > 2) {
- acc.packet[2] = pmul<Packet>(accZ.packet[2], pAlpha);
- }
- if (N > 3) {
- acc.packet[3] = pmul<Packet>(accZ.packet[3], pAlpha);
+}
+
+template<typename Packet, int N>
+EIGEN_ALWAYS_INLINE void band(PacketBlock<Packet,N>& acc, const Packet& pMask)
+{
+ for (int M = 0; M < N; M++) {
+ acc.packet[M] = pand<Packet>(acc.packet[M], pMask);
}
}
// Complex version of PacketBlock scaling.
-template<typename Packet, int N>
-EIGEN_ALWAYS_INLINE void bscalec(PacketBlock<Packet,N>& aReal, PacketBlock<Packet,N>& aImag, const Packet& bReal, const Packet& bImag, PacketBlock<Packet,N>& cReal, PacketBlock<Packet,N>& cImag)
+template<typename Packet, int N, bool mask>
+EIGEN_ALWAYS_INLINE void bscalec(PacketBlock<Packet,N>& aReal, PacketBlock<Packet,N>& aImag, const Packet& bReal, const Packet& bImag, PacketBlock<Packet,N>& cReal, PacketBlock<Packet,N>& cImag, const Packet& pMask)
{
+ if (mask && (sizeof(__UNPACK_TYPE__(Packet)) == sizeof(float))) {
+ band<Packet, N>(aReal, pMask);
+ band<Packet, N>(aImag, pMask);
+ } else {
+ EIGEN_UNUSED_VARIABLE(pMask);
+ }
+
bscalec_common<Packet, N>(cReal, aReal, bReal);
bscalec_common<Packet, N>(cImag, aImag, bReal);
@@ -1186,213 +1120,253 @@
pger_common<Packet, false, N>(&cImag, bImag, aReal.packet);
}
-template<typename Packet, int N>
-EIGEN_ALWAYS_INLINE void band(PacketBlock<Packet,N>& acc, const Packet& pMask)
-{
- acc.packet[0] = pand(acc.packet[0], pMask);
- if (N > 1) {
- acc.packet[1] = pand(acc.packet[1], pMask);
- }
- if (N > 2) {
- acc.packet[2] = pand(acc.packet[2], pMask);
- }
- if (N > 3) {
- acc.packet[3] = pand(acc.packet[3], pMask);
- }
-}
-
-template<typename Packet, int N>
-EIGEN_ALWAYS_INLINE void bscalec(PacketBlock<Packet,N>& aReal, PacketBlock<Packet,N>& aImag, const Packet& bReal, const Packet& bImag, PacketBlock<Packet,N>& cReal, PacketBlock<Packet,N>& cImag, const Packet& pMask)
-{
- band<Packet, N>(aReal, pMask);
- band<Packet, N>(aImag, pMask);
-
- bscalec<Packet,N>(aReal, aImag, bReal, bImag, cReal, cImag);
-}
-
// Load a PacketBlock, the N parameters make tunning gemm easier so we can add more accumulators as needed.
-template<typename DataMapper, typename Packet, typename Index, const Index accCols, int StorageOrder, bool Complex, int N>
+//
+// full = operate (load) on the entire PacketBlock or only half
+template<typename DataMapper, typename Packet, typename Index, const Index accCols, int StorageOrder, bool Complex, int N, bool full>
EIGEN_ALWAYS_INLINE void bload(PacketBlock<Packet,N*(Complex?2:1)>& acc, const DataMapper& res, Index row, Index col)
{
if (StorageOrder == RowMajor) {
- acc.packet[0] = res.template loadPacket<Packet>(row + 0, col);
- if (N > 1) {
- acc.packet[1] = res.template loadPacket<Packet>(row + 1, col);
- }
- if (N > 2) {
- acc.packet[2] = res.template loadPacket<Packet>(row + 2, col);
- }
- if (N > 3) {
- acc.packet[3] = res.template loadPacket<Packet>(row + 3, col);
+ for (int M = 0; M < N; M++) {
+ acc.packet[M] = res.template loadPacket<Packet>(row + M, col);
}
if (Complex) {
- acc.packet[0+N] = res.template loadPacket<Packet>(row + 0, col + accCols);
- if (N > 1) {
- acc.packet[1+N] = res.template loadPacket<Packet>(row + 1, col + accCols);
- }
- if (N > 2) {
- acc.packet[2+N] = res.template loadPacket<Packet>(row + 2, col + accCols);
- }
- if (N > 3) {
- acc.packet[3+N] = res.template loadPacket<Packet>(row + 3, col + accCols);
+ for (int M = 0; M < N; M++) {
+ acc.packet[M+N] = res.template loadPacket<Packet>(row + M, col + accCols);
}
}
} else {
- acc.packet[0] = res.template loadPacket<Packet>(row, col + 0);
- if (N > 1) {
- acc.packet[1] = res.template loadPacket<Packet>(row, col + 1);
+ for (int M = 0; M < N; M++) {
+ acc.packet[M] = res.template loadPacket<Packet>(row, col + M);
}
- if (N > 2) {
- acc.packet[2] = res.template loadPacket<Packet>(row, col + 2);
- }
- if (N > 3) {
- acc.packet[3] = res.template loadPacket<Packet>(row, col + 3);
- }
- if (Complex) {
- acc.packet[0+N] = res.template loadPacket<Packet>(row + accCols, col + 0);
- if (N > 1) {
- acc.packet[1+N] = res.template loadPacket<Packet>(row + accCols, col + 1);
- }
- if (N > 2) {
- acc.packet[2+N] = res.template loadPacket<Packet>(row + accCols, col + 2);
- }
- if (N > 3) {
- acc.packet[3+N] = res.template loadPacket<Packet>(row + accCols, col + 3);
+ if (Complex && full) {
+ for (int M = 0; M < N; M++) {
+ acc.packet[M+N] = res.template loadPacket<Packet>(row + accCols, col + M);
}
}
}
}
-const static Packet4i mask41 = { -1, 0, 0, 0 };
-const static Packet4i mask42 = { -1, -1, 0, 0 };
-const static Packet4i mask43 = { -1, -1, -1, 0 };
+template<typename DataMapper, typename Packet, typename Index, int N>
+EIGEN_ALWAYS_INLINE void bstore(PacketBlock<Packet,N>& acc, const DataMapper& res, Index row)
+{
+ for (int M = 0; M < N; M++) {
+ res.template storePacket<Packet>(row, M, acc.packet[M]);
+ }
+}
-const static Packet2l mask21 = { -1, 0 };
+#ifdef _ARCH_PWR10
+#define USE_P10_AND_PVIPR2_0 (EIGEN_COMP_LLVM || (__GNUC__ >= 11))
+#else
+#define USE_P10_AND_PVIPR2_0 0
+#endif
+
+#if !USE_P10_AND_PVIPR2_0
+const static Packet4i mask4[4] = { { 0, 0, 0, 0 }, { -1, 0, 0, 0 }, { -1, -1, 0, 0 }, { -1, -1, -1, 0 } };
+#endif
template<typename Packet, typename Index>
EIGEN_ALWAYS_INLINE Packet bmask(const Index remaining_rows)
{
- if (remaining_rows == 0) {
- return pset1<Packet>(float(0.0)); // Not used
- } else {
- switch (remaining_rows) {
- case 1: return Packet(mask41);
- case 2: return Packet(mask42);
- default: return Packet(mask43);
- }
- }
+#if USE_P10_AND_PVIPR2_0
+#ifdef _BIG_ENDIAN
+ return Packet(vec_reve(vec_genwm((1 << remaining_rows) - 1)));
+#else
+ return Packet(vec_genwm((1 << remaining_rows) - 1));
+#endif
+#else
+ return Packet(mask4[remaining_rows]);
+#endif
}
template<>
EIGEN_ALWAYS_INLINE Packet2d bmask<Packet2d,Index>(const Index remaining_rows)
{
- if (remaining_rows == 0) {
- return pset1<Packet2d>(double(0.0)); // Not used
- } else {
- return Packet2d(mask21);
- }
+#if USE_P10_AND_PVIPR2_0
+ Packet2d mask2 = Packet2d(vec_gendm(remaining_rows));
+#ifdef _BIG_ENDIAN
+ return preverse(mask2);
+#else
+ return mask2;
+#endif
+#else
+ Packet2l ret = { -remaining_rows, 0 };
+ return Packet2d(ret);
+#endif
}
-template<typename Packet, int N>
+// Scale the PacketBlock vectors by alpha.
+template<typename Packet, int N, bool mask>
EIGEN_ALWAYS_INLINE void bscale(PacketBlock<Packet,N>& acc, PacketBlock<Packet,N>& accZ, const Packet& pAlpha, const Packet& pMask)
{
- band<Packet, N>(accZ, pMask);
+ if (mask) {
+ band<Packet, N>(accZ, pMask);
+ } else {
+ EIGEN_UNUSED_VARIABLE(pMask);
+ }
- bscale<Packet, N>(acc, accZ, pAlpha);
+ for (int M = 0; M < N; M++) {
+ acc.packet[M] = pmadd<Packet>(pAlpha, accZ.packet[M], acc.packet[M]);
+ }
}
-template<typename Packet, int N> EIGEN_ALWAYS_INLINE void
-pbroadcastN_old(const __UNPACK_TYPE__(Packet) *a,
- Packet& a0, Packet& a1, Packet& a2, Packet& a3)
+template<typename Packet, int N, bool real>
+EIGEN_ALWAYS_INLINE void pbroadcastN(const __UNPACK_TYPE__(Packet) *ap0,
+ const __UNPACK_TYPE__(Packet) *ap1, const __UNPACK_TYPE__(Packet) *ap2,
+ Packet& a0, Packet& a1, Packet& a2, Packet& a3)
{
- a0 = pset1<Packet>(a[0]);
- if (N > 1) {
- a1 = pset1<Packet>(a[1]);
+ a0 = pset1<Packet>(ap0[0]);
+ if (N == 4) {
+ a1 = pset1<Packet>(ap0[1]);
+ a2 = pset1<Packet>(ap0[2]);
+ a3 = pset1<Packet>(ap0[3]);
+ EIGEN_UNUSED_VARIABLE(ap1);
+ EIGEN_UNUSED_VARIABLE(ap2);
} else {
- EIGEN_UNUSED_VARIABLE(a1);
+ if (N > 1) {
+ a1 = pset1<Packet>(ap1[0]);
+ } else {
+ EIGEN_UNUSED_VARIABLE(a1);
+ EIGEN_UNUSED_VARIABLE(ap1);
+ }
+ if (N > 2) {
+ a2 = pset1<Packet>(ap2[0]);
+ } else {
+ EIGEN_UNUSED_VARIABLE(a2);
+ EIGEN_UNUSED_VARIABLE(ap2);
+ }
}
- if (N > 2) {
- a2 = pset1<Packet>(a[2]);
- } else {
- EIGEN_UNUSED_VARIABLE(a2);
- }
- if (N > 3) {
- a3 = pset1<Packet>(a[3]);
- } else {
- EIGEN_UNUSED_VARIABLE(a3);
- }
+}
+
+template<> EIGEN_ALWAYS_INLINE void
+pbroadcastN<Packet4f,4,true>(const float *ap0, const float *, const float *,
+ Packet4f& a0, Packet4f& a1, Packet4f& a2, Packet4f& a3)
+{
+ pbroadcast4<Packet4f>(ap0, a0, a1, a2, a3);
+}
+
+template<> EIGEN_ALWAYS_INLINE void
+pbroadcastN<Packet4f,4,false>(const float *ap0, const float *ap1, const float *ap2,
+ Packet4f& a0, Packet4f& a1, Packet4f& a2, Packet4f& a3)
+{
+ pbroadcastN<Packet4f,4,true>(ap0, ap1, ap2, a0, a1, a2, a3);
}
template<>
-EIGEN_ALWAYS_INLINE void pbroadcastN_old<Packet4f,4>(const float* a, Packet4f& a0, Packet4f& a1, Packet4f& a2, Packet4f& a3)
+EIGEN_ALWAYS_INLINE void pbroadcastN<Packet2d,4,false>(const double* ap0, const double *,
+ const double *, Packet2d& a0, Packet2d& a1, Packet2d& a2, Packet2d& a3)
{
- pbroadcast4<Packet4f>(a, a0, a1, a2, a3);
-}
-
-template<>
-EIGEN_ALWAYS_INLINE void pbroadcastN_old<Packet2d,4>(const double* a, Packet2d& a0, Packet2d& a1, Packet2d& a2, Packet2d& a3)
-{
- a1 = pload<Packet2d>(a);
- a3 = pload<Packet2d>(a + 2);
+ a1 = pload<Packet2d>(ap0);
+ a3 = pload<Packet2d>(ap0 + 2);
a0 = vec_splat(a1, 0);
a1 = vec_splat(a1, 1);
a2 = vec_splat(a3, 0);
a3 = vec_splat(a3, 1);
}
-template<typename Packet, int N> EIGEN_ALWAYS_INLINE void
-pbroadcastN(const __UNPACK_TYPE__(Packet) *a,
- Packet& a0, Packet& a1, Packet& a2, Packet& a3)
+// Grab two decouples real/imaginary PacketBlocks and return two coupled (real/imaginary pairs) PacketBlocks.
+template<typename Packet, typename Packetc, int N, bool full>
+EIGEN_ALWAYS_INLINE void bcouple_common(PacketBlock<Packet,N>& taccReal, PacketBlock<Packet,N>& taccImag, PacketBlock<Packetc, N>& acc1, PacketBlock<Packetc, N>& acc2)
{
- a0 = pset1<Packet>(a[0]);
- if (N > 1) {
- a1 = pset1<Packet>(a[1]);
- } else {
- EIGEN_UNUSED_VARIABLE(a1);
+ for (int M = 0; M < N; M++) {
+ acc1.packet[M].v = vec_mergeh(taccReal.packet[M], taccImag.packet[M]);
}
- if (N > 2) {
- a2 = pset1<Packet>(a[2]);
- } else {
- EIGEN_UNUSED_VARIABLE(a2);
- }
- if (N > 3) {
- a3 = pset1<Packet>(a[3]);
- } else {
- EIGEN_UNUSED_VARIABLE(a3);
+
+ if (full) {
+ for (int M = 0; M < N; M++) {
+ acc2.packet[M].v = vec_mergel(taccReal.packet[M], taccImag.packet[M]);
+ }
}
}
-template<> EIGEN_ALWAYS_INLINE void
-pbroadcastN<Packet4f,4>(const float *a,
- Packet4f& a0, Packet4f& a1, Packet4f& a2, Packet4f& a3)
+template<typename Packet, typename Packetc, int N, bool full>
+EIGEN_ALWAYS_INLINE void bcouple(PacketBlock<Packet,N>& taccReal, PacketBlock<Packet,N>& taccImag, PacketBlock<Packetc,N*2>& tRes, PacketBlock<Packetc, N>& acc1, PacketBlock<Packetc, N>& acc2)
{
- a3 = pload<Packet4f>(a);
- a0 = vec_splat(a3, 0);
- a1 = vec_splat(a3, 1);
- a2 = vec_splat(a3, 2);
- a3 = vec_splat(a3, 3);
+ bcouple_common<Packet, Packetc, N, full>(taccReal, taccImag, acc1, acc2);
+
+ for (int M = 0; M < N; M++) {
+ acc1.packet[M] = padd<Packetc>(tRes.packet[M], acc1.packet[M]);
+ }
+
+ if (full) {
+ for (int M = 0; M < N; M++) {
+ acc2.packet[M] = padd<Packetc>(tRes.packet[M+N], acc2.packet[M]);
+ }
+ }
}
// PEEL loop factor.
#define PEEL 7
#define PEEL_ROW 7
-#define MICRO_UNROLL_PEEL(func) \
+#define MICRO_UNROLL(func) \
func(0) func(1) func(2) func(3) func(4) func(5) func(6) func(7)
+#define MICRO_NORMAL_ROWS \
+ accRows == quad_traits<Scalar>::rows || accRows == 1
+
+#define MICRO_NEW_ROWS ((MICRO_NORMAL_ROWS) ? accRows : 1)
+
+#define MICRO_RHS(ptr, N) rhs_##ptr##N
+
#define MICRO_ZERO_PEEL(peel) \
if ((PEEL_ROW > peel) && (peel != 0)) { \
- bsetzero<Scalar, Packet, accRows>(accZero##peel); \
+ bsetzero<Packet, accRows>(accZero##peel); \
} else { \
EIGEN_UNUSED_VARIABLE(accZero##peel); \
}
-#define MICRO_ZERO_PEEL_ROW \
- MICRO_UNROLL_PEEL(MICRO_ZERO_PEEL);
+#define MICRO_ADD(ptr, N) \
+ if (MICRO_NORMAL_ROWS) { \
+ MICRO_RHS(ptr,0) += (accRows * N); \
+ } else { \
+ MICRO_RHS(ptr,0) += N; \
+ MICRO_RHS(ptr,1) += N; \
+ if (accRows == 3) { \
+ MICRO_RHS(ptr,2) += N; \
+ } \
+ }
+
+#define MICRO_ADD_ROWS(N) MICRO_ADD(ptr, N)
+
+#define MICRO_BROADCAST1(peel, ptr, rhsV, real) \
+ if (MICRO_NORMAL_ROWS) { \
+ pbroadcastN<Packet,accRows,real>(MICRO_RHS(ptr,0) + (accRows * peel), MICRO_RHS(ptr,0), MICRO_RHS(ptr,0), rhsV##peel[0], rhsV##peel[1], rhsV##peel[2], rhsV##peel[3]); \
+ } else { \
+ pbroadcastN<Packet,accRows,real>(MICRO_RHS(ptr,0) + peel, MICRO_RHS(ptr,1) + peel, MICRO_RHS(ptr,2) + peel, rhsV##peel[0], rhsV##peel[1], rhsV##peel[2], rhsV##peel[3]); \
+ }
+
+#define MICRO_BROADCAST(peel) MICRO_BROADCAST1(peel, ptr, rhsV, true)
+
+#define MICRO_BROADCAST_EXTRA1(ptr, rhsV, real) \
+ pbroadcastN<Packet,accRows,real>(MICRO_RHS(ptr,0), MICRO_RHS(ptr,1), MICRO_RHS(ptr,2), rhsV[0], rhsV[1], rhsV[2], rhsV[3]);
+
+#define MICRO_BROADCAST_EXTRA \
+ Packet rhsV[4]; \
+ MICRO_BROADCAST_EXTRA1(ptr, rhsV, true) \
+ MICRO_ADD_ROWS(1)
+
+#define MICRO_SRC2(ptr, N, M) \
+ if (MICRO_NORMAL_ROWS) { \
+ EIGEN_UNUSED_VARIABLE(strideB); \
+ EIGEN_UNUSED_VARIABLE(MICRO_RHS(ptr,1)); \
+ EIGEN_UNUSED_VARIABLE(MICRO_RHS(ptr,2)); \
+ } else { \
+ MICRO_RHS(ptr,1) = rhs_base + N + M; \
+ if (accRows == 3) { \
+ MICRO_RHS(ptr,2) = rhs_base + N*2 + M; \
+ } else { \
+ EIGEN_UNUSED_VARIABLE(MICRO_RHS(ptr,2)); \
+ } \
+ }
+
+#define MICRO_SRC2_PTR MICRO_SRC2(ptr, strideB, 0)
+
+#define MICRO_ZERO_PEEL_ROW MICRO_UNROLL(MICRO_ZERO_PEEL)
#define MICRO_WORK_PEEL(peel) \
if (PEEL_ROW > peel) { \
- pbroadcastN<Packet,accRows>(rhs_ptr + (accRows * peel), rhsV##peel[0], rhsV##peel[1], rhsV##peel[2], rhsV##peel[3]); \
+ MICRO_BROADCAST(peel) \
pger<accRows, Scalar, Packet, false>(&accZero##peel, lhs_ptr + (remaining_rows * peel), rhsV##peel); \
} else { \
EIGEN_UNUSED_VARIABLE(rhsV##peel); \
@@ -1400,9 +1374,9 @@
#define MICRO_WORK_PEEL_ROW \
Packet rhsV0[4], rhsV1[4], rhsV2[4], rhsV3[4], rhsV4[4], rhsV5[4], rhsV6[4], rhsV7[4]; \
- MICRO_UNROLL_PEEL(MICRO_WORK_PEEL); \
+ MICRO_UNROLL(MICRO_WORK_PEEL) \
lhs_ptr += (remaining_rows * PEEL_ROW); \
- rhs_ptr += (accRows * PEEL_ROW);
+ MICRO_ADD_ROWS(PEEL_ROW)
#define MICRO_ADD_PEEL(peel, sum) \
if (PEEL_ROW > peel) { \
@@ -1415,17 +1389,34 @@
MICRO_ADD_PEEL(4, 0) MICRO_ADD_PEEL(5, 1) MICRO_ADD_PEEL(6, 2) MICRO_ADD_PEEL(7, 3) \
MICRO_ADD_PEEL(2, 0) MICRO_ADD_PEEL(3, 1) MICRO_ADD_PEEL(1, 0)
+#define MICRO_PREFETCHN1(ptr, N) \
+ EIGEN_POWER_PREFETCH(MICRO_RHS(ptr,0)); \
+ if (N == 2 || N == 3) { \
+ EIGEN_POWER_PREFETCH(MICRO_RHS(ptr,1)); \
+ if (N == 3) { \
+ EIGEN_POWER_PREFETCH(MICRO_RHS(ptr,2)); \
+ } \
+ }
+
+#define MICRO_PREFETCHN(N) MICRO_PREFETCHN1(ptr, N)
+
+#define MICRO_COMPLEX_PREFETCHN(N) \
+ MICRO_PREFETCHN1(ptr_real, N); \
+ if(!RhsIsReal) { \
+ MICRO_PREFETCHN1(ptr_imag, N); \
+ }
+
template<typename Scalar, typename Packet, typename Index, const Index accRows, const Index remaining_rows>
EIGEN_ALWAYS_INLINE void MICRO_EXTRA_ROW(
const Scalar* &lhs_ptr,
- const Scalar* &rhs_ptr,
+ const Scalar* &rhs_ptr0,
+ const Scalar* &rhs_ptr1,
+ const Scalar* &rhs_ptr2,
PacketBlock<Packet,accRows> &accZero)
{
- Packet rhsV[4];
- pbroadcastN<Packet,accRows>(rhs_ptr, rhsV[0], rhsV[1], rhsV[2], rhsV[3]);
+ MICRO_BROADCAST_EXTRA
pger<accRows, Scalar, Packet, false>(&accZero, lhs_ptr, rhsV);
lhs_ptr += remaining_rows;
- rhs_ptr += accRows;
}
template<typename Scalar, typename Packet, typename DataMapper, typename Index, const Index accRows, const Index accCols, const Index remaining_rows>
@@ -1436,60 +1427,71 @@
Index depth,
Index strideA,
Index offsetA,
+ Index strideB,
Index row,
- Index col,
Index rows,
- Index cols,
const Packet& pAlpha,
const Packet& pMask)
{
- const Scalar* rhs_ptr = rhs_base;
+ const Scalar* rhs_ptr0 = rhs_base, * rhs_ptr1 = NULL, * rhs_ptr2 = NULL;
const Scalar* lhs_ptr = lhs_base + row*strideA + remaining_rows*offsetA;
PacketBlock<Packet,accRows> accZero0, accZero1, accZero2, accZero3, accZero4, accZero5, accZero6, accZero7, acc;
- bsetzero<Scalar, Packet, accRows>(accZero0);
+ MICRO_SRC2_PTR
+ bsetzero<Packet, accRows>(accZero0);
- Index remaining_depth = (col + quad_traits<Scalar>::rows < cols) ? depth : (depth & -quad_traits<Scalar>::rows);
+ Index remaining_depth = depth & -quad_traits<Scalar>::rows;
Index k = 0;
if (remaining_depth >= PEEL_ROW) {
MICRO_ZERO_PEEL_ROW
do
{
- EIGEN_POWER_PREFETCH(rhs_ptr);
+ MICRO_PREFETCHN(accRows)
EIGEN_POWER_PREFETCH(lhs_ptr);
MICRO_WORK_PEEL_ROW
} while ((k += PEEL_ROW) + PEEL_ROW <= remaining_depth);
MICRO_ADD_PEEL_ROW
}
- for(; k < remaining_depth; k++)
+ for(; k < depth; k++)
{
- MICRO_EXTRA_ROW<Scalar, Packet, Index, accRows, remaining_rows>(lhs_ptr, rhs_ptr, accZero0);
+ MICRO_EXTRA_ROW<Scalar, Packet, Index, accRows, remaining_rows>(lhs_ptr, rhs_ptr0, rhs_ptr1, rhs_ptr2, accZero0);
}
- if ((remaining_depth == depth) && (rows >= accCols))
+ bload<DataMapper, Packet, Index, 0, ColMajor, false, accRows>(acc, res, row, 0);
+ if ((accRows == 1) || (rows >= accCols))
{
- bload<DataMapper, Packet, Index, 0, ColMajor, false, accRows>(acc, res, row, 0);
- bscale<Packet,accRows>(acc, accZero0, pAlpha, pMask);
- res.template storePacketBlock<Packet,accRows>(row, 0, acc);
+ bscale<Packet,accRows,true>(acc, accZero0, pAlpha, pMask);
+ bstore<DataMapper, Packet, Index, accRows>(acc, res, row);
} else {
- for(; k < depth; k++)
- {
- Packet rhsV[4];
- pbroadcastN<Packet,accRows>(rhs_ptr, rhsV[0], rhsV[1], rhsV[2], rhsV[3]);
- pger<accRows, Scalar, Packet, Index, false, remaining_rows>(&accZero0, lhs_ptr, rhsV);
- lhs_ptr += remaining_rows;
- rhs_ptr += accRows;
- }
-
+ bscale<Packet,accRows,false>(acc, accZero0, pAlpha, pMask);
for(Index j = 0; j < accRows; j++) {
- accZero0.packet[j] = vec_mul(pAlpha, accZero0.packet[j]);
for(Index i = 0; i < remaining_rows; i++) {
- res(row + i, j) += accZero0.packet[j][i];
+ res(row + i, j) = acc.packet[j][i];
}
}
}
}
+#define MICRO_EXTRA(MICRO_EXTRA_UNROLL, value, is_col) \
+ switch(value) { \
+ default: \
+ MICRO_EXTRA_UNROLL(1) \
+ break; \
+ case 2: \
+ if (is_col || (sizeof(Scalar) == sizeof(float))) { \
+ MICRO_EXTRA_UNROLL(2) \
+ } \
+ break; \
+ case 3: \
+ if (is_col || (sizeof(Scalar) == sizeof(float))) { \
+ MICRO_EXTRA_UNROLL(3) \
+ } \
+ break; \
+ }
+
+#define MICRO_EXTRA_ROWS(N) \
+ gemm_unrolled_row_iteration<Scalar, Packet, DataMapper, Index, accRows, accCols, N>(res, lhs_base, rhs_base, depth, strideA, offsetA, strideB, row, rows, pAlpha, pMask);
+
template<typename Scalar, typename Packet, typename DataMapper, typename Index, const Index accRows, const Index accCols>
EIGEN_ALWAYS_INLINE void gemm_extra_row(
const DataMapper& res,
@@ -1498,46 +1500,20 @@
Index depth,
Index strideA,
Index offsetA,
+ Index strideB,
Index row,
- Index col,
Index rows,
- Index cols,
Index remaining_rows,
const Packet& pAlpha,
const Packet& pMask)
{
- switch(remaining_rows) {
- case 1:
- gemm_unrolled_row_iteration<Scalar, Packet, DataMapper, Index, accRows, accCols, 1>(res, lhs_base, rhs_base, depth, strideA, offsetA, row, col, rows, cols, pAlpha, pMask);
- break;
- case 2:
- if (sizeof(Scalar) == sizeof(float)) {
- gemm_unrolled_row_iteration<Scalar, Packet, DataMapper, Index, accRows, accCols, 2>(res, lhs_base, rhs_base, depth, strideA, offsetA, row, col, rows, cols, pAlpha, pMask);
- }
- break;
- default:
- if (sizeof(Scalar) == sizeof(float)) {
- gemm_unrolled_row_iteration<Scalar, Packet, DataMapper, Index, accRows, accCols, 3>(res, lhs_base, rhs_base, depth, strideA, offsetA, row, col, rows, cols, pAlpha, pMask);
- }
- break;
- }
+ MICRO_EXTRA(MICRO_EXTRA_ROWS, remaining_rows, false)
}
-#define MICRO_UNROLL(func) \
- func(0) func(1) func(2) func(3) func(4) func(5) func(6) func(7)
-
#define MICRO_UNROLL_WORK(func, func2, peel) \
- MICRO_UNROLL(func2); \
- func(0,peel) func(1,peel) func(2,peel) func(3,peel) \
- func(4,peel) func(5,peel) func(6,peel) func(7,peel)
-
-#define MICRO_LOAD_ONE(iter) \
- if (unroll_factor > iter) { \
- lhsV##iter = ploadLhs<Scalar, Packet>(lhs_ptr##iter); \
- lhs_ptr##iter += accCols; \
- } else { \
- EIGEN_UNUSED_VARIABLE(lhsV##iter); \
- }
+ MICRO_UNROLL(func2); \
+ func(0,peel) func(1,peel) func(2,peel) func(3,peel) \
+ func(4,peel) func(5,peel) func(6,peel) func(7,peel)
#define MICRO_WORK_ONE(iter, peel) \
if (unroll_factor > iter) { \
@@ -1547,7 +1523,7 @@
#define MICRO_TYPE_PEEL4(func, func2, peel) \
if (PEEL > peel) { \
Packet lhsV0, lhsV1, lhsV2, lhsV3, lhsV4, lhsV5, lhsV6, lhsV7; \
- pbroadcastN<Packet,accRows>(rhs_ptr + (accRows * peel), rhsV##peel[0], rhsV##peel[1], rhsV##peel[2], rhsV##peel[3]); \
+ MICRO_BROADCAST(peel) \
MICRO_UNROLL_WORK(func, func2, peel) \
} else { \
EIGEN_UNUSED_VARIABLE(rhsV##peel); \
@@ -1555,79 +1531,71 @@
#define MICRO_UNROLL_TYPE_PEEL(M, func, func1, func2) \
Packet rhsV0[M], rhsV1[M], rhsV2[M], rhsV3[M], rhsV4[M], rhsV5[M], rhsV6[M], rhsV7[M]; \
- func(func1,func2,0); func(func1,func2,1); \
- func(func1,func2,2); func(func1,func2,3); \
- func(func1,func2,4); func(func1,func2,5); \
- func(func1,func2,6); func(func1,func2,7);
+ func(func1,func2,0) func(func1,func2,1) \
+ func(func1,func2,2) func(func1,func2,3) \
+ func(func1,func2,4) func(func1,func2,5) \
+ func(func1,func2,6) func(func1,func2,7)
#define MICRO_UNROLL_TYPE_ONE(M, func, func1, func2) \
Packet rhsV0[M]; \
- func(func1,func2,0);
+ func(func1,func2,0)
-#define MICRO_ONE_PEEL4 \
- MICRO_UNROLL_TYPE_PEEL(4, MICRO_TYPE_PEEL4, MICRO_WORK_ONE, MICRO_LOAD_ONE); \
- rhs_ptr += (accRows * PEEL);
+#define MICRO_UNROLL_TYPE(MICRO_TYPE, size) \
+ MICRO_TYPE(4, MICRO_TYPE_PEEL4, MICRO_WORK_ONE, MICRO_LOAD_ONE) \
+ MICRO_ADD_ROWS(size)
-#define MICRO_ONE4 \
- MICRO_UNROLL_TYPE_ONE(4, MICRO_TYPE_PEEL4, MICRO_WORK_ONE, MICRO_LOAD_ONE); \
- rhs_ptr += accRows;
+#define MICRO_ONE_PEEL4 MICRO_UNROLL_TYPE(MICRO_UNROLL_TYPE_PEEL, PEEL)
+
+#define MICRO_ONE4 MICRO_UNROLL_TYPE(MICRO_UNROLL_TYPE_ONE, 1)
#define MICRO_DST_PTR_ONE(iter) \
if (unroll_factor > iter) { \
- bsetzero<Scalar, Packet, accRows>(accZero##iter); \
+ bsetzero<Packet, accRows>(accZero##iter); \
} else { \
EIGEN_UNUSED_VARIABLE(accZero##iter); \
}
#define MICRO_DST_PTR MICRO_UNROLL(MICRO_DST_PTR_ONE)
-#define MICRO_SRC_PTR_ONE(iter) \
- if (unroll_factor > iter) { \
- lhs_ptr##iter = lhs_base + ( (row/accCols) + iter )*strideA*accCols; \
- } else { \
- EIGEN_UNUSED_VARIABLE(lhs_ptr##iter); \
- }
-
#define MICRO_SRC_PTR MICRO_UNROLL(MICRO_SRC_PTR_ONE)
-#define MICRO_PREFETCH_ONE(iter) \
- if (unroll_factor > iter) { \
- EIGEN_POWER_PREFETCH(lhs_ptr##iter); \
- }
-
#define MICRO_PREFETCH MICRO_UNROLL(MICRO_PREFETCH_ONE)
#define MICRO_STORE_ONE(iter) \
if (unroll_factor > iter) { \
bload<DataMapper, Packet, Index, 0, ColMajor, false, accRows>(acc, res, row + iter*accCols, 0); \
- bscale<Packet,accRows>(acc, accZero##iter, pAlpha); \
- res.template storePacketBlock<Packet,accRows>(row + iter*accCols, 0, acc); \
+ bscale<Packet,accRows,!(MICRO_NORMAL(iter))>(acc, accZero##iter, pAlpha, pMask); \
+ bstore<DataMapper, Packet, Index, accRows>(acc, res, row + iter*accCols); \
}
#define MICRO_STORE MICRO_UNROLL(MICRO_STORE_ONE)
-template<int unroll_factor, typename Scalar, typename Packet, typename DataMapper, typename Index, const Index accRows, const Index accCols>
-EIGEN_STRONG_INLINE void gemm_unrolled_iteration(
+template<int unroll_factor, typename Scalar, typename Packet, typename DataMapper, typename Index, const Index accRows, const Index accCols, const Index accCols2>
+EIGEN_ALWAYS_INLINE void gemm_unrolled_iteration(
const DataMapper& res,
const Scalar* lhs_base,
const Scalar* rhs_base,
Index depth,
Index strideA,
+ Index offsetA,
+ Index strideB,
Index& row,
- const Packet& pAlpha)
+ const Packet& pAlpha,
+ const Packet& pMask)
{
- const Scalar* rhs_ptr = rhs_base;
- const Scalar* lhs_ptr0 = NULL, * lhs_ptr1 = NULL, * lhs_ptr2 = NULL, * lhs_ptr3 = NULL, * lhs_ptr4 = NULL, * lhs_ptr5 = NULL, * lhs_ptr6 = NULL, * lhs_ptr7 = NULL;
+ const Scalar* rhs_ptr0 = rhs_base, * rhs_ptr1 = NULL, * rhs_ptr2 = NULL;
+ const Scalar* lhs_ptr0 = NULL, * lhs_ptr1 = NULL, * lhs_ptr2 = NULL, * lhs_ptr3 = NULL, * lhs_ptr4 = NULL, * lhs_ptr5 = NULL, * lhs_ptr6 = NULL, * lhs_ptr7 = NULL;
PacketBlock<Packet,accRows> accZero0, accZero1, accZero2, accZero3, accZero4, accZero5, accZero6, accZero7;
PacketBlock<Packet,accRows> acc;
+ MICRO_SRC2_PTR
MICRO_SRC_PTR
MICRO_DST_PTR
Index k = 0;
for(; k + PEEL <= depth; k+= PEEL)
{
- EIGEN_POWER_PREFETCH(rhs_ptr);
+ MICRO_PREFETCHN(accRows)
MICRO_PREFETCH
MICRO_ONE_PEEL4
}
@@ -1637,9 +1605,13 @@
}
MICRO_STORE
- row += unroll_factor*accCols;
+ MICRO_UPDATE
}
+#define MICRO_UNROLL_ITER2(N, M) \
+ gemm_unrolled_iteration<N + ((M) ? 1 : 0), Scalar, Packet, DataMapper, Index, accRows, accCols, M ? M : accCols>(res3, lhs_base, rhs_base, depth, strideA, offsetA, strideB, row, pAlpha, pMask); \
+ if (M) return;
+
template<typename Scalar, typename Packet, typename DataMapper, typename Index, const Index accRows, const Index accCols>
EIGEN_ALWAYS_INLINE void gemm_cols(
const DataMapper& res,
@@ -1652,55 +1624,54 @@
Index offsetB,
Index col,
Index rows,
- Index cols,
Index remaining_rows,
const Packet& pAlpha,
const Packet& pMask)
{
const DataMapper res3 = res.getSubMapper(0, col);
- const Scalar* rhs_base = blockB + col*strideB + accRows*offsetB;
+ const Scalar* rhs_base = blockB + col*strideB + MICRO_NEW_ROWS*offsetB;
const Scalar* lhs_base = blockA + accCols*offsetA;
Index row = 0;
-#define MAX_UNROLL 6
+#define MAX_UNROLL 7
while(row + MAX_UNROLL*accCols <= rows) {
- gemm_unrolled_iteration<MAX_UNROLL, Scalar, Packet, DataMapper, Index, accRows, accCols>(res3, lhs_base, rhs_base, depth, strideA, row, pAlpha);
+ MICRO_UNROLL_ITER2(MAX_UNROLL, 0);
}
switch( (rows-row)/accCols ) {
#if MAX_UNROLL > 7
case 7:
- gemm_unrolled_iteration<7, Scalar, Packet, DataMapper, Index, accRows, accCols>(res3, lhs_base, rhs_base, depth, strideA, row, pAlpha);
+ MICRO_UNROLL_ITER(MICRO_UNROLL_ITER2, 7)
break;
#endif
#if MAX_UNROLL > 6
case 6:
- gemm_unrolled_iteration<6, Scalar, Packet, DataMapper, Index, accRows, accCols>(res3, lhs_base, rhs_base, depth, strideA, row, pAlpha);
+ MICRO_UNROLL_ITER(MICRO_UNROLL_ITER2, 6)
break;
#endif
#if MAX_UNROLL > 5
case 5:
- gemm_unrolled_iteration<5, Scalar, Packet, DataMapper, Index, accRows, accCols>(res3, lhs_base, rhs_base, depth, strideA, row, pAlpha);
+ MICRO_UNROLL_ITER(MICRO_UNROLL_ITER2, 5)
break;
#endif
#if MAX_UNROLL > 4
case 4:
- gemm_unrolled_iteration<4, Scalar, Packet, DataMapper, Index, accRows, accCols>(res3, lhs_base, rhs_base, depth, strideA, row, pAlpha);
+ MICRO_UNROLL_ITER(MICRO_UNROLL_ITER2, 4)
break;
#endif
#if MAX_UNROLL > 3
case 3:
- gemm_unrolled_iteration<3, Scalar, Packet, DataMapper, Index, accRows, accCols>(res3, lhs_base, rhs_base, depth, strideA, row, pAlpha);
+ MICRO_UNROLL_ITER(MICRO_UNROLL_ITER2, 3)
break;
#endif
#if MAX_UNROLL > 2
case 2:
- gemm_unrolled_iteration<2, Scalar, Packet, DataMapper, Index, accRows, accCols>(res3, lhs_base, rhs_base, depth, strideA, row, pAlpha);
+ MICRO_UNROLL_ITER(MICRO_UNROLL_ITER2, 2)
break;
#endif
#if MAX_UNROLL > 1
case 1:
- gemm_unrolled_iteration<1, Scalar, Packet, DataMapper, Index, accRows, accCols>(res3, lhs_base, rhs_base, depth, strideA, row, pAlpha);
+ MICRO_UNROLL_ITER(MICRO_UNROLL_ITER2, 1)
break;
#endif
default:
@@ -1710,10 +1681,13 @@
if(remaining_rows > 0)
{
- gemm_extra_row<Scalar, Packet, DataMapper, Index, accRows, accCols>(res3, blockA, rhs_base, depth, strideA, offsetA, row, col, rows, cols, remaining_rows, pAlpha, pMask);
+ gemm_extra_row<Scalar, Packet, DataMapper, Index, accRows, accCols>(res3, blockA, rhs_base, depth, strideA, offsetA, strideB, row, rows, remaining_rows, pAlpha, pMask);
}
}
+#define MICRO_EXTRA_COLS(N) \
+ gemm_cols<Scalar, Packet, DataMapper, Index, N, accCols>(res, blockA, blockB, depth, strideA, offsetA, strideB, offsetB, col, rows, remaining_rows, pAlpha, pMask);
+
template<typename Scalar, typename Packet, typename DataMapper, typename Index, const Index accCols>
EIGEN_STRONG_INLINE void gemm_extra_cols(
const DataMapper& res,
@@ -1731,9 +1705,7 @@
const Packet& pAlpha,
const Packet& pMask)
{
- for (; col < cols; col++) {
- gemm_cols<Scalar, Packet, DataMapper, Index, 1, accCols>(res, blockA, blockB, depth, strideA, offsetA, strideB, offsetB, col, rows, cols, remaining_rows, pAlpha, pMask);
- }
+ MICRO_EXTRA(MICRO_EXTRA_COLS, cols-col, true)
}
/****************
@@ -1753,10 +1725,13 @@
Index col = 0;
for(; col + accRows <= cols; col += accRows)
{
- gemm_cols<Scalar, Packet, DataMapper, Index, accRows, accCols>(res, blockA, blockB, depth, strideA, offsetA, strideB, offsetB, col, rows, cols, remaining_rows, pAlpha, pMask);
+ gemm_cols<Scalar, Packet, DataMapper, Index, accRows, accCols>(res, blockA, blockB, depth, strideA, offsetA, strideB, offsetB, col, rows, remaining_rows, pAlpha, pMask);
}
- gemm_extra_cols<Scalar, Packet, DataMapper, Index, accCols>(res, blockA, blockB, depth, strideA, offsetA, strideB, offsetB, col, rows, cols, remaining_rows, pAlpha, pMask);
+ if (col != cols)
+ {
+ gemm_extra_cols<Scalar, Packet, DataMapper, Index, accCols>(res, blockA, blockB, depth, strideA, offsetA, strideB, offsetB, col, rows, cols, remaining_rows, pAlpha, pMask);
+ }
}
#define accColsC (accCols / 2)
@@ -1767,41 +1742,79 @@
#define PEEL_COMPLEX 3
#define PEEL_COMPLEX_ROW 3
-#define MICRO_COMPLEX_UNROLL_PEEL(func) \
+#define MICRO_COMPLEX_UNROLL(func) \
func(0) func(1) func(2) func(3)
#define MICRO_COMPLEX_ZERO_PEEL(peel) \
if ((PEEL_COMPLEX_ROW > peel) && (peel != 0)) { \
- bsetzero<Scalar, Packet, accRows>(accReal##peel); \
- bsetzero<Scalar, Packet, accRows>(accImag##peel); \
+ bsetzero<Packet, accRows>(accReal##peel); \
+ bsetzero<Packet, accRows>(accImag##peel); \
} else { \
EIGEN_UNUSED_VARIABLE(accReal##peel); \
EIGEN_UNUSED_VARIABLE(accImag##peel); \
}
-#define MICRO_COMPLEX_ZERO_PEEL_ROW \
- MICRO_COMPLEX_UNROLL_PEEL(MICRO_COMPLEX_ZERO_PEEL);
+#define MICRO_COMPLEX_ADD_ROWS(N, used) \
+ MICRO_ADD(ptr_real, N) \
+ if (!RhsIsReal) { \
+ MICRO_ADD(ptr_imag, N) \
+ } else if (used) { \
+ EIGEN_UNUSED_VARIABLE(MICRO_RHS(ptr_imag,0)); \
+ EIGEN_UNUSED_VARIABLE(MICRO_RHS(ptr_imag,1)); \
+ EIGEN_UNUSED_VARIABLE(MICRO_RHS(ptr_imag,2)); \
+ }
+
+#define MICRO_COMPLEX_BROADCAST(peel) \
+ MICRO_BROADCAST1(peel, ptr_real, rhsV, false) \
+ if (!RhsIsReal) { \
+ MICRO_BROADCAST1(peel, ptr_imag, rhsVi, false) \
+ } else { \
+ EIGEN_UNUSED_VARIABLE(rhsVi##peel); \
+ }
+
+#define MICRO_COMPLEX_BROADCAST_EXTRA \
+ Packet rhsV[4], rhsVi[4]; \
+ MICRO_BROADCAST_EXTRA1(ptr_real, rhsV, false) \
+ if(!RhsIsReal) { \
+ MICRO_BROADCAST_EXTRA1(ptr_imag, rhsVi, false) \
+ } else { \
+ EIGEN_UNUSED_VARIABLE(rhsVi); \
+ } \
+ MICRO_COMPLEX_ADD_ROWS(1, true)
+
+#define MICRO_COMPLEX_SRC2_PTR \
+ MICRO_SRC2(ptr_real, strideB*advanceCols, 0) \
+ if (!RhsIsReal) { \
+ MICRO_RHS(ptr_imag,0) = rhs_base + MICRO_NEW_ROWS*strideB; \
+ MICRO_SRC2(ptr_imag, strideB*advanceCols, strideB) \
+ } else { \
+ EIGEN_UNUSED_VARIABLE(MICRO_RHS(ptr_imag,0)); \
+ EIGEN_UNUSED_VARIABLE(MICRO_RHS(ptr_imag,1)); \
+ EIGEN_UNUSED_VARIABLE(MICRO_RHS(ptr_imag,2)); \
+ }
+
+#define MICRO_COMPLEX_ZERO_PEEL_ROW MICRO_COMPLEX_UNROLL(MICRO_COMPLEX_ZERO_PEEL)
#define MICRO_COMPLEX_WORK_PEEL(peel) \
if (PEEL_COMPLEX_ROW > peel) { \
- pbroadcastN_old<Packet,accRows>(rhs_ptr_real + (accRows * peel), rhsV##peel[0], rhsV##peel[1], rhsV##peel[2], rhsV##peel[3]); \
- if(!RhsIsReal) pbroadcastN_old<Packet,accRows>(rhs_ptr_imag + (accRows * peel), rhsVi##peel[0], rhsVi##peel[1], rhsVi##peel[2], rhsVi##peel[3]); \
+ MICRO_COMPLEX_BROADCAST(peel) \
pgerc<accRows, Scalar, Packet, ConjugateLhs, ConjugateRhs, LhsIsReal, RhsIsReal>(&accReal##peel, &accImag##peel, lhs_ptr_real + (remaining_rows * peel), lhs_ptr_imag + (remaining_rows * peel), rhsV##peel, rhsVi##peel); \
} else { \
EIGEN_UNUSED_VARIABLE(rhsV##peel); \
EIGEN_UNUSED_VARIABLE(rhsVi##peel); \
}
+#define MICRO_COMPLEX_ADD_COLS(size) \
+ lhs_ptr_real += (remaining_rows * size); \
+ if(!LhsIsReal) lhs_ptr_imag += (remaining_rows * size); \
+ else EIGEN_UNUSED_VARIABLE(lhs_ptr_imag);
+
#define MICRO_COMPLEX_WORK_PEEL_ROW \
Packet rhsV0[4], rhsV1[4], rhsV2[4], rhsV3[4]; \
Packet rhsVi0[4], rhsVi1[4], rhsVi2[4], rhsVi3[4]; \
- MICRO_COMPLEX_UNROLL_PEEL(MICRO_COMPLEX_WORK_PEEL); \
- lhs_ptr_real += (remaining_rows * PEEL_COMPLEX_ROW); \
- if(!LhsIsReal) lhs_ptr_imag += (remaining_rows * PEEL_COMPLEX_ROW); \
- else EIGEN_UNUSED_VARIABLE(lhs_ptr_imag); \
- rhs_ptr_real += (accRows * PEEL_COMPLEX_ROW); \
- if(!RhsIsReal) rhs_ptr_imag += (accRows * PEEL_COMPLEX_ROW); \
- else EIGEN_UNUSED_VARIABLE(rhs_ptr_imag);
+ MICRO_COMPLEX_UNROLL(MICRO_COMPLEX_WORK_PEEL) \
+ MICRO_COMPLEX_ADD_COLS(PEEL_COMPLEX_ROW) \
+ MICRO_COMPLEX_ADD_ROWS(PEEL_COMPLEX_ROW, false)
#define MICRO_COMPLEX_ADD_PEEL(peel, sum) \
if (PEEL_COMPLEX_ROW > peel) { \
@@ -1818,19 +1831,13 @@
template<typename Scalar, typename Packet, typename Index, const Index accRows, bool ConjugateLhs, bool ConjugateRhs, bool LhsIsReal, bool RhsIsReal, const Index remaining_rows>
EIGEN_ALWAYS_INLINE void MICRO_COMPLEX_EXTRA_ROW(
const Scalar* &lhs_ptr_real, const Scalar* &lhs_ptr_imag,
- const Scalar* &rhs_ptr_real, const Scalar* &rhs_ptr_imag,
+ const Scalar* &rhs_ptr_real0, const Scalar* &rhs_ptr_real1, const Scalar* &rhs_ptr_real2,
+ const Scalar* &rhs_ptr_imag0, const Scalar* &rhs_ptr_imag1, const Scalar* &rhs_ptr_imag2,
PacketBlock<Packet,accRows> &accReal, PacketBlock<Packet,accRows> &accImag)
{
- Packet rhsV[4], rhsVi[4];
- pbroadcastN_old<Packet,accRows>(rhs_ptr_real, rhsV[0], rhsV[1], rhsV[2], rhsV[3]);
- if(!RhsIsReal) pbroadcastN_old<Packet,accRows>(rhs_ptr_imag, rhsVi[0], rhsVi[1], rhsVi[2], rhsVi[3]);
+ MICRO_COMPLEX_BROADCAST_EXTRA
pgerc<accRows, Scalar, Packet, ConjugateLhs, ConjugateRhs, LhsIsReal, RhsIsReal>(&accReal, &accImag, lhs_ptr_real, lhs_ptr_imag, rhsV, rhsVi);
- lhs_ptr_real += remaining_rows;
- if(!LhsIsReal) lhs_ptr_imag += remaining_rows;
- else EIGEN_UNUSED_VARIABLE(lhs_ptr_imag);
- rhs_ptr_real += accRows;
- if(!RhsIsReal) rhs_ptr_imag += accRows;
- else EIGEN_UNUSED_VARIABLE(rhs_ptr_imag);
+ MICRO_COMPLEX_ADD_COLS(1)
}
template<typename Scalar, typename Packet, typename Packetc, typename DataMapper, typename Index, const Index accRows, const Index accCols, bool ConjugateLhs, bool ConjugateRhs, bool LhsIsReal, bool RhsIsReal, const Index remaining_rows>
@@ -1843,17 +1850,13 @@
Index offsetA,
Index strideB,
Index row,
- Index col,
Index rows,
- Index cols,
const Packet& pAlphaReal,
const Packet& pAlphaImag,
const Packet& pMask)
{
- const Scalar* rhs_ptr_real = rhs_base;
- const Scalar* rhs_ptr_imag = NULL;
- if(!RhsIsReal) rhs_ptr_imag = rhs_base + accRows*strideB;
- else EIGEN_UNUSED_VARIABLE(rhs_ptr_imag);
+ const Scalar* rhs_ptr_real0 = rhs_base, * rhs_ptr_real1 = NULL, * rhs_ptr_real2 = NULL;
+ const Scalar* rhs_ptr_imag0 = NULL, * rhs_ptr_imag1 = NULL, * rhs_ptr_imag2 = NULL;
const Scalar* lhs_ptr_real = lhs_base + advanceRows*row*strideA + remaining_rows*offsetA;
const Scalar* lhs_ptr_imag = NULL;
if(!LhsIsReal) lhs_ptr_imag = lhs_ptr_real + remaining_rows*strideA;
@@ -1863,19 +1866,18 @@
PacketBlock<Packetc,accRows> acc0, acc1;
PacketBlock<Packetc,accRows*2> tRes;
- bsetzero<Scalar, Packet, accRows>(accReal0);
- bsetzero<Scalar, Packet, accRows>(accImag0);
+ MICRO_COMPLEX_SRC2_PTR
- Index remaining_depth = (col + quad_traits<Scalar>::rows < cols) ? depth : (depth & -quad_traits<Scalar>::rows);
+ bsetzero<Packet, accRows>(accReal0);
+ bsetzero<Packet, accRows>(accImag0);
+
+ Index remaining_depth = depth & -quad_traits<Scalar>::rows;
Index k = 0;
if (remaining_depth >= PEEL_COMPLEX_ROW) {
MICRO_COMPLEX_ZERO_PEEL_ROW
do
{
- EIGEN_POWER_PREFETCH(rhs_ptr_real);
- if(!RhsIsReal) {
- EIGEN_POWER_PREFETCH(rhs_ptr_imag);
- }
+ MICRO_COMPLEX_PREFETCHN(accRows)
EIGEN_POWER_PREFETCH(lhs_ptr_real);
if(!LhsIsReal) {
EIGEN_POWER_PREFETCH(lhs_ptr_imag);
@@ -1884,52 +1886,44 @@
} while ((k += PEEL_COMPLEX_ROW) + PEEL_COMPLEX_ROW <= remaining_depth);
MICRO_COMPLEX_ADD_PEEL_ROW
}
- for(; k < remaining_depth; k++)
+ for(; k < depth; k++)
{
- MICRO_COMPLEX_EXTRA_ROW<Scalar, Packet, Index, accRows, ConjugateLhs, ConjugateRhs, LhsIsReal, RhsIsReal, remaining_rows>(lhs_ptr_real, lhs_ptr_imag, rhs_ptr_real, rhs_ptr_imag, accReal0, accImag0);
+ MICRO_COMPLEX_EXTRA_ROW<Scalar, Packet, Index, accRows, ConjugateLhs, ConjugateRhs, LhsIsReal, RhsIsReal, remaining_rows>(lhs_ptr_real, lhs_ptr_imag, rhs_ptr_real0, rhs_ptr_real1, rhs_ptr_real2, rhs_ptr_imag0, rhs_ptr_imag1, rhs_ptr_imag2, accReal0, accImag0);
}
- if ((remaining_depth == depth) && (rows >= accCols))
+ const bool full = (remaining_rows > accColsC);
+ bload<DataMapper, Packetc, Index, accColsC, ColMajor, true, accRows, full>(tRes, res, row, 0);
+ if ((accRows == 1) || (rows >= accCols))
{
- bload<DataMapper, Packetc, Index, accColsC, ColMajor, true, accRows>(tRes, res, row, 0);
- bscalec<Packet,accRows>(accReal0, accImag0, pAlphaReal, pAlphaImag, taccReal, taccImag, pMask);
- bcouple<Packet, Packetc, accRows>(taccReal, taccImag, tRes, acc0, acc1);
- res.template storePacketBlock<Packetc,accRows>(row + 0, 0, acc0);
- res.template storePacketBlock<Packetc,accRows>(row + accColsC, 0, acc1);
- } else {
- for(; k < depth; k++)
- {
- Packet rhsV[4], rhsVi[4];
- pbroadcastN_old<Packet,accRows>(rhs_ptr_real, rhsV[0], rhsV[1], rhsV[2], rhsV[3]);
- if(!RhsIsReal) pbroadcastN_old<Packet,accRows>(rhs_ptr_imag, rhsVi[0], rhsVi[1], rhsVi[2], rhsVi[3]);
- pgerc<accRows, Scalar, Packet, Index, ConjugateLhs, ConjugateRhs, LhsIsReal, RhsIsReal, remaining_rows>(&accReal0, &accImag0, lhs_ptr_real, lhs_ptr_imag, rhsV, rhsVi);
- lhs_ptr_real += remaining_rows;
- if(!LhsIsReal) lhs_ptr_imag += remaining_rows;
- rhs_ptr_real += accRows;
- if(!RhsIsReal) rhs_ptr_imag += accRows;
+ bscalec<Packet,accRows,true>(accReal0, accImag0, pAlphaReal, pAlphaImag, taccReal, taccImag, pMask);
+ bcouple<Packet, Packetc, accRows, full>(taccReal, taccImag, tRes, acc0, acc1);
+ bstore<DataMapper, Packetc, Index, accRows>(acc0, res, row + 0);
+ if (full) {
+ bstore<DataMapper, Packetc, Index, accRows>(acc1, res, row + accColsC);
}
-
- bscalec<Packet,accRows>(accReal0, accImag0, pAlphaReal, pAlphaImag, taccReal, taccImag);
- bcouple_common<Packet, Packetc, accRows>(taccReal, taccImag, acc0, acc1);
+ } else {
+ bscalec<Packet,accRows,false>(accReal0, accImag0, pAlphaReal, pAlphaImag, taccReal, taccImag, pMask);
+ bcouple<Packet, Packetc, accRows, full>(taccReal, taccImag, tRes, acc0, acc1);
if ((sizeof(Scalar) == sizeof(float)) && (remaining_rows == 1))
{
for(Index j = 0; j < accRows; j++) {
- res(row + 0, j) += pfirst<Packetc>(acc0.packet[j]);
+ res(row + 0, j) = pfirst<Packetc>(acc0.packet[j]);
}
} else {
- for(Index j = 0; j < accRows; j++) {
- PacketBlock<Packetc,1> acc2;
- acc2.packet[0] = res.template loadPacket<Packetc>(row + 0, j) + acc0.packet[j];
- res.template storePacketBlock<Packetc,1>(row + 0, j, acc2);
- if(remaining_rows > accColsC) {
- res(row + accColsC, j) += pfirst<Packetc>(acc1.packet[j]);
+ bstore<DataMapper, Packetc, Index, accRows>(acc0, res, row + 0);
+ if (full) {
+ for(Index j = 0; j < accRows; j++) {
+ res(row + accColsC, j) = pfirst<Packetc>(acc1.packet[j]);
}
}
}
}
}
+#define MICRO_COMPLEX_EXTRA_ROWS(N) \
+ gemm_unrolled_complex_row_iteration<Scalar, Packet, Packetc, DataMapper, Index, accRows, accCols, ConjugateLhs, ConjugateRhs, LhsIsReal, RhsIsReal, N>(res, lhs_base, rhs_base, depth, strideA, offsetA, strideB, row, rows, pAlphaReal, pAlphaImag, pMask);
+
template<typename Scalar, typename Packet, typename Packetc, typename DataMapper, typename Index, const Index accRows, const Index accCols, bool ConjugateLhs, bool ConjugateRhs, bool LhsIsReal, bool RhsIsReal>
EIGEN_ALWAYS_INLINE void gemm_complex_extra_row(
const DataMapper& res,
@@ -1940,51 +1934,18 @@
Index offsetA,
Index strideB,
Index row,
- Index col,
Index rows,
- Index cols,
Index remaining_rows,
const Packet& pAlphaReal,
const Packet& pAlphaImag,
const Packet& pMask)
{
- switch(remaining_rows) {
- case 1:
- gemm_unrolled_complex_row_iteration<Scalar, Packet, Packetc, DataMapper, Index, accRows, accCols, ConjugateLhs, ConjugateRhs, LhsIsReal, RhsIsReal, 1>(res, lhs_base, rhs_base, depth, strideA, offsetA, strideB, row, col, rows, cols, pAlphaReal, pAlphaImag, pMask);
- break;
- case 2:
- if (sizeof(Scalar) == sizeof(float)) {
- gemm_unrolled_complex_row_iteration<Scalar, Packet, Packetc, DataMapper, Index, accRows, accCols, ConjugateLhs, ConjugateRhs, LhsIsReal, RhsIsReal, 2>(res, lhs_base, rhs_base, depth, strideA, offsetA, strideB, row, col, rows, cols, pAlphaReal, pAlphaImag, pMask);
- }
- break;
- default:
- if (sizeof(Scalar) == sizeof(float)) {
- gemm_unrolled_complex_row_iteration<Scalar, Packet, Packetc, DataMapper, Index, accRows, accCols, ConjugateLhs, ConjugateRhs, LhsIsReal, RhsIsReal, 3>(res, lhs_base, rhs_base, depth, strideA, offsetA, strideB, row, col, rows, cols, pAlphaReal, pAlphaImag, pMask);
- }
- break;
- }
+ MICRO_EXTRA(MICRO_COMPLEX_EXTRA_ROWS, remaining_rows, false)
}
-#define MICRO_COMPLEX_UNROLL(func) \
- func(0) func(1) func(2) func(3)
-
#define MICRO_COMPLEX_UNROLL_WORK(func, func2, peel) \
- MICRO_COMPLEX_UNROLL(func2); \
- func(0,peel) func(1,peel) func(2,peel) func(3,peel)
-
-#define MICRO_COMPLEX_LOAD_ONE(iter) \
- if (unroll_factor > iter) { \
- lhsV##iter = ploadLhs<Scalar, Packet>(lhs_ptr_real##iter); \
- if(!LhsIsReal) { \
- lhsVi##iter = ploadLhs<Scalar, Packet>(lhs_ptr_real##iter + imag_delta); \
- } else { \
- EIGEN_UNUSED_VARIABLE(lhsVi##iter); \
- } \
- lhs_ptr_real##iter += accCols; \
- } else { \
- EIGEN_UNUSED_VARIABLE(lhsV##iter); \
- EIGEN_UNUSED_VARIABLE(lhsVi##iter); \
- }
+ MICRO_COMPLEX_UNROLL(func2); \
+ func(0,peel) func(1,peel) func(2,peel) func(3,peel)
#define MICRO_COMPLEX_WORK_ONE4(iter, peel) \
if (unroll_factor > iter) { \
@@ -1995,12 +1956,7 @@
if (PEEL_COMPLEX > peel) { \
Packet lhsV0, lhsV1, lhsV2, lhsV3; \
Packet lhsVi0, lhsVi1, lhsVi2, lhsVi3; \
- pbroadcastN_old<Packet,accRows>(rhs_ptr_real + (accRows * peel), rhsV##peel[0], rhsV##peel[1], rhsV##peel[2], rhsV##peel[3]); \
- if(!RhsIsReal) { \
- pbroadcastN_old<Packet,accRows>(rhs_ptr_imag + (accRows * peel), rhsVi##peel[0], rhsVi##peel[1], rhsVi##peel[2], rhsVi##peel[3]); \
- } else { \
- EIGEN_UNUSED_VARIABLE(rhsVi##peel); \
- } \
+ MICRO_COMPLEX_BROADCAST(peel) \
MICRO_COMPLEX_UNROLL_WORK(func, func2, peel) \
} else { \
EIGEN_UNUSED_VARIABLE(rhsV##peel); \
@@ -2010,27 +1966,25 @@
#define MICRO_COMPLEX_UNROLL_TYPE_PEEL(M, func, func1, func2) \
Packet rhsV0[M], rhsV1[M], rhsV2[M], rhsV3[M]; \
Packet rhsVi0[M], rhsVi1[M], rhsVi2[M], rhsVi3[M]; \
- func(func1,func2,0); func(func1,func2,1); \
- func(func1,func2,2); func(func1,func2,3);
+ func(func1,func2,0) func(func1,func2,1) \
+ func(func1,func2,2) func(func1,func2,3)
#define MICRO_COMPLEX_UNROLL_TYPE_ONE(M, func, func1, func2) \
Packet rhsV0[M], rhsVi0[M];\
- func(func1,func2,0);
+ func(func1,func2,0)
-#define MICRO_COMPLEX_ONE_PEEL4 \
- MICRO_COMPLEX_UNROLL_TYPE_PEEL(4, MICRO_COMPLEX_TYPE_PEEL4, MICRO_COMPLEX_WORK_ONE4, MICRO_COMPLEX_LOAD_ONE); \
- rhs_ptr_real += (accRows * PEEL_COMPLEX); \
- if(!RhsIsReal) rhs_ptr_imag += (accRows * PEEL_COMPLEX);
+#define MICRO_COMPLEX_UNROLL_TYPE(MICRO_COMPLEX_TYPE, size) \
+ MICRO_COMPLEX_TYPE(4, MICRO_COMPLEX_TYPE_PEEL4, MICRO_COMPLEX_WORK_ONE4, MICRO_COMPLEX_LOAD_ONE) \
+ MICRO_COMPLEX_ADD_ROWS(size, false)
-#define MICRO_COMPLEX_ONE4 \
- MICRO_COMPLEX_UNROLL_TYPE_ONE(4, MICRO_COMPLEX_TYPE_PEEL4, MICRO_COMPLEX_WORK_ONE4, MICRO_COMPLEX_LOAD_ONE); \
- rhs_ptr_real += accRows; \
- if(!RhsIsReal) rhs_ptr_imag += accRows;
+#define MICRO_COMPLEX_ONE_PEEL4 MICRO_COMPLEX_UNROLL_TYPE(MICRO_COMPLEX_UNROLL_TYPE_PEEL, PEEL_COMPLEX)
+
+#define MICRO_COMPLEX_ONE4 MICRO_COMPLEX_UNROLL_TYPE(MICRO_COMPLEX_UNROLL_TYPE_ONE, 1)
#define MICRO_COMPLEX_DST_PTR_ONE(iter) \
if (unroll_factor > iter) { \
- bsetzero<Scalar, Packet, accRows>(accReal##iter); \
- bsetzero<Scalar, Packet, accRows>(accImag##iter); \
+ bsetzero<Packet, accRows>(accReal##iter); \
+ bsetzero<Packet, accRows>(accImag##iter); \
} else { \
EIGEN_UNUSED_VARIABLE(accReal##iter); \
EIGEN_UNUSED_VARIABLE(accImag##iter); \
@@ -2038,53 +1992,42 @@
#define MICRO_COMPLEX_DST_PTR MICRO_COMPLEX_UNROLL(MICRO_COMPLEX_DST_PTR_ONE)
-#define MICRO_COMPLEX_SRC_PTR_ONE(iter) \
- if (unroll_factor > iter) { \
- lhs_ptr_real##iter = lhs_base + ( ((advanceRows*row)/accCols) + iter*advanceRows )*strideA*accCols; \
- } else { \
- EIGEN_UNUSED_VARIABLE(lhs_ptr_real##iter); \
- }
-
#define MICRO_COMPLEX_SRC_PTR MICRO_COMPLEX_UNROLL(MICRO_COMPLEX_SRC_PTR_ONE)
-#define MICRO_COMPLEX_PREFETCH_ONE(iter) \
- if (unroll_factor > iter) { \
- EIGEN_POWER_PREFETCH(lhs_ptr_real##iter); \
- }
-
#define MICRO_COMPLEX_PREFETCH MICRO_COMPLEX_UNROLL(MICRO_COMPLEX_PREFETCH_ONE)
#define MICRO_COMPLEX_STORE_ONE(iter) \
if (unroll_factor > iter) { \
- bload<DataMapper, Packetc, Index, accColsC, ColMajor, true, accRows>(tRes, res, row + iter*accCols, 0); \
- bscalec<Packet,accRows>(accReal##iter, accImag##iter, pAlphaReal, pAlphaImag, taccReal, taccImag); \
- bcouple<Packet, Packetc, accRows>(taccReal, taccImag, tRes, acc0, acc1); \
- res.template storePacketBlock<Packetc,accRows>(row + iter*accCols + 0, 0, acc0); \
- res.template storePacketBlock<Packetc,accRows>(row + iter*accCols + accColsC, 0, acc1); \
+ const bool full = ((MICRO_NORMAL(iter)) || (accCols2 > accColsC)); \
+ bload<DataMapper, Packetc, Index, accColsC, ColMajor, true, accRows, full>(tRes, res, row + iter*accCols, 0); \
+ bscalec<Packet,accRows,!(MICRO_NORMAL(iter))>(accReal##iter, accImag##iter, pAlphaReal, pAlphaImag, taccReal, taccImag, pMask); \
+ bcouple<Packet, Packetc, accRows, full>(taccReal, taccImag, tRes, acc0, acc1); \
+ bstore<DataMapper, Packetc, Index, accRows>(acc0, res, row + iter*accCols + 0); \
+ if (full) { \
+ bstore<DataMapper, Packetc, Index, accRows>(acc1, res, row + iter*accCols + accColsC); \
+ } \
}
#define MICRO_COMPLEX_STORE MICRO_COMPLEX_UNROLL(MICRO_COMPLEX_STORE_ONE)
-template<int unroll_factor, typename Scalar, typename Packet, typename Packetc, typename DataMapper, typename Index, const Index accRows, const Index accCols, bool ConjugateLhs, bool ConjugateRhs, bool LhsIsReal, bool RhsIsReal>
-EIGEN_STRONG_INLINE void gemm_complex_unrolled_iteration(
+template<int unroll_factor, typename Scalar, typename Packet, typename Packetc, typename DataMapper, typename Index, const Index accRows, const Index accCols, const Index accCols2, bool ConjugateLhs, bool ConjugateRhs, bool LhsIsReal, bool RhsIsReal>
+EIGEN_ALWAYS_INLINE void gemm_complex_unrolled_iteration(
const DataMapper& res,
const Scalar* lhs_base,
const Scalar* rhs_base,
Index depth,
Index strideA,
+ Index offsetA,
Index strideB,
Index& row,
const Packet& pAlphaReal,
- const Packet& pAlphaImag)
+ const Packet& pAlphaImag,
+ const Packet& pMask)
{
- const Scalar* rhs_ptr_real = rhs_base;
- const Scalar* rhs_ptr_imag = NULL;
+ const Scalar* rhs_ptr_real0 = rhs_base, * rhs_ptr_real1 = NULL, * rhs_ptr_real2 = NULL;
+ const Scalar* rhs_ptr_imag0 = NULL, * rhs_ptr_imag1 = NULL, * rhs_ptr_imag2 = NULL;
const Index imag_delta = accCols*strideA;
- if(!RhsIsReal) {
- rhs_ptr_imag = rhs_base + accRows*strideB;
- } else {
- EIGEN_UNUSED_VARIABLE(rhs_ptr_imag);
- }
+ const Index imag_delta2 = accCols2*strideA;
const Scalar* lhs_ptr_real0 = NULL, * lhs_ptr_real1 = NULL;
const Scalar* lhs_ptr_real2 = NULL, * lhs_ptr_real3 = NULL;
PacketBlock<Packet,accRows> accReal0, accImag0, accReal1, accImag1;
@@ -2093,16 +2036,14 @@
PacketBlock<Packetc,accRows> acc0, acc1;
PacketBlock<Packetc,accRows*2> tRes;
+ MICRO_COMPLEX_SRC2_PTR
MICRO_COMPLEX_SRC_PTR
MICRO_COMPLEX_DST_PTR
Index k = 0;
for(; k + PEEL_COMPLEX <= depth; k+= PEEL_COMPLEX)
{
- EIGEN_POWER_PREFETCH(rhs_ptr_real);
- if(!RhsIsReal) {
- EIGEN_POWER_PREFETCH(rhs_ptr_imag);
- }
+ MICRO_COMPLEX_PREFETCHN(accRows)
MICRO_COMPLEX_PREFETCH
MICRO_COMPLEX_ONE_PEEL4
}
@@ -2112,9 +2053,13 @@
}
MICRO_COMPLEX_STORE
- row += unroll_factor*accCols;
+ MICRO_COMPLEX_UPDATE
}
+#define MICRO_COMPLEX_UNROLL_ITER2(N, M) \
+ gemm_complex_unrolled_iteration<N + (M ? 1 : 0), Scalar, Packet, Packetc, DataMapper, Index, accRows, accCols, M ? M : accCols, ConjugateLhs, ConjugateRhs, LhsIsReal, RhsIsReal>(res3, lhs_base, rhs_base, depth, strideA, offsetA, strideB, row, pAlphaReal, pAlphaImag, pMask); \
+ if (M) return;
+
template<typename Scalar, typename Packet, typename Packetc, typename DataMapper, typename Index, const Index accRows, const Index accCols, bool ConjugateLhs, bool ConjugateRhs, bool LhsIsReal, bool RhsIsReal>
EIGEN_ALWAYS_INLINE void gemm_complex_cols(
const DataMapper& res,
@@ -2127,7 +2072,6 @@
Index offsetB,
Index col,
Index rows,
- Index cols,
Index remaining_rows,
const Packet& pAlphaReal,
const Packet& pAlphaImag,
@@ -2135,33 +2079,33 @@
{
const DataMapper res3 = res.getSubMapper(0, col);
- const Scalar* rhs_base = blockB + advanceCols*col*strideB + accRows*offsetB;
+ const Scalar* rhs_base = blockB + advanceCols*col*strideB + MICRO_NEW_ROWS*offsetB;
const Scalar* lhs_base = blockA + accCols*offsetA;
Index row = 0;
-#define MAX_COMPLEX_UNROLL 3
+#define MAX_COMPLEX_UNROLL 4
while(row + MAX_COMPLEX_UNROLL*accCols <= rows) {
- gemm_complex_unrolled_iteration<MAX_COMPLEX_UNROLL, Scalar, Packet, Packetc, DataMapper, Index, accRows, accCols, ConjugateLhs, ConjugateRhs, LhsIsReal, RhsIsReal>(res3, lhs_base, rhs_base, depth, strideA, strideB, row, pAlphaReal, pAlphaImag);
+ MICRO_COMPLEX_UNROLL_ITER2(MAX_COMPLEX_UNROLL, 0);
}
switch( (rows-row)/accCols ) {
#if MAX_COMPLEX_UNROLL > 4
case 4:
- gemm_complex_unrolled_iteration<4, Scalar, Packet, Packetc, DataMapper, Index, accRows, accCols, ConjugateLhs, ConjugateRhs, LhsIsReal, RhsIsReal>(res3, lhs_base, rhs_base, depth, strideA, strideB, row, pAlphaReal, pAlphaImag);
+ MICRO_UNROLL_ITER(MICRO_COMPLEX_UNROLL_ITER2, 4)
break;
#endif
#if MAX_COMPLEX_UNROLL > 3
case 3:
- gemm_complex_unrolled_iteration<3, Scalar, Packet, Packetc, DataMapper, Index, accRows, accCols, ConjugateLhs, ConjugateRhs, LhsIsReal, RhsIsReal>(res3, lhs_base, rhs_base, depth, strideA, strideB, row, pAlphaReal, pAlphaImag);
+ MICRO_UNROLL_ITER(MICRO_COMPLEX_UNROLL_ITER2, 3)
break;
#endif
#if MAX_COMPLEX_UNROLL > 2
case 2:
- gemm_complex_unrolled_iteration<2, Scalar, Packet, Packetc, DataMapper, Index, accRows, accCols, ConjugateLhs, ConjugateRhs, LhsIsReal, RhsIsReal>(res3, lhs_base, rhs_base, depth, strideA, strideB, row, pAlphaReal, pAlphaImag);
+ MICRO_UNROLL_ITER(MICRO_COMPLEX_UNROLL_ITER2, 2)
break;
#endif
#if MAX_COMPLEX_UNROLL > 1
case 1:
- gemm_complex_unrolled_iteration<1, Scalar, Packet, Packetc, DataMapper, Index, accRows, accCols, ConjugateLhs, ConjugateRhs, LhsIsReal, RhsIsReal>(res3, lhs_base, rhs_base, depth, strideA, strideB, row, pAlphaReal, pAlphaImag);
+ MICRO_UNROLL_ITER(MICRO_COMPLEX_UNROLL_ITER2, 1)
break;
#endif
default:
@@ -2171,10 +2115,13 @@
if(remaining_rows > 0)
{
- gemm_complex_extra_row<Scalar, Packet, Packetc, DataMapper, Index, accRows, accCols, ConjugateLhs, ConjugateRhs, LhsIsReal, RhsIsReal>(res3, blockA, rhs_base, depth, strideA, offsetA, strideB, row, col, rows, cols, remaining_rows, pAlphaReal, pAlphaImag, pMask);
+ gemm_complex_extra_row<Scalar, Packet, Packetc, DataMapper, Index, accRows, accCols, ConjugateLhs, ConjugateRhs, LhsIsReal, RhsIsReal>(res3, blockA, rhs_base, depth, strideA, offsetA, strideB, row, rows, remaining_rows, pAlphaReal, pAlphaImag, pMask);
}
}
+#define MICRO_COMPLEX_EXTRA_COLS(N) \
+ gemm_complex_cols<Scalar, Packet, Packetc, DataMapper, Index, N, accCols, ConjugateLhs, ConjugateRhs, LhsIsReal, RhsIsReal>(res, blockA, blockB, depth, strideA, offsetA, strideB, offsetB, col, rows, remaining_rows, pAlphaReal, pAlphaImag, pMask);
+
template<typename Scalar, typename Packet, typename Packetc, typename DataMapper, typename Index, const Index accCols, bool ConjugateLhs, bool ConjugateRhs, bool LhsIsReal, bool RhsIsReal>
EIGEN_STRONG_INLINE void gemm_complex_extra_cols(
const DataMapper& res,
@@ -2193,9 +2140,7 @@
const Packet& pAlphaImag,
const Packet& pMask)
{
- for (; col < cols; col++) {
- gemm_complex_cols<Scalar, Packet, Packetc, DataMapper, Index, 1, accCols, ConjugateLhs, ConjugateRhs, LhsIsReal, RhsIsReal>(res, blockA, blockB, depth, strideA, offsetA, strideB, offsetB, col, rows, cols, remaining_rows, pAlphaReal, pAlphaImag, pMask);
- }
+ MICRO_EXTRA(MICRO_COMPLEX_EXTRA_COLS, cols-col, true)
}
template<typename LhsScalar, typename RhsScalar, typename Scalarc, typename Scalar, typename Index, typename Packet, typename Packetc, typename RhsPacket, typename DataMapper, const Index accRows, const Index accCols, bool ConjugateLhs, bool ConjugateRhs, bool LhsIsReal, bool RhsIsReal>
@@ -2216,10 +2161,13 @@
Index col = 0;
for(; col + accRows <= cols; col += accRows)
{
- gemm_complex_cols<Scalar, Packet, Packetc, DataMapper, Index, accRows, accCols, ConjugateLhs, ConjugateRhs, LhsIsReal, RhsIsReal>(res, blockA, blockB, depth, strideA, offsetA, strideB, offsetB, col, rows, cols, remaining_rows, pAlphaReal, pAlphaImag, pMask);
+ gemm_complex_cols<Scalar, Packet, Packetc, DataMapper, Index, accRows, accCols, ConjugateLhs, ConjugateRhs, LhsIsReal, RhsIsReal>(res, blockA, blockB, depth, strideA, offsetA, strideB, offsetB, col, rows, remaining_rows, pAlphaReal, pAlphaImag, pMask);
}
- gemm_complex_extra_cols<Scalar, Packet, Packetc, DataMapper, Index, accCols, ConjugateLhs, ConjugateRhs, LhsIsReal, RhsIsReal>(res, blockA, blockB, depth, strideA, offsetA, strideB, offsetB, col, rows, cols, remaining_rows, pAlphaReal, pAlphaImag, pMask);
+ if (col != cols)
+ {
+ gemm_complex_extra_cols<Scalar, Packet, Packetc, DataMapper, Index, accCols, ConjugateLhs, ConjugateRhs, LhsIsReal, RhsIsReal>(res, blockA, blockB, depth, strideA, offsetA, strideB, offsetB, col, rows, cols, remaining_rows, pAlphaReal, pAlphaImag, pMask);
+ }
}
#undef accColsC
@@ -2481,10 +2429,10 @@
const Index accCols = quad_traits<float>::size;
void (*gemm_function)(const DataMapper&, const float*, const float*, Index, Index, Index, float, Index, Index, Index, Index);
- #ifdef EIGEN_ALTIVEC_MMA_ONLY
+ #if defined(EIGEN_ALTIVEC_MMA_ONLY)
//generate with MMA only
gemm_function = &Eigen::internal::gemmMMA<float, Index, Packet, RhsPacket, DataMapper, accRows, accCols>;
- #elif defined(ALTIVEC_MMA_SUPPORT) && !defined(EIGEN_ALTIVEC_DISABLE_MMA)
+ #elif defined(EIGEN_ALTIVEC_MMA_DYNAMIC_DISPATCH)
if (__builtin_cpu_supports ("arch_3_1") && __builtin_cpu_supports ("mma")){
gemm_function = &Eigen::internal::gemmMMA<float, Index, Packet, RhsPacket, DataMapper, accRows, accCols>;
}
@@ -2494,7 +2442,7 @@
#else
gemm_function = &Eigen::internal::gemm<float, Index, Packet, RhsPacket, DataMapper, accRows, accCols>;
#endif
- gemm_function(res, blockA, blockB, rows, depth, cols, alpha, strideA, strideB, offsetA, offsetB);
+ gemm_function(res, blockA, blockB, rows, depth, cols, alpha, strideA, strideB, offsetA, offsetB);
}
template<typename Index, typename DataMapper, int mr, int nr, bool ConjugateLhs, bool ConjugateRhs>
@@ -2520,20 +2468,20 @@
void (*gemm_function)(const DataMapper&, const std::complex<float>*, const std::complex<float>*,
Index, Index, Index, std::complex<float>, Index, Index, Index, Index);
- #ifdef EIGEN_ALTIVEC_MMA_ONLY
- //generate with MMA only
- gemm_function = &Eigen::internal::gemm_complexMMA<std::complex<float>, std::complex<float>, std::complex<float>, float, Index, Packet, Packetc, RhsPacket, DataMapper, accRows, accCols, ConjugateLhs, ConjugateRhs, false, false>;
- #elif defined(ALTIVEC_MMA_SUPPORT) && !defined(EIGEN_ALTIVEC_DISABLE_MMA)
- if (__builtin_cpu_supports ("arch_3_1") && __builtin_cpu_supports ("mma")){
- gemm_function = &Eigen::internal::gemm_complexMMA<std::complex<float>, std::complex<float>, std::complex<float>, float, Index, Packet, Packetc, RhsPacket, DataMapper, accRows, accCols, ConjugateLhs, ConjugateRhs, false, false>;
- }
- else{
- gemm_function = &Eigen::internal::gemm_complex<std::complex<float>, std::complex<float>, std::complex<float>, float, Index, Packet, Packetc, RhsPacket, DataMapper, accRows, accCols, ConjugateLhs, ConjugateRhs, false, false>;
- }
- #else
- gemm_function = &Eigen::internal::gemm_complex<std::complex<float>, std::complex<float>, std::complex<float>, float, Index, Packet, Packetc, RhsPacket, DataMapper, accRows, accCols, ConjugateLhs, ConjugateRhs, false, false>;
- #endif
- gemm_function(res, blockA, blockB, rows, depth, cols, alpha, strideA, strideB, offsetA, offsetB);
+ #if defined(EIGEN_ALTIVEC_MMA_ONLY)
+ //generate with MMA only
+ gemm_function = &Eigen::internal::gemm_complexMMA<std::complex<float>, std::complex<float>, std::complex<float>, float, Index, Packet, Packetc, RhsPacket, DataMapper, accRows, accCols, ConjugateLhs, ConjugateRhs, false, false>;
+ #elif defined(EIGEN_ALTIVEC_MMA_DYNAMIC_DISPATCH)
+ if (__builtin_cpu_supports ("arch_3_1") && __builtin_cpu_supports ("mma")){
+ gemm_function = &Eigen::internal::gemm_complexMMA<std::complex<float>, std::complex<float>, std::complex<float>, float, Index, Packet, Packetc, RhsPacket, DataMapper, accRows, accCols, ConjugateLhs, ConjugateRhs, false, false>;
+ }
+ else{
+ gemm_function = &Eigen::internal::gemm_complex<std::complex<float>, std::complex<float>, std::complex<float>, float, Index, Packet, Packetc, RhsPacket, DataMapper, accRows, accCols, ConjugateLhs, ConjugateRhs, false, false>;
+ }
+ #else
+ gemm_function = &Eigen::internal::gemm_complex<std::complex<float>, std::complex<float>, std::complex<float>, float, Index, Packet, Packetc, RhsPacket, DataMapper, accRows, accCols, ConjugateLhs, ConjugateRhs, false, false>;
+ #endif
+ gemm_function(res, blockA, blockB, rows, depth, cols, alpha, strideA, strideB, offsetA, offsetB);
}
template<typename Index, typename DataMapper, int mr, int nr, bool ConjugateLhs, bool ConjugateRhs>
@@ -2558,20 +2506,20 @@
const Index accCols = quad_traits<float>::size;
void (*gemm_function)(const DataMapper&, const float*, const std::complex<float>*,
Index, Index, Index, std::complex<float>, Index, Index, Index, Index);
- #ifdef EIGEN_ALTIVEC_MMA_ONLY
- //generate with MMA only
- gemm_function = &Eigen::internal::gemm_complexMMA<float, std::complex<float>, std::complex<float>, float, Index, Packet, Packetc, RhsPacket, DataMapper, accRows, accCols, ConjugateLhs, ConjugateRhs, true, false>;
- #elif defined(ALTIVEC_MMA_SUPPORT) && !defined(EIGEN_ALTIVEC_DISABLE_MMA)
- if (__builtin_cpu_supports ("arch_3_1") && __builtin_cpu_supports ("mma")){
- gemm_function = &Eigen::internal::gemm_complexMMA<float, std::complex<float>, std::complex<float>, float, Index, Packet, Packetc, RhsPacket, DataMapper, accRows, accCols, ConjugateLhs, ConjugateRhs, true, false>;
- }
- else{
- gemm_function = &Eigen::internal::gemm_complex<float, std::complex<float>, std::complex<float>, float, Index, Packet, Packetc, RhsPacket, DataMapper, accRows, accCols, ConjugateLhs, ConjugateRhs, true, false>;
- }
- #else
- gemm_function = &Eigen::internal::gemm_complex<float, std::complex<float>, std::complex<float>, float, Index, Packet, Packetc, RhsPacket, DataMapper, accRows, accCols, ConjugateLhs, ConjugateRhs, true, false>;
- #endif
- gemm_function(res, blockA, blockB, rows, depth, cols, alpha, strideA, strideB, offsetA, offsetB);
+ #if defined(EIGEN_ALTIVEC_MMA_ONLY)
+ //generate with MMA only
+ gemm_function = &Eigen::internal::gemm_complexMMA<float, std::complex<float>, std::complex<float>, float, Index, Packet, Packetc, RhsPacket, DataMapper, accRows, accCols, ConjugateLhs, ConjugateRhs, true, false>;
+ #elif defined(EIGEN_ALTIVEC_MMA_DYNAMIC_DISPATCH)
+ if (__builtin_cpu_supports ("arch_3_1") && __builtin_cpu_supports ("mma")){
+ gemm_function = &Eigen::internal::gemm_complexMMA<float, std::complex<float>, std::complex<float>, float, Index, Packet, Packetc, RhsPacket, DataMapper, accRows, accCols, ConjugateLhs, ConjugateRhs, true, false>;
+ }
+ else{
+ gemm_function = &Eigen::internal::gemm_complex<float, std::complex<float>, std::complex<float>, float, Index, Packet, Packetc, RhsPacket, DataMapper, accRows, accCols, ConjugateLhs, ConjugateRhs, true, false>;
+ }
+ #else
+ gemm_function = &Eigen::internal::gemm_complex<float, std::complex<float>, std::complex<float>, float, Index, Packet, Packetc, RhsPacket, DataMapper, accRows, accCols, ConjugateLhs, ConjugateRhs, true, false>;
+ #endif
+ gemm_function(res, blockA, blockB, rows, depth, cols, alpha, strideA, strideB, offsetA, offsetB);
}
template<typename Index, typename DataMapper, int mr, int nr, bool ConjugateLhs, bool ConjugateRhs>
@@ -2596,20 +2544,20 @@
const Index accCols = quad_traits<float>::size;
void (*gemm_function)(const DataMapper&, const std::complex<float>*, const float*,
Index, Index, Index, std::complex<float>, Index, Index, Index, Index);
- #ifdef EIGEN_ALTIVEC_MMA_ONLY
- //generate with MMA only
- gemm_function = &Eigen::internal::gemm_complexMMA<std::complex<float>, float, std::complex<float>, float, Index, Packet, Packetc, RhsPacket, DataMapper, accRows, accCols, ConjugateLhs, ConjugateRhs, false, true>;
- #elif defined(ALTIVEC_MMA_SUPPORT) && !defined(EIGEN_ALTIVEC_DISABLE_MMA)
- if (__builtin_cpu_supports ("arch_3_1") && __builtin_cpu_supports ("mma")){
- gemm_function = &Eigen::internal::gemm_complexMMA<std::complex<float>, float, std::complex<float>, float, Index, Packet, Packetc, RhsPacket, DataMapper, accRows, accCols, ConjugateLhs, ConjugateRhs, false, true>;
- }
- else{
- gemm_function = &Eigen::internal::gemm_complex<std::complex<float>, float, std::complex<float>, float, Index, Packet, Packetc, RhsPacket, DataMapper, accRows, accCols, ConjugateLhs, ConjugateRhs, false, true>;
- }
- #else
- gemm_function = &Eigen::internal::gemm_complex<std::complex<float>, float, std::complex<float>, float, Index, Packet, Packetc, RhsPacket, DataMapper, accRows, accCols, ConjugateLhs, ConjugateRhs, false, true>;
- #endif
- gemm_function(res, blockA, blockB, rows, depth, cols, alpha, strideA, strideB, offsetA, offsetB);
+ #if defined(EIGEN_ALTIVEC_MMA_ONLY)
+ //generate with MMA only
+ gemm_function = &Eigen::internal::gemm_complexMMA<std::complex<float>, float, std::complex<float>, float, Index, Packet, Packetc, RhsPacket, DataMapper, accRows, accCols, ConjugateLhs, ConjugateRhs, false, true>;
+ #elif defined(EIGEN_ALTIVEC_MMA_DYNAMIC_DISPATCH)
+ if (__builtin_cpu_supports ("arch_3_1") && __builtin_cpu_supports ("mma")){
+ gemm_function = &Eigen::internal::gemm_complexMMA<std::complex<float>, float, std::complex<float>, float, Index, Packet, Packetc, RhsPacket, DataMapper, accRows, accCols, ConjugateLhs, ConjugateRhs, false, true>;
+ }
+ else{
+ gemm_function = &Eigen::internal::gemm_complex<std::complex<float>, float, std::complex<float>, float, Index, Packet, Packetc, RhsPacket, DataMapper, accRows, accCols, ConjugateLhs, ConjugateRhs, false, true>;
+ }
+ #else
+ gemm_function = &Eigen::internal::gemm_complex<std::complex<float>, float, std::complex<float>, float, Index, Packet, Packetc, RhsPacket, DataMapper, accRows, accCols, ConjugateLhs, ConjugateRhs, false, true>;
+ #endif
+ gemm_function(res, blockA, blockB, rows, depth, cols, alpha, strideA, strideB, offsetA, offsetB);
}
template<typename Index, typename DataMapper, int mr, int nr, bool ConjugateLhs, bool ConjugateRhs>
@@ -2633,10 +2581,10 @@
const Index accCols = quad_traits<double>::size;
void (*gemm_function)(const DataMapper&, const double*, const double*, Index, Index, Index, double, Index, Index, Index, Index);
- #ifdef EIGEN_ALTIVEC_MMA_ONLY
+ #if defined(EIGEN_ALTIVEC_MMA_ONLY)
//generate with MMA only
gemm_function = &Eigen::internal::gemmMMA<double, Index, Packet, RhsPacket, DataMapper, accRows, accCols>;
- #elif defined(ALTIVEC_MMA_SUPPORT) && !defined(EIGEN_ALTIVEC_DISABLE_MMA)
+ #elif defined(EIGEN_ALTIVEC_MMA_DYNAMIC_DISPATCH)
if (__builtin_cpu_supports ("arch_3_1") && __builtin_cpu_supports ("mma")){
gemm_function = &Eigen::internal::gemmMMA<double, Index, Packet, RhsPacket, DataMapper, accRows, accCols>;
}
@@ -2646,7 +2594,7 @@
#else
gemm_function = &Eigen::internal::gemm<double, Index, Packet, RhsPacket, DataMapper, accRows, accCols>;
#endif
- gemm_function(res, blockA, blockB, rows, depth, cols, alpha, strideA, strideB, offsetA, offsetB);
+ gemm_function(res, blockA, blockB, rows, depth, cols, alpha, strideA, strideB, offsetA, offsetB);
}
template<typename Index, typename DataMapper, int mr, int nr, bool ConjugateLhs, bool ConjugateRhs>
@@ -2671,20 +2619,20 @@
const Index accCols = quad_traits<double>::size;
void (*gemm_function)(const DataMapper&, const std::complex<double>*, const std::complex<double>*,
Index, Index, Index, std::complex<double>, Index, Index, Index, Index);
- #ifdef EIGEN_ALTIVEC_MMA_ONLY
- //generate with MMA only
- gemm_function = &Eigen::internal::gemm_complexMMA<std::complex<double>, std::complex<double>, std::complex<double>, double, Index, Packet, Packetc, RhsPacket, DataMapper, accRows, accCols, ConjugateLhs, ConjugateRhs, false, false>;
- #elif defined(ALTIVEC_MMA_SUPPORT) && !defined(EIGEN_ALTIVEC_DISABLE_MMA)
- if (__builtin_cpu_supports ("arch_3_1") && __builtin_cpu_supports ("mma")){
- gemm_function = &Eigen::internal::gemm_complexMMA<std::complex<double>, std::complex<double>, std::complex<double>, double, Index, Packet, Packetc, RhsPacket, DataMapper, accRows, accCols, ConjugateLhs, ConjugateRhs, false, false>;
- }
- else{
- gemm_function = &Eigen::internal::gemm_complex<std::complex<double>, std::complex<double>, std::complex<double>, double, Index, Packet, Packetc, RhsPacket, DataMapper, accRows, accCols, ConjugateLhs, ConjugateRhs, false, false>;
- }
- #else
- gemm_function = &Eigen::internal::gemm_complex<std::complex<double>, std::complex<double>, std::complex<double>, double, Index, Packet, Packetc, RhsPacket, DataMapper, accRows, accCols, ConjugateLhs, ConjugateRhs, false, false>;
- #endif
- gemm_function(res, blockA, blockB, rows, depth, cols, alpha, strideA, strideB, offsetA, offsetB);
+ #if defined(EIGEN_ALTIVEC_MMA_ONLY)
+ //generate with MMA only
+ gemm_function = &Eigen::internal::gemm_complexMMA<std::complex<double>, std::complex<double>, std::complex<double>, double, Index, Packet, Packetc, RhsPacket, DataMapper, accRows, accCols, ConjugateLhs, ConjugateRhs, false, false>;
+ #elif defined(EIGEN_ALTIVEC_MMA_DYNAMIC_DISPATCH)
+ if (__builtin_cpu_supports ("arch_3_1") && __builtin_cpu_supports ("mma")){
+ gemm_function = &Eigen::internal::gemm_complexMMA<std::complex<double>, std::complex<double>, std::complex<double>, double, Index, Packet, Packetc, RhsPacket, DataMapper, accRows, accCols, ConjugateLhs, ConjugateRhs, false, false>;
+ }
+ else{
+ gemm_function = &Eigen::internal::gemm_complex<std::complex<double>, std::complex<double>, std::complex<double>, double, Index, Packet, Packetc, RhsPacket, DataMapper, accRows, accCols, ConjugateLhs, ConjugateRhs, false, false>;
+ }
+ #else
+ gemm_function = &Eigen::internal::gemm_complex<std::complex<double>, std::complex<double>, std::complex<double>, double, Index, Packet, Packetc, RhsPacket, DataMapper, accRows, accCols, ConjugateLhs, ConjugateRhs, false, false>;
+ #endif
+ gemm_function(res, blockA, blockB, rows, depth, cols, alpha, strideA, strideB, offsetA, offsetB);
}
template<typename Index, typename DataMapper, int mr, int nr, bool ConjugateLhs, bool ConjugateRhs>
@@ -2709,20 +2657,20 @@
const Index accCols = quad_traits<double>::size;
void (*gemm_function)(const DataMapper&, const std::complex<double>*, const double*,
Index, Index, Index, std::complex<double>, Index, Index, Index, Index);
- #ifdef EIGEN_ALTIVEC_MMA_ONLY
- //generate with MMA only
- gemm_function = &Eigen::internal::gemm_complexMMA<std::complex<double>, double, std::complex<double>, double, Index, Packet, Packetc, RhsPacket, DataMapper, accRows, accCols, ConjugateLhs, ConjugateRhs, false, true>;
- #elif defined(ALTIVEC_MMA_SUPPORT) && !defined(EIGEN_ALTIVEC_DISABLE_MMA)
- if (__builtin_cpu_supports ("arch_3_1") && __builtin_cpu_supports ("mma")){
- gemm_function = &Eigen::internal::gemm_complexMMA<std::complex<double>, double, std::complex<double>, double, Index, Packet, Packetc, RhsPacket, DataMapper, accRows, accCols, ConjugateLhs, ConjugateRhs, false, true>;
- }
- else{
- gemm_function = &Eigen::internal::gemm_complex<std::complex<double>, double, std::complex<double>, double, Index, Packet, Packetc, RhsPacket, DataMapper, accRows, accCols, ConjugateLhs, ConjugateRhs, false, true>;
- }
- #else
- gemm_function = &Eigen::internal::gemm_complex<std::complex<double>, double, std::complex<double>, double, Index, Packet, Packetc, RhsPacket, DataMapper, accRows, accCols, ConjugateLhs, ConjugateRhs, false, true>;
- #endif
- gemm_function(res, blockA, blockB, rows, depth, cols, alpha, strideA, strideB, offsetA, offsetB);
+ #if defined(EIGEN_ALTIVEC_MMA_ONLY)
+ //generate with MMA only
+ gemm_function = &Eigen::internal::gemm_complexMMA<std::complex<double>, double, std::complex<double>, double, Index, Packet, Packetc, RhsPacket, DataMapper, accRows, accCols, ConjugateLhs, ConjugateRhs, false, true>;
+ #elif defined(EIGEN_ALTIVEC_MMA_DYNAMIC_DISPATCH)
+ if (__builtin_cpu_supports ("arch_3_1") && __builtin_cpu_supports ("mma")){
+ gemm_function = &Eigen::internal::gemm_complexMMA<std::complex<double>, double, std::complex<double>, double, Index, Packet, Packetc, RhsPacket, DataMapper, accRows, accCols, ConjugateLhs, ConjugateRhs, false, true>;
+ }
+ else{
+ gemm_function = &Eigen::internal::gemm_complex<std::complex<double>, double, std::complex<double>, double, Index, Packet, Packetc, RhsPacket, DataMapper, accRows, accCols, ConjugateLhs, ConjugateRhs, false, true>;
+ }
+ #else
+ gemm_function = &Eigen::internal::gemm_complex<std::complex<double>, double, std::complex<double>, double, Index, Packet, Packetc, RhsPacket, DataMapper, accRows, accCols, ConjugateLhs, ConjugateRhs, false, true>;
+ #endif
+ gemm_function(res, blockA, blockB, rows, depth, cols, alpha, strideA, strideB, offsetA, offsetB);
}
template<typename Index, typename DataMapper, int mr, int nr, bool ConjugateLhs, bool ConjugateRhs>
@@ -2747,20 +2695,20 @@
const Index accCols = quad_traits<double>::size;
void (*gemm_function)(const DataMapper&, const double*, const std::complex<double>*,
Index, Index, Index, std::complex<double>, Index, Index, Index, Index);
- #ifdef EIGEN_ALTIVEC_MMA_ONLY
- //generate with MMA only
- gemm_function = &Eigen::internal::gemm_complexMMA<double, std::complex<double>, std::complex<double>, double, Index, Packet, Packetc, RhsPacket, DataMapper, accRows, accCols, ConjugateLhs, ConjugateRhs, true, false>;
- #elif defined(ALTIVEC_MMA_SUPPORT) && !defined(EIGEN_ALTIVEC_DISABLE_MMA)
- if (__builtin_cpu_supports ("arch_3_1") && __builtin_cpu_supports ("mma")){
- gemm_function = &Eigen::internal::gemm_complexMMA<double, std::complex<double>, std::complex<double>, double, Index, Packet, Packetc, RhsPacket, DataMapper, accRows, accCols, ConjugateLhs, ConjugateRhs, true, false>;
- }
- else{
- gemm_function = &Eigen::internal::gemm_complex<double, std::complex<double>, std::complex<double>, double, Index, Packet, Packetc, RhsPacket, DataMapper, accRows, accCols, ConjugateLhs, ConjugateRhs, true, false>;
- }
- #else
- gemm_function = &Eigen::internal::gemm_complex<double, std::complex<double>, std::complex<double>, double, Index, Packet, Packetc, RhsPacket, DataMapper, accRows, accCols, ConjugateLhs, ConjugateRhs, true, false>;
- #endif
- gemm_function(res, blockA, blockB, rows, depth, cols, alpha, strideA, strideB, offsetA, offsetB);
+ #if defined(EIGEN_ALTIVEC_MMA_ONLY)
+ //generate with MMA only
+ gemm_function = &Eigen::internal::gemm_complexMMA<double, std::complex<double>, std::complex<double>, double, Index, Packet, Packetc, RhsPacket, DataMapper, accRows, accCols, ConjugateLhs, ConjugateRhs, true, false>;
+ #elif defined(EIGEN_ALTIVEC_MMA_DYNAMIC_DISPATCH)
+ if (__builtin_cpu_supports ("arch_3_1") && __builtin_cpu_supports ("mma")){
+ gemm_function = &Eigen::internal::gemm_complexMMA<double, std::complex<double>, std::complex<double>, double, Index, Packet, Packetc, RhsPacket, DataMapper, accRows, accCols, ConjugateLhs, ConjugateRhs, true, false>;
+ }
+ else{
+ gemm_function = &Eigen::internal::gemm_complex<double, std::complex<double>, std::complex<double>, double, Index, Packet, Packetc, RhsPacket, DataMapper, accRows, accCols, ConjugateLhs, ConjugateRhs, true, false>;
+ }
+ #else
+ gemm_function = &Eigen::internal::gemm_complex<double, std::complex<double>, std::complex<double>, double, Index, Packet, Packetc, RhsPacket, DataMapper, accRows, accCols, ConjugateLhs, ConjugateRhs, true, false>;
+ #endif
+ gemm_function(res, blockA, blockB, rows, depth, cols, alpha, strideA, strideB, offsetA, offsetB);
}
} // end namespace internal
diff --git a/Eigen/src/Core/arch/AltiVec/MatrixProductCommon.h b/Eigen/src/Core/arch/AltiVec/MatrixProductCommon.h
index d92b678..e68c595 100644
--- a/Eigen/src/Core/arch/AltiVec/MatrixProductCommon.h
+++ b/Eigen/src/Core/arch/AltiVec/MatrixProductCommon.h
@@ -19,10 +19,9 @@
Index depth,
Index strideA,
Index offsetA,
+ Index strideB,
Index row,
- Index col,
Index rows,
- Index cols,
Index remaining_rows,
const Packet& pAlpha,
const Packet& pMask);
@@ -57,9 +56,7 @@
Index offsetA,
Index strideB,
Index row,
- Index col,
Index rows,
- Index cols,
Index remaining_rows,
const Packet& pAlphaReal,
const Packet& pAlphaImag,
@@ -83,79 +80,100 @@
const Packet& pAlphaImag,
const Packet& pMask);
-template<typename Scalar, typename Packet>
-EIGEN_ALWAYS_INLINE Packet ploadLhs(const Scalar* lhs);
+template<typename Packet>
+EIGEN_ALWAYS_INLINE Packet ploadLhs(const __UNPACK_TYPE__(Packet)* lhs);
-template<typename DataMapper, typename Packet, typename Index, const Index accCols, int StorageOrder, bool Complex, int N>
+template<typename DataMapper, typename Packet, typename Index, const Index accCols, int StorageOrder, bool Complex, int N, bool full = true>
EIGEN_ALWAYS_INLINE void bload(PacketBlock<Packet,N*(Complex?2:1)>& acc, const DataMapper& res, Index row, Index col);
-template<typename Packet, int N>
-EIGEN_ALWAYS_INLINE void bscale(PacketBlock<Packet,N>& acc, PacketBlock<Packet,N>& accZ, const Packet& pAlpha);
+template<typename DataMapper, typename Packet, typename Index, int N>
+EIGEN_ALWAYS_INLINE void bstore(PacketBlock<Packet,N>& acc, const DataMapper& res, Index row);
-template<typename Packet, int N>
-EIGEN_ALWAYS_INLINE void bscalec(PacketBlock<Packet,N>& aReal, PacketBlock<Packet,N>& aImag, const Packet& bReal, const Packet& bImag, PacketBlock<Packet,N>& cReal, PacketBlock<Packet,N>& cImag);
+template<typename Packet, int N, bool mask>
+EIGEN_ALWAYS_INLINE void bscale(PacketBlock<Packet,N>& acc, PacketBlock<Packet,N>& accZ, const Packet& pAlpha, const Packet& pMask);
-// Grab two decouples real/imaginary PacketBlocks and return two coupled (real/imaginary pairs) PacketBlocks.
-template<typename Packet, typename Packetc, int N>
-EIGEN_ALWAYS_INLINE void bcouple_common(PacketBlock<Packet,N>& taccReal, PacketBlock<Packet,N>& taccImag, PacketBlock<Packetc, N>& acc1, PacketBlock<Packetc, N>& acc2)
-{
- acc1.packet[0].v = vec_mergeh(taccReal.packet[0], taccImag.packet[0]);
- if (N > 1) {
- acc1.packet[1].v = vec_mergeh(taccReal.packet[1], taccImag.packet[1]);
- }
- if (N > 2) {
- acc1.packet[2].v = vec_mergeh(taccReal.packet[2], taccImag.packet[2]);
- }
- if (N > 3) {
- acc1.packet[3].v = vec_mergeh(taccReal.packet[3], taccImag.packet[3]);
+template<typename Packet, int N, bool mask>
+EIGEN_ALWAYS_INLINE void bscalec(PacketBlock<Packet,N>& aReal, PacketBlock<Packet,N>& aImag, const Packet& bReal, const Packet& bImag, PacketBlock<Packet,N>& cReal, PacketBlock<Packet,N>& cImag, const Packet& pMask);
+
+template<typename Packet, typename Packetc, int N, bool full>
+EIGEN_ALWAYS_INLINE void bcouple(PacketBlock<Packet,N>& taccReal, PacketBlock<Packet,N>& taccImag, PacketBlock<Packetc,N*2>& tRes, PacketBlock<Packetc, N>& acc1, PacketBlock<Packetc, N>& acc2);
+
+#define MICRO_NORMAL(iter) \
+ (accCols == accCols2) || (unroll_factor != (iter + 1))
+
+#define MICRO_UNROLL_ITER(func, N) \
+ switch (remaining_rows) { \
+ default: \
+ func(N, 0) \
+ break; \
+ case 1: \
+ func(N, 1) \
+ break; \
+ case 2: \
+ if (sizeof(Scalar) == sizeof(float)) { \
+ func(N, 2) \
+ } \
+ break; \
+ case 3: \
+ if (sizeof(Scalar) == sizeof(float)) { \
+ func(N, 3) \
+ } \
+ break; \
}
- acc2.packet[0].v = vec_mergel(taccReal.packet[0], taccImag.packet[0]);
- if (N > 1) {
- acc2.packet[1].v = vec_mergel(taccReal.packet[1], taccImag.packet[1]);
- }
- if (N > 2) {
- acc2.packet[2].v = vec_mergel(taccReal.packet[2], taccImag.packet[2]);
- }
- if (N > 3) {
- acc2.packet[3].v = vec_mergel(taccReal.packet[3], taccImag.packet[3]);
- }
-}
+#define MICRO_NORMAL_COLS(iter, a, b) ((MICRO_NORMAL(iter)) ? a : b)
-template<typename Packet, typename Packetc, int N>
-EIGEN_ALWAYS_INLINE void bcouple(PacketBlock<Packet,N>& taccReal, PacketBlock<Packet,N>& taccImag, PacketBlock<Packetc,N*2>& tRes, PacketBlock<Packetc, N>& acc1, PacketBlock<Packetc, N>& acc2)
-{
- bcouple_common<Packet, Packetc, N>(taccReal, taccImag, acc1, acc2);
-
- acc1.packet[0] = padd<Packetc>(tRes.packet[0], acc1.packet[0]);
- if (N > 1) {
- acc1.packet[1] = padd<Packetc>(tRes.packet[1], acc1.packet[1]);
- }
- if (N > 2) {
- acc1.packet[2] = padd<Packetc>(tRes.packet[2], acc1.packet[2]);
- }
- if (N > 3) {
- acc1.packet[3] = padd<Packetc>(tRes.packet[3], acc1.packet[3]);
+#define MICRO_LOAD1(lhs_ptr, iter) \
+ if (unroll_factor > iter) { \
+ lhsV##iter = ploadLhs<Packet>(lhs_ptr##iter); \
+ lhs_ptr##iter += MICRO_NORMAL_COLS(iter, accCols, accCols2); \
+ } else { \
+ EIGEN_UNUSED_VARIABLE(lhsV##iter); \
}
- acc2.packet[0] = padd<Packetc>(tRes.packet[0+N], acc2.packet[0]);
- if (N > 1) {
- acc2.packet[1] = padd<Packetc>(tRes.packet[1+N], acc2.packet[1]);
- }
- if (N > 2) {
- acc2.packet[2] = padd<Packetc>(tRes.packet[2+N], acc2.packet[2]);
- }
- if (N > 3) {
- acc2.packet[3] = padd<Packetc>(tRes.packet[3+N], acc2.packet[3]);
- }
-}
+#define MICRO_LOAD_ONE(iter) MICRO_LOAD1(lhs_ptr, iter)
-// This is necessary because ploadRhs for double returns a pair of vectors when MMA is enabled.
-template<typename Scalar, typename Packet>
-EIGEN_ALWAYS_INLINE Packet ploadRhs(const Scalar* rhs)
-{
- return ploadu<Packet>(rhs);
-}
+#define MICRO_COMPLEX_LOAD_ONE(iter) \
+ if (!LhsIsReal && (unroll_factor > iter)) { \
+ lhsVi##iter = ploadLhs<Packet>(lhs_ptr_real##iter + MICRO_NORMAL_COLS(iter, imag_delta, imag_delta2)); \
+ } else { \
+ EIGEN_UNUSED_VARIABLE(lhsVi##iter); \
+ } \
+ MICRO_LOAD1(lhs_ptr_real, iter) \
+
+#define MICRO_SRC_PTR1(lhs_ptr, advRows, iter) \
+ if (unroll_factor > iter) { \
+ lhs_ptr##iter = lhs_base + (row+(iter*accCols))*strideA*advRows - MICRO_NORMAL_COLS(iter, 0, (accCols-accCols2)*offsetA); \
+ } else { \
+ EIGEN_UNUSED_VARIABLE(lhs_ptr##iter); \
+ }
+
+#define MICRO_SRC_PTR_ONE(iter) MICRO_SRC_PTR1(lhs_ptr, 1, iter)
+
+#define MICRO_COMPLEX_SRC_PTR_ONE(iter) MICRO_SRC_PTR1(lhs_ptr_real, advanceRows, iter)
+
+#define MICRO_PREFETCH1(lhs_ptr, iter) \
+ if (unroll_factor > iter) { \
+ EIGEN_POWER_PREFETCH(lhs_ptr##iter); \
+ }
+
+#define MICRO_PREFETCH_ONE(iter) MICRO_PREFETCH1(lhs_ptr, iter)
+
+#define MICRO_COMPLEX_PREFETCH_ONE(iter) MICRO_PREFETCH1(lhs_ptr_real, iter)
+
+#define MICRO_UPDATE \
+ if (accCols == accCols2) { \
+ EIGEN_UNUSED_VARIABLE(pMask); \
+ EIGEN_UNUSED_VARIABLE(offsetA); \
+ row += unroll_factor*accCols; \
+ }
+
+#define MICRO_COMPLEX_UPDATE \
+ MICRO_UPDATE \
+ if(LhsIsReal || (accCols == accCols2)) { \
+ EIGEN_UNUSED_VARIABLE(imag_delta2); \
+ }
+
} // end namespace internal
} // end namespace Eigen
diff --git a/Eigen/src/Core/arch/AltiVec/MatrixProductMMA.h b/Eigen/src/Core/arch/AltiVec/MatrixProductMMA.h
index 8104697..1cb82ee 100644
--- a/Eigen/src/Core/arch/AltiVec/MatrixProductMMA.h
+++ b/Eigen/src/Core/arch/AltiVec/MatrixProductMMA.h
@@ -11,7 +11,9 @@
#ifndef EIGEN_MATRIX_PRODUCT_MMA_ALTIVEC_H
#define EIGEN_MATRIX_PRODUCT_MMA_ALTIVEC_H
-#if !EIGEN_COMP_LLVM
+// If using dynamic dispatch, set the CPU target.
+#if defined(EIGEN_ALTIVEC_MMA_DYNAMIC_DISPATCH)
+#pragma GCC push_options
#pragma GCC target("cpu=power10,htm")
#endif
@@ -19,6 +21,9 @@
#if !__has_builtin(__builtin_vsx_assemble_pair)
#define __builtin_vsx_assemble_pair __builtin_mma_assemble_pair
#endif
+#if !__has_builtin(__builtin_vsx_disassemble_pair)
+#define __builtin_vsx_disassemble_pair __builtin_mma_disassemble_pair
+#endif
#endif
#include "../../InternalHeaderCheck.h"
@@ -27,44 +32,48 @@
namespace internal {
-template<typename Scalar, typename Packet>
+#define accColsC (accCols / 2)
+
EIGEN_ALWAYS_INLINE void bsetzeroMMA(__vector_quad* acc)
{
__builtin_mma_xxsetaccz(acc);
}
-template<typename DataMapper, typename Index, typename Packet, const Index accCols>
-EIGEN_ALWAYS_INLINE void storeAccumulator(Index i, const DataMapper& data, const Packet& alpha, __vector_quad* acc)
+template<typename DataMapper, typename Index, typename Packet, const Index accCols, const Index accCols2>
+EIGEN_ALWAYS_INLINE void storeAccumulator(Index i, const DataMapper& data, const Packet& alpha, const Packet& pMask, __vector_quad* acc)
{
PacketBlock<Packet, 4> result;
__builtin_mma_disassemble_acc(&result.packet, acc);
PacketBlock<Packet, 4> tRes;
- bload<DataMapper, Packet, Index, accCols, ColMajor, false, 4>(tRes, data, i, 0);
+ bload<DataMapper, Packet, Index, 0, ColMajor, false, 4>(tRes, data, i, 0);
- bscale<Packet, 4>(tRes, result, alpha);
+ bscale<Packet, 4, (accCols != accCols2)>(tRes, result, alpha, pMask);
- data.template storePacketBlock<Packet, 4>(i, 0, tRes);
+ bstore<DataMapper, Packet, Index, 4>(tRes, data, i);
}
-template<typename DataMapper, typename Index, typename Packet, typename Packetc, const Index accColsC>
-EIGEN_ALWAYS_INLINE void storeComplexAccumulator(Index i, const DataMapper& data, const Packet& alphaReal, const Packet& alphaImag, __vector_quad* accReal, __vector_quad* accImag)
+template<typename DataMapper, typename Index, typename Packet, typename Packetc, const Index accCols, const Index accCols2>
+EIGEN_ALWAYS_INLINE void storeComplexAccumulator(Index i, const DataMapper& data, const Packet& alphaReal, const Packet& alphaImag, const Packet& pMask, __vector_quad* accReal, __vector_quad* accImag)
{
+ const bool full = (accCols2 > accColsC);
PacketBlock<Packet, 4> resultReal, resultImag;
__builtin_mma_disassemble_acc(&resultReal.packet, accReal);
__builtin_mma_disassemble_acc(&resultImag.packet, accImag);
PacketBlock<Packetc, 8> tRes;
- bload<DataMapper, Packetc, Index, accColsC, ColMajor, true, 4>(tRes, data, i, 0);
+ bload<DataMapper, Packetc, Index, accColsC, ColMajor, true, 4, full>(tRes, data, i, 0);
- PacketBlock<Packet,4> taccReal, taccImag;
- bscalec<Packet,4>(resultReal, resultImag, alphaReal, alphaImag, taccReal, taccImag);
+ PacketBlock<Packet, 4> taccReal, taccImag;
+ bscalec<Packet, 4, (accCols != accCols2)>(resultReal, resultImag, alphaReal, alphaImag, taccReal, taccImag, pMask);
PacketBlock<Packetc, 4> acc1, acc2;
- bcouple<Packet, Packetc, 4>(taccReal, taccImag, tRes, acc1, acc2);
+ bcouple<Packet, Packetc, 4, full>(taccReal, taccImag, tRes, acc1, acc2);
- data.template storePacketBlock<Packetc, 4>(i, 0, acc1);
- data.template storePacketBlock<Packetc, 4>(i + accColsC, 0, acc2);
+ bstore<DataMapper, Packetc, Index, 4>(acc1, data, i);
+ if (full) {
+ bstore<DataMapper, Packetc, Index, 4>(acc2, data, i + accColsC);
+ }
}
// Defaults to float32, since Eigen still supports C++03 we can't use default template arguments
@@ -80,18 +89,6 @@
}
template<typename LhsPacket, typename RhsPacket, bool NegativeAccumulate>
-EIGEN_ALWAYS_INLINE void pgerMMA(__vector_quad* acc, const PacketBlock<Packet2d,2>& a, const Packet2d& b)
-{
- __vector_pair* a0 = reinterpret_cast<__vector_pair *>(const_cast<Packet2d *>(&a.packet[0]));
- if(NegativeAccumulate)
- {
- __builtin_mma_xvf64gernp(acc, *a0, (__vector unsigned char)b);
- } else {
- __builtin_mma_xvf64gerpp(acc, *a0, (__vector unsigned char)b);
- }
-}
-
-template<typename LhsPacket, typename RhsPacket, bool NegativeAccumulate>
EIGEN_ALWAYS_INLINE void pgerMMA(__vector_quad* acc, const __vector_pair& a, const Packet2d& b)
{
if(NegativeAccumulate)
@@ -102,18 +99,13 @@
}
}
-template<typename LhsPacket, typename RhsPacket, bool NegativeAccumulate>
-EIGEN_ALWAYS_INLINE void pgerMMA(__vector_quad*, const __vector_pair&, const Packet4f&)
-{
- // Just for compilation
-}
-
-template<typename Scalar, typename Packet, typename RhsPacket, bool ConjugateLhs, bool ConjugateRhs, bool LhsIsReal, bool RhsIsReal>
-EIGEN_ALWAYS_INLINE void pgercMMA(__vector_quad* accReal, __vector_quad* accImag, const Packet& lhsV, const Packet& lhsVi, const RhsPacket& rhsV, const RhsPacket& rhsVi)
+template<typename Packet, typename RhsPacket, bool ConjugateLhs, bool ConjugateRhs, bool LhsIsReal, bool RhsIsReal>
+EIGEN_ALWAYS_INLINE void pgercMMA(__vector_quad* accReal, __vector_quad* accImag, const Packet& lhsV, Packet& lhsVi, const RhsPacket& rhsV, RhsPacket& rhsVi)
{
pgerMMA<Packet, RhsPacket, false>(accReal, rhsV, lhsV);
if(LhsIsReal) {
pgerMMA<Packet, RhsPacket, ConjugateRhs>(accImag, rhsVi, lhsV);
+ EIGEN_UNUSED_VARIABLE(lhsVi);
} else {
if(!RhsIsReal) {
pgerMMA<Packet, RhsPacket, ConjugateLhs == ConjugateRhs>(accReal, rhsVi, lhsVi);
@@ -126,35 +118,33 @@
}
// This is necessary because ploadRhs for double returns a pair of vectors when MMA is enabled.
+template<typename Packet>
+EIGEN_ALWAYS_INLINE Packet ploadRhs(const __UNPACK_TYPE__(Packet)* rhs)
+{
+ return ploadu<Packet>(rhs);
+}
+
template<typename Scalar, typename Packet>
EIGEN_ALWAYS_INLINE void ploadRhsMMA(const Scalar* rhs, Packet& rhsV)
{
- rhsV = ploadRhs<Scalar, Packet>(rhs);
+ rhsV = ploadRhs<Packet>(rhs);
}
template<>
-EIGEN_ALWAYS_INLINE void ploadRhsMMA<double, PacketBlock<Packet2d, 2> >(const double* rhs, PacketBlock<Packet2d, 2>& rhsV)
-{
- rhsV.packet[0] = ploadRhs<double, Packet2d>(rhs);
- rhsV.packet[1] = ploadRhs<double, Packet2d>(rhs + (sizeof(Packet2d) / sizeof(double)));
-}
-
-template<>
-EIGEN_ALWAYS_INLINE void ploadRhsMMA<double, __vector_pair>(const double* rhs, __vector_pair& rhsV)
+EIGEN_ALWAYS_INLINE void ploadRhsMMA(const double* rhs, __vector_pair& rhsV)
{
#if EIGEN_COMP_LLVM
__builtin_vsx_assemble_pair(&rhsV,
- reinterpret_cast<__vector unsigned char>(ploadRhs<double, Packet2d>(rhs + (sizeof(Packet2d) / sizeof(double)))),
- reinterpret_cast<__vector unsigned char>(ploadRhs<double, Packet2d>(rhs)));
+ reinterpret_cast<__vector unsigned char>(ploadRhs<Packet2d>(rhs + (sizeof(Packet2d) / sizeof(double)))),
+ reinterpret_cast<__vector unsigned char>(ploadRhs<Packet2d>(rhs)));
#else
__asm__ ("lxvp %x0,%1" : "=wa" (rhsV) : "Y" (*rhs));
#endif
}
-template<>
-EIGEN_ALWAYS_INLINE void ploadRhsMMA(const float*, __vector_pair&)
+EIGEN_ALWAYS_INLINE void ploadLhsMMA(const double* lhs, __vector_pair& lhsV)
{
- // Just for compilation
+ ploadRhsMMA(lhs, lhsV);
}
// PEEL_MMA loop factor.
@@ -163,98 +153,116 @@
#define MICRO_MMA_UNROLL(func) \
func(0) func(1) func(2) func(3) func(4) func(5) func(6) func(7)
-#define MICRO_MMA_LOAD_ONE(iter) \
- if (unroll_factor > iter) { \
- lhsV##iter = ploadLhs<Scalar, Packet>(lhs_ptr##iter); \
- lhs_ptr##iter += accCols; \
- } else { \
- EIGEN_UNUSED_VARIABLE(lhsV##iter); \
- }
+#define MICRO_MMA_WORK(func, type, peel) \
+ func(0,type,peel) func(1,type,peel) func(2,type,peel) func(3,type,peel) \
+ func(4,type,peel) func(5,type,peel) func(6,type,peel) func(7,type,peel)
#define MICRO_MMA_WORK_ONE(iter, type, peel) \
if (unroll_factor > iter) { \
pgerMMA<Packet, type, false>(&accZero##iter, rhsV##peel, lhsV##iter); \
}
-#define MICRO_MMA_TYPE_PEEL(func, func2, type, peel) \
+#define MICRO_MMA_WORK_TWO(iter, type, peel) \
+ if (unroll_factor > iter) { \
+ pgerMMA<Packet, type, false>(&accZero##iter, rhsV##peel, lhsV2##iter.packet[peel & 1]); \
+ }
+
+#define MICRO_MMA_LOAD1_TWO(lhs_ptr, iter) \
+ if (unroll_factor > iter) { \
+ if (MICRO_NORMAL(iter)) { \
+ ploadLhsMMA(reinterpret_cast<const double*>(lhs_ptr##iter), plhsV##iter); \
+ __builtin_vsx_disassemble_pair(reinterpret_cast<void*>(&lhsV2##iter.packet), &plhsV##iter); \
+ lhs_ptr##iter += accCols*2; \
+ } else { \
+ lhsV2##iter.packet[0] = ploadLhs<Packet>(lhs_ptr##iter); \
+ lhsV2##iter.packet[1] = ploadLhs<Packet>(lhs_ptr##iter + accCols2); \
+ lhs_ptr##iter += accCols2*2; \
+ EIGEN_UNUSED_VARIABLE(plhsV##iter) \
+ } \
+ } else { \
+ EIGEN_UNUSED_VARIABLE(lhsV2##iter); \
+ EIGEN_UNUSED_VARIABLE(plhsV##iter) \
+ }
+
+#define MICRO_MMA_LOAD_TWO(iter) MICRO_MMA_LOAD1_TWO(lhs_ptr, iter)
+
+#define MICRO_MMA_TYPE_PEEL(funcw, funcl, type, peel) \
if (PEEL_MMA > peel) { \
Packet lhsV0, lhsV1, lhsV2, lhsV3, lhsV4, lhsV5, lhsV6, lhsV7; \
- ploadRhsMMA<Scalar, type>(rhs_ptr + (accRows * peel), rhsV##peel); \
- MICRO_MMA_UNROLL(func2); \
- func(0,type,peel) func(1,type,peel) func(2,type,peel) func(3,type,peel) \
- func(4,type,peel) func(5,type,peel) func(6,type,peel) func(7,type,peel) \
+ ploadRhsMMA(rhs_ptr + (accRows * peel), rhsV##peel); \
+ MICRO_MMA_UNROLL(funcl) \
+ MICRO_MMA_WORK(funcw, type, peel) \
} else { \
EIGEN_UNUSED_VARIABLE(rhsV##peel); \
}
-#define MICRO_MMA_UNROLL_TYPE_PEEL(func, func2, type) \
+#define MICRO_MMA_TYPE_PEEL2(funcw1, funcl1, funcw2, funcl2, type, peel1, peel2) \
+ if (PEEL_MMA > peel2) { \
+ PacketBlock<Packet,2> lhsV20, lhsV21, lhsV22, lhsV23, lhsV24, lhsV25, lhsV26, lhsV27; \
+ __vector_pair plhsV0, plhsV1, plhsV2, plhsV3, plhsV4, plhsV5, plhsV6, plhsV7; \
+ ploadRhsMMA(rhs_ptr + (accRows * peel1), rhsV##peel1); \
+ ploadRhsMMA(rhs_ptr + (accRows * peel2), rhsV##peel2); \
+ MICRO_MMA_UNROLL(funcl2) \
+ MICRO_MMA_WORK(funcw2, type, peel1) \
+ MICRO_MMA_WORK(funcw2, type, peel2) \
+ } else { \
+ MICRO_MMA_TYPE_PEEL(funcw1, funcl1, type, peel1) \
+ }
+
+#define MICRO_MMA_UNROLL_TYPE_PEEL2(funcw1, funcl1, funcw2, funcl2, type) \
type rhsV0, rhsV1, rhsV2, rhsV3, rhsV4, rhsV5, rhsV6, rhsV7; \
- MICRO_MMA_TYPE_PEEL(func,func2,type,0); MICRO_MMA_TYPE_PEEL(func,func2,type,1); \
- MICRO_MMA_TYPE_PEEL(func,func2,type,2); MICRO_MMA_TYPE_PEEL(func,func2,type,3); \
- MICRO_MMA_TYPE_PEEL(func,func2,type,4); MICRO_MMA_TYPE_PEEL(func,func2,type,5); \
- MICRO_MMA_TYPE_PEEL(func,func2,type,6); MICRO_MMA_TYPE_PEEL(func,func2,type,7);
+ MICRO_MMA_TYPE_PEEL2(funcw1,funcl1,funcw2,funcl2,type,0,1) \
+ MICRO_MMA_TYPE_PEEL2(funcw1,funcl1,funcw2,funcl2,type,2,3) \
+ MICRO_MMA_TYPE_PEEL2(funcw1,funcl1,funcw2,funcl2,type,4,5) \
+ MICRO_MMA_TYPE_PEEL2(funcw1,funcl1,funcw2,funcl2,type,6,7)
-#define MICRO_MMA_UNROLL_TYPE_ONE(func, func2, type) \
+#define MICRO_MMA_UNROLL_TYPE_ONE(funcw, funcl, type) \
type rhsV0; \
- MICRO_MMA_TYPE_PEEL(func,func2,type,0);
+ MICRO_MMA_TYPE_PEEL(funcw,funcl,type,0)
-#define MICRO_MMA_ONE_PEEL \
- if (sizeof(Scalar) == sizeof(float)) { \
- MICRO_MMA_UNROLL_TYPE_PEEL(MICRO_MMA_WORK_ONE, MICRO_MMA_LOAD_ONE, RhsPacket); \
- } else { \
- MICRO_MMA_UNROLL_TYPE_PEEL(MICRO_MMA_WORK_ONE, MICRO_MMA_LOAD_ONE, __vector_pair); \
- } \
- rhs_ptr += (accRows * PEEL_MMA);
+#define MICRO_MMA_UNROLL_TYPE(MICRO_MMA_TYPE, size) \
+ MICRO_MMA_TYPE(MICRO_MMA_WORK_ONE, MICRO_LOAD_ONE, RhsPacket) \
+ rhs_ptr += (accRows * size);
-#define MICRO_MMA_ONE \
- if (sizeof(Scalar) == sizeof(float)) { \
- MICRO_MMA_UNROLL_TYPE_ONE(MICRO_MMA_WORK_ONE, MICRO_MMA_LOAD_ONE, RhsPacket); \
- } else { \
- MICRO_MMA_UNROLL_TYPE_ONE(MICRO_MMA_WORK_ONE, MICRO_MMA_LOAD_ONE, __vector_pair); \
- } \
- rhs_ptr += accRows;
+#define MICRO_MMA_UNROLL_TYPE2(MICRO_MMA_TYPE, size) \
+ MICRO_MMA_TYPE(MICRO_MMA_WORK_ONE, MICRO_LOAD_ONE, MICRO_MMA_WORK_TWO, MICRO_MMA_LOAD_TWO, RhsPacket) \
+ rhs_ptr += (accRows * size);
+
+#define MICRO_MMA_ONE_PEEL MICRO_MMA_UNROLL_TYPE2(MICRO_MMA_UNROLL_TYPE_PEEL2, PEEL_MMA)
+
+#define MICRO_MMA_ONE MICRO_MMA_UNROLL_TYPE(MICRO_MMA_UNROLL_TYPE_ONE, 1)
#define MICRO_MMA_DST_PTR_ONE(iter) \
if (unroll_factor > iter) { \
- bsetzeroMMA<Scalar, Packet>(&accZero##iter); \
+ bsetzeroMMA(&accZero##iter); \
} else { \
EIGEN_UNUSED_VARIABLE(accZero##iter); \
}
#define MICRO_MMA_DST_PTR MICRO_MMA_UNROLL(MICRO_MMA_DST_PTR_ONE)
-#define MICRO_MMA_SRC_PTR_ONE(iter) \
- if (unroll_factor > iter) { \
- lhs_ptr##iter = lhs_base + ( (row/accCols) + iter )*strideA*accCols; \
- } else { \
- EIGEN_UNUSED_VARIABLE(lhs_ptr##iter); \
- }
+#define MICRO_MMA_SRC_PTR MICRO_MMA_UNROLL(MICRO_SRC_PTR_ONE)
-#define MICRO_MMA_SRC_PTR MICRO_MMA_UNROLL(MICRO_MMA_SRC_PTR_ONE)
-
-#define MICRO_MMA_PREFETCH_ONE(iter) \
- if (unroll_factor > iter) { \
- EIGEN_POWER_PREFETCH(lhs_ptr##iter); \
- }
-
-#define MICRO_MMA_PREFETCH MICRO_MMA_UNROLL(MICRO_MMA_PREFETCH_ONE)
+#define MICRO_MMA_PREFETCH MICRO_MMA_UNROLL(MICRO_PREFETCH_ONE)
#define MICRO_MMA_STORE_ONE(iter) \
if (unroll_factor > iter) { \
- storeAccumulator<DataMapper, Index, Packet, accCols>(row + iter*accCols, res, pAlpha, &accZero##iter); \
+ storeAccumulator<DataMapper, Index, Packet, accCols, (unroll_factor != (iter + 1)) ? accCols : accCols2>(row + iter*accCols, res, pAlpha, pMask, &accZero##iter); \
}
#define MICRO_MMA_STORE MICRO_MMA_UNROLL(MICRO_MMA_STORE_ONE)
-template<int unroll_factor, typename Scalar, typename Packet, typename RhsPacket, typename DataMapper, typename Index, const Index accRows, const Index accCols>
+template<int unroll_factor, typename Scalar, typename Packet, typename RhsPacket, typename DataMapper, typename Index, const Index accRows, const Index accCols, const Index accCols2>
EIGEN_ALWAYS_INLINE void gemm_unrolled_MMA_iteration(
const DataMapper& res,
const Scalar* lhs_base,
const Scalar* rhs_base,
Index depth,
Index strideA,
+ Index offsetA,
Index& row,
- const Packet& pAlpha)
+ const Packet& pAlpha,
+ const Packet& pMask)
{
const Scalar* rhs_ptr = rhs_base;
const Scalar* lhs_ptr0 = NULL, * lhs_ptr1 = NULL, * lhs_ptr2 = NULL, * lhs_ptr3 = NULL, * lhs_ptr4 = NULL, * lhs_ptr5 = NULL, * lhs_ptr6 = NULL, * lhs_ptr7 = NULL;
@@ -263,8 +271,8 @@
MICRO_MMA_SRC_PTR
MICRO_MMA_DST_PTR
- Index k = 0;
- for(; k + PEEL_MMA <= depth; k+= PEEL_MMA)
+ Index k = 0, depth2 = depth - PEEL_MMA;
+ for(; k <= depth2; k += PEEL_MMA)
{
EIGEN_POWER_PREFETCH(rhs_ptr);
MICRO_MMA_PREFETCH
@@ -276,9 +284,13 @@
}
MICRO_MMA_STORE
- row += unroll_factor*accCols;
+ MICRO_UPDATE
}
+#define MICRO_MMA_UNROLL_ITER2(N, M) \
+ gemm_unrolled_MMA_iteration<N + (M ? 1 : 0), Scalar, Packet, RhsPacket, DataMapper, Index, accRows, accCols, M ? M : accCols>(res3, lhs_base, rhs_base, depth, strideA, offsetA, row, pAlpha, pMask); \
+ if (M) return;
+
template<typename Scalar, typename Packet, typename RhsPacket, typename DataMapper, typename Index, const Index accRows, const Index accCols>
EIGEN_ALWAYS_INLINE void gemmMMA_cols(
const DataMapper& res,
@@ -291,7 +303,6 @@
Index offsetB,
Index col,
Index rows,
- Index cols,
Index remaining_rows,
const Packet& pAlpha,
const Packet& pMask)
@@ -304,42 +315,42 @@
#define MAX_MMA_UNROLL 7
while(row + MAX_MMA_UNROLL*accCols <= rows) {
- gemm_unrolled_MMA_iteration<MAX_MMA_UNROLL, Scalar, Packet, RhsPacket, DataMapper, Index, accRows, accCols>(res3, lhs_base, rhs_base, depth, strideA, row, pAlpha);
+ MICRO_MMA_UNROLL_ITER2(MAX_MMA_UNROLL, 0);
}
switch( (rows-row)/accCols ) {
#if MAX_MMA_UNROLL > 7
case 7:
- gemm_unrolled_MMA_iteration<7, Scalar, Packet, RhsPacket, DataMapper, Index, accRows, accCols>(res3, lhs_base, rhs_base, depth, strideA, row, pAlpha);
+ MICRO_UNROLL_ITER(MICRO_MMA_UNROLL_ITER2, 7)
break;
#endif
#if MAX_MMA_UNROLL > 6
case 6:
- gemm_unrolled_MMA_iteration<6, Scalar, Packet, RhsPacket, DataMapper, Index, accRows, accCols>(res3, lhs_base, rhs_base, depth, strideA, row, pAlpha);
+ MICRO_UNROLL_ITER(MICRO_MMA_UNROLL_ITER2, 6)
break;
#endif
#if MAX_MMA_UNROLL > 5
case 5:
- gemm_unrolled_MMA_iteration<5, Scalar, Packet, RhsPacket, DataMapper, Index, accRows, accCols>(res3, lhs_base, rhs_base, depth, strideA, row, pAlpha);
+ MICRO_UNROLL_ITER(MICRO_MMA_UNROLL_ITER2, 5)
break;
#endif
#if MAX_MMA_UNROLL > 4
case 4:
- gemm_unrolled_MMA_iteration<4, Scalar, Packet, RhsPacket, DataMapper, Index, accRows, accCols>(res3, lhs_base, rhs_base, depth, strideA, row, pAlpha);
+ MICRO_UNROLL_ITER(MICRO_MMA_UNROLL_ITER2, 4)
break;
#endif
#if MAX_MMA_UNROLL > 3
case 3:
- gemm_unrolled_MMA_iteration<3, Scalar, Packet, RhsPacket, DataMapper, Index, accRows, accCols>(res3, lhs_base, rhs_base, depth, strideA, row, pAlpha);
+ MICRO_UNROLL_ITER(MICRO_MMA_UNROLL_ITER2, 3)
break;
#endif
#if MAX_MMA_UNROLL > 2
case 2:
- gemm_unrolled_MMA_iteration<2, Scalar, Packet, RhsPacket, DataMapper, Index, accRows, accCols>(res3, lhs_base, rhs_base, depth, strideA, row, pAlpha);
+ MICRO_UNROLL_ITER(MICRO_MMA_UNROLL_ITER2, 2)
break;
#endif
#if MAX_MMA_UNROLL > 1
case 1:
- gemm_unrolled_MMA_iteration<1, Scalar, Packet, RhsPacket, DataMapper, Index, accRows, accCols>(res3, lhs_base, rhs_base, depth, strideA, row, pAlpha);
+ MICRO_UNROLL_ITER(MICRO_MMA_UNROLL_ITER2, 1)
break;
#endif
default:
@@ -349,7 +360,7 @@
if(remaining_rows > 0)
{
- gemm_extra_row<Scalar, Packet, DataMapper, Index, accRows, accCols>(res3, blockA, rhs_base, depth, strideA, offsetA, row, col, rows, cols, remaining_rows, pAlpha, pMask);
+ gemm_extra_row<Scalar, Packet, DataMapper, Index, accRows, accCols>(res3, blockA, rhs_base, depth, strideA, offsetA, strideB, row, rows, remaining_rows, pAlpha, pMask);
}
}
@@ -364,16 +375,20 @@
const Packet pAlpha = pset1<Packet>(alpha);
const Packet pMask = bmask<Packet>(remaining_rows);
+ typedef typename std::conditional_t<(sizeof(Scalar) == sizeof(float)), RhsPacket, __vector_pair> RhsPacket2;
+
Index col = 0;
for(; col + accRows <= cols; col += accRows)
{
- gemmMMA_cols<Scalar, Packet, RhsPacket, DataMapper, Index, accRows, accCols>(res, blockA, blockB, depth, strideA, offsetA, strideB, offsetB, col, rows, cols, remaining_rows, pAlpha, pMask);
+ gemmMMA_cols<Scalar, Packet, RhsPacket2, DataMapper, Index, accRows, accCols>(res, blockA, blockB, depth, strideA, offsetA, strideB, offsetB, col, rows, remaining_rows, pAlpha, pMask);
}
- gemm_extra_cols<Scalar, Packet, DataMapper, Index, accCols>(res, blockA, blockB, depth, strideA, offsetA, strideB, offsetB, col, rows, cols, remaining_rows, pAlpha, pMask);
+ if (col != cols)
+ {
+ gemm_extra_cols<Scalar, Packet, DataMapper, Index, accCols>(res, blockA, blockB, depth, strideA, offsetA, strideB, offsetB, col, rows, cols, remaining_rows, pAlpha, pMask);
+ }
}
-#define accColsC (accCols / 2)
#define advanceRows ((LhsIsReal) ? 1 : 2)
#define advanceCols ((RhsIsReal) ? 1 : 2)
@@ -383,74 +398,104 @@
#define MICRO_COMPLEX_MMA_UNROLL(func) \
func(0) func(1) func(2) func(3)
-#define MICRO_COMPLEX_MMA_LOAD_ONE(iter) \
- if (unroll_factor > iter) { \
- lhsV##iter = ploadLhs<Scalar, Packet>(lhs_ptr_real##iter); \
- if(!LhsIsReal) { \
- lhsVi##iter = ploadLhs<Scalar, Packet>(lhs_ptr_real##iter + imag_delta); \
- } else { \
- EIGEN_UNUSED_VARIABLE(lhsVi##iter); \
- } \
- lhs_ptr_real##iter += accCols; \
- } else { \
- EIGEN_UNUSED_VARIABLE(lhsV##iter); \
- EIGEN_UNUSED_VARIABLE(lhsVi##iter); \
- }
+#define MICRO_COMPLEX_MMA_WORK(func, type, peel) \
+ func(0,type,peel) func(1,type,peel) func(2,type,peel) func(3,type,peel)
#define MICRO_COMPLEX_MMA_WORK_ONE(iter, type, peel) \
if (unroll_factor > iter) { \
- pgercMMA<Scalar, Packet, type, ConjugateLhs, ConjugateRhs, LhsIsReal, RhsIsReal>(&accReal##iter, &accImag##iter, lhsV##iter, lhsVi##iter, rhsV##peel, rhsVi##peel); \
+ pgercMMA<Packet, type, ConjugateLhs, ConjugateRhs, LhsIsReal, RhsIsReal>(&accReal##iter, &accImag##iter, lhsV##iter, lhsVi##iter, rhsV##peel, rhsVi##peel); \
}
-#define MICRO_COMPLEX_MMA_TYPE_PEEL(func, func2, type, peel) \
+#define MICRO_COMPLEX_MMA_WORK_TWO(iter, type, peel) \
+ if (unroll_factor > iter) { \
+ pgercMMA<Packet, type, ConjugateLhs, ConjugateRhs, LhsIsReal, RhsIsReal>(&accReal##iter, &accImag##iter, lhsV2##iter.packet[peel & 1], lhsVi2##iter.packet[peel & 1], rhsV##peel, rhsVi##peel); \
+ }
+
+#define MICRO_COMPLEX_MMA_LOAD1_TWO(lhs_ptr, iter) \
+ if (!LhsIsReal && (unroll_factor > iter)) { \
+ if (MICRO_NORMAL(iter)) { \
+ ploadLhsMMA(reinterpret_cast<const double*>(lhs_ptr_real##iter + imag_delta), plhsVi##iter); \
+ __builtin_vsx_disassemble_pair(reinterpret_cast<void*>(&lhsVi2##iter.packet), &plhsVi##iter); \
+ } else { \
+ lhsVi2##iter.packet[0] = ploadLhs<Packet>(lhs_ptr_real##iter + imag_delta2); \
+ lhsVi2##iter.packet[1] = ploadLhs<Packet>(lhs_ptr_real##iter + imag_delta2 + accCols2); \
+ EIGEN_UNUSED_VARIABLE(plhsVi##iter) \
+ } \
+ } else { \
+ EIGEN_UNUSED_VARIABLE(lhsVi2##iter); \
+ EIGEN_UNUSED_VARIABLE(plhsVi##iter) \
+ } \
+ MICRO_MMA_LOAD1_TWO(lhs_ptr_real, iter)
+
+#define MICRO_COMPLEX_MMA_LOAD_TWO(iter) MICRO_COMPLEX_MMA_LOAD1_TWO(lhs_ptr, iter)
+
+#define MICRO_COMPLEX_MMA_TYPE_PEEL(funcw, funcl, type, peel) \
if (PEEL_COMPLEX_MMA > peel) { \
Packet lhsV0, lhsV1, lhsV2, lhsV3; \
Packet lhsVi0, lhsVi1, lhsVi2, lhsVi3; \
- ploadRhsMMA<Scalar, type>(rhs_ptr_real + (accRows * peel), rhsV##peel); \
+ ploadRhsMMA(rhs_ptr_real + (accRows * peel), rhsV##peel); \
if(!RhsIsReal) { \
- ploadRhsMMA<Scalar, type>(rhs_ptr_imag + (accRows * peel), rhsVi##peel); \
+ ploadRhsMMA(rhs_ptr_imag + (accRows * peel), rhsVi##peel); \
} else { \
EIGEN_UNUSED_VARIABLE(rhsVi##peel); \
} \
- MICRO_COMPLEX_MMA_UNROLL(func2); \
- func(0,type,peel) func(1,type,peel) func(2,type,peel) func(3,type,peel) \
+ MICRO_COMPLEX_MMA_UNROLL(funcl) \
+ MICRO_COMPLEX_MMA_WORK(funcw, type, peel) \
} else { \
EIGEN_UNUSED_VARIABLE(rhsV##peel); \
EIGEN_UNUSED_VARIABLE(rhsVi##peel); \
}
-#define MICRO_COMPLEX_MMA_UNROLL_TYPE_PEEL(func, func2, type) \
+#define MICRO_COMPLEX_MMA_TYPE_PEEL2(funcw1, funcl1, funcw2, funcl2, type, peel1, peel2) \
+ if (PEEL_COMPLEX_MMA > peel2) { \
+ PacketBlock<Packet,2> lhsV20, lhsV21, lhsV22, lhsV23; \
+ PacketBlock<Packet,2> lhsVi20, lhsVi21, lhsVi22, lhsVi23; \
+ __vector_pair plhsV0, plhsV1, plhsV2, plhsV3; \
+ __vector_pair plhsVi0, plhsVi1, plhsVi2, plhsVi3; \
+ ploadRhsMMA(rhs_ptr_real + (accRows * peel1), rhsV##peel1); \
+ ploadRhsMMA(rhs_ptr_real + (accRows * peel2), rhsV##peel2); \
+ if(!RhsIsReal) { \
+ ploadRhsMMA(rhs_ptr_imag + (accRows * peel1), rhsVi##peel1); \
+ ploadRhsMMA(rhs_ptr_imag + (accRows * peel2), rhsVi##peel2); \
+ } else { \
+ EIGEN_UNUSED_VARIABLE(rhsVi##peel1); \
+ EIGEN_UNUSED_VARIABLE(rhsVi##peel2); \
+ } \
+ MICRO_COMPLEX_MMA_UNROLL(funcl2) \
+ MICRO_COMPLEX_MMA_WORK(funcw2, type, peel1) \
+ MICRO_COMPLEX_MMA_WORK(funcw2, type, peel2) \
+ } else { \
+ MICRO_COMPLEX_MMA_TYPE_PEEL(funcw1, funcl1, type, peel1) \
+ }
+
+#define MICRO_COMPLEX_MMA_UNROLL_TYPE_PEEL2(funcw1, funcl1, funcw2, funcl2, type) \
type rhsV0, rhsV1, rhsV2, rhsV3; \
type rhsVi0, rhsVi1, rhsVi2, rhsVi3; \
- MICRO_COMPLEX_MMA_TYPE_PEEL(func,func2,type,0); MICRO_COMPLEX_MMA_TYPE_PEEL(func,func2,type,1); \
- MICRO_COMPLEX_MMA_TYPE_PEEL(func,func2,type,2); MICRO_COMPLEX_MMA_TYPE_PEEL(func,func2,type,3);
+ MICRO_COMPLEX_MMA_TYPE_PEEL2(funcw1,funcl1,funcw2,funcl2,type,0,1) \
+ MICRO_COMPLEX_MMA_TYPE_PEEL2(funcw1,funcl1,funcw2,funcl2,type,2,3)
-#define MICRO_COMPLEX_MMA_UNROLL_TYPE_ONE(func, func2, type) \
+#define MICRO_COMPLEX_MMA_UNROLL_TYPE_ONE(funcw, funcl, type) \
type rhsV0, rhsVi0; \
- MICRO_COMPLEX_MMA_TYPE_PEEL(func,func2,type,0);
+ MICRO_COMPLEX_MMA_TYPE_PEEL(funcw,funcl,type,0)
-#define MICRO_COMPLEX_MMA_ONE_PEEL \
- if (sizeof(Scalar) == sizeof(float)) { \
- MICRO_COMPLEX_MMA_UNROLL_TYPE_PEEL(MICRO_COMPLEX_MMA_WORK_ONE, MICRO_COMPLEX_MMA_LOAD_ONE, RhsPacket); \
- } else { \
- MICRO_COMPLEX_MMA_UNROLL_TYPE_PEEL(MICRO_COMPLEX_MMA_WORK_ONE, MICRO_COMPLEX_MMA_LOAD_ONE, __vector_pair); \
- } \
- rhs_ptr_real += (accRows * PEEL_COMPLEX_MMA); \
- if(!RhsIsReal) rhs_ptr_imag += (accRows * PEEL_COMPLEX_MMA);
+#define MICRO_COMPLEX_MMA_UNROLL_TYPE(MICRO_COMPLEX_MMA_TYPE, size) \
+ MICRO_COMPLEX_MMA_TYPE(MICRO_COMPLEX_MMA_WORK_ONE, MICRO_COMPLEX_LOAD_ONE, RhsPacket) \
+ rhs_ptr_real += (accRows * size); \
+ if(!RhsIsReal) rhs_ptr_imag += (accRows * size);
-#define MICRO_COMPLEX_MMA_ONE \
- if (sizeof(Scalar) == sizeof(float)) { \
- MICRO_COMPLEX_MMA_UNROLL_TYPE_ONE(MICRO_COMPLEX_MMA_WORK_ONE, MICRO_COMPLEX_MMA_LOAD_ONE, RhsPacket); \
- } else { \
- MICRO_COMPLEX_MMA_UNROLL_TYPE_ONE(MICRO_COMPLEX_MMA_WORK_ONE, MICRO_COMPLEX_MMA_LOAD_ONE, __vector_pair); \
- } \
- rhs_ptr_real += accRows; \
- if(!RhsIsReal) rhs_ptr_imag += accRows;
+#define MICRO_COMPLEX_MMA_UNROLL_TYPE2(MICRO_COMPLEX_MMA_TYPE, size) \
+ MICRO_COMPLEX_MMA_TYPE(MICRO_COMPLEX_MMA_WORK_ONE, MICRO_COMPLEX_LOAD_ONE, MICRO_COMPLEX_MMA_WORK_TWO, MICRO_COMPLEX_MMA_LOAD_TWO, RhsPacket) \
+ rhs_ptr_real += (accRows * size); \
+ if(!RhsIsReal) rhs_ptr_imag += (accRows * size);
+
+#define MICRO_COMPLEX_MMA_ONE_PEEL MICRO_COMPLEX_MMA_UNROLL_TYPE2(MICRO_COMPLEX_MMA_UNROLL_TYPE_PEEL2, PEEL_COMPLEX_MMA)
+
+#define MICRO_COMPLEX_MMA_ONE MICRO_COMPLEX_MMA_UNROLL_TYPE(MICRO_COMPLEX_MMA_UNROLL_TYPE_ONE, 1)
#define MICRO_COMPLEX_MMA_DST_PTR_ONE(iter) \
if (unroll_factor > iter) { \
- bsetzeroMMA<Scalar, Packet>(&accReal##iter); \
- bsetzeroMMA<Scalar, Packet>(&accImag##iter); \
+ bsetzeroMMA(&accReal##iter); \
+ bsetzeroMMA(&accImag##iter); \
} else { \
EIGEN_UNUSED_VARIABLE(accReal##iter); \
EIGEN_UNUSED_VARIABLE(accImag##iter); \
@@ -458,44 +503,35 @@
#define MICRO_COMPLEX_MMA_DST_PTR MICRO_COMPLEX_MMA_UNROLL(MICRO_COMPLEX_MMA_DST_PTR_ONE)
-#define MICRO_COMPLEX_MMA_SRC_PTR_ONE(iter) \
- if (unroll_factor > iter) { \
- lhs_ptr_real##iter = lhs_base + ( ((advanceRows*row)/accCols) + iter*advanceRows )*strideA*accCols; \
- } else { \
- EIGEN_UNUSED_VARIABLE(lhs_ptr_real##iter); \
- }
+#define MICRO_COMPLEX_MMA_SRC_PTR MICRO_COMPLEX_MMA_UNROLL(MICRO_COMPLEX_SRC_PTR_ONE)
-#define MICRO_COMPLEX_MMA_SRC_PTR MICRO_COMPLEX_MMA_UNROLL(MICRO_COMPLEX_MMA_SRC_PTR_ONE)
-
-#define MICRO_COMPLEX_MMA_PREFETCH_ONE(iter) \
- if (unroll_factor > iter) { \
- EIGEN_POWER_PREFETCH(lhs_ptr_real##iter); \
- }
-
-#define MICRO_COMPLEX_MMA_PREFETCH MICRO_COMPLEX_MMA_UNROLL(MICRO_COMPLEX_MMA_PREFETCH_ONE)
+#define MICRO_COMPLEX_MMA_PREFETCH MICRO_COMPLEX_MMA_UNROLL(MICRO_COMPLEX_PREFETCH_ONE)
#define MICRO_COMPLEX_MMA_STORE_ONE(iter) \
if (unroll_factor > iter) { \
- storeComplexAccumulator<DataMapper, Index, Packet, Packetc, accColsC>(row + iter*accCols, res, pAlphaReal, pAlphaImag, &accReal##iter, &accImag##iter); \
+ storeComplexAccumulator<DataMapper, Index, Packet, Packetc, accCols, (unroll_factor != (iter + 1)) ? accCols : accCols2>(row + iter*accCols, res, pAlphaReal, pAlphaImag, pMask, &accReal##iter, &accImag##iter); \
}
#define MICRO_COMPLEX_MMA_STORE MICRO_COMPLEX_MMA_UNROLL(MICRO_COMPLEX_MMA_STORE_ONE)
-template<int unroll_factor, typename Scalar, typename Packet, typename Packetc, typename RhsPacket, typename DataMapper, typename Index, const Index accRows, const Index accCols, bool ConjugateLhs, bool ConjugateRhs, bool LhsIsReal, bool RhsIsReal>
+template<int unroll_factor, typename Scalar, typename Packet, typename Packetc, typename RhsPacket, typename DataMapper, typename Index, const Index accRows, const Index accCols, const Index accCols2, bool ConjugateLhs, bool ConjugateRhs, bool LhsIsReal, bool RhsIsReal>
EIGEN_ALWAYS_INLINE void gemm_complex_unrolled_MMA_iteration(
const DataMapper& res,
const Scalar* lhs_base,
const Scalar* rhs_base,
Index depth,
Index strideA,
+ Index offsetA,
Index strideB,
Index& row,
const Packet& pAlphaReal,
- const Packet& pAlphaImag)
+ const Packet& pAlphaImag,
+ const Packet& pMask)
{
const Scalar* rhs_ptr_real = rhs_base;
const Scalar* rhs_ptr_imag = NULL;
const Index imag_delta = accCols*strideA;
+ const Index imag_delta2 = accCols2*strideA;
if(!RhsIsReal) {
rhs_ptr_imag = rhs_base + accRows*strideB;
} else {
@@ -508,8 +544,8 @@
MICRO_COMPLEX_MMA_SRC_PTR
MICRO_COMPLEX_MMA_DST_PTR
- Index k = 0;
- for(; k + PEEL_COMPLEX_MMA <= depth; k+= PEEL_COMPLEX_MMA)
+ Index k = 0, depth2 = depth - PEEL_COMPLEX_MMA;
+ for(; k <= depth2; k += PEEL_COMPLEX_MMA)
{
EIGEN_POWER_PREFETCH(rhs_ptr_real);
if(!RhsIsReal) {
@@ -524,9 +560,13 @@
}
MICRO_COMPLEX_MMA_STORE
- row += unroll_factor*accCols;
+ MICRO_COMPLEX_UPDATE
}
+#define MICRO_COMPLEX_MMA_UNROLL_ITER2(N, M) \
+ gemm_complex_unrolled_MMA_iteration<N + (M ? 1 : 0), Scalar, Packet, Packetc, RhsPacket, DataMapper, Index, accRows, accCols, M ? M : accCols, ConjugateLhs, ConjugateRhs, LhsIsReal, RhsIsReal>(res3, lhs_base, rhs_base, depth, strideA, offsetA, strideB, row, pAlphaReal, pAlphaImag, pMask); \
+ if (M) return;
+
template<typename Scalar, typename Packet, typename Packetc, typename RhsPacket, typename DataMapper, typename Index, const Index accRows, const Index accCols, bool ConjugateLhs, bool ConjugateRhs, bool LhsIsReal, bool RhsIsReal>
EIGEN_ALWAYS_INLINE void gemmMMA_complex_cols(
const DataMapper& res,
@@ -539,7 +579,6 @@
Index offsetB,
Index col,
Index rows,
- Index cols,
Index remaining_rows,
const Packet& pAlphaReal,
const Packet& pAlphaImag,
@@ -553,27 +592,27 @@
#define MAX_COMPLEX_MMA_UNROLL 4
while(row + MAX_COMPLEX_MMA_UNROLL*accCols <= rows) {
- gemm_complex_unrolled_MMA_iteration<MAX_COMPLEX_MMA_UNROLL, Scalar, Packet, Packetc, RhsPacket, DataMapper, Index, accRows, accCols, ConjugateLhs, ConjugateRhs, LhsIsReal, RhsIsReal>(res3, lhs_base, rhs_base, depth, strideA, strideB, row, pAlphaReal, pAlphaImag);
+ MICRO_COMPLEX_MMA_UNROLL_ITER2(MAX_COMPLEX_MMA_UNROLL, 0);
}
switch( (rows-row)/accCols ) {
#if MAX_COMPLEX_MMA_UNROLL > 4
case 4:
- gemm_complex_unrolled_MMA_iteration<4, Scalar, Packet, Packetc, RhsPacket, DataMapper, Index, accRows, accCols, ConjugateLhs, ConjugateRhs, LhsIsReal, RhsIsReal>(res3, lhs_base, rhs_base, depth, strideA, strideB, row, pAlphaReal, pAlphaImag);
+ MICRO_UNROLL_ITER(MICRO_COMPLEX_MMA_UNROLL_ITER2, 4)
break;
#endif
#if MAX_COMPLEX_MMA_UNROLL > 3
case 3:
- gemm_complex_unrolled_MMA_iteration<3, Scalar, Packet, Packetc, RhsPacket, DataMapper, Index, accRows, accCols, ConjugateLhs, ConjugateRhs, LhsIsReal, RhsIsReal>(res3, lhs_base, rhs_base, depth, strideA, strideB, row, pAlphaReal, pAlphaImag);
+ MICRO_UNROLL_ITER(MICRO_COMPLEX_MMA_UNROLL_ITER2, 3)
break;
#endif
#if MAX_COMPLEX_MMA_UNROLL > 2
case 2:
- gemm_complex_unrolled_MMA_iteration<2, Scalar, Packet, Packetc, RhsPacket, DataMapper, Index, accRows, accCols, ConjugateLhs, ConjugateRhs, LhsIsReal, RhsIsReal>(res3, lhs_base, rhs_base, depth, strideA, strideB, row, pAlphaReal, pAlphaImag);
+ MICRO_UNROLL_ITER(MICRO_COMPLEX_MMA_UNROLL_ITER2, 2)
break;
#endif
#if MAX_COMPLEX_MMA_UNROLL > 1
case 1:
- gemm_complex_unrolled_MMA_iteration<1, Scalar, Packet, Packetc, RhsPacket, DataMapper, Index, accRows, accCols, ConjugateLhs, ConjugateRhs, LhsIsReal, RhsIsReal>(res3, lhs_base, rhs_base, depth, strideA, strideB, row, pAlphaReal, pAlphaImag);
+ MICRO_UNROLL_ITER(MICRO_COMPLEX_MMA_UNROLL_ITER2, 1)
break;
#endif
default:
@@ -583,7 +622,7 @@
if(remaining_rows > 0)
{
- gemm_complex_extra_row<Scalar, Packet, Packetc, DataMapper, Index, accRows, accCols, ConjugateLhs, ConjugateRhs, LhsIsReal, RhsIsReal>(res3, blockA, rhs_base, depth, strideA, offsetA, strideB, row, col, rows, cols, remaining_rows, pAlphaReal, pAlphaImag, pMask);
+ gemm_complex_extra_row<Scalar, Packet, Packetc, DataMapper, Index, accRows, accCols, ConjugateLhs, ConjugateRhs, LhsIsReal, RhsIsReal>(res3, blockA, rhs_base, depth, strideA, offsetA, strideB, row, rows, remaining_rows, pAlphaReal, pAlphaImag, pMask);
}
}
@@ -602,25 +641,31 @@
const Scalar* blockA = (Scalar *) blockAc;
const Scalar* blockB = (Scalar *) blockBc;
+ typedef typename std::conditional_t<(sizeof(Scalar) == sizeof(float)), RhsPacket, __vector_pair> RhsPacket2;
+
Index col = 0;
for(; col + accRows <= cols; col += accRows)
{
- gemmMMA_complex_cols<Scalar, Packet, Packetc, RhsPacket, DataMapper, Index, accRows, accCols, ConjugateLhs, ConjugateRhs, LhsIsReal, RhsIsReal>(res, blockA, blockB, depth, strideA, offsetA, strideB, offsetB, col, rows, cols, remaining_rows, pAlphaReal, pAlphaImag, pMask);
+ gemmMMA_complex_cols<Scalar, Packet, Packetc, RhsPacket2, DataMapper, Index, accRows, accCols, ConjugateLhs, ConjugateRhs, LhsIsReal, RhsIsReal>(res, blockA, blockB, depth, strideA, offsetA, strideB, offsetB, col, rows, remaining_rows, pAlphaReal, pAlphaImag, pMask);
}
- gemm_complex_extra_cols<Scalar, Packet, Packetc, DataMapper, Index, accCols, ConjugateLhs, ConjugateRhs, LhsIsReal, RhsIsReal>(res, blockA, blockB, depth, strideA, offsetA, strideB, offsetB, col, rows, cols, remaining_rows, pAlphaReal, pAlphaImag, pMask);
+ if (col != cols)
+ {
+ gemm_complex_extra_cols<Scalar, Packet, Packetc, DataMapper, Index, accCols, ConjugateLhs, ConjugateRhs, LhsIsReal, RhsIsReal>(res, blockA, blockB, depth, strideA, offsetA, strideB, offsetB, col, rows, cols, remaining_rows, pAlphaReal, pAlphaImag, pMask);
+ }
}
#undef accColsC
#undef advanceRows
#undef advanceCols
-#if !EIGEN_COMP_LLVM
-#pragma GCC reset_options
-#endif
} // end namespace internal
} // end namespace Eigen
+#if defined(EIGEN_ALTIVEC_MMA_DYNAMIC_DISPATCH)
+#pragma GCC pop_options
+#endif
+
#endif // EIGEN_MATRIX_PRODUCT_MMA_ALTIVEC_H
diff --git a/Eigen/src/Core/arch/AltiVec/MatrixVectorProduct.h b/Eigen/src/Core/arch/AltiVec/MatrixVectorProduct.h
index d40ae53..6ab4d0b 100644
--- a/Eigen/src/Core/arch/AltiVec/MatrixVectorProduct.h
+++ b/Eigen/src/Core/arch/AltiVec/MatrixVectorProduct.h
@@ -12,7 +12,7 @@
#include "../../InternalHeaderCheck.h"
-#if defined(__MMA__) && !defined(EIGEN_ALTIVEC_DISABLE_MMA)
+#if defined(__MMA__) && !EIGEN_ALTIVEC_DISABLE_MMA
#if EIGEN_COMP_LLVM || (__GNUC__ > 10 || __GNUC_MINOR__ >= 3)
#define USE_GEMV_MMA
#endif
@@ -120,8 +120,8 @@
GEMV_BUILDPAIR_MMA(b##iter1, GEMV_LOADPACKET_COL(iter2), GEMV_LOADPACKET_COL((iter2) + 1));
#else
#define GEMV_LOADPAIR_COL_MMA(iter1, iter2) \
- const LhsScalar& src##iter1 = lhs(i + 0, j); \
- __asm__ ("lxvp %x0,%1(%2)" : "=wa" (b##iter1) : "K" (iter1 * 32), "a" (&src##iter1));
+ const LhsScalar& src##iter1 = lhs(i + ((iter1 * 32) / sizeof(LhsScalar)), j); \
+ b##iter1 = *reinterpret_cast<__vector_pair *>(const_cast<LhsScalar *>(&src##iter1));
#endif
#define GEMV_LOAD1A_COL_MMA(iter, N) \
@@ -503,7 +503,11 @@
/** \internal packet conjugate with real & imaginary operation inverted */
EIGEN_ALWAYS_INLINE Packet2cf pconjinv(const Packet2cf& a) {
+#ifdef __POWER8_VECTOR__
+ return Packet2cf(Packet4f(vec_neg(Packet2d(a.v))));
+#else
return Packet2cf(pxor(a.v, reinterpret_cast<Packet4f>(p16uc_COMPLEX32_CONJ_XOR2)));
+#endif
}
EIGEN_ALWAYS_INLINE Packet1cd pconjinv(const Packet1cd& a) {
@@ -555,12 +559,20 @@
/** \internal packet negate */
EIGEN_ALWAYS_INLINE Packet2cf pnegate2(Packet2cf a)
{
+#ifdef __POWER8_VECTOR__
+ return Packet2cf(vec_neg(a.v));
+#else
return Packet2cf(pxor(a.v, reinterpret_cast<Packet4f>(p16uc_COMPLEX32_NEGATE)));
+#endif
}
EIGEN_ALWAYS_INLINE Packet1cd pnegate2(Packet1cd a)
{
+#ifdef __POWER8_VECTOR__
+ return Packet1cd(vec_neg(a.v));
+#else
return Packet1cd(pxor(a.v, reinterpret_cast<Packet2d>(p16uc_COMPLEX64_NEGATE)));
+#endif
}
/** \internal flip the real & imaginary results and negate */
@@ -637,13 +649,24 @@
#endif
}
+#ifndef __POWER8_VECTOR__
+const Packet16uc p16uc_MERGEE = { 0x00, 0x01, 0x02, 0x03, 0x10, 0x11, 0x12, 0x13, 0x08, 0x09, 0x0A, 0x0B, 0x18, 0x19, 0x1A, 0x1B };
+
+const Packet16uc p16uc_MERGEO = { 0x04, 0x05, 0x06, 0x07, 0x14, 0x15, 0x16, 0x17, 0x0C, 0x0D, 0x0E, 0x0F, 0x1C, 0x1D, 0x1E, 0x1F };
+#endif
+
/** \internal load two vectors from the interleaved real & imaginary values of src */
template<typename RhsScalar>
EIGEN_ALWAYS_INLINE void pload_realimag_row(RhsScalar* src, Packet4f& r, Packet4f& i)
{
Packet4f t = ploadu<Packet4f>(reinterpret_cast<float*>(src));
+#ifdef __POWER8_VECTOR__
r = vec_mergee(t, t);
i = vec_mergeo(t, t);
+#else
+ r = vec_perm(t, t, p16uc_MERGEE);
+ i = vec_perm(t, t, p16uc_MERGEO);
+#endif
}
template<typename RhsScalar>
@@ -909,7 +932,7 @@
{
PResPacket c2 = pcplxflipconj(c0);
PResPacket c3 = pcplxflipconj(c1);
-#if EIGEN_COMP_LLVM
+#if EIGEN_COMP_LLVM || !defined(_ARCH_PWR10)
ScalarPacket c4 = pload_complex<ResPacket>(res + (iter2 * ResPacketSize));
ScalarPacket c5 = pload_complex<ResPacket>(res + ((iter2 + 1) * ResPacketSize));
PResPacket c6 = PResPacket(pmadd_complex<ScalarPacket, AlphaData>(c0.v, c2.v, c4, b0));
@@ -1389,8 +1412,8 @@
#else
#define GEMV_LOADPAIR_COL_COMPLEX_MMA(iter1, iter2) \
if (sizeof(LhsPacket) == 16) { \
- const LhsScalar& src = lhs(i + 0, j); \
- __asm__ ("lxvp %x0,%1(%2)" : "=wa" (a##iter1) : "K" (iter1 * 32), "a" (&src)); \
+ const LhsScalar& src = lhs(i + ((32 * iter1) / sizeof(LhsScalar)), j); \
+ a##iter1 = *reinterpret_cast<__vector_pair *>(const_cast<LhsScalar *>(&src)); \
EIGEN_UNUSED_VARIABLE(f##iter1); \
} else { \
f##iter1 = lhs.template load<PLhsPacket, Unaligned>(i + ((iter2) * ResPacketSize), j); \
@@ -1696,17 +1719,26 @@
};
#ifdef USE_GEMV_MMA
+static Packet16uc p16uc_ELEMENT_3 = { 0x0c,0x0d,0x0e,0x0f, 0x1c,0x1d,0x1e,0x1f, 0x0c,0x0d,0x0e,0x0f, 0x1c,0x1d,0x1e,0x1f };
+
/** \internal predux (add elements of a vector) from a MMA accumulator - real results */
template<typename ResScalar, typename ResPacket>
EIGEN_ALWAYS_INLINE ScalarBlock<ResScalar, 2> predux_real(__vector_quad* acc0, __vector_quad* acc1)
{
- ScalarBlock<ResScalar, 2> cc0;
+ union {
+ ScalarBlock<ResScalar, 2> cs;
+ double cd;
+ } cc0;
PacketBlock<ResPacket, 4> result0, result1;
__builtin_mma_disassemble_acc(&result0.packet, acc0);
__builtin_mma_disassemble_acc(&result1.packet, acc1);
- cc0.scalar[0] = result0.packet[0][0] + result0.packet[1][1] + result0.packet[2][2] + result0.packet[3][3];
- cc0.scalar[1] = result1.packet[0][0] + result1.packet[1][1] + result1.packet[2][2] + result1.packet[3][3];
- return cc0;
+ result0.packet[0] = vec_mergeh(result0.packet[0], result1.packet[0]);
+ result0.packet[1] = vec_mergeo(result0.packet[1], result1.packet[1]);
+ result0.packet[2] = vec_mergel(result0.packet[2], result1.packet[2]);
+ result0.packet[3] = vec_perm(result0.packet[3], result1.packet[3], p16uc_ELEMENT_3);
+ result0.packet[0] = vec_add(vec_add(result0.packet[0], result0.packet[2]), vec_add(result0.packet[1], result0.packet[3]));
+ cc0.cd = pfirst(reinterpret_cast<Packet2d>(result0.packet[0]));
+ return cc0.cs;
}
template<>
@@ -1996,9 +2028,9 @@
}
}
-#define EIGEN_POWER_GEMV_REAL_SPECIALIZE(Scalar, Major) \
+#define EIGEN_POWER_GEMV_REAL_SPECIALIZE_COL(Scalar) \
template<typename Index, typename LhsMapper, bool ConjugateLhs, typename RhsMapper, bool ConjugateRhs, int Version> \
-struct general_matrix_vector_product<Index, Scalar, LhsMapper, Major, ConjugateLhs, Scalar, RhsMapper, ConjugateRhs, Version> \
+struct general_matrix_vector_product<Index, Scalar, LhsMapper, ColMajor, ConjugateLhs, Scalar, RhsMapper, ConjugateRhs, Version> \
{ \
typedef typename ScalarBinaryOpTraits<Scalar, Scalar>::ReturnType ResScalar; \
\
@@ -2008,18 +2040,30 @@
const RhsMapper& rhs, \
ResScalar* res, Index resIncr, \
ResScalar alpha) { \
- if (Major == ColMajor) { \
- gemv_col<Index, Scalar, LhsMapper, Scalar, RhsMapper, ResScalar>(rows, cols, lhs, rhs, res, resIncr, alpha); \
- } else { \
- gemv_row<Index, Scalar, LhsMapper, Scalar, RhsMapper, ResScalar>(rows, cols, lhs, rhs, res, resIncr, alpha); \
- } \
+ gemv_col<Index, Scalar, LhsMapper, Scalar, RhsMapper, ResScalar>(rows, cols, lhs, rhs, res, resIncr, alpha); \
} \
};
-EIGEN_POWER_GEMV_REAL_SPECIALIZE(float, ColMajor)
-EIGEN_POWER_GEMV_REAL_SPECIALIZE(double, ColMajor)
-EIGEN_POWER_GEMV_REAL_SPECIALIZE(float, RowMajor)
-EIGEN_POWER_GEMV_REAL_SPECIALIZE(double, RowMajor)
+#define EIGEN_POWER_GEMV_REAL_SPECIALIZE_ROW(Scalar) \
+template<typename Index, typename LhsMapper, bool ConjugateLhs, typename RhsMapper, bool ConjugateRhs, int Version> \
+struct general_matrix_vector_product<Index, Scalar, LhsMapper, RowMajor, ConjugateLhs, Scalar, RhsMapper, ConjugateRhs, Version> \
+{ \
+ typedef typename ScalarBinaryOpTraits<Scalar, Scalar>::ReturnType ResScalar; \
+\
+ EIGEN_DEVICE_FUNC EIGEN_DONT_INLINE static void run( \
+ Index rows, Index cols, \
+ const LhsMapper& lhs, \
+ const RhsMapper& rhs, \
+ ResScalar* res, Index resIncr, \
+ ResScalar alpha) { \
+ gemv_row<Index, Scalar, LhsMapper, Scalar, RhsMapper, ResScalar>(rows, cols, lhs, rhs, res, resIncr, alpha); \
+ } \
+};
+
+EIGEN_POWER_GEMV_REAL_SPECIALIZE_COL(float)
+EIGEN_POWER_GEMV_REAL_SPECIALIZE_COL(double)
+EIGEN_POWER_GEMV_REAL_SPECIALIZE_ROW(float)
+EIGEN_POWER_GEMV_REAL_SPECIALIZE_ROW(double)
template<typename ResScalar, typename PResPacket, typename ResPacket, typename LhsPacket, typename RhsPacket>
EIGEN_ALWAYS_INLINE ScalarBlock<ResScalar, 2> predux_complex(PResPacket& a0, PResPacket& b0, ResPacket& a1, ResPacket& b1)
@@ -2311,9 +2355,9 @@
}
}
-#define EIGEN_POWER_GEMV_COMPLEX_SPECIALIZE(Scalar, LhsScalar, RhsScalar, Major) \
+#define EIGEN_POWER_GEMV_COMPLEX_SPECIALIZE_COL(Scalar, LhsScalar, RhsScalar) \
template<typename Index, typename LhsMapper, bool ConjugateLhs, typename RhsMapper, bool ConjugateRhs, int Version> \
-struct general_matrix_vector_product<Index, LhsScalar, LhsMapper, Major, ConjugateLhs, RhsScalar, RhsMapper, ConjugateRhs, Version> \
+struct general_matrix_vector_product<Index, LhsScalar, LhsMapper, ColMajor, ConjugateLhs, RhsScalar, RhsMapper, ConjugateRhs, Version> \
{ \
typedef typename ScalarBinaryOpTraits<LhsScalar, RhsScalar>::ReturnType ResScalar; \
\
@@ -2323,26 +2367,38 @@
const RhsMapper& rhs, \
ResScalar* res, Index resIncr, \
ResScalar alpha) { \
- if (Major == ColMajor) { \
- gemv_complex_col<Index, Scalar, LhsScalar, LhsMapper, ConjugateLhs, sizeof(Scalar) == sizeof(LhsScalar), RhsScalar, RhsMapper, ConjugateRhs, sizeof(Scalar) == sizeof(RhsScalar), ResScalar>(rows, cols, lhs, rhs, res, resIncr, alpha); \
- } else { \
- gemv_complex_row<Index, Scalar, LhsScalar, LhsMapper, ConjugateLhs, sizeof(Scalar) == sizeof(LhsScalar), RhsScalar, RhsMapper, ConjugateRhs, sizeof(Scalar) == sizeof(RhsScalar), ResScalar>(rows, cols, lhs, rhs, res, resIncr, alpha); \
- } \
+ gemv_complex_col<Index, Scalar, LhsScalar, LhsMapper, ConjugateLhs, sizeof(Scalar) == sizeof(LhsScalar), RhsScalar, RhsMapper, ConjugateRhs, sizeof(Scalar) == sizeof(RhsScalar), ResScalar>(rows, cols, lhs, rhs, res, resIncr, alpha); \
} \
};
-EIGEN_POWER_GEMV_COMPLEX_SPECIALIZE(float, float, std::complex<float>, ColMajor)
-EIGEN_POWER_GEMV_COMPLEX_SPECIALIZE(float, std::complex<float>, float, ColMajor)
-EIGEN_POWER_GEMV_COMPLEX_SPECIALIZE(float, std::complex<float>, std::complex<float>, ColMajor)
-EIGEN_POWER_GEMV_COMPLEX_SPECIALIZE(double, double, std::complex<double>, ColMajor)
-EIGEN_POWER_GEMV_COMPLEX_SPECIALIZE(double, std::complex<double>, double, ColMajor)
-EIGEN_POWER_GEMV_COMPLEX_SPECIALIZE(double, std::complex<double>, std::complex<double>, ColMajor)
-EIGEN_POWER_GEMV_COMPLEX_SPECIALIZE(float, float, std::complex<float>, RowMajor)
-EIGEN_POWER_GEMV_COMPLEX_SPECIALIZE(float, std::complex<float>, float, RowMajor)
-EIGEN_POWER_GEMV_COMPLEX_SPECIALIZE(float, std::complex<float>, std::complex<float>, RowMajor)
-EIGEN_POWER_GEMV_COMPLEX_SPECIALIZE(double, double, std::complex<double>, RowMajor)
-EIGEN_POWER_GEMV_COMPLEX_SPECIALIZE(double, std::complex<double>, double, RowMajor)
-EIGEN_POWER_GEMV_COMPLEX_SPECIALIZE(double, std::complex<double>, std::complex<double>, RowMajor)
+#define EIGEN_POWER_GEMV_COMPLEX_SPECIALIZE_ROW(Scalar, LhsScalar, RhsScalar) \
+template<typename Index, typename LhsMapper, bool ConjugateLhs, typename RhsMapper, bool ConjugateRhs, int Version> \
+struct general_matrix_vector_product<Index, LhsScalar, LhsMapper, RowMajor, ConjugateLhs, RhsScalar, RhsMapper, ConjugateRhs, Version> \
+{ \
+ typedef typename ScalarBinaryOpTraits<LhsScalar, RhsScalar>::ReturnType ResScalar; \
+\
+ EIGEN_DEVICE_FUNC EIGEN_DONT_INLINE static void run( \
+ Index rows, Index cols, \
+ const LhsMapper& lhs, \
+ const RhsMapper& rhs, \
+ ResScalar* res, Index resIncr, \
+ ResScalar alpha) { \
+ gemv_complex_row<Index, Scalar, LhsScalar, LhsMapper, ConjugateLhs, sizeof(Scalar) == sizeof(LhsScalar), RhsScalar, RhsMapper, ConjugateRhs, sizeof(Scalar) == sizeof(RhsScalar), ResScalar>(rows, cols, lhs, rhs, res, resIncr, alpha); \
+ } \
+};
+
+EIGEN_POWER_GEMV_COMPLEX_SPECIALIZE_COL(float, float, std::complex<float>)
+EIGEN_POWER_GEMV_COMPLEX_SPECIALIZE_COL(float, std::complex<float>, float)
+EIGEN_POWER_GEMV_COMPLEX_SPECIALIZE_COL(float, std::complex<float>, std::complex<float>)
+EIGEN_POWER_GEMV_COMPLEX_SPECIALIZE_COL(double, double, std::complex<double>)
+EIGEN_POWER_GEMV_COMPLEX_SPECIALIZE_COL(double, std::complex<double>, double)
+EIGEN_POWER_GEMV_COMPLEX_SPECIALIZE_COL(double, std::complex<double>, std::complex<double>)
+EIGEN_POWER_GEMV_COMPLEX_SPECIALIZE_ROW(float, float, std::complex<float>)
+EIGEN_POWER_GEMV_COMPLEX_SPECIALIZE_ROW(float, std::complex<float>, float)
+EIGEN_POWER_GEMV_COMPLEX_SPECIALIZE_ROW(float, std::complex<float>, std::complex<float>)
+EIGEN_POWER_GEMV_COMPLEX_SPECIALIZE_ROW(double, double, std::complex<double>)
+EIGEN_POWER_GEMV_COMPLEX_SPECIALIZE_ROW(double, std::complex<double>, double)
+EIGEN_POWER_GEMV_COMPLEX_SPECIALIZE_ROW(double, std::complex<double>, std::complex<double>)
#endif // EIGEN_MATRIX_VECTOR_PRODUCT_ALTIVEC_H
diff --git a/Eigen/src/Core/arch/AltiVec/PacketMath.h b/Eigen/src/Core/arch/AltiVec/PacketMath.h
index efb0242..945f36b 100755
--- a/Eigen/src/Core/arch/AltiVec/PacketMath.h
+++ b/Eigen/src/Core/arch/AltiVec/PacketMath.h
@@ -83,8 +83,10 @@
static EIGEN_DECLARE_CONST_FAST_Packet4i(MINUS1,-1); //{ -1, -1, -1, -1}
static EIGEN_DECLARE_CONST_FAST_Packet4ui(SIGN, 0x80000000u);
static EIGEN_DECLARE_CONST_FAST_Packet4ui(PREV0DOT5, 0x3EFFFFFFu);
+#ifndef __POWER8_VECTOR__
static EIGEN_DECLARE_CONST_FAST_Packet8us(ONE,1); //{ 1, 1, 1, 1, 1, 1, 1, 1}
static EIGEN_DECLARE_CONST_FAST_Packet16uc(ONE,1);
+#endif
static Packet4f p4f_MZERO = (Packet4f) vec_sl((Packet4ui)p4i_MINUS1, (Packet4ui)p4i_MINUS1); //{ 0x80000000, 0x80000000, 0x80000000, 0x80000000}
#ifndef __VSX__
static Packet4f p4f_ONE = vec_ctf(p4i_ONE, 0); //{ 1.0, 1.0, 1.0, 1.0}
@@ -102,11 +104,13 @@
static Packet16uc p16uc_REVERSE32 = { 12,13,14,15, 8,9,10,11, 4,5,6,7, 0,1,2,3 };
static Packet16uc p16uc_REVERSE16 = { 14,15, 12,13, 10,11, 8,9, 6,7, 4,5, 2,3, 0,1 };
+#ifndef _ARCH_PWR9
static Packet16uc p16uc_REVERSE8 = { 15,14,13,12,11,10,9,8,7,6,5,4,3,2,1,0 };
+#endif
+#ifdef _BIG_ENDIAN
static Packet16uc p16uc_DUPLICATE32_HI = { 0,1,2,3, 0,1,2,3, 4,5,6,7, 4,5,6,7 };
-static Packet16uc p16uc_DUPLICATE16_HI = { 0,1,0,1, 2,3,2,3, 4,5,4,5, 6,7,6,7 };
-static Packet16uc p16uc_DUPLICATE8_HI = { 0,0, 1,1, 2,2, 3,3, 4,4, 5,5, 6,6, 7,7 };
+#endif
static const Packet16uc p16uc_DUPLICATE16_EVEN= { 0,1 ,0,1, 4,5, 4,5, 8,9, 8,9, 12,13, 12,13 };
static const Packet16uc p16uc_DUPLICATE16_ODD = { 2,3 ,2,3, 6,7, 6,7, 10,11, 10,11, 14,15, 14,15 };
@@ -116,15 +120,11 @@
// Define global static constants:
#ifdef _BIG_ENDIAN
static Packet16uc p16uc_FORWARD = vec_lvsl(0, (float*)0);
-#ifdef __VSX__
-static Packet16uc p16uc_REVERSE64 = { 8,9,10,11, 12,13,14,15, 0,1,2,3, 4,5,6,7 };
-#endif
static Packet16uc p16uc_PSET32_WODD = vec_sld((Packet16uc) vec_splat((Packet4ui)p16uc_FORWARD, 0), (Packet16uc) vec_splat((Packet4ui)p16uc_FORWARD, 2), 8);//{ 0,1,2,3, 0,1,2,3, 8,9,10,11, 8,9,10,11 };
static Packet16uc p16uc_PSET32_WEVEN = vec_sld(p16uc_DUPLICATE32_HI, (Packet16uc) vec_splat((Packet4ui)p16uc_FORWARD, 3), 8);//{ 4,5,6,7, 4,5,6,7, 12,13,14,15, 12,13,14,15 };
static Packet16uc p16uc_HALF64_0_16 = vec_sld((Packet16uc)p4i_ZERO, vec_splat((Packet16uc) vec_abs(p4i_MINUS16), 3), 8); //{ 0,0,0,0, 0,0,0,0, 16,16,16,16, 16,16,16,16};
#else
static Packet16uc p16uc_FORWARD = p16uc_REVERSE32;
-static Packet16uc p16uc_REVERSE64 = { 8,9,10,11, 12,13,14,15, 0,1,2,3, 4,5,6,7 };
static Packet16uc p16uc_PSET32_WODD = vec_sld((Packet16uc) vec_splat((Packet4ui)p16uc_FORWARD, 1), (Packet16uc) vec_splat((Packet4ui)p16uc_FORWARD, 3), 8);//{ 0,1,2,3, 0,1,2,3, 8,9,10,11, 8,9,10,11 };
static Packet16uc p16uc_PSET32_WEVEN = vec_sld((Packet16uc) vec_splat((Packet4ui)p16uc_FORWARD, 0), (Packet16uc) vec_splat((Packet4ui)p16uc_FORWARD, 2), 8);//{ 4,5,6,7, 4,5,6,7, 12,13,14,15, 12,13,14,15 };
static Packet16uc p16uc_HALF64_0_16 = vec_sld(vec_splat((Packet16uc) vec_abs(p4i_MINUS16), 0), (Packet16uc)p4i_ZERO, 8); //{ 0,0,0,0, 0,0,0,0, 16,16,16,16, 16,16,16,16};
@@ -137,12 +137,6 @@
static Packet16uc p16uc_COMPLEX32_REV = vec_sld(p16uc_REVERSE32, p16uc_REVERSE32, 8); //{ 4,5,6,7, 0,1,2,3, 12,13,14,15, 8,9,10,11 };
-#ifdef _BIG_ENDIAN
-static Packet16uc p16uc_COMPLEX32_REV2 = vec_sld(p16uc_FORWARD, p16uc_FORWARD, 8); //{ 8,9,10,11, 12,13,14,15, 0,1,2,3, 4,5,6,7 };
-#else
-static Packet16uc p16uc_COMPLEX32_REV2 = vec_sld(p16uc_PSET64_HI, p16uc_PSET64_LO, 8); //{ 8,9,10,11, 12,13,14,15, 0,1,2,3, 4,5,6,7 };
-#endif // _BIG_ENDIAN
-
#if EIGEN_HAS_BUILTIN(__builtin_prefetch) || EIGEN_COMP_GNUC
#define EIGEN_PPC_PREFETCH(ADDR) __builtin_prefetch(ADDR);
#else
@@ -788,8 +782,22 @@
template<> EIGEN_STRONG_INLINE Packet16c psub<Packet16c> (const Packet16c& a, const Packet16c& b) { return a - b; }
template<> EIGEN_STRONG_INLINE Packet16uc psub<Packet16uc>(const Packet16uc& a, const Packet16uc& b) { return a - b; }
-template<> EIGEN_STRONG_INLINE Packet4f pnegate(const Packet4f& a) { return p4f_ZERO - a; }
-template<> EIGEN_STRONG_INLINE Packet4i pnegate(const Packet4i& a) { return p4i_ZERO - a; }
+template<> EIGEN_STRONG_INLINE Packet4f pnegate(const Packet4f& a)
+{
+#ifdef __POWER8_VECTOR__
+ return vec_neg(a);
+#else
+ return p4f_ZERO - a;
+#endif
+}
+template<> EIGEN_STRONG_INLINE Packet4i pnegate(const Packet4i& a)
+{
+#ifdef __POWER8_VECTOR__
+ return vec_neg(a);
+#else
+ return p4i_ZERO - a;
+#endif
+}
template<> EIGEN_STRONG_INLINE Packet4f pconj(const Packet4f& a) { return a; }
template<> EIGEN_STRONG_INLINE Packet4i pconj(const Packet4i& a) { return a; }
@@ -831,6 +839,10 @@
template<> EIGEN_STRONG_INLINE Packet8s pmadd(const Packet8s& a, const Packet8s& b, const Packet8s& c) { return vec_madd(a,b,c); }
template<> EIGEN_STRONG_INLINE Packet8us pmadd(const Packet8us& a, const Packet8us& b, const Packet8us& c) { return vec_madd(a,b,c); }
+template<> EIGEN_STRONG_INLINE Packet4f pmsub(const Packet4f& a, const Packet4f& b, const Packet4f& c) { return vec_msub(a,b,c); }
+template<> EIGEN_STRONG_INLINE Packet4f pnmadd(const Packet4f& a, const Packet4f& b, const Packet4f& c) { return vec_nmsub(a,b,c); }
+template<> EIGEN_STRONG_INLINE Packet4f pnmsub(const Packet4f& a, const Packet4f& b, const Packet4f& c) { return vec_nmadd(a,b,c); }
+
template<> EIGEN_STRONG_INLINE Packet4f pmin<Packet4f>(const Packet4f& a, const Packet4f& b)
{
#ifdef __VSX__
@@ -953,7 +965,10 @@
template<typename Packet> EIGEN_STRONG_INLINE Packet ploadu_common(const __UNPACK_TYPE__(Packet)* from)
{
EIGEN_DEBUG_ALIGNED_LOAD
-#ifdef _BIG_ENDIAN
+#if defined(__VSX__) || !defined(_BIG_ENDIAN)
+ EIGEN_DEBUG_UNALIGNED_LOAD
+ return vec_xl(0, const_cast<__UNPACK_TYPE__(Packet)*>(from));
+#else
Packet16uc MSQ, LSQ;
Packet16uc mask;
MSQ = vec_ld(0, (unsigned char *)from); // most significant quadword
@@ -961,9 +976,6 @@
mask = vec_lvsl(0, from); // create the permute mask
//TODO: Add static_cast here
return (Packet) vec_perm(MSQ, LSQ, mask); // align the data
-#else
- EIGEN_DEBUG_UNALIGNED_LOAD
- return vec_xl(0, const_cast<__UNPACK_TYPE__(Packet)*>(from));
#endif
}
@@ -1001,7 +1013,7 @@
Packet p;
if((std::ptrdiff_t(from) % 16) == 0) p = pload<Packet>(from);
else p = ploadu<Packet>(from);
- return vec_perm(p, p, p16uc_DUPLICATE32_HI);
+ return vec_mergeh(p, p);
}
template<> EIGEN_STRONG_INLINE Packet4f ploaddup<Packet4f>(const float* from)
{
@@ -1017,7 +1029,7 @@
Packet8s p;
if((std::ptrdiff_t(from) % 16) == 0) p = pload<Packet8s>(from);
else p = ploadu<Packet8s>(from);
- return vec_perm(p, p, p16uc_DUPLICATE16_HI);
+ return vec_mergeh(p, p);
}
template<> EIGEN_STRONG_INLINE Packet8us ploaddup<Packet8us>(const unsigned short int* from)
@@ -1025,7 +1037,7 @@
Packet8us p;
if((std::ptrdiff_t(from) % 16) == 0) p = pload<Packet8us>(from);
else p = ploadu<Packet8us>(from);
- return vec_perm(p, p, p16uc_DUPLICATE16_HI);
+ return vec_mergeh(p, p);
}
template<> EIGEN_STRONG_INLINE Packet8s ploadquad<Packet8s>(const short int* from)
@@ -1054,7 +1066,7 @@
Packet16c p;
if((std::ptrdiff_t(from) % 16) == 0) p = pload<Packet16c>(from);
else p = ploadu<Packet16c>(from);
- return vec_perm(p, p, p16uc_DUPLICATE8_HI);
+ return vec_mergeh(p, p);
}
template<> EIGEN_STRONG_INLINE Packet16uc ploaddup<Packet16uc>(const unsigned char* from)
@@ -1062,13 +1074,15 @@
Packet16uc p;
if((std::ptrdiff_t(from) % 16) == 0) p = pload<Packet16uc>(from);
else p = ploadu<Packet16uc>(from);
- return vec_perm(p, p, p16uc_DUPLICATE8_HI);
+ return vec_mergeh(p, p);
}
template<typename Packet> EIGEN_STRONG_INLINE void pstoreu_common(__UNPACK_TYPE__(Packet)* to, const Packet& from)
{
EIGEN_DEBUG_UNALIGNED_STORE
-#ifdef _BIG_ENDIAN
+#if defined(__VSX__) || !defined(_BIG_ENDIAN)
+ vec_xst(from, 0, to);
+#else
// Taken from http://developer.apple.com/hardwaredrivers/ve/alignment.html
// Warning: not thread safe!
Packet16uc MSQ, LSQ, edges;
@@ -1083,8 +1097,6 @@
LSQ = vec_perm((Packet16uc)from,edges,align); // misalign the data (LSQ)
vec_st( LSQ, 15, (unsigned char *)to ); // Store the LSQ part first
vec_st( MSQ, 0, (unsigned char *)to ); // Store the MSQ part second
-#else
- vec_xst(from, 0, to);
#endif
}
template<> EIGEN_STRONG_INLINE void pstoreu<float>(float* to, const Packet4f& from)
@@ -1164,11 +1176,19 @@
}
template<> EIGEN_STRONG_INLINE Packet16c preverse(const Packet16c& a)
{
+#ifdef _ARCH_PWR9
+ return vec_revb(a);
+#else
return vec_perm(a, a, p16uc_REVERSE8);
+#endif
}
template<> EIGEN_STRONG_INLINE Packet16uc preverse(const Packet16uc& a)
{
+#ifdef _ARCH_PWR9
+ return vec_revb(a);
+#else
return vec_perm(a, a, p16uc_REVERSE8);
+#endif
}
template<> EIGEN_STRONG_INLINE Packet8bf preverse(const Packet8bf& a)
{
@@ -2102,7 +2122,11 @@
template<typename Packet> EIGEN_STRONG_INLINE
Packet pblend4(const Selector<4>& ifPacket, const Packet& thenPacket, const Packet& elsePacket) {
Packet4ui select = { ifPacket.select[0], ifPacket.select[1], ifPacket.select[2], ifPacket.select[3] };
+#ifdef __POWER8_VECTOR__
+ Packet4ui mask = reinterpret_cast<Packet4ui>(vec_neg(reinterpret_cast<Packet4i>(select)));
+#else
Packet4ui mask = reinterpret_cast<Packet4ui>(vec_cmpeq(reinterpret_cast<Packet4ui>(select), reinterpret_cast<Packet4ui>(p4i_ONE)));
+#endif
return vec_sel(elsePacket, thenPacket, mask);
}
@@ -2117,7 +2141,11 @@
template<> EIGEN_STRONG_INLINE Packet8s pblend(const Selector<8>& ifPacket, const Packet8s& thenPacket, const Packet8s& elsePacket) {
Packet8us select = { ifPacket.select[0], ifPacket.select[1], ifPacket.select[2], ifPacket.select[3],
ifPacket.select[4], ifPacket.select[5], ifPacket.select[6], ifPacket.select[7] };
+#ifdef __POWER8_VECTOR__
+ Packet8us mask = reinterpret_cast<Packet8us>(vec_neg(reinterpret_cast<Packet8s>(select)));
+#else
Packet8us mask = reinterpret_cast<Packet8us>(vec_cmpeq(select, p8us_ONE));
+#endif
Packet8s result = vec_sel(elsePacket, thenPacket, mask);
return result;
}
@@ -2125,7 +2153,11 @@
template<> EIGEN_STRONG_INLINE Packet8us pblend(const Selector<8>& ifPacket, const Packet8us& thenPacket, const Packet8us& elsePacket) {
Packet8us select = { ifPacket.select[0], ifPacket.select[1], ifPacket.select[2], ifPacket.select[3],
ifPacket.select[4], ifPacket.select[5], ifPacket.select[6], ifPacket.select[7] };
+#ifdef __POWER8_VECTOR__
+ Packet8us mask = reinterpret_cast<Packet8us>(vec_neg(reinterpret_cast<Packet8s>(select)));
+#else
Packet8us mask = reinterpret_cast<Packet8us>(vec_cmpeq(reinterpret_cast<Packet8us>(select), p8us_ONE));
+#endif
return vec_sel(elsePacket, thenPacket, mask);
}
@@ -2139,7 +2171,11 @@
ifPacket.select[8], ifPacket.select[9], ifPacket.select[10], ifPacket.select[11],
ifPacket.select[12], ifPacket.select[13], ifPacket.select[14], ifPacket.select[15] };
+#ifdef __POWER8_VECTOR__
+ Packet16uc mask = reinterpret_cast<Packet16uc>(vec_neg(reinterpret_cast<Packet16c>(select)));
+#else
Packet16uc mask = reinterpret_cast<Packet16uc>(vec_cmpeq(reinterpret_cast<Packet16uc>(select), p16uc_ONE));
+#endif
return vec_sel(elsePacket, thenPacket, mask);
}
@@ -2149,7 +2185,11 @@
ifPacket.select[8], ifPacket.select[9], ifPacket.select[10], ifPacket.select[11],
ifPacket.select[12], ifPacket.select[13], ifPacket.select[14], ifPacket.select[15] };
+#ifdef __POWER8_VECTOR__
+ Packet16uc mask = reinterpret_cast<Packet16uc>(vec_neg(reinterpret_cast<Packet16c>(select)));
+#else
Packet16uc mask = reinterpret_cast<Packet16uc>(vec_cmpeq(reinterpret_cast<Packet16uc>(select), p16uc_ONE));
+#endif
return vec_sel(elsePacket, thenPacket, mask);
}
@@ -2395,7 +2435,14 @@
template<> EIGEN_STRONG_INLINE Packet2d psub<Packet2d>(const Packet2d& a, const Packet2d& b) { return a - b; }
-template<> EIGEN_STRONG_INLINE Packet2d pnegate(const Packet2d& a) { return p2d_ZERO - a; }
+template<> EIGEN_STRONG_INLINE Packet2d pnegate(const Packet2d& a)
+{
+#ifdef __POWER8_VECTOR__
+ return vec_neg(a);
+#else
+ return p2d_ZERO - a;
+#endif
+}
template<> EIGEN_STRONG_INLINE Packet2d pconj(const Packet2d& a) { return a; }
@@ -2404,6 +2451,9 @@
// for some weird raisons, it has to be overloaded for packet of integers
template<> EIGEN_STRONG_INLINE Packet2d pmadd(const Packet2d& a, const Packet2d& b, const Packet2d& c) { return vec_madd(a, b, c); }
+template<> EIGEN_STRONG_INLINE Packet2d pmsub(const Packet2d& a, const Packet2d& b, const Packet2d& c) { return vec_msub(a,b,c); }
+template<> EIGEN_STRONG_INLINE Packet2d pnmadd(const Packet2d& a, const Packet2d& b, const Packet2d& c) { return vec_nmsub(a,b,c); }
+template<> EIGEN_STRONG_INLINE Packet2d pnmsub(const Packet2d& a, const Packet2d& b, const Packet2d& c) { return vec_nmadd(a,b,c); }
template<> EIGEN_STRONG_INLINE Packet2d pmin<Packet2d>(const Packet2d& a, const Packet2d& b)
{
@@ -2487,7 +2537,7 @@
template<> EIGEN_STRONG_INLINE Packet2d preverse(const Packet2d& a)
{
- return reinterpret_cast<Packet2d>(vec_perm(reinterpret_cast<Packet16uc>(a), reinterpret_cast<Packet16uc>(a), p16uc_REVERSE64));
+ return vec_sld(a, a, 8);
}
template<> EIGEN_STRONG_INLINE Packet2d pabs(const Packet2d& a) { return vec_abs(a); }
@@ -2573,7 +2623,7 @@
struct plogical_shift_left_impl;
template<int N>
-struct plogical_shift_left_impl<N, typename enable_if<(N < 32) && (N >= 0)>::type> {
+struct plogical_shift_left_impl<N, std::enable_if_t<(N < 32) && (N >= 0)>> {
static EIGEN_STRONG_INLINE Packet2l run(const Packet2l& a) {
static const unsigned n = static_cast<unsigned>(N);
const Packet4ui shift = {n, n, n, n};
@@ -2587,7 +2637,7 @@
};
template<int N>
-struct plogical_shift_left_impl<N, typename enable_if<(N >= 32)>::type> {
+struct plogical_shift_left_impl<N, std::enable_if_t<(N >= 32)>> {
static EIGEN_STRONG_INLINE Packet2l run(const Packet2l& a) {
static const unsigned m = static_cast<unsigned>(N - 32);
const Packet4ui shift = {m, m, m, m};
@@ -2605,7 +2655,7 @@
struct plogical_shift_right_impl;
template<int N>
-struct plogical_shift_right_impl<N, typename enable_if<(N < 32) && (N >= 0)>::type> {
+struct plogical_shift_right_impl<N, std::enable_if_t<(N < 32) && (N >= 0)>> {
static EIGEN_STRONG_INLINE Packet2l run(const Packet2l& a) {
static const unsigned n = static_cast<unsigned>(N);
const Packet4ui shift = {n, n, n, n};
@@ -2619,7 +2669,7 @@
};
template<int N>
-struct plogical_shift_right_impl<N, typename enable_if<(N >= 32)>::type> {
+struct plogical_shift_right_impl<N, std::enable_if_t<(N >= 32)>> {
static EIGEN_STRONG_INLINE Packet2l run(const Packet2l& a) {
static const unsigned m = static_cast<unsigned>(N - 32);
const Packet4ui shift = {m, m, m, m};
@@ -2692,8 +2742,8 @@
EIGEN_DEVICE_FUNC inline void
ptranspose(PacketBlock<Packet2d,2>& kernel) {
Packet2d t0, t1;
- t0 = vec_perm(kernel.packet[0], kernel.packet[1], p16uc_TRANSPOSE64_HI);
- t1 = vec_perm(kernel.packet[0], kernel.packet[1], p16uc_TRANSPOSE64_LO);
+ t0 = vec_mergeh(kernel.packet[0], kernel.packet[1]);
+ t1 = vec_mergel(kernel.packet[0], kernel.packet[1]);
kernel.packet[0] = t0;
kernel.packet[1] = t1;
}
diff --git a/Eigen/src/Core/arch/Default/GenericPacketMathFunctions.h b/Eigen/src/Core/arch/Default/GenericPacketMathFunctions.h
index dedf976..822113b 100644
--- a/Eigen/src/Core/arch/Default/GenericPacketMathFunctions.h
+++ b/Eigen/src/Core/arch/Default/GenericPacketMathFunctions.h
@@ -28,11 +28,11 @@
template<> struct make_integer<half> { typedef numext::int16_t type; };
template<> struct make_integer<bfloat16> { typedef numext::int16_t type; };
-template<typename Packet> EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC
+template<typename Packet> EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC
Packet pfrexp_generic_get_biased_exponent(const Packet& a) {
typedef typename unpacket_traits<Packet>::type Scalar;
typedef typename unpacket_traits<Packet>::integer_packet PacketI;
- enum { mantissa_bits = numext::numeric_limits<Scalar>::digits - 1};
+ static constexpr int mantissa_bits = numext::numeric_limits<Scalar>::digits - 1;
return pcast<PacketI, Packet>(plogical_shift_right<mantissa_bits>(preinterpret<PacketI>(pabs(a))));
}
@@ -42,42 +42,41 @@
Packet pfrexp_generic(const Packet& a, Packet& exponent) {
typedef typename unpacket_traits<Packet>::type Scalar;
typedef typename make_unsigned<typename make_integer<Scalar>::type>::type ScalarUI;
- enum {
+ static constexpr int
TotalBits = sizeof(Scalar) * CHAR_BIT,
MantissaBits = numext::numeric_limits<Scalar>::digits - 1,
- ExponentBits = int(TotalBits) - int(MantissaBits) - 1
- };
+ ExponentBits = TotalBits - MantissaBits - 1;
- EIGEN_CONSTEXPR ScalarUI scalar_sign_mantissa_mask =
- ~(((ScalarUI(1) << int(ExponentBits)) - ScalarUI(1)) << int(MantissaBits)); // ~0x7f800000
- const Packet sign_mantissa_mask = pset1frombits<Packet>(static_cast<ScalarUI>(scalar_sign_mantissa_mask));
+ EIGEN_CONSTEXPR ScalarUI scalar_sign_mantissa_mask =
+ ~(((ScalarUI(1) << ExponentBits) - ScalarUI(1)) << MantissaBits); // ~0x7f800000
+ const Packet sign_mantissa_mask = pset1frombits<Packet>(static_cast<ScalarUI>(scalar_sign_mantissa_mask));
const Packet half = pset1<Packet>(Scalar(0.5));
const Packet zero = pzero(a);
const Packet normal_min = pset1<Packet>((numext::numeric_limits<Scalar>::min)()); // Minimum normal value, 2^-126
-
+
// To handle denormals, normalize by multiplying by 2^(int(MantissaBits)+1).
const Packet is_denormal = pcmp_lt(pabs(a), normal_min);
- EIGEN_CONSTEXPR ScalarUI scalar_normalization_offset = ScalarUI(int(MantissaBits) + 1); // 24
+ EIGEN_CONSTEXPR ScalarUI scalar_normalization_offset = ScalarUI(MantissaBits + 1); // 24
// The following cannot be constexpr because bfloat16(uint16_t) is not constexpr.
const Scalar scalar_normalization_factor = Scalar(ScalarUI(1) << int(scalar_normalization_offset)); // 2^24
- const Packet normalization_factor = pset1<Packet>(scalar_normalization_factor);
+ const Packet normalization_factor = pset1<Packet>(scalar_normalization_factor);
const Packet normalized_a = pselect(is_denormal, pmul(a, normalization_factor), a);
-
+
// Determine exponent offset: -126 if normal, -126-24 if denormal
- const Scalar scalar_exponent_offset = -Scalar((ScalarUI(1)<<(int(ExponentBits)-1)) - ScalarUI(2)); // -126
+ const Scalar scalar_exponent_offset = -Scalar((ScalarUI(1)<<(ExponentBits-1)) - ScalarUI(2)); // -126
Packet exponent_offset = pset1<Packet>(scalar_exponent_offset);
const Packet normalization_offset = pset1<Packet>(-Scalar(scalar_normalization_offset)); // -24
exponent_offset = pselect(is_denormal, padd(exponent_offset, normalization_offset), exponent_offset);
-
+
// Determine exponent and mantissa from normalized_a.
exponent = pfrexp_generic_get_biased_exponent(normalized_a);
// Zero, Inf and NaN return 'a' unmodified, exponent is zero
// (technically the exponent is unspecified for inf/NaN, but GCC/Clang set it to zero)
- const Scalar scalar_non_finite_exponent = Scalar((ScalarUI(1) << int(ExponentBits)) - ScalarUI(1)); // 255
+ const Scalar scalar_non_finite_exponent = Scalar((ScalarUI(1) << ExponentBits) - ScalarUI(1)); // 255
const Packet non_finite_exponent = pset1<Packet>(scalar_non_finite_exponent);
const Packet is_zero_or_not_finite = por(pcmp_eq(a, zero), pcmp_eq(exponent, non_finite_exponent));
const Packet m = pselect(is_zero_or_not_finite, a, por(pand(normalized_a, sign_mantissa_mask), half));
- exponent = pselect(is_zero_or_not_finite, zero, padd(exponent, exponent_offset));
+ exponent = pselect(is_zero_or_not_finite, zero, padd(exponent, exponent_offset));
return m;
}
@@ -110,25 +109,24 @@
typedef typename unpacket_traits<Packet>::integer_packet PacketI;
typedef typename unpacket_traits<Packet>::type Scalar;
typedef typename unpacket_traits<PacketI>::type ScalarI;
- enum {
+ static constexpr int
TotalBits = sizeof(Scalar) * CHAR_BIT,
MantissaBits = numext::numeric_limits<Scalar>::digits - 1,
- ExponentBits = int(TotalBits) - int(MantissaBits) - 1
- };
+ ExponentBits = TotalBits - MantissaBits - 1;
- const Packet max_exponent = pset1<Packet>(Scalar((ScalarI(1)<<int(ExponentBits)) + ScalarI(int(MantissaBits) - 1))); // 278
- const PacketI bias = pset1<PacketI>((ScalarI(1)<<(int(ExponentBits)-1)) - ScalarI(1)); // 127
+ const Packet max_exponent = pset1<Packet>(Scalar((ScalarI(1)<<ExponentBits) + ScalarI(MantissaBits - 1))); // 278
+ const PacketI bias = pset1<PacketI>((ScalarI(1)<<(ExponentBits-1)) - ScalarI(1)); // 127
const PacketI e = pcast<Packet, PacketI>(pmin(pmax(exponent, pnegate(max_exponent)), max_exponent));
PacketI b = parithmetic_shift_right<2>(e); // floor(e/4);
- Packet c = preinterpret<Packet>(plogical_shift_left<int(MantissaBits)>(padd(b, bias))); // 2^b
+ Packet c = preinterpret<Packet>(plogical_shift_left<MantissaBits>(padd(b, bias))); // 2^b
Packet out = pmul(pmul(pmul(a, c), c), c); // a * 2^(3b)
b = psub(psub(psub(e, b), b), b); // e - 3b
- c = preinterpret<Packet>(plogical_shift_left<int(MantissaBits)>(padd(b, bias))); // 2^(e-3*b)
+ c = preinterpret<Packet>(plogical_shift_left<MantissaBits>(padd(b, bias))); // 2^(e-3*b)
out = pmul(out, c);
return out;
}
-// Explicitly multiplies
+// Explicitly multiplies
// a * (2^e)
// clamping e to the range
// [NumTraits<Scalar>::min_exponent()-2, NumTraits<Scalar>::max_exponent()]
@@ -142,20 +140,19 @@
typedef typename unpacket_traits<Packet>::integer_packet PacketI;
typedef typename unpacket_traits<Packet>::type Scalar;
typedef typename unpacket_traits<PacketI>::type ScalarI;
- enum {
+ static constexpr int
TotalBits = sizeof(Scalar) * CHAR_BIT,
MantissaBits = numext::numeric_limits<Scalar>::digits - 1,
- ExponentBits = int(TotalBits) - int(MantissaBits) - 1
- };
-
+ ExponentBits = TotalBits - MantissaBits - 1;
+
static EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC
Packet run(const Packet& a, const Packet& exponent) {
- const Packet bias = pset1<Packet>(Scalar((ScalarI(1)<<(int(ExponentBits)-1)) - ScalarI(1))); // 127
- const Packet limit = pset1<Packet>(Scalar((ScalarI(1)<<int(ExponentBits)) - ScalarI(1))); // 255
+ const Packet bias = pset1<Packet>(Scalar((ScalarI(1)<<(ExponentBits-1)) - ScalarI(1))); // 127
+ const Packet limit = pset1<Packet>(Scalar((ScalarI(1)<<ExponentBits) - ScalarI(1))); // 255
// restrict biased exponent between 0 and 255 for float.
const PacketI e = pcast<Packet, PacketI>(pmin(pmax(padd(exponent, bias), pzero(limit)), limit)); // exponent + 127
// return a * (2^e)
- return pmul(a, preinterpret<Packet>(plogical_shift_left<int(MantissaBits)>(e)));
+ return pmul(a, preinterpret<Packet>(plogical_shift_left<MantissaBits>(e)));
}
};
diff --git a/Eigen/src/Core/arch/Default/Half.h b/Eigen/src/Core/arch/Default/Half.h
index 402b8d4..7f9e6c1 100644
--- a/Eigen/src/Core/arch/Default/Half.h
+++ b/Eigen/src/Core/arch/Default/Half.h
@@ -591,7 +591,12 @@
#elif defined(EIGEN_HAS_FP16_C)
__half_raw h;
- h.x = _cvtss_sh(ff, 0);
+ #if EIGEN_COMP_MSVC
+ // MSVC does not have scalar instructions.
+ h.x =_mm_extract_epi16(_mm_cvtps_ph(_mm_set_ss(ff), 0), 0);
+ #else
+ h.x = _cvtss_sh(ff, 0);
+ #endif
return h;
#elif defined(EIGEN_HAS_ARM64_FP16_SCALAR_ARITHMETIC)
@@ -652,7 +657,12 @@
(defined(EIGEN_HAS_HIP_FP16) && defined(EIGEN_HIP_DEVICE_COMPILE))
return __half2float(h);
#elif defined(EIGEN_HAS_FP16_C)
- return _cvtsh_ss(h.x);
+ #if EIGEN_COMP_MSVC
+ // MSVC does not have scalar instructions.
+ return _mm_cvtss_f32(_mm_cvtph_ps(_mm_set1_epi16(h.x)));
+ #else
+ return _cvtsh_ss(h.x);
+ #endif
#elif defined(EIGEN_HAS_ARM64_FP16_SCALAR_ARITHMETIC)
return static_cast<float>(h.x);
#else
diff --git a/Eigen/src/Core/arch/Default/TypeCasting.h b/Eigen/src/Core/arch/Default/TypeCasting.h
index b71cbdf..dc779a7 100644
--- a/Eigen/src/Core/arch/Default/TypeCasting.h
+++ b/Eigen/src/Core/arch/Default/TypeCasting.h
@@ -19,7 +19,6 @@
template<>
struct scalar_cast_op<float, Eigen::half> {
- EIGEN_EMPTY_STRUCT_CTOR(scalar_cast_op)
typedef Eigen::half result_type;
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Eigen::half operator() (const float& a) const {
#if (defined(EIGEN_HAS_CUDA_FP16) && defined(EIGEN_CUDA_ARCH) && EIGEN_CUDA_ARCH >= 300) || \
@@ -38,7 +37,6 @@
template<>
struct scalar_cast_op<int, Eigen::half> {
- EIGEN_EMPTY_STRUCT_CTOR(scalar_cast_op)
typedef Eigen::half result_type;
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Eigen::half operator() (const int& a) const {
#if (defined(EIGEN_HAS_CUDA_FP16) && defined(EIGEN_CUDA_ARCH) && EIGEN_CUDA_ARCH >= 300) || \
@@ -57,7 +55,6 @@
template<>
struct scalar_cast_op<Eigen::half, float> {
- EIGEN_EMPTY_STRUCT_CTOR(scalar_cast_op)
typedef float result_type;
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE float operator() (const Eigen::half& a) const {
#if (defined(EIGEN_HAS_CUDA_FP16) && defined(EIGEN_CUDA_ARCH) && EIGEN_CUDA_ARCH >= 300) || \
@@ -76,7 +73,6 @@
template<>
struct scalar_cast_op<float, Eigen::bfloat16> {
- EIGEN_EMPTY_STRUCT_CTOR(scalar_cast_op)
typedef Eigen::bfloat16 result_type;
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Eigen::bfloat16 operator() (const float& a) const {
return Eigen::bfloat16(a);
@@ -90,7 +86,6 @@
template<>
struct scalar_cast_op<int, Eigen::bfloat16> {
- EIGEN_EMPTY_STRUCT_CTOR(scalar_cast_op)
typedef Eigen::bfloat16 result_type;
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Eigen::bfloat16 operator() (const int& a) const {
return Eigen::bfloat16(static_cast<float>(a));
@@ -104,7 +99,6 @@
template<>
struct scalar_cast_op<Eigen::bfloat16, float> {
- EIGEN_EMPTY_STRUCT_CTOR(scalar_cast_op)
typedef float result_type;
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE float operator() (const Eigen::bfloat16& a) const {
return static_cast<float>(a);
diff --git a/Eigen/src/Core/arch/GPU/Tuple.h b/Eigen/src/Core/arch/GPU/Tuple.h
index 97c54e5..e223ca1 100644
--- a/Eigen/src/Core/arch/GPU/Tuple.h
+++ b/Eigen/src/Core/arch/GPU/Tuple.h
@@ -31,22 +31,22 @@
EIGEN_MAKE_ALIGNED_OPERATOR_NEW
// Default constructor, enable if all types are default-constructible.
- template<typename U1 = T1, typename EnableIf = typename std::enable_if<
+ template<typename U1 = T1, typename EnableIf = std::enable_if_t<
std::is_default_constructible<U1>::value
&& reduce_all<std::is_default_constructible<Ts>::value...>::value
- >::type>
+ >>
EIGEN_CONSTEXPR EIGEN_DEVICE_FUNC
TupleImpl() : head_{}, tail_{} {}
// Element constructor.
template<typename U1, typename... Us,
// Only enable if...
- typename EnableIf = typename std::enable_if<
+ typename EnableIf = std::enable_if_t<
// the number of input arguments match, and ...
sizeof...(Us) == sizeof...(Ts) && (
// this does not look like a copy/move constructor.
N > 1 || std::is_convertible<U1, T1>::value)
- >::type>
+ >>
EIGEN_CONSTEXPR EIGEN_DEVICE_FUNC
TupleImpl(U1&& arg1, Us&&... args)
: head_(std::forward<U1>(arg1)), tail_(std::forward<Us>(args)...) {}
@@ -253,9 +253,9 @@
* \return concatenated tuple.
*/
template<typename... Tuples,
- typename EnableIf = typename std::enable_if<
+ typename EnableIf = std::enable_if_t<
internal::reduce_all<
- is_tuple<typename std::decay<Tuples>::type>::value...>::value>::type>
+ is_tuple<typename std::decay<Tuples>::type>::value...>::value>>
EIGEN_CONSTEXPR EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
typename tuple_cat_impl<sizeof...(Tuples), typename std::decay<Tuples>::type...>::ReturnType
tuple_cat(Tuples&&... tuples) {
diff --git a/Eigen/src/Core/arch/NEON/UnaryFunctors.h b/Eigen/src/Core/arch/NEON/UnaryFunctors.h
index 131746d..67f9dcf 100644
--- a/Eigen/src/Core/arch/NEON/UnaryFunctors.h
+++ b/Eigen/src/Core/arch/NEON/UnaryFunctors.h
@@ -20,7 +20,6 @@
*/
template <>
struct scalar_logistic_op<Eigen::half> {
- EIGEN_EMPTY_STRUCT_CTOR(scalar_logistic_op)
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
Eigen::half operator()(const Eigen::half& x) const {
// Convert to float and call scalar_logistic_op<float>.
diff --git a/Eigen/src/Core/arch/SYCL/SyclMemoryModel.h b/Eigen/src/Core/arch/SYCL/SyclMemoryModel.h
index 2b96587..54eedfa 100644
--- a/Eigen/src/Core/arch/SYCL/SyclMemoryModel.h
+++ b/Eigen/src/Core/arch/SYCL/SyclMemoryModel.h
@@ -141,7 +141,7 @@
/* basic type for all buffers
*/
- using buffer_t = cl::sycl::buffer_mem;
+ using buffer_t = cl::sycl::buffer<buffer_data_type_t>;
/**
* Node that stores information about a device allocation.
@@ -237,17 +237,14 @@
template <typename buffer_data_type = buffer_data_type_t>
cl::sycl::buffer<buffer_data_type, 1> get_buffer(
const virtual_pointer_t ptr) {
- using sycl_buffer_t = cl::sycl::buffer<buffer_data_type, 1>;
- // get_node() returns a `buffer_mem`, so we need to cast it to a `buffer<>`.
- // We can do this without the `buffer_mem` being a pointer, as we
- // only declare member variables in the base class (`buffer_mem`) and not in
- // the child class (`buffer<>).
auto node = get_node(ptr);
+ auto& map_node = node->second;
eigen_assert(node->first == ptr || node->first < ptr);
- eigen_assert(ptr < static_cast<virtual_pointer_t>(node->second.m_size +
+ eigen_assert(ptr < static_cast<virtual_pointer_t>(map_node.m_size +
node->first));
- return *(static_cast<sycl_buffer_t *>(&node->second.m_buffer));
+ return map_node.m_buffer.reinterpret<buffer_data_type>(
+ cl::sycl::range<1>{map_node.m_size / sizeof(buffer_data_type)});
}
/**
@@ -429,8 +426,11 @@
template <class BufferT>
virtual_pointer_t add_pointer_impl(BufferT b) {
virtual_pointer_t retVal = nullptr;
- size_t bufSize = b.get_count();
- pMapNode_t p{b, bufSize, false};
+ size_t bufSize = b.get_count() * sizeof(buffer_data_type_t);
+ auto byte_buffer =
+ b.template reinterpret<buffer_data_type_t>(cl::sycl::range<1>{bufSize});
+ pMapNode_t p{byte_buffer, bufSize, false};
+
// If this is the first pointer:
if (m_pointerMap.empty()) {
virtual_pointer_t initialVal{m_baseAddress};
diff --git a/Eigen/src/Core/functors/AssignmentFunctors.h b/Eigen/src/Core/functors/AssignmentFunctors.h
index 6bf755f..c9d80e6 100644
--- a/Eigen/src/Core/functors/AssignmentFunctors.h
+++ b/Eigen/src/Core/functors/AssignmentFunctors.h
@@ -22,9 +22,8 @@
*/
template<typename DstScalar,typename SrcScalar> struct assign_op {
- EIGEN_EMPTY_STRUCT_CTOR(assign_op)
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void assignCoeff(DstScalar& a, const SrcScalar& b) const { a = b; }
-
+
template<int Alignment, typename Packet>
EIGEN_STRONG_INLINE void assignPacket(DstScalar* a, const Packet& b) const
{ internal::pstoret<DstScalar,Packet,Alignment>(a,b); }
@@ -47,9 +46,8 @@
*/
template<typename DstScalar,typename SrcScalar> struct add_assign_op {
- EIGEN_EMPTY_STRUCT_CTOR(add_assign_op)
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void assignCoeff(DstScalar& a, const SrcScalar& b) const { a += b; }
-
+
template<int Alignment, typename Packet>
EIGEN_STRONG_INLINE void assignPacket(DstScalar* a, const Packet& b) const
{ internal::pstoret<DstScalar,Packet,Alignment>(a,internal::padd(internal::ploadt<Packet,Alignment>(a),b)); }
@@ -68,9 +66,8 @@
*/
template<typename DstScalar,typename SrcScalar> struct sub_assign_op {
- EIGEN_EMPTY_STRUCT_CTOR(sub_assign_op)
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void assignCoeff(DstScalar& a, const SrcScalar& b) const { a -= b; }
-
+
template<int Alignment, typename Packet>
EIGEN_STRONG_INLINE void assignPacket(DstScalar* a, const Packet& b) const
{ internal::pstoret<DstScalar,Packet,Alignment>(a,internal::psub(internal::ploadt<Packet,Alignment>(a),b)); }
@@ -90,9 +87,8 @@
template<typename DstScalar, typename SrcScalar=DstScalar>
struct mul_assign_op {
- EIGEN_EMPTY_STRUCT_CTOR(mul_assign_op)
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void assignCoeff(DstScalar& a, const SrcScalar& b) const { a *= b; }
-
+
template<int Alignment, typename Packet>
EIGEN_STRONG_INLINE void assignPacket(DstScalar* a, const Packet& b) const
{ internal::pstoret<DstScalar,Packet,Alignment>(a,internal::pmul(internal::ploadt<Packet,Alignment>(a),b)); }
@@ -111,9 +107,8 @@
*/
template<typename DstScalar, typename SrcScalar=DstScalar> struct div_assign_op {
- EIGEN_EMPTY_STRUCT_CTOR(div_assign_op)
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void assignCoeff(DstScalar& a, const SrcScalar& b) const { a /= b; }
-
+
template<int Alignment, typename Packet>
EIGEN_STRONG_INLINE void assignPacket(DstScalar* a, const Packet& b) const
{ internal::pstoret<DstScalar,Packet,Alignment>(a,internal::pdiv(internal::ploadt<Packet,Alignment>(a),b)); }
@@ -143,7 +138,6 @@
*/
template<typename Scalar> struct swap_assign_op {
- EIGEN_EMPTY_STRUCT_CTOR(swap_assign_op)
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void assignCoeff(Scalar& a, const Scalar& b) const
{
#ifdef EIGEN_GPUCC
diff --git a/Eigen/src/Core/functors/BinaryFunctors.h b/Eigen/src/Core/functors/BinaryFunctors.h
index 88e2e8a..094acb4 100644
--- a/Eigen/src/Core/functors/BinaryFunctors.h
+++ b/Eigen/src/Core/functors/BinaryFunctors.h
@@ -34,9 +34,7 @@
struct scalar_sum_op : binary_op_base<LhsScalar,RhsScalar>
{
typedef typename ScalarBinaryOpTraits<LhsScalar,RhsScalar,scalar_sum_op>::ReturnType result_type;
-#ifndef EIGEN_SCALAR_BINARY_OP_PLUGIN
- EIGEN_EMPTY_STRUCT_CTOR(scalar_sum_op)
-#else
+#ifdef EIGEN_SCALAR_BINARY_OP_PLUGIN
scalar_sum_op() {
EIGEN_SCALAR_BINARY_OP_PLUGIN
}
@@ -72,9 +70,7 @@
struct scalar_product_op : binary_op_base<LhsScalar,RhsScalar>
{
typedef typename ScalarBinaryOpTraits<LhsScalar,RhsScalar,scalar_product_op>::ReturnType result_type;
-#ifndef EIGEN_SCALAR_BINARY_OP_PLUGIN
- EIGEN_EMPTY_STRUCT_CTOR(scalar_product_op)
-#else
+#ifdef EIGEN_SCALAR_BINARY_OP_PLUGIN
scalar_product_op() {
EIGEN_SCALAR_BINARY_OP_PLUGIN
}
@@ -115,7 +111,6 @@
typedef typename ScalarBinaryOpTraits<LhsScalar,RhsScalar,scalar_conj_product_op>::ReturnType result_type;
- EIGEN_EMPTY_STRUCT_CTOR(scalar_conj_product_op)
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE result_type operator() (const LhsScalar& a, const RhsScalar& b) const
{ return conj_helper<LhsScalar,RhsScalar,Conj,false>().pmul(a,b); }
@@ -140,7 +135,6 @@
struct scalar_min_op : binary_op_base<LhsScalar,RhsScalar>
{
typedef typename ScalarBinaryOpTraits<LhsScalar,RhsScalar,scalar_min_op>::ReturnType result_type;
- EIGEN_EMPTY_STRUCT_CTOR(scalar_min_op)
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE result_type operator() (const LhsScalar& a, const RhsScalar& b) const {
return internal::pmin<NaNPropagation>(a, b);
}
@@ -173,7 +167,6 @@
struct scalar_max_op : binary_op_base<LhsScalar,RhsScalar>
{
typedef typename ScalarBinaryOpTraits<LhsScalar,RhsScalar,scalar_max_op>::ReturnType result_type;
- EIGEN_EMPTY_STRUCT_CTOR(scalar_max_op)
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE result_type operator() (const LhsScalar& a, const RhsScalar& b) const {
return internal::pmax<NaNPropagation>(a,b);
}
@@ -225,7 +218,6 @@
struct scalar_cmp_op<LhsScalar,RhsScalar, cmp_EQ> : binary_op_base<LhsScalar,RhsScalar>
{
typedef bool result_type;
- EIGEN_EMPTY_STRUCT_CTOR(scalar_cmp_op)
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE bool operator()(const LhsScalar& a, const RhsScalar& b) const {return a==b;}
template<typename Packet>
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet packetOp(const Packet& a, const Packet& b) const
@@ -235,7 +227,6 @@
struct scalar_cmp_op<LhsScalar,RhsScalar, cmp_LT> : binary_op_base<LhsScalar,RhsScalar>
{
typedef bool result_type;
- EIGEN_EMPTY_STRUCT_CTOR(scalar_cmp_op)
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE bool operator()(const LhsScalar& a, const RhsScalar& b) const {return a<b;}
template<typename Packet>
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet packetOp(const Packet& a, const Packet& b) const
@@ -245,7 +236,6 @@
struct scalar_cmp_op<LhsScalar,RhsScalar, cmp_LE> : binary_op_base<LhsScalar,RhsScalar>
{
typedef bool result_type;
- EIGEN_EMPTY_STRUCT_CTOR(scalar_cmp_op)
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE bool operator()(const LhsScalar& a, const RhsScalar& b) const {return a<=b;}
template<typename Packet>
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet packetOp(const Packet& a, const Packet& b) const
@@ -255,7 +245,6 @@
struct scalar_cmp_op<LhsScalar,RhsScalar, cmp_GT> : binary_op_base<LhsScalar,RhsScalar>
{
typedef bool result_type;
- EIGEN_EMPTY_STRUCT_CTOR(scalar_cmp_op)
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE bool operator()(const LhsScalar& a, const RhsScalar& b) const {return a>b;}
template<typename Packet>
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet packetOp(const Packet& a, const Packet& b) const
@@ -265,7 +254,6 @@
struct scalar_cmp_op<LhsScalar,RhsScalar, cmp_GE> : binary_op_base<LhsScalar,RhsScalar>
{
typedef bool result_type;
- EIGEN_EMPTY_STRUCT_CTOR(scalar_cmp_op)
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE bool operator()(const LhsScalar& a, const RhsScalar& b) const {return a>=b;}
template<typename Packet>
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet packetOp(const Packet& a, const Packet& b) const
@@ -275,7 +263,6 @@
struct scalar_cmp_op<LhsScalar,RhsScalar, cmp_UNORD> : binary_op_base<LhsScalar,RhsScalar>
{
typedef bool result_type;
- EIGEN_EMPTY_STRUCT_CTOR(scalar_cmp_op)
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE bool operator()(const LhsScalar& a, const RhsScalar& b) const {return !(a<=b || b<=a);}
template<typename Packet>
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet packetOp(const Packet& a, const Packet& b) const
@@ -285,7 +272,6 @@
struct scalar_cmp_op<LhsScalar,RhsScalar, cmp_NEQ> : binary_op_base<LhsScalar,RhsScalar>
{
typedef bool result_type;
- EIGEN_EMPTY_STRUCT_CTOR(scalar_cmp_op)
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE bool operator()(const LhsScalar& a, const RhsScalar& b) const {return a!=b;}
template<typename Packet>
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet packetOp(const Packet& a, const Packet& b) const
@@ -300,8 +286,6 @@
template<typename Scalar>
struct scalar_hypot_op<Scalar,Scalar> : binary_op_base<Scalar,Scalar>
{
- EIGEN_EMPTY_STRUCT_CTOR(scalar_hypot_op)
-
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const Scalar operator() (const Scalar &x, const Scalar &y) const
{
// This functor is used by hypotNorm only for which it is faster to first apply abs
@@ -331,9 +315,7 @@
struct scalar_pow_op : binary_op_base<Scalar,Exponent>
{
typedef typename ScalarBinaryOpTraits<Scalar,Exponent,scalar_pow_op>::ReturnType result_type;
-#ifndef EIGEN_SCALAR_BINARY_OP_PLUGIN
- EIGEN_EMPTY_STRUCT_CTOR(scalar_pow_op)
-#else
+#ifdef EIGEN_SCALAR_BINARY_OP_PLUGIN
scalar_pow_op() {
typedef Scalar LhsScalar;
typedef Exponent RhsScalar;
@@ -376,9 +358,7 @@
struct scalar_difference_op : binary_op_base<LhsScalar,RhsScalar>
{
typedef typename ScalarBinaryOpTraits<LhsScalar,RhsScalar,scalar_difference_op>::ReturnType result_type;
-#ifndef EIGEN_SCALAR_BINARY_OP_PLUGIN
- EIGEN_EMPTY_STRUCT_CTOR(scalar_difference_op)
-#else
+#ifdef EIGEN_SCALAR_BINARY_OP_PLUGIN
scalar_difference_op() {
EIGEN_SCALAR_BINARY_OP_PLUGIN
}
@@ -405,9 +385,7 @@
struct scalar_quotient_op : binary_op_base<LhsScalar,RhsScalar>
{
typedef typename ScalarBinaryOpTraits<LhsScalar,RhsScalar,scalar_quotient_op>::ReturnType result_type;
-#ifndef EIGEN_SCALAR_BINARY_OP_PLUGIN
- EIGEN_EMPTY_STRUCT_CTOR(scalar_quotient_op)
-#else
+#ifdef EIGEN_SCALAR_BINARY_OP_PLUGIN
scalar_quotient_op() {
EIGEN_SCALAR_BINARY_OP_PLUGIN
}
@@ -434,7 +412,6 @@
* \sa class CwiseBinaryOp, ArrayBase::operator&&
*/
struct scalar_boolean_and_op {
- EIGEN_EMPTY_STRUCT_CTOR(scalar_boolean_and_op)
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE bool operator() (const bool& a, const bool& b) const { return a && b; }
template<typename Packet>
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const Packet packetOp(const Packet& a, const Packet& b) const
@@ -453,7 +430,6 @@
* \sa class CwiseBinaryOp, ArrayBase::operator||
*/
struct scalar_boolean_or_op {
- EIGEN_EMPTY_STRUCT_CTOR(scalar_boolean_or_op)
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE bool operator() (const bool& a, const bool& b) const { return a || b; }
template<typename Packet>
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const Packet packetOp(const Packet& a, const Packet& b) const
@@ -472,7 +448,6 @@
* \sa class CwiseBinaryOp, ArrayBase::operator^
*/
struct scalar_boolean_xor_op {
- EIGEN_EMPTY_STRUCT_CTOR(scalar_boolean_xor_op)
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE bool operator() (const bool& a, const bool& b) const { return a ^ b; }
template<typename Packet>
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const Packet packetOp(const Packet& a, const Packet& b) const
@@ -494,9 +469,7 @@
struct scalar_absolute_difference_op : binary_op_base<LhsScalar,RhsScalar>
{
typedef typename ScalarBinaryOpTraits<LhsScalar,RhsScalar,scalar_absolute_difference_op>::ReturnType result_type;
-#ifndef EIGEN_SCALAR_BINARY_OP_PLUGIN
- EIGEN_EMPTY_STRUCT_CTOR(scalar_absolute_difference_op)
-#else
+#ifdef EIGEN_SCALAR_BINARY_OP_PLUGIN
scalar_absolute_difference_op() {
EIGEN_SCALAR_BINARY_OP_PLUGIN
}
diff --git a/Eigen/src/Core/functors/NullaryFunctors.h b/Eigen/src/Core/functors/NullaryFunctors.h
index 0293a99..e099d4a 100644
--- a/Eigen/src/Core/functors/NullaryFunctors.h
+++ b/Eigen/src/Core/functors/NullaryFunctors.h
@@ -31,7 +31,6 @@
PacketAccess = packet_traits<Scalar>::Vectorizable, IsRepeatable = true }; };
template<typename Scalar> struct scalar_identity_op {
- EIGEN_EMPTY_STRUCT_CTOR(scalar_identity_op)
template<typename IndexType>
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const Scalar operator() (IndexType row, IndexType col) const { return row==col ? Scalar(1) : Scalar(0); }
};
diff --git a/Eigen/src/Core/functors/UnaryFunctors.h b/Eigen/src/Core/functors/UnaryFunctors.h
index e7bccaf..fafb533 100644
--- a/Eigen/src/Core/functors/UnaryFunctors.h
+++ b/Eigen/src/Core/functors/UnaryFunctors.h
@@ -22,7 +22,6 @@
* \sa class CwiseUnaryOp, MatrixBase::operator-
*/
template<typename Scalar> struct scalar_opposite_op {
- EIGEN_EMPTY_STRUCT_CTOR(scalar_opposite_op)
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const Scalar operator() (const Scalar& a) const { return -a; }
template<typename Packet>
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const Packet packetOp(const Packet& a) const
@@ -41,7 +40,6 @@
* \sa class CwiseUnaryOp, Cwise::abs
*/
template<typename Scalar> struct scalar_abs_op {
- EIGEN_EMPTY_STRUCT_CTOR(scalar_abs_op)
typedef typename NumTraits<Scalar>::Real result_type;
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const result_type operator() (const Scalar& a) const { return numext::abs(a); }
template<typename Packet>
@@ -72,14 +70,12 @@
/* Avoid recomputing abs when we know the score and they are the same. Not a true Eigen functor. */
template<typename Scalar, typename=void> struct abs_knowing_score
{
- EIGEN_EMPTY_STRUCT_CTOR(abs_knowing_score)
typedef typename NumTraits<Scalar>::Real result_type;
template<typename Score>
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const result_type operator() (const Scalar& a, const Score&) const { return numext::abs(a); }
};
template<typename Scalar> struct abs_knowing_score<Scalar, typename scalar_score_coeff_op<Scalar>::Score_is_abs>
{
- EIGEN_EMPTY_STRUCT_CTOR(abs_knowing_score)
typedef typename NumTraits<Scalar>::Real result_type;
template<typename Scal>
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const result_type operator() (const Scal&, const result_type& a) const { return a; }
@@ -91,7 +87,6 @@
* \sa class CwiseUnaryOp, Cwise::abs2
*/
template<typename Scalar> struct scalar_abs2_op {
- EIGEN_EMPTY_STRUCT_CTOR(scalar_abs2_op)
typedef typename NumTraits<Scalar>::Real result_type;
EIGEN_DEVICE_FUNC
EIGEN_STRONG_INLINE const result_type operator() (const Scalar& a) const { return numext::abs2(a); }
@@ -109,7 +104,6 @@
* \sa class CwiseUnaryOp, MatrixBase::conjugate()
*/
template<typename Scalar> struct scalar_conjugate_op {
- EIGEN_EMPTY_STRUCT_CTOR(scalar_conjugate_op)
EIGEN_DEVICE_FUNC
EIGEN_STRONG_INLINE const Scalar operator() (const Scalar& a) const { return numext::conj(a); }
template<typename Packet>
@@ -138,7 +132,6 @@
* \sa class CwiseUnaryOp, Cwise::arg
*/
template<typename Scalar> struct scalar_arg_op {
- EIGEN_EMPTY_STRUCT_CTOR(scalar_arg_op)
typedef typename NumTraits<Scalar>::Real result_type;
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const result_type operator() (const Scalar& a) const { return numext::arg(a); }
template<typename Packet>
@@ -160,7 +153,6 @@
*/
template<typename Scalar, typename NewType>
struct scalar_cast_op {
- EIGEN_EMPTY_STRUCT_CTOR(scalar_cast_op)
typedef NewType result_type;
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const NewType operator() (const Scalar& a) const { return cast<Scalar, NewType>(a); }
};
@@ -175,7 +167,6 @@
*/
template<typename Scalar, int N>
struct scalar_shift_right_op {
- EIGEN_EMPTY_STRUCT_CTOR(scalar_shift_right_op)
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const Scalar operator() (const Scalar& a) const
{ return a >> N; }
@@ -194,8 +185,6 @@
*/
template<typename Scalar, int N>
struct scalar_shift_left_op {
- EIGEN_EMPTY_STRUCT_CTOR(scalar_shift_left_op)
-
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const Scalar operator() (const Scalar& a) const
{ return a << N; }
template<typename Packet>
@@ -213,7 +202,6 @@
*/
template<typename Scalar>
struct scalar_real_op {
- EIGEN_EMPTY_STRUCT_CTOR(scalar_real_op)
typedef typename NumTraits<Scalar>::Real result_type;
EIGEN_DEVICE_FUNC
EIGEN_STRONG_INLINE result_type operator() (const Scalar& a) const { return numext::real(a); }
@@ -229,7 +217,6 @@
*/
template<typename Scalar>
struct scalar_imag_op {
- EIGEN_EMPTY_STRUCT_CTOR(scalar_imag_op)
typedef typename NumTraits<Scalar>::Real result_type;
EIGEN_DEVICE_FUNC
EIGEN_STRONG_INLINE result_type operator() (const Scalar& a) const { return numext::imag(a); }
@@ -245,7 +232,6 @@
*/
template<typename Scalar>
struct scalar_real_ref_op {
- EIGEN_EMPTY_STRUCT_CTOR(scalar_real_ref_op)
typedef typename NumTraits<Scalar>::Real result_type;
EIGEN_DEVICE_FUNC
EIGEN_STRONG_INLINE result_type& operator() (const Scalar& a) const { return numext::real_ref(*const_cast<Scalar*>(&a)); }
@@ -261,7 +247,6 @@
*/
template<typename Scalar>
struct scalar_imag_ref_op {
- EIGEN_EMPTY_STRUCT_CTOR(scalar_imag_ref_op)
typedef typename NumTraits<Scalar>::Real result_type;
EIGEN_DEVICE_FUNC
EIGEN_STRONG_INLINE result_type& operator() (const Scalar& a) const { return numext::imag_ref(*const_cast<Scalar*>(&a)); }
@@ -277,7 +262,6 @@
* \sa class CwiseUnaryOp, Cwise::exp()
*/
template<typename Scalar> struct scalar_exp_op {
- EIGEN_EMPTY_STRUCT_CTOR(scalar_exp_op)
EIGEN_DEVICE_FUNC inline const Scalar operator() (const Scalar& a) const { return internal::pexp(a); }
template <typename Packet>
EIGEN_DEVICE_FUNC inline Packet packetOp(const Packet& a) const { return internal::pexp(a); }
@@ -317,7 +301,6 @@
* \sa class CwiseUnaryOp, ArrayBase::expm1()
*/
template<typename Scalar> struct scalar_expm1_op {
- EIGEN_EMPTY_STRUCT_CTOR(scalar_expm1_op)
EIGEN_DEVICE_FUNC inline const Scalar operator() (const Scalar& a) const { return numext::expm1(a); }
template <typename Packet>
EIGEN_DEVICE_FUNC inline Packet packetOp(const Packet& a) const { return internal::pexpm1(a); }
@@ -337,7 +320,6 @@
* \sa class CwiseUnaryOp, ArrayBase::log()
*/
template<typename Scalar> struct scalar_log_op {
- EIGEN_EMPTY_STRUCT_CTOR(scalar_log_op)
EIGEN_DEVICE_FUNC inline const Scalar operator() (const Scalar& a) const { return numext::log(a); }
template <typename Packet>
EIGEN_DEVICE_FUNC inline Packet packetOp(const Packet& a) const { return internal::plog(a); }
@@ -368,7 +350,6 @@
* \sa class CwiseUnaryOp, ArrayBase::log1p()
*/
template<typename Scalar> struct scalar_log1p_op {
- EIGEN_EMPTY_STRUCT_CTOR(scalar_log1p_op)
EIGEN_DEVICE_FUNC inline const Scalar operator() (const Scalar& a) const { return numext::log1p(a); }
template <typename Packet>
EIGEN_DEVICE_FUNC inline Packet packetOp(const Packet& a) const { return internal::plog1p(a); }
@@ -388,7 +369,6 @@
* \sa class CwiseUnaryOp, Cwise::log10()
*/
template<typename Scalar> struct scalar_log10_op {
- EIGEN_EMPTY_STRUCT_CTOR(scalar_log10_op)
EIGEN_DEVICE_FUNC inline const Scalar operator() (const Scalar& a) const { EIGEN_USING_STD(log10) return log10(a); }
template <typename Packet>
EIGEN_DEVICE_FUNC inline Packet packetOp(const Packet& a) const { return internal::plog10(a); }
@@ -404,7 +384,6 @@
* \sa class CwiseUnaryOp, Cwise::log2()
*/
template<typename Scalar> struct scalar_log2_op {
- EIGEN_EMPTY_STRUCT_CTOR(scalar_log2_op)
EIGEN_DEVICE_FUNC inline const Scalar operator() (const Scalar& a) const { return Scalar(EIGEN_LOG2E) * numext::log(a); }
template <typename Packet>
EIGEN_DEVICE_FUNC inline Packet packetOp(const Packet& a) const { return internal::plog2(a); }
@@ -418,7 +397,6 @@
* \sa class CwiseUnaryOp, Cwise::sqrt()
*/
template<typename Scalar> struct scalar_sqrt_op {
- EIGEN_EMPTY_STRUCT_CTOR(scalar_sqrt_op)
EIGEN_DEVICE_FUNC inline const Scalar operator() (const Scalar& a) const { return numext::sqrt(a); }
template <typename Packet>
EIGEN_DEVICE_FUNC inline Packet packetOp(const Packet& a) const { return internal::psqrt(a); }
@@ -442,7 +420,6 @@
// Boolean specialization to eliminate -Wimplicit-conversion-floating-point-to-bool warnings.
template<> struct scalar_sqrt_op<bool> {
- EIGEN_EMPTY_STRUCT_CTOR(scalar_sqrt_op)
EIGEN_DEPRECATED EIGEN_DEVICE_FUNC inline bool operator() (const bool& a) const { return a; }
template <typename Packet>
EIGEN_DEPRECATED EIGEN_DEVICE_FUNC inline Packet packetOp(const Packet& a) const { return a; }
@@ -457,7 +434,6 @@
* \sa class CwiseUnaryOp, Cwise::rsqrt()
*/
template<typename Scalar> struct scalar_rsqrt_op {
- EIGEN_EMPTY_STRUCT_CTOR(scalar_rsqrt_op)
EIGEN_DEVICE_FUNC inline const Scalar operator() (const Scalar& a) const { return numext::rsqrt(a); }
template <typename Packet>
EIGEN_DEVICE_FUNC inline Packet packetOp(const Packet& a) const { return internal::prsqrt(a); }
@@ -476,7 +452,6 @@
* \sa class CwiseUnaryOp, ArrayBase::cos()
*/
template<typename Scalar> struct scalar_cos_op {
- EIGEN_EMPTY_STRUCT_CTOR(scalar_cos_op)
EIGEN_DEVICE_FUNC inline Scalar operator() (const Scalar& a) const { return numext::cos(a); }
template <typename Packet>
EIGEN_DEVICE_FUNC inline Packet packetOp(const Packet& a) const { return internal::pcos(a); }
@@ -495,7 +470,6 @@
* \sa class CwiseUnaryOp, ArrayBase::sin()
*/
template<typename Scalar> struct scalar_sin_op {
- EIGEN_EMPTY_STRUCT_CTOR(scalar_sin_op)
EIGEN_DEVICE_FUNC inline const Scalar operator() (const Scalar& a) const { return numext::sin(a); }
template <typename Packet>
EIGEN_DEVICE_FUNC inline Packet packetOp(const Packet& a) const { return internal::psin(a); }
@@ -515,7 +489,6 @@
* \sa class CwiseUnaryOp, ArrayBase::tan()
*/
template<typename Scalar> struct scalar_tan_op {
- EIGEN_EMPTY_STRUCT_CTOR(scalar_tan_op)
EIGEN_DEVICE_FUNC inline const Scalar operator() (const Scalar& a) const { return numext::tan(a); }
template <typename Packet>
EIGEN_DEVICE_FUNC inline Packet packetOp(const Packet& a) const { return internal::ptan(a); }
@@ -534,7 +507,6 @@
* \sa class CwiseUnaryOp, ArrayBase::acos()
*/
template<typename Scalar> struct scalar_acos_op {
- EIGEN_EMPTY_STRUCT_CTOR(scalar_acos_op)
EIGEN_DEVICE_FUNC inline const Scalar operator() (const Scalar& a) const { return numext::acos(a); }
template <typename Packet>
EIGEN_DEVICE_FUNC inline Packet packetOp(const Packet& a) const { return internal::pacos(a); }
@@ -553,7 +525,6 @@
* \sa class CwiseUnaryOp, ArrayBase::asin()
*/
template<typename Scalar> struct scalar_asin_op {
- EIGEN_EMPTY_STRUCT_CTOR(scalar_asin_op)
EIGEN_DEVICE_FUNC inline const Scalar operator() (const Scalar& a) const { return numext::asin(a); }
template <typename Packet>
EIGEN_DEVICE_FUNC inline Packet packetOp(const Packet& a) const { return internal::pasin(a); }
@@ -573,7 +544,6 @@
* \sa class CwiseUnaryOp, ArrayBase::atan()
*/
template<typename Scalar> struct scalar_atan_op {
- EIGEN_EMPTY_STRUCT_CTOR(scalar_atan_op)
EIGEN_DEVICE_FUNC inline const Scalar operator() (const Scalar& a) const { return numext::atan(a); }
template <typename Packet>
EIGEN_DEVICE_FUNC inline Packet packetOp(const Packet& a) const { return internal::patan(a); }
@@ -593,7 +563,6 @@
*/
template <typename Scalar>
struct scalar_tanh_op {
- EIGEN_EMPTY_STRUCT_CTOR(scalar_tanh_op)
EIGEN_DEVICE_FUNC inline const Scalar operator()(const Scalar& a) const { return numext::tanh(a); }
template <typename Packet>
EIGEN_DEVICE_FUNC inline Packet packetOp(const Packet& x) const { return ptanh(x); }
@@ -631,7 +600,6 @@
*/
template <typename Scalar>
struct scalar_atanh_op {
- EIGEN_EMPTY_STRUCT_CTOR(scalar_atanh_op)
EIGEN_DEVICE_FUNC inline const Scalar operator()(const Scalar& a) const { return numext::atanh(a); }
};
@@ -646,7 +614,6 @@
* \sa class CwiseUnaryOp, ArrayBase::sinh()
*/
template<typename Scalar> struct scalar_sinh_op {
- EIGEN_EMPTY_STRUCT_CTOR(scalar_sinh_op)
EIGEN_DEVICE_FUNC inline const Scalar operator() (const Scalar& a) const { return numext::sinh(a); }
template <typename Packet>
EIGEN_DEVICE_FUNC inline Packet packetOp(const Packet& a) const { return internal::psinh(a); }
@@ -667,7 +634,6 @@
*/
template <typename Scalar>
struct scalar_asinh_op {
- EIGEN_EMPTY_STRUCT_CTOR(scalar_asinh_op)
EIGEN_DEVICE_FUNC inline const Scalar operator()(const Scalar& a) const { return numext::asinh(a); }
};
@@ -682,7 +648,6 @@
* \sa class CwiseUnaryOp, ArrayBase::cosh()
*/
template<typename Scalar> struct scalar_cosh_op {
- EIGEN_EMPTY_STRUCT_CTOR(scalar_cosh_op)
EIGEN_DEVICE_FUNC inline const Scalar operator() (const Scalar& a) const { return numext::cosh(a); }
template <typename Packet>
EIGEN_DEVICE_FUNC inline Packet packetOp(const Packet& a) const { return internal::pcosh(a); }
@@ -703,7 +668,6 @@
*/
template <typename Scalar>
struct scalar_acosh_op {
- EIGEN_EMPTY_STRUCT_CTOR(scalar_acosh_op)
EIGEN_DEVICE_FUNC inline const Scalar operator()(const Scalar& a) const { return numext::acosh(a); }
};
@@ -719,7 +683,6 @@
*/
template<typename Scalar>
struct scalar_inverse_op {
- EIGEN_EMPTY_STRUCT_CTOR(scalar_inverse_op)
EIGEN_DEVICE_FUNC inline Scalar operator() (const Scalar& a) const { return Scalar(1)/a; }
template<typename Packet>
EIGEN_DEVICE_FUNC inline const Packet packetOp(const Packet& a) const
@@ -744,7 +707,6 @@
*/
template<typename Scalar>
struct scalar_square_op {
- EIGEN_EMPTY_STRUCT_CTOR(scalar_square_op)
EIGEN_DEVICE_FUNC inline Scalar operator() (const Scalar& a) const { return a*a; }
template<typename Packet>
EIGEN_DEVICE_FUNC inline const Packet packetOp(const Packet& a) const
@@ -757,7 +719,6 @@
// Boolean specialization to avoid -Wint-in-bool-context warnings on GCC.
template<>
struct scalar_square_op<bool> {
- EIGEN_EMPTY_STRUCT_CTOR(scalar_square_op)
EIGEN_DEPRECATED EIGEN_DEVICE_FUNC inline bool operator() (const bool& a) const { return a; }
template<typename Packet>
EIGEN_DEPRECATED EIGEN_DEVICE_FUNC inline const Packet packetOp(const Packet& a) const
@@ -773,7 +734,6 @@
*/
template<typename Scalar>
struct scalar_cube_op {
- EIGEN_EMPTY_STRUCT_CTOR(scalar_cube_op)
EIGEN_DEVICE_FUNC inline Scalar operator() (const Scalar& a) const { return a*a*a; }
template<typename Packet>
EIGEN_DEVICE_FUNC inline const Packet packetOp(const Packet& a) const
@@ -786,7 +746,6 @@
// Boolean specialization to avoid -Wint-in-bool-context warnings on GCC.
template<>
struct scalar_cube_op<bool> {
- EIGEN_EMPTY_STRUCT_CTOR(scalar_cube_op)
EIGEN_DEPRECATED EIGEN_DEVICE_FUNC inline bool operator() (const bool& a) const { return a; }
template<typename Packet>
EIGEN_DEPRECATED EIGEN_DEVICE_FUNC inline const Packet packetOp(const Packet& a) const
@@ -801,7 +760,6 @@
* \sa class CwiseUnaryOp, ArrayBase::round()
*/
template<typename Scalar> struct scalar_round_op {
- EIGEN_EMPTY_STRUCT_CTOR(scalar_round_op)
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const Scalar operator() (const Scalar& a) const { return numext::round(a); }
template <typename Packet>
EIGEN_DEVICE_FUNC inline Packet packetOp(const Packet& a) const { return internal::pround(a); }
@@ -820,7 +778,6 @@
* \sa class CwiseUnaryOp, ArrayBase::floor()
*/
template<typename Scalar> struct scalar_floor_op {
- EIGEN_EMPTY_STRUCT_CTOR(scalar_floor_op)
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const Scalar operator() (const Scalar& a) const { return numext::floor(a); }
template <typename Packet>
EIGEN_DEVICE_FUNC inline Packet packetOp(const Packet& a) const { return internal::pfloor(a); }
@@ -839,7 +796,6 @@
* \sa class CwiseUnaryOp, ArrayBase::rint()
*/
template<typename Scalar> struct scalar_rint_op {
- EIGEN_EMPTY_STRUCT_CTOR(scalar_rint_op)
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const Scalar operator() (const Scalar& a) const { return numext::rint(a); }
template <typename Packet>
EIGEN_DEVICE_FUNC inline Packet packetOp(const Packet& a) const { return internal::print(a); }
@@ -858,7 +814,6 @@
* \sa class CwiseUnaryOp, ArrayBase::ceil()
*/
template<typename Scalar> struct scalar_ceil_op {
- EIGEN_EMPTY_STRUCT_CTOR(scalar_ceil_op)
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const Scalar operator() (const Scalar& a) const { return numext::ceil(a); }
template <typename Packet>
EIGEN_DEVICE_FUNC inline Packet packetOp(const Packet& a) const { return internal::pceil(a); }
@@ -877,7 +832,6 @@
* \sa class CwiseUnaryOp, ArrayBase::isnan()
*/
template<typename Scalar> struct scalar_isnan_op {
- EIGEN_EMPTY_STRUCT_CTOR(scalar_isnan_op)
typedef bool result_type;
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE result_type operator() (const Scalar& a) const {
#if defined(SYCL_DEVICE_ONLY)
@@ -901,7 +855,6 @@
* \sa class CwiseUnaryOp, ArrayBase::isinf()
*/
template<typename Scalar> struct scalar_isinf_op {
- EIGEN_EMPTY_STRUCT_CTOR(scalar_isinf_op)
typedef bool result_type;
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE result_type operator() (const Scalar& a) const {
#if defined(SYCL_DEVICE_ONLY)
@@ -925,7 +878,6 @@
* \sa class CwiseUnaryOp, ArrayBase::isfinite()
*/
template<typename Scalar> struct scalar_isfinite_op {
- EIGEN_EMPTY_STRUCT_CTOR(scalar_isfinite_op)
typedef bool result_type;
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE result_type operator() (const Scalar& a) const {
#if defined(SYCL_DEVICE_ONLY)
@@ -950,7 +902,6 @@
* \sa class CwiseUnaryOp, ArrayBase::operator!
*/
template<typename Scalar> struct scalar_boolean_not_op {
- EIGEN_EMPTY_STRUCT_CTOR(scalar_boolean_not_op)
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE bool operator() (const bool& a) const { return !a; }
};
template<typename Scalar>
@@ -968,7 +919,6 @@
template<typename Scalar,bool is_complex=(NumTraits<Scalar>::IsComplex!=0), bool is_integer=(NumTraits<Scalar>::IsInteger!=0) > struct scalar_sign_op;
template<typename Scalar>
struct scalar_sign_op<Scalar, false, true> {
- EIGEN_EMPTY_STRUCT_CTOR(scalar_sign_op)
EIGEN_DEVICE_FUNC inline const Scalar operator() (const Scalar& a) const
{
return Scalar( (a>Scalar(0)) - (a<Scalar(0)) );
@@ -980,7 +930,6 @@
template<typename Scalar>
struct scalar_sign_op<Scalar, false, false> {
- EIGEN_EMPTY_STRUCT_CTOR(scalar_sign_op)
EIGEN_DEVICE_FUNC inline const Scalar operator() (const Scalar& a) const
{
return (numext::isnan)(a) ? a : Scalar( (a>Scalar(0)) - (a<Scalar(0)) );
@@ -992,7 +941,6 @@
template<typename Scalar, bool is_integer>
struct scalar_sign_op<Scalar,true, is_integer> {
- EIGEN_EMPTY_STRUCT_CTOR(scalar_sign_op)
EIGEN_DEVICE_FUNC inline const Scalar operator() (const Scalar& a) const
{
typedef typename NumTraits<Scalar>::Real real_type;
@@ -1023,7 +971,6 @@
*/
template <typename T>
struct scalar_logistic_op {
- EIGEN_EMPTY_STRUCT_CTOR(scalar_logistic_op)
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE T operator()(const T& x) const {
return packetOp(x);
}
@@ -1059,7 +1006,6 @@
*/
template <>
struct scalar_logistic_op<float> {
- EIGEN_EMPTY_STRUCT_CTOR(scalar_logistic_op)
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE float operator()(const float& x) const {
// Truncate at the first point where the interpolant is exactly one.
const float cst_exp_hi = 16.6355324f;
diff --git a/Eigen/src/Core/products/GeneralBlockPanelKernel.h b/Eigen/src/Core/products/GeneralBlockPanelKernel.h
index 076b95f..b1a1277 100644
--- a/Eigen/src/Core/products/GeneralBlockPanelKernel.h
+++ b/Eigen/src/Core/products/GeneralBlockPanelKernel.h
@@ -356,7 +356,7 @@
private:
static const int remaining_registers = EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS - registers_taken;
public:
- typedef typename conditional<remaining_registers>=4, RhsPacketx4, RhsPacket>::type type;
+ typedef std::conditional_t<remaining_registers>=4, RhsPacketx4, RhsPacket> type;
};
template <typename Packet>
@@ -459,9 +459,9 @@
};
- typedef typename conditional<Vectorizable,LhsPacket_,LhsScalar>::type LhsPacket;
- typedef typename conditional<Vectorizable,RhsPacket_,RhsScalar>::type RhsPacket;
- typedef typename conditional<Vectorizable,ResPacket_,ResScalar>::type ResPacket;
+ typedef std::conditional_t<Vectorizable,LhsPacket_,LhsScalar> LhsPacket;
+ typedef std::conditional_t<Vectorizable,RhsPacket_,RhsScalar> RhsPacket;
+ typedef std::conditional_t<Vectorizable,ResPacket_,ResScalar> ResPacket;
typedef LhsPacket LhsPacket4Packing;
typedef QuadPacket<RhsPacket> RhsPacketx4;
@@ -578,9 +578,9 @@
RhsProgress = 1
};
- typedef typename conditional<Vectorizable,LhsPacket_,LhsScalar>::type LhsPacket;
- typedef typename conditional<Vectorizable,RhsPacket_,RhsScalar>::type RhsPacket;
- typedef typename conditional<Vectorizable,ResPacket_,ResScalar>::type ResPacket;
+ typedef std::conditional_t<Vectorizable,LhsPacket_,LhsScalar> LhsPacket;
+ typedef std::conditional_t<Vectorizable,RhsPacket_,RhsScalar> RhsPacket;
+ typedef std::conditional_t<Vectorizable,ResPacket_,ResScalar> ResPacket;
typedef LhsPacket LhsPacket4Packing;
typedef QuadPacket<RhsPacket> RhsPacketx4;
@@ -614,7 +614,7 @@
EIGEN_STRONG_INLINE void loadRhsQuad(const RhsScalar* b, RhsPacket& dest) const
{
- loadRhsQuad_impl(b,dest, typename conditional<RhsPacketSize==16,true_type,false_type>::type());
+ loadRhsQuad_impl(b,dest, std::conditional_t<RhsPacketSize==16,true_type,false_type>());
}
EIGEN_STRONG_INLINE void loadRhsQuad_impl(const RhsScalar* b, RhsPacket& dest, const true_type&) const
@@ -645,7 +645,7 @@
template <typename LhsPacketType, typename RhsPacketType, typename AccPacketType, typename LaneIdType>
EIGEN_STRONG_INLINE void madd(const LhsPacketType& a, const RhsPacketType& b, AccPacketType& c, RhsPacketType& tmp, const LaneIdType&) const
{
- madd_impl(a, b, c, tmp, typename conditional<Vectorizable,true_type,false_type>::type());
+ madd_impl(a, b, c, tmp, std::conditional_t<Vectorizable,true_type,false_type>());
}
template <typename LhsPacketType, typename RhsPacketType, typename AccPacketType>
@@ -703,7 +703,7 @@
template<typename Packet>
const DoublePacket<Packet>&
predux_half_dowto4(const DoublePacket<Packet> &a,
- typename enable_if<unpacket_traits<Packet>::size<=8>::type* = 0)
+ std::enable_if_t<unpacket_traits<Packet>::size<=8>* = 0)
{
return a;
}
@@ -711,7 +711,7 @@
template<typename Packet>
DoublePacket<typename unpacket_traits<Packet>::half>
predux_half_dowto4(const DoublePacket<Packet> &a,
- typename enable_if<unpacket_traits<Packet>::size==16>::type* = 0)
+ std::enable_if_t<unpacket_traits<Packet>::size==16>* = 0)
{
// yes, that's pretty hackish :(
DoublePacket<typename unpacket_traits<Packet>::half> res;
@@ -725,7 +725,7 @@
// same here, "quad" actually means "8" in terms of real coefficients
template<typename Scalar, typename RealPacket>
void loadQuadToDoublePacket(const Scalar* b, DoublePacket<RealPacket>& dest,
- typename enable_if<unpacket_traits<RealPacket>::size<=8>::type* = 0)
+ std::enable_if_t<unpacket_traits<RealPacket>::size<=8>* = 0)
{
dest.first = pset1<RealPacket>(numext::real(*b));
dest.second = pset1<RealPacket>(numext::imag(*b));
@@ -733,7 +733,7 @@
template<typename Scalar, typename RealPacket>
void loadQuadToDoublePacket(const Scalar* b, DoublePacket<RealPacket>& dest,
- typename enable_if<unpacket_traits<RealPacket>::size==16>::type* = 0)
+ std::enable_if_t<unpacket_traits<RealPacket>::size==16>* = 0)
{
// yes, that's pretty hackish too :(
typedef typename NumTraits<Scalar>::Real RealScalar;
@@ -791,11 +791,11 @@
typedef DoublePacket<RealPacket> DoublePacketType;
- typedef typename conditional<Vectorizable,ScalarPacket,Scalar>::type LhsPacket4Packing;
- typedef typename conditional<Vectorizable,RealPacket, Scalar>::type LhsPacket;
- typedef typename conditional<Vectorizable,DoublePacketType,Scalar>::type RhsPacket;
- typedef typename conditional<Vectorizable,ScalarPacket,Scalar>::type ResPacket;
- typedef typename conditional<Vectorizable,DoublePacketType,Scalar>::type AccPacket;
+ typedef std::conditional_t<Vectorizable,ScalarPacket,Scalar> LhsPacket4Packing;
+ typedef std::conditional_t<Vectorizable,RealPacket, Scalar> LhsPacket;
+ typedef std::conditional_t<Vectorizable,DoublePacketType,Scalar> RhsPacket;
+ typedef std::conditional_t<Vectorizable,ScalarPacket,Scalar> ResPacket;
+ typedef std::conditional_t<Vectorizable,DoublePacketType,Scalar> AccPacket;
// this actually holds 8 packets!
typedef QuadPacket<RhsPacket> RhsPacketx4;
@@ -868,7 +868,7 @@
template<typename LhsPacketType, typename RhsPacketType, typename ResPacketType, typename TmpType, typename LaneIdType>
EIGEN_STRONG_INLINE
- typename enable_if<!is_same<RhsPacketType,RhsPacketx4>::value>::type
+ std::enable_if_t<!is_same<RhsPacketType,RhsPacketx4>::value>
madd(const LhsPacketType& a, const RhsPacketType& b, DoublePacket<ResPacketType>& c, TmpType& /*tmp*/, const LaneIdType&) const
{
c.first = padd(pmul(a,b.first), c.first);
@@ -960,9 +960,9 @@
RhsProgress = 1
};
- typedef typename conditional<Vectorizable,LhsPacket_,LhsScalar>::type LhsPacket;
- typedef typename conditional<Vectorizable,RhsPacket_,RhsScalar>::type RhsPacket;
- typedef typename conditional<Vectorizable,ResPacket_,ResScalar>::type ResPacket;
+ typedef std::conditional_t<Vectorizable,LhsPacket_,LhsScalar> LhsPacket;
+ typedef std::conditional_t<Vectorizable,RhsPacket_,RhsScalar> RhsPacket;
+ typedef std::conditional_t<Vectorizable,ResPacket_,ResScalar> ResPacket;
typedef LhsPacket LhsPacket4Packing;
typedef QuadPacket<RhsPacket> RhsPacketx4;
typedef ResPacket AccPacket;
@@ -1011,7 +1011,7 @@
template <typename LhsPacketType, typename RhsPacketType, typename AccPacketType, typename LaneIdType>
EIGEN_STRONG_INLINE void madd(const LhsPacketType& a, const RhsPacketType& b, AccPacketType& c, RhsPacketType& tmp, const LaneIdType&) const
{
- madd_impl(a, b, c, tmp, typename conditional<Vectorizable,true_type,false_type>::type());
+ madd_impl(a, b, c, tmp, std::conditional_t<Vectorizable,true_type,false_type>());
}
template <typename LhsPacketType, typename RhsPacketType, typename AccPacketType>
@@ -1976,10 +1976,10 @@
if(SwappedTraits::LhsProgress==8)
{
// Special case where we have to first reduce the accumulation register C0
- typedef typename conditional<SwappedTraits::LhsProgress>=8,typename unpacket_traits<SResPacket>::half,SResPacket>::type SResPacketHalf;
- typedef typename conditional<SwappedTraits::LhsProgress>=8,typename unpacket_traits<SLhsPacket>::half,SLhsPacket>::type SLhsPacketHalf;
- typedef typename conditional<SwappedTraits::LhsProgress>=8,typename unpacket_traits<SRhsPacket>::half,SRhsPacket>::type SRhsPacketHalf;
- typedef typename conditional<SwappedTraits::LhsProgress>=8,typename unpacket_traits<SAccPacket>::half,SAccPacket>::type SAccPacketHalf;
+ typedef std::conditional_t<SwappedTraits::LhsProgress>=8,typename unpacket_traits<SResPacket>::half,SResPacket> SResPacketHalf;
+ typedef std::conditional_t<SwappedTraits::LhsProgress>=8,typename unpacket_traits<SLhsPacket>::half,SLhsPacket> SLhsPacketHalf;
+ typedef std::conditional_t<SwappedTraits::LhsProgress>=8,typename unpacket_traits<SRhsPacket>::half,SRhsPacket> SRhsPacketHalf;
+ typedef std::conditional_t<SwappedTraits::LhsProgress>=8,typename unpacket_traits<SAccPacket>::half,SAccPacket> SAccPacketHalf;
SResPacketHalf R = res.template gatherPacket<SResPacketHalf>(i, j2);
SResPacketHalf alphav = pset1<SResPacketHalf>(alpha);
diff --git a/Eigen/src/Core/products/GeneralMatrixMatrix.h b/Eigen/src/Core/products/GeneralMatrixMatrix.h
index df64232..5262428 100644
--- a/Eigen/src/Core/products/GeneralMatrixMatrix.h
+++ b/Eigen/src/Core/products/GeneralMatrixMatrix.h
@@ -277,16 +277,16 @@
template<int StorageOrder, typename LhsScalar_, typename RhsScalar_, int MaxRows, int MaxCols, int MaxDepth, int KcFactor>
class gemm_blocking_space<StorageOrder,LhsScalar_,RhsScalar_,MaxRows, MaxCols, MaxDepth, KcFactor, true /* == FiniteAtCompileTime */>
: public level3_blocking<
- typename conditional<StorageOrder==RowMajor,RhsScalar_,LhsScalar_>::type,
- typename conditional<StorageOrder==RowMajor,LhsScalar_,RhsScalar_>::type>
+ std::conditional_t<StorageOrder==RowMajor,RhsScalar_,LhsScalar_>,
+ std::conditional_t<StorageOrder==RowMajor,LhsScalar_,RhsScalar_>>
{
enum {
Transpose = StorageOrder==RowMajor,
ActualRows = Transpose ? MaxCols : MaxRows,
ActualCols = Transpose ? MaxRows : MaxCols
};
- typedef typename conditional<Transpose,RhsScalar_,LhsScalar_>::type LhsScalar;
- typedef typename conditional<Transpose,LhsScalar_,RhsScalar_>::type RhsScalar;
+ typedef std::conditional_t<Transpose,RhsScalar_,LhsScalar_> LhsScalar;
+ typedef std::conditional_t<Transpose,LhsScalar_,RhsScalar_> RhsScalar;
typedef gebp_traits<LhsScalar,RhsScalar> Traits;
enum {
SizeA = ActualRows * MaxDepth,
@@ -328,14 +328,14 @@
template<int StorageOrder, typename LhsScalar_, typename RhsScalar_, int MaxRows, int MaxCols, int MaxDepth, int KcFactor>
class gemm_blocking_space<StorageOrder,LhsScalar_,RhsScalar_,MaxRows, MaxCols, MaxDepth, KcFactor, false>
: public level3_blocking<
- typename conditional<StorageOrder==RowMajor,RhsScalar_,LhsScalar_>::type,
- typename conditional<StorageOrder==RowMajor,LhsScalar_,RhsScalar_>::type>
+ std::conditional_t<StorageOrder==RowMajor,RhsScalar_,LhsScalar_>,
+ std::conditional_t<StorageOrder==RowMajor,LhsScalar_,RhsScalar_>>
{
enum {
Transpose = StorageOrder==RowMajor
};
- typedef typename conditional<Transpose,RhsScalar_,LhsScalar_>::type LhsScalar;
- typedef typename conditional<Transpose,LhsScalar_,RhsScalar_>::type RhsScalar;
+ typedef std::conditional_t<Transpose,RhsScalar_,LhsScalar_> LhsScalar;
+ typedef std::conditional_t<Transpose,LhsScalar_,RhsScalar_> RhsScalar;
typedef gebp_traits<LhsScalar,RhsScalar> Traits;
Index m_sizeA;
@@ -415,11 +415,11 @@
typedef internal::blas_traits<Lhs> LhsBlasTraits;
typedef typename LhsBlasTraits::DirectLinearAccessType ActualLhsType;
- typedef typename internal::remove_all<ActualLhsType>::type ActualLhsTypeCleaned;
+ typedef internal::remove_all_t<ActualLhsType> ActualLhsTypeCleaned;
typedef internal::blas_traits<Rhs> RhsBlasTraits;
typedef typename RhsBlasTraits::DirectLinearAccessType ActualRhsType;
- typedef typename internal::remove_all<ActualRhsType>::type ActualRhsTypeCleaned;
+ typedef internal::remove_all_t<ActualRhsType> ActualRhsTypeCleaned;
enum {
MaxDepthAtCompileTime = min_size_prefer_fixed(Lhs::MaxColsAtCompileTime, Rhs::MaxRowsAtCompileTime)
@@ -485,8 +485,8 @@
::scaleAndAddTo(dst_vec, a_lhs.row(0), a_rhs, alpha);
}
- typename internal::add_const_on_value_type<ActualLhsType>::type lhs = LhsBlasTraits::extract(a_lhs);
- typename internal::add_const_on_value_type<ActualRhsType>::type rhs = RhsBlasTraits::extract(a_rhs);
+ add_const_on_value_type_t<ActualLhsType> lhs = LhsBlasTraits::extract(a_lhs);
+ add_const_on_value_type_t<ActualRhsType> rhs = RhsBlasTraits::extract(a_rhs);
Scalar actualAlpha = combine_scalar_factors(alpha, a_lhs, a_rhs);
diff --git a/Eigen/src/Core/products/GeneralMatrixMatrixTriangular.h b/Eigen/src/Core/products/GeneralMatrixMatrixTriangular.h
index 9728f30..716f2ca 100644
--- a/Eigen/src/Core/products/GeneralMatrixMatrixTriangular.h
+++ b/Eigen/src/Core/products/GeneralMatrixMatrixTriangular.h
@@ -210,17 +210,17 @@
{
typedef typename MatrixType::Scalar Scalar;
- typedef typename internal::remove_all<typename ProductType::LhsNested>::type Lhs;
+ typedef internal::remove_all_t<typename ProductType::LhsNested> Lhs;
typedef internal::blas_traits<Lhs> LhsBlasTraits;
typedef typename LhsBlasTraits::DirectLinearAccessType ActualLhs;
- typedef typename internal::remove_all<ActualLhs>::type ActualLhs_;
- typename internal::add_const_on_value_type<ActualLhs>::type actualLhs = LhsBlasTraits::extract(prod.lhs());
+ typedef internal::remove_all_t<ActualLhs> ActualLhs_;
+ internal::add_const_on_value_type_t<ActualLhs> actualLhs = LhsBlasTraits::extract(prod.lhs());
- typedef typename internal::remove_all<typename ProductType::RhsNested>::type Rhs;
+ typedef internal::remove_all_t<typename ProductType::RhsNested> Rhs;
typedef internal::blas_traits<Rhs> RhsBlasTraits;
typedef typename RhsBlasTraits::DirectLinearAccessType ActualRhs;
- typedef typename internal::remove_all<ActualRhs>::type ActualRhs_;
- typename internal::add_const_on_value_type<ActualRhs>::type actualRhs = RhsBlasTraits::extract(prod.rhs());
+ typedef internal::remove_all_t<ActualRhs> ActualRhs_;
+ internal::add_const_on_value_type_t<ActualRhs> actualRhs = RhsBlasTraits::extract(prod.rhs());
Scalar actualAlpha = alpha * LhsBlasTraits::extractScalarFactor(prod.lhs().derived()) * RhsBlasTraits::extractScalarFactor(prod.rhs().derived());
@@ -256,17 +256,17 @@
{
static void run(MatrixType& mat, const ProductType& prod, const typename MatrixType::Scalar& alpha, bool beta)
{
- typedef typename internal::remove_all<typename ProductType::LhsNested>::type Lhs;
+ typedef internal::remove_all_t<typename ProductType::LhsNested> Lhs;
typedef internal::blas_traits<Lhs> LhsBlasTraits;
typedef typename LhsBlasTraits::DirectLinearAccessType ActualLhs;
- typedef typename internal::remove_all<ActualLhs>::type ActualLhs_;
- typename internal::add_const_on_value_type<ActualLhs>::type actualLhs = LhsBlasTraits::extract(prod.lhs());
+ typedef internal::remove_all_t<ActualLhs> ActualLhs_;
+ internal::add_const_on_value_type_t<ActualLhs> actualLhs = LhsBlasTraits::extract(prod.lhs());
- typedef typename internal::remove_all<typename ProductType::RhsNested>::type Rhs;
+ typedef internal::remove_all_t<typename ProductType::RhsNested> Rhs;
typedef internal::blas_traits<Rhs> RhsBlasTraits;
typedef typename RhsBlasTraits::DirectLinearAccessType ActualRhs;
- typedef typename internal::remove_all<ActualRhs>::type ActualRhs_;
- typename internal::add_const_on_value_type<ActualRhs>::type actualRhs = RhsBlasTraits::extract(prod.rhs());
+ typedef internal::remove_all_t<ActualRhs> ActualRhs_;
+ internal::add_const_on_value_type_t<ActualRhs> actualRhs = RhsBlasTraits::extract(prod.rhs());
typename ProductType::Scalar actualAlpha = alpha * LhsBlasTraits::extractScalarFactor(prod.lhs().derived()) * RhsBlasTraits::extractScalarFactor(prod.rhs().derived());
diff --git a/Eigen/src/Core/products/GeneralMatrixVector.h b/Eigen/src/Core/products/GeneralMatrixVector.h
index e7dd61d..7307994 100644
--- a/Eigen/src/Core/products/GeneralMatrixVector.h
+++ b/Eigen/src/Core/products/GeneralMatrixVector.h
@@ -58,9 +58,9 @@
ResPacketSize = Vectorizable ? unpacket_traits<ResPacket_>::size : 1
};
- typedef typename conditional<Vectorizable,LhsPacket_,LhsScalar>::type LhsPacket;
- typedef typename conditional<Vectorizable,RhsPacket_,RhsScalar>::type RhsPacket;
- typedef typename conditional<Vectorizable,ResPacket_,ResScalar>::type ResPacket;
+ typedef std::conditional_t<Vectorizable,LhsPacket_,LhsScalar> LhsPacket;
+ typedef std::conditional_t<Vectorizable,RhsPacket_,RhsScalar> RhsPacket;
+ typedef std::conditional_t<Vectorizable,ResPacket_,ResScalar> ResPacket;
};
diff --git a/Eigen/src/Core/products/SelfadjointMatrixMatrix.h b/Eigen/src/Core/products/SelfadjointMatrixMatrix.h
index f6fdbca..c7bb445 100644
--- a/Eigen/src/Core/products/SelfadjointMatrixMatrix.h
+++ b/Eigen/src/Core/products/SelfadjointMatrixMatrix.h
@@ -511,8 +511,8 @@
{
eigen_assert(dst.rows()==a_lhs.rows() && dst.cols()==a_rhs.cols());
- typename internal::add_const_on_value_type<ActualLhsType>::type lhs = LhsBlasTraits::extract(a_lhs);
- typename internal::add_const_on_value_type<ActualRhsType>::type rhs = RhsBlasTraits::extract(a_rhs);
+ add_const_on_value_type_t<ActualLhsType> lhs = LhsBlasTraits::extract(a_lhs);
+ add_const_on_value_type_t<ActualRhsType> rhs = RhsBlasTraits::extract(a_rhs);
Scalar actualAlpha = alpha * LhsBlasTraits::extractScalarFactor(a_lhs)
* RhsBlasTraits::extractScalarFactor(a_rhs);
diff --git a/Eigen/src/Core/products/SelfadjointMatrixVector.h b/Eigen/src/Core/products/SelfadjointMatrixVector.h
index 086638e..a62b6b5 100644
--- a/Eigen/src/Core/products/SelfadjointMatrixVector.h
+++ b/Eigen/src/Core/products/SelfadjointMatrixVector.h
@@ -169,11 +169,11 @@
typedef internal::blas_traits<Lhs> LhsBlasTraits;
typedef typename LhsBlasTraits::DirectLinearAccessType ActualLhsType;
- typedef typename internal::remove_all<ActualLhsType>::type ActualLhsTypeCleaned;
+ typedef internal::remove_all_t<ActualLhsType> ActualLhsTypeCleaned;
typedef internal::blas_traits<Rhs> RhsBlasTraits;
typedef typename RhsBlasTraits::DirectLinearAccessType ActualRhsType;
- typedef typename internal::remove_all<ActualRhsType>::type ActualRhsTypeCleaned;
+ typedef internal::remove_all_t<ActualRhsType> ActualRhsTypeCleaned;
enum { LhsUpLo = LhsMode&(Upper|Lower) };
@@ -187,8 +187,8 @@
eigen_assert(dest.rows()==a_lhs.rows() && dest.cols()==a_rhs.cols());
- typename internal::add_const_on_value_type<ActualLhsType>::type lhs = LhsBlasTraits::extract(a_lhs);
- typename internal::add_const_on_value_type<ActualRhsType>::type rhs = RhsBlasTraits::extract(a_rhs);
+ add_const_on_value_type_t<ActualLhsType> lhs = LhsBlasTraits::extract(a_lhs);
+ add_const_on_value_type_t<ActualRhsType> rhs = RhsBlasTraits::extract(a_rhs);
Scalar actualAlpha = alpha * LhsBlasTraits::extractScalarFactor(a_lhs)
* RhsBlasTraits::extractScalarFactor(a_rhs);
diff --git a/Eigen/src/Core/products/SelfadjointProduct.h b/Eigen/src/Core/products/SelfadjointProduct.h
index 026bc19..4cbc1f7 100644
--- a/Eigen/src/Core/products/SelfadjointProduct.h
+++ b/Eigen/src/Core/products/SelfadjointProduct.h
@@ -28,7 +28,7 @@
{
internal::conj_if<ConjRhs> cj;
typedef Map<const Matrix<Scalar,Dynamic,1> > OtherMap;
- typedef typename internal::conditional<ConjLhs,typename OtherMap::ConjugateReturnType,const OtherMap&>::type ConjLhsType;
+ typedef std::conditional_t<ConjLhs,typename OtherMap::ConjugateReturnType,const OtherMap&> ConjLhsType;
for (Index i=0; i<size; ++i)
{
Map<Matrix<Scalar,Dynamic,1> >(mat+stride*i+(UpLo==Lower ? i : 0), (UpLo==Lower ? size-i : (i+1)))
@@ -57,8 +57,8 @@
typedef typename MatrixType::Scalar Scalar;
typedef internal::blas_traits<OtherType> OtherBlasTraits;
typedef typename OtherBlasTraits::DirectLinearAccessType ActualOtherType;
- typedef typename internal::remove_all<ActualOtherType>::type ActualOtherType_;
- typename internal::add_const_on_value_type<ActualOtherType>::type actualOther = OtherBlasTraits::extract(other.derived());
+ typedef internal::remove_all_t<ActualOtherType> ActualOtherType_;
+ internal::add_const_on_value_type_t<ActualOtherType> actualOther = OtherBlasTraits::extract(other.derived());
Scalar actualAlpha = alpha * OtherBlasTraits::extractScalarFactor(other.derived());
@@ -89,8 +89,8 @@
typedef typename MatrixType::Scalar Scalar;
typedef internal::blas_traits<OtherType> OtherBlasTraits;
typedef typename OtherBlasTraits::DirectLinearAccessType ActualOtherType;
- typedef typename internal::remove_all<ActualOtherType>::type ActualOtherType_;
- typename internal::add_const_on_value_type<ActualOtherType>::type actualOther = OtherBlasTraits::extract(other.derived());
+ typedef internal::remove_all_t<ActualOtherType> ActualOtherType_;
+ internal::add_const_on_value_type_t<ActualOtherType> actualOther = OtherBlasTraits::extract(other.derived());
Scalar actualAlpha = alpha * OtherBlasTraits::extractScalarFactor(other.derived());
diff --git a/Eigen/src/Core/products/SelfadjointRank2Update.h b/Eigen/src/Core/products/SelfadjointRank2Update.h
index 44c3381..fb199ad 100644
--- a/Eigen/src/Core/products/SelfadjointRank2Update.h
+++ b/Eigen/src/Core/products/SelfadjointRank2Update.h
@@ -52,9 +52,8 @@
}
};
-template<bool Cond, typename T> struct conj_expr_if
- : conditional<!Cond, const T&,
- CwiseUnaryOp<scalar_conjugate_op<typename traits<T>::Scalar>,T> > {};
+template<bool Cond, typename T>
+using conj_expr_if = std::conditional<!Cond, const T&, CwiseUnaryOp<scalar_conjugate_op<typename traits<T>::Scalar>,T>>;
} // end namespace internal
@@ -65,13 +64,13 @@
{
typedef internal::blas_traits<DerivedU> UBlasTraits;
typedef typename UBlasTraits::DirectLinearAccessType ActualUType;
- typedef typename internal::remove_all<ActualUType>::type ActualUType_;
- typename internal::add_const_on_value_type<ActualUType>::type actualU = UBlasTraits::extract(u.derived());
+ typedef internal::remove_all_t<ActualUType> ActualUType_;
+ internal::add_const_on_value_type_t<ActualUType> actualU = UBlasTraits::extract(u.derived());
typedef internal::blas_traits<DerivedV> VBlasTraits;
typedef typename VBlasTraits::DirectLinearAccessType ActualVType;
- typedef typename internal::remove_all<ActualVType>::type ActualVType_;
- typename internal::add_const_on_value_type<ActualVType>::type actualV = VBlasTraits::extract(v.derived());
+ typedef internal::remove_all_t<ActualVType> ActualVType_;
+ internal::add_const_on_value_type_t<ActualVType> actualV = VBlasTraits::extract(v.derived());
// If MatrixType is row major, then we use the routine for lower triangular in the upper triangular case and
// vice versa, and take the complex conjugate of all coefficients and vector entries.
@@ -82,8 +81,8 @@
if (IsRowMajor)
actualAlpha = numext::conj(actualAlpha);
- typedef typename internal::remove_all<typename internal::conj_expr_if<int(IsRowMajor) ^ int(UBlasTraits::NeedToConjugate), ActualUType_>::type>::type UType;
- typedef typename internal::remove_all<typename internal::conj_expr_if<int(IsRowMajor) ^ int(VBlasTraits::NeedToConjugate), ActualVType_>::type>::type VType;
+ typedef internal::remove_all_t<typename internal::conj_expr_if<int(IsRowMajor) ^ int(UBlasTraits::NeedToConjugate), ActualUType_>::type> UType;
+ typedef internal::remove_all_t<typename internal::conj_expr_if<int(IsRowMajor) ^ int(VBlasTraits::NeedToConjugate), ActualVType_>::type> VType;
internal::selfadjoint_rank2_update_selector<Scalar, Index, UType, VType,
(IsRowMajor ? int(UpLo==Upper ? Lower : Upper) : UpLo)>
::run(_expression().const_cast_derived().data(),_expression().outerStride(),UType(actualU),VType(actualV),actualAlpha);
diff --git a/Eigen/src/Core/products/TriangularMatrixMatrix.h b/Eigen/src/Core/products/TriangularMatrixMatrix.h
index 26bee63..770107a 100644
--- a/Eigen/src/Core/products/TriangularMatrixMatrix.h
+++ b/Eigen/src/Core/products/TriangularMatrixMatrix.h
@@ -414,13 +414,13 @@
typedef internal::blas_traits<Lhs> LhsBlasTraits;
typedef typename LhsBlasTraits::DirectLinearAccessType ActualLhsType;
- typedef typename internal::remove_all<ActualLhsType>::type ActualLhsTypeCleaned;
+ typedef internal::remove_all_t<ActualLhsType> ActualLhsTypeCleaned;
typedef internal::blas_traits<Rhs> RhsBlasTraits;
typedef typename RhsBlasTraits::DirectLinearAccessType ActualRhsType;
- typedef typename internal::remove_all<ActualRhsType>::type ActualRhsTypeCleaned;
-
- typename internal::add_const_on_value_type<ActualLhsType>::type lhs = LhsBlasTraits::extract(a_lhs);
- typename internal::add_const_on_value_type<ActualRhsType>::type rhs = RhsBlasTraits::extract(a_rhs);
+ typedef internal::remove_all_t<ActualRhsType> ActualRhsTypeCleaned;
+
+ internal::add_const_on_value_type_t<ActualLhsType> lhs = LhsBlasTraits::extract(a_lhs);
+ internal::add_const_on_value_type_t<ActualRhsType> rhs = RhsBlasTraits::extract(a_rhs);
LhsScalar lhs_alpha = LhsBlasTraits::extractScalarFactor(a_lhs);
RhsScalar rhs_alpha = RhsBlasTraits::extractScalarFactor(a_rhs);
diff --git a/Eigen/src/Core/products/TriangularMatrixVector.h b/Eigen/src/Core/products/TriangularMatrixVector.h
index 72b3c13..df15e81 100644
--- a/Eigen/src/Core/products/TriangularMatrixVector.h
+++ b/Eigen/src/Core/products/TriangularMatrixVector.h
@@ -219,8 +219,8 @@
typedef Map<Matrix<ResScalar,Dynamic,1>, plain_enum_min(AlignedMax,internal::packet_traits<ResScalar>::size)> MappedDest;
- typename internal::add_const_on_value_type<ActualLhsType>::type actualLhs = LhsBlasTraits::extract(lhs);
- typename internal::add_const_on_value_type<ActualRhsType>::type actualRhs = RhsBlasTraits::extract(rhs);
+ add_const_on_value_type_t<ActualLhsType> actualLhs = LhsBlasTraits::extract(lhs);
+ add_const_on_value_type_t<ActualRhsType> actualRhs = RhsBlasTraits::extract(rhs);
LhsScalar lhs_alpha = LhsBlasTraits::extractScalarFactor(lhs);
RhsScalar rhs_alpha = RhsBlasTraits::extractScalarFactor(rhs);
@@ -298,10 +298,10 @@
typedef typename LhsBlasTraits::DirectLinearAccessType ActualLhsType;
typedef internal::blas_traits<Rhs> RhsBlasTraits;
typedef typename RhsBlasTraits::DirectLinearAccessType ActualRhsType;
- typedef typename internal::remove_all<ActualRhsType>::type ActualRhsTypeCleaned;
+ typedef internal::remove_all_t<ActualRhsType> ActualRhsTypeCleaned;
- typename add_const<ActualLhsType>::type actualLhs = LhsBlasTraits::extract(lhs);
- typename add_const<ActualRhsType>::type actualRhs = RhsBlasTraits::extract(rhs);
+ std::add_const_t<ActualLhsType> actualLhs = LhsBlasTraits::extract(lhs);
+ std::add_const_t<ActualRhsType> actualRhs = RhsBlasTraits::extract(rhs);
LhsScalar lhs_alpha = LhsBlasTraits::extractScalarFactor(lhs);
RhsScalar rhs_alpha = RhsBlasTraits::extractScalarFactor(rhs);
diff --git a/Eigen/src/Core/products/TriangularSolverMatrix.h b/Eigen/src/Core/products/TriangularSolverMatrix.h
index 520cfc9..def6a28 100644
--- a/Eigen/src/Core/products/TriangularSolverMatrix.h
+++ b/Eigen/src/Core/products/TriangularSolverMatrix.h
@@ -2,6 +2,7 @@
// for linear algebra.
//
// Copyright (C) 2009 Gael Guennebaud <gael.guennebaud@inria.fr>
+// Modifications Copyright (C) 2022 Intel Corporation
//
// This Source Code Form is subject to the terms of the Mozilla
// Public License v. 2.0. If a copy of the MPL was not distributed
@@ -12,10 +13,118 @@
#include "../InternalHeaderCheck.h"
-namespace Eigen {
+namespace Eigen {
namespace internal {
+template <typename Scalar, typename Index, int Mode, bool Conjugate, int TriStorageOrder,int OtherInnerStride>
+struct trsm_kernels {
+ // Generic Implementation of triangular solve for triangular matrix on left and multiple rhs.
+ // Handles non-packed matrices.
+ static void trsmKernelL(
+ Index size, Index otherSize,
+ const Scalar* _tri, Index triStride,
+ Scalar* _other, Index otherIncr, Index otherStride);
+
+ // Generic Implementation of triangular solve for triangular matrix on right and multiple lhs.
+ // Handles non-packed matrices.
+ static void trsmKernelR(
+ Index size, Index otherSize,
+ const Scalar* _tri, Index triStride,
+ Scalar* _other, Index otherIncr, Index otherStride);
+};
+
+template <typename Scalar, typename Index, int Mode, bool Conjugate, int TriStorageOrder,int OtherInnerStride>
+EIGEN_STRONG_INLINE void trsm_kernels<Scalar, Index, Mode, Conjugate, TriStorageOrder, OtherInnerStride>::trsmKernelL(
+ Index size, Index otherSize,
+ const Scalar* _tri, Index triStride,
+ Scalar* _other, Index otherIncr, Index otherStride)
+ {
+ typedef const_blas_data_mapper<Scalar, Index, TriStorageOrder> TriMapper;
+ typedef blas_data_mapper<Scalar, Index, ColMajor, Unaligned, OtherInnerStride> OtherMapper;
+ TriMapper tri(_tri, triStride);
+ OtherMapper other(_other, otherStride, otherIncr);
+
+ enum { IsLower = (Mode&Lower) == Lower };
+ conj_if<Conjugate> conj;
+
+ // tr solve
+ for (Index k=0; k<size; ++k)
+ {
+ // TODO write a small kernel handling this (can be shared with trsv)
+ Index i = IsLower ? k : -k-1;
+ Index rs = size - k - 1; // remaining size
+ Index s = TriStorageOrder==RowMajor ? (IsLower ? 0 : i+1)
+ : IsLower ? i+1 : i-rs;
+
+ Scalar a = (Mode & UnitDiag) ? Scalar(1) : Scalar(1)/conj(tri(i,i));
+ for (Index j=0; j<otherSize; ++j)
+ {
+ if (TriStorageOrder==RowMajor)
+ {
+ Scalar b(0);
+ const Scalar* l = &tri(i,s);
+ typename OtherMapper::LinearMapper r = other.getLinearMapper(s,j);
+ for (Index i3=0; i3<k; ++i3)
+ b += conj(l[i3]) * r(i3);
+
+ other(i,j) = (other(i,j) - b)*a;
+ }
+ else
+ {
+ Scalar& otherij = other(i,j);
+ otherij *= a;
+ Scalar b = otherij;
+ typename OtherMapper::LinearMapper r = other.getLinearMapper(s,j);
+ typename TriMapper::LinearMapper l = tri.getLinearMapper(s,i);
+ for (Index i3=0;i3<rs;++i3)
+ r(i3) -= b * conj(l(i3));
+ }
+ }
+ }
+ }
+
+
+template <typename Scalar, typename Index, int Mode, bool Conjugate, int TriStorageOrder, int OtherInnerStride>
+EIGEN_STRONG_INLINE void trsm_kernels<Scalar, Index, Mode, Conjugate, TriStorageOrder, OtherInnerStride>::trsmKernelR(
+ Index size, Index otherSize,
+ const Scalar* _tri, Index triStride,
+ Scalar* _other, Index otherIncr, Index otherStride)
+{
+ typedef typename NumTraits<Scalar>::Real RealScalar;
+ typedef blas_data_mapper<Scalar, Index, ColMajor, Unaligned, OtherInnerStride> LhsMapper;
+ typedef const_blas_data_mapper<Scalar, Index, TriStorageOrder> RhsMapper;
+ LhsMapper lhs(_other, otherStride, otherIncr);
+ RhsMapper rhs(_tri, triStride);
+
+ enum {
+ RhsStorageOrder = TriStorageOrder,
+ IsLower = (Mode&Lower) == Lower
+ };
+ conj_if<Conjugate> conj;
+
+ for (Index k=0; k<size; ++k)
+ {
+ Index j = IsLower ? size-k-1 : k;
+
+ typename LhsMapper::LinearMapper r = lhs.getLinearMapper(0,j);
+ for (Index k3=0; k3<k; ++k3)
+ {
+ Scalar b = conj(rhs(IsLower ? j+1+k3 : k3,j));
+ typename LhsMapper::LinearMapper a = lhs.getLinearMapper(0,IsLower ? j+1+k3 : k3);
+ for (Index i=0; i<otherSize; ++i)
+ r(i) -= a(i) * b;
+ }
+ if((Mode & UnitDiag)==0)
+ {
+ Scalar inv_rjj = RealScalar(1)/conj(rhs(j,j));
+ for (Index i=0; i<otherSize; ++i)
+ r(i) *= inv_rjj;
+ }
+ }
+}
+
+
// if the rhs is row major, let's transpose the product
template <typename Scalar, typename Index, int Side, int Mode, bool Conjugate, int TriStorageOrder, int OtherInnerStride>
struct triangular_solve_matrix<Scalar,Index,Side,Mode,Conjugate,TriStorageOrder,RowMajor,OtherInnerStride>
@@ -46,6 +155,7 @@
Scalar* _other, Index otherIncr, Index otherStride,
level3_blocking<Scalar,Scalar>& blocking);
};
+
template <typename Scalar, typename Index, int Mode, bool Conjugate, int TriStorageOrder, int OtherInnerStride>
EIGEN_DONT_INLINE void triangular_solve_matrix<Scalar,Index,OnTheLeft,Mode,Conjugate,TriStorageOrder,ColMajor,OtherInnerStride>::run(
Index size, Index otherSize,
@@ -55,6 +165,25 @@
{
Index cols = otherSize;
+ std::ptrdiff_t l1, l2, l3;
+ manage_caching_sizes(GetAction, &l1, &l2, &l3);
+
+#if defined(EIGEN_ENABLE_AVX512_NOCOPY_TRSM_CUTOFFS)
+ EIGEN_IF_CONSTEXPR( (OtherInnerStride == 1 &&
+ (std::is_same<Scalar,float>::value ||
+ std::is_same<Scalar,double>::value)) ) {
+ // Very rough cutoffs to determine when to call trsm w/o packing
+ // For small problem sizes trsmKernel compiled with clang is generally faster.
+ // TODO: Investigate better heuristics for cutoffs.
+ double L2Cap = 0.5; // 50% of L2 size
+ if (size < avx512_trsm_cutoff<Scalar>(l2, cols, L2Cap)) {
+ trsm_kernels<Scalar, Index, Mode, Conjugate, TriStorageOrder, 1>::trsmKernelL(
+ size, cols, _tri, triStride, _other, 1, otherStride);
+ return;
+ }
+ }
+#endif
+
typedef const_blas_data_mapper<Scalar, Index, TriStorageOrder> TriMapper;
typedef blas_data_mapper<Scalar, Index, ColMajor, Unaligned, OtherInnerStride> OtherMapper;
TriMapper tri(_tri, triStride);
@@ -76,15 +205,12 @@
ei_declare_aligned_stack_constructed_variable(Scalar, blockA, sizeA, blocking.blockA());
ei_declare_aligned_stack_constructed_variable(Scalar, blockB, sizeB, blocking.blockB());
- conj_if<Conjugate> conj;
gebp_kernel<Scalar, Scalar, Index, OtherMapper, Traits::mr, Traits::nr, Conjugate, false> gebp_kernel;
gemm_pack_lhs<Scalar, Index, TriMapper, Traits::mr, Traits::LhsProgress, typename Traits::LhsPacket4Packing, TriStorageOrder> pack_lhs;
gemm_pack_rhs<Scalar, Index, OtherMapper, Traits::nr, ColMajor, false, true> pack_rhs;
// the goal here is to subdivise the Rhs panels such that we keep some cache
// coherence when accessing the rhs elements
- std::ptrdiff_t l1, l2, l3;
- manage_caching_sizes(GetAction, &l1, &l2, &l3);
Index subcols = cols>0 ? l2/(4 * sizeof(Scalar) * std::max<Index>(otherStride,size)) : 0;
subcols = std::max<Index>((subcols/Traits::nr)*Traits::nr, Traits::nr);
@@ -115,38 +241,19 @@
{
Index actualPanelWidth = std::min<Index>(actual_kc-k1, SmallPanelWidth);
// tr solve
- for (Index k=0; k<actualPanelWidth; ++k)
{
- // TODO write a small kernel handling this (can be shared with trsv)
- Index i = IsLower ? k2+k1+k : k2-k1-k-1;
- Index rs = actualPanelWidth - k - 1; // remaining size
- Index s = TriStorageOrder==RowMajor ? (IsLower ? k2+k1 : i+1)
- : IsLower ? i+1 : i-rs;
-
- Scalar a = (Mode & UnitDiag) ? Scalar(1) : Scalar(1)/conj(tri(i,i));
- for (Index j=j2; j<j2+actual_cols; ++j)
- {
- if (TriStorageOrder==RowMajor)
- {
- Scalar b(0);
- const Scalar* l = &tri(i,s);
- typename OtherMapper::LinearMapper r = other.getLinearMapper(s,j);
- for (Index i3=0; i3<k; ++i3)
- b += conj(l[i3]) * r(i3);
-
- other(i,j) = (other(i,j) - b)*a;
- }
- else
- {
- Scalar& otherij = other(i,j);
- otherij *= a;
- Scalar b = otherij;
- typename OtherMapper::LinearMapper r = other.getLinearMapper(s,j);
- typename TriMapper::LinearMapper l = tri.getLinearMapper(s,i);
- for (Index i3=0;i3<rs;++i3)
- r(i3) -= b * conj(l(i3));
- }
+ Index i = IsLower ? k2+k1 : k2-k1;
+#if defined(EIGEN_USE_AVX512_TRSM_KERNELS)
+ EIGEN_IF_CONSTEXPR( (OtherInnerStride == 1 &&
+ (std::is_same<Scalar,float>::value ||
+ std::is_same<Scalar,double>::value)) ) {
+ i = IsLower ? k2 + k1: k2 - k1 - actualPanelWidth;
}
+#endif
+ trsm_kernels<Scalar, Index, Mode, Conjugate, TriStorageOrder, OtherInnerStride>::trsmKernelL(
+ actualPanelWidth, actual_cols,
+ _tri + i + (i)*triStride, triStride,
+ _other + i*OtherInnerStride + j2*otherStride, otherIncr, otherStride);
}
Index lengthTarget = actual_kc-k1-actualPanelWidth;
@@ -168,7 +275,7 @@
}
}
}
-
+
// R2 -= A21 * B => GEPP
{
Index start = IsLower ? k2+kc : 0;
@@ -198,6 +305,7 @@
Scalar* _other, Index otherIncr, Index otherStride,
level3_blocking<Scalar,Scalar>& blocking);
};
+
template <typename Scalar, typename Index, int Mode, bool Conjugate, int TriStorageOrder, int OtherInnerStride>
EIGEN_DONT_INLINE void triangular_solve_matrix<Scalar,Index,OnTheRight,Mode,Conjugate,TriStorageOrder,ColMajor,OtherInnerStride>::run(
Index size, Index otherSize,
@@ -206,7 +314,22 @@
level3_blocking<Scalar,Scalar>& blocking)
{
Index rows = otherSize;
- typedef typename NumTraits<Scalar>::Real RealScalar;
+
+#if defined(EIGEN_ENABLE_AVX512_NOCOPY_TRSM_CUTOFFS)
+ EIGEN_IF_CONSTEXPR( (OtherInnerStride == 1 &&
+ (std::is_same<Scalar,float>::value ||
+ std::is_same<Scalar,double>::value)) ) {
+ // TODO: Investigate better heuristics for cutoffs.
+ std::ptrdiff_t l1, l2, l3;
+ manage_caching_sizes(GetAction, &l1, &l2, &l3);
+ double L2Cap = 0.5; // 50% of L2 size
+ if (size < avx512_trsm_cutoff<Scalar>(l2, rows, L2Cap)) {
+ trsm_kernels<Scalar, Index, Mode, Conjugate, TriStorageOrder, OtherInnerStride>::
+ trsmKernelR(size, rows, _tri, triStride, _other, 1, otherStride);
+ return;
+ }
+ }
+#endif
typedef blas_data_mapper<Scalar, Index, ColMajor, Unaligned, OtherInnerStride> LhsMapper;
typedef const_blas_data_mapper<Scalar, Index, TriStorageOrder> RhsMapper;
@@ -229,7 +352,6 @@
ei_declare_aligned_stack_constructed_variable(Scalar, blockA, sizeA, blocking.blockA());
ei_declare_aligned_stack_constructed_variable(Scalar, blockB, sizeB, blocking.blockB());
- conj_if<Conjugate> conj;
gebp_kernel<Scalar, Scalar, Index, LhsMapper, Traits::mr, Traits::nr, false, Conjugate> gebp_kernel;
gemm_pack_rhs<Scalar, Index, RhsMapper, Traits::nr, RhsStorageOrder> pack_rhs;
gemm_pack_rhs<Scalar, Index, RhsMapper, Traits::nr, RhsStorageOrder,false,true> pack_rhs_panel;
@@ -296,27 +418,13 @@
panelOffset, panelOffset); // offsets
}
- // unblocked triangular solve
- for (Index k=0; k<actualPanelWidth; ++k)
{
- Index j = IsLower ? absolute_j2+actualPanelWidth-k-1 : absolute_j2+k;
-
- typename LhsMapper::LinearMapper r = lhs.getLinearMapper(i2,j);
- for (Index k3=0; k3<k; ++k3)
- {
- Scalar b = conj(rhs(IsLower ? j+1+k3 : absolute_j2+k3,j));
- typename LhsMapper::LinearMapper a = lhs.getLinearMapper(i2,IsLower ? j+1+k3 : absolute_j2+k3);
- for (Index i=0; i<actual_mc; ++i)
- r(i) -= a(i) * b;
- }
- if((Mode & UnitDiag)==0)
- {
- Scalar inv_rjj = RealScalar(1)/conj(rhs(j,j));
- for (Index i=0; i<actual_mc; ++i)
- r(i) *= inv_rjj;
- }
+ // unblocked triangular solve
+ trsm_kernels<Scalar, Index, Mode, Conjugate, TriStorageOrder, OtherInnerStride>::
+ trsmKernelR(actualPanelWidth, actual_mc,
+ _tri + absolute_j2 + absolute_j2*triStride, triStride,
+ _other + i2*OtherInnerStride + absolute_j2*otherStride, otherIncr, otherStride);
}
-
// pack the just computed part of lhs to A
pack_lhs_panel(blockA, lhs.getSubMapper(i2,absolute_j2),
actualPanelWidth, actual_mc,
@@ -331,7 +439,6 @@
}
}
}
-
} // end namespace internal
} // end namespace Eigen
diff --git a/Eigen/src/Core/products/TriangularSolverVector.h b/Eigen/src/Core/products/TriangularSolverVector.h
index c1c9e4c..57ade28 100644
--- a/Eigen/src/Core/products/TriangularSolverVector.h
+++ b/Eigen/src/Core/products/TriangularSolverVector.h
@@ -43,11 +43,10 @@
typedef const_blas_data_mapper<LhsScalar,Index,RowMajor> LhsMapper;
typedef const_blas_data_mapper<RhsScalar,Index,ColMajor> RhsMapper;
- typename internal::conditional<
- Conjugate,
- const CwiseUnaryOp<typename internal::scalar_conjugate_op<LhsScalar>,LhsMap>,
- const LhsMap&>
- ::type cjLhs(lhs);
+ std::conditional_t<
+ Conjugate,
+ const CwiseUnaryOp<typename internal::scalar_conjugate_op<LhsScalar>,LhsMap>,
+ const LhsMap&> cjLhs(lhs);
static const Index PanelWidth = EIGEN_TUNE_TRIANGULAR_PANEL_WIDTH;
for(Index pi=IsLower ? 0 : size;
IsLower ? pi<size : pi>0;
@@ -99,10 +98,10 @@
const LhsMap lhs(_lhs,size,size,OuterStride<>(lhsStride));
typedef const_blas_data_mapper<LhsScalar,Index,ColMajor> LhsMapper;
typedef const_blas_data_mapper<RhsScalar,Index,ColMajor> RhsMapper;
- typename internal::conditional<Conjugate,
- const CwiseUnaryOp<typename internal::scalar_conjugate_op<LhsScalar>,LhsMap>,
- const LhsMap&
- >::type cjLhs(lhs);
+ std::conditional_t<Conjugate,
+ const CwiseUnaryOp<typename internal::scalar_conjugate_op<LhsScalar>,LhsMap>,
+ const LhsMap&
+ > cjLhs(lhs);
static const Index PanelWidth = EIGEN_TUNE_TRIANGULAR_PANEL_WIDTH;
for(Index pi=IsLower ? 0 : size;
diff --git a/Eigen/src/Core/util/BlasUtil.h b/Eigen/src/Core/util/BlasUtil.h
index 8b35bcc..f45665e 100755
--- a/Eigen/src/Core/util/BlasUtil.h
+++ b/Eigen/src/Core/util/BlasUtil.h
@@ -213,6 +213,11 @@
return ploadt<PacketT, AlignmentT>(&operator()(i, j));
}
+ template<typename PacketType>
+ EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE void storePacket(Index i, Index j, const PacketType &p) const {
+ pstoret<Scalar, PacketType, AlignmentType>(&operator()(i, j), p);
+ }
+
template<typename SubPacket>
EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE void scatterPacket(Index i, Index j, const SubPacket &p) const {
pscatter<Scalar, SubPacket>(&operator()(i, j), p, m_stride);
@@ -311,6 +316,11 @@
return pgather<Scalar,PacketT>(&operator()(i, j),m_incr.value());
}
+ template<typename PacketType>
+ EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE void storePacket(Index i, Index j, const PacketType &p) const {
+ pscatter<Scalar, PacketType>(&operator()(i, j), p, m_incr.value());
+ }
+
template<typename SubPacket>
EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE void scatterPacket(Index i, Index j, const SubPacket &p) const {
pscatter<Scalar, SubPacket>(&operator()(i, j), p, m_stride);
@@ -428,10 +438,10 @@
) ? 1 : 0,
HasScalarFactor = false
};
- typedef typename conditional<bool(HasUsableDirectAccess),
+ typedef std::conditional_t<bool(HasUsableDirectAccess),
ExtractType,
typename ExtractType_::PlainObject
- >::type DirectLinearAccessType;
+ > DirectLinearAccessType;
static inline EIGEN_DEVICE_FUNC ExtractType extract(const XprType& x) { return x; }
static inline EIGEN_DEVICE_FUNC const Scalar extractScalarFactor(const XprType&) { return Scalar(1); }
};
@@ -514,10 +524,10 @@
typedef Transpose<NestedXpr> XprType;
typedef Transpose<const typename Base::ExtractType_> ExtractType; // const to get rid of a compile error; anyway blas traits are only used on the RHS
typedef Transpose<const typename Base::ExtractType_> ExtractType_;
- typedef typename conditional<bool(Base::HasUsableDirectAccess),
+ typedef std::conditional_t<bool(Base::HasUsableDirectAccess),
ExtractType,
typename ExtractType::PlainObject
- >::type DirectLinearAccessType;
+ > DirectLinearAccessType;
enum {
IsTransposed = Base::IsTransposed ? 0 : 1
};
diff --git a/Eigen/src/Core/util/ForwardDeclarations.h b/Eigen/src/Core/util/ForwardDeclarations.h
index aa4b642..a696961 100644
--- a/Eigen/src/Core/util/ForwardDeclarations.h
+++ b/Eigen/src/Core/util/ForwardDeclarations.h
@@ -106,7 +106,7 @@
template<typename MatrixType, int MapOptions=Unaligned, typename StrideType = Stride<0,0> > class Map;
template<typename Derived> class RefBase;
template<typename PlainObjectType, int Options = 0,
- typename StrideType = typename internal::conditional<PlainObjectType::IsVectorAtCompileTime,InnerStride<1>,OuterStride<> >::type > class Ref;
+ typename StrideType = typename std::conditional_t<PlainObjectType::IsVectorAtCompileTime,InnerStride<1>,OuterStride<> > > class Ref;
template<typename ViewOp, typename MatrixType, typename StrideType = Stride<0,0>> class CwiseUnaryView;
template<typename Derived> class TriangularBase;
diff --git a/Eigen/src/Core/util/IndexedViewHelper.h b/Eigen/src/Core/util/IndexedViewHelper.h
index 3e5fc09..19fa45d 100644
--- a/Eigen/src/Core/util/IndexedViewHelper.h
+++ b/Eigen/src/Core/util/IndexedViewHelper.h
@@ -99,7 +99,7 @@
// Turn a single index into something that looks like an array (i.e., that exposes a .size(), and operator[](int) methods)
template<typename T, int XprSize>
-struct IndexedViewCompatibleType<T,XprSize,typename internal::enable_if<internal::is_integral<T>::value>::type> {
+struct IndexedViewCompatibleType<T,XprSize,std::enable_if_t<internal::is_integral<T>::value>> {
// Here we could simply use Array, but maybe it's less work for the compiler to use
// a simpler wrapper as SingleRange
//typedef Eigen::Array<Index,1,1> type;
@@ -107,13 +107,13 @@
};
template<typename T, int XprSize>
-struct IndexedViewCompatibleType<T, XprSize, typename enable_if<symbolic::is_symbolic<T>::value>::type> {
+struct IndexedViewCompatibleType<T, XprSize, std::enable_if_t<symbolic::is_symbolic<T>::value>> {
typedef SingleRange type;
};
template<typename T>
-typename enable_if<symbolic::is_symbolic<T>::value,SingleRange>::type
+std::enable_if_t<symbolic::is_symbolic<T>::value,SingleRange>
makeIndexedViewCompatible(const T& id, Index size, SpecializedType) {
return eval_expr_given_size(id,size);
}
diff --git a/Eigen/src/Core/util/IntegralConstant.h b/Eigen/src/Core/util/IntegralConstant.h
index da9cccb..ea275bd 100644
--- a/Eigen/src/Core/util/IntegralConstant.h
+++ b/Eigen/src/Core/util/IntegralConstant.h
@@ -169,7 +169,7 @@
template<typename T, int DynamicKey=Dynamic, typename EnableIf=void> struct cleanup_index_type { typedef T type; };
// Convert any integral type (e.g., short, int, unsigned int, etc.) to Eigen::Index
-template<typename T, int DynamicKey> struct cleanup_index_type<T,DynamicKey,typename internal::enable_if<internal::is_integral<T>::value>::type> { typedef Index type; };
+template<typename T, int DynamicKey> struct cleanup_index_type<T,DynamicKey,std::enable_if_t<internal::is_integral<T>::value>> { typedef Index type; };
// If VariableAndFixedInt does not match DynamicKey, then we turn it to a pure compile-time value:
template<int N, int DynamicKey> struct cleanup_index_type<VariableAndFixedInt<N>, DynamicKey> { typedef FixedInt<N> type; };
diff --git a/Eigen/src/Core/util/Macros.h b/Eigen/src/Core/util/Macros.h
index 281ffb7..cb583ee 100644
--- a/Eigen/src/Core/util/Macros.h
+++ b/Eigen/src/Core/util/Macros.h
@@ -876,11 +876,6 @@
#define eigen_plain_assert(x)
#endif
#else
- namespace Eigen {
- namespace internal {
- inline bool copy_bool(bool b) { return b; }
- }
- }
#define eigen_plain_assert(x) assert(x)
#endif
@@ -922,7 +917,7 @@
// Suppresses 'unused variable' warnings.
namespace Eigen {
namespace internal {
- template<typename T> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void ignore_unused_variable(const T&) {}
+ template<typename T> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE constexpr void ignore_unused_variable(const T&) {}
}
}
#define EIGEN_UNUSED_VARIABLE(var) Eigen::internal::ignore_unused_variable(var);
diff --git a/Eigen/src/Core/util/Memory.h b/Eigen/src/Core/util/Memory.h
index 8e724fc..7657ead 100644
--- a/Eigen/src/Core/util/Memory.h
+++ b/Eigen/src/Core/util/Memory.h
@@ -634,7 +634,7 @@
>
struct local_nested_eval_wrapper
{
- static const bool NeedExternalBuffer = false;
+ static constexpr bool NeedExternalBuffer = false;
typedef typename Xpr::Scalar Scalar;
typedef typename nested_eval<Xpr,NbEvaluations>::type ObjectType;
ObjectType object;
@@ -650,7 +650,7 @@
template<typename Xpr, int NbEvaluations>
struct local_nested_eval_wrapper<Xpr,NbEvaluations,true>
{
- static const bool NeedExternalBuffer = true;
+ static constexpr bool NeedExternalBuffer = true;
typedef typename Xpr::Scalar Scalar;
typedef typename plain_object_eval<Xpr>::type PlainObject;
typedef Map<PlainObject,EIGEN_DEFAULT_ALIGN_BYTES> ObjectType;
@@ -1150,6 +1150,38 @@
return (std::max)(l2,l3);
}
+
+
+/** \internal
+ * This wraps C++20's std::construct_at, using placement new instead if it is not available.
+ */
+
+#if EIGEN_COMP_CXXVER >= 20
+using std::construct_at;
+#else
+template<class T, class... Args>
+EIGEN_DEVICE_FUNC T* construct_at( T* p, Args&&... args )
+{
+ return ::new (const_cast<void*>(static_cast<const volatile void*>(p)))
+ T(std::forward<Args>(args)...);
+}
+#endif
+
+/** \internal
+ * This wraps C++17's std::destroy_at. If it's not available it calls the destructor.
+ * The wrapper is not a full replacement for C++20's std::destroy_at as it cannot
+ * be applied to std::array.
+ */
+#if EIGEN_COMP_CXXVER >= 17
+using std::destroy_at;
+#else
+template<class T>
+EIGEN_DEVICE_FUNC void destroy_at(T* p)
+{
+ p->~T();
+}
+#endif
+
} // end namespace internal
} // end namespace Eigen
diff --git a/Eigen/src/Core/util/Meta.h b/Eigen/src/Core/util/Meta.h
index 0646a9a..8d5c2c3 100755
--- a/Eigen/src/Core/util/Meta.h
+++ b/Eigen/src/Core/util/Meta.h
@@ -90,23 +90,11 @@
template<>
struct bool_constant<false> : false_type {};
-template<bool Condition, typename Then, typename Else>
-struct conditional { typedef Then type; };
-
-template<typename Then, typename Else>
-struct conditional <false, Then, Else> { typedef Else type; };
-
-template<typename T> struct remove_reference { typedef T type; };
-template<typename T> struct remove_reference<T&> { typedef T type; };
-
-template<typename T> struct remove_pointer { typedef T type; };
-template<typename T> struct remove_pointer<T*> { typedef T type; };
-template<typename T> struct remove_pointer<T*const> { typedef T type; };
-
-template <class T> struct remove_const { typedef T type; };
-template <class T> struct remove_const<const T> { typedef T type; };
-template <class T> struct remove_const<const T[]> { typedef T type[]; };
-template <class T, unsigned int Size> struct remove_const<const T[Size]> { typedef T type[Size]; };
+// Third-party libraries rely on these.
+using std::conditional;
+using std::remove_reference;
+using std::remove_pointer;
+using std::remove_const;
template<typename T> struct remove_all { typedef T type; };
template<typename T> struct remove_all<const T> { typedef typename remove_all<T>::type type; };
@@ -115,6 +103,9 @@
template<typename T> struct remove_all<T const*> { typedef typename remove_all<T>::type type; };
template<typename T> struct remove_all<T*> { typedef typename remove_all<T>::type type; };
+template<typename T>
+using remove_all_t = typename remove_all<T>::type;
+
template<typename T> struct is_arithmetic { enum { value = false }; };
template<> struct is_arithmetic<float> { enum { value = true }; };
template<> struct is_arithmetic<double> { enum { value = true }; };
@@ -134,7 +125,7 @@
template<typename T> struct is_same<T,T> { enum { value = 1 }; };
template< class T >
-struct is_void : is_same<void, typename remove_const<T>::type> {};
+struct is_void : is_same<void, std::remove_const_t<T>> {};
template<> struct is_arithmetic<signed long long> { enum { value = true }; };
template<> struct is_arithmetic<unsigned long long> { enum { value = true }; };
@@ -142,9 +133,6 @@
using std::make_unsigned;
-template <typename T> struct add_const { typedef const T type; };
-template <typename T> struct add_const<T&> { typedef T& type; };
-
template <typename T> struct is_const { enum { value = 0 }; };
template <typename T> struct is_const<T const> { enum { value = 1 }; };
@@ -154,16 +142,11 @@
template<typename T> struct add_const_on_value_type<T* const> { typedef T const* const type; };
template<typename T> struct add_const_on_value_type<T const* const> { typedef T const* const type; };
+template<typename T>
+using add_const_on_value_type_t = typename add_const_on_value_type<T>::type;
+
using std::is_convertible;
-/** \internal Allows to enable/disable an overload
- * according to a compile time condition.
- */
-template<bool Condition, typename T=void> struct enable_if;
-
-template<typename T> struct enable_if<true,T>
-{ typedef T type; };
-
/** \internal
* A base class do disable default copy ctor and copy assignment operator.
*/
@@ -194,7 +177,7 @@
enum { value = Dynamic };
};
-template<typename T> struct array_size<T,typename internal::enable_if<((T::SizeAtCompileTime&0)==0)>::type> {
+template<typename T> struct array_size<T, std::enable_if_t<((T::SizeAtCompileTime&0)==0)>> {
enum { value = T::SizeAtCompileTime };
};
@@ -256,24 +239,24 @@
template<typename F, typename... ArgTypes>
struct result_of<F(ArgTypes...)> {
typedef typename std::invoke_result<F, ArgTypes...>::type type1;
- typedef typename remove_all<type1>::type type;
+ typedef remove_all_t<type1> type;
};
template<typename F, typename... ArgTypes>
struct invoke_result {
typedef typename std::invoke_result<F, ArgTypes...>::type type1;
- typedef typename remove_all<type1>::type type;
+ typedef remove_all_t<type1> type;
};
#else
template<typename T> struct result_of {
typedef typename std::result_of<T>::type type1;
- typedef typename remove_all<type1>::type type;
+ typedef remove_all_t<type1> type;
};
template<typename F, typename... ArgTypes>
struct invoke_result {
typedef typename result_of<F(ArgTypes...)>::type type1;
- typedef typename remove_all<type1>::type type;
+ typedef remove_all_t<type1> type;
};
#endif
@@ -305,7 +288,7 @@
template <typename T, typename IndexType=Index>
struct has_nullary_operator
{
- template <typename C> static meta_yes testFunctor(C const *,typename enable_if<(sizeof(return_ptr<C>()->operator()())>0)>::type * = 0);
+ template <typename C> static meta_yes testFunctor(C const *,std::enable_if_t<(sizeof(return_ptr<C>()->operator()())>0)> * = 0);
static meta_no testFunctor(...);
enum { value = sizeof(testFunctor(static_cast<T*>(0))) == sizeof(meta_yes) };
@@ -314,7 +297,7 @@
template <typename T, typename IndexType=Index>
struct has_unary_operator
{
- template <typename C> static meta_yes testFunctor(C const *,typename enable_if<(sizeof(return_ptr<C>()->operator()(IndexType(0)))>0)>::type * = 0);
+ template <typename C> static meta_yes testFunctor(C const *,std::enable_if_t<(sizeof(return_ptr<C>()->operator()(IndexType(0)))>0)> * = 0);
static meta_no testFunctor(...);
enum { value = sizeof(testFunctor(static_cast<T*>(0))) == sizeof(meta_yes) };
@@ -323,7 +306,7 @@
template <typename T, typename IndexType=Index>
struct has_binary_operator
{
- template <typename C> static meta_yes testFunctor(C const *,typename enable_if<(sizeof(return_ptr<C>()->operator()(IndexType(0),IndexType(0)))>0)>::type * = 0);
+ template <typename C> static meta_yes testFunctor(C const *,std::enable_if_t<(sizeof(return_ptr<C>()->operator()(IndexType(0),IndexType(0)))>0)> * = 0);
static meta_no testFunctor(...);
enum { value = sizeof(testFunctor(static_cast<T*>(0))) == sizeof(meta_yes) };
@@ -335,8 +318,7 @@
template<int Y,
int InfX = 0,
int SupX = ((Y==1) ? 1 : Y/2),
- bool Done = ((SupX-InfX)<=1 ? true : ((SupX*SupX <= Y) && ((SupX+1)*(SupX+1) > Y))) >
- // use ?: instead of || just to shut up a stupid gcc 4.3 warning
+ bool Done = ((SupX - InfX) <= 1 || ((SupX * SupX <= Y) && ((SupX + 1) * (SupX + 1) > Y)))>
class meta_sqrt
{
enum {
@@ -382,7 +364,7 @@
// FIXME quick workaround around current limitation of result_of
// template<typename Scalar, typename ArgType0, typename ArgType1>
// struct result_of<scalar_product_op<Scalar>(ArgType0,ArgType1)> {
-// typedef typename scalar_product_traits<typename remove_all<ArgType0>::type, typename remove_all<ArgType1>::type>::ReturnType type;
+// typedef typename scalar_product_traits<remove_all_t<ArgType0>, remove_all_t<ArgType1>>::ReturnType type;
// };
/** \internal Obtains a POD type suitable to use as storage for an object of a size
@@ -532,6 +514,14 @@
inline constexpr bool check_implication(bool a, bool b) {
return !a || b;
}
+
+/// \internal Provide fallback for std::is_constant_evaluated for pre-C++20.
+#if EIGEN_COMP_CXXVER >= 20
+using std::is_constant_evaluated;
+#else
+constexpr bool is_constant_evaluated() { return false; }
+#endif
+
} // end namespace internal
} // end namespace Eigen
diff --git a/Eigen/src/Core/util/ReshapedHelper.h b/Eigen/src/Core/util/ReshapedHelper.h
index 6d949eb..6daea03 100644
--- a/Eigen/src/Core/util/ReshapedHelper.h
+++ b/Eigen/src/Core/util/ReshapedHelper.h
@@ -41,10 +41,9 @@
return total/other;
}
-template<int Flags, int Order>
-struct get_compiletime_reshape_order {
- enum { value = Order == AutoOrder ? Flags & RowMajorBit : Order };
-};
+constexpr inline int get_compiletime_reshape_order(int flags, int order) {
+ return order == AutoOrder ? flags & RowMajorBit : order;
+}
}
diff --git a/Eigen/src/Core/util/Serializer.h b/Eigen/src/Core/util/Serializer.h
index b77c5de..cbfc04a 100644
--- a/Eigen/src/Core/util/Serializer.h
+++ b/Eigen/src/Core/util/Serializer.h
@@ -28,9 +28,9 @@
// Specialization for POD types.
template<typename T>
-class Serializer<T, typename std::enable_if<
+class Serializer<T, typename std::enable_if_t<
std::is_trivial<T>::value
- && std::is_standard_layout<T>::value>::type > {
+ && std::is_standard_layout<T>::value>> {
public:
/**
diff --git a/Eigen/src/Core/util/XprHelper.h b/Eigen/src/Core/util/XprHelper.h
index 73bad6c..ba8a7f1 100644
--- a/Eigen/src/Core/util/XprHelper.h
+++ b/Eigen/src/Core/util/XprHelper.h
@@ -11,17 +11,6 @@
#ifndef EIGEN_XPRHELPER_H
#define EIGEN_XPRHELPER_H
-// just a workaround because GCC seems to not really like empty structs
-// FIXME: gcc 4.3 generates bad code when strict-aliasing is enabled
-// so currently we simply disable this optimization for gcc 4.3
-#if EIGEN_COMP_GNUC
- #define EIGEN_EMPTY_STRUCT_CTOR(X) \
- EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE X() {} \
- EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE X(const X& ) {}
-#else
- #define EIGEN_EMPTY_STRUCT_CTOR(X)
-#endif
-
#include "../InternalHeaderCheck.h"
namespace Eigen {
@@ -113,7 +102,7 @@
template<typename I1, typename I2>
struct promote_index_type
{
- typedef typename conditional<(sizeof(I1)<sizeof(I2)), I2, I1>::type type;
+ typedef std::conditional_t<(sizeof(I1)<sizeof(I2)), I2, I1> type;
};
/** \internal If the template parameter Value is Dynamic, this class is just a wrapper around a T variable that
@@ -148,7 +137,6 @@
template<typename T, int Value> class variable_if_dynamicindex
{
public:
- EIGEN_EMPTY_STRUCT_CTOR(variable_if_dynamicindex)
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE explicit variable_if_dynamicindex(T v) { EIGEN_ONLY_USED_FOR_DEBUG(v); eigen_assert(v == T(Value)); }
EIGEN_DEVICE_FUNC static EIGEN_STRONG_INLINE EIGEN_CONSTEXPR
T value() { return T(Value); }
@@ -203,38 +191,27 @@
};
#if EIGEN_MAX_STATIC_ALIGN_BYTES>0
-template<int ArrayBytes, int AlignmentBytes,
- bool Match = bool((ArrayBytes%AlignmentBytes)==0),
- bool TryHalf = bool(EIGEN_MIN_ALIGN_BYTES<AlignmentBytes) >
-struct compute_default_alignment_helper
-{
- enum { value = 0 };
-};
-
-template<int ArrayBytes, int AlignmentBytes, bool TryHalf>
-struct compute_default_alignment_helper<ArrayBytes, AlignmentBytes, true, TryHalf> // Match
-{
- enum { value = AlignmentBytes };
-};
-
-template<int ArrayBytes, int AlignmentBytes>
-struct compute_default_alignment_helper<ArrayBytes, AlignmentBytes, false, true> // Try-half
-{
- // current packet too large, try with an half-packet
- enum { value = compute_default_alignment_helper<ArrayBytes, AlignmentBytes/2>::value };
-};
+constexpr inline int compute_default_alignment_helper(int ArrayBytes, int AlignmentBytes) {
+ if((ArrayBytes % AlignmentBytes) == 0) {
+ return AlignmentBytes;
+ } else if (EIGEN_MIN_ALIGN_BYTES<AlignmentBytes) {
+ return compute_default_alignment_helper(ArrayBytes, AlignmentBytes/2);
+ } else {
+ return 0;
+ }
+}
#else
// If static alignment is disabled, no need to bother.
-// This also avoids a division by zero in "bool Match = bool((ArrayBytes%AlignmentBytes)==0)"
-template<int ArrayBytes, int AlignmentBytes>
-struct compute_default_alignment_helper
-{
- enum { value = 0 };
-};
+// This also avoids a division by zero
+constexpr inline int compute_default_alignment_helper(int ArrayBytes, int AlignmentBytes) {
+ EIGEN_UNUSED_VARIABLE(ArrayBytes);
+ EIGEN_UNUSED_VARIABLE(AlignmentBytes);
+ return 0;
+}
#endif
template<typename T, int Size> struct compute_default_alignment {
- enum { value = compute_default_alignment_helper<Size*sizeof(T),EIGEN_MAX_STATIC_ALIGN_BYTES>::value };
+ enum { value = compute_default_alignment_helper(Size*sizeof(T), EIGEN_MAX_STATIC_ALIGN_BYTES) };
};
template<typename T> struct compute_default_alignment<T,Dynamic> {
@@ -261,25 +238,21 @@
typedef Matrix<Scalar_, Rows_, Cols_, Options, MaxRows_, MaxCols_> type;
};
-template<typename Scalar, int Rows, int Cols, int Options, int MaxRows, int MaxCols>
-class compute_matrix_flags
-{
- enum { row_major_bit = Options&RowMajor ? RowMajorBit : 0 };
- public:
- // FIXME currently we still have to handle DirectAccessBit at the expression level to handle DenseCoeffsBase<>
- // and then propagate this information to the evaluator's flags.
- // However, I (Gael) think that DirectAccessBit should only matter at the evaluation stage.
- enum { ret = DirectAccessBit | LvalueBit | NestByRefBit | row_major_bit };
-};
+constexpr inline unsigned compute_matrix_flags(int Options) {
+ unsigned row_major_bit = Options&RowMajor ? RowMajorBit : 0;
+ // FIXME currently we still have to handle DirectAccessBit at the expression level to handle DenseCoeffsBase<>
+ // and then propagate this information to the evaluator's flags.
+ // However, I (Gael) think that DirectAccessBit should only matter at the evaluation stage.
+ return DirectAccessBit | LvalueBit | NestByRefBit | row_major_bit;
+}
-template<int Rows_, int Cols_> struct size_at_compile_time
-{
- enum { ret = (Rows_==Dynamic || Cols_==Dynamic) ? Dynamic : Rows_ * Cols_ };
-};
+constexpr inline int size_at_compile_time(int rows, int cols) {
+ return (rows==Dynamic || cols==Dynamic) ? Dynamic : rows * cols;
+}
template<typename XprType> struct size_of_xpr_at_compile_time
{
- enum { ret = size_at_compile_time<traits<XprType>::RowsAtCompileTime,traits<XprType>::ColsAtCompileTime>::ret };
+ enum { ret = size_at_compile_time(traits<XprType>::RowsAtCompileTime, traits<XprType>::ColsAtCompileTime) };
};
/* plain_matrix_type : the difference from eval is that plain_matrix_type is always a plain matrix type,
@@ -409,28 +382,28 @@
template <typename T>
struct ref_selector
{
- typedef typename conditional<
+ typedef std::conditional_t<
bool(traits<T>::Flags & NestByRefBit),
T const&,
const T
- >::type type;
+ > type;
- typedef typename conditional<
+ typedef std::conditional_t<
bool(traits<T>::Flags & NestByRefBit),
T &,
T
- >::type non_const_type;
+ > non_const_type;
};
/** \internal Adds the const qualifier on the value-type of T2 if and only if T1 is a const type */
template<typename T1, typename T2>
struct transfer_constness
{
- typedef typename conditional<
+ typedef std::conditional_t<
bool(internal::is_const<T1>::value),
- typename internal::add_const_on_value_type<T2>::type,
+ add_const_on_value_type_t<T2>,
T2
- >::type type;
+ > type;
};
@@ -463,7 +436,7 @@
Evaluate = (int(evaluator<T>::Flags) & EvalBeforeNestingBit) || (int(CostEval) < int(CostNoEval))
};
- typedef typename conditional<Evaluate, PlainObject, typename ref_selector<T>::type>::type type;
+ typedef std::conditional_t<Evaluate, PlainObject, typename ref_selector<T>::type> type;
};
template<typename T>
@@ -503,10 +476,10 @@
template<typename XprType, typename CastType> struct cast_return_type
{
typedef typename XprType::Scalar CurrentScalarType;
- typedef typename remove_all<CastType>::type CastType_;
+ typedef remove_all_t<CastType> CastType_;
typedef typename CastType_::Scalar NewScalarType;
- typedef typename conditional<is_same<CurrentScalarType,NewScalarType>::value,
- const XprType&,CastType>::type type;
+ typedef std::conditional_t<is_same<CurrentScalarType,NewScalarType>::value,
+ const XprType&,CastType> type;
};
template <typename A, typename B> struct promote_storage_type;
@@ -597,11 +570,11 @@
typedef Array<Scalar, 1, ExpressionType::ColsAtCompileTime,
int(ExpressionType::PlainObject::Options) | int(RowMajor), 1, ExpressionType::MaxColsAtCompileTime> ArrayRowType;
- typedef typename conditional<
+ typedef std::conditional_t<
is_same< typename traits<ExpressionType>::XprKind, MatrixXpr >::value,
MatrixRowType,
ArrayRowType
- >::type type;
+ > type;
};
template<typename ExpressionType, typename Scalar = typename ExpressionType::Scalar>
@@ -612,11 +585,11 @@
typedef Array<Scalar, ExpressionType::RowsAtCompileTime, 1,
ExpressionType::PlainObject::Options & ~RowMajor, ExpressionType::MaxRowsAtCompileTime, 1> ArrayColType;
- typedef typename conditional<
+ typedef std::conditional_t<
is_same< typename traits<ExpressionType>::XprKind, MatrixXpr >::value,
MatrixColType,
ArrayColType
- >::type type;
+ > type;
};
template<typename ExpressionType, typename Scalar = typename ExpressionType::Scalar>
@@ -629,11 +602,11 @@
typedef Matrix<Scalar, diag_size, 1, ExpressionType::PlainObject::Options & ~RowMajor, max_diag_size, 1> MatrixDiagType;
typedef Array<Scalar, diag_size, 1, ExpressionType::PlainObject::Options & ~RowMajor, max_diag_size, 1> ArrayDiagType;
- typedef typename conditional<
+ typedef std::conditional_t<
is_same< typename traits<ExpressionType>::XprKind, MatrixXpr >::value,
MatrixDiagType,
ArrayDiagType
- >::type type;
+ > type;
};
template<typename Expr,typename Scalar = typename Expr::Scalar>
@@ -647,7 +620,7 @@
typedef Matrix<Scalar, traits<Expr>::RowsAtCompileTime, traits<Expr>::ColsAtCompileTime,
Options, traits<Expr>::MaxRowsAtCompileTime,traits<Expr>::MaxColsAtCompileTime> matrix_type;
- typedef CwiseNullaryOp<scalar_constant_op<Scalar>, const typename conditional<is_same< typename traits<Expr>::XprKind, MatrixXpr >::value, matrix_type, array_type>::type > type;
+ typedef CwiseNullaryOp<scalar_constant_op<Scalar>, const std::conditional_t<is_same< typename traits<Expr>::XprKind, MatrixXpr >::value, matrix_type, array_type> > type;
};
template<typename ExpressionType>
@@ -687,14 +660,14 @@
template<typename T1, typename T2>
EIGEN_DEVICE_FUNC
-bool is_same_dense(const T1 &mat1, const T2 &mat2, typename enable_if<possibly_same_dense<T1,T2>::value>::type * = 0)
+bool is_same_dense(const T1 &mat1, const T2 &mat2, std::enable_if_t<possibly_same_dense<T1,T2>::value> * = 0)
{
return (mat1.data()==mat2.data()) && (mat1.innerStride()==mat2.innerStride()) && (mat1.outerStride()==mat2.outerStride());
}
template<typename T1, typename T2>
EIGEN_DEVICE_FUNC
-bool is_same_dense(const T1 &, const T2 &, typename enable_if<!possibly_same_dense<T1,T2>::value>::type * = 0)
+bool is_same_dense(const T1 &, const T2 &, std::enable_if_t<!possibly_same_dense<T1,T2>::value> * = 0)
{
return false;
}
@@ -716,9 +689,9 @@
template<bool Vectorized>
-struct scalar_div_cost<signed long,Vectorized,typename conditional<sizeof(long)==8,void,false_type>::type> { enum { value = 24 }; };
+struct scalar_div_cost<signed long,Vectorized, std::conditional_t<sizeof(long)==8,void,false_type>> { enum { value = 24 }; };
template<bool Vectorized>
-struct scalar_div_cost<unsigned long,Vectorized,typename conditional<sizeof(long)==8,void,false_type>::type> { enum { value = 21 }; };
+struct scalar_div_cost<unsigned long,Vectorized, std::conditional_t<sizeof(long)==8,void,false_type>> { enum { value = 21 }; };
#ifdef EIGEN_DEBUG_ASSIGN
@@ -807,12 +780,12 @@
};
template <typename T, typename BinaryOp>
-struct ScalarBinaryOpTraits<T, typename NumTraits<typename internal::enable_if<NumTraits<T>::IsComplex,T>::type>::Real, BinaryOp>
+struct ScalarBinaryOpTraits<T, typename NumTraits<std::enable_if_t<NumTraits<T>::IsComplex,T>>::Real, BinaryOp>
{
typedef T ReturnType;
};
template <typename T, typename BinaryOp>
-struct ScalarBinaryOpTraits<typename NumTraits<typename internal::enable_if<NumTraits<T>::IsComplex,T>::type>::Real, T, BinaryOp>
+struct ScalarBinaryOpTraits<typename NumTraits<std::enable_if_t<NumTraits<T>::IsComplex,T>>::Real, T, BinaryOp>
{
typedef T ReturnType;
};