BEGIN_PUBLIC
Update Eigen to: https://gitlab.com/libeigen/eigen/-/commit/2ce2f5198929caab4b41a6ad1b9c93f67d8b9a69
This adds AVX acceleration for bfloat16 (AVX512 already supports it).
END_PUBLIC
.
PiperOrigin-RevId: 322917791
diff --git a/Eigen/src/Core/arch/AVX/MathFunctions.h b/Eigen/src/Core/arch/AVX/MathFunctions.h
index c539443..4616961 100644
--- a/Eigen/src/Core/arch/AVX/MathFunctions.h
+++ b/Eigen/src/Core/arch/AVX/MathFunctions.h
@@ -58,15 +58,15 @@
// Hyperbolic Tangent function.
template <>
EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS EIGEN_UNUSED Packet8f
-ptanh<Packet8f>(const Packet8f& x) {
- return internal::generic_fast_tanh_float(x);
+ptanh<Packet8f>(const Packet8f& _x) {
+ return internal::generic_fast_tanh_float(_x);
}
// Exponential function for doubles.
template <>
EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS EIGEN_UNUSED Packet4d
-pexp<Packet4d>(const Packet4d& x) {
- return pexp_double(x);
+pexp<Packet4d>(const Packet4d& _x) {
+ return pexp_double(_x);
}
// Functions for sqrt.
@@ -96,13 +96,13 @@
}
#else
template <> EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS EIGEN_UNUSED
-Packet8f psqrt<Packet8f>(const Packet8f& x) {
- return _mm256_sqrt_ps(x);
+Packet8f psqrt<Packet8f>(const Packet8f& _x) {
+ return _mm256_sqrt_ps(_x);
}
#endif
template <> EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS EIGEN_UNUSED
-Packet4d psqrt<Packet4d>(const Packet4d& x) {
- return _mm256_sqrt_pd(x);
+Packet4d psqrt<Packet4d>(const Packet4d& _x) {
+ return _mm256_sqrt_pd(_x);
}
#if EIGEN_FAST_MATH
@@ -140,18 +140,27 @@
#else
template <> EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS EIGEN_UNUSED
-Packet8f prsqrt<Packet8f>(const Packet8f& x) {
+Packet8f prsqrt<Packet8f>(const Packet8f& _x) {
_EIGEN_DECLARE_CONST_Packet8f(one, 1.0f);
- return _mm256_div_ps(p8f_one, _mm256_sqrt_ps(x));
+ return _mm256_div_ps(p8f_one, _mm256_sqrt_ps(_x));
}
#endif
template <> EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS EIGEN_UNUSED
-Packet4d prsqrt<Packet4d>(const Packet4d& x) {
+Packet4d prsqrt<Packet4d>(const Packet4d& _x) {
_EIGEN_DECLARE_CONST_Packet4d(one, 1.0);
- return _mm256_div_pd(p4d_one, _mm256_sqrt_pd(x));
+ return _mm256_div_pd(p4d_one, _mm256_sqrt_pd(_x));
}
+BF16_PACKET_FUNCTION(Packet8f, Packet8bf, psin)
+BF16_PACKET_FUNCTION(Packet8f, Packet8bf, pcos)
+BF16_PACKET_FUNCTION(Packet8f, Packet8bf, plog)
+BF16_PACKET_FUNCTION(Packet8f, Packet8bf, plog1p)
+BF16_PACKET_FUNCTION(Packet8f, Packet8bf, pexpm1)
+BF16_PACKET_FUNCTION(Packet8f, Packet8bf, pexp)
+BF16_PACKET_FUNCTION(Packet8f, Packet8bf, ptanh)
+BF16_PACKET_FUNCTION(Packet8f, Packet8bf, psqrt)
+BF16_PACKET_FUNCTION(Packet8f, Packet8bf, prsqrt)
} // end namespace internal
diff --git a/Eigen/src/Core/arch/AVX/PacketMath.h b/Eigen/src/Core/arch/AVX/PacketMath.h
index b17c340..cf7146c 100644
--- a/Eigen/src/Core/arch/AVX/PacketMath.h
+++ b/Eigen/src/Core/arch/AVX/PacketMath.h
@@ -32,11 +32,13 @@
typedef __m256i Packet8i;
typedef __m256d Packet4d;
typedef eigen_packet_wrapper<__m128i, 2> Packet8h;
+typedef eigen_packet_wrapper<__m128i, 3> Packet8bf;
template<> struct is_arithmetic<__m256> { enum { value = true }; };
template<> struct is_arithmetic<__m256i> { enum { value = true }; };
template<> struct is_arithmetic<__m256d> { enum { value = true }; };
template<> struct is_arithmetic<Packet8h> { enum { value = true }; };
+template<> struct is_arithmetic<Packet8bf> { enum { value = true }; };
#define _EIGEN_DECLARE_CONST_Packet8f(NAME,X) \
const Packet8f p8f_##NAME = pset1<Packet8f>(X)
@@ -134,6 +136,40 @@
HasBlend = 0
};
};
+
+template <>
+struct packet_traits<bfloat16> : default_packet_traits {
+ typedef Packet8bf type;
+ // There is no half-size packet for current Packet8bf.
+ // TODO: support as SSE path.
+ typedef Packet8bf half;
+ enum {
+ Vectorizable = 1,
+ AlignedOnScalar = 1,
+ size = 8,
+ HasHalfPacket = 0,
+
+ HasCmp = 1,
+ HasDiv = 1,
+ HasSin = EIGEN_FAST_MATH,
+ HasCos = EIGEN_FAST_MATH,
+ HasLog = 1,
+ HasLog1p = 1,
+ HasExpm1 = 1,
+ HasExp = 1,
+ HasNdtri = 1,
+ HasBessel = 1,
+ HasSqrt = 1,
+ HasRsqrt = 1,
+ HasTanh = EIGEN_FAST_MATH,
+ HasErf = EIGEN_FAST_MATH,
+ HasBlend = 0,
+ HasRound = 1,
+ HasFloor = 1,
+ HasCeil = 1,
+ HasRint = 1
+ };
+};
#endif
template<> struct scalar_div_cost<float,true> { enum { value = 14 }; };
@@ -165,6 +201,14 @@
enum {size=4, alignment=Aligned32, vectorizable=true, masked_load_available=false, masked_store_available=false};
};
template<> struct unpacket_traits<Packet8i> { typedef int type; typedef Packet4i half; enum {size=8, alignment=Aligned32, vectorizable=false, masked_load_available=false, masked_store_available=false}; };
+template<> struct unpacket_traits<Packet8bf> { typedef bfloat16 type; typedef Packet8bf half; enum {size=8, alignment=Aligned16, vectorizable=true, masked_load_available=false, masked_store_available=false}; };
+
+// Helper function for bit packing snippet of low precision comparison.
+// It packs the flags from 16x16 to 8x16.
+EIGEN_STRONG_INLINE __m128i Pack16To8(Packet8f rf) {
+ return _mm_packs_epi32(_mm256_extractf128_si256(_mm256_castps_si256(rf), 0),
+ _mm256_extractf128_si256(_mm256_castps_si256(rf), 1));
+}
template<> EIGEN_STRONG_INLINE Packet8f pset1<Packet8f>(const float& from) { return _mm256_set1_ps(from); }
template<> EIGEN_STRONG_INLINE Packet4d pset1<Packet4d>(const double& from) { return _mm256_set1_pd(from); }
@@ -1032,6 +1076,307 @@
kernel.packet[3] = pload<Packet8h>(out[3]);
}
+EIGEN_STRONG_INLINE Packet8f Bf16ToF32(const Packet8bf& a) {
+#ifdef EIGEN_VECTORIZE_AVX2
+ __m256i extend = _mm256_cvtepu16_epi32(a);
+ return _mm256_castsi256_ps(_mm256_slli_epi32(extend, 16));
+#else
+ __m128i lo = _mm_cvtepu16_epi32(a);
+ __m128i hi = _mm_cvtepu16_epi32(_mm_srli_si128(a, 8));
+ __m128i lo_shift = _mm_slli_epi32(lo, 16);
+ __m128i hi_shift = _mm_slli_epi32(hi, 16);
+ return _mm256_castsi256_ps(_mm256_insertf128_si256(_mm256_castsi128_si256(lo_shift), hi_shift, 1));
+#endif
+}
+
+// Convert float to bfloat16 according to round-to-nearest-even/denormals algorithm.
+EIGEN_STRONG_INLINE Packet8bf F32ToBf16(const Packet8f& a) {
+ Packet8bf r;
+
+ // Flush input denormals value to zero with hardware capability.
+ _MM_SET_DENORMALS_ZERO_MODE(_MM_DENORMALS_ZERO_ON);
+ __m256 flush = _mm256_and_ps(a, a);
+ _MM_SET_DENORMALS_ZERO_MODE(_MM_DENORMALS_ZERO_OFF);
+
+ __m256i input = _mm256_castps_si256(flush);
+
+#ifdef EIGEN_VECTORIZE_AVX2
+ // uint32_t lsb = (input >> 16);
+ __m256i t = _mm256_srli_epi32(input, 16);
+ // uint32_t lsb = lsb & 1;
+ t = _mm256_and_si256(t, _mm256_set1_epi32(1));
+ // uint32_t rounding_bias = 0x7fff + lsb;
+ t = _mm256_add_epi32(t, _mm256_set1_epi32(0x7fff));
+ // input += rounding_bias;
+ t = _mm256_add_epi32(t, input);
+ // input = input >> 16;
+ t = _mm256_srli_epi32(t, 16);
+ // Check NaN before converting back to bf16
+ __m256 mask = _mm256_cmp_ps(flush, flush, _CMP_ORD_Q);
+ __m256i nan = _mm256_set1_epi32(0x7fc0);
+ t = _mm256_blendv_epi8(nan, t, _mm256_castps_si256(mask));
+ // output.value = static_cast<uint16_t>(input);
+ return _mm_packus_epi32(_mm256_extractf128_si256(t, 0),
+ _mm256_extractf128_si256(t, 1));
+#else
+ // uint32_t lsb = (input >> 16);
+ __m128i lo = _mm_srli_epi32(_mm256_extractf128_si256(input, 0), 16);
+ __m128i hi = _mm_srli_epi32(_mm256_extractf128_si256(input, 1), 16);
+ // uint32_t lsb = lsb & 1;
+ lo = _mm_and_si128(lo, _mm_set1_epi32(1));
+ hi = _mm_and_si128(hi, _mm_set1_epi32(1));
+ // uint32_t rounding_bias = 0x7fff + lsb;
+ lo = _mm_add_epi32(lo, _mm_set1_epi32(0x7fff));
+ hi = _mm_add_epi32(hi, _mm_set1_epi32(0x7fff));
+ // input += rounding_bias;
+ lo = _mm_add_epi32(lo, _mm256_extractf128_si256(input, 0));
+ hi = _mm_add_epi32(hi, _mm256_extractf128_si256(input, 1));
+ // input = input >> 16;
+ lo = _mm_srli_epi32(lo, 16);
+ hi = _mm_srli_epi32(hi, 16);
+ // Check NaN before converting back to bf16
+ __m256 mask = _mm256_cmp_ps(flush, flush, _CMP_ORD_Q);
+ __m128i nan = _mm_set1_epi32(0x7fc0);
+ lo = _mm_blendv_epi8(nan, lo, _mm_castps_si128(_mm256_castps256_ps128(mask)));
+ hi = _mm_blendv_epi8(nan, hi, _mm_castps_si128(_mm256_extractf128_ps(mask, 1)));
+ // output.value = static_cast<uint16_t>(input);
+ return _mm_packus_epi32(lo, hi);
+#endif
+}
+
+template<> EIGEN_STRONG_INLINE Packet8bf pset1<Packet8bf>(const bfloat16& from) {
+ return _mm_set1_epi16(from.value);
+}
+
+template<> EIGEN_STRONG_INLINE bfloat16 pfirst<Packet8bf>(const Packet8bf& from) {
+ return bfloat16_impl::raw_uint16_to_bfloat16(static_cast<unsigned short>(_mm_extract_epi16(from, 0)));
+}
+
+template<> EIGEN_STRONG_INLINE Packet8bf pload<Packet8bf>(const bfloat16* from) {
+ return _mm_load_si128(reinterpret_cast<const __m128i*>(from));
+}
+
+template<> EIGEN_STRONG_INLINE Packet8bf ploadu<Packet8bf>(const bfloat16* from) {
+ return _mm_loadu_si128(reinterpret_cast<const __m128i*>(from));
+}
+
+template<> EIGEN_STRONG_INLINE void pstore<bfloat16>(bfloat16* to, const Packet8bf& from) {
+ _mm_store_si128(reinterpret_cast<__m128i*>(to), from);
+}
+
+template<> EIGEN_STRONG_INLINE void pstoreu<bfloat16>(bfloat16* to, const Packet8bf& from) {
+ _mm_storeu_si128(reinterpret_cast<__m128i*>(to), from);
+}
+
+template<> EIGEN_STRONG_INLINE Packet8bf
+ploaddup<Packet8bf>(const bfloat16* from) {
+ unsigned short a = from[0].value;
+ unsigned short b = from[1].value;
+ unsigned short c = from[2].value;
+ unsigned short d = from[3].value;
+ return _mm_set_epi16(d, d, c, c, b, b, a, a);
+}
+
+template<> EIGEN_STRONG_INLINE Packet8bf
+ploadquad<Packet8bf>(const bfloat16* from) {
+ unsigned short a = from[0].value;
+ unsigned short b = from[1].value;
+ return _mm_set_epi16(b, b, b, b, a, a, a, a);
+}
+
+template<> EIGEN_STRONG_INLINE Packet8bf ptrue(const Packet8bf& a) {
+ return _mm_cmpeq_epi32(a, a);
+}
+
+template <>
+EIGEN_STRONG_INLINE Packet8bf pabs(const Packet8bf& a) {
+ return F32ToBf16(pabs<Packet8f>(Bf16ToF32(a)));
+}
+
+template <>
+EIGEN_STRONG_INLINE Packet8bf pmin<Packet8bf>(const Packet8bf& a,
+ const Packet8bf& b) {
+ return F32ToBf16(pmin<Packet8f>(Bf16ToF32(a), Bf16ToF32(b)));
+}
+
+template <>
+EIGEN_STRONG_INLINE Packet8bf pmax<Packet8bf>(const Packet8bf& a,
+ const Packet8bf& b) {
+ return F32ToBf16(pmax<Packet8f>(Bf16ToF32(a), Bf16ToF32(b)));
+}
+
+template<> EIGEN_STRONG_INLINE Packet8bf por(const Packet8bf& a,const Packet8bf& b) {
+ return _mm_or_si128(a,b);
+}
+template<> EIGEN_STRONG_INLINE Packet8bf pxor(const Packet8bf& a,const Packet8bf& b) {
+ return _mm_xor_si128(a,b);
+}
+template<> EIGEN_STRONG_INLINE Packet8bf pand(const Packet8bf& a,const Packet8bf& b) {
+ return _mm_and_si128(a,b);
+}
+template<> EIGEN_STRONG_INLINE Packet8bf pandnot(const Packet8bf& a,const Packet8bf& b) {
+ return _mm_andnot_si128(b,a);
+}
+
+template<> EIGEN_STRONG_INLINE Packet8bf pselect(const Packet8bf& mask, const Packet8bf& a, const Packet8bf& b) {
+ return _mm_blendv_epi8(b, a, mask);
+}
+
+template<> EIGEN_STRONG_INLINE Packet8bf pround<Packet8bf>(const Packet8bf& a)
+{
+ return F32ToBf16(pround<Packet8f>(Bf16ToF32(a)));
+}
+
+template<> EIGEN_STRONG_INLINE Packet8bf print<Packet8bf>(const Packet8bf& a) {
+ return F32ToBf16(print<Packet8f>(Bf16ToF32(a)));
+}
+
+template<> EIGEN_STRONG_INLINE Packet8bf pceil<Packet8bf>(const Packet8bf& a) {
+ return F32ToBf16(pceil<Packet8f>(Bf16ToF32(a)));
+}
+
+template<> EIGEN_STRONG_INLINE Packet8bf pfloor<Packet8bf>(const Packet8bf& a) {
+ return F32ToBf16(pfloor<Packet8f>(Bf16ToF32(a)));
+}
+
+template<> EIGEN_STRONG_INLINE Packet8bf pcmp_eq(const Packet8bf& a,const Packet8bf& b) {
+ return Pack16To8(pcmp_eq(Bf16ToF32(a), Bf16ToF32(b)));
+}
+
+template<> EIGEN_STRONG_INLINE Packet8bf pcmp_le(const Packet8bf& a,const Packet8bf& b) {
+ return Pack16To8(pcmp_le(Bf16ToF32(a), Bf16ToF32(b)));
+}
+
+template<> EIGEN_STRONG_INLINE Packet8bf pcmp_lt(const Packet8bf& a,const Packet8bf& b) {
+ return Pack16To8(pcmp_lt(Bf16ToF32(a), Bf16ToF32(b)));
+}
+
+template<> EIGEN_STRONG_INLINE Packet8bf pcmp_lt_or_nan(const Packet8bf& a,const Packet8bf& b) {
+ return Pack16To8(pcmp_lt_or_nan(Bf16ToF32(a), Bf16ToF32(b)));
+}
+
+template<> EIGEN_STRONG_INLINE Packet8bf pconj(const Packet8bf& a) { return a; }
+
+template<> EIGEN_STRONG_INLINE Packet8bf pnegate(const Packet8bf& a) {
+ Packet8bf sign_mask = _mm_set1_epi16(static_cast<unsigned short>(0x8000));
+ return _mm_xor_si128(a, sign_mask);
+}
+
+template<> EIGEN_STRONG_INLINE Packet8bf padd<Packet8bf>(const Packet8bf& a, const Packet8bf& b) {
+ return F32ToBf16(padd<Packet8f>(Bf16ToF32(a), Bf16ToF32(b)));
+}
+
+template<> EIGEN_STRONG_INLINE Packet8bf psub<Packet8bf>(const Packet8bf& a, const Packet8bf& b) {
+ return F32ToBf16(psub<Packet8f>(Bf16ToF32(a), Bf16ToF32(b)));
+}
+
+template<> EIGEN_STRONG_INLINE Packet8bf pmul<Packet8bf>(const Packet8bf& a, const Packet8bf& b) {
+ return F32ToBf16(pmul<Packet8f>(Bf16ToF32(a), Bf16ToF32(b)));
+}
+
+template<> EIGEN_STRONG_INLINE Packet8bf pdiv<Packet8bf>(const Packet8bf& a, const Packet8bf& b) {
+ return F32ToBf16(pdiv<Packet8f>(Bf16ToF32(a), Bf16ToF32(b)));
+}
+
+
+template<> EIGEN_STRONG_INLINE Packet8bf pgather<bfloat16, Packet8bf>(const bfloat16* from, Index stride)
+{
+ return _mm_set_epi16(from[7*stride].value, from[6*stride].value, from[5*stride].value, from[4*stride].value, from[3*stride].value, from[2*stride].value, from[1*stride].value, from[0*stride].value);
+}
+
+template<> EIGEN_STRONG_INLINE void pscatter<bfloat16, Packet8bf>(bfloat16* to, const Packet8bf& from, Index stride)
+{
+ EIGEN_ALIGN32 bfloat16 aux[8];
+ pstore(aux, from);
+ to[stride*0] = aux[0];
+ to[stride*1] = aux[1];
+ to[stride*2] = aux[2];
+ to[stride*3] = aux[3];
+ to[stride*4] = aux[4];
+ to[stride*5] = aux[5];
+ to[stride*6] = aux[6];
+ to[stride*7] = aux[7];
+}
+
+template<> EIGEN_STRONG_INLINE bfloat16 predux<Packet8bf>(const Packet8bf& a) {
+ return static_cast<bfloat16>(predux<Packet8f>(Bf16ToF32(a)));
+}
+
+template<> EIGEN_STRONG_INLINE bfloat16 predux_max<Packet8bf>(const Packet8bf& a) {
+ return static_cast<bfloat16>(predux_max<Packet8f>(Bf16ToF32(a)));
+}
+
+template<> EIGEN_STRONG_INLINE bfloat16 predux_min<Packet8bf>(const Packet8bf& a) {
+ return static_cast<bfloat16>(predux_min<Packet8f>(Bf16ToF32(a)));
+}
+
+template<> EIGEN_STRONG_INLINE bfloat16 predux_mul<Packet8bf>(const Packet8bf& a) {
+ return static_cast<bfloat16>(predux_mul<Packet8f>(Bf16ToF32(a)));
+}
+
+template<> EIGEN_STRONG_INLINE Packet8bf preverse(const Packet8bf& a)
+{
+ __m128i m = _mm_setr_epi8(14,15,12,13,10,11,8,9,6,7,4,5,2,3,0,1);
+ return _mm_shuffle_epi8(a,m);
+}
+
+EIGEN_STRONG_INLINE void
+ptranspose(PacketBlock<Packet8bf,8>& kernel) {
+ __m128i a = kernel.packet[0];
+ __m128i b = kernel.packet[1];
+ __m128i c = kernel.packet[2];
+ __m128i d = kernel.packet[3];
+ __m128i e = kernel.packet[4];
+ __m128i f = kernel.packet[5];
+ __m128i g = kernel.packet[6];
+ __m128i h = kernel.packet[7];
+
+ __m128i a03b03 = _mm_unpacklo_epi16(a, b);
+ __m128i c03d03 = _mm_unpacklo_epi16(c, d);
+ __m128i e03f03 = _mm_unpacklo_epi16(e, f);
+ __m128i g03h03 = _mm_unpacklo_epi16(g, h);
+ __m128i a47b47 = _mm_unpackhi_epi16(a, b);
+ __m128i c47d47 = _mm_unpackhi_epi16(c, d);
+ __m128i e47f47 = _mm_unpackhi_epi16(e, f);
+ __m128i g47h47 = _mm_unpackhi_epi16(g, h);
+
+ __m128i a01b01c01d01 = _mm_unpacklo_epi32(a03b03, c03d03);
+ __m128i a23b23c23d23 = _mm_unpackhi_epi32(a03b03, c03d03);
+ __m128i e01f01g01h01 = _mm_unpacklo_epi32(e03f03, g03h03);
+ __m128i e23f23g23h23 = _mm_unpackhi_epi32(e03f03, g03h03);
+ __m128i a45b45c45d45 = _mm_unpacklo_epi32(a47b47, c47d47);
+ __m128i a67b67c67d67 = _mm_unpackhi_epi32(a47b47, c47d47);
+ __m128i e45f45g45h45 = _mm_unpacklo_epi32(e47f47, g47h47);
+ __m128i e67f67g67h67 = _mm_unpackhi_epi32(e47f47, g47h47);
+
+ kernel.packet[0] = _mm_unpacklo_epi64(a01b01c01d01, e01f01g01h01);
+ kernel.packet[1] = _mm_unpackhi_epi64(a01b01c01d01, e01f01g01h01);
+ kernel.packet[2] = _mm_unpacklo_epi64(a23b23c23d23, e23f23g23h23);
+ kernel.packet[3] = _mm_unpackhi_epi64(a23b23c23d23, e23f23g23h23);
+ kernel.packet[4] = _mm_unpacklo_epi64(a45b45c45d45, e45f45g45h45);
+ kernel.packet[5] = _mm_unpackhi_epi64(a45b45c45d45, e45f45g45h45);
+ kernel.packet[6] = _mm_unpacklo_epi64(a67b67c67d67, e67f67g67h67);
+ kernel.packet[7] = _mm_unpackhi_epi64(a67b67c67d67, e67f67g67h67);
+}
+
+EIGEN_STRONG_INLINE void
+ptranspose(PacketBlock<Packet8bf,4>& kernel) {
+ __m128i a = kernel.packet[0];
+ __m128i b = kernel.packet[1];
+ __m128i c = kernel.packet[2];
+ __m128i d = kernel.packet[3];
+
+ __m128i ab_03 = _mm_unpacklo_epi16(a, b);
+ __m128i cd_03 = _mm_unpacklo_epi16(c, d);
+ __m128i ab_47 = _mm_unpackhi_epi16(a, b);
+ __m128i cd_47 = _mm_unpackhi_epi16(c, d);
+
+ kernel.packet[0] = _mm_unpacklo_epi32(ab_03, cd_03);
+ kernel.packet[1] = _mm_unpackhi_epi32(ab_03, cd_03);
+ kernel.packet[2] = _mm_unpacklo_epi32(ab_47, cd_47);
+ kernel.packet[3] = _mm_unpackhi_epi32(ab_47, cd_47);
+}
+
} // end namespace internal
} // end namespace Eigen
diff --git a/Eigen/src/Core/arch/AVX/TypeCasting.h b/Eigen/src/Core/arch/AVX/TypeCasting.h
index 1810435..c669a7f 100644
--- a/Eigen/src/Core/arch/AVX/TypeCasting.h
+++ b/Eigen/src/Core/arch/AVX/TypeCasting.h
@@ -76,12 +76,38 @@
};
};
+template <>
+struct type_casting_traits<bfloat16, float> {
+ enum {
+ VectorizedCast = 1,
+ SrcCoeffRatio = 1,
+ TgtCoeffRatio = 1
+ };
+};
+
+template<> EIGEN_STRONG_INLINE Packet8f pcast<Packet8bf, Packet8f>(const Packet8bf& a) {
+ return Bf16ToF32(a);
+}
+
+template <>
+struct type_casting_traits<float, bfloat16> {
+ enum {
+ VectorizedCast = 1,
+ SrcCoeffRatio = 1,
+ TgtCoeffRatio = 1
+ };
+};
+
#endif // EIGEN_VECTORIZE_AVX512
template<> EIGEN_STRONG_INLINE Packet8h pcast<Packet8f, Packet8h>(const Packet8f& a) {
return float2half(a);
}
+template<> EIGEN_STRONG_INLINE Packet8bf pcast<Packet8f, Packet8bf>(const Packet8f& a) {
+ return F32ToBf16(a);
+}
+
} // end namespace internal
} // end namespace Eigen
diff --git a/Eigen/src/Core/arch/AVX512/MathFunctions.h b/Eigen/src/Core/arch/AVX512/MathFunctions.h
index b86afce..83af5f5 100644
--- a/Eigen/src/Core/arch/AVX512/MathFunctions.h
+++ b/Eigen/src/Core/arch/AVX512/MathFunctions.h
@@ -135,10 +135,7 @@
p16f_minus_inf);
}
-template <>
-EIGEN_STRONG_INLINE Packet16bf plog<Packet16bf>(const Packet16bf& _x) {
- return F32ToBf16(plog<Packet16f>(Bf16ToF32(_x)));
-}
+BF16_PACKET_FUNCTION(Packet16f, Packet16bf, plog)
#endif
// Exponential function. Works by writing "x = m*log(2) + r" where
@@ -264,10 +261,7 @@
return pmax(pmul(x, e), _x);
}*/
-template <>
-EIGEN_STRONG_INLINE Packet16bf pexp<Packet16bf>(const Packet16bf& _x) {
- return F32ToBf16(pexp<Packet16f>(Bf16ToF32(_x)));
-}
+BF16_PACKET_FUNCTION(Packet16f, Packet16bf, pexp)
// Functions for sqrt.
// The EIGEN_FAST_MATH version uses the _mm_rsqrt_ps approximation and one step
@@ -325,10 +319,7 @@
}
#endif
-template <>
-EIGEN_STRONG_INLINE Packet16bf psqrt<Packet16bf>(const Packet16bf& x) {
- return F32ToBf16(psqrt<Packet16f>(Bf16ToF32(x)));
-}
+BF16_PACKET_FUNCTION(Packet16f, Packet16bf, psqrt)
// prsqrt for float.
#if defined(EIGEN_VECTORIZE_AVX512ER)
@@ -377,10 +368,7 @@
}
#endif
-template <>
-EIGEN_STRONG_INLINE Packet16bf prsqrt<Packet16bf>(const Packet16bf& x) {
- return F32ToBf16(prsqrt<Packet16f>(Bf16ToF32(x)));
-}
+BF16_PACKET_FUNCTION(Packet16f, Packet16bf, prsqrt)
// prsqrt for double.
#if EIGEN_FAST_MATH
@@ -435,20 +423,14 @@
return generic_plog1p(_x);
}
-template<>
-EIGEN_STRONG_INLINE Packet16bf plog1p<Packet16bf>(const Packet16bf& _x) {
- return F32ToBf16(plog1p<Packet16f>(Bf16ToF32(_x)));
-}
+BF16_PACKET_FUNCTION(Packet16f, Packet16bf, plog1p)
template<> EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS EIGEN_UNUSED
Packet16f pexpm1<Packet16f>(const Packet16f& _x) {
return generic_expm1(_x);
}
-template<>
-EIGEN_STRONG_INLINE Packet16bf pexpm1<Packet16bf>(const Packet16bf& _x) {
- return F32ToBf16(pexpm1<Packet16f>(Bf16ToF32(_x)));
-}
+BF16_PACKET_FUNCTION(Packet16f, Packet16bf, pexpm1)
#endif
#endif
@@ -461,31 +443,20 @@
}
template <>
-EIGEN_STRONG_INLINE Packet16bf psin<Packet16bf>(const Packet16bf& _x) {
- return F32ToBf16(psin<Packet16f>(Bf16ToF32(_x)));
-}
-
-template <>
EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS EIGEN_UNUSED Packet16f
pcos<Packet16f>(const Packet16f& _x) {
return pcos_float(_x);
}
template <>
-EIGEN_STRONG_INLINE Packet16bf pcos<Packet16bf>(const Packet16bf& _x) {
- return F32ToBf16(pcos<Packet16f>(Bf16ToF32(_x)));
-}
-
-template <>
EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS EIGEN_UNUSED Packet16f
ptanh<Packet16f>(const Packet16f& _x) {
return internal::generic_fast_tanh_float(_x);
}
-template <>
-EIGEN_STRONG_INLINE Packet16bf ptanh<Packet16bf>(const Packet16bf& _x) {
- return F32ToBf16(ptanh<Packet16f>(Bf16ToF32(_x)));
-}
+BF16_PACKET_FUNCTION(Packet16f, Packet16bf, psin)
+BF16_PACKET_FUNCTION(Packet16f, Packet16bf, pcos)
+BF16_PACKET_FUNCTION(Packet16f, Packet16bf, ptanh)
} // end namespace internal
diff --git a/Eigen/src/Core/arch/AVX512/PacketMath.h b/Eigen/src/Core/arch/AVX512/PacketMath.h
index 879145d..2b6693e 100644
--- a/Eigen/src/Core/arch/AVX512/PacketMath.h
+++ b/Eigen/src/Core/arch/AVX512/PacketMath.h
@@ -1633,13 +1633,13 @@
struct packet_traits<bfloat16> : default_packet_traits {
typedef Packet16bf type;
// There is no half-size packet for current Packet16bf.
- // TODO: support as SSE/AVX path.
- typedef Packet16bf half;
+ // TODO: support as SSE path.
+ typedef Packet8bf half;
enum {
Vectorizable = 1,
AlignedOnScalar = 1,
size = 16,
- HasHalfPacket = 0,
+ HasHalfPacket = 1,
HasBlend = 0,
HasInsert = 1,
HasSin = EIGEN_FAST_MATH,
@@ -1668,7 +1668,7 @@
{
typedef bfloat16 type;
enum {size=16, alignment=Aligned32, vectorizable=true, masked_load_available=false, masked_store_available=false};
- typedef Packet16bf half;
+ typedef Packet8bf half;
};
template <>
@@ -1741,7 +1741,7 @@
return _mm512_castsi512_ps(_mm512_slli_epi32(_mm512_cvtepu16_epi32(a.i), 16));
}
-// Convert float to bfloat16 according to round-to-even/denormals alogrithm.
+// Convert float to bfloat16 according to round-to-nearest-even/denormals algorithm.
EIGEN_STRONG_INLINE Packet16bf F32ToBf16(const Packet16f& a) {
Packet16bf r;
@@ -1776,7 +1776,7 @@
// output.value = static_cast<uint16_t>(input);
r.i = _mm512_cvtepi32_epi16(t);
-#endif
+#endif // EIGEN_VECTORIZE_AVX512BF16
return r;
}
@@ -1916,6 +1916,13 @@
}
template <>
+EIGEN_STRONG_INLINE Packet8bf predux_half_dowto4<Packet16bf>(const Packet16bf& a) {
+ Packet8bf lane0 = _mm256_extractf128_si256(a.i, 0);
+ Packet8bf lane1 = _mm256_extractf128_si256(a.i, 1);
+ return padd<Packet8bf>(lane0, lane1);
+}
+
+template <>
EIGEN_STRONG_INLINE bfloat16 predux<Packet16bf>(const Packet16bf& p) {
return static_cast<bfloat16>(predux<Packet16f>(Bf16ToF32(p)));
}
@@ -1944,7 +1951,7 @@
// Swap hi and lo first because shuffle is in 128-bit lanes.
res.i = _mm256_permute2x128_si256(a.i, a.i, 1);
// Shuffle 8-bit values in src within 2*128-bit lanes.
- res.i = _mm256_shuffle_epi8(a.i, m);
+ res.i = _mm256_shuffle_epi8(res.i, m);
return res;
}
@@ -2056,38 +2063,22 @@
__m256i ijklmnop_ef = _mm256_unpackhi_epi64(ijkl_cf, mnop_cf);
// NOTE: no unpacklo/hi instr in this case, so using permute instr.
- kernel.packet[0].i = _mm256_permute2x128_si256(abcdefgh_01, ijklmnop_01,
- 0x20);
- kernel.packet[1].i = _mm256_permute2x128_si256(abcdefgh_23, ijklmnop_23,
- 0x20);
- kernel.packet[2].i = _mm256_permute2x128_si256(abcdefgh_45, ijklmnop_45,
- 0x20);
- kernel.packet[3].i = _mm256_permute2x128_si256(abcdefgh_67, ijklmnop_67,
- 0x20);
- kernel.packet[4].i = _mm256_permute2x128_si256(abcdefgh_89, ijklmnop_89,
- 0x20);
- kernel.packet[5].i = _mm256_permute2x128_si256(abcdefgh_ab, ijklmnop_ab,
- 0x20);
- kernel.packet[6].i = _mm256_permute2x128_si256(abcdefgh_cd, ijklmnop_cd,
- 0x20);
- kernel.packet[7].i = _mm256_permute2x128_si256(abcdefgh_ef, ijklmnop_ef,
- 0x20);
- kernel.packet[8].i = _mm256_permute2x128_si256(abcdefgh_01, ijklmnop_01,
- 0x20);
- kernel.packet[9].i = _mm256_permute2x128_si256(abcdefgh_23, ijklmnop_23,
- 0x20);
- kernel.packet[10].i = _mm256_permute2x128_si256(abcdefgh_45, ijklmnop_45,
- 0x20);
- kernel.packet[11].i = _mm256_permute2x128_si256(abcdefgh_67, ijklmnop_67,
- 0x20);
- kernel.packet[12].i = _mm256_permute2x128_si256(abcdefgh_89, ijklmnop_89,
- 0x20);
- kernel.packet[13].i = _mm256_permute2x128_si256(abcdefgh_ab, ijklmnop_ab,
- 0x20);
- kernel.packet[14].i = _mm256_permute2x128_si256(abcdefgh_cd, ijklmnop_cd,
- 0x20);
- kernel.packet[15].i = _mm256_permute2x128_si256(abcdefgh_ef, ijklmnop_ef,
- 0x20);
+ kernel.packet[0].i = _mm256_permute2x128_si256(abcdefgh_01, ijklmnop_01, 0x20);
+ kernel.packet[1].i = _mm256_permute2x128_si256(abcdefgh_23, ijklmnop_23, 0x20);
+ kernel.packet[2].i = _mm256_permute2x128_si256(abcdefgh_45, ijklmnop_45, 0x20);
+ kernel.packet[3].i = _mm256_permute2x128_si256(abcdefgh_67, ijklmnop_67, 0x20);
+ kernel.packet[4].i = _mm256_permute2x128_si256(abcdefgh_89, ijklmnop_89, 0x20);
+ kernel.packet[5].i = _mm256_permute2x128_si256(abcdefgh_ab, ijklmnop_ab, 0x20);
+ kernel.packet[6].i = _mm256_permute2x128_si256(abcdefgh_cd, ijklmnop_cd, 0x20);
+ kernel.packet[7].i = _mm256_permute2x128_si256(abcdefgh_ef, ijklmnop_ef, 0x20);
+ kernel.packet[8].i = _mm256_permute2x128_si256(abcdefgh_01, ijklmnop_01, 0x31);
+ kernel.packet[9].i = _mm256_permute2x128_si256(abcdefgh_23, ijklmnop_23, 0x31);
+ kernel.packet[10].i = _mm256_permute2x128_si256(abcdefgh_45, ijklmnop_45, 0x31);
+ kernel.packet[11].i = _mm256_permute2x128_si256(abcdefgh_67, ijklmnop_67, 0x31);
+ kernel.packet[12].i = _mm256_permute2x128_si256(abcdefgh_89, ijklmnop_89, 0x31);
+ kernel.packet[13].i = _mm256_permute2x128_si256(abcdefgh_ab, ijklmnop_ab, 0x31);
+ kernel.packet[14].i = _mm256_permute2x128_si256(abcdefgh_cd, ijklmnop_cd, 0x31);
+ kernel.packet[15].i = _mm256_permute2x128_si256(abcdefgh_ef, ijklmnop_ef, 0x31);
}
EIGEN_STRONG_INLINE void ptranspose(PacketBlock<Packet16bf,4>& kernel) {
diff --git a/Eigen/src/Core/arch/Default/BFloat16.h b/Eigen/src/Core/arch/Default/BFloat16.h
index 3607640..30c9982 100644
--- a/Eigen/src/Core/arch/Default/BFloat16.h
+++ b/Eigen/src/Core/arch/Default/BFloat16.h
@@ -23,10 +23,31 @@
#define EIGEN_EXPLICIT_CAST(tgt_type) operator tgt_type()
#endif
+#define BF16_PACKET_FUNCTION(PACKET_F, PACKET_BF16, METHOD) \
+ template <> \
+ EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS EIGEN_UNUSED \
+ PACKET_BF16 METHOD<PACKET_BF16>(const PACKET_BF16& _x) { \
+ return F32ToBf16(METHOD<PACKET_F>(Bf16ToF32(_x))); \
+ }
+
namespace Eigen {
struct bfloat16;
+// explicit conversion operators are no available before C++11 so we first cast
+// bfloat16 to RealScalar rather than to std::complex<RealScalar> directly
+#if !EIGEN_HAS_CXX11
+namespace internal {
+template <typename RealScalar>
+struct cast_impl<bfloat16, std::complex<RealScalar> > {
+ EIGEN_DEVICE_FUNC static inline std::complex<RealScalar> run(const bfloat16 &x)
+ {
+ return static_cast<std::complex<RealScalar> >(static_cast<RealScalar>(x));
+ }
+};
+} // namespace internal
+#endif // EIGEN_HAS_CXX11
+
namespace bfloat16_impl {
// Make our own __bfloat16_raw definition.
@@ -37,7 +58,14 @@
};
EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC __bfloat16_raw raw_uint16_to_bfloat16(unsigned short value);
+template <bool AssumeArgumentIsNormalOrInfinityOrZero>
EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC __bfloat16_raw float_to_bfloat16_rtne(float ff);
+// Forward declarations of template specializations, to avoid Visual C++ 2019 errors, saying:
+// > error C2908: explicit specialization; 'float_to_bfloat16_rtne' has already been instantiated
+template <>
+EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC __bfloat16_raw float_to_bfloat16_rtne<false>(float ff);
+template <>
+EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC __bfloat16_raw float_to_bfloat16_rtne<true>(float ff);
EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC float bfloat16_to_float(__bfloat16_raw h);
struct bfloat16_base : public __bfloat16_raw {
@@ -60,14 +88,14 @@
: bfloat16_impl::bfloat16_base(bfloat16_impl::raw_uint16_to_bfloat16(b ? 0x3f80 : 0)) {}
template<class T>
explicit EIGEN_DEVICE_FUNC bfloat16(const T& val)
- : bfloat16_impl::bfloat16_base(bfloat16_impl::float_to_bfloat16_rtne(static_cast<float>(val))) {}
+ : bfloat16_impl::bfloat16_base(bfloat16_impl::float_to_bfloat16_rtne<internal::is_integral<T>::value>(static_cast<float>(val))) {}
explicit EIGEN_DEVICE_FUNC bfloat16(float f)
- : bfloat16_impl::bfloat16_base(bfloat16_impl::float_to_bfloat16_rtne(f)) {}
+ : bfloat16_impl::bfloat16_base(bfloat16_impl::float_to_bfloat16_rtne<false>(f)) {}
// Following the convention of numpy, converting between complex and
// float will lead to loss of imag value.
template<typename RealScalar>
explicit EIGEN_DEVICE_FUNC bfloat16(const std::complex<RealScalar>& val)
- : bfloat16_impl::bfloat16_base(bfloat16_impl::float_to_bfloat16_rtne(static_cast<float>(val.real()))) {}
+ : bfloat16_impl::bfloat16_base(bfloat16_impl::float_to_bfloat16_rtne<false>(static_cast<float>(val.real()))) {}
EIGEN_DEVICE_FUNC EIGEN_EXPLICIT_CAST(bool) const {
// +0.0 and -0.0 become false, everything else becomes true.
@@ -291,7 +319,7 @@
return output;
}
const uint16_t* p = reinterpret_cast<const uint16_t*>(&v);
-#if defined(__BYTE_ORDER__) && (__BYTE_ORDER__ == __ORDER_BIG_ENDIAN__)
+#if defined(__BYTE_ORDER__) && __BYTE_ORDER__ == __ORDER_BIG_ENDIAN__
output.value = p[0];
#else
output.value = p[1];
@@ -305,19 +333,13 @@
return h;
}
-union float32_bits {
- unsigned int u;
- float f;
-};
-
-EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC __bfloat16_raw float_to_bfloat16_rtne(float ff) {
+// float_to_bfloat16_rtne template specialization that does not make any
+// assumption about the value of its function argument (ff).
+template <>
+EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC __bfloat16_raw float_to_bfloat16_rtne<false>(float ff) {
#if (defined(EIGEN_HAS_CUDA_BF16) && defined(EIGEN_HAS_HIP_BF16))
// Nothing to do here
#else
- unsigned int input;
- float32_bits f;
- f.f = ff;
- input = f.u;
__bfloat16_raw output;
if (Eigen::numext::isnan EIGEN_NOT_A_MACRO(ff)) {
@@ -478,22 +500,39 @@
// Sign | Exp (8 bit) | Frac (first 7 bit)
// S E E E E E E E E F F F F F F L
// 0 1 1 1 1 1 1 1 1 0 0 0 0 0 0 0
- //
- //
+
+ // At this point, ff must be either a normal float, or +/-infinity.
+ output = float_to_bfloat16_rtne<true>(ff);
+ }
+ return output;
+#endif
+}
+
+// float_to_bfloat16_rtne template specialization that assumes that its function
+// argument (ff) is either a normal floating point number, or +/-infinity, or
+// zero. Used to improve the runtime performance of conversion from an integer
+// type to bfloat16.
+template <>
+EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC __bfloat16_raw float_to_bfloat16_rtne<true>(float ff) {
+#if (defined(EIGEN_HAS_CUDA_BF16) && defined(EIGEN_HAS_HIP_BF16))
+ // Nothing to do here
+#else
+ unsigned int input = numext::as_uint(ff);
+ __bfloat16_raw output;
+
// Least significant bit of resulting bfloat.
unsigned int lsb = (input >> 16) & 1;
unsigned int rounding_bias = 0x7fff + lsb;
input += rounding_bias;
output.value = static_cast<unsigned short>(input >> 16);
- }
- return output;
+ return output;
#endif
}
EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC float bfloat16_to_float(__bfloat16_raw h) {
float result = 0;
unsigned short* q = reinterpret_cast<unsigned short*>(&result);
-#if defined(__BYTE_ORDER__) && (__BYTE_ORDER__ == __ORDER_BIG_ENDIAN__)
+#if defined(__BYTE_ORDER__) && __BYTE_ORDER__ == __ORDER_BIG_ENDIAN__
q[0] = h.value;
#else
q[1] = h.value;
diff --git a/Eigen/src/Core/arch/Default/Half.h b/Eigen/src/Core/arch/Default/Half.h
index b84cfc7..bbc15d4 100644
--- a/Eigen/src/Core/arch/Default/Half.h
+++ b/Eigen/src/Core/arch/Default/Half.h
@@ -36,7 +36,7 @@
#ifndef EIGEN_HALF_H
#define EIGEN_HALF_H
-#if __cplusplus > 199711L
+#if EIGEN_HAS_CXX11
#define EIGEN_EXPLICIT_CAST(tgt_type) explicit operator tgt_type()
#else
#define EIGEN_EXPLICIT_CAST(tgt_type) operator tgt_type()
@@ -48,6 +48,20 @@
struct half;
+// explicit conversion operators are no available before C++11 so we first cast
+// half to RealScalar rather than to std::complex<RealScalar> directly
+#if !EIGEN_HAS_CXX11
+namespace internal {
+template <typename RealScalar>
+struct cast_impl<half, std::complex<RealScalar> > {
+ EIGEN_DEVICE_FUNC static inline std::complex<RealScalar> run(const half &x)
+ {
+ return static_cast<std::complex<RealScalar> >(static_cast<RealScalar>(x));
+ }
+};
+} // namespace internal
+#endif // EIGEN_HAS_CXX11
+
namespace half_impl {
#if !defined(EIGEN_HAS_GPU_FP16)
@@ -737,7 +751,6 @@
}
#endif
-
#if defined(EIGEN_GPU_COMPILE_PHASE)
namespace Eigen {
namespace numext {
diff --git a/Eigen/src/Core/util/Meta.h b/Eigen/src/Core/util/Meta.h
index e9e2f18..c892e49 100755
--- a/Eigen/src/Core/util/Meta.h
+++ b/Eigen/src/Core/util/Meta.h
@@ -681,10 +681,11 @@
#endif
/** \internal extract the bits of the float \a x */
-inline unsigned int as_uint(float x)
+EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC unsigned int as_uint(float x)
{
unsigned int ret;
- std::memcpy(&ret, &x, sizeof(float));
+ EIGEN_USING_STD(memcpy);
+ memcpy(&ret, &x, sizeof(float));
return ret;
}
diff --git a/Eigen/src/Geometry/Quaternion.h b/Eigen/src/Geometry/Quaternion.h
index 8e02161..61f6d0a 100644
--- a/Eigen/src/Geometry/Quaternion.h
+++ b/Eigen/src/Geometry/Quaternion.h
@@ -197,14 +197,14 @@
#else
template<typename NewScalarType>
- EIGEN_DEVICE_FUNC inline
+ EIGEN_DEVICE_FUNC inline
typename internal::enable_if<internal::is_same<Scalar,NewScalarType>::value,const Derived&>::type cast() const
{
return derived();
}
template<typename NewScalarType>
- EIGEN_DEVICE_FUNC inline
+ EIGEN_DEVICE_FUNC inline
typename internal::enable_if<!internal::is_same<Scalar,NewScalarType>::value,Quaternion<NewScalarType> >::type cast() const
{
return Quaternion<NewScalarType>(coeffs().template cast<NewScalarType>());
diff --git a/Eigen/src/Geometry/Transform.h b/Eigen/src/Geometry/Transform.h
index c87b5fe..13b1de8 100644
--- a/Eigen/src/Geometry/Transform.h
+++ b/Eigen/src/Geometry/Transform.h
@@ -262,12 +262,6 @@
internal::transform_make_affine<(int(Mode)==Affine || int(Mode)==Isometry) ? Affine : AffineCompact>::run(m_matrix);
}
- EIGEN_DEVICE_FUNC inline Transform(const Transform& other)
- {
- check_template_params();
- m_matrix = other.m_matrix;
- }
-
EIGEN_DEVICE_FUNC inline explicit Transform(const TranslationType& t)
{
check_template_params();
@@ -285,9 +279,6 @@
*this = r;
}
- EIGEN_DEVICE_FUNC inline Transform& operator=(const Transform& other)
- { m_matrix = other.m_matrix; return *this; }
-
typedef internal::transform_take_affine_part<Transform> take_affine_part;
/** Constructs and initializes a transformation from a Dim^2 or a (Dim+1)^2 matrix. */
diff --git a/Eigen/src/SparseCore/SparseBlock.h b/Eigen/src/SparseCore/SparseBlock.h
index db50902..5b4f6cc 100644
--- a/Eigen/src/SparseCore/SparseBlock.h
+++ b/Eigen/src/SparseCore/SparseBlock.h
@@ -446,9 +446,13 @@
{}
inline Index nonZerosEstimate() const {
- Index nnz = m_block.nonZeros();
- if(nnz<0)
- return m_argImpl.nonZerosEstimate() * m_block.size() / m_block.nestedExpression().size();
+ const Index nnz = m_block.nonZeros();
+ if(nnz < 0) {
+ // Scale the non-zero estimate for the underlying expression linearly with block size.
+ // Return zero if the underlying block is empty.
+ const Index nested_sz = m_block.nestedExpression().size();
+ return nested_sz == 0 ? 0 : m_argImpl.nonZerosEstimate() * m_block.size() / nested_sz;
+ }
return nnz;
}
diff --git a/cmake/EigenTesting.cmake b/cmake/EigenTesting.cmake
index 5242237..c98393e 100644
--- a/cmake/EigenTesting.cmake
+++ b/cmake/EigenTesting.cmake
@@ -310,6 +310,12 @@
message(STATUS "AVX: Using architecture defaults")
endif()
+ if(EIGEN_TEST_AVX2)
+ message(STATUS "AVX2: ON")
+ else()
+ message(STATUS "AVX2: Using architecture defaults")
+ endif()
+
if(EIGEN_TEST_FMA)
message(STATUS "FMA: ON")
else()
@@ -322,6 +328,12 @@
message(STATUS "AVX512: Using architecture defaults")
endif()
+ if(EIGEN_TEST_AVX512DQ)
+ message(STATUS "AVX512DQ: ON")
+ else()
+ message(STATUS "AVX512DQ: Using architecture defaults")
+ endif()
+
if(EIGEN_TEST_ALTIVEC)
message(STATUS "Altivec: ON")
else()
diff --git a/doc/eigendoxy_footer.html.in b/doc/eigendoxy_footer.html.in
index 94f2bab..1266535 100644
--- a/doc/eigendoxy_footer.html.in
+++ b/doc/eigendoxy_footer.html.in
@@ -17,23 +17,6 @@
</small></address>
<!--END !GENERATE_TREEVIEW-->
-<!-- Matomo -->
-<script type="text/javascript">
- var _paq = _paq || [];
- /* tracker methods like "setCustomDimension" should be called before "trackPageView" */
- _paq.push(['trackPageView']);
- _paq.push(['enableLinkTracking']);
- (function() {
- var u="//stats.sylphide-consulting.com/matomo/";
- _paq.push(['setTrackerUrl', u+'piwik.php']);
- _paq.push(['setSiteId', '20']);
- var d=document, g=d.createElement('script'), s=d.getElementsByTagName('script')[0];
- g.type='text/javascript'; g.async=true; g.defer=true; g.src=u+'piwik.js'; s.parentNode.insertBefore(g,s);
- })();
-</script>
-<noscript><p><img src="//stats.sylphide-consulting.com/matomo/piwik.php?idsite=20&rec=1" style="border:0;" alt="" /></p></noscript>
-<!-- End Matomo Code -->
-
</body>
</html>
diff --git a/test/basicstuff.cpp b/test/basicstuff.cpp
index 80fc8a0..f9044a2 100644
--- a/test/basicstuff.cpp
+++ b/test/basicstuff.cpp
@@ -195,41 +195,73 @@
VERIFY(!static_cast<const MatrixType&>(cm).imag().isZero());
}
+template<typename SrcScalar, typename TgtScalar, bool SrcIsHalfOrBF16 = (internal::is_same<SrcScalar, half>::value || internal::is_same<SrcScalar, bfloat16>::value)> struct casting_test;
+
+
template<typename SrcScalar, typename TgtScalar>
-void casting_test()
-{
- Matrix<SrcScalar,4,4> m;
- for (int i=0; i<m.rows(); ++i) {
- for (int j=0; j<m.cols(); ++j) {
- m(i, j) = internal::random_without_cast_overflow<SrcScalar,TgtScalar>::value();
+struct casting_test<SrcScalar, TgtScalar, false> {
+ static void run() {
+ Matrix<SrcScalar,4,4> m;
+ for (int i=0; i<m.rows(); ++i) {
+ for (int j=0; j<m.cols(); ++j) {
+ m(i, j) = internal::random_without_cast_overflow<SrcScalar,TgtScalar>::value();
+ }
+ }
+ Matrix<TgtScalar,4,4> n = m.template cast<TgtScalar>();
+ for (int i=0; i<m.rows(); ++i) {
+ for (int j=0; j<m.cols(); ++j) {
+ VERIFY_IS_APPROX(n(i, j), static_cast<TgtScalar>(m(i, j)));
+ }
}
}
- Matrix<TgtScalar,4,4> n = m.template cast<TgtScalar>();
- for (int i=0; i<m.rows(); ++i) {
- for (int j=0; j<m.cols(); ++j) {
- VERIFY_IS_APPROX(n(i, j), static_cast<TgtScalar>(m(i, j)));
+};
+
+template<typename SrcScalar, typename TgtScalar>
+struct casting_test<SrcScalar, TgtScalar, true> {
+ static void run() {
+ casting_test<SrcScalar, TgtScalar, false>::run();
+ }
+};
+
+template<typename SrcScalar, typename RealScalar>
+struct casting_test<SrcScalar, std::complex<RealScalar>, true> {
+ static void run() {
+ typedef std::complex<RealScalar> TgtScalar;
+ Matrix<SrcScalar,4,4> m;
+ for (int i=0; i<m.rows(); ++i) {
+ for (int j=0; j<m.cols(); ++j) {
+ m(i, j) = internal::random_without_cast_overflow<SrcScalar, TgtScalar>::value();
+ }
+ }
+ Matrix<TgtScalar,4,4> n = m.template cast<TgtScalar>();
+ for (int i=0; i<m.rows(); ++i) {
+ for (int j=0; j<m.cols(); ++j) {
+ VERIFY_IS_APPROX(n(i, j), static_cast<TgtScalar>(static_cast<RealScalar>(m(i, j))));
+ }
}
}
-}
+};
template<typename SrcScalar, typename EnableIf = void>
struct casting_test_runner {
static void run() {
- casting_test<SrcScalar, bool>();
- casting_test<SrcScalar, int8_t>();
- casting_test<SrcScalar, uint8_t>();
- casting_test<SrcScalar, int16_t>();
- casting_test<SrcScalar, uint16_t>();
- casting_test<SrcScalar, int32_t>();
- casting_test<SrcScalar, uint32_t>();
- casting_test<SrcScalar, int64_t>();
- casting_test<SrcScalar, uint64_t>();
- casting_test<SrcScalar, half>();
- casting_test<SrcScalar, bfloat16>();
- casting_test<SrcScalar, float>();
- casting_test<SrcScalar, double>();
- casting_test<SrcScalar, std::complex<float>>();
- casting_test<SrcScalar, std::complex<double>>();
+ casting_test<SrcScalar, bool>::run();
+ casting_test<SrcScalar, int8_t>::run();
+ casting_test<SrcScalar, uint8_t>::run();
+ casting_test<SrcScalar, int16_t>::run();
+ casting_test<SrcScalar, uint16_t>::run();
+ casting_test<SrcScalar, int32_t>::run();
+ casting_test<SrcScalar, uint32_t>::run();
+#if EIGEN_HAS_CXX11
+ casting_test<SrcScalar, int64_t>::run();
+ casting_test<SrcScalar, uint64_t>::run();
+#endif
+ casting_test<SrcScalar, half>::run();
+ casting_test<SrcScalar, bfloat16>::run();
+ casting_test<SrcScalar, float>::run();
+ casting_test<SrcScalar, double>::run();
+ casting_test<SrcScalar, std::complex<float> >::run();
+ casting_test<SrcScalar, std::complex<double> >::run();
}
};
@@ -238,10 +270,10 @@
{
static void run() {
// Only a few casts from std::complex<T> are defined.
- casting_test<SrcScalar, half>();
- casting_test<SrcScalar, bfloat16>();
- casting_test<SrcScalar, std::complex<float>>();
- casting_test<SrcScalar, std::complex<double>>();
+ casting_test<SrcScalar, half>::run();
+ casting_test<SrcScalar, bfloat16>::run();
+ casting_test<SrcScalar, std::complex<float> >::run();
+ casting_test<SrcScalar, std::complex<double> >::run();
}
};
@@ -253,14 +285,16 @@
casting_test_runner<uint16_t>::run();
casting_test_runner<int32_t>::run();
casting_test_runner<uint32_t>::run();
+#if EIGEN_HAS_CXX11
casting_test_runner<int64_t>::run();
casting_test_runner<uint64_t>::run();
+#endif
casting_test_runner<half>::run();
casting_test_runner<bfloat16>::run();
casting_test_runner<float>::run();
casting_test_runner<double>::run();
- casting_test_runner<std::complex<float>>::run();
- casting_test_runner<std::complex<double>>::run();
+ casting_test_runner<std::complex<float> >::run();
+ casting_test_runner<std::complex<double> >::run();
}
template <typename Scalar>
diff --git a/test/bfloat16_float.cpp b/test/bfloat16_float.cpp
index eb55f7d..11fc313 100644
--- a/test/bfloat16_float.cpp
+++ b/test/bfloat16_float.cpp
@@ -31,7 +31,7 @@
void test_truncate(float input, float expected_truncation, float expected_rounding){
bfloat16 truncated = Eigen::bfloat16_impl::truncate_to_bfloat16(input);
- bfloat16 rounded = Eigen::bfloat16_impl::float_to_bfloat16_rtne(input);
+ bfloat16 rounded = Eigen::bfloat16_impl::float_to_bfloat16_rtne<false>(input);
if ((numext::isnan)(input)){
VERIFY((numext::isnan)(static_cast<float>(truncated)) || (numext::isinf)(static_cast<float>(truncated)));
VERIFY((numext::isnan)(static_cast<float>(rounded)) || (numext::isinf)(static_cast<float>(rounded)));
@@ -41,6 +41,19 @@
VERIFY_IS_EQUAL(expected_rounding, static_cast<float>(rounded));
}
+template<typename T>
+ void test_roundtrip() {
+ // Representable T round trip via bfloat16
+ VERIFY_IS_EQUAL(static_cast<T>(static_cast<bfloat16>(-std::numeric_limits<T>::infinity())), -std::numeric_limits<T>::infinity());
+ VERIFY_IS_EQUAL(static_cast<T>(static_cast<bfloat16>(std::numeric_limits<T>::infinity())), std::numeric_limits<T>::infinity());
+ VERIFY_IS_EQUAL(static_cast<T>(static_cast<bfloat16>(T(-1.0))), T(-1.0));
+ VERIFY_IS_EQUAL(static_cast<T>(static_cast<bfloat16>(T(-0.5))), T(-0.5));
+ VERIFY_IS_EQUAL(static_cast<T>(static_cast<bfloat16>(T(-0.0))), T(-0.0));
+ VERIFY_IS_EQUAL(static_cast<T>(static_cast<bfloat16>(T(1.0))), T(1.0));
+ VERIFY_IS_EQUAL(static_cast<T>(static_cast<bfloat16>(T(0.5))), T(0.5));
+ VERIFY_IS_EQUAL(static_cast<T>(static_cast<bfloat16>(T(0.0))), T(0.0));
+}
+
void test_conversion()
{
using Eigen::bfloat16_impl::__bfloat16_raw;
@@ -93,7 +106,7 @@
} else {
VERIFY_IS_EQUAL(bf_trunc.value, 0x0000);
}
- bfloat16 bf_round = Eigen::bfloat16_impl::float_to_bfloat16_rtne(denorm);
+ bfloat16 bf_round = Eigen::bfloat16_impl::float_to_bfloat16_rtne<false>(denorm);
VERIFY_IS_EQUAL(static_cast<float>(bf_round), 0.0f);
if (std::signbit(denorm)) {
VERIFY_IS_EQUAL(bf_round.value, 0x8000);
@@ -106,14 +119,10 @@
VERIFY_IS_EQUAL(static_cast<float>(bfloat16()), 0.0f);
// Representable floats round trip via bfloat16
- VERIFY_IS_EQUAL(static_cast<float>(static_cast<bfloat16>(-std::numeric_limits<float>::infinity())), -std::numeric_limits<float>::infinity());
- VERIFY_IS_EQUAL(static_cast<float>(static_cast<bfloat16>(std::numeric_limits<float>::infinity())), std::numeric_limits<float>::infinity());
- VERIFY_IS_EQUAL(static_cast<float>(static_cast<bfloat16>(-1.0f)), -1.0f);
- VERIFY_IS_EQUAL(static_cast<float>(static_cast<bfloat16>(-0.5f)), -0.5f);
- VERIFY_IS_EQUAL(static_cast<float>(static_cast<bfloat16>(-0.0f)), -0.0f);
- VERIFY_IS_EQUAL(static_cast<float>(static_cast<bfloat16>(1.0f)), 1.0f);
- VERIFY_IS_EQUAL(static_cast<float>(static_cast<bfloat16>(0.5f)), 0.5f);
- VERIFY_IS_EQUAL(static_cast<float>(static_cast<bfloat16>(0.0f)), 0.0f);
+ test_roundtrip<float>();
+ test_roundtrip<double>();
+ test_roundtrip<std::complex<float> >();
+ test_roundtrip<std::complex<double> >();
// Truncate test
test_truncate(
diff --git a/unsupported/test/cxx11_tensor_reduction.cpp b/unsupported/test/cxx11_tensor_reduction.cpp
index f1ac83b..c46c4c9 100644
--- a/unsupported/test/cxx11_tensor_reduction.cpp
+++ b/unsupported/test/cxx11_tensor_reduction.cpp
@@ -204,7 +204,7 @@
++count;
}
}
- VERIFY_IS_APPROX(result(i, j), sum / count);
+ VERIFY_IS_APPROX(result(i, j), sum / Scalar(count));
}
}