Update Eigen to commit:d791d48859c6fc7850c9fd5270d2b236c818068d
CHANGELOG
=========
d791d4885 - Fix AVX512FP16 build failure
2fae4d7a7 - Revert "fix scalar pselect"
b430eb31e - AVX512F double->int64_t cast
02bcf9b59 - fix scalar pselect
392b95bdf - allow pointer_based_stl_iterator to conform to the contiguous_iterator concept if we are in c++20
27f817625 - fixing warning C5054: operator '==': deprecated between enumerations of different types
PiperOrigin-RevId: 645467923
Change-Id: Icf1fd74b44ed4aea311a3d57854c5d5814e89e42
diff --git a/Eigen/src/Core/ProductEvaluators.h b/Eigen/src/Core/ProductEvaluators.h
index 19c2560..fa4d038 100644
--- a/Eigen/src/Core/ProductEvaluators.h
+++ b/Eigen/src/Core/ProductEvaluators.h
@@ -810,7 +810,7 @@
: (Derived::MaxColsAtCompileTime == 1 && Derived::MaxRowsAtCompileTime != 1) ? ColMajor
: MatrixFlags & RowMajorBit ? RowMajor
: ColMajor,
- SameStorageOrder_ = StorageOrder_ == (MatrixFlags & RowMajorBit ? RowMajor : ColMajor),
+ SameStorageOrder_ = int(StorageOrder_) == ((MatrixFlags & RowMajorBit) ? RowMajor : ColMajor),
ScalarAccessOnDiag_ = !((int(StorageOrder_) == ColMajor && int(ProductOrder) == OnTheLeft) ||
(int(StorageOrder_) == RowMajor && int(ProductOrder) == OnTheRight)),
diff --git a/Eigen/src/Core/StlIterators.h b/Eigen/src/Core/StlIterators.h
index 3ab7d21..bb897f8 100644
--- a/Eigen/src/Core/StlIterators.h
+++ b/Eigen/src/Core/StlIterators.h
@@ -325,7 +325,13 @@
public:
typedef Index difference_type;
typedef typename XprType::Scalar value_type;
+#if __cplusplus >= 202002L
+ typedef std::conditional_t<XprType::InnerStrideAtCompileTime == 1, std::contiguous_iterator_tag,
+ std::random_access_iterator_tag>
+ iterator_category;
+#else
typedef std::random_access_iterator_tag iterator_category;
+#endif
typedef std::conditional_t<bool(is_lvalue), value_type*, const value_type*> pointer;
typedef std::conditional_t<bool(is_lvalue), value_type&, const value_type&> reference;
diff --git a/Eigen/src/Core/arch/AVX512/TypeCasting.h b/Eigen/src/Core/arch/AVX512/TypeCasting.h
index b16e9f6..9508ac6 100644
--- a/Eigen/src/Core/arch/AVX512/TypeCasting.h
+++ b/Eigen/src/Core/arch/AVX512/TypeCasting.h
@@ -42,12 +42,10 @@
template <>
struct type_casting_traits<int64_t, double> : vectorized_type_casting_traits<int64_t, double> {};
-#ifndef EIGEN_VECTORIZE_AVX512FP16
template <>
struct type_casting_traits<half, float> : vectorized_type_casting_traits<half, float> {};
template <>
struct type_casting_traits<float, half> : vectorized_type_casting_traits<float, half> {};
-#endif
template <>
struct type_casting_traits<bfloat16, float> : vectorized_type_casting_traits<bfloat16, float> {};
@@ -82,14 +80,34 @@
template <>
EIGEN_STRONG_INLINE Packet8l pcast<Packet8d, Packet8l>(const Packet8d& a) {
-#if defined(EIGEN_VECTORIZE_AVX512DQ) && defined(EIGEN_VECTORIZE_AVS512VL)
+#if defined(EIGEN_VECTORIZE_AVX512DQ) && defined(EIGEN_VECTORIZE_AVX512VL)
return _mm512_cvttpd_epi64(a);
#else
- EIGEN_ALIGN16 double aux[8];
- pstore(aux, a);
- return _mm512_set_epi64(static_cast<int64_t>(aux[7]), static_cast<int64_t>(aux[6]), static_cast<int64_t>(aux[5]),
- static_cast<int64_t>(aux[4]), static_cast<int64_t>(aux[3]), static_cast<int64_t>(aux[2]),
- static_cast<int64_t>(aux[1]), static_cast<int64_t>(aux[0]));
+ constexpr int kTotalBits = sizeof(double) * CHAR_BIT, kMantissaBits = std::numeric_limits<double>::digits - 1,
+ kExponentBits = kTotalBits - kMantissaBits - 1, kBias = (1 << (kExponentBits - 1)) - 1;
+
+ const __m512i cst_one = _mm512_set1_epi64(1);
+ const __m512i cst_total_bits = _mm512_set1_epi64(kTotalBits);
+ const __m512i cst_bias = _mm512_set1_epi64(kBias);
+
+ __m512i a_bits = _mm512_castpd_si512(a);
+ // shift left by 1 to clear the sign bit, and shift right by kMantissaBits + 1 to recover biased exponent
+ __m512i biased_e = _mm512_srli_epi64(_mm512_slli_epi64(a_bits, 1), kMantissaBits + 1);
+ __m512i e = _mm512_sub_epi64(biased_e, cst_bias);
+
+ // shift to the left by kExponentBits + 1 to clear the sign and exponent bits
+ __m512i shifted_mantissa = _mm512_slli_epi64(a_bits, kExponentBits + 1);
+ // shift to the right by kTotalBits - e to convert the significand to an integer
+ __m512i result_significand = _mm512_srlv_epi64(shifted_mantissa, _mm512_sub_epi64(cst_total_bits, e));
+
+ // add the implied bit
+ __m512i result_exponent = _mm512_sllv_epi64(cst_one, e);
+ // e <= 0 is interpreted as a large positive shift (2's complement), which also conveniently results in zero
+ __m512i result = _mm512_add_epi64(result_significand, result_exponent);
+ // handle negative arguments
+ __mmask8 sign_mask = _mm512_cmplt_epi64_mask(a_bits, _mm512_setzero_si512());
+ result = _mm512_mask_sub_epi64(result, sign_mask, _mm512_setzero_si512(), result);
+ return result;
#endif
}
@@ -110,10 +128,10 @@
template <>
EIGEN_STRONG_INLINE Packet8d pcast<Packet8l, Packet8d>(const Packet8l& a) {
-#if defined(EIGEN_VECTORIZE_AVX512DQ) && defined(EIGEN_VECTORIZE_AVS512VL)
+#if defined(EIGEN_VECTORIZE_AVX512DQ) && defined(EIGEN_VECTORIZE_AVX512VL)
return _mm512_cvtepi64_pd(a);
#else
- EIGEN_ALIGN16 int64_t aux[8];
+ EIGEN_ALIGN64 int64_t aux[8];
pstore(aux, a);
return _mm512_set_pd(static_cast<double>(aux[7]), static_cast<double>(aux[6]), static_cast<double>(aux[5]),
static_cast<double>(aux[4]), static_cast<double>(aux[3]), static_cast<double>(aux[2]),
@@ -229,8 +247,6 @@
return _mm256_castsi256_si128(a);
}
-#ifndef EIGEN_VECTORIZE_AVX512FP16
-
template <>
EIGEN_STRONG_INLINE Packet16f pcast<Packet16h, Packet16f>(const Packet16h& a) {
return half2float(a);
@@ -241,8 +257,6 @@
return float2half(a);
}
-#endif
-
template <>
EIGEN_STRONG_INLINE Packet16f pcast<Packet16bf, Packet16f>(const Packet16bf& a) {
return Bf16ToF32(a);