Update Eigen to commit:b378014fef017a829fb42c7fad15f3764bfb8ef9
CHANGELOG
=========
b378014fe - Make sure we return +/-1 above the clamping point for Erf().
e2bbf496f - Use select ternary op in tensor select evaulator
2b954be66 - fix typo in sse packetmath
25685c90a - Fix incorrect packet type for unsigned int version of pfirst() in MSVC workaround in PacketMath.h.
1e223a956 - Add missing 'f' in float literal in SpecialFunctionsImpl.h that triggers implicit conversion warning.
3f3ce214e - New BF16 pcast functions and move type casting to TypeCasting.h
17b5b4de5 - Add `Packet4ui`, `Packet8ui`, and `Packet4ul` to the `SSE`/`AVX` `PacketMath.h` headers
87300c93c - Refactor IndexedView
1148f0a9e - Add dynamic dispatch to BF16 GEMM (Power) and new VSX version
3026fc0d3 - Improve accuracy of erf().
PiperOrigin-RevId: 525284849
Change-Id: I17dc0105ec012f44f39b4b26d9b6c677a86d5bc6
diff --git a/Eigen/Core b/Eigen/Core
index 4d23920..bf0b9c7 100644
--- a/Eigen/Core
+++ b/Eigen/Core
@@ -220,6 +220,7 @@
#include "src/Core/arch/SSE/Complex.h"
#elif defined(EIGEN_VECTORIZE_ALTIVEC) || defined(EIGEN_VECTORIZE_VSX)
#include "src/Core/arch/AltiVec/PacketMath.h"
+ #include "src/Core/arch/AltiVec/TypeCasting.h"
#include "src/Core/arch/AltiVec/MathFunctions.h"
#include "src/Core/arch/AltiVec/Complex.h"
#elif defined EIGEN_VECTORIZE_NEON
diff --git a/Eigen/src/Core/arch/AVX/PacketMath.h b/Eigen/src/Core/arch/AVX/PacketMath.h
index af4742b..e4aac9e 100644
--- a/Eigen/src/Core/arch/AVX/PacketMath.h
+++ b/Eigen/src/Core/arch/AVX/PacketMath.h
@@ -37,22 +37,32 @@
typedef eigen_packet_wrapper<__m128i, 2> Packet8h;
#endif
typedef eigen_packet_wrapper<__m128i, 3> Packet8bf;
+typedef eigen_packet_wrapper<__m256i, 4> Packet8ui;
#ifdef EIGEN_VECTORIZE_AVX2
// Start from 3 to be compatible with AVX512
typedef eigen_packet_wrapper<__m256i, 3> Packet4l;
+typedef eigen_packet_wrapper<__m256i, 5> Packet4ul;
#endif
template<> struct is_arithmetic<__m256> { enum { value = true }; };
template<> struct is_arithmetic<__m256i> { enum { value = true }; };
template<> struct is_arithmetic<__m256d> { enum { value = true }; };
template<> struct is_arithmetic<Packet8i> { enum { value = true }; };
+// Note that `Packet8ui` uses the underlying type `__m256i`, which is
+// interpreted as a vector of _signed_ `int32`s, which breaks some arithmetic
+// operations used in `GenericPacketMath.h`.
+template<> struct is_arithmetic<Packet8ui> { enum { value = false }; };
#ifndef EIGEN_VECTORIZE_AVX512FP16
template<> struct is_arithmetic<Packet8h> { enum { value = true }; };
#endif
template<> struct is_arithmetic<Packet8bf> { enum { value = true }; };
#ifdef EIGEN_VECTORIZE_AVX2
template<> struct is_arithmetic<Packet4l> { enum { value = true }; };
+// Note that `Packet4ul` uses the underlying type `__m256i`, which is
+// interpreted as a vector of _signed_ `int32`s, which breaks some arithmetic
+// operations used in `GenericPacketMath.h`.
+template<> struct is_arithmetic<Packet4ul> { enum { value = false }; };
#endif
// Use the packet_traits defined in AVX512/PacketMath.h instead if we're going
@@ -214,6 +224,25 @@
size=8
};
};
+template<> struct packet_traits<uint32_t> : default_packet_traits
+{
+ typedef Packet8ui type;
+ typedef Packet4ui half;
+ enum {
+ Vectorizable = 1,
+ AlignedOnScalar = 1,
+ size = 8,
+
+ HasDiv = 0,
+ HasNegate = 0,
+ HasSqrt = 0,
+
+ HasCmp = 1,
+ HasMin = 1,
+ HasMax = 1,
+ HasShift = 1
+ };
+};
#ifdef EIGEN_VECTORIZE_AVX2
template<> struct packet_traits<int64_t> : default_packet_traits
@@ -229,6 +258,29 @@
size=4
};
};
+template<> struct packet_traits<uint64_t> : default_packet_traits
+{
+ typedef Packet4ul type;
+ // There is no half-size packet for current Packet4ul.
+ // TODO: support as SSE path.
+ typedef Packet4ul half;
+ enum {
+ Vectorizable = 1,
+ AlignedOnScalar = 1,
+ size = 4,
+
+ // HasMin = 0,
+ // HasMax = 0,
+ HasDiv = 0,
+ HasBlend = 0,
+ HasTranspose = 0,
+ HasNegate = 0,
+ HasSqrt = 0,
+
+ HasCmp = 1,
+ HasShift = 1
+ };
+};
#endif
#endif
@@ -257,12 +309,22 @@
typedef Packet4i half;
enum {size=8, alignment=Aligned32, vectorizable=true, masked_load_available=false, masked_store_available=false};
};
+template<> struct unpacket_traits<Packet8ui> {
+ typedef uint32_t type;
+ typedef Packet4ui half;
+ enum {size = 8, alignment = Aligned32, vectorizable = true, masked_load_available = false, masked_store_available = false};
+};
#ifdef EIGEN_VECTORIZE_AVX2
template<> struct unpacket_traits<Packet4l> {
typedef int64_t type;
typedef Packet4l half;
enum {size=4, alignment=Aligned32, vectorizable=true, masked_load_available=false, masked_store_available=false};
};
+template<> struct unpacket_traits<Packet4ul> {
+ typedef uint64_t type;
+ typedef Packet4ul half;
+ enum {size = 4, alignment = Aligned32, vectorizable = true, masked_load_available = false, masked_store_available = false};
+};
#endif
template<> struct unpacket_traits<Packet8bf> {
typedef bfloat16 type;
@@ -283,30 +345,58 @@
return _mm256_set1_epi64x(from);
}
template <>
+EIGEN_STRONG_INLINE Packet4ul pset1<Packet4ul>(const uint64_t& from) {
+ return _mm256_set1_epi64x(numext::bit_cast<uint64_t>(from));
+}
+template <>
EIGEN_STRONG_INLINE Packet4l pzero(const Packet4l& /*a*/) {
return _mm256_setzero_si256();
}
template <>
+EIGEN_STRONG_INLINE Packet4ul pzero(const Packet4ul& /*a*/) {
+ return _mm256_setzero_si256();
+}
+template <>
EIGEN_STRONG_INLINE Packet4l peven_mask(const Packet4l& /*a*/) {
return _mm256_set_epi64x(0ll, -1ll, 0ll, -1ll);
}
template <>
+EIGEN_STRONG_INLINE Packet4ul peven_mask(const Packet4ul& /*a*/) {
+ return _mm256_set_epi64x(0ll, -1ll, 0ll, -1ll);
+}
+template <>
EIGEN_STRONG_INLINE Packet4l pload1<Packet4l>(const int64_t* from) {
return _mm256_set1_epi64x(*from);
}
template <>
+EIGEN_STRONG_INLINE Packet4ul pload1<Packet4ul>(const uint64_t* from) {
+ return _mm256_set1_epi64x(*from);
+}
+template <>
EIGEN_STRONG_INLINE Packet4l padd<Packet4l>(const Packet4l& a, const Packet4l& b) {
return _mm256_add_epi64(a, b);
}
template <>
+EIGEN_STRONG_INLINE Packet4ul padd<Packet4ul>(const Packet4ul& a, const Packet4ul& b) {
+ return _mm256_add_epi64(a, b);
+}
+template<>
EIGEN_STRONG_INLINE Packet4l plset<Packet4l>(const int64_t& a) {
return padd(pset1<Packet4l>(a), Packet4l(_mm256_set_epi64x(3ll, 2ll, 1ll, 0ll)));
}
template <>
+EIGEN_STRONG_INLINE Packet4ul plset<Packet4ul>(const uint64_t& a) {
+ return padd(pset1<Packet4ul>(a), Packet4ul(_mm256_set_epi64x(3ll, 2ll, 1ll, 0ll)));
+}
+template <>
EIGEN_STRONG_INLINE Packet4l psub<Packet4l>(const Packet4l& a, const Packet4l& b) {
return _mm256_sub_epi64(a, b);
}
template <>
+EIGEN_STRONG_INLINE Packet4ul psub<Packet4ul>(const Packet4ul& a, const Packet4ul& b) {
+ return _mm256_sub_epi64(a, b);
+}
+template <>
EIGEN_STRONG_INLINE Packet4l pnegate(const Packet4l& a) {
return psub(pzero(a), a);
}
@@ -319,18 +409,36 @@
return _mm256_xor_si256(_mm256_cmpgt_epi64(a, b), _mm256_set1_epi32(-1));
}
template <>
+EIGEN_STRONG_INLINE Packet4ul pcmp_le(const Packet4ul& a, const Packet4ul& b) {
+ return (Packet4ul)pcmp_le((Packet4l)psub(a, pset1<Packet4ul>(0x8000000000000000UL)),
+ (Packet4l)psub(b, pset1<Packet4ul>(0x8000000000000000UL)));
+}
+template <>
EIGEN_STRONG_INLINE Packet4l pcmp_lt(const Packet4l& a, const Packet4l& b) {
return _mm256_cmpgt_epi64(b, a);
}
template <>
+EIGEN_STRONG_INLINE Packet4ul pcmp_lt(const Packet4ul& a, const Packet4ul& b) {
+ return (Packet4ul)pcmp_lt((Packet4l)psub(a, pset1<Packet4ul>(0x8000000000000000UL)),
+ (Packet4l)psub(b, pset1<Packet4ul>(0x8000000000000000UL)));
+}
+template <>
EIGEN_STRONG_INLINE Packet4l pcmp_eq(const Packet4l& a, const Packet4l& b) {
return _mm256_cmpeq_epi64(a, b);
}
template <>
+EIGEN_STRONG_INLINE Packet4ul pcmp_eq(const Packet4ul& a, const Packet4ul& b) {
+ return _mm256_cmpeq_epi64(a, b);
+}
+template <>
EIGEN_STRONG_INLINE Packet4l ptrue<Packet4l>(const Packet4l& a) {
return _mm256_cmpeq_epi64(a, a);
}
template <>
+EIGEN_STRONG_INLINE Packet4ul ptrue<Packet4ul>(const Packet4ul& a) {
+ return _mm256_cmpeq_epi64(a, a);
+}
+template <>
EIGEN_STRONG_INLINE Packet4l pand<Packet4l>(const Packet4l& a, const Packet4l& b) {
return _mm256_and_si256(a, b);
}
@@ -343,6 +451,10 @@
return _mm256_xor_si256(a, b);
}
template <>
+EIGEN_STRONG_INLINE Packet4ul pxor<Packet4ul>(const Packet4ul& a, const Packet4ul& b) {
+ return _mm256_xor_si256(a, b);
+}
+template <>
EIGEN_STRONG_INLINE Packet4l pandnot<Packet4l>(const Packet4l& a, const Packet4l& b) {
return _mm256_andnot_si256(b, a);
}
@@ -388,28 +500,54 @@
EIGEN_DEBUG_ALIGNED_LOAD return _mm256_load_si256(reinterpret_cast<const __m256i*>(from));
}
template <>
+EIGEN_STRONG_INLINE Packet4ul pload<Packet4ul>(const uint64_t* from) {
+ EIGEN_DEBUG_ALIGNED_LOAD return _mm256_load_si256(reinterpret_cast<const __m256i*>(from));
+}
+template <>
EIGEN_STRONG_INLINE Packet4l ploadu<Packet4l>(const int64_t* from) {
EIGEN_DEBUG_UNALIGNED_LOAD return _mm256_loadu_si256(reinterpret_cast<const __m256i*>(from));
}
+template <>
+EIGEN_STRONG_INLINE Packet4ul ploadu<Packet4ul>(const uint64_t* from) {
+ EIGEN_DEBUG_UNALIGNED_LOAD return _mm256_loadu_si256(reinterpret_cast<const __m256i*>(from));
+}
// Loads 2 int64_ts from memory a returns the packet {a0, a0, a1, a1}
template <>
EIGEN_STRONG_INLINE Packet4l ploaddup<Packet4l>(const int64_t* from) {
const Packet4l a = _mm256_castsi128_si256(_mm_loadu_si128(reinterpret_cast<const __m128i*>(from)));
return _mm256_permutevar8x32_epi32(a, _mm256_setr_epi32(0, 1, 0, 1, 2, 3, 2, 3));
}
+// Loads 2 uint64_ts from memory a returns the packet {a0, a0, a1, a1}
template <>
+EIGEN_STRONG_INLINE Packet4ul ploaddup<Packet4ul>(const uint64_t* from) {
+ const Packet4ul a = _mm256_castsi128_si256(_mm_loadu_si128(reinterpret_cast<const __m128i*>(from)));
+ return _mm256_permutevar8x32_epi32(a, _mm256_setr_epi32(0, 1, 0, 1, 2, 3, 2, 3));
+}
+template<>
EIGEN_STRONG_INLINE void pstore<int64_t>(int64_t* to, const Packet4l& from) {
EIGEN_DEBUG_ALIGNED_STORE _mm256_storeu_si256(reinterpret_cast<__m256i*>(to), from);
}
template <>
+EIGEN_STRONG_INLINE void pstore<uint64_t>(uint64_t* to, const Packet4ul& from) {
+ EIGEN_DEBUG_ALIGNED_STORE _mm256_storeu_si256(reinterpret_cast<__m256i*>(to), from);
+}
+template <>
EIGEN_STRONG_INLINE void pstoreu<int64_t>(int64_t* to, const Packet4l& from) {
EIGEN_DEBUG_UNALIGNED_STORE _mm256_storeu_si256(reinterpret_cast<__m256i*>(to), from);
}
template <>
+EIGEN_STRONG_INLINE void pstoreu<uint64_t>(uint64_t* to, const Packet4ul& from) {
+ EIGEN_DEBUG_UNALIGNED_STORE _mm256_storeu_si256(reinterpret_cast<__m256i*>(to), from);
+}
+template <>
EIGEN_DEVICE_FUNC inline Packet4l pgather<int64_t, Packet4l>(const int64_t* from, Index stride) {
return _mm256_set_epi64x(from[3 * stride], from[2 * stride], from[1 * stride], from[0 * stride]);
}
template <>
+EIGEN_DEVICE_FUNC inline Packet4ul pgather<uint64_t, Packet4ul>(const uint64_t* from, Index stride) {
+ return _mm256_set_epi64x(from[3 * stride], from[2 * stride], from[1 * stride], from[0 * stride]);
+}
+template <>
EIGEN_DEVICE_FUNC inline void pscatter<int64_t, Packet4l>(int64_t* to, const Packet4l& from, Index stride) {
__m128i low = _mm256_extractf128_si256(from, 0);
to[stride * 0] = _mm_extract_epi64(low, 0);
@@ -420,19 +558,43 @@
to[stride * 3] = _mm_extract_epi64(high, 1);
}
template <>
+EIGEN_DEVICE_FUNC inline void pscatter<uint64_t, Packet4ul>(uint64_t* to, const Packet4ul& from, Index stride) {
+ __m128i low = _mm256_extractf128_si256(from, 0);
+ to[stride * 0] = _mm_extract_epi64(low, 0);
+ to[stride * 1] = _mm_extract_epi64(low, 1);
+
+ __m128i high = _mm256_extractf128_si256(from, 1);
+ to[stride * 2] = _mm_extract_epi64(high, 0);
+ to[stride * 3] = _mm_extract_epi64(high, 1);
+}
+template <>
EIGEN_STRONG_INLINE void pstore1<Packet4l>(int64_t* to, const int64_t& a) {
Packet4l pa = pset1<Packet4l>(a);
pstore(to, pa);
}
template <>
+EIGEN_STRONG_INLINE void pstore1<Packet4ul>(uint64_t* to, const uint64_t& a) {
+ Packet4ul pa = pset1<Packet4ul>(a);
+ pstore(to, pa);
+}
+template<>
EIGEN_STRONG_INLINE int64_t pfirst<Packet4l>(const Packet4l& a) {
return _mm_cvtsi128_si64(_mm256_castsi256_si128(a));
}
template <>
+EIGEN_STRONG_INLINE uint64_t pfirst<Packet4ul>(const Packet4ul& a) {
+ return _mm_cvtsi128_si64(_mm256_castsi256_si128(a));
+}
+template <>
EIGEN_STRONG_INLINE int64_t predux<Packet4l>(const Packet4l& a) {
__m128i r = _mm_add_epi64(_mm256_castsi256_si128(a), _mm256_extractf128_si256(a, 1));
return _mm_extract_epi64(r, 0) + _mm_extract_epi64(r, 1);
}
+template <>
+EIGEN_STRONG_INLINE uint64_t predux<Packet4ul>(const Packet4ul& a) {
+ __m128i r = _mm_add_epi64(_mm256_castsi256_si128(a), _mm256_extractf128_si256(a, 1));
+ return numext::bit_cast<uint64_t>(_mm_extract_epi64(r, 0) + _mm_extract_epi64(r, 1));
+}
#define MM256_SHUFFLE_EPI64(A, B, M) _mm256_shuffle_pd(_mm256_castsi256_pd(A), _mm256_castsi256_pd(B), M)
EIGEN_DEVICE_FUNC inline void ptranspose(PacketBlock<Packet4l, 4>& kernel) {
__m256d T0 = MM256_SHUFFLE_EPI64(kernel.packet[0], kernel.packet[1], 15);
@@ -445,6 +607,9 @@
kernel.packet[0] = _mm256_castpd_si256(_mm256_permute2f128_pd(T1, T3, 32));
kernel.packet[2] = _mm256_castpd_si256(_mm256_permute2f128_pd(T1, T3, 49));
}
+EIGEN_DEVICE_FUNC inline void ptranspose(PacketBlock<Packet4ul, 4>& kernel) {
+ ptranspose((PacketBlock<Packet4l, 4>&)kernel);
+}
template <>
EIGEN_STRONG_INLINE Packet4l pmin<Packet4l>(const Packet4l& a, const Packet4l& b) {
__m256i cmp = _mm256_cmpgt_epi64(a, b);
@@ -453,6 +618,12 @@
return Packet4l(_mm256_or_si256(a_min, b_min));
}
template <>
+EIGEN_STRONG_INLINE Packet4ul pmin<Packet4ul>(const Packet4ul& a, const Packet4ul& b) {
+ return padd((Packet4ul)pmin((Packet4l)psub(a, pset1<Packet4ul>(0x8000000000000000UL)),
+ (Packet4l)psub(b, pset1<Packet4ul>(0x8000000000000000UL))),
+ pset1<Packet4ul>(0x8000000000000000UL));
+}
+template <>
EIGEN_STRONG_INLINE Packet4l pmax<Packet4l>(const Packet4l& a, const Packet4l& b) {
__m256i cmp = _mm256_cmpgt_epi64(a, b);
__m256i a_min = _mm256_and_si256(cmp, a);
@@ -460,12 +631,22 @@
return Packet4l(_mm256_or_si256(a_min, b_min));
}
template <>
+EIGEN_STRONG_INLINE Packet4ul pmax<Packet4ul>(const Packet4ul& a, const Packet4ul& b) {
+ return padd((Packet4ul)pmax((Packet4l)psub(a, pset1<Packet4ul>(0x8000000000000000UL)),
+ (Packet4l)psub(b, pset1<Packet4ul>(0x8000000000000000UL))),
+ pset1<Packet4ul>(0x8000000000000000UL));
+}
+template <>
EIGEN_STRONG_INLINE Packet4l pabs<Packet4l>(const Packet4l& a) {
Packet4l pz = pzero<Packet4l>(a);
Packet4l cmp = _mm256_cmpgt_epi64(a, pz);
return psub(cmp, pxor(a, cmp));
}
template <>
+EIGEN_STRONG_INLINE Packet4ul pabs<Packet4ul>(const Packet4ul& a) {
+ return a;
+}
+template <>
EIGEN_STRONG_INLINE Packet4l pmul<Packet4l>(const Packet4l& a, const Packet4l& b) {
// 64-bit mul requires avx512, so do this with 32-bit multiplication
__m256i upper32_a = _mm256_srli_epi64(a, 32);
@@ -485,6 +666,7 @@
template<> EIGEN_STRONG_INLINE Packet8f pset1<Packet8f>(const float& from) { return _mm256_set1_ps(from); }
template<> EIGEN_STRONG_INLINE Packet4d pset1<Packet4d>(const double& from) { return _mm256_set1_pd(from); }
template<> EIGEN_STRONG_INLINE Packet8i pset1<Packet8i>(const int& from) { return _mm256_set1_epi32(from); }
+template<> EIGEN_STRONG_INLINE Packet8ui pset1<Packet8ui>(const uint32_t& from) { return _mm256_set1_epi32(from); }
template<> EIGEN_STRONG_INLINE Packet8f pset1frombits<Packet8f>(unsigned int from) { return _mm256_castsi256_ps(pset1<Packet8i>(from)); }
template<> EIGEN_STRONG_INLINE Packet4d pset1frombits<Packet4d>(uint64_t from) { return _mm256_castsi256_pd(_mm256_set1_epi64x(from)); }
@@ -492,10 +674,12 @@
template<> EIGEN_STRONG_INLINE Packet8f pzero(const Packet8f& /*a*/) { return _mm256_setzero_ps(); }
template<> EIGEN_STRONG_INLINE Packet4d pzero(const Packet4d& /*a*/) { return _mm256_setzero_pd(); }
template<> EIGEN_STRONG_INLINE Packet8i pzero(const Packet8i& /*a*/) { return _mm256_setzero_si256(); }
+template<> EIGEN_STRONG_INLINE Packet8ui pzero(const Packet8ui& /*a*/) { return _mm256_setzero_si256(); }
template<> EIGEN_STRONG_INLINE Packet8f peven_mask(const Packet8f& /*a*/) { return _mm256_castsi256_ps(_mm256_set_epi32(0, -1, 0, -1, 0, -1, 0, -1)); }
template<> EIGEN_STRONG_INLINE Packet8i peven_mask(const Packet8i& /*a*/) { return _mm256_set_epi32(0, -1, 0, -1, 0, -1, 0, -1); }
+template<> EIGEN_STRONG_INLINE Packet8ui peven_mask(const Packet8ui& /*a*/) { return _mm256_set_epi32(0, -1, 0, -1, 0, -1, 0, -1); }
template<> EIGEN_STRONG_INLINE Packet4d peven_mask(const Packet4d& /*a*/) { return _mm256_castsi256_pd(_mm256_set_epi32(0, 0, -1, -1, 0, 0, -1, -1)); }
template<> EIGEN_STRONG_INLINE Packet8f pload1<Packet8f>(const float* from) { return _mm256_broadcast_ss(from); }
@@ -522,10 +706,21 @@
return _mm256_insertf128_si256(_mm256_castsi128_si256(lo), (hi), 1);
#endif
}
+template<> EIGEN_STRONG_INLINE Packet8ui padd<Packet8ui>(const Packet8ui& a, const Packet8ui& b)
+{
+#ifdef EIGEN_VECTORIZE_AVX2
+ return _mm256_add_epi32(a, b);
+#else
+ __m128i lo = _mm_add_epi32(_mm256_extractf128_si256(a, 0), _mm256_extractf128_si256(b, 0));
+ __m128i hi = _mm_add_epi32(_mm256_extractf128_si256(a, 1), _mm256_extractf128_si256(b, 1));
+ return _mm256_insertf128_si256(_mm256_castsi128_si256(lo), (hi), 1);
+#endif
+}
template<> EIGEN_STRONG_INLINE Packet8f plset<Packet8f>(const float& a) { return padd(pset1<Packet8f>(a), _mm256_set_ps(7.0,6.0,5.0,4.0,3.0,2.0,1.0,0.0)); }
template<> EIGEN_STRONG_INLINE Packet4d plset<Packet4d>(const double& a) { return padd(pset1<Packet4d>(a), _mm256_set_pd(3.0,2.0,1.0,0.0)); }
template<> EIGEN_STRONG_INLINE Packet8i plset<Packet8i>(const int& a) { return padd(pset1<Packet8i>(a), (Packet8i)_mm256_set_epi32(7,6,5,4,3,2,1,0)); }
+template<> EIGEN_STRONG_INLINE Packet8ui plset<Packet8ui>(const uint32_t& a) { return padd(pset1<Packet8ui>(a), (Packet8ui)_mm256_set_epi32(7, 6, 5, 4, 3, 2, 1, 0)); }
template<> EIGEN_STRONG_INLINE Packet8f psub<Packet8f>(const Packet8f& a, const Packet8f& b) { return _mm256_sub_ps(a,b); }
template<> EIGEN_STRONG_INLINE Packet4d psub<Packet4d>(const Packet4d& a, const Packet4d& b) { return _mm256_sub_pd(a,b); }
@@ -538,6 +733,16 @@
return _mm256_insertf128_si256(_mm256_castsi128_si256(lo), (hi), 1);
#endif
}
+template<> EIGEN_STRONG_INLINE Packet8ui psub<Packet8ui>(const Packet8ui& a, const Packet8ui& b)
+{
+#ifdef EIGEN_VECTORIZE_AVX2
+ return _mm256_sub_epi32(a, b);
+#else
+ __m128i lo = _mm_sub_epi32(_mm256_extractf128_si256(a, 0), _mm256_extractf128_si256(b, 0));
+ __m128i hi = _mm_sub_epi32(_mm256_extractf128_si256(a, 1), _mm256_extractf128_si256(b, 1));
+ return _mm256_insertf128_si256(_mm256_castsi128_si256(lo), (hi), 1);
+#endif
+}
template<> EIGEN_STRONG_INLINE Packet8f pnegate(const Packet8f& a)
{
@@ -569,6 +774,16 @@
return _mm256_insertf128_si256(_mm256_castsi128_si256(lo), (hi), 1);
#endif
}
+template<> EIGEN_STRONG_INLINE Packet8ui pmul<Packet8ui>(const Packet8ui& a, const Packet8ui& b)
+{
+#ifdef EIGEN_VECTORIZE_AVX2
+ return _mm256_mullo_epi32(a, b);
+#else
+ const __m128i lo = _mm_mullo_epi32(_mm256_extractf128_si256(a, 0), _mm256_extractf128_si256(b, 0));
+ const __m128i hi = _mm_mullo_epi32(_mm256_extractf128_si256(a, 1), _mm256_extractf128_si256(b, 1));
+ return _mm256_insertf128_si256(_mm256_castsi128_si256(lo), (hi), 1);
+#endif
+}
template<> EIGEN_STRONG_INLINE Packet8f pdiv<Packet8f>(const Packet8f& a, const Packet8f& b) { return _mm256_div_ps(a,b); }
template<> EIGEN_STRONG_INLINE Packet4d pdiv<Packet4d>(const Packet4d& a, const Packet4d& b) { return _mm256_div_pd(a,b); }
@@ -577,7 +792,7 @@
{
#ifdef EIGEN_VECTORIZE_AVX512
return _mm512_cvttpd_epi32(_mm512_div_pd(_mm512_cvtepi32_pd(a), _mm512_cvtepi32_pd(b)));
-#else
+#else
Packet4i lo = pdiv<Packet4i>(_mm256_extractf128_si256(a, 0), _mm256_extractf128_si256(b, 0));
Packet4i hi = pdiv<Packet4i>(_mm256_extractf128_si256(a, 1), _mm256_extractf128_si256(b, 1));
return _mm256_insertf128_si256(_mm256_castsi128_si256(lo), hi, 1);
@@ -666,6 +881,15 @@
return _mm256_insertf128_si256(_mm256_castsi128_si256(lo), (hi), 1);
#endif
}
+template<> EIGEN_STRONG_INLINE Packet8ui pcmp_eq(const Packet8ui& a, const Packet8ui& b) {
+#ifdef EIGEN_VECTORIZE_AVX2
+ return _mm256_cmpeq_epi32(a, b);
+#else
+ __m128i lo = _mm_cmpeq_epi32(_mm256_extractf128_si256(a, 0), _mm256_extractf128_si256(b, 0));
+ __m128i hi = _mm_cmpeq_epi32(_mm256_extractf128_si256(a, 1), _mm256_extractf128_si256(b, 1));
+ return _mm256_insertf128_si256(_mm256_castsi128_si256(lo), (hi), 1);
+#endif
+}
template<> EIGEN_STRONG_INLINE Packet8f pmin<Packet8f>(const Packet8f& a, const Packet8f& b) {
#if EIGEN_GNUC_STRICT_LESS_THAN(6,3,0)
@@ -701,6 +925,15 @@
return _mm256_insertf128_si256(_mm256_castsi128_si256(lo), (hi), 1);
#endif
}
+template<> EIGEN_STRONG_INLINE Packet8ui pmin<Packet8ui>(const Packet8ui& a, const Packet8ui& b) {
+#ifdef EIGEN_VECTORIZE_AVX2
+ return _mm256_min_epu32(a, b);
+#else
+ __m128i lo = _mm_min_epu32(_mm256_extractf128_si256(a, 0), _mm256_extractf128_si256(b, 0));
+ __m128i hi = _mm_min_epu32(_mm256_extractf128_si256(a, 1), _mm256_extractf128_si256(b, 1));
+ return _mm256_insertf128_si256(_mm256_castsi128_si256(lo), (hi), 1);
+#endif
+}
template<> EIGEN_STRONG_INLINE Packet8f pmax<Packet8f>(const Packet8f& a, const Packet8f& b) {
#if EIGEN_GNUC_STRICT_LESS_THAN(6,3,0)
@@ -733,6 +966,15 @@
return _mm256_insertf128_si256(_mm256_castsi128_si256(lo), (hi), 1);
#endif
}
+template<> EIGEN_STRONG_INLINE Packet8ui pmax<Packet8ui>(const Packet8ui& a, const Packet8ui& b) {
+#ifdef EIGEN_VECTORIZE_AVX2
+ return _mm256_max_epu32(a, b);
+#else
+ __m128i lo = _mm_max_epu32(_mm256_extractf128_si256(a, 0), _mm256_extractf128_si256(b, 0));
+ __m128i hi = _mm_max_epu32(_mm256_extractf128_si256(a, 1), _mm256_extractf128_si256(b, 1));
+ return _mm256_insertf128_si256(_mm256_castsi128_si256(lo), (hi), 1);
+#endif
+}
#ifdef EIGEN_VECTORIZE_AVX2
template<> EIGEN_STRONG_INLINE Packet8i psign(const Packet8i& a) {
@@ -823,6 +1065,13 @@
return _mm256_castps_si256(_mm256_and_ps(_mm256_castsi256_ps(a),_mm256_castsi256_ps(b)));
#endif
}
+template<> EIGEN_STRONG_INLINE Packet8ui pand<Packet8ui>(const Packet8ui& a, const Packet8ui& b) {
+#ifdef EIGEN_VECTORIZE_AVX2
+ return _mm256_and_si256(a,b);
+#else
+ return _mm256_castps_si256(_mm256_and_ps(_mm256_castsi256_ps(a),_mm256_castsi256_ps(b)));
+#endif
+}
template<> EIGEN_STRONG_INLINE Packet8f por<Packet8f>(const Packet8f& a, const Packet8f& b) { return _mm256_or_ps(a,b); }
template<> EIGEN_STRONG_INLINE Packet4d por<Packet4d>(const Packet4d& a, const Packet4d& b) { return _mm256_or_pd(a,b); }
@@ -833,6 +1082,13 @@
return _mm256_castps_si256(_mm256_or_ps(_mm256_castsi256_ps(a),_mm256_castsi256_ps(b)));
#endif
}
+template<> EIGEN_STRONG_INLINE Packet8ui por<Packet8ui>(const Packet8ui& a, const Packet8ui& b) {
+#ifdef EIGEN_VECTORIZE_AVX2
+ return _mm256_or_si256(a,b);
+#else
+ return _mm256_castps_si256(_mm256_or_ps(_mm256_castsi256_ps(a),_mm256_castsi256_ps(b)));
+#endif
+}
template<> EIGEN_STRONG_INLINE Packet8f pxor<Packet8f>(const Packet8f& a, const Packet8f& b) { return _mm256_xor_ps(a,b); }
template<> EIGEN_STRONG_INLINE Packet4d pxor<Packet4d>(const Packet4d& a, const Packet4d& b) { return _mm256_xor_pd(a,b); }
@@ -843,6 +1099,13 @@
return _mm256_castps_si256(_mm256_xor_ps(_mm256_castsi256_ps(a),_mm256_castsi256_ps(b)));
#endif
}
+template<> EIGEN_STRONG_INLINE Packet8ui pxor<Packet8ui>(const Packet8ui& a, const Packet8ui& b) {
+#ifdef EIGEN_VECTORIZE_AVX2
+ return _mm256_xor_si256(a, b);
+#else
+ return _mm256_castps_si256(_mm256_xor_ps(_mm256_castsi256_ps(a), _mm256_castsi256_ps(b)));
+#endif
+}
template<> EIGEN_STRONG_INLINE Packet8f pandnot<Packet8f>(const Packet8f& a, const Packet8f& b) { return _mm256_andnot_ps(b,a); }
template<> EIGEN_STRONG_INLINE Packet4d pandnot<Packet4d>(const Packet4d& a, const Packet4d& b) { return _mm256_andnot_pd(b,a); }
@@ -853,6 +1116,20 @@
return _mm256_castps_si256(_mm256_andnot_ps(_mm256_castsi256_ps(b),_mm256_castsi256_ps(a)));
#endif
}
+template<> EIGEN_STRONG_INLINE Packet8ui pandnot<Packet8ui>(const Packet8ui& a, const Packet8ui& b) {
+#ifdef EIGEN_VECTORIZE_AVX2
+ return _mm256_andnot_si256(b,a);
+#else
+ return _mm256_castps_si256(_mm256_andnot_ps(_mm256_castsi256_ps(b),_mm256_castsi256_ps(a)));
+#endif
+}
+
+template<> EIGEN_STRONG_INLINE Packet8ui pcmp_lt(const Packet8ui& a, const Packet8ui& b) {
+ return pxor(pcmp_eq(a, pmax(a, b)), ptrue(a));
+}
+template<> EIGEN_STRONG_INLINE Packet8ui pcmp_le(const Packet8ui& a, const Packet8ui& b) {
+ return pcmp_eq(a, pmin(a, b));
+}
template<> EIGEN_STRONG_INLINE Packet8f pround<Packet8f>(const Packet8f& a)
{
@@ -871,6 +1148,8 @@
{ return _mm256_blendv_ps(b,a,mask); }
template<> EIGEN_STRONG_INLINE Packet8i pselect<Packet8i>(const Packet8i& mask, const Packet8i& a, const Packet8i& b)
{ return _mm256_castps_si256(_mm256_blendv_ps(_mm256_castsi256_ps(b), _mm256_castsi256_ps(a), _mm256_castsi256_ps(mask))); }
+template<> EIGEN_STRONG_INLINE Packet8ui pselect<Packet8ui>(const Packet8ui& mask, const Packet8ui& a, const Packet8ui& b)
+{ return _mm256_castps_si256(_mm256_blendv_ps(_mm256_castsi256_ps(b), _mm256_castsi256_ps(a), _mm256_castsi256_ps(mask))); }
template<> EIGEN_STRONG_INLINE Packet4d pselect<Packet4d>(const Packet4d& mask, const Packet4d& a, const Packet4d& b)
{ return _mm256_blendv_pd(b,a,mask); }
@@ -905,13 +1184,25 @@
#endif
}
+template<int N> EIGEN_STRONG_INLINE Packet8ui parithmetic_shift_right(Packet8ui a) {
+ return (Packet8ui)plogical_shift_right<N>((Packet8i)a);
+}
+template<int N> EIGEN_STRONG_INLINE Packet8ui plogical_shift_right(Packet8ui a) {
+ return (Packet8ui)plogical_shift_right<N>((Packet8i)a);
+}
+template<int N> EIGEN_STRONG_INLINE Packet8ui plogical_shift_left(Packet8ui a) {
+ return (Packet8ui)plogical_shift_left<N>((Packet8i)a);
+}
+
template<> EIGEN_STRONG_INLINE Packet8f pload<Packet8f>(const float* from) { EIGEN_DEBUG_ALIGNED_LOAD return _mm256_load_ps(from); }
template<> EIGEN_STRONG_INLINE Packet4d pload<Packet4d>(const double* from) { EIGEN_DEBUG_ALIGNED_LOAD return _mm256_load_pd(from); }
template<> EIGEN_STRONG_INLINE Packet8i pload<Packet8i>(const int* from) { EIGEN_DEBUG_ALIGNED_LOAD return _mm256_load_si256(reinterpret_cast<const __m256i*>(from)); }
+template<> EIGEN_STRONG_INLINE Packet8ui pload<Packet8ui>(const uint32_t* from) { EIGEN_DEBUG_ALIGNED_LOAD return _mm256_load_si256(reinterpret_cast<const __m256i*>(from)); }
template<> EIGEN_STRONG_INLINE Packet8f ploadu<Packet8f>(const float* from) { EIGEN_DEBUG_UNALIGNED_LOAD return _mm256_loadu_ps(from); }
template<> EIGEN_STRONG_INLINE Packet4d ploadu<Packet4d>(const double* from) { EIGEN_DEBUG_UNALIGNED_LOAD return _mm256_loadu_pd(from); }
template<> EIGEN_STRONG_INLINE Packet8i ploadu<Packet8i>(const int* from) { EIGEN_DEBUG_UNALIGNED_LOAD return _mm256_loadu_si256(reinterpret_cast<const __m256i*>(from)); }
+template<> EIGEN_STRONG_INLINE Packet8ui ploadu<Packet8ui>(const uint32_t* from) { EIGEN_DEBUG_UNALIGNED_LOAD return _mm256_loadu_si256(reinterpret_cast<const __m256i*>(from)); }
template<> EIGEN_STRONG_INLINE Packet8f ploadu<Packet8f>(const float* from, uint8_t umask) {
#ifdef EIGEN_VECTORIZE_AVX512
@@ -930,7 +1221,7 @@
template<> EIGEN_STRONG_INLINE Packet8f ploaddup<Packet8f>(const float* from)
{
// TODO try to find a way to avoid the need of a temporary register
-// Packet8f tmp = _mm256_castps128_ps256(_mm_loadu_ps(from));
+ // Packet8f tmp = _mm256_castps128_ps256(_mm_loadu_ps(from));
// tmp = _mm256_insertf128_ps(tmp, _mm_movehl_ps(_mm256_castps256_ps128(tmp),_mm256_castps256_ps128(tmp)), 1);
// return _mm256_unpacklo_ps(tmp,tmp);
@@ -961,6 +1252,20 @@
return _mm256_castps_si256(_mm256_permute_ps(tmp, _MM_SHUFFLE(3,3,2,2)));
#endif
}
+template<> EIGEN_STRONG_INLINE Packet8ui ploaddup<Packet8ui>(const uint32_t* from) {
+#ifdef EIGEN_VECTORIZE_AVX2
+ const Packet8ui a = _mm256_castsi128_si256(ploadu<Packet4ui>(from));
+ return _mm256_permutevar8x32_epi32(a, _mm256_setr_epi32(0, 0, 1, 1, 2, 2, 3, 3));
+#else
+ __m256 tmp = _mm256_broadcast_ps((const __m128*)(const void*)from);
+ // mimic an "inplace" permutation of the lower 128bits using a blend
+ tmp = _mm256_blend_ps(
+ tmp, _mm256_castps128_ps256(_mm_permute_ps(_mm256_castps256_ps128(tmp), _MM_SHUFFLE(1, 0, 1, 0))), 15);
+ // then we can perform a consistent permutation on the global register to get
+ // everything in shape:
+ return _mm256_castps_si256(_mm256_permute_ps(tmp, _MM_SHUFFLE(3, 3, 2, 2)));
+#endif
+}
// Loads 2 floats from memory a returns the packet {a0, a0 a0, a0, a1, a1, a1, a1}
template<> EIGEN_STRONG_INLINE Packet8f ploadquad<Packet8f>(const float* from)
@@ -972,14 +1277,19 @@
{
return _mm256_insertf128_si256(_mm256_set1_epi32(*from), _mm_set1_epi32(*(from+1)), 1);
}
+template<> EIGEN_STRONG_INLINE Packet8ui ploadquad<Packet8ui>(const uint32_t* from) {
+ return _mm256_insertf128_si256(_mm256_set1_epi32(*from), _mm_set1_epi32(*(from + 1)), 1);
+}
template<> EIGEN_STRONG_INLINE void pstore<float>(float* to, const Packet8f& from) { EIGEN_DEBUG_ALIGNED_STORE _mm256_store_ps(to, from); }
template<> EIGEN_STRONG_INLINE void pstore<double>(double* to, const Packet4d& from) { EIGEN_DEBUG_ALIGNED_STORE _mm256_store_pd(to, from); }
template<> EIGEN_STRONG_INLINE void pstore<int>(int* to, const Packet8i& from) { EIGEN_DEBUG_ALIGNED_STORE _mm256_storeu_si256(reinterpret_cast<__m256i*>(to), from); }
+template<> EIGEN_STRONG_INLINE void pstore<uint32_t>(uint32_t* to, const Packet8ui& from) { EIGEN_DEBUG_ALIGNED_STORE _mm256_storeu_si256(reinterpret_cast<__m256i*>(to), from); }
template<> EIGEN_STRONG_INLINE void pstoreu<float>(float* to, const Packet8f& from) { EIGEN_DEBUG_UNALIGNED_STORE _mm256_storeu_ps(to, from); }
template<> EIGEN_STRONG_INLINE void pstoreu<double>(double* to, const Packet4d& from) { EIGEN_DEBUG_UNALIGNED_STORE _mm256_storeu_pd(to, from); }
template<> EIGEN_STRONG_INLINE void pstoreu<int>(int* to, const Packet8i& from) { EIGEN_DEBUG_UNALIGNED_STORE _mm256_storeu_si256(reinterpret_cast<__m256i*>(to), from); }
+template<> EIGEN_STRONG_INLINE void pstoreu<uint32_t>(uint32_t* to, const Packet8ui& from) { EIGEN_DEBUG_UNALIGNED_STORE _mm256_storeu_si256(reinterpret_cast<__m256i*>(to), from); }
template<> EIGEN_STRONG_INLINE void pstoreu<float>(float* to, const Packet8f& from, uint8_t umask) {
#ifdef EIGEN_VECTORIZE_AVX512
@@ -1017,6 +1327,9 @@
return _mm256_set_epi32(from[7*stride], from[6*stride], from[5*stride], from[4*stride],
from[3*stride], from[2*stride], from[1*stride], from[0*stride]);
}
+template<> EIGEN_DEVICE_FUNC inline Packet8ui pgather<uint32_t, Packet8ui>(const uint32_t* from, Index stride) {
+ return (Packet8ui)pgather<int, Packet8i>((int*)from, stride);
+}
template<> EIGEN_DEVICE_FUNC inline void pscatter<float, Packet8f>(float* to, const Packet8f& from, Index stride)
{
@@ -1055,6 +1368,9 @@
to[stride*6] = _mm_extract_epi32(high, 2);
to[stride*7] = _mm_extract_epi32(high, 3);
}
+template<> EIGEN_DEVICE_FUNC inline void pscatter<uint32_t, Packet8ui>(uint32_t* to, const Packet8ui& from, Index stride) {
+ pscatter<int, Packet8i>((int*)to, (Packet8i)from, stride);
+}
template<> EIGEN_STRONG_INLINE void pstore1<Packet8f>(float* to, const float& a)
{
@@ -1076,6 +1392,7 @@
template<> EIGEN_STRONG_INLINE void prefetch<float>(const float* addr) { _mm_prefetch((SsePrefetchPtrType)(addr), _MM_HINT_T0); }
template<> EIGEN_STRONG_INLINE void prefetch<double>(const double* addr) { _mm_prefetch((SsePrefetchPtrType)(addr), _MM_HINT_T0); }
template<> EIGEN_STRONG_INLINE void prefetch<int>(const int* addr) { _mm_prefetch((SsePrefetchPtrType)(addr), _MM_HINT_T0); }
+template<> EIGEN_STRONG_INLINE void prefetch<uint32_t>(const uint32_t* addr) { _mm_prefetch((SsePrefetchPtrType)(addr), _MM_HINT_T0); }
#endif
template<> EIGEN_STRONG_INLINE float pfirst<Packet8f>(const Packet8f& a) {
@@ -1087,6 +1404,9 @@
template<> EIGEN_STRONG_INLINE int pfirst<Packet8i>(const Packet8i& a) {
return _mm_cvtsi128_si32(_mm256_castsi256_si128(a));
}
+template<> EIGEN_STRONG_INLINE uint32_t pfirst<Packet8ui>(const Packet8ui& a) {
+ return numext::bit_cast<uint32_t>(_mm_cvtsi128_si32(_mm256_castsi256_si128(a)));
+}
template<> EIGEN_STRONG_INLINE Packet8f preverse(const Packet8f& a)
@@ -1098,21 +1418,27 @@
{
__m256d tmp = _mm256_shuffle_pd(a,a,5);
return _mm256_permute2f128_pd(tmp, tmp, 1);
- #if 0
+#if 0
// This version is unlikely to be faster as _mm256_shuffle_ps and _mm256_permute_pd
// exhibit the same latency/throughput, but it is here for future reference/benchmarking...
__m256d swap_halves = _mm256_permute2f128_pd(a,a,1);
return _mm256_permute_pd(swap_halves,5);
- #endif
+#endif
}
template<> EIGEN_STRONG_INLINE Packet8i preverse(const Packet8i& a)
{
return _mm256_castps_si256(preverse(_mm256_castsi256_ps(a)));
}
+template<> EIGEN_STRONG_INLINE Packet8ui preverse(const Packet8ui& a) {
+ return _mm256_castps_si256(preverse(_mm256_castsi256_ps(a)));
+}
#ifdef EIGEN_VECTORIZE_AVX2
template<> EIGEN_STRONG_INLINE Packet4l preverse(const Packet4l& a)
-{
+ {
+ return _mm256_castpd_si256(preverse(_mm256_castsi256_pd(a)));
+}
+template<> EIGEN_STRONG_INLINE Packet4ul preverse(const Packet4ul& a) {
return _mm256_castpd_si256(preverse(_mm256_castsi256_pd(a)));
}
#endif
@@ -1138,12 +1464,15 @@
return _mm256_insertf128_si256(_mm256_castsi128_si256(lo), (hi), 1);
#endif
}
+template<> EIGEN_STRONG_INLINE Packet8ui pabs(const Packet8ui& a) { return a; }
template<> EIGEN_STRONG_INLINE Packet8h psignbit(const Packet8h& a) { return _mm_srai_epi16(a, 15); }
template<> EIGEN_STRONG_INLINE Packet8bf psignbit(const Packet8bf& a) { return _mm_srai_epi16(a, 15); }
template<> EIGEN_STRONG_INLINE Packet8f psignbit(const Packet8f& a) { return _mm256_castsi256_ps(parithmetic_shift_right<31>((Packet8i)_mm256_castps_si256(a))); }
+template<> EIGEN_STRONG_INLINE Packet8ui psignbit(const Packet8ui& a) { return pzero(a); }
#ifdef EIGEN_VECTORIZE_AVX2
template<> EIGEN_STRONG_INLINE Packet4d psignbit(const Packet4d& a) { return _mm256_castsi256_pd(parithmetic_shift_right<63>((Packet4l)_mm256_castpd_si256(a))); }
+template<> EIGEN_STRONG_INLINE Packet4ul psignbit(const Packet4ul& a) { return pzero(a); }
#endif
template<> EIGEN_STRONG_INLINE Packet8f pfrexp<Packet8f>(const Packet8f& a, Packet8f& exponent) {
@@ -1186,18 +1515,18 @@
// Clamp exponent to [-2099, 2099]
const Packet4d max_exponent = pset1<Packet4d>(2099.0);
const Packet4i e = _mm256_cvtpd_epi32(pmin(pmax(exponent, pnegate(max_exponent)), max_exponent));
-
+
// Split 2^e into four factors and multiply.
const Packet4i bias = pset1<Packet4i>(1023);
Packet4i b = parithmetic_shift_right<2>(e); // floor(e/4)
-
+
// 2^b
Packet4i hi = vec4i_swizzle1(padd(b, bias), 0, 2, 1, 3);
Packet4i lo = _mm_slli_epi64(hi, 52);
hi = _mm_slli_epi64(_mm_srli_epi64(hi, 32), 52);
Packet4d c = _mm256_castsi256_pd(_mm256_insertf128_si256(_mm256_castsi128_si256(lo), hi, 1));
Packet4d out = pmul(pmul(pmul(a, c), c), c); // a * 2^(3b)
-
+
// 2^(e - 3b)
b = psub(psub(psub(e, b), b), b); // e - 3b
hi = vec4i_swizzle1(padd(b, bias), 0, 2, 1, 3);
@@ -1220,6 +1549,9 @@
{
return predux(Packet4i(_mm_add_epi32(_mm256_castsi256_si128(a),_mm256_extractf128_si256(a,1))));
}
+template<> EIGEN_STRONG_INLINE uint32_t predux<Packet8ui>(const Packet8ui& a) {
+ return predux(Packet4ui(_mm_add_epi32(_mm256_castsi256_si128(a), _mm256_extractf128_si256(a, 1))));
+}
template<> EIGEN_STRONG_INLINE Packet4f predux_half_dowto4<Packet8f>(const Packet8f& a)
{
@@ -1229,6 +1561,9 @@
{
return _mm_add_epi32(_mm256_castsi256_si128(a),_mm256_extractf128_si256(a,1));
}
+template<> EIGEN_STRONG_INLINE Packet4ui predux_half_dowto4<Packet8ui>(const Packet8ui& a) {
+ return _mm_add_epi32(_mm256_castsi256_si128(a), _mm256_extractf128_si256(a, 1));
+}
template<> EIGEN_STRONG_INLINE float predux_mul<Packet8f>(const Packet8f& a)
{
@@ -1284,6 +1619,10 @@
{
return _mm256_movemask_ps(_mm256_castsi256_ps(x)) != 0;
}
+template<> EIGEN_STRONG_INLINE bool predux_any(const Packet8ui& x)
+{
+ return _mm256_movemask_ps(_mm256_castsi256_ps(x)) != 0;
+}
EIGEN_DEVICE_FUNC inline void
ptranspose(PacketBlock<Packet8f,8>& kernel) {
@@ -1372,6 +1711,9 @@
kernel.packet[6] = _mm256_permute2f128_si256(S2, S6, 0x31);
kernel.packet[7] = _mm256_permute2f128_si256(S3, S7, 0x31);
}
+EIGEN_DEVICE_FUNC inline void ptranspose(PacketBlock<Packet8ui, 8>& kernel) {
+ ptranspose((PacketBlock<Packet8i, 8>&)kernel);
+}
EIGEN_DEVICE_FUNC inline void
ptranspose(PacketBlock<Packet8i,4>& kernel) {
@@ -1390,6 +1732,9 @@
kernel.packet[2] = _mm256_permute2f128_si256(S0, S1, 0x31);
kernel.packet[3] = _mm256_permute2f128_si256(S2, S3, 0x31);
}
+EIGEN_DEVICE_FUNC inline void ptranspose(PacketBlock<Packet8ui, 4>& kernel) {
+ ptranspose((PacketBlock<Packet8i, 4>&)kernel);
+}
EIGEN_DEVICE_FUNC inline void
ptranspose(PacketBlock<Packet4d,4>& kernel) {
@@ -1405,7 +1750,7 @@
}
template<> EIGEN_STRONG_INLINE Packet8f pblend(const Selector<8>& ifPacket, const Packet8f& thenPacket, const Packet8f& elsePacket) {
-#ifdef EIGEN_VECTORIZE_AVX2
+#ifdef EIGEN_VECTORIZE_AVX2
const __m256i zero = _mm256_setzero_si256();
const __m256i select = _mm256_set_epi32(ifPacket.select[7], ifPacket.select[6], ifPacket.select[5], ifPacket.select[4], ifPacket.select[3], ifPacket.select[2], ifPacket.select[1], ifPacket.select[0]);
__m256i false_mask = _mm256_cmpeq_epi32(zero, select);
@@ -1419,7 +1764,7 @@
}
template<> EIGEN_STRONG_INLINE Packet4d pblend(const Selector<4>& ifPacket, const Packet4d& thenPacket, const Packet4d& elsePacket) {
-#ifdef EIGEN_VECTORIZE_AVX2
+#ifdef EIGEN_VECTORIZE_AVX2
const __m256i zero = _mm256_setzero_si256();
const __m256i select = _mm256_set_epi64x(ifPacket.select[3], ifPacket.select[2], ifPacket.select[1], ifPacket.select[0]);
__m256i false_mask = _mm256_cmpeq_epi64(select, zero);
@@ -1478,7 +1823,7 @@
}
template<> EIGEN_STRONG_INLINE Packet8h ptrue(const Packet8h& a) {
- return _mm_cmpeq_epi32(a, a);
+ return _mm_cmpeq_epi32(a, a);
}
template <>
@@ -1850,7 +2195,7 @@
}
template<> EIGEN_STRONG_INLINE Packet8bf ptrue(const Packet8bf& a) {
- return _mm_cmpeq_epi32(a, a);
+ return _mm_cmpeq_epi32(a, a);
}
template <>
diff --git a/Eigen/src/Core/arch/AltiVec/MatrixProduct.h b/Eigen/src/Core/arch/AltiVec/MatrixProduct.h
index 5027060..f4e7e8d 100644
--- a/Eigen/src/Core/arch/AltiVec/MatrixProduct.h
+++ b/Eigen/src/Core/arch/AltiVec/MatrixProduct.h
@@ -841,7 +841,6 @@
}
};
-#if !EIGEN_ALTIVEC_DISABLE_MMA && defined(__MMA__)
// General template for lhs packing, bfloat16 specialization.
template<typename DataMapper, int StorageOrder, bool PanelMode>
struct dhs_pack<bfloat16, DataMapper, Packet8bf, StorageOrder, PanelMode, true>
@@ -900,42 +899,60 @@
bload<DataMapper, Packet8bf, 8, StorageOrder, false, 8>(block1, lhs2, 0 * vectorSize, i);
bload<DataMapper, Packet8bf, 8, StorageOrder, false, 8>(block2, lhs2, 1 * vectorSize, i);
- Packet2ul v1[8], v2[8];
+ Packet4ui v1[8], v2[8];
- v1[0] = reinterpret_cast<Packet2ul>(vec_mergeh(reinterpret_cast<Packet4ui>(block1.packet[0].m_val), reinterpret_cast<Packet4ui>(block1.packet[1].m_val)));
- v1[1] = reinterpret_cast<Packet2ul>(vec_mergel(reinterpret_cast<Packet4ui>(block1.packet[0].m_val), reinterpret_cast<Packet4ui>(block1.packet[1].m_val)));
- v1[2] = reinterpret_cast<Packet2ul>(vec_mergeh(reinterpret_cast<Packet4ui>(block1.packet[2].m_val), reinterpret_cast<Packet4ui>(block1.packet[3].m_val)));
- v1[3] = reinterpret_cast<Packet2ul>(vec_mergel(reinterpret_cast<Packet4ui>(block1.packet[2].m_val), reinterpret_cast<Packet4ui>(block1.packet[3].m_val)));
- v1[4] = reinterpret_cast<Packet2ul>(vec_mergeh(reinterpret_cast<Packet4ui>(block1.packet[4].m_val), reinterpret_cast<Packet4ui>(block1.packet[5].m_val)));
- v1[5] = reinterpret_cast<Packet2ul>(vec_mergel(reinterpret_cast<Packet4ui>(block1.packet[4].m_val), reinterpret_cast<Packet4ui>(block1.packet[5].m_val)));
- v1[6] = reinterpret_cast<Packet2ul>(vec_mergeh(reinterpret_cast<Packet4ui>(block1.packet[6].m_val), reinterpret_cast<Packet4ui>(block1.packet[7].m_val)));
- v1[7] = reinterpret_cast<Packet2ul>(vec_mergel(reinterpret_cast<Packet4ui>(block1.packet[6].m_val), reinterpret_cast<Packet4ui>(block1.packet[7].m_val)));
- v2[0] = reinterpret_cast<Packet2ul>(vec_mergeh(reinterpret_cast<Packet4ui>(block2.packet[0].m_val), reinterpret_cast<Packet4ui>(block2.packet[1].m_val)));
- v2[1] = reinterpret_cast<Packet2ul>(vec_mergel(reinterpret_cast<Packet4ui>(block2.packet[0].m_val), reinterpret_cast<Packet4ui>(block2.packet[1].m_val)));
- v2[2] = reinterpret_cast<Packet2ul>(vec_mergeh(reinterpret_cast<Packet4ui>(block2.packet[2].m_val), reinterpret_cast<Packet4ui>(block2.packet[3].m_val)));
- v2[3] = reinterpret_cast<Packet2ul>(vec_mergel(reinterpret_cast<Packet4ui>(block2.packet[2].m_val), reinterpret_cast<Packet4ui>(block2.packet[3].m_val)));
- v2[4] = reinterpret_cast<Packet2ul>(vec_mergeh(reinterpret_cast<Packet4ui>(block2.packet[4].m_val), reinterpret_cast<Packet4ui>(block2.packet[5].m_val)));
- v2[5] = reinterpret_cast<Packet2ul>(vec_mergel(reinterpret_cast<Packet4ui>(block2.packet[4].m_val), reinterpret_cast<Packet4ui>(block2.packet[5].m_val)));
- v2[6] = reinterpret_cast<Packet2ul>(vec_mergeh(reinterpret_cast<Packet4ui>(block2.packet[6].m_val), reinterpret_cast<Packet4ui>(block2.packet[7].m_val)));
- v2[7] = reinterpret_cast<Packet2ul>(vec_mergel(reinterpret_cast<Packet4ui>(block2.packet[6].m_val), reinterpret_cast<Packet4ui>(block2.packet[7].m_val)));
+ v1[0] = vec_mergeh(reinterpret_cast<Packet4ui>(block1.packet[0].m_val), reinterpret_cast<Packet4ui>(block1.packet[1].m_val));
+ v1[1] = vec_mergel(reinterpret_cast<Packet4ui>(block1.packet[0].m_val), reinterpret_cast<Packet4ui>(block1.packet[1].m_val));
+ v1[2] = vec_mergeh(reinterpret_cast<Packet4ui>(block1.packet[2].m_val), reinterpret_cast<Packet4ui>(block1.packet[3].m_val));
+ v1[3] = vec_mergel(reinterpret_cast<Packet4ui>(block1.packet[2].m_val), reinterpret_cast<Packet4ui>(block1.packet[3].m_val));
+ v1[4] = vec_mergeh(reinterpret_cast<Packet4ui>(block1.packet[4].m_val), reinterpret_cast<Packet4ui>(block1.packet[5].m_val));
+ v1[5] = vec_mergel(reinterpret_cast<Packet4ui>(block1.packet[4].m_val), reinterpret_cast<Packet4ui>(block1.packet[5].m_val));
+ v1[6] = vec_mergeh(reinterpret_cast<Packet4ui>(block1.packet[6].m_val), reinterpret_cast<Packet4ui>(block1.packet[7].m_val));
+ v1[7] = vec_mergel(reinterpret_cast<Packet4ui>(block1.packet[6].m_val), reinterpret_cast<Packet4ui>(block1.packet[7].m_val));
+ v2[0] = vec_mergeh(reinterpret_cast<Packet4ui>(block2.packet[0].m_val), reinterpret_cast<Packet4ui>(block2.packet[1].m_val));
+ v2[1] = vec_mergel(reinterpret_cast<Packet4ui>(block2.packet[0].m_val), reinterpret_cast<Packet4ui>(block2.packet[1].m_val));
+ v2[2] = vec_mergeh(reinterpret_cast<Packet4ui>(block2.packet[2].m_val), reinterpret_cast<Packet4ui>(block2.packet[3].m_val));
+ v2[3] = vec_mergel(reinterpret_cast<Packet4ui>(block2.packet[2].m_val), reinterpret_cast<Packet4ui>(block2.packet[3].m_val));
+ v2[4] = vec_mergeh(reinterpret_cast<Packet4ui>(block2.packet[4].m_val), reinterpret_cast<Packet4ui>(block2.packet[5].m_val));
+ v2[5] = vec_mergel(reinterpret_cast<Packet4ui>(block2.packet[4].m_val), reinterpret_cast<Packet4ui>(block2.packet[5].m_val));
+ v2[6] = vec_mergeh(reinterpret_cast<Packet4ui>(block2.packet[6].m_val), reinterpret_cast<Packet4ui>(block2.packet[7].m_val));
+ v2[7] = vec_mergel(reinterpret_cast<Packet4ui>(block2.packet[6].m_val), reinterpret_cast<Packet4ui>(block2.packet[7].m_val));
- block1.packet[0] = reinterpret_cast<Packet8us>(vec_mergeh(v1[0],v1[2]));
- block1.packet[2] = reinterpret_cast<Packet8us>(vec_mergel(v1[0],v1[2]));
- block1.packet[4] = reinterpret_cast<Packet8us>(vec_mergeh(v1[1],v1[3]));
- block1.packet[6] = reinterpret_cast<Packet8us>(vec_mergel(v1[1],v1[3]));
- block1.packet[1] = reinterpret_cast<Packet8us>(vec_mergeh(v1[4],v1[6]));
- block1.packet[3] = reinterpret_cast<Packet8us>(vec_mergel(v1[4],v1[6]));
- block1.packet[5] = reinterpret_cast<Packet8us>(vec_mergeh(v1[5],v1[7]));
- block1.packet[7] = reinterpret_cast<Packet8us>(vec_mergel(v1[5],v1[7]));
- block2.packet[0] = reinterpret_cast<Packet8us>(vec_mergeh(v2[0],v2[2]));
- block2.packet[2] = reinterpret_cast<Packet8us>(vec_mergel(v2[0],v2[2]));
- block2.packet[4] = reinterpret_cast<Packet8us>(vec_mergeh(v2[1],v2[3]));
- block2.packet[6] = reinterpret_cast<Packet8us>(vec_mergel(v2[1],v2[3]));
- block2.packet[1] = reinterpret_cast<Packet8us>(vec_mergeh(v2[4],v2[6]));
- block2.packet[3] = reinterpret_cast<Packet8us>(vec_mergel(v2[4],v2[6]));
- block2.packet[5] = reinterpret_cast<Packet8us>(vec_mergeh(v2[5],v2[7]));
- block2.packet[7] = reinterpret_cast<Packet8us>(vec_mergel(v2[5],v2[7]));
-
+#ifdef EIGEN_VECTORIZE_VSX
+ block1.packet[0] = reinterpret_cast<Packet8us>(vec_mergeh(reinterpret_cast<Packet2ul>(v1[0]),reinterpret_cast<Packet2ul>(v1[2])));
+ block1.packet[2] = reinterpret_cast<Packet8us>(vec_mergel(reinterpret_cast<Packet2ul>(v1[0]),reinterpret_cast<Packet2ul>(v1[2])));
+ block1.packet[4] = reinterpret_cast<Packet8us>(vec_mergeh(reinterpret_cast<Packet2ul>(v1[1]),reinterpret_cast<Packet2ul>(v1[3])));
+ block1.packet[6] = reinterpret_cast<Packet8us>(vec_mergel(reinterpret_cast<Packet2ul>(v1[1]),reinterpret_cast<Packet2ul>(v1[3])));
+ block1.packet[1] = reinterpret_cast<Packet8us>(vec_mergeh(reinterpret_cast<Packet2ul>(v1[4]),reinterpret_cast<Packet2ul>(v1[6])));
+ block1.packet[3] = reinterpret_cast<Packet8us>(vec_mergel(reinterpret_cast<Packet2ul>(v1[4]),reinterpret_cast<Packet2ul>(v1[6])));
+ block1.packet[5] = reinterpret_cast<Packet8us>(vec_mergeh(reinterpret_cast<Packet2ul>(v1[5]),reinterpret_cast<Packet2ul>(v1[7])));
+ block1.packet[7] = reinterpret_cast<Packet8us>(vec_mergel(reinterpret_cast<Packet2ul>(v1[5]),reinterpret_cast<Packet2ul>(v1[7])));
+ block2.packet[0] = reinterpret_cast<Packet8us>(vec_mergeh(reinterpret_cast<Packet2ul>(v2[0]),reinterpret_cast<Packet2ul>(v2[2])));
+ block2.packet[2] = reinterpret_cast<Packet8us>(vec_mergel(reinterpret_cast<Packet2ul>(v2[0]),reinterpret_cast<Packet2ul>(v2[2])));
+ block2.packet[4] = reinterpret_cast<Packet8us>(vec_mergeh(reinterpret_cast<Packet2ul>(v2[1]),reinterpret_cast<Packet2ul>(v2[3])));
+ block2.packet[6] = reinterpret_cast<Packet8us>(vec_mergel(reinterpret_cast<Packet2ul>(v2[1]),reinterpret_cast<Packet2ul>(v2[3])));
+ block2.packet[1] = reinterpret_cast<Packet8us>(vec_mergeh(reinterpret_cast<Packet2ul>(v2[4]),reinterpret_cast<Packet2ul>(v2[6])));
+ block2.packet[3] = reinterpret_cast<Packet8us>(vec_mergel(reinterpret_cast<Packet2ul>(v2[4]),reinterpret_cast<Packet2ul>(v2[6])));
+ block2.packet[5] = reinterpret_cast<Packet8us>(vec_mergeh(reinterpret_cast<Packet2ul>(v2[5]),reinterpret_cast<Packet2ul>(v2[7])));
+ block2.packet[7] = reinterpret_cast<Packet8us>(vec_mergel(reinterpret_cast<Packet2ul>(v2[5]),reinterpret_cast<Packet2ul>(v2[7])));
+#else
+ block1.packet[0] = reinterpret_cast<Packet8us>(vec_perm(v1[0],v1[2],p16uc_TRANSPOSE64_HI));
+ block1.packet[2] = reinterpret_cast<Packet8us>(vec_perm(v1[0],v1[2],p16uc_TRANSPOSE64_LO));
+ block1.packet[4] = reinterpret_cast<Packet8us>(vec_perm(v1[1],v1[3],p16uc_TRANSPOSE64_HI));
+ block1.packet[6] = reinterpret_cast<Packet8us>(vec_perm(v1[1],v1[3],p16uc_TRANSPOSE64_LO));
+ block1.packet[1] = reinterpret_cast<Packet8us>(vec_perm(v1[4],v1[6],p16uc_TRANSPOSE64_HI));
+ block1.packet[3] = reinterpret_cast<Packet8us>(vec_perm(v1[4],v1[6],p16uc_TRANSPOSE64_LO));
+ block1.packet[5] = reinterpret_cast<Packet8us>(vec_perm(v1[5],v1[7],p16uc_TRANSPOSE64_HI));
+ block1.packet[7] = reinterpret_cast<Packet8us>(vec_perm(v1[5],v1[7],p16uc_TRANSPOSE64_LO));
+ block2.packet[0] = reinterpret_cast<Packet8us>(vec_perm(v2[0],v2[2],p16uc_TRANSPOSE64_HI));
+ block2.packet[2] = reinterpret_cast<Packet8us>(vec_perm(v2[0],v2[2],p16uc_TRANSPOSE64_LO));
+ block2.packet[4] = reinterpret_cast<Packet8us>(vec_perm(v2[1],v2[3],p16uc_TRANSPOSE64_HI));
+ block2.packet[6] = reinterpret_cast<Packet8us>(vec_perm(v2[1],v2[3],p16uc_TRANSPOSE64_LO));
+ block2.packet[1] = reinterpret_cast<Packet8us>(vec_perm(v2[4],v2[6],p16uc_TRANSPOSE64_HI));
+ block2.packet[3] = reinterpret_cast<Packet8us>(vec_perm(v2[4],v2[6],p16uc_TRANSPOSE64_LO));
+ block2.packet[5] = reinterpret_cast<Packet8us>(vec_perm(v2[5],v2[7],p16uc_TRANSPOSE64_HI));
+ block2.packet[7] = reinterpret_cast<Packet8us>(vec_perm(v2[5],v2[7],p16uc_TRANSPOSE64_LO));
+#endif
for(Index M = 0; M < 8; M+=2) {
pstore<bfloat16>(blockA + ri + (0 * vectorSize) + (2*vectorSize * M), block1.packet[M+0]);
@@ -1005,26 +1022,37 @@
bload<DataMapper, Packet8bf, 8, StorageOrder, false, 8>(block1, lhs2, 0 * vectorSize, i);
- Packet2ul v1[8];
+ Packet4ui v1[8];
// This is transposing and interleaving data
- v1[0] = reinterpret_cast<Packet2ul>(vec_mergeh(reinterpret_cast<Packet4ui>(block1.packet[0].m_val), reinterpret_cast<Packet4ui>(block1.packet[1].m_val)));
- v1[1] = reinterpret_cast<Packet2ul>(vec_mergel(reinterpret_cast<Packet4ui>(block1.packet[0].m_val), reinterpret_cast<Packet4ui>(block1.packet[1].m_val)));
- v1[2] = reinterpret_cast<Packet2ul>(vec_mergeh(reinterpret_cast<Packet4ui>(block1.packet[2].m_val), reinterpret_cast<Packet4ui>(block1.packet[3].m_val)));
- v1[3] = reinterpret_cast<Packet2ul>(vec_mergel(reinterpret_cast<Packet4ui>(block1.packet[2].m_val), reinterpret_cast<Packet4ui>(block1.packet[3].m_val)));
- v1[4] = reinterpret_cast<Packet2ul>(vec_mergeh(reinterpret_cast<Packet4ui>(block1.packet[4].m_val), reinterpret_cast<Packet4ui>(block1.packet[5].m_val)));
- v1[5] = reinterpret_cast<Packet2ul>(vec_mergel(reinterpret_cast<Packet4ui>(block1.packet[4].m_val), reinterpret_cast<Packet4ui>(block1.packet[5].m_val)));
- v1[6] = reinterpret_cast<Packet2ul>(vec_mergeh(reinterpret_cast<Packet4ui>(block1.packet[6].m_val), reinterpret_cast<Packet4ui>(block1.packet[7].m_val)));
- v1[7] = reinterpret_cast<Packet2ul>(vec_mergel(reinterpret_cast<Packet4ui>(block1.packet[6].m_val), reinterpret_cast<Packet4ui>(block1.packet[7].m_val)));
+ v1[0] = vec_mergeh(reinterpret_cast<Packet4ui>(block1.packet[0].m_val), reinterpret_cast<Packet4ui>(block1.packet[1].m_val));
+ v1[1] = vec_mergel(reinterpret_cast<Packet4ui>(block1.packet[0].m_val), reinterpret_cast<Packet4ui>(block1.packet[1].m_val));
+ v1[2] = vec_mergeh(reinterpret_cast<Packet4ui>(block1.packet[2].m_val), reinterpret_cast<Packet4ui>(block1.packet[3].m_val));
+ v1[3] = vec_mergel(reinterpret_cast<Packet4ui>(block1.packet[2].m_val), reinterpret_cast<Packet4ui>(block1.packet[3].m_val));
+ v1[4] = vec_mergeh(reinterpret_cast<Packet4ui>(block1.packet[4].m_val), reinterpret_cast<Packet4ui>(block1.packet[5].m_val));
+ v1[5] = vec_mergel(reinterpret_cast<Packet4ui>(block1.packet[4].m_val), reinterpret_cast<Packet4ui>(block1.packet[5].m_val));
+ v1[6] = vec_mergeh(reinterpret_cast<Packet4ui>(block1.packet[6].m_val), reinterpret_cast<Packet4ui>(block1.packet[7].m_val));
+ v1[7] = vec_mergel(reinterpret_cast<Packet4ui>(block1.packet[6].m_val), reinterpret_cast<Packet4ui>(block1.packet[7].m_val));
- block1.packet[0] = reinterpret_cast<Packet8us>(vec_mergeh(v1[0],v1[2]));
- block1.packet[2] = reinterpret_cast<Packet8us>(vec_mergel(v1[0],v1[2]));
- block1.packet[4] = reinterpret_cast<Packet8us>(vec_mergeh(v1[1],v1[3]));
- block1.packet[6] = reinterpret_cast<Packet8us>(vec_mergel(v1[1],v1[3]));
- block1.packet[1] = reinterpret_cast<Packet8us>(vec_mergeh(v1[4],v1[6]));
- block1.packet[3] = reinterpret_cast<Packet8us>(vec_mergel(v1[4],v1[6]));
- block1.packet[5] = reinterpret_cast<Packet8us>(vec_mergeh(v1[5],v1[7]));
- block1.packet[7] = reinterpret_cast<Packet8us>(vec_mergel(v1[5],v1[7]));
+#ifdef EIGEN_VECTORIZE_VSX
+ block1.packet[0] = reinterpret_cast<Packet8us>(vec_mergeh(reinterpret_cast<Packet2ul>(v1[0]),reinterpret_cast<Packet2ul>(v1[2])));
+ block1.packet[2] = reinterpret_cast<Packet8us>(vec_mergel(reinterpret_cast<Packet2ul>(v1[0]),reinterpret_cast<Packet2ul>(v1[2])));
+ block1.packet[4] = reinterpret_cast<Packet8us>(vec_mergeh(reinterpret_cast<Packet2ul>(v1[1]),reinterpret_cast<Packet2ul>(v1[3])));
+ block1.packet[6] = reinterpret_cast<Packet8us>(vec_mergel(reinterpret_cast<Packet2ul>(v1[1]),reinterpret_cast<Packet2ul>(v1[3])));
+ block1.packet[1] = reinterpret_cast<Packet8us>(vec_mergeh(reinterpret_cast<Packet2ul>(v1[4]),reinterpret_cast<Packet2ul>(v1[6])));
+ block1.packet[3] = reinterpret_cast<Packet8us>(vec_mergel(reinterpret_cast<Packet2ul>(v1[4]),reinterpret_cast<Packet2ul>(v1[6])));
+ block1.packet[5] = reinterpret_cast<Packet8us>(vec_mergeh(reinterpret_cast<Packet2ul>(v1[5]),reinterpret_cast<Packet2ul>(v1[7])));
+ block1.packet[7] = reinterpret_cast<Packet8us>(vec_mergel(reinterpret_cast<Packet2ul>(v1[5]),reinterpret_cast<Packet2ul>(v1[7])));
+#else
+ block1.packet[0] = reinterpret_cast<Packet8us>(vec_perm(v1[0],v1[2],p16uc_TRANSPOSE64_HI));
+ block1.packet[2] = reinterpret_cast<Packet8us>(vec_perm(v1[0],v1[2],p16uc_TRANSPOSE64_LO));
+ block1.packet[4] = reinterpret_cast<Packet8us>(vec_perm(v1[1],v1[3],p16uc_TRANSPOSE64_HI));
+ block1.packet[6] = reinterpret_cast<Packet8us>(vec_perm(v1[1],v1[3],p16uc_TRANSPOSE64_LO));
+ block1.packet[1] = reinterpret_cast<Packet8us>(vec_perm(v1[4],v1[6],p16uc_TRANSPOSE64_HI));
+ block1.packet[3] = reinterpret_cast<Packet8us>(vec_perm(v1[4],v1[6],p16uc_TRANSPOSE64_LO));
+ block1.packet[5] = reinterpret_cast<Packet8us>(vec_perm(v1[5],v1[7],p16uc_TRANSPOSE64_HI));
+ block1.packet[7] = reinterpret_cast<Packet8us>(vec_perm(v1[5],v1[7],p16uc_TRANSPOSE64_LO));
+#endif
for(Index M = 0; M < 8; M++) {
pstore<bfloat16>(blockA + ri + (vectorSize * M), block1.packet[M]);
@@ -1157,16 +1185,24 @@
bload<DataMapper, Packet8bf, 4, StorageOrder, false, 4>(block, rhs2, i, 0);
- Packet2ul t0, t1, t2, t3;
- t0 = reinterpret_cast<Packet2ul>(vec_mergeh(reinterpret_cast<Packet4ui>(block.packet[0].m_val), reinterpret_cast<Packet4ui>(block.packet[1].m_val)));
- t1 = reinterpret_cast<Packet2ul>(vec_mergeh(reinterpret_cast<Packet4ui>(block.packet[2].m_val), reinterpret_cast<Packet4ui>(block.packet[3].m_val)));
- t2 = reinterpret_cast<Packet2ul>(vec_mergel(reinterpret_cast<Packet4ui>(block.packet[0].m_val), reinterpret_cast<Packet4ui>(block.packet[1].m_val)));
- t3 = reinterpret_cast<Packet2ul>(vec_mergel(reinterpret_cast<Packet4ui>(block.packet[2].m_val), reinterpret_cast<Packet4ui>(block.packet[3].m_val)));
+ Packet4ui t0, t1, t2, t3;
- block.packet[0] = reinterpret_cast<Packet8us>(vec_mergeh(t0, t1));
- block.packet[1] = reinterpret_cast<Packet8us>(vec_mergel(t0, t1));
- block.packet[2] = reinterpret_cast<Packet8us>(vec_mergeh(t2, t3));
- block.packet[3] = reinterpret_cast<Packet8us>(vec_mergel(t2, t3));
+ t0 = vec_mergeh(reinterpret_cast<Packet4ui>(block.packet[0].m_val), reinterpret_cast<Packet4ui>(block.packet[1].m_val));
+ t1 = vec_mergel(reinterpret_cast<Packet4ui>(block.packet[0].m_val), reinterpret_cast<Packet4ui>(block.packet[1].m_val));
+ t2 = vec_mergeh(reinterpret_cast<Packet4ui>(block.packet[2].m_val), reinterpret_cast<Packet4ui>(block.packet[3].m_val));
+ t3 = vec_mergel(reinterpret_cast<Packet4ui>(block.packet[2].m_val), reinterpret_cast<Packet4ui>(block.packet[3].m_val));
+
+#ifdef EIGEN_VECTORIZE_VSX
+ block.packet[0] = reinterpret_cast<Packet8us>(vec_mergeh(reinterpret_cast<Packet2ul>(t0),reinterpret_cast<Packet2ul>(t2)));
+ block.packet[1] = reinterpret_cast<Packet8us>(vec_mergel(reinterpret_cast<Packet2ul>(t0),reinterpret_cast<Packet2ul>(t2)));
+ block.packet[2] = reinterpret_cast<Packet8us>(vec_mergeh(reinterpret_cast<Packet2ul>(t1),reinterpret_cast<Packet2ul>(t3)));
+ block.packet[3] = reinterpret_cast<Packet8us>(vec_mergel(reinterpret_cast<Packet2ul>(t1),reinterpret_cast<Packet2ul>(t3)));
+#else
+ block.packet[0] = reinterpret_cast<Packet8us>(vec_perm(t0,t2,p16uc_TRANSPOSE64_HI));
+ block.packet[1] = reinterpret_cast<Packet8us>(vec_perm(t0,t2,p16uc_TRANSPOSE64_LO));
+ block.packet[2] = reinterpret_cast<Packet8us>(vec_perm(t1,t3,p16uc_TRANSPOSE64_HI));
+ block.packet[3] = reinterpret_cast<Packet8us>(vec_perm(t1,t3,p16uc_TRANSPOSE64_LO));
+#endif
storeBlock<bfloat16, Packet8bf, 4>(blockB + ri, block);
} else {
@@ -1254,7 +1290,6 @@
}
}
};
-#endif
// General template for lhs complex packing, float64 specialization.
template<typename DataMapper, typename Packet, typename PacketC, int StorageOrder, bool Conjugate, bool PanelMode>
@@ -2674,8 +2709,596 @@
#undef advanceCols
#undef advanceRows
+EIGEN_ALWAYS_INLINE bool supportsMMA()
+{
+#if defined(EIGEN_ALTIVEC_MMA_ONLY)
+ return true;
+#else
+#if EIGEN_COMP_LLVM
+ return false; // No dynamic dispatch for LLVM
+#else
+ return __builtin_cpu_supports ("arch_3_1") && __builtin_cpu_supports ("mma");
+#endif
+#endif
+}
+
+EIGEN_ALWAYS_INLINE Packet4f loadAndMultiplyF32(Packet4f acc, const Packet4f pAlpha, float* result)
+{
+ Packet4f result_block = ploadu<Packet4f>(result);
+ return pmadd(acc, pAlpha, result_block);
+}
+
+template<bool lhsExtraRows>
+EIGEN_ALWAYS_INLINE void storeF32(float*& result, Packet4f result_block, Index rows, Index extra_rows)
+{
+ if (lhsExtraRows) {
+ pstoreu_partial(result, result_block, extra_rows);
+ } else {
+ pstoreu(result, result_block);
+ }
+ result += rows;
+}
+
+template<bool rhsExtraCols, bool lhsExtraRows>
+EIGEN_ALWAYS_INLINE void storeResults(Packet4f (&acc)[4], Index rows, const Packet4f pAlpha, float* result, Index extra_cols, Index extra_rows)
+{
+ Index x = 0;
+ if (rhsExtraCols) {
+ do{
+ Packet4f result_block = loadAndMultiplyF32(acc[x], pAlpha, result);
+ storeF32<lhsExtraRows>(result, result_block, rows, extra_rows);
+ } while (++x < extra_cols);
+ } else {
+ Packet4f result_block[4];
+ float *result2 = result;
+ do{
+ result_block[x] = loadAndMultiplyF32(acc[x], pAlpha, result);
+ result += rows;
+ } while (++x < 4);
+ x = 0;
+ do{
+ storeF32<lhsExtraRows>(result2, result_block[x], rows, extra_rows);
+ } while (++x < 4);
+ }
+}
+
+EIGEN_ALWAYS_INLINE Packet4f oneConvertBF16Hi(Packet8us data)
+{
+ Packet8us z = pset1<Packet8us>(0);
+#ifdef _BIG_ENDIAN
+ return reinterpret_cast<Packet4f>(vec_mergeh(data, z));
+#else
+ return reinterpret_cast<Packet4f>(vec_mergeh(z, data));
+#endif
+}
+
+EIGEN_ALWAYS_INLINE Packet4f oneConvertBF16Lo(Packet8us data)
+{
+ Packet8us z = pset1<Packet8us>(0);
+#ifdef _BIG_ENDIAN
+ return reinterpret_cast<Packet4f>(vec_mergel(data, z));
+#else
+ return reinterpret_cast<Packet4f>(vec_mergel(z, data));
+#endif
+}
+
+template<Index N, Index M>
+EIGEN_ALWAYS_INLINE void storeConvertTwoBF16(float* to, PacketBlock<Packet8bf,(N+7)/8>& block, Index extra = 0)
+{
+ if (N < 4) {
+ pstoreu_partial(to + 0, oneConvertBF16Hi(block.packet[0].m_val), extra);
+ } else if (N >= (M*8+4)) {
+ pstoreu(to + 0, oneConvertBF16Hi(block.packet[M].m_val));
+ if (N >= 8) {
+ pstoreu(to + 4, oneConvertBF16Lo(block.packet[M].m_val));
+ }
+ }
+}
+
+template<Index N>
+EIGEN_ALWAYS_INLINE void storeConvertBlockBF16(float* to, PacketBlock<Packet8bf,(N+7)/8>& block, Index extra)
+{
+ storeConvertTwoBF16<N, 0>(to + 0, block, extra);
+ if (N >= 16) {
+ storeConvertTwoBF16<N, 1>(to + 8, block);
+ }
+ if (N >= 32) {
+ storeConvertTwoBF16<N, 2>(to + 16, block);
+ storeConvertTwoBF16<N, 3>(to + 24, block);
+ }
+}
+
+template<bool non_unit_stride, Index delta>
+EIGEN_ALWAYS_INLINE Packet8bf loadBF16fromResult(bfloat16* src, Index resInc)
+{
+ if (non_unit_stride) {
+ return pgather<bfloat16, Packet8bf>(src + delta*resInc, resInc);
+ } else {
+ return ploadu<Packet8bf>(src + delta);
+ }
+}
+
+static Packet16uc p16uc_MERGE16_32_1 = { 0, 1, 16,17, 2, 3, 18,19, 0, 1, 16,17, 2, 3, 18,19 };
+static Packet16uc p16uc_MERGE16_32_2 = { 4, 5, 20,21, 6, 7, 22,23, 4, 5, 20,21, 6, 7, 22,23 };
+static Packet16uc p16uc_MERGE16_32_3 = { 8, 9, 24,25, 10,11, 26,27, 8, 9, 24,25, 10,11, 26,27 };
+static Packet16uc p16uc_MERGE16_32_4 = { 12,13, 28,29, 14,15, 30,31, 12,13, 28,29, 14,15, 30,31 };
+
+static Packet16uc p16uc_MERGE16_32_5 = { 0,1, 16,17, 16,17, 16,17, 0,1, 16,17, 16,17, 16,17 };
+static Packet16uc p16uc_MERGE16_32_6 = { 2,3, 18,19, 18,19, 18,19, 2,3, 18,19, 18,19, 18,19 };
+static Packet16uc p16uc_MERGE16_32_7 = { 4,5, 20,21, 20,21, 20,21, 4,5, 20,21, 20,21, 20,21 };
+static Packet16uc p16uc_MERGE16_32_8 = { 6,7, 22,23, 22,23, 22,23, 6,7, 22,23, 22,23, 22,23 };
+
+EIGEN_ALWAYS_INLINE Packet4f oneConvertBF16Perm(Packet8us data, Packet16uc mask)
+{
+ Packet8us z = pset1<Packet8us>(0);
+#ifdef _BIG_ENDIAN
+ return reinterpret_cast<Packet4f>(vec_perm(data, z, mask));
+#else
+ return reinterpret_cast<Packet4f>(vec_perm(z, data, mask));
+#endif
+}
+
+template<bool lhsExtraRows, bool odd, Index size>
+EIGEN_ALWAYS_INLINE void convertArrayPointerBF16toF32DupOne(float *result, Index col, Index rows, const bfloat16* src, Index extra_rows)
+{
+ Packet4f dup[4*4];
+ Packet8bf data[4];
+
+ for (Index i = 0; i < size; i++) {
+ data[i] = ploadu<Packet8bf>(src + col + rows*i);
+ }
+
+ for (Index i = 0, j = 0; i < size; i++, j += 4) {
+ dup[j+0] = oneConvertBF16Perm(data[i].m_val, odd ? p16uc_MERGE16_32_5 : p16uc_MERGE16_32_1);
+ dup[j+1] = oneConvertBF16Perm(data[i].m_val, odd ? p16uc_MERGE16_32_6 : p16uc_MERGE16_32_2);
+ dup[j+2] = oneConvertBF16Perm(data[i].m_val, odd ? p16uc_MERGE16_32_7 : p16uc_MERGE16_32_3);
+ dup[j+3] = oneConvertBF16Perm(data[i].m_val, odd ? p16uc_MERGE16_32_8 : p16uc_MERGE16_32_4);
+ }
+
+ for (Index j = 0; j < 4*size; j += 4) {
+ if (lhsExtraRows) {
+ Packet4f z = pset1<Packet4f>(float(0));
+ Index i = 0;
+ do {
+ pstoreu(result + (j+i)*4, dup[j+i]);
+ } while (++i < extra_rows);
+ do {
+ pstoreu(result + (j+i)*4, z);
+ } while (++i < 4);
+ } else {
+ for (Index i = 0; i < 4; i++) {
+ pstoreu(result + (j+i)*4, dup[j+i]);
+ }
+ }
+ }
+}
+
+template<bool lhsExtraRows>
+EIGEN_ALWAYS_INLINE void convertArrayPointerBF16toF32Dup(float *result, Index cols, Index rows, const bfloat16* src, Index delta, Index extra_rows)
+{
+ Index col2 = 0, col = 0;
+ for(; col + 4*2 <= cols; col += 4*2, col2 += 4*rows, result += 4*4*4) {
+ convertArrayPointerBF16toF32DupOne<lhsExtraRows,false,4>(result, col2 + delta*2, rows, src, extra_rows);
+ }
+ for(; col + 2 <= cols; col += 2, col2 += rows, result += 4*4) {
+ convertArrayPointerBF16toF32DupOne<lhsExtraRows,false,1>(result, col2 + delta*2, rows, src, extra_rows);
+ }
+ if (cols & 1) {
+ convertArrayPointerBF16toF32DupOne<lhsExtraRows,true,1>(result, col2 + delta, rows, src, extra_rows);
+ }
+}
+
+template<const Index size, bool non_unit_stride>
+EIGEN_ALWAYS_INLINE void convertPointerBF16toF32(Index& i, float *result, Index rows, bfloat16*& src, Index resInc)
+{
+ constexpr Index extra = ((size < 4) ? 4 : size);
+ for(; i + size <= rows; i += extra, src += extra*resInc){
+ PacketBlock<Packet8bf,(size+7)/8> r32;
+ r32.packet[0] = loadBF16fromResult<non_unit_stride, 0>(src, resInc);
+ if (size >= 16) {
+ r32.packet[1] = loadBF16fromResult<non_unit_stride, 8>(src, resInc);
+ }
+ if (size >= 32) {
+ r32.packet[2] = loadBF16fromResult<non_unit_stride, 16>(src, resInc);
+ r32.packet[3] = loadBF16fromResult<non_unit_stride, 24>(src, resInc);
+ }
+ storeConvertBlockBF16<size>(result + i, r32, rows & 3);
+ }
+}
+
+template<bool non_unit_stride>
+EIGEN_ALWAYS_INLINE void convertArrayPointerBF16toF32(float *result, Index cols, Index rows, bfloat16* src, Index resInc)
+{
+ for(Index col = 0; col < cols; col++, src += (rows*resInc), result += rows) {
+ Index i = 0;
+ bfloat16* src2 = src;
+ convertPointerBF16toF32<32, non_unit_stride>(i, result, rows, src2, resInc);
+ convertPointerBF16toF32<16, non_unit_stride>(i, result, rows, src2, resInc);
+ convertPointerBF16toF32<8, non_unit_stride>(i, result, rows, src2, resInc);
+ convertPointerBF16toF32<4, non_unit_stride>(i, result, rows, src2, resInc);
+ convertPointerBF16toF32<1, non_unit_stride>(i, result, rows, src2, resInc);
+ }
+}
+
+template<Index num_acc>
+EIGEN_ALWAYS_INLINE void zeroAccumulators(Packet4f (&acc)[num_acc][4])
+{
+ Packet4f z = pset1<Packet4f>(float(0));
+
+ for(Index k = 0; k < num_acc; k++) {
+ for(Index j = 0; j < 4; j++) {
+ acc[k][j] = z;
+ }
+ }
+}
+
+template<Index num_acc>
+EIGEN_ALWAYS_INLINE void tranposeResults(Packet4f (&acc)[num_acc][4])
+{
+ for(Index i = 0; i < num_acc; i++) {
+ Packet4ui t0, t1, t2, t3;
+ t0 = vec_mergeh(reinterpret_cast<Packet4ui>(acc[i][0]), reinterpret_cast<Packet4ui>(acc[i][2]));
+ t1 = vec_mergel(reinterpret_cast<Packet4ui>(acc[i][0]), reinterpret_cast<Packet4ui>(acc[i][2]));
+ t2 = vec_mergeh(reinterpret_cast<Packet4ui>(acc[i][1]), reinterpret_cast<Packet4ui>(acc[i][3]));
+ t3 = vec_mergel(reinterpret_cast<Packet4ui>(acc[i][1]), reinterpret_cast<Packet4ui>(acc[i][3]));
+ acc[i][0] = reinterpret_cast<Packet4f>(vec_mergeh(t0, t2));
+ acc[i][1] = reinterpret_cast<Packet4f>(vec_mergel(t0, t2));
+ acc[i][2] = reinterpret_cast<Packet4f>(vec_mergeh(t1, t3));
+ acc[i][3] = reinterpret_cast<Packet4f>(vec_mergel(t1, t3));
+ }
+}
+
+template<Index num_acc>
+EIGEN_ALWAYS_INLINE void addResults(Packet4f (&acc)[num_acc][4])
+{
+ for(Index i = 0, j = 0; j < num_acc; i++, j += 2) {
+ for(Index x = 0, y = 0; x < 2; x++, y += 2) {
+ for(Index w = 0, z = 0; w < 2; w++, z += 2) {
+ acc[i][y+w] = acc[j+x][z+0] + acc[j+x][z+1];
+ }
+ }
+ }
+}
+
+template<Index num_acc, bool rhsExtraCols, bool lhsExtraRows, Index num_rhs>
+EIGEN_ALWAYS_INLINE void outputResultsVSX(Packet4f (&acc)[num_acc][4], Index rows, const Packet4f pAlpha, float* result, const Index extra_cols, Index extra_rows)
+{
+ tranposeResults<num_acc>(acc);
+ addResults<num_acc>(acc);
+
+ constexpr Index real_rhs = ((num_rhs / 2) - (rhsExtraCols ? 1 : 0));
+ Index k = 0;
+ for(Index i = 0; i < real_rhs; i++, result += 4*rows, k++){
+ storeResults<false, lhsExtraRows>(acc[k], rows, pAlpha, result, extra_cols, extra_rows);
+ }
+ if(rhsExtraCols) {
+ storeResults<rhsExtraCols, lhsExtraRows>(acc[k], rows, pAlpha, result, extra_cols, extra_rows);
+ }
+}
+
+template<bool zero>
+EIGEN_ALWAYS_INLINE void loadTwoRhsFloat32(const float* block, Index strideB, Index i, Packet4f& dhs0, Packet4f &dhs1)
+{
+ dhs0 = ploadu<Packet4f>(block + strideB*i + 0);
+ if (zero) {
+ Packet4f dhs2 = pset1<Packet4f>(float(0));
+ dhs1 = vec_mergel(dhs0, dhs2);
+ dhs0 = vec_mergeh(dhs0, dhs2);
+ } else {
+ dhs1 = ploadu<Packet4f>(block + strideB*i + 4);
+ }
+}
+
+template<Index num_acc, bool zero, bool rhsExtraCols, Index num_rhs>
+EIGEN_ALWAYS_INLINE void KLoop
+(
+ const float* indexA,
+ const float* indexB,
+ Packet4f (&acc)[num_acc][4],
+ Index strideB,
+ Index k,
+ Index offsetB,
+ Index extra_cols
+)
+{
+ constexpr Index num_lhs = 4;
+ Packet4f lhs[num_lhs], rhs[num_rhs];
+
+ constexpr Index real_rhs = (num_rhs - (rhsExtraCols ? 2 : 0));
+ for(Index i = 0; i < real_rhs; i += 2){
+ loadTwoRhsFloat32<zero>(indexB + k*4, strideB, i, rhs[i + 0], rhs[i + 1]);
+ }
+ if(rhsExtraCols) {
+ loadTwoRhsFloat32<zero>(indexB + k*extra_cols - offsetB, strideB, real_rhs, rhs[real_rhs + 0], rhs[real_rhs + 1]);
+ }
+
+ indexA += 2*k*4;
+ for(Index j = 0; j < num_lhs; j++) {
+ lhs[j] = ploadu<Packet4f>(indexA + j*4);
+ }
+
+ for(Index j = 0; j < num_rhs; j++) {
+ for(Index i = 0; i < num_lhs; i++) {
+ acc[j][i] = pmadd(rhs[j], lhs[i], acc[j][i]);
+ }
+ }
+}
+
+template<const Index num_acc, bool rhsExtraCols, bool lhsExtraRows>
+EIGEN_ALWAYS_INLINE void colVSXLoopBodyIter(Index depth, Index rows, const Packet4f pAlpha, const float* indexA, const float* indexB, Index strideB, Index offsetB, float* result, const Index extra_cols, const Index extra_rows)
+{
+ constexpr Index num_rhs = num_acc;
+
+ Packet4f acc[num_acc][4];
+
+ zeroAccumulators<num_acc>(acc);
+
+ Index k;
+ for(k = 0; k + 2 <= depth; k += 2){
+ KLoop<num_acc, false, rhsExtraCols, num_rhs>(indexA, indexB, acc, strideB, k, offsetB, extra_cols);
+ }
+ if(depth&1){
+ KLoop<num_acc, true, rhsExtraCols, num_rhs>(indexA, indexB, acc, strideB, k, offsetB, extra_cols);
+ }
+
+ outputResultsVSX<num_acc, rhsExtraCols, lhsExtraRows, num_rhs>(acc, rows, pAlpha, result, extra_cols, extra_rows);
+}
+
+// No more than 4 (uses 2X the accumulators or 8X the number of VSX registers)
+#define MAX_BFLOAT16_ACC_VSX 4
+
+template<const Index num_acc, bool rhsExtraCols, bool lhsExtraRows>
+void colVSXLoopBody(Index& col, Index depth, Index cols, Index rows, const Packet4f pAlpha, const float* indexA, const float* indexB, Index strideB, Index offsetB, float* result)
+{
+ constexpr Index step = (num_acc * 4); // each accumulator has 4 elements
+ const Index extra_cols = (rhsExtraCols) ? (cols & 3) : 0;
+ const Index extra_rows = (lhsExtraRows) ? (rows & 3) : 0;
+ constexpr bool multiIters = !rhsExtraCols && (num_acc == MAX_BFLOAT16_ACC_VSX);
+
+ do{
+ colVSXLoopBodyIter<num_acc*2, rhsExtraCols, lhsExtraRows>(depth, rows, pAlpha, indexA, indexB, strideB, offsetB, result, extra_cols, extra_rows);
+
+ indexB += strideB*(num_acc * 2);
+ result += rows*step;
+ } while(multiIters && (step <= cols - (col += step)));
+}
+
+template<const Index num_acc, bool rhsExtraCols, bool lhsExtraRows>
+EIGEN_ALWAYS_INLINE void colVSXLoopBodyExtraN(Index col, Index depth, Index cols, Index rows, const Packet4f pAlpha, const float* indexA, const float* blockB, Index strideB, Index offsetB, float* result)
+{
+ if (MAX_BFLOAT16_ACC_VSX > num_acc) {
+ colVSXLoopBody<num_acc + (rhsExtraCols ? 1 : 0), rhsExtraCols, lhsExtraRows>(col, depth, cols, rows, pAlpha, indexA, blockB, strideB, offsetB, result);
+ }
+}
+
+template<bool rhsExtraCols, bool lhsExtraRows>
+void colVSXLoopBodyExtra(Index col, Index depth, Index cols, Index rows, const Packet4f pAlpha, const float* indexA, const float* blockB, Index strideB, Index offsetB, float* result)
+{
+ switch ((cols - col) >> 2) {
+ case 3:
+ colVSXLoopBodyExtraN<3, rhsExtraCols, lhsExtraRows>(col, depth, cols, rows, pAlpha, indexA, blockB, strideB, offsetB, result);
+ break;
+ case 2:
+ colVSXLoopBodyExtraN<2, rhsExtraCols, lhsExtraRows>(col, depth, cols, rows, pAlpha, indexA, blockB, strideB, offsetB, result);
+ break;
+ case 1:
+ colVSXLoopBodyExtraN<1, rhsExtraCols, lhsExtraRows>(col, depth, cols, rows, pAlpha, indexA, blockB, strideB, offsetB, result);
+ break;
+ default:
+ if (rhsExtraCols) {
+ colVSXLoopBody<1, true, lhsExtraRows>(col, depth, cols, rows, pAlpha, indexA, blockB, strideB, offsetB, result);
+ }
+ break;
+ }
+}
+
+template<Index size, bool lhsExtraRows = false>
+EIGEN_ALWAYS_INLINE void colVSXLoops(Index depth, Index cols, Index rows, const Packet4f pAlpha, const bfloat16* indexA, const float* indexA2, const float* blockB2, Index strideA, Index strideB, Index offsetB, float* result2)
+{
+ Index delta_rows = 2*(lhsExtraRows ? (rows & 3) : size);
+ for (Index row = 0; row < size; row += 4) {
+ convertArrayPointerBF16toF32Dup<lhsExtraRows>(const_cast<float *>(indexA2), strideA, delta_rows, indexA, row, rows & 3);
+
+ const float *blockB = blockB2;
+ float *result = result2 + row;
+
+ Index col = 0;
+ if (cols >= (MAX_BFLOAT16_ACC_VSX * 4)) {
+ colVSXLoopBody<MAX_BFLOAT16_ACC_VSX, false, lhsExtraRows>(col, depth, cols, rows, pAlpha, indexA2, blockB, strideB, 0, result);
+ blockB += (strideB >> 1)*col;
+ result += rows*col;
+ }
+ if (cols & 3) {
+ colVSXLoopBodyExtra<true, lhsExtraRows>(col, depth, cols, rows, pAlpha, indexA2, blockB, strideB, offsetB, result);
+ } else {
+ colVSXLoopBodyExtra<false, lhsExtraRows>(col, depth, cols, rows, pAlpha, indexA2, blockB, strideB, 0, result);
+ }
+ }
+}
+
+template<Index size>
+EIGEN_ALWAYS_INLINE void calcVSXColLoops(const bfloat16*& indexA, const float* indexA2, Index& row, Index depth, Index cols, Index rows, const Packet4f pAlpha, const float* indexB, Index strideA, Index strideB, Index offsetA, Index offsetB, Index bigSuffix, float *result)
+{
+ if ((size == 16) || (rows & size)) {
+ indexA += size*offsetA;
+ colVSXLoops<size>(depth, cols, rows, pAlpha, indexA, indexA2, indexB, strideA, strideB, offsetB, result + row);
+ row += size;
+ indexA += bigSuffix*size/16;
+ }
+}
+
+template<const Index size, typename DataMapper>
+EIGEN_ALWAYS_INLINE void convertBF16toF32(Index& i, float *result, Index rows, const DataMapper& src)
+{
+ constexpr Index extra = ((size < 4) ? 4 : size);
+ for(; i + size <= rows; i += extra){
+ PacketBlock<Packet8bf,(size+7)/8> r32;
+ r32.packet[0] = src.template loadPacket<Packet8bf>(i + 0);
+ if (size >= 16) {
+ r32.packet[1] = src.template loadPacket<Packet8bf>(i + 8);
+ }
+ if (size >= 32) {
+ r32.packet[2] = src.template loadPacket<Packet8bf>(i + 16);
+ r32.packet[3] = src.template loadPacket<Packet8bf>(i + 24);
+ }
+ storeConvertBlockBF16<size>(result + i, r32, rows & 3);
+ }
+}
+
+template<typename DataMapper>
+EIGEN_ALWAYS_INLINE void convertArrayBF16toF32(float *result, Index cols, Index rows, const DataMapper& src)
+{
+ typedef typename DataMapper::LinearMapper LinearMapper;
+ for(Index j = 0; j < cols; j++, result += rows){
+ const LinearMapper src2 = src.getLinearMapper(0, j);
+ Index i = 0;
+ convertBF16toF32<32, LinearMapper>(i, result, rows, src2);
+ convertBF16toF32<16, LinearMapper>(i, result, rows, src2);
+ convertBF16toF32<8, LinearMapper>(i, result, rows, src2);
+ convertBF16toF32<4, LinearMapper>(i, result, rows, src2);
+ convertBF16toF32<1, LinearMapper>(i, result, rows, src2);
+ }
+}
+
+EIGEN_ALWAYS_INLINE Packet8bf convertF32toBF16VSX(const float *res)
+{
+ return F32ToBf16Both(ploadu<Packet4f>(res + 0), ploadu<Packet4f>(res + 4));
+}
+
+template<typename DataMapper, const Index size>
+EIGEN_ALWAYS_INLINE void convertArrayF32toBF16ColVSX(float *result, Index col, Index rows, const DataMapper& res)
+{
+ const DataMapper res2 = res.getSubMapper(0, col);
+ Index row;
+ float *result2 = result + col*rows;
+ for(row = 0; row + 8 <= rows; row += 8){
+ // get and save block
+ PacketBlock<Packet8bf,size> block;
+ for(Index j = 0; j < size; j++){
+ block.packet[j] = convertF32toBF16VSX(result2 + j*rows + row);
+ }
+ res2.template storePacketBlock<Packet8bf,size>(row, 0, block);
+ }
+ // extra rows
+ if(row < rows){
+ for(Index j = 0; j < size; j++){
+ Packet8bf fp16 = convertF32toBF16VSX(result2 + j*rows + row);
+ res2.template storePacketPartial<Packet8bf>(row, j, fp16, rows & 7);
+ }
+ }
+}
+
+template<typename DataMapper>
+EIGEN_ALWAYS_INLINE void convertArrayF32toBF16VSX(float *result, Index cols, Index rows, const DataMapper& res)
+{
+ Index col;
+ for(col = 0; col + 4 <= cols; col += 4){
+ convertArrayF32toBF16ColVSX<DataMapper,4>(result, col, rows, res);
+ }
+ // extra cols
+ while(col < cols){
+ convertArrayF32toBF16ColVSX<DataMapper,1>(result, col, rows, res);
+ col++;
+ }
+}
+
+template<typename DataMapper>
+void gemmbfloat16(const DataMapper& res, const bfloat16* indexA, const bfloat16* indexB, Index rows, Index depth, Index cols, bfloat16 alpha, Index strideA, Index strideB, Index offsetA, Index offsetB)
+{
+ float falpha = Eigen::bfloat16_impl::bfloat16_to_float(alpha);
+ const Packet4f pAlpha = pset1<Packet4f>(falpha);
+
+ if( strideA == -1 ) strideA = depth;
+ if( strideB == -1 ) strideB = depth;
+
+ ei_declare_aligned_stack_constructed_variable(float, result, cols*rows, 0);
+ ei_declare_aligned_stack_constructed_variable(float, indexB2, strideB*cols, 0);
+ ei_declare_aligned_stack_constructed_variable(float, indexA2, ((strideA + 1) & -2)*4*2, 0);
+
+ convertArrayBF16toF32<DataMapper>(result, cols, rows, res);
+ convertArrayPointerBF16toF32(indexB2, cols, strideB, const_cast<bfloat16 *>(indexB));
+
+ Index bigSuffix = 2*8*(strideA-offsetA);
+ float* indexBF32 = indexB2 + 4*offsetB;
+ offsetB *= 3;
+ strideB *= 2;
+
+ Index row = 0;
+ // LHS (8x16) block
+ while(row + 16 <= rows){
+ calcVSXColLoops<16>(indexA, indexA2, row, depth, cols, rows, pAlpha, indexBF32, strideA, strideB, offsetA, offsetB, bigSuffix, result);
+ }
+ // LHS (8x8) block
+ calcVSXColLoops<8>(indexA, indexA2, row, depth, cols, rows, pAlpha, indexBF32, strideA, strideB, offsetA, offsetB, bigSuffix, result);
+ // LHS (8x4) block
+ calcVSXColLoops<4>(indexA, indexA2, row, depth, cols, rows, pAlpha, indexBF32, strideA, strideB, offsetA, offsetB, bigSuffix, result);
+ // extra rows
+ if(rows & 3){
+ // This index is the beginning of remaining block.
+ colVSXLoops<4, true>(depth, cols, rows, pAlpha, indexA, indexA2, indexBF32, strideA, strideB, offsetB, result + row);
+ }
+
+ // Convert back to bfloat16
+ convertArrayF32toBF16VSX<DataMapper>(result, cols, rows, res);
+}
+
+#undef MAX_BFLOAT16_ACC_VSX
+
#include "MatrixVectorProduct.h"
+template<const Index size, bool non_unit_stride, Index delta>
+EIGEN_ALWAYS_INLINE void storeBF16fromResult(bfloat16* dst, Packet8bf data, Index resInc, Index extra)
+{
+ if (non_unit_stride) {
+ if (size < 8) {
+ pscatter_partial(dst + delta*resInc, data, resInc, extra);
+ } else {
+ pscatter(dst + delta*resInc, data, resInc);
+ }
+ } else {
+ if (size < 8) {
+ pstoreu_partial(dst + delta, data, extra);
+ } else {
+ pstoreu(dst + delta, data);
+ }
+ }
+}
+
+template<const Index size, bool non_unit_stride = false>
+EIGEN_ALWAYS_INLINE void convertPointerF32toBF16VSX(Index& i, float* result, Index rows, bfloat16*& dst, Index resInc = 1)
+{
+ constexpr Index extra = ((size < 8) ? 8 : size);
+ for(; i + size <= rows; i += extra, dst += extra*resInc){
+ PacketBlock<Packet8bf,(size+7)/8> r32;
+ r32.packet[0] = convertF32toBF16VSX(result + i + 0);
+ if (size >= 16) {
+ r32.packet[1] = convertF32toBF16VSX(result + i + 8);
+ }
+ if (size >= 32) {
+ r32.packet[2] = convertF32toBF16VSX(result + i + 16);
+ r32.packet[3] = convertF32toBF16VSX(result + i + 24);
+ }
+ storeBF16fromResult<size, non_unit_stride, 0>(dst, r32.packet[0], resInc, rows & 7);
+ if (size >= 16) {
+ storeBF16fromResult<size, non_unit_stride, 8>(dst, r32.packet[1], resInc);
+ }
+ if (size >= 32) {
+ storeBF16fromResult<size, non_unit_stride, 16>(dst, r32.packet[2], resInc);
+ storeBF16fromResult<size, non_unit_stride, 24>(dst, r32.packet[3], resInc);
+ }
+ }
+}
+
+template<bool non_unit_stride = false>
+EIGEN_ALWAYS_INLINE void convertArrayPointerF32toBF16VSX(float *result, Index rows, bfloat16* dst, Index resInc = 1)
+{
+ Index i = 0;
+ convertPointerF32toBF16VSX<32,non_unit_stride>(i, result, rows, dst, resInc);
+ convertPointerF32toBF16VSX<16,non_unit_stride>(i, result, rows, dst, resInc);
+ convertPointerF32toBF16VSX<8,non_unit_stride>(i, result, rows, dst, resInc);
+ convertPointerF32toBF16VSX<1,non_unit_stride>(i, result, rows, dst, resInc);
+}
+
/************************************
* ppc64le template specializations *
* **********************************/
@@ -2735,10 +3358,7 @@
dhs_pack<double, DataMapper, Packet2d, RowMajor, PanelMode, false> pack;
pack(blockB, rhs, depth, cols, stride, offset);
}
-#endif
-#if !EIGEN_ALTIVEC_DISABLE_MMA && defined(__MMA__)
-#if EIGEN_ALTIVEC_USE_CUSTOM_PACK
template<typename Index, typename DataMapper, int nr, bool Conjugate, bool PanelMode>
struct gemm_pack_rhs<bfloat16, Index, DataMapper, nr, ColMajor, Conjugate, PanelMode>
{
@@ -2795,7 +3415,6 @@
dhs_pack<bfloat16, DataMapper, Packet8bf, RowMajor, PanelMode, true> pack;
pack(blockA, lhs, depth, rows, stride, offset);
}
-#endif
template<typename Index, typename DataMapper, int Pack1, int Pack2, typename Packet, bool Conjugate, bool PanelMode>
struct gemm_pack_lhs<float, Index, DataMapper, Pack1, Pack2, Packet, RowMajor, Conjugate, PanelMode>
@@ -2987,21 +3606,12 @@
{
const Index accRows = quad_traits<float>::rows;
const Index accCols = quad_traits<float>::size;
- void (*gemm_function)(const DataMapper&, const float*, const float*, Index, Index, Index, float, Index, Index, Index, Index);
-
- #if defined(EIGEN_ALTIVEC_MMA_ONLY)
- //generate with MMA only
- gemm_function = &Eigen::internal::gemmMMA<float, Packet, RhsPacket, DataMapper, accRows, accCols>;
- #elif defined(EIGEN_ALTIVEC_MMA_DYNAMIC_DISPATCH)
- if (__builtin_cpu_supports ("arch_3_1") && __builtin_cpu_supports ("mma")){
- gemm_function = &Eigen::internal::gemmMMA<float, Packet, RhsPacket, DataMapper, accRows, accCols>;
- }
- else{
- gemm_function = &Eigen::internal::gemm<float, Packet, RhsPacket, DataMapper, accRows, accCols>;
- }
- #else
- gemm_function = &Eigen::internal::gemm<float, Packet, RhsPacket, DataMapper, accRows, accCols>;
+ static void (*gemm_function)(const DataMapper&, const float*, const float*, Index, Index, Index, float, Index, Index, Index, Index) =
+ #ifdef EIGEN_MATRIX_PRODUCT_MMA_ALTIVEC_H
+ (supportsMMA()) ?
+ &Eigen::internal::gemmMMA<float, Packet, RhsPacket, DataMapper, accRows, accCols> :
#endif
+ &Eigen::internal::gemm<float, Packet, RhsPacket, DataMapper, accRows, accCols>;
gemm_function(res, blockA, blockB, rows, depth, cols, alpha, strideA, strideB, offsetA, offsetB);
}
@@ -3025,22 +3635,13 @@
{
const Index accRows = quad_traits<float>::rows;
const Index accCols = quad_traits<float>::size;
- void (*gemm_function)(const DataMapper&, const std::complex<float>*, const std::complex<float>*,
- Index, Index, Index, std::complex<float>, Index, Index, Index, Index);
-
- #if defined(EIGEN_ALTIVEC_MMA_ONLY)
- //generate with MMA only
- gemm_function = &Eigen::internal::gemm_complexMMA<std::complex<float>, std::complex<float>, std::complex<float>, float, Packet, Packetc, RhsPacket, DataMapper, accRows, accCols, ConjugateLhs, ConjugateRhs, false, false>;
- #elif defined(EIGEN_ALTIVEC_MMA_DYNAMIC_DISPATCH)
- if (__builtin_cpu_supports ("arch_3_1") && __builtin_cpu_supports ("mma")){
- gemm_function = &Eigen::internal::gemm_complexMMA<std::complex<float>, std::complex<float>, std::complex<float>, float, Packet, Packetc, RhsPacket, DataMapper, accRows, accCols, ConjugateLhs, ConjugateRhs, false, false>;
- }
- else{
- gemm_function = &Eigen::internal::gemm_complex<std::complex<float>, std::complex<float>, std::complex<float>, float, Packet, Packetc, RhsPacket, DataMapper, accRows, accCols, ConjugateLhs, ConjugateRhs, false, false>;
- }
- #else
- gemm_function = &Eigen::internal::gemm_complex<std::complex<float>, std::complex<float>, std::complex<float>, float, Packet, Packetc, RhsPacket, DataMapper, accRows, accCols, ConjugateLhs, ConjugateRhs, false, false>;
+ static void (*gemm_function)(const DataMapper&, const std::complex<float>*, const std::complex<float>*,
+ Index, Index, Index, std::complex<float>, Index, Index, Index, Index) =
+ #ifdef EIGEN_MATRIX_PRODUCT_MMA_ALTIVEC_H
+ (supportsMMA()) ?
+ &Eigen::internal::gemm_complexMMA<std::complex<float>, std::complex<float>, std::complex<float>, float, Packet, Packetc, RhsPacket, DataMapper, accRows, accCols, ConjugateLhs, ConjugateRhs, false, false> :
#endif
+ &Eigen::internal::gemm_complex<std::complex<float>, std::complex<float>, std::complex<float>, float, Packet, Packetc, RhsPacket, DataMapper, accRows, accCols, ConjugateLhs, ConjugateRhs, false, false>;
gemm_function(res, blockA, blockB, rows, depth, cols, alpha, strideA, strideB, offsetA, offsetB);
}
@@ -3064,21 +3665,13 @@
{
const Index accRows = quad_traits<float>::rows;
const Index accCols = quad_traits<float>::size;
- void (*gemm_function)(const DataMapper&, const float*, const std::complex<float>*,
- Index, Index, Index, std::complex<float>, Index, Index, Index, Index);
- #if defined(EIGEN_ALTIVEC_MMA_ONLY)
- //generate with MMA only
- gemm_function = &Eigen::internal::gemm_complexMMA<float, std::complex<float>, std::complex<float>, float, Packet, Packetc, RhsPacket, DataMapper, accRows, accCols, ConjugateLhs, ConjugateRhs, true, false>;
- #elif defined(EIGEN_ALTIVEC_MMA_DYNAMIC_DISPATCH)
- if (__builtin_cpu_supports ("arch_3_1") && __builtin_cpu_supports ("mma")){
- gemm_function = &Eigen::internal::gemm_complexMMA<float, std::complex<float>, std::complex<float>, float, Packet, Packetc, RhsPacket, DataMapper, accRows, accCols, ConjugateLhs, ConjugateRhs, true, false>;
- }
- else{
- gemm_function = &Eigen::internal::gemm_complex<float, std::complex<float>, std::complex<float>, float, Packet, Packetc, RhsPacket, DataMapper, accRows, accCols, ConjugateLhs, ConjugateRhs, true, false>;
- }
- #else
- gemm_function = &Eigen::internal::gemm_complex<float, std::complex<float>, std::complex<float>, float, Packet, Packetc, RhsPacket, DataMapper, accRows, accCols, ConjugateLhs, ConjugateRhs, true, false>;
+ static void (*gemm_function)(const DataMapper&, const float*, const std::complex<float>*,
+ Index, Index, Index, std::complex<float>, Index, Index, Index, Index) =
+ #ifdef EIGEN_MATRIX_PRODUCT_MMA_ALTIVEC_H
+ (supportsMMA()) ?
+ &Eigen::internal::gemm_complexMMA<float, std::complex<float>, std::complex<float>, float, Packet, Packetc, RhsPacket, DataMapper, accRows, accCols, ConjugateLhs, ConjugateRhs, true, false> :
#endif
+ &Eigen::internal::gemm_complex<float, std::complex<float>, std::complex<float>, float, Packet, Packetc, RhsPacket, DataMapper, accRows, accCols, ConjugateLhs, ConjugateRhs, true, false>;
gemm_function(res, blockA, blockB, rows, depth, cols, alpha, strideA, strideB, offsetA, offsetB);
}
@@ -3102,21 +3695,13 @@
{
const Index accRows = quad_traits<float>::rows;
const Index accCols = quad_traits<float>::size;
- void (*gemm_function)(const DataMapper&, const std::complex<float>*, const float*,
- Index, Index, Index, std::complex<float>, Index, Index, Index, Index);
- #if defined(EIGEN_ALTIVEC_MMA_ONLY)
- //generate with MMA only
- gemm_function = &Eigen::internal::gemm_complexMMA<std::complex<float>, float, std::complex<float>, float, Packet, Packetc, RhsPacket, DataMapper, accRows, accCols, ConjugateLhs, ConjugateRhs, false, true>;
- #elif defined(EIGEN_ALTIVEC_MMA_DYNAMIC_DISPATCH)
- if (__builtin_cpu_supports ("arch_3_1") && __builtin_cpu_supports ("mma")){
- gemm_function = &Eigen::internal::gemm_complexMMA<std::complex<float>, float, std::complex<float>, float, Packet, Packetc, RhsPacket, DataMapper, accRows, accCols, ConjugateLhs, ConjugateRhs, false, true>;
- }
- else{
- gemm_function = &Eigen::internal::gemm_complex<std::complex<float>, float, std::complex<float>, float, Packet, Packetc, RhsPacket, DataMapper, accRows, accCols, ConjugateLhs, ConjugateRhs, false, true>;
- }
- #else
- gemm_function = &Eigen::internal::gemm_complex<std::complex<float>, float, std::complex<float>, float, Packet, Packetc, RhsPacket, DataMapper, accRows, accCols, ConjugateLhs, ConjugateRhs, false, true>;
+ static void (*gemm_function)(const DataMapper&, const std::complex<float>*, const float*,
+ Index, Index, Index, std::complex<float>, Index, Index, Index, Index) =
+ #ifdef EIGEN_MATRIX_PRODUCT_MMA_ALTIVEC_H
+ (supportsMMA()) ?
+ &Eigen::internal::gemm_complexMMA<std::complex<float>, float, std::complex<float>, float, Packet, Packetc, RhsPacket, DataMapper, accRows, accCols, ConjugateLhs, ConjugateRhs, false, true> :
#endif
+ &Eigen::internal::gemm_complex<std::complex<float>, float, std::complex<float>, float, Packet, Packetc, RhsPacket, DataMapper, accRows, accCols, ConjugateLhs, ConjugateRhs, false, true>;
gemm_function(res, blockA, blockB, rows, depth, cols, alpha, strideA, strideB, offsetA, offsetB);
}
@@ -3139,21 +3724,12 @@
{
const Index accRows = quad_traits<double>::rows;
const Index accCols = quad_traits<double>::size;
- void (*gemm_function)(const DataMapper&, const double*, const double*, Index, Index, Index, double, Index, Index, Index, Index);
-
- #if defined(EIGEN_ALTIVEC_MMA_ONLY)
- //generate with MMA only
- gemm_function = &Eigen::internal::gemmMMA<double, Packet, RhsPacket, DataMapper, accRows, accCols>;
- #elif defined(EIGEN_ALTIVEC_MMA_DYNAMIC_DISPATCH)
- if (__builtin_cpu_supports ("arch_3_1") && __builtin_cpu_supports ("mma")){
- gemm_function = &Eigen::internal::gemmMMA<double, Packet, RhsPacket, DataMapper, accRows, accCols>;
- }
- else{
- gemm_function = &Eigen::internal::gemm<double, Packet, RhsPacket, DataMapper, accRows, accCols>;
- }
- #else
- gemm_function = &Eigen::internal::gemm<double, Packet, RhsPacket, DataMapper, accRows, accCols>;
+ static void (*gemm_function)(const DataMapper&, const double*, const double*, Index, Index, Index, double, Index, Index, Index, Index) =
+ #ifdef EIGEN_MATRIX_PRODUCT_MMA_ALTIVEC_H
+ (supportsMMA()) ?
+ &Eigen::internal::gemmMMA<double, Packet, RhsPacket, DataMapper, accRows, accCols> :
#endif
+ &Eigen::internal::gemm<double, Packet, RhsPacket, DataMapper, accRows, accCols>;
gemm_function(res, blockA, blockB, rows, depth, cols, alpha, strideA, strideB, offsetA, offsetB);
}
@@ -3177,21 +3753,13 @@
{
const Index accRows = quad_traits<double>::rows;
const Index accCols = quad_traits<double>::size;
- void (*gemm_function)(const DataMapper&, const std::complex<double>*, const std::complex<double>*,
- Index, Index, Index, std::complex<double>, Index, Index, Index, Index);
- #if defined(EIGEN_ALTIVEC_MMA_ONLY)
- //generate with MMA only
- gemm_function = &Eigen::internal::gemm_complexMMA<std::complex<double>, std::complex<double>, std::complex<double>, double, Packet, Packetc, RhsPacket, DataMapper, accRows, accCols, ConjugateLhs, ConjugateRhs, false, false>;
- #elif defined(EIGEN_ALTIVEC_MMA_DYNAMIC_DISPATCH)
- if (__builtin_cpu_supports ("arch_3_1") && __builtin_cpu_supports ("mma")){
- gemm_function = &Eigen::internal::gemm_complexMMA<std::complex<double>, std::complex<double>, std::complex<double>, double, Packet, Packetc, RhsPacket, DataMapper, accRows, accCols, ConjugateLhs, ConjugateRhs, false, false>;
- }
- else{
- gemm_function = &Eigen::internal::gemm_complex<std::complex<double>, std::complex<double>, std::complex<double>, double, Packet, Packetc, RhsPacket, DataMapper, accRows, accCols, ConjugateLhs, ConjugateRhs, false, false>;
- }
- #else
- gemm_function = &Eigen::internal::gemm_complex<std::complex<double>, std::complex<double>, std::complex<double>, double, Packet, Packetc, RhsPacket, DataMapper, accRows, accCols, ConjugateLhs, ConjugateRhs, false, false>;
+ static void (*gemm_function)(const DataMapper&, const std::complex<double>*, const std::complex<double>*,
+ Index, Index, Index, std::complex<double>, Index, Index, Index, Index) =
+ #ifdef EIGEN_MATRIX_PRODUCT_MMA_ALTIVEC_H
+ (supportsMMA()) ?
+ &Eigen::internal::gemm_complexMMA<std::complex<double>, std::complex<double>, std::complex<double>, double, Packet, Packetc, RhsPacket, DataMapper, accRows, accCols, ConjugateLhs, ConjugateRhs, false, false> :
#endif
+ &Eigen::internal::gemm_complex<std::complex<double>, std::complex<double>, std::complex<double>, double, Packet, Packetc, RhsPacket, DataMapper, accRows, accCols, ConjugateLhs, ConjugateRhs, false, false>;
gemm_function(res, blockA, blockB, rows, depth, cols, alpha, strideA, strideB, offsetA, offsetB);
}
@@ -3215,21 +3783,13 @@
{
const Index accRows = quad_traits<double>::rows;
const Index accCols = quad_traits<double>::size;
- void (*gemm_function)(const DataMapper&, const std::complex<double>*, const double*,
- Index, Index, Index, std::complex<double>, Index, Index, Index, Index);
- #if defined(EIGEN_ALTIVEC_MMA_ONLY)
- //generate with MMA only
- gemm_function = &Eigen::internal::gemm_complexMMA<std::complex<double>, double, std::complex<double>, double, Packet, Packetc, RhsPacket, DataMapper, accRows, accCols, ConjugateLhs, ConjugateRhs, false, true>;
- #elif defined(EIGEN_ALTIVEC_MMA_DYNAMIC_DISPATCH)
- if (__builtin_cpu_supports ("arch_3_1") && __builtin_cpu_supports ("mma")){
- gemm_function = &Eigen::internal::gemm_complexMMA<std::complex<double>, double, std::complex<double>, double, Packet, Packetc, RhsPacket, DataMapper, accRows, accCols, ConjugateLhs, ConjugateRhs, false, true>;
- }
- else{
- gemm_function = &Eigen::internal::gemm_complex<std::complex<double>, double, std::complex<double>, double, Packet, Packetc, RhsPacket, DataMapper, accRows, accCols, ConjugateLhs, ConjugateRhs, false, true>;
- }
- #else
- gemm_function = &Eigen::internal::gemm_complex<std::complex<double>, double, std::complex<double>, double, Packet, Packetc, RhsPacket, DataMapper, accRows, accCols, ConjugateLhs, ConjugateRhs, false, true>;
+ static void (*gemm_function)(const DataMapper&, const std::complex<double>*, const double*,
+ Index, Index, Index, std::complex<double>, Index, Index, Index, Index) =
+ #ifdef EIGEN_MATRIX_PRODUCT_MMA_ALTIVEC_H
+ (supportsMMA()) ?
+ &Eigen::internal::gemm_complexMMA<std::complex<double>, double, std::complex<double>, double, Packet, Packetc, RhsPacket, DataMapper, accRows, accCols, ConjugateLhs, ConjugateRhs, false, true> :
#endif
+ &Eigen::internal::gemm_complex<std::complex<double>, double, std::complex<double>, double, Packet, Packetc, RhsPacket, DataMapper, accRows, accCols, ConjugateLhs, ConjugateRhs, false, true>;
gemm_function(res, blockA, blockB, rows, depth, cols, alpha, strideA, strideB, offsetA, offsetB);
}
@@ -3253,25 +3813,16 @@
{
const Index accRows = quad_traits<double>::rows;
const Index accCols = quad_traits<double>::size;
- void (*gemm_function)(const DataMapper&, const double*, const std::complex<double>*,
- Index, Index, Index, std::complex<double>, Index, Index, Index, Index);
- #if defined(EIGEN_ALTIVEC_MMA_ONLY)
- //generate with MMA only
- gemm_function = &Eigen::internal::gemm_complexMMA<double, std::complex<double>, std::complex<double>, double, Packet, Packetc, RhsPacket, DataMapper, accRows, accCols, ConjugateLhs, ConjugateRhs, true, false>;
- #elif defined(EIGEN_ALTIVEC_MMA_DYNAMIC_DISPATCH)
- if (__builtin_cpu_supports ("arch_3_1") && __builtin_cpu_supports ("mma")){
- gemm_function = &Eigen::internal::gemm_complexMMA<double, std::complex<double>, std::complex<double>, double, Packet, Packetc, RhsPacket, DataMapper, accRows, accCols, ConjugateLhs, ConjugateRhs, true, false>;
- }
- else{
- gemm_function = &Eigen::internal::gemm_complex<double, std::complex<double>, std::complex<double>, double, Packet, Packetc, RhsPacket, DataMapper, accRows, accCols, ConjugateLhs, ConjugateRhs, true, false>;
- }
- #else
- gemm_function = &Eigen::internal::gemm_complex<double, std::complex<double>, std::complex<double>, double, Packet, Packetc, RhsPacket, DataMapper, accRows, accCols, ConjugateLhs, ConjugateRhs, true, false>;
+ static void (*gemm_function)(const DataMapper&, const double*, const std::complex<double>*,
+ Index, Index, Index, std::complex<double>, Index, Index, Index, Index) =
+ #ifdef EIGEN_MATRIX_PRODUCT_MMA_ALTIVEC_H
+ (supportsMMA()) ?
+ &Eigen::internal::gemm_complexMMA<double, std::complex<double>, std::complex<double>, double, Packet, Packetc, RhsPacket, DataMapper, accRows, accCols, ConjugateLhs, ConjugateRhs, true, false> :
#endif
+ &Eigen::internal::gemm_complex<double, std::complex<double>, std::complex<double>, double, Packet, Packetc, RhsPacket, DataMapper, accRows, accCols, ConjugateLhs, ConjugateRhs, true, false>;
gemm_function(res, blockA, blockB, rows, depth, cols, alpha, strideA, strideB, offsetA, offsetB);
}
-#if !EIGEN_ALTIVEC_DISABLE_MMA && defined(__MMA__)
template<typename Index, typename DataMapper, int mr, int nr, bool ConjugateLhs, bool ConjugateRhs>
struct gebp_kernel<bfloat16, bfloat16, Index, DataMapper, mr, nr, ConjugateLhs, ConjugateRhs>
{
@@ -3289,9 +3840,14 @@
Index rows, Index depth, Index cols, bfloat16 alpha,
Index strideA, Index strideB, Index offsetA, Index offsetB)
{
- Eigen::internal::gemmMMAbfloat16<DataMapper>(res, blockA, blockB, rows, depth, cols, alpha, strideA, strideB, offsetA, offsetB);
+ static void (*gemm_function)(const DataMapper&, const bfloat16*, const bfloat16*, Index, Index, Index, bfloat16, Index, Index, Index, Index) =
+ #ifdef EIGEN_MATRIX_PRODUCT_MMA_ALTIVEC_H
+ (supportsMMA()) ?
+ &Eigen::internal::gemmMMAbfloat16<DataMapper> :
+ #endif
+ &Eigen::internal::gemmbfloat16<DataMapper>;
+ gemm_function(res, blockA, blockB, rows, depth, cols, alpha, strideA, strideB, offsetA, offsetB);
}
-#endif
} // end namespace internal
} // end namespace Eigen
diff --git a/Eigen/src/Core/arch/AltiVec/MatrixProductCommon.h b/Eigen/src/Core/arch/AltiVec/MatrixProductCommon.h
index 1ac6629..e89b5e5 100644
--- a/Eigen/src/Core/arch/AltiVec/MatrixProductCommon.h
+++ b/Eigen/src/Core/arch/AltiVec/MatrixProductCommon.h
@@ -84,6 +84,18 @@
const Packet& pAlphaImag,
const Packet& pMask);
+template<typename DataMapper>
+EIGEN_ALWAYS_INLINE void convertArrayBF16toF32(float *result, Index cols, Index rows, const DataMapper& src);
+
+template<const Index size, bool non_unit_stride, Index delta>
+EIGEN_ALWAYS_INLINE void storeBF16fromResult(bfloat16* dst, Packet8bf data, Index resInc, Index extra = 0);
+
+template<bool non_unit_stride = false>
+EIGEN_ALWAYS_INLINE void convertArrayPointerBF16toF32(float *result, Index cols, Index rows, bfloat16* src, Index resInc = 1);
+
+template<bool rhsExtraCols, bool lhsExtraRows>
+EIGEN_ALWAYS_INLINE void storeResults(Packet4f (&acc)[4], Index rows, const Packet4f pAlpha, float* result, Index extra_cols, Index extra_rows);
+
template<typename Packet>
EIGEN_ALWAYS_INLINE Packet ploadLhs(const __UNPACK_TYPE__(Packet)* lhs);
diff --git a/Eigen/src/Core/arch/AltiVec/MatrixProductMMA.h b/Eigen/src/Core/arch/AltiVec/MatrixProductMMA.h
index 05d180c..e4013a7 100644
--- a/Eigen/src/Core/arch/AltiVec/MatrixProductMMA.h
+++ b/Eigen/src/Core/arch/AltiVec/MatrixProductMMA.h
@@ -28,9 +28,7 @@
#include "../../InternalHeaderCheck.h"
-#if !EIGEN_ALTIVEC_DISABLE_MMA
#include "MatrixProductMMAbfloat16.h"
-#endif
namespace Eigen {
diff --git a/Eigen/src/Core/arch/AltiVec/MatrixProductMMAbfloat16.h b/Eigen/src/Core/arch/AltiVec/MatrixProductMMAbfloat16.h
index fe4906d..731fd9b 100644
--- a/Eigen/src/Core/arch/AltiVec/MatrixProductMMAbfloat16.h
+++ b/Eigen/src/Core/arch/AltiVec/MatrixProductMMAbfloat16.h
@@ -53,7 +53,7 @@
indexA += k*(lhsExtraRows ? extra_rows : num_packets);
for(Index j = 0; j < num_lhs; j++) {
- lhs[j] = loadBfloat16<zero>(indexA + j*(zero ? 4 : 8)); //a packet of bfloat16 has 8 elements
+ lhs[j] = loadBfloat16<zero>(indexA + j*(zero ? 4 : 8)); // a packet of bfloat16 has 8 elements
}
BFLOAT16_UNROLL
@@ -65,46 +65,6 @@
}
}
-EIGEN_ALWAYS_INLINE Packet4f loadAndMultiplyF32(Packet4f acc, const Packet4f pAlpha, float* result)
-{
- Packet4f result_block = ploadu<Packet4f>(result);
- return pmadd(acc, pAlpha, result_block);
-}
-
-template<bool lhsExtraRows>
-EIGEN_ALWAYS_INLINE void storeF32(float*& result, Packet4f result_block, Index rows, Index extra_rows)
-{
- if (lhsExtraRows) {
- pstoreu_partial(result, result_block, extra_rows);
- } else {
- pstoreu(result, result_block);
- }
- result += rows;
-}
-
-template<bool rhsExtraCols, bool lhsExtraRows>
-EIGEN_ALWAYS_INLINE void storeResults(Packet4f (&acc)[4], Index rows, const Packet4f pAlpha, float* result, Index extra_cols, Index extra_rows)
-{
- Index x = 0;
- if (rhsExtraCols) {
- do{
- Packet4f result_block = loadAndMultiplyF32(acc[x], pAlpha, result);
- storeF32<lhsExtraRows>(result, result_block, rows, extra_rows);
- } while (++x < extra_cols);
- } else {
- Packet4f result_block[4];
- float *result2 = result;
- do{
- result_block[x] = loadAndMultiplyF32(acc[x], pAlpha, result);
- result += rows;
- } while (++x < 4);
- x = 0;
- do{
- storeF32<lhsExtraRows>(result2, result_block[x], rows, extra_rows);
- } while (++x < 4);
- }
-}
-
template<Index num_acc>
EIGEN_ALWAYS_INLINE void zeroAccumulators(__vector_quad (&quad_acc)[num_acc])
{
@@ -165,17 +125,14 @@
template<const Index num_acc, const Index num_packets, bool rhsExtraCols, bool lhsExtraRows>
void colLoopBody(Index& col, Index depth, Index cols, Index rows, const Packet4f pAlpha, const bfloat16* indexA, const bfloat16* indexB, Index strideB, Index offsetB, float* result)
{
- constexpr Index step = (num_acc * 4); //each accumulator has 4 elements
+ constexpr Index step = (num_acc * 4); // each accumulator has 4 elements
const Index extra_cols = (rhsExtraCols) ? (cols & 3) : 0;
const Index extra_rows = (lhsExtraRows) ? (rows & 3) : 0;
constexpr bool multiIters = !rhsExtraCols && (num_acc == MAX_BFLOAT16_ACC);
+ constexpr bool normIters = multiIters && ((num_acc % (num_packets / 4)) == 0);
do{
- if (multiIters && ((num_acc % (num_packets / 4)) == 0)) {
- colLoopBodyIter<num_acc, num_packets, rhsExtraCols, lhsExtraRows, true>(depth, rows, pAlpha, indexA, indexB, strideB, offsetB, result, extra_cols, extra_rows);
- } else {
- colLoopBodyIter<num_acc, num_packets, rhsExtraCols, lhsExtraRows>(depth, rows, pAlpha, indexA, indexB, strideB, offsetB, result, extra_cols, extra_rows);
- }
+ colLoopBodyIter<num_acc, num_packets, rhsExtraCols, lhsExtraRows, normIters>(depth, rows, pAlpha, indexA, indexB, strideB, offsetB, result, extra_cols, extra_rows);
indexB += strideB*num_acc;
result += rows*step;
@@ -239,104 +196,89 @@
}
}
-template<bool full = true>
EIGEN_ALWAYS_INLINE Packet8bf convertF32toBF16(const float *res)
{
- Packet16uc fp16_0 = __builtin_vsx_xvcvspbf16(reinterpret_cast<Packet16uc>(ploadu<Packet4f>(res + 0)));
- Packet16uc fp16_1 = (full) ? __builtin_vsx_xvcvspbf16(reinterpret_cast<Packet16uc>(ploadu<Packet4f>(res + 4))) : fp16_0;
- return vec_pack(reinterpret_cast<Packet4ui>(fp16_0), reinterpret_cast<Packet4ui>(fp16_1));
+ Packet16uc fp16[2];
+#if EIGEN_COMP_LLVM
+ __vector_pair fp16_vp = *reinterpret_cast<__vector_pair *>(const_cast<float *>(res));
+ __builtin_vsx_disassemble_pair(reinterpret_cast<void*>(fp16), &fp16_vp);
+ fp16[0] = __builtin_vsx_xvcvspbf16(fp16[0]);
+ fp16[1] = __builtin_vsx_xvcvspbf16(fp16[1]);
+#else
+ fp16[0] = __builtin_vsx_xvcvspbf16(reinterpret_cast<Packet16uc>(ploadu<Packet4f>(res + 0)));
+ fp16[1] = __builtin_vsx_xvcvspbf16(reinterpret_cast<Packet16uc>(ploadu<Packet4f>(res + 4)));
+#endif
+ return vec_pack(reinterpret_cast<Packet4ui>(fp16[0]), reinterpret_cast<Packet4ui>(fp16[1]));
}
-template<int N>
-EIGEN_ALWAYS_INLINE void storeConvertBlockBF16(float* to, PacketBlock<Packet8bf,(N+4)/8>& block)
+template<typename DataMapper, const Index size>
+EIGEN_ALWAYS_INLINE void convertArrayF32toBF16Col(float *result, Index col, Index rows, const DataMapper& res)
{
- Packet8us z = pset1<Packet8us>(0);
- pstore(to + 0, reinterpret_cast<Packet4f>(vec_mergeh(z, block.packet[0].m_val)));
- if (N >= 8) {
- pstore(to + 4, reinterpret_cast<Packet4f>(vec_mergel(z, block.packet[0].m_val)));
+ const DataMapper res2 = res.getSubMapper(0, col);
+ Index row;
+ float *result2 = result + col*rows;
+ for(row = 0; row + 8 <= rows; row += 8){
+ // get and save block
+ PacketBlock<Packet8bf,size> block;
+ for(Index j = 0; j < size; j++){
+ block.packet[j] = convertF32toBF16(result2 + j*rows + row);
+ }
+ res2.template storePacketBlock<Packet8bf,size>(row, 0, block);
}
- if (N >= 16) {
- pstore(to + 8, reinterpret_cast<Packet4f>(vec_mergeh(z, block.packet[1].m_val)));
- pstore(to + 12, reinterpret_cast<Packet4f>(vec_mergel(z, block.packet[1].m_val)));
- }
- if (N >= 32) {
- pstore(to + 16, reinterpret_cast<Packet4f>(vec_mergeh(z, block.packet[2].m_val)));
- pstore(to + 20, reinterpret_cast<Packet4f>(vec_mergel(z, block.packet[2].m_val)));
- pstore(to + 24, reinterpret_cast<Packet4f>(vec_mergeh(z, block.packet[3].m_val)));
- pstore(to + 28, reinterpret_cast<Packet4f>(vec_mergel(z, block.packet[3].m_val)));
+ // extra rows
+ if(row < rows){
+ for(Index j = 0; j < size; j++){
+ Packet8bf fp16 = convertF32toBF16(result2 + j*rows + row);
+ res2.template storePacketPartial<Packet8bf>(row, j, fp16, rows & 7);
+ }
}
}
-template<const Index size, typename DataMapper>
-EIGEN_ALWAYS_INLINE void convertBF16toF32(Index& i, float *result, Index rows, const DataMapper& src)
+template<const Index size, bool non_unit_stride = false>
+EIGEN_ALWAYS_INLINE void convertPointerF32toBF16(Index& i, float* result, Index rows, bfloat16*& dst, Index resInc = 1)
{
- for(; i + size <= rows; i += size){
- PacketBlock<Packet8bf,(size+4)/8> r32;
- r32.packet[0] = src.template loadPacket<Packet8bf>(i + 0);
+ constexpr Index extra = ((size < 8) ? 8 : size);
+ for(; i + size <= rows; i += extra, dst += extra*resInc){
+ PacketBlock<Packet8bf,(size+7)/8> r32;
+ r32.packet[0] = convertF32toBF16(result + i + 0);
if (size >= 16) {
- r32.packet[1] = src.template loadPacket<Packet8bf>(i + 8);
+ r32.packet[1] = convertF32toBF16(result + i + 8);
}
if (size >= 32) {
- r32.packet[2] = src.template loadPacket<Packet8bf>(i + 16);
- r32.packet[3] = src.template loadPacket<Packet8bf>(i + 24);
+ r32.packet[2] = convertF32toBF16(result + i + 16);
+ r32.packet[3] = convertF32toBF16(result + i + 24);
}
- storeConvertBlockBF16<size>(result + i, r32);
+ storeBF16fromResult<size, non_unit_stride, 0>(dst, r32.packet[0], resInc, rows & 7);
+ if (size >= 16) {
+ storeBF16fromResult<size, non_unit_stride, 8>(dst, r32.packet[1], resInc);
+ }
+ if (size >= 32) {
+ storeBF16fromResult<size, non_unit_stride, 16>(dst, r32.packet[2], resInc);
+ storeBF16fromResult<size, non_unit_stride, 24>(dst, r32.packet[3], resInc);
+ }
}
}
-template<typename DataMapper>
-EIGEN_ALWAYS_INLINE void convertArrayBF16toF32(float *result, Index cols, Index rows, const DataMapper& src)
+template<bool non_unit_stride = false>
+EIGEN_ALWAYS_INLINE void convertArrayPointerF32toBF16(float *result, Index rows, bfloat16* dst, Index resInc = 1)
{
- typedef typename DataMapper::LinearMapper LinearMapper;
- for(Index j = 0; j < cols; j++, result += rows){
- const LinearMapper src2 = src.getLinearMapper(0, j);
- Index i = 0;
- convertBF16toF32<32, LinearMapper>(i, result, rows, src2);
- convertBF16toF32<16, LinearMapper>(i, result, rows, src2);
- convertBF16toF32<8, LinearMapper>(i, result, rows, src2);
- convertBF16toF32<4, LinearMapper>(i, result, rows, src2);
- for(; i < rows; i++){
- result[i] = Eigen::bfloat16_impl::bfloat16_to_float(src2(i));
- }
- }
+ Index i = 0;
+ convertPointerF32toBF16<32,non_unit_stride>(i, result, rows, dst, resInc);
+ convertPointerF32toBF16<16,non_unit_stride>(i, result, rows, dst, resInc);
+ convertPointerF32toBF16<8,non_unit_stride>(i, result, rows, dst, resInc);
+ convertPointerF32toBF16<1,non_unit_stride>(i, result, rows, dst, resInc);
}
template<typename DataMapper>
EIGEN_ALWAYS_INLINE void convertArrayF32toBF16(float *result, Index cols, Index rows, const DataMapper& res)
{
- typedef typename DataMapper::LinearMapper LinearMapper;
- Index col, row;
+ Index col;
for(col = 0; col + 4 <= cols; col += 4){
- const DataMapper res2 = res.getSubMapper(0, col);
- for(row = 0; row + 8 <= rows; row += 8){
- //get and save block
- PacketBlock<Packet8bf,4> block;
- for(Index j = 0; j < 4; j++){
- block.packet[j].m_val = convertF32toBF16(result + (col + j)*rows + row);
- }
-
- res2.template storePacketBlock<Packet8bf,4>(row, 0, block);
- }
- //extra rows
- while(row < rows){
- for(Index col_off = 0; col_off < 4; col_off++){
- res2(row, col_off) = Eigen::bfloat16(result[(col+col_off)*rows+row]);
- }
- row++;
- }
-
+ convertArrayF32toBF16Col<DataMapper,4>(result, col, rows, res);
}
- //extra cols
+ // extra cols
while(col < cols){
- const LinearMapper res2 = res.getLinearMapper(0, col);
- float *result2 = result + col*rows;
- for(row = 0; row + 8 <= rows; row += 8){
- Packet8bf fp16 = convertF32toBF16(result2 + row);
- res2.template storePacket<Packet8bf>(row, fp16);
- }
- for(; row < rows; row++){
- res2(row) = Eigen::bfloat16(result2[row]);
- }
+ convertArrayF32toBF16Col<DataMapper,1>(result, col, rows, res);
col++;
}
}
@@ -361,134 +303,42 @@
convertArrayBF16toF32<DataMapper>(result, cols, rows, res);
- Index row = 0;
-
if( strideA == -1 ) strideA = depth;
if( strideB == -1 ) strideB = depth;
- //Packing is done in blocks.
- //There's 4 possible sizes of blocks
- //Blocks of 8 columns with 16 elements (8x16)
- //Blocks of 8 columns with 8 elements (8x8). This happens when there's 16 > rows >= 8
- //Blocks of 8 columns with 4 elements (8x4). This happens when there's 8 > rows >= 4
- //Blocks of 8 columns with < 4 elements. This happens when there's less than 4 remaining rows
+ // Packing is done in blocks.
+ // There's 4 possible sizes of blocks
+ // Blocks of 8 columns with 16 elements (8x16)
+ // Blocks of 8 columns with 8 elements (8x8). This happens when there's 16 > rows >= 8
+ // Blocks of 8 columns with 4 elements (8x4). This happens when there's 8 > rows >= 4
+ // Blocks of 8 columns with < 4 elements. This happens when there's less than 4 remaining rows
- //Loop for LHS standard block (8x16)
+ // Loop for LHS standard block (8x16)
Index bigSuffix = (2*8) * (strideA-offsetA);
indexB += 4*offsetB;
strideB *= 4;
offsetB *= 3;
+
+ Index row = 0;
while(row + 16 <= rows){
calcColLoops<16>(indexA, row, depth, cols, rows, pAlpha, indexB, strideB, offsetA, offsetB, bigSuffix, result);
}
- //LHS (8x8) block
+ // LHS (8x8) block
calcColLoops<8>(indexA, row, depth, cols, rows, pAlpha, indexB, strideB, offsetA, offsetB, bigSuffix, result);
- //LHS (8x4) block
+ // LHS (8x4) block
calcColLoops<4>(indexA, row, depth, cols, rows, pAlpha, indexB, strideB, offsetA, offsetB, bigSuffix, result);
- //extra rows
+ // extra rows
if(rows & 3){
- //This index is the beginning of remaining block.
+ // This index is the beginning of remaining block.
colLoops<4, true>(depth, cols, rows, pAlpha, indexA, indexB, strideB, offsetB, result + row);
}
- //Convert back to bfloat16
+ // Convert back to bfloat16
convertArrayF32toBF16<DataMapper>(result, cols, rows, res);
}
-template<const Index size, bool inc, Index delta>
-EIGEN_ALWAYS_INLINE void storeBF16fromResult(bfloat16* dst, Packet8bf data, Index resInc)
-{
- if (inc) {
- if (size == 4) {
- pscatter_partial(dst + delta*resInc, data, resInc, 4);
- } else {
- pscatter(dst + delta*resInc, data, resInc);
- }
- } else {
- if (size == 4) {
- pstoreu_partial(dst + delta, data, 4);
- } else {
- pstoreu(dst + delta, data);
- }
- }
-}
+#undef MAX_BFLOAT16_ACC
-template<const Index size, bool inc>
-EIGEN_ALWAYS_INLINE void convertPointerF32toBF16(Index& i, float* result, Index rows, bfloat16*& dst, Index resInc)
-{
- for(; i + size <= rows; i += size, dst += size*resInc){
- PacketBlock<Packet8bf,(size+4)/8> r32;
- r32.packet[0] = convertF32toBF16<size != 4>(result + i + 0);
- if (size >= 16) {
- r32.packet[1] = convertF32toBF16<true>(result + i + 8);
- }
- if (size >= 32) {
- r32.packet[2] = convertF32toBF16<true>(result + i + 16);
- r32.packet[3] = convertF32toBF16<true>(result + i + 24);
- }
- storeBF16fromResult<size, inc, 0>(dst, r32.packet[0], resInc);
- if (size >= 16) {
- storeBF16fromResult<size, inc, 8>(dst, r32.packet[1], resInc);
- }
- if (size >= 32) {
- storeBF16fromResult<size, inc, 16>(dst, r32.packet[2], resInc);
- storeBF16fromResult<size, inc, 24>(dst, r32.packet[3], resInc);
- }
- }
-}
-
-template<bool inc, Index delta>
-EIGEN_ALWAYS_INLINE Packet8bf loadBF16fromResult(bfloat16* src, Index resInc)
-{
- if (inc) {
- return pgather<bfloat16, Packet8bf>(src + delta*resInc, resInc);
- } else {
- return ploadu<Packet8bf>(src + delta);
- }
-}
-
-template<const Index size, bool inc>
-EIGEN_ALWAYS_INLINE void convertPointerBF16toF32(Index& i, float *result, Index rows, bfloat16*& src, Index resInc)
-{
- for(; i + size <= rows; i += size, src += size*resInc){
- PacketBlock<Packet8bf,(size+4)/8> r32;
- r32.packet[0] = loadBF16fromResult<inc, 0>(src, resInc);
- if (size >= 16) {
- r32.packet[1] = loadBF16fromResult<inc, 8>(src, resInc);
- }
- if (size >= 32) {
- r32.packet[2] = loadBF16fromResult<inc, 16>(src, resInc);
- r32.packet[3] = loadBF16fromResult<inc, 24>(src, resInc);
- }
- storeConvertBlockBF16<size>(result + i, r32);
- }
-}
-
-template<bool inc = false>
-EIGEN_ALWAYS_INLINE void convertArrayPointerBF16toF32(float *result, Index rows, bfloat16* src, Index resInc = 1)
-{
- Index i = 0;
- convertPointerBF16toF32<32, inc>(i, result, rows, src, resInc);
- convertPointerBF16toF32<16, inc>(i, result, rows, src, resInc);
- convertPointerBF16toF32<8, inc>(i, result, rows, src, resInc);
- convertPointerBF16toF32<4, inc>(i, result, rows, src, resInc);
- for(; i < rows; i++, src += resInc){
- result[i] = Eigen::bfloat16_impl::bfloat16_to_float(*src);
- }
-}
-
-template<bool inc = false>
-EIGEN_ALWAYS_INLINE void convertArrayPointerF32toBF16(float *result, Index rows, bfloat16* dst, Index resInc = 1)
-{
- Index i = 0;
- convertPointerF32toBF16<32,inc>(i, result, rows, dst, resInc);
- convertPointerF32toBF16<16,inc>(i, result, rows, dst, resInc);
- convertPointerF32toBF16<8,inc>(i, result, rows, dst, resInc);
- convertPointerF32toBF16<4,inc>(i, result, rows, dst, resInc);
- for(; i < rows; i++, dst += resInc){
- *dst = Eigen::bfloat16(result[i]);
- }
-}
-
+#if !EIGEN_ALTIVEC_DISABLE_MMA
template<bool extraRows>
EIGEN_ALWAYS_INLINE void outputVecCol(Packet4f acc, float *result, Packet4f pAlpha, Index extra_rows)
{
@@ -667,7 +517,7 @@
ei_declare_aligned_stack_constructed_variable(float, result, rows, 0);
- convertArrayPointerBF16toF32(result, rows, res);
+ convertArrayPointerBF16toF32(result, 1, rows, res);
for (Index j2 = 0; j2 < cols; j2 += block_cols)
{
@@ -867,9 +717,9 @@
ei_declare_aligned_stack_constructed_variable(float, result, rows, 0);
if (resIncr == 1) {
- convertArrayPointerBF16toF32(result, rows, res);
+ convertArrayPointerBF16toF32(result, 1, rows, res);
} else {
- convertArrayPointerBF16toF32<true>(result, rows, res, resIncr);
+ convertArrayPointerBF16toF32<true>(result, 1, rows, res, resIncr);
}
calcVecLoops<LhsMapper, LinearMapper>(cols, rows, lhs, rhs2, pAlpha, result);
if (resIncr == 1) {
@@ -878,6 +728,10 @@
convertArrayPointerF32toBF16<true>(result, rows, res, resIncr);
}
}
+#endif
+
+#undef MAX_BFLOAT16_VEC_ACC
+#undef BFLOAT16_UNROLL
}
}
diff --git a/Eigen/src/Core/arch/AltiVec/MatrixVectorProduct.h b/Eigen/src/Core/arch/AltiVec/MatrixVectorProduct.h
index e107335..fa29b34 100644
--- a/Eigen/src/Core/arch/AltiVec/MatrixVectorProduct.h
+++ b/Eigen/src/Core/arch/AltiVec/MatrixVectorProduct.h
@@ -17,8 +17,8 @@
#define USE_GEMV_MMA
#endif
-#if !EIGEN_COMP_LLVM && (__GNUC__ == 10 && __GNUC_MINOR__ <= 3)
-// Only allow one vector_pair in buggy gcc - gcc 10.3 has a bug
+#if !EIGEN_COMP_LLVM && (__GNUC__ < 11)
+// Only allow one vector_pair in buggy gcc - gcc 10.x has a bug
#define GCC_ONE_VECTORPAIR_BUG
#endif
#endif
diff --git a/Eigen/src/Core/arch/AltiVec/PacketMath.h b/Eigen/src/Core/arch/AltiVec/PacketMath.h
index d477ab7..5f4ccda 100644
--- a/Eigen/src/Core/arch/AltiVec/PacketMath.h
+++ b/Eigen/src/Core/arch/AltiVec/PacketMath.h
@@ -35,6 +35,7 @@
typedef __vector __bool int Packet4bi;
typedef __vector short int Packet8s;
typedef __vector unsigned short int Packet8us;
+typedef __vector __bool short Packet8bi;
typedef __vector signed char Packet16c;
typedef __vector unsigned char Packet16uc;
typedef eigen_packet_wrapper<__vector unsigned short int,0> Packet8bf;
@@ -83,10 +84,7 @@
static EIGEN_DECLARE_CONST_FAST_Packet4i(MINUS1,-1); //{ -1, -1, -1, -1}
static EIGEN_DECLARE_CONST_FAST_Packet4ui(SIGN, 0x80000000u);
static EIGEN_DECLARE_CONST_FAST_Packet4ui(PREV0DOT5, 0x3EFFFFFFu);
-#ifndef __POWER8_VECTOR__
static EIGEN_DECLARE_CONST_FAST_Packet8us(ONE,1); //{ 1, 1, 1, 1, 1, 1, 1, 1}
-static EIGEN_DECLARE_CONST_FAST_Packet16uc(ONE,1);
-#endif
static Packet4f p4f_MZERO = (Packet4f) vec_sl((Packet4ui)p4i_MINUS1, (Packet4ui)p4i_MINUS1); //{ 0x80000000, 0x80000000, 0x80000000, 0x80000000}
#ifndef __VSX__
static Packet4f p4f_ONE = vec_ctf(p4i_ONE, 0); //{ 1.0, 1.0, 1.0, 1.0}
@@ -116,6 +114,14 @@
static Packet16uc p16uc_QUADRUPLICATE16_HI = { 0,1,0,1,0,1,0,1, 2,3,2,3,2,3,2,3 };
+static Packet16uc p16uc_MERGEE16 = { 0,1, 16,17, 4,5, 20,21, 8,9, 24,25, 12,13, 28,29 };
+static Packet16uc p16uc_MERGEO16 = { 2,3, 18,19, 6,7, 22,23, 10,11, 26,27, 14,15, 30,31 };
+#ifdef _BIG_ENDIAN
+static Packet16uc p16uc_MERGEH16 = { 0,1, 4,5, 8,9, 12,13, 16,17, 20,21, 24,25, 28,29 };
+#else
+static Packet16uc p16uc_MERGEL16 = { 2,3, 6,7, 10,11, 14,15, 18,19, 22,23, 26,27, 30,31 };
+#endif
+
// Handle endianness properly while loading constants
// Define global static constants:
#ifdef _BIG_ENDIAN
@@ -537,31 +543,20 @@
}
return load;
#else
- EIGEN_ALIGN16 __UNPACK_TYPE__(Packet) load[packet_size];
- unsigned char* load2 = reinterpret_cast<unsigned char *>(load + offset);
- unsigned char* from2 = reinterpret_cast<unsigned char *>(const_cast<__UNPACK_TYPE__(Packet)*>(from));
- Index n2 = n * size;
- Index i = 0;
- if (16 <= n2) {
- pstoreu(load2, ploadu<Packet16uc>(from2));
- i += 16;
+ if (n) {
+ EIGEN_ALIGN16 __UNPACK_TYPE__(Packet) load[packet_size];
+ unsigned char* load2 = reinterpret_cast<unsigned char *>(load + offset);
+ unsigned char* from2 = reinterpret_cast<unsigned char *>(const_cast<__UNPACK_TYPE__(Packet)*>(from));
+ Index n2 = n * size;
+ if (16 <= n2) {
+ pstoreu(load2, ploadu<Packet16uc>(from2));
+ } else {
+ memcpy((void *)load2, (void *)from2, n2);
+ }
+ return pload_ignore<Packet>(load);
+ } else {
+ return Packet(pset1<Packet16uc>(0));
}
- if (i + 8 <= n2) {
- *reinterpret_cast<uint64_t *>(load2 + i) = *reinterpret_cast<uint64_t *>(from2 + i);
- i += 8;
- }
- if (i + 4 <= n2) {
- *reinterpret_cast<uint32_t *>(load2 + i) = *reinterpret_cast<uint32_t *>(from2 + i);
- i += 4;
- }
- if (i + 2 <= n2) {
- *reinterpret_cast<uint16_t *>(load2 + i) = *reinterpret_cast<uint16_t *>(from2 + i);
- i += 2;
- }
- if (i < n2) {
- *reinterpret_cast<uint8_t *>(load2 + i) = *reinterpret_cast<uint8_t *>(from2 + i);
- }
- return pload_ignore<Packet>(load);
#endif
}
@@ -635,7 +630,7 @@
template<> EIGEN_STRONG_INLINE void pstore<bfloat16>(bfloat16* to, const Packet8bf& from)
{
- pstore_common<Packet8us>(reinterpret_cast<unsigned short int*>(to), from);
+ pstore_common<Packet8us>(reinterpret_cast<unsigned short int*>(to), from.m_val);
}
template<> EIGEN_STRONG_INLINE void pstore<signed char>(signed char* to, const Packet16c& from)
@@ -670,30 +665,17 @@
}
vec_xst_len(store, to, n * size);
#else
- EIGEN_ALIGN16 __UNPACK_TYPE__(Packet) store[packet_size];
- pstore(store, from);
- unsigned char* store2 = reinterpret_cast<unsigned char *>(store + offset);
- unsigned char* to2 = reinterpret_cast<unsigned char *>(to);
- Index n2 = n * size;
- Index i = 0;
- if (16 <= n2) {
- pstore(to2, ploadu<Packet16uc>(store2));
- i += 16;
- }
- if (i + 8 <= n2) {
- *reinterpret_cast<uint64_t *>(to2 + i) = *reinterpret_cast<uint64_t *>(store2 + i);
- i += 8;
- }
- if (i + 4 <= n2) {
- *reinterpret_cast<uint32_t *>(to2 + i) = *reinterpret_cast<uint32_t *>(store2 + i);
- i += 4;
- }
- if (i + 2 <= n2) {
- *reinterpret_cast<uint16_t *>(to2 + i) = *reinterpret_cast<uint16_t *>(store2 + i);
- i += 2;
- }
- if (i < n2) {
- *reinterpret_cast<uint8_t *>(to2 + i) = *reinterpret_cast<uint8_t *>(store2 + i);
+ if (n) {
+ EIGEN_ALIGN16 __UNPACK_TYPE__(Packet) store[packet_size];
+ pstore(store, from);
+ unsigned char* store2 = reinterpret_cast<unsigned char *>(store + offset);
+ unsigned char* to2 = reinterpret_cast<unsigned char *>(to);
+ Index n2 = n * size;
+ if (16 <= n2) {
+ pstore(to2, ploadu<Packet16uc>(store2));
+ } else {
+ memcpy((void *)to2, (void *)store2, n2);
+ }
}
#endif
}
@@ -720,7 +702,7 @@
template<> EIGEN_ALWAYS_INLINE void pstore_partial<bfloat16>(bfloat16* to, const Packet8bf& from, const Index n, const Index offset)
{
- pstore_partial_common<Packet8us>(reinterpret_cast<unsigned short int*>(to), from, n, offset);
+ pstore_partial_common<Packet8us>(reinterpret_cast<unsigned short int*>(to), from.m_val, n, offset);
}
template<> EIGEN_ALWAYS_INLINE void pstore_partial<signed char>(signed char* to, const Packet16c& from, const Index n, const Index offset)
@@ -1003,6 +985,22 @@
return vec_xor(a, p4f_MZERO);
#endif
}
+template<> EIGEN_STRONG_INLINE Packet16c pnegate(const Packet16c& a)
+{
+#ifdef __POWER8_VECTOR__
+ return vec_neg(a);
+#else
+ return reinterpret_cast<Packet16c>(p4i_ZERO) - a;
+#endif
+}
+template<> EIGEN_STRONG_INLINE Packet8s pnegate(const Packet8s& a)
+{
+#ifdef __POWER8_VECTOR__
+ return vec_neg(a);
+#else
+ return reinterpret_cast<Packet8s>(p4i_ZERO) - a;
+#endif
+}
template<> EIGEN_STRONG_INLINE Packet4i pnegate(const Packet4i& a)
{
#ifdef __POWER8_VECTOR__
@@ -1102,7 +1100,7 @@
template<> EIGEN_STRONG_INLINE Packet4f pcmp_le(const Packet4f& a, const Packet4f& b) { return reinterpret_cast<Packet4f>(vec_cmple(a,b)); }
// To fix bug with vec_cmplt on older versions
-#if defined(__POWER8_VECTOR__) || EIGEN_COMP_LLVM
+#ifdef EIGEN_VECTORIZE_VSX
template<> EIGEN_STRONG_INLINE Packet4f pcmp_lt(const Packet4f& a, const Packet4f& b) { return reinterpret_cast<Packet4f>(vec_cmplt(a,b)); }
#endif
template<> EIGEN_STRONG_INLINE Packet4f pcmp_eq(const Packet4f& a, const Packet4f& b) { return reinterpret_cast<Packet4f>(vec_cmpeq(a,b)); }
@@ -1256,31 +1254,20 @@
EIGEN_DEBUG_UNALIGNED_LOAD
return vec_xl_len(const_cast<__UNPACK_TYPE__(Packet)*>(from), n * size);
#else
- EIGEN_ALIGN16 __UNPACK_TYPE__(Packet) load[packet_size];
- unsigned char* load2 = reinterpret_cast<unsigned char *>(load);
- unsigned char* from2 = reinterpret_cast<unsigned char *>(const_cast<__UNPACK_TYPE__(Packet)*>(from));
- Index n2 = n * size;
- Index i = 0;
- if (16 <= n2) {
- pstore(load2, ploadu<Packet16uc>(from2));
- i += 16;
+ if (n) {
+ EIGEN_ALIGN16 __UNPACK_TYPE__(Packet) load[packet_size];
+ unsigned char* load2 = reinterpret_cast<unsigned char *>(load);
+ unsigned char* from2 = reinterpret_cast<unsigned char *>(const_cast<__UNPACK_TYPE__(Packet)*>(from));
+ Index n2 = n * size;
+ if (16 <= n2) {
+ pstore(load2, ploadu<Packet16uc>(from2));
+ } else {
+ memcpy((void *)load2, (void *)from2, n2);
+ }
+ return pload_ignore<Packet>(load);
+ } else {
+ return Packet(pset1<Packet16uc>(0));
}
- if (i + 8 <= n2) {
- *reinterpret_cast<uint64_t *>(load2 + i) = *reinterpret_cast<uint64_t *>(from2 + i);
- i += 8;
- }
- if (i + 4 <= n2) {
- *reinterpret_cast<uint32_t *>(load2 + i) = *reinterpret_cast<uint32_t *>(from2 + i);
- i += 4;
- }
- if (i + 2 <= n2) {
- *reinterpret_cast<uint16_t *>(load2 + i) = *reinterpret_cast<uint16_t *>(from2 + i);
- i += 2;
- }
- if (i < n2) {
- *reinterpret_cast<uint8_t *>(load2 + i) = *reinterpret_cast<uint8_t *>(from2 + i);
- }
- return pload_ignore<Packet>(load);
#endif
}
@@ -1422,7 +1409,7 @@
}
template<> EIGEN_STRONG_INLINE void pstoreu<bfloat16>(bfloat16* to, const Packet8bf& from)
{
- pstoreu_common<Packet8us>(reinterpret_cast<unsigned short int*>(to), from);
+ pstoreu_common<Packet8us>(reinterpret_cast<unsigned short int*>(to), from.m_val);
}
template<> EIGEN_STRONG_INLINE void pstoreu<signed char>(signed char* to, const Packet16c& from)
{
@@ -1443,30 +1430,17 @@
EIGEN_DEBUG_UNALIGNED_STORE
vec_xst_len(from, to, n * size);
#else
- EIGEN_ALIGN16 __UNPACK_TYPE__(Packet) store[packet_size];
- pstore(store, from);
- unsigned char* store2 = reinterpret_cast<unsigned char *>(store);
- unsigned char* to2 = reinterpret_cast<unsigned char *>(to);
- Index n2 = n * size;
- Index i = 0;
- if (16 <= n2) {
- pstoreu(to2, pload<Packet16uc>(store2));
- i += 16;
- }
- if (i + 8 <= n2) {
- *reinterpret_cast<uint64_t *>(to2 + i) = *reinterpret_cast<uint64_t *>(store2 + i);
- i += 8;
- }
- if (i + 4 <= n2) {
- *reinterpret_cast<uint32_t *>(to2 + i) = *reinterpret_cast<uint32_t *>(store2 + i);
- i += 4;
- }
- if (i + 2 <= n2) {
- *reinterpret_cast<uint16_t *>(to2 + i) = *reinterpret_cast<uint16_t *>(store2 + i);
- i += 2;
- }
- if (i < n2) {
- *reinterpret_cast<uint8_t *>(to2 + i) = *reinterpret_cast<uint8_t *>(store2 + i);
+ if (n) {
+ EIGEN_ALIGN16 __UNPACK_TYPE__(Packet) store[packet_size];
+ pstore(store, from);
+ unsigned char* store2 = reinterpret_cast<unsigned char *>(store);
+ unsigned char* to2 = reinterpret_cast<unsigned char *>(to);
+ Index n2 = n * size;
+ if (16 <= n2) {
+ pstoreu(to2, pload<Packet16uc>(store2));
+ } else {
+ memcpy((void *)to2, (void *)store2, n2);
+ }
}
#endif
}
@@ -1636,17 +1610,37 @@
);
}
+EIGEN_ALWAYS_INLINE Packet8us pmerge(Packet4ui even, Packet4ui odd) {
+#ifdef _BIG_ENDIAN
+ return vec_perm(reinterpret_cast<Packet8us>(odd), reinterpret_cast<Packet8us>(even), p16uc_MERGEO16);
+#else
+ return vec_perm(reinterpret_cast<Packet8us>(even), reinterpret_cast<Packet8us>(odd), p16uc_MERGEE16);
+#endif
+}
+
// Simple interleaving of bool masks, prevents true values from being
// converted to NaNs.
EIGEN_STRONG_INLINE Packet8bf F32ToBf16Bool(Packet4f even, Packet4f odd) {
- const EIGEN_DECLARE_CONST_FAST_Packet4ui(high_mask, 0xFFFF0000);
- Packet4f bf_odd, bf_even;
- bf_odd = pand(reinterpret_cast<Packet4f>(p4ui_high_mask), odd);
- bf_even = plogical_shift_right<16>(even);
- return reinterpret_cast<Packet8us>(por<Packet4f>(bf_even, bf_odd));
+ return pmerge(reinterpret_cast<Packet4ui>(even), reinterpret_cast<Packet4ui>(odd));
}
+//#define SUPPORT_BF16_SUBNORMALS
+
+#ifndef __VEC_CLASS_FP_NAN
+#define __VEC_CLASS_FP_NAN (1<<6)
+#endif
+
+#if defined(SUPPORT_BF16_SUBNORMALS) && !defined(__VEC_CLASS_FP_SUBNORMAL)
+#define __VEC_CLASS_FP_SUBNORMAL_P (1<<1)
+#define __VEC_CLASS_FP_SUBNORMAL_N (1<<0)
+
+#define __VEC_CLASS_FP_SUBNORMAL (__VEC_CLASS_FP_SUBNORMAL_P | __VEC_CLASS_FP_SUBNORMAL_N)
+#endif
+
EIGEN_STRONG_INLINE Packet8bf F32ToBf16(Packet4f p4f){
+#ifdef _ARCH_PWR10
+ return reinterpret_cast<Packet8us>(__builtin_vsx_xvcvspbf16(reinterpret_cast<Packet16uc>(p4f)));
+#else
Packet4ui input = reinterpret_cast<Packet4ui>(p4f);
Packet4ui lsb = plogical_shift_right<16>(input);
lsb = pand<Packet4ui>(lsb, reinterpret_cast<Packet4ui>(p4i_ONE));
@@ -1655,43 +1649,202 @@
Packet4ui rounding_bias = padd<Packet4ui>(lsb, p4ui_BIAS);
input = padd<Packet4ui>(input, rounding_bias);
- //Test NaN and Subnormal - Begin
+ const EIGEN_DECLARE_CONST_FAST_Packet4ui(nan, 0x7FC00000);
+#ifdef _ARCH_PWR9
+ Packet4bi nan_selector = vec_test_data_class(p4f, __VEC_CLASS_FP_NAN);
+ input = vec_sel(input, p4ui_nan, nan_selector);
+
+#ifdef SUPPORT_BF16_SUBNORMALS
+ Packet4bi subnormal_selector = vec_test_data_class(p4f, __VEC_CLASS_FP_SUBNORMAL);
+ input = vec_sel(input, reinterpret_cast<Packet4ui>(p4f), subnormal_selector);
+#endif
+#else
+#ifdef SUPPORT_BF16_SUBNORMALS
+ //Test NaN and Subnormal
const EIGEN_DECLARE_CONST_FAST_Packet4ui(exp_mask, 0x7F800000);
Packet4ui exp = pand<Packet4ui>(p4ui_exp_mask, reinterpret_cast<Packet4ui>(p4f));
const EIGEN_DECLARE_CONST_FAST_Packet4ui(mantissa_mask, 0x7FFFFF);
Packet4ui mantissa = pand<Packet4ui>(p4ui_mantissa_mask, reinterpret_cast<Packet4ui>(p4f));
- const EIGEN_DECLARE_CONST_FAST_Packet4ui(max_exp, 0x7F800000);
- Packet4bi is_max_exp = vec_cmpeq(exp, p4ui_max_exp);
- Packet4bi is_zero_exp = vec_cmpeq(exp, reinterpret_cast<Packet4ui>(p4i_ZERO));
-
+ Packet4bi is_max_exp = vec_cmpeq(exp, p4ui_exp_mask);
Packet4bi is_mant_zero = vec_cmpeq(mantissa, reinterpret_cast<Packet4ui>(p4i_ZERO));
+
Packet4ui nan_selector = pandnot<Packet4ui>(
reinterpret_cast<Packet4ui>(is_max_exp),
reinterpret_cast<Packet4ui>(is_mant_zero)
);
+ Packet4bi is_zero_exp = vec_cmpeq(exp, reinterpret_cast<Packet4ui>(p4i_ZERO));
+
Packet4ui subnormal_selector = pandnot<Packet4ui>(
reinterpret_cast<Packet4ui>(is_zero_exp),
reinterpret_cast<Packet4ui>(is_mant_zero)
);
- const EIGEN_DECLARE_CONST_FAST_Packet4ui(nan, 0x7FC00000);
input = vec_sel(input, p4ui_nan, nan_selector);
input = vec_sel(input, reinterpret_cast<Packet4ui>(p4f), subnormal_selector);
- //Test NaN and Subnormal - End
+#else
+ //Test only NaN
+ Packet4bi nan_selector = vec_cmpeq(p4f, p4f);
+
+ input = vec_sel(p4ui_nan, input, nan_selector);
+#endif
+#endif
input = plogical_shift_right<16>(input);
return reinterpret_cast<Packet8us>(input);
+#endif
}
+#ifdef _BIG_ENDIAN
+/**
+ * Pack the high portion of two float Packets into one bfloat16 Packet
+ *
+ * @param lohi to expect either a low & high OR odd & even order
+ */
+template<bool lohi>
+EIGEN_ALWAYS_INLINE Packet8bf Bf16PackHigh(Packet4f lo, Packet4f hi)
+{
+ if (lohi) {
+ return vec_perm(reinterpret_cast<Packet8us>(lo), reinterpret_cast<Packet8us>(hi), p16uc_MERGEH16);
+ } else {
+ return vec_perm(reinterpret_cast<Packet8us>(hi), reinterpret_cast<Packet8us>(lo), p16uc_MERGEE16);
+ }
+}
+
+/**
+ * Pack the low portion of two float Packets into one bfloat16 Packet
+ *
+ * @param lohi to expect either a low & high OR odd & even order
+ */
+template<bool lohi>
+EIGEN_ALWAYS_INLINE Packet8bf Bf16PackLow(Packet4f lo, Packet4f hi)
+{
+ if (lohi) {
+ return vec_pack(reinterpret_cast<Packet4ui>(lo), reinterpret_cast<Packet4ui>(hi));
+ } else {
+ return vec_perm(reinterpret_cast<Packet8us>(hi), reinterpret_cast<Packet8us>(lo), p16uc_MERGEO16);
+ }
+}
+#else
+template<bool lohi>
+EIGEN_ALWAYS_INLINE Packet8bf Bf16PackLow(Packet4f hi, Packet4f lo)
+{
+ if (lohi) {
+ return vec_pack(reinterpret_cast<Packet4ui>(hi), reinterpret_cast<Packet4ui>(lo));
+ } else {
+ return vec_perm(reinterpret_cast<Packet8us>(hi), reinterpret_cast<Packet8us>(lo), p16uc_MERGEE16);
+ }
+}
+
+template<bool lohi>
+EIGEN_ALWAYS_INLINE Packet8bf Bf16PackHigh(Packet4f hi, Packet4f lo)
+{
+ if (lohi) {
+ return vec_perm(reinterpret_cast<Packet8us>(hi), reinterpret_cast<Packet8us>(lo), p16uc_MERGEL16);
+ } else {
+ return vec_perm(reinterpret_cast<Packet8us>(hi), reinterpret_cast<Packet8us>(lo), p16uc_MERGEO16);
+ }
+}
+#endif
+
+/**
+ * Convert and pack two float Packets into one bfloat16 Packet
+ *
+ * @param lohi to expect either a low & high OR odd & even order
+ */
+template<bool lohi = true>
+EIGEN_ALWAYS_INLINE Packet8bf F32ToBf16Two(Packet4f lo, Packet4f hi)
+{
+ Packet8us p4f = Bf16PackHigh<lohi>(lo, hi);
+ Packet8us p4f2 = Bf16PackLow<lohi>(lo, hi);
+
+ Packet8us lsb = pand<Packet8us>(p4f, p8us_ONE);
+ EIGEN_DECLARE_CONST_FAST_Packet8us(BIAS,0x7FFFu);
+ lsb = padd<Packet8us>(lsb, p8us_BIAS);
+ lsb = padd<Packet8us>(lsb, p4f2);
+
+ Packet8bi rounding_bias = vec_cmplt(lsb, p4f2);
+ Packet8us input = psub<Packet8us>(p4f, reinterpret_cast<Packet8us>(rounding_bias));
+
+#ifdef _ARCH_PWR9
+ Packet4bi nan_selector_lo = vec_test_data_class(lo, __VEC_CLASS_FP_NAN);
+ Packet4bi nan_selector_hi = vec_test_data_class(hi, __VEC_CLASS_FP_NAN);
+ Packet8us nan_selector = Bf16PackLow<lohi>(reinterpret_cast<Packet4f>(nan_selector_lo), reinterpret_cast<Packet4f>(nan_selector_hi));
+
+ input = vec_sel(input, p8us_BIAS, nan_selector);
+
+#ifdef SUPPORT_BF16_SUBNORMALS
+ Packet4bi subnormal_selector_lo = vec_test_data_class(lo, __VEC_CLASS_FP_SUBNORMAL);
+ Packet4bi subnormal_selector_hi = vec_test_data_class(hi, __VEC_CLASS_FP_SUBNORMAL);
+ Packet8us subnormal_selector = Bf16PackLow<lohi>(reinterpret_cast<Packet4f>(subnormal_selector_lo), reinterpret_cast<Packet4f>(subnormal_selector_hi));
+
+ input = vec_sel(input, reinterpret_cast<Packet8us>(p4f), subnormal_selector);
+#endif
+#else
+#ifdef SUPPORT_BF16_SUBNORMALS
+ //Test NaN and Subnormal
+ const EIGEN_DECLARE_CONST_FAST_Packet8us(exp_mask, 0x7F80);
+ Packet8us exp = pand<Packet8us>(p8us_exp_mask, p4f);
+
+ const EIGEN_DECLARE_CONST_FAST_Packet8us(mantissa_mask, 0x7Fu);
+ Packet8us mantissa = pand<Packet8us>(p8us_mantissa_mask, p4f);
+
+ Packet8bi is_max_exp = vec_cmpeq(exp, p8us_exp_mask);
+ Packet8bi is_mant_zero = vec_cmpeq(mantissa, reinterpret_cast<Packet8us>(p4i_ZERO));
+
+ Packet8us nan_selector = pandnot<Packet8us>(
+ reinterpret_cast<Packet8us>(is_max_exp),
+ reinterpret_cast<Packet8us>(is_mant_zero)
+ );
+
+ Packet8bi is_zero_exp = vec_cmpeq(exp, reinterpret_cast<Packet8us>(p4i_ZERO));
+
+ Packet8us subnormal_selector = pandnot<Packet8us>(
+ reinterpret_cast<Packet8us>(is_zero_exp),
+ reinterpret_cast<Packet8us>(is_mant_zero)
+ );
+
+ // Using BIAS as NaN (since any or all of the last 7 bits can be set)
+ input = vec_sel(input, p8us_BIAS, nan_selector);
+ input = vec_sel(input, reinterpret_cast<Packet8us>(p4f), subnormal_selector);
+#else
+ //Test only NaN
+ Packet4bi nan_selector_lo = vec_cmpeq(lo, lo);
+ Packet4bi nan_selector_hi = vec_cmpeq(hi, hi);
+ Packet8us nan_selector = Bf16PackLow<lohi>(reinterpret_cast<Packet4f>(nan_selector_lo), reinterpret_cast<Packet4f>(nan_selector_hi));
+
+ input = vec_sel(p8us_BIAS, input, nan_selector);
+#endif
+#endif
+
+ return input;
+}
+
+/**
+ * Convert and pack two float Packets into one bfloat16 Packet - low & high order
+ */
+EIGEN_STRONG_INLINE Packet8bf F32ToBf16Both(Packet4f lo, Packet4f hi)
+{
+#ifdef _ARCH_PWR10
+ Packet8bf fp16_0 = F32ToBf16(lo);
+ Packet8bf fp16_1 = F32ToBf16(hi);
+ return vec_pack(reinterpret_cast<Packet4ui>(fp16_0.m_val), reinterpret_cast<Packet4ui>(fp16_1.m_val));
+#else
+ return F32ToBf16Two(lo, hi);
+#endif
+}
+
+/**
+ * Convert and pack two float Packets into one bfloat16 Packet - odd & even order
+ */
EIGEN_STRONG_INLINE Packet8bf F32ToBf16(Packet4f even, Packet4f odd){
- Packet4f bf_odd, bf_even;
- bf_odd = reinterpret_cast<Packet4f>(F32ToBf16(odd).m_val);
- bf_odd = plogical_shift_left<16>(bf_odd);
- bf_even = reinterpret_cast<Packet4f>(F32ToBf16(even).m_val);
- return reinterpret_cast<Packet8us>(por<Packet4f>(bf_even, bf_odd));
+#ifdef _ARCH_PWR10
+ return pmerge(reinterpret_cast<Packet4ui>(F32ToBf16(even).m_val), reinterpret_cast<Packet4ui>(F32ToBf16(odd).m_val));
+#else
+ return F32ToBf16Two<false>(even, odd);
+#endif
}
#define BF16_TO_F32_UNARY_OP_WRAPPER(OP, A) \
Packet4f a_even = Bf16ToF32Even(A);\
@@ -2493,11 +2646,7 @@
template<typename Packet> EIGEN_STRONG_INLINE
Packet pblend4(const Selector<4>& ifPacket, const Packet& thenPacket, const Packet& elsePacket) {
Packet4ui select = { ifPacket.select[0], ifPacket.select[1], ifPacket.select[2], ifPacket.select[3] };
-#ifdef __POWER8_VECTOR__
- Packet4ui mask = reinterpret_cast<Packet4ui>(vec_neg(reinterpret_cast<Packet4i>(select)));
-#else
- Packet4ui mask = reinterpret_cast<Packet4ui>(vec_cmpeq(reinterpret_cast<Packet4ui>(select), reinterpret_cast<Packet4ui>(p4i_ONE)));
-#endif
+ Packet4ui mask = reinterpret_cast<Packet4ui>(pnegate(reinterpret_cast<Packet4i>(select)));
return vec_sel(elsePacket, thenPacket, mask);
}
@@ -2512,11 +2661,7 @@
template<> EIGEN_STRONG_INLINE Packet8s pblend(const Selector<8>& ifPacket, const Packet8s& thenPacket, const Packet8s& elsePacket) {
Packet8us select = { ifPacket.select[0], ifPacket.select[1], ifPacket.select[2], ifPacket.select[3],
ifPacket.select[4], ifPacket.select[5], ifPacket.select[6], ifPacket.select[7] };
-#ifdef __POWER8_VECTOR__
- Packet8us mask = reinterpret_cast<Packet8us>(vec_neg(reinterpret_cast<Packet8s>(select)));
-#else
- Packet8us mask = reinterpret_cast<Packet8us>(vec_cmpeq(select, p8us_ONE));
-#endif
+ Packet8us mask = reinterpret_cast<Packet8us>(pnegate(reinterpret_cast<Packet8s>(select)));
Packet8s result = vec_sel(elsePacket, thenPacket, mask);
return result;
}
@@ -2524,11 +2669,7 @@
template<> EIGEN_STRONG_INLINE Packet8us pblend(const Selector<8>& ifPacket, const Packet8us& thenPacket, const Packet8us& elsePacket) {
Packet8us select = { ifPacket.select[0], ifPacket.select[1], ifPacket.select[2], ifPacket.select[3],
ifPacket.select[4], ifPacket.select[5], ifPacket.select[6], ifPacket.select[7] };
-#ifdef __POWER8_VECTOR__
- Packet8us mask = reinterpret_cast<Packet8us>(vec_neg(reinterpret_cast<Packet8s>(select)));
-#else
- Packet8us mask = reinterpret_cast<Packet8us>(vec_cmpeq(reinterpret_cast<Packet8us>(select), p8us_ONE));
-#endif
+ Packet8us mask = reinterpret_cast<Packet8us>(pnegate(reinterpret_cast<Packet8s>(select)));
return vec_sel(elsePacket, thenPacket, mask);
}
@@ -2542,11 +2683,7 @@
ifPacket.select[8], ifPacket.select[9], ifPacket.select[10], ifPacket.select[11],
ifPacket.select[12], ifPacket.select[13], ifPacket.select[14], ifPacket.select[15] };
-#ifdef __POWER8_VECTOR__
- Packet16uc mask = reinterpret_cast<Packet16uc>(vec_neg(reinterpret_cast<Packet16c>(select)));
-#else
- Packet16uc mask = reinterpret_cast<Packet16uc>(vec_cmpeq(reinterpret_cast<Packet16uc>(select), p16uc_ONE));
-#endif
+ Packet16uc mask = reinterpret_cast<Packet16uc>(pnegate(reinterpret_cast<Packet16c>(select)));
return vec_sel(elsePacket, thenPacket, mask);
}
@@ -2556,113 +2693,10 @@
ifPacket.select[8], ifPacket.select[9], ifPacket.select[10], ifPacket.select[11],
ifPacket.select[12], ifPacket.select[13], ifPacket.select[14], ifPacket.select[15] };
-#ifdef __POWER8_VECTOR__
- Packet16uc mask = reinterpret_cast<Packet16uc>(vec_neg(reinterpret_cast<Packet16c>(select)));
-#else
- Packet16uc mask = reinterpret_cast<Packet16uc>(vec_cmpeq(reinterpret_cast<Packet16uc>(select), p16uc_ONE));
-#endif
+ Packet16uc mask = reinterpret_cast<Packet16uc>(pnegate(reinterpret_cast<Packet16c>(select)));
return vec_sel(elsePacket, thenPacket, mask);
}
-template <>
-struct type_casting_traits<float, int> {
- enum {
- VectorizedCast = 1,
- SrcCoeffRatio = 1,
- TgtCoeffRatio = 1
- };
-};
-
-template <>
-struct type_casting_traits<int, float> {
- enum {
- VectorizedCast = 1,
- SrcCoeffRatio = 1,
- TgtCoeffRatio = 1
- };
-};
-
-template <>
-struct type_casting_traits<bfloat16, unsigned short int> {
- enum {
- VectorizedCast = 1,
- SrcCoeffRatio = 1,
- TgtCoeffRatio = 1
- };
-};
-
-template <>
-struct type_casting_traits<unsigned short int, bfloat16> {
- enum {
- VectorizedCast = 1,
- SrcCoeffRatio = 1,
- TgtCoeffRatio = 1
- };
-};
-
-template<> EIGEN_STRONG_INLINE Packet4i pcast<Packet4f, Packet4i>(const Packet4f& a) {
- return vec_cts(a,0);
-}
-
-template<> EIGEN_STRONG_INLINE Packet4ui pcast<Packet4f, Packet4ui>(const Packet4f& a) {
- return vec_ctu(a,0);
-}
-
-template<> EIGEN_STRONG_INLINE Packet4f pcast<Packet4i, Packet4f>(const Packet4i& a) {
- return vec_ctf(a,0);
-}
-
-template<> EIGEN_STRONG_INLINE Packet4f pcast<Packet4ui, Packet4f>(const Packet4ui& a) {
- return vec_ctf(a,0);
-}
-
-template<> EIGEN_STRONG_INLINE Packet8us pcast<Packet8bf, Packet8us>(const Packet8bf& a) {
- Packet4f float_even = Bf16ToF32Even(a);
- Packet4f float_odd = Bf16ToF32Odd(a);
- Packet4ui int_even = pcast<Packet4f, Packet4ui>(float_even);
- Packet4ui int_odd = pcast<Packet4f, Packet4ui>(float_odd);
- const EIGEN_DECLARE_CONST_FAST_Packet4ui(low_mask, 0x0000FFFF);
- Packet4ui low_even = pand<Packet4ui>(int_even, p4ui_low_mask);
- Packet4ui low_odd = pand<Packet4ui>(int_odd, p4ui_low_mask);
-
- //Check values that are bigger than USHRT_MAX (0xFFFF)
- Packet4bi overflow_selector;
- if(vec_any_gt(int_even, p4ui_low_mask)){
- overflow_selector = vec_cmpgt(int_even, p4ui_low_mask);
- low_even = vec_sel(low_even, p4ui_low_mask, overflow_selector);
- }
- if(vec_any_gt(int_odd, p4ui_low_mask)){
- overflow_selector = vec_cmpgt(int_odd, p4ui_low_mask);
- low_odd = vec_sel(low_even, p4ui_low_mask, overflow_selector);
- }
-
- low_odd = plogical_shift_left<16>(low_odd);
-
- Packet4ui int_final = por<Packet4ui>(low_even, low_odd);
- return reinterpret_cast<Packet8us>(int_final);
-}
-
-template<> EIGEN_STRONG_INLINE Packet8bf pcast<Packet8us, Packet8bf>(const Packet8us& a) {
- //short -> int -> float -> bfloat16
- const EIGEN_DECLARE_CONST_FAST_Packet4ui(low_mask, 0x0000FFFF);
- Packet4ui int_cast = reinterpret_cast<Packet4ui>(a);
- Packet4ui int_even = pand<Packet4ui>(int_cast, p4ui_low_mask);
- Packet4ui int_odd = plogical_shift_right<16>(int_cast);
- Packet4f float_even = pcast<Packet4ui, Packet4f>(int_even);
- Packet4f float_odd = pcast<Packet4ui, Packet4f>(int_odd);
- return F32ToBf16(float_even, float_odd);
-}
-
-
-template<> EIGEN_STRONG_INLINE Packet4i preinterpret<Packet4i,Packet4f>(const Packet4f& a) {
- return reinterpret_cast<Packet4i>(a);
-}
-
-template<> EIGEN_STRONG_INLINE Packet4f preinterpret<Packet4f,Packet4i>(const Packet4i& a) {
- return reinterpret_cast<Packet4f>(a);
-}
-
-
//---------- double ----------
#ifdef EIGEN_VECTORIZE_VSX
@@ -2675,7 +2709,6 @@
typedef __vector __bool long Packet2bl;
#endif
-static Packet2l p2l_ONE = { 1, 1 };
static Packet2l p2l_ZERO = reinterpret_cast<Packet2l>(p4i_ZERO);
static Packet2ul p2ul_SIGN = { 0x8000000000000000ull, 0x8000000000000000ull };
static Packet2ul p2ul_PREV0DOT5 = { 0x3FDFFFFFFFFFFFFFull, 0x3FDFFFFFFFFFFFFFull };
@@ -2937,35 +2970,25 @@
return vec_sld(a, a, 8);
}
template<> EIGEN_STRONG_INLINE Packet2d pabs(const Packet2d& a) { return vec_abs(a); }
+#ifdef __POWER8_VECTOR__
template<> EIGEN_STRONG_INLINE Packet2d psignbit(const Packet2d& a) { return (Packet2d)vec_sra((Packet2l)a, vec_splats((unsigned long long)(63))); }
-// VSX support varies between different compilers and even different
-// versions of the same compiler. For gcc version >= 4.9.3, we can use
-// vec_cts to efficiently convert Packet2d to Packet2l. Otherwise, use
-// a slow version that works with older compilers.
-// Update: apparently vec_cts/vec_ctf intrinsics for 64-bit doubles
-// are buggy, https://gcc.gnu.org/bugzilla/show_bug.cgi?id=70963
-template<>
-inline Packet2l pcast<Packet2d, Packet2l>(const Packet2d& x) {
-#if EIGEN_GNUC_STRICT_AT_LEAST(7,1,0)
- return vec_cts(x, 0); // TODO: check clang version.
#else
- double tmp[2];
- memcpy(tmp, &x, sizeof(tmp));
- Packet2l l = { static_cast<long long>(tmp[0]),
- static_cast<long long>(tmp[1]) };
- return l;
+#ifdef _BIG_ENDIAN
+static Packet16uc p16uc_DUPSIGN = { 0,0,0,0, 0,0,0,0, 8,8,8,8, 8,8,8,8 };
+#else
+static Packet16uc p16uc_DUPSIGN = { 7,7,7,7, 7,7,7,7, 15,15,15,15, 15,15,15,15 };
#endif
-}
-template<>
-inline Packet2d pcast<Packet2l, Packet2d>(const Packet2l& x) {
- unsigned long long tmp[2];
- memcpy(tmp, &x, sizeof(tmp));
- Packet2d d = { static_cast<double>(tmp[0]),
- static_cast<double>(tmp[1]) };
- return d;
+template<> EIGEN_STRONG_INLINE Packet2d psignbit(const Packet2d& a)
+{
+ Packet16c tmp = vec_sra(reinterpret_cast<Packet16c>(a), vec_splats((unsigned char)(7)));
+ return reinterpret_cast<Packet2d>(vec_perm(tmp, tmp, p16uc_DUPSIGN));
}
+#endif
+template<> inline Packet2l pcast<Packet2d, Packet2l>(const Packet2d& x);
+
+template<> inline Packet2d pcast<Packet2l, Packet2d>(const Packet2l& x);
// Packet2l shifts.
// For POWER8 we simply use vec_sr/l.
@@ -3146,7 +3169,7 @@
template<> EIGEN_STRONG_INLINE Packet2d pblend(const Selector<2>& ifPacket, const Packet2d& thenPacket, const Packet2d& elsePacket) {
Packet2l select = { ifPacket.select[0], ifPacket.select[1] };
- Packet2bl mask = reinterpret_cast<Packet2bl>( vec_cmpeq(reinterpret_cast<Packet2d>(select), reinterpret_cast<Packet2d>(p2l_ONE)) );
+ Packet2ul mask = reinterpret_cast<Packet2ul>(pnegate(reinterpret_cast<Packet2l>(select)));
return vec_sel(elsePacket, thenPacket, mask);
}
diff --git a/Eigen/src/Core/arch/AltiVec/TypeCasting.h b/Eigen/src/Core/arch/AltiVec/TypeCasting.h
new file mode 100644
index 0000000..bda63d8
--- /dev/null
+++ b/Eigen/src/Core/arch/AltiVec/TypeCasting.h
@@ -0,0 +1,178 @@
+// This file is part of Eigen, a lightweight C++ template library
+// for linear algebra.
+//
+// Copyright (C) 2019 Rasmus Munk Larsen <rmlarsen@google.com>
+// Copyright (C) 2023 Chip Kerchner (chip.kerchner@ibm.com)
+//
+// This Source Code Form is subject to the terms of the Mozilla
+// Public License v. 2.0. If a copy of the MPL was not distributed
+// with this file, You can obtain one at http://mozilla.org/MPL/2.0/.
+
+#ifndef EIGEN_TYPE_CASTING_ALTIVEC_H
+#define EIGEN_TYPE_CASTING_ALTIVEC_H
+
+#include "../../InternalHeaderCheck.h"
+
+namespace Eigen {
+
+namespace internal {
+template <>
+struct type_casting_traits<float, int> {
+ enum {
+ VectorizedCast = 1,
+ SrcCoeffRatio = 1,
+ TgtCoeffRatio = 1
+ };
+};
+
+template <>
+struct type_casting_traits<int, float> {
+ enum {
+ VectorizedCast = 1,
+ SrcCoeffRatio = 1,
+ TgtCoeffRatio = 1
+ };
+};
+
+template <>
+struct type_casting_traits<bfloat16, unsigned short int> {
+ enum {
+ VectorizedCast = 1,
+ SrcCoeffRatio = 1,
+ TgtCoeffRatio = 1
+ };
+};
+
+template <>
+struct type_casting_traits<unsigned short int, bfloat16> {
+ enum {
+ VectorizedCast = 1,
+ SrcCoeffRatio = 1,
+ TgtCoeffRatio = 1
+ };
+};
+
+template<> EIGEN_STRONG_INLINE Packet4i pcast<Packet4f, Packet4i>(const Packet4f& a) {
+ return vec_cts(a,0);
+}
+
+template<> EIGEN_STRONG_INLINE Packet4ui pcast<Packet4f, Packet4ui>(const Packet4f& a) {
+ return vec_ctu(a,0);
+}
+
+template<> EIGEN_STRONG_INLINE Packet4f pcast<Packet4i, Packet4f>(const Packet4i& a) {
+ return vec_ctf(a,0);
+}
+
+template<> EIGEN_STRONG_INLINE Packet4f pcast<Packet4ui, Packet4f>(const Packet4ui& a) {
+ return vec_ctf(a,0);
+}
+
+template<> EIGEN_STRONG_INLINE Packet8us pcast<Packet8bf, Packet8us>(const Packet8bf& a) {
+ Packet4f float_even = Bf16ToF32Even(a);
+ Packet4f float_odd = Bf16ToF32Odd(a);
+ Packet4ui int_even = pcast<Packet4f, Packet4ui>(float_even);
+ Packet4ui int_odd = pcast<Packet4f, Packet4ui>(float_odd);
+ const EIGEN_DECLARE_CONST_FAST_Packet4ui(low_mask, 0x0000FFFF);
+ Packet4ui low_even = pand<Packet4ui>(int_even, p4ui_low_mask);
+ Packet4ui low_odd = pand<Packet4ui>(int_odd, p4ui_low_mask);
+
+ //Check values that are bigger than USHRT_MAX (0xFFFF)
+ Packet4bi overflow_selector;
+ if(vec_any_gt(int_even, p4ui_low_mask)){
+ overflow_selector = vec_cmpgt(int_even, p4ui_low_mask);
+ low_even = vec_sel(low_even, p4ui_low_mask, overflow_selector);
+ }
+ if(vec_any_gt(int_odd, p4ui_low_mask)){
+ overflow_selector = vec_cmpgt(int_odd, p4ui_low_mask);
+ low_odd = vec_sel(low_even, p4ui_low_mask, overflow_selector);
+ }
+
+ return pmerge(low_even, low_odd);
+}
+
+template<> EIGEN_STRONG_INLINE Packet8bf pcast<Packet8us, Packet8bf>(const Packet8us& a) {
+ //short -> int -> float -> bfloat16
+ const EIGEN_DECLARE_CONST_FAST_Packet4ui(low_mask, 0x0000FFFF);
+ Packet4ui int_cast = reinterpret_cast<Packet4ui>(a);
+ Packet4ui int_even = pand<Packet4ui>(int_cast, p4ui_low_mask);
+ Packet4ui int_odd = plogical_shift_right<16>(int_cast);
+ Packet4f float_even = pcast<Packet4ui, Packet4f>(int_even);
+ Packet4f float_odd = pcast<Packet4ui, Packet4f>(int_odd);
+ return F32ToBf16(float_even, float_odd);
+}
+
+template <>
+struct type_casting_traits<bfloat16, float> {
+ enum {
+ VectorizedCast = 1,
+ SrcCoeffRatio = 1,
+ TgtCoeffRatio = 2
+ };
+};
+
+template<> EIGEN_STRONG_INLINE Packet4f pcast<Packet8bf, Packet4f>(const Packet8bf& a) {
+ Packet8us z = pset1<Packet8us>(0);
+#ifdef _BIG_ENDIAN
+ return reinterpret_cast<Packet4f>(vec_mergeh(a.m_val, z));
+#else
+ return reinterpret_cast<Packet4f>(vec_mergeh(z, a.m_val));
+#endif
+}
+
+template <>
+struct type_casting_traits<float, bfloat16> {
+ enum {
+ VectorizedCast = 1,
+ SrcCoeffRatio = 2,
+ TgtCoeffRatio = 1
+ };
+};
+
+template<> EIGEN_STRONG_INLINE Packet8bf pcast<Packet4f, Packet8bf>(const Packet4f& a, const Packet4f &b) {
+ return F32ToBf16Both(a, b);
+}
+
+template<> EIGEN_STRONG_INLINE Packet4i preinterpret<Packet4i,Packet4f>(const Packet4f& a) {
+ return reinterpret_cast<Packet4i>(a);
+}
+
+template<> EIGEN_STRONG_INLINE Packet4f preinterpret<Packet4f,Packet4i>(const Packet4i& a) {
+ return reinterpret_cast<Packet4f>(a);
+}
+
+#ifdef EIGEN_VECTORIZE_VSX
+// VSX support varies between different compilers and even different
+// versions of the same compiler. For gcc version >= 4.9.3, we can use
+// vec_cts to efficiently convert Packet2d to Packet2l. Otherwise, use
+// a slow version that works with older compilers.
+// Update: apparently vec_cts/vec_ctf intrinsics for 64-bit doubles
+// are buggy, https://gcc.gnu.org/bugzilla/show_bug.cgi?id=70963
+template<>
+inline Packet2l pcast<Packet2d, Packet2l>(const Packet2d& x) {
+#if EIGEN_GNUC_STRICT_AT_LEAST(7,1,0)
+ return vec_cts(x, 0); // TODO: check clang version.
+#else
+ double tmp[2];
+ memcpy(tmp, &x, sizeof(tmp));
+ Packet2l l = { static_cast<long long>(tmp[0]),
+ static_cast<long long>(tmp[1]) };
+ return l;
+#endif
+}
+
+template<>
+inline Packet2d pcast<Packet2l, Packet2d>(const Packet2l& x) {
+ unsigned long long tmp[2];
+ memcpy(tmp, &x, sizeof(tmp));
+ Packet2d d = { static_cast<double>(tmp[0]),
+ static_cast<double>(tmp[1]) };
+ return d;
+}
+#endif
+
+} // end namespace internal
+
+} // end namespace Eigen
+
+#endif // EIGEN_TYPE_CASTING_ALTIVEC_H
diff --git a/Eigen/src/Core/arch/SSE/PacketMath.h b/Eigen/src/Core/arch/SSE/PacketMath.h
index 499c16b..af085a0 100644
--- a/Eigen/src/Core/arch/SSE/PacketMath.h
+++ b/Eigen/src/Core/arch/SSE/PacketMath.h
@@ -10,6 +10,7 @@
#ifndef EIGEN_PACKET_MATH_SSE_H
#define EIGEN_PACKET_MATH_SSE_H
+#include <cstdint>
#include "../../InternalHeaderCheck.h"
namespace Eigen {
@@ -47,11 +48,16 @@
typedef eigen_packet_wrapper<__m128i, 0> Packet4i;
typedef eigen_packet_wrapper<__m128i, 1> Packet16b;
+typedef eigen_packet_wrapper<__m128i, 4> Packet4ui;
template<> struct is_arithmetic<__m128> { enum { value = true }; };
template<> struct is_arithmetic<__m128i> { enum { value = true }; };
template<> struct is_arithmetic<__m128d> { enum { value = true }; };
template<> struct is_arithmetic<Packet4i> { enum { value = true }; };
+// Note that `Packet4ui` uses the underlying type `__m128i`, which is
+// interpreted as a vector of _signed_ `int32`s, which breaks some arithmetic
+// operations used in `GenericPacketMath.h`.
+template<> struct is_arithmetic<Packet4ui> { enum { value = false }; };
template<> struct is_arithmetic<Packet16b> { enum { value = true }; };
template<int p, int q, int r, int s>
@@ -66,6 +72,9 @@
#define vec4i_swizzle1(v,p,q,r,s) \
Packet4i(_mm_shuffle_epi32( v, (shuffle_mask<p,q,r,s>::mask)))
+#define vec4ui_swizzle1(v, p, q, r, s) \
+ Packet4ui(vec4i_swizzle1(v,p,q,r,s))
+
#define vec2d_swizzle1(v,p,q) \
Packet2d(_mm_castsi128_pd(_mm_shuffle_epi32( _mm_castpd_si128(v), (shuffle_mask<2*p,2*p+1,2*q,2*q+1>::mask))))
@@ -75,6 +84,9 @@
#define vec4i_swizzle2(a,b,p,q,r,s) \
Packet4i(_mm_castps_si128( (_mm_shuffle_ps( _mm_castsi128_ps(a), _mm_castsi128_ps(b), (shuffle_mask<p,q,r,s>::mask)))))
+#define vec4ui_swizzle2(a,b,p,q,r,s) \
+ Packet4i(vec4i_swizzle2(a,b,p,q,r,s))
+
EIGEN_STRONG_INLINE Packet4f vec4f_movelh(const Packet4f& a, const Packet4f& b)
{
return Packet4f(_mm_movelh_ps(a,b));
@@ -120,6 +132,7 @@
#define EIGEN_DECLARE_CONST_Packet4i(NAME,X) \
const Packet4i p4i_##NAME = pset1<Packet4i>(X)
+#define EIGEN_DECLARE_CONST_Packet4ui(NAME, X) const Packet4ui p4ui_##NAME = pset1<Packet4ui>(X)
// Use the packet_traits defined in AVX/PacketMath.h instead if we're going
// to leverage AVX instructions.
@@ -202,6 +215,33 @@
HasBlend = 1
};
};
+template<> struct packet_traits<uint32_t> : default_packet_traits
+{
+ typedef Packet4ui type;
+ typedef Packet4ui half;
+ enum {
+ Vectorizable = 1,
+ AlignedOnScalar = 1,
+ size = 4,
+
+ HasDiv = 0,
+ HasNegate = 0,
+ HasSqrt = 0,
+
+#ifdef EIGEN_VECTORIZE_SSE4_1
+ HasCmp = 1,
+ HasMin = 1,
+ HasMax = 1,
+#else
+ HasCmp = 0,
+ HasMin = 0,
+ HasMax = 0,
+#endif
+
+ HasShift = 1,
+ HasBlend = 1
+ };
+};
#endif
template<> struct packet_traits<bool> : default_packet_traits
{
@@ -211,7 +251,7 @@
Vectorizable = 1,
AlignedOnScalar = 1,
size=16,
-
+
HasAdd = 1,
HasSub = 1,
HasCmp = 1, // note -- only pcmp_eq is defined
@@ -244,6 +284,11 @@
typedef Packet4i half;
enum {size=4, alignment=Aligned16, vectorizable=true, masked_load_available=false, masked_store_available=false};
};
+template<> struct unpacket_traits<Packet4ui> {
+ typedef uint32_t type;
+ typedef Packet4ui half;
+ enum {size = 4, alignment = Aligned16, vectorizable = true, masked_load_available = false, masked_store_available = false};
+};
template<> struct unpacket_traits<Packet16b> {
typedef bool type;
typedef Packet16b half;
@@ -258,6 +303,7 @@
template<> EIGEN_STRONG_INLINE Packet4f pset1<Packet4f>(const float& from) { return _mm_set_ps1(from); }
template<> EIGEN_STRONG_INLINE Packet2d pset1<Packet2d>(const double& from) { return _mm_set1_pd(from); }
template<> EIGEN_STRONG_INLINE Packet4i pset1<Packet4i>(const int& from) { return _mm_set1_epi32(from); }
+template<> EIGEN_STRONG_INLINE Packet4ui pset1<Packet4ui>(const uint32_t& from) { return _mm_set1_epi32(numext::bit_cast<int32_t>(from)); }
template<> EIGEN_STRONG_INLINE Packet16b pset1<Packet16b>(const bool& from) { return _mm_set1_epi8(static_cast<char>(from)); }
template<> EIGEN_STRONG_INLINE Packet4f pset1frombits<Packet4f>(unsigned int from) { return _mm_castsi128_ps(pset1<Packet4i>(from)); }
@@ -265,11 +311,13 @@
template<> EIGEN_STRONG_INLINE Packet4f peven_mask(const Packet4f& /*a*/) { return _mm_castsi128_ps(_mm_set_epi32(0, -1, 0, -1)); }
template<> EIGEN_STRONG_INLINE Packet4i peven_mask(const Packet4i& /*a*/) { return _mm_set_epi32(0, -1, 0, -1); }
+template<> EIGEN_STRONG_INLINE Packet4ui peven_mask(const Packet4ui& /*a*/) { return _mm_set_epi32(0, -1, 0, -1); }
template<> EIGEN_STRONG_INLINE Packet2d peven_mask(const Packet2d& /*a*/) { return _mm_castsi128_pd(_mm_set_epi32(0, 0, -1, -1)); }
template<> EIGEN_STRONG_INLINE Packet4f pzero(const Packet4f& /*a*/) { return _mm_setzero_ps(); }
template<> EIGEN_STRONG_INLINE Packet2d pzero(const Packet2d& /*a*/) { return _mm_setzero_pd(); }
template<> EIGEN_STRONG_INLINE Packet4i pzero(const Packet4i& /*a*/) { return _mm_setzero_si128(); }
+template<> EIGEN_STRONG_INLINE Packet4ui pzero(const Packet4ui& /*a*/) { return _mm_setzero_si128(); }
// GCC generates a shufps instruction for _mm_set1_ps/_mm_load1_ps instead of the more efficient pshufd instruction.
// However, using inrinsics for pset1 makes gcc to generate crappy code in some cases (see bug 203)
@@ -285,10 +333,12 @@
template<> EIGEN_STRONG_INLINE Packet4f plset<Packet4f>(const float& a) { return _mm_add_ps(pset1<Packet4f>(a), _mm_set_ps(3,2,1,0)); }
template<> EIGEN_STRONG_INLINE Packet2d plset<Packet2d>(const double& a) { return _mm_add_pd(pset1<Packet2d>(a),_mm_set_pd(1,0)); }
template<> EIGEN_STRONG_INLINE Packet4i plset<Packet4i>(const int& a) { return _mm_add_epi32(pset1<Packet4i>(a),_mm_set_epi32(3,2,1,0)); }
+template<> EIGEN_STRONG_INLINE Packet4ui plset<Packet4ui>(const uint32_t& a) { return _mm_add_epi32(pset1<Packet4ui>(a), _mm_set_epi32(3, 2, 1, 0)); }
template<> EIGEN_STRONG_INLINE Packet4f padd<Packet4f>(const Packet4f& a, const Packet4f& b) { return _mm_add_ps(a,b); }
template<> EIGEN_STRONG_INLINE Packet2d padd<Packet2d>(const Packet2d& a, const Packet2d& b) { return _mm_add_pd(a,b); }
template<> EIGEN_STRONG_INLINE Packet4i padd<Packet4i>(const Packet4i& a, const Packet4i& b) { return _mm_add_epi32(a,b); }
+template<> EIGEN_STRONG_INLINE Packet4ui padd<Packet4ui>(const Packet4ui& a, const Packet4ui& b) { return _mm_add_epi32(a, b); }
template<> EIGEN_STRONG_INLINE Packet16b padd<Packet16b>(const Packet16b& a, const Packet16b& b) { return _mm_or_si128(a,b); }
@@ -299,6 +349,7 @@
template<> EIGEN_STRONG_INLINE Packet4f psub<Packet4f>(const Packet4f& a, const Packet4f& b) { return _mm_sub_ps(a,b); }
template<> EIGEN_STRONG_INLINE Packet2d psub<Packet2d>(const Packet2d& a, const Packet2d& b) { return _mm_sub_pd(a,b); }
template<> EIGEN_STRONG_INLINE Packet4i psub<Packet4i>(const Packet4i& a, const Packet4i& b) { return _mm_sub_epi32(a,b); }
+template<> EIGEN_STRONG_INLINE Packet4ui psub<Packet4ui>(const Packet4ui& a, const Packet4ui& b) { return _mm_sub_epi32(a, b); }
template<> EIGEN_STRONG_INLINE Packet16b psub<Packet16b>(const Packet16b& a, const Packet16b& b) { return _mm_xor_si128(a,b); }
template<> EIGEN_STRONG_INLINE Packet4f pxor<Packet4f>(const Packet4f& a, const Packet4f& b);
@@ -315,7 +366,7 @@
template<> EIGEN_STRONG_INLINE Packet2d pxor<Packet2d>(const Packet2d& , const Packet2d& );
template<> EIGEN_STRONG_INLINE Packet2d paddsub<Packet2d>(const Packet2d& a, const Packet2d& b)
{
-#ifdef EIGEN_VECTORIZE_SSE3
+#ifdef EIGEN_VECTORIZE_SSE3
return _mm_addsub_pd(a,b);
#else
const Packet2d mask = _mm_castsi128_pd(_mm_setr_epi32(0x0,0x80000000,0x0,0x0));
@@ -364,6 +415,21 @@
0,2,1,3);
#endif
}
+template<> EIGEN_STRONG_INLINE Packet4ui pmul<Packet4ui>(const Packet4ui& a, const Packet4ui& b)
+{
+#ifdef EIGEN_VECTORIZE_SSE4_1
+ return _mm_mullo_epi32(a,b);
+#else
+ // this version is slightly faster than 4 scalar products
+ return vec4ui_swizzle1(
+ vec4ui_swizzle2(
+ _mm_mul_epu32(a,b),
+ _mm_mul_epu32(vec4ui_swizzle1(a,1,0,3,2),
+ vec4ui_swizzle1(b,1,0,3,2)),
+ 0,2,0,2),
+ 0,2,1,3);
+#endif
+}
template<> EIGEN_STRONG_INLINE Packet16b pmul<Packet16b>(const Packet16b& a, const Packet16b& b) { return _mm_and_si128(a,b); }
@@ -388,6 +454,7 @@
// for some weird raisons, it has to be overloaded for packet of integers
template<> EIGEN_STRONG_INLINE Packet4i pmadd(const Packet4i& a, const Packet4i& b, const Packet4i& c) { return padd(pmul(a,b), c); }
+template<> EIGEN_STRONG_INLINE Packet4ui pmadd(const Packet4ui& a, const Packet4ui& b, const Packet4ui& c) { return padd(pmul(a, b), c); }
#ifdef EIGEN_VECTORIZE_FMA
template<> EIGEN_STRONG_INLINE Packet4f pmadd(const Packet4f& a, const Packet4f& b, const Packet4f& c) { return _mm_fmadd_ps(a,b,c); }
template<> EIGEN_STRONG_INLINE Packet2d pmadd(const Packet2d& a, const Packet2d& b, const Packet2d& c) { return _mm_fmadd_pd(a,b,c); }
@@ -412,6 +479,10 @@
return _mm_castps_si128(_mm_blendv_ps(_mm_castsi128_ps(b),_mm_castsi128_ps(a),_mm_castsi128_ps(mask)));
}
+template<> EIGEN_DEVICE_FUNC inline Packet4ui pselect(const Packet4ui& mask, const Packet4ui& a, const Packet4ui& b) {
+ return _mm_castps_si128(_mm_blendv_ps(_mm_castsi128_ps(b),_mm_castsi128_ps(a),_mm_castsi128_ps(mask)));
+}
+
template<> EIGEN_DEVICE_FUNC inline Packet2d pselect(const Packet2d& mask, const Packet2d& a, const Packet2d& b) { return _mm_blendv_pd(b,a,mask); }
template<> EIGEN_DEVICE_FUNC inline Packet16b pselect(const Packet16b& mask, const Packet16b& a, const Packet16b& b) {
@@ -442,21 +513,25 @@
template<> EIGEN_STRONG_INLINE Packet4f pand<Packet4f>(const Packet4f& a, const Packet4f& b) { return _mm_and_ps(a,b); }
template<> EIGEN_STRONG_INLINE Packet2d pand<Packet2d>(const Packet2d& a, const Packet2d& b) { return _mm_and_pd(a,b); }
template<> EIGEN_STRONG_INLINE Packet4i pand<Packet4i>(const Packet4i& a, const Packet4i& b) { return _mm_and_si128(a,b); }
+template<> EIGEN_STRONG_INLINE Packet4ui pand<Packet4ui>(const Packet4ui& a, const Packet4ui& b) { return _mm_and_si128(a, b); }
template<> EIGEN_STRONG_INLINE Packet16b pand<Packet16b>(const Packet16b& a, const Packet16b& b) { return _mm_and_si128(a,b); }
template<> EIGEN_STRONG_INLINE Packet4f por<Packet4f>(const Packet4f& a, const Packet4f& b) { return _mm_or_ps(a,b); }
template<> EIGEN_STRONG_INLINE Packet2d por<Packet2d>(const Packet2d& a, const Packet2d& b) { return _mm_or_pd(a,b); }
template<> EIGEN_STRONG_INLINE Packet4i por<Packet4i>(const Packet4i& a, const Packet4i& b) { return _mm_or_si128(a,b); }
+template<> EIGEN_STRONG_INLINE Packet4ui por<Packet4ui>(const Packet4ui& a, const Packet4ui& b) { return _mm_or_si128(a, b); }
template<> EIGEN_STRONG_INLINE Packet16b por<Packet16b>(const Packet16b& a, const Packet16b& b) { return _mm_or_si128(a,b); }
template<> EIGEN_STRONG_INLINE Packet4f pxor<Packet4f>(const Packet4f& a, const Packet4f& b) { return _mm_xor_ps(a,b); }
template<> EIGEN_STRONG_INLINE Packet2d pxor<Packet2d>(const Packet2d& a, const Packet2d& b) { return _mm_xor_pd(a,b); }
template<> EIGEN_STRONG_INLINE Packet4i pxor<Packet4i>(const Packet4i& a, const Packet4i& b) { return _mm_xor_si128(a,b); }
+template<> EIGEN_STRONG_INLINE Packet4ui pxor<Packet4ui>(const Packet4ui& a, const Packet4ui& b) { return _mm_xor_si128(a, b); }
template<> EIGEN_STRONG_INLINE Packet16b pxor<Packet16b>(const Packet16b& a, const Packet16b& b) { return _mm_xor_si128(a,b); }
template<> EIGEN_STRONG_INLINE Packet4f pandnot<Packet4f>(const Packet4f& a, const Packet4f& b) { return _mm_andnot_ps(b,a); }
template<> EIGEN_STRONG_INLINE Packet2d pandnot<Packet2d>(const Packet2d& a, const Packet2d& b) { return _mm_andnot_pd(b,a); }
template<> EIGEN_STRONG_INLINE Packet4i pandnot<Packet4i>(const Packet4i& a, const Packet4i& b) { return _mm_andnot_si128(b,a); }
+template<> EIGEN_STRONG_INLINE Packet4ui pandnot<Packet4ui>(const Packet4ui& a, const Packet4ui& b) { return _mm_andnot_si128(b, a); }
template<> EIGEN_STRONG_INLINE Packet4f pcmp_le(const Packet4f& a, const Packet4f& b) { return _mm_cmple_ps(a,b); }
template<> EIGEN_STRONG_INLINE Packet4f pcmp_lt(const Packet4f& a, const Packet4f& b) { return _mm_cmplt_ps(a,b); }
@@ -470,22 +545,23 @@
template<> EIGEN_STRONG_INLINE Packet4i pcmp_lt(const Packet4i& a, const Packet4i& b) { return _mm_cmplt_epi32(a,b); }
template<> EIGEN_STRONG_INLINE Packet4i pcmp_eq(const Packet4i& a, const Packet4i& b) { return _mm_cmpeq_epi32(a,b); }
+template<> EIGEN_STRONG_INLINE Packet4ui pcmp_eq(const Packet4ui& a, const Packet4ui& b) { return _mm_cmpeq_epi32(a, b); }
template<> EIGEN_STRONG_INLINE Packet16b pcmp_eq(const Packet16b& a, const Packet16b& b) { return _mm_cmpeq_epi8(a,b); }
template<> EIGEN_STRONG_INLINE Packet4i pcmp_le(const Packet4i& a, const Packet4i& b) { return por(pcmp_lt(a,b), pcmp_eq(a,b)); }
template<> EIGEN_STRONG_INLINE Packet4f pmin<Packet4f>(const Packet4f& a, const Packet4f& b) {
#if EIGEN_GNUC_STRICT_LESS_THAN(6,3,0)
- // There appears to be a bug in GCC, by which the optimizer may
- // flip the argument order in calls to _mm_min_ps, so we have to
- // resort to inline ASM here. This is supposed to be fixed in gcc6.3,
- // see also: https://gcc.gnu.org/bugzilla/show_bug.cgi?id=72867
- #ifdef EIGEN_VECTORIZE_AVX
+// There appears to be a bug in GCC, by which the optimizer may
+// flip the argument order in calls to _mm_min_ps, so we have to
+// resort to inline ASM here. This is supposed to be fixed in gcc6.3,
+// see also: https://gcc.gnu.org/bugzilla/show_bug.cgi?id=72867
+#ifdef EIGEN_VECTORIZE_AVX
Packet4f res;
asm("vminps %[a], %[b], %[res]" : [res] "=x" (res) : [a] "x" (a), [b] "x" (b));
- #else
+#else
Packet4f res = b;
asm("minps %[a], %[res]" : [res] "+x" (res) : [a] "x" (a));
- #endif
+#endif
return res;
#else
// Arguments are reversed to match NaN propagation behavior of std::min.
@@ -494,17 +570,17 @@
}
template<> EIGEN_STRONG_INLINE Packet2d pmin<Packet2d>(const Packet2d& a, const Packet2d& b) {
#if EIGEN_GNUC_STRICT_LESS_THAN(6,3,0)
- // There appears to be a bug in GCC, by which the optimizer may
- // flip the argument order in calls to _mm_min_pd, so we have to
- // resort to inline ASM here. This is supposed to be fixed in gcc6.3,
- // see also: https://gcc.gnu.org/bugzilla/show_bug.cgi?id=72867
- #ifdef EIGEN_VECTORIZE_AVX
+// There appears to be a bug in GCC, by which the optimizer may
+// flip the argument order in calls to _mm_min_pd, so we have to
+// resort to inline ASM here. This is supposed to be fixed in gcc6.3,
+// see also: https://gcc.gnu.org/bugzilla/show_bug.cgi?id=72867
+#ifdef EIGEN_VECTORIZE_AVX
Packet2d res;
asm("vminpd %[a], %[b], %[res]" : [res] "=x" (res) : [a] "x" (a), [b] "x" (b));
- #else
+#else
Packet2d res = b;
asm("minpd %[a], %[res]" : [res] "+x" (res) : [a] "x" (a));
- #endif
+#endif
return res;
#else
// Arguments are reversed to match NaN propagation behavior of std::min.
@@ -521,21 +597,30 @@
return _mm_or_si128(_mm_and_si128(mask,a),_mm_andnot_si128(mask,b));
#endif
}
+template<> EIGEN_STRONG_INLINE Packet4ui pmin<Packet4ui>(const Packet4ui& a, const Packet4ui& b) {
+#ifdef EIGEN_VECTORIZE_SSE4_1
+ return _mm_min_epu32(a, b);
+#else
+ return padd((Packet4ui)pmin((Packet4i)psub(a, pset1<Packet4ui>(0x80000000UL)),
+ (Packet4i)psub(b, pset1<Packet4ui>(0x80000000UL))),
+ pset1<Packet4ui>(0x80000000UL));
+#endif
+}
template<> EIGEN_STRONG_INLINE Packet4f pmax<Packet4f>(const Packet4f& a, const Packet4f& b) {
#if EIGEN_GNUC_STRICT_LESS_THAN(6,3,0)
- // There appears to be a bug in GCC, by which the optimizer may
- // flip the argument order in calls to _mm_max_ps, so we have to
- // resort to inline ASM here. This is supposed to be fixed in gcc6.3,
- // see also: https://gcc.gnu.org/bugzilla/show_bug.cgi?id=72867
- #ifdef EIGEN_VECTORIZE_AVX
+// There appears to be a bug in GCC, by which the optimizer may
+// flip the argument order in calls to _mm_max_ps, so we have to
+// resort to inline ASM here. This is supposed to be fixed in gcc6.3,
+// see also: https://gcc.gnu.org/bugzilla/show_bug.cgi?id=72867
+#ifdef EIGEN_VECTORIZE_AVX
Packet4f res;
asm("vmaxps %[a], %[b], %[res]" : [res] "=x" (res) : [a] "x" (a), [b] "x" (b));
- #else
+#else
Packet4f res = b;
asm("maxps %[a], %[res]" : [res] "+x" (res) : [a] "x" (a));
- #endif
+#endif
return res;
#else
// Arguments are reversed to match NaN propagation behavior of std::max.
@@ -544,17 +629,17 @@
}
template<> EIGEN_STRONG_INLINE Packet2d pmax<Packet2d>(const Packet2d& a, const Packet2d& b) {
#if EIGEN_GNUC_STRICT_LESS_THAN(6,3,0)
- // There appears to be a bug in GCC, by which the optimizer may
- // flip the argument order in calls to _mm_max_pd, so we have to
- // resort to inline ASM here. This is supposed to be fixed in gcc6.3,
- // see also: https://gcc.gnu.org/bugzilla/show_bug.cgi?id=72867
- #ifdef EIGEN_VECTORIZE_AVX
+// There appears to be a bug in GCC, by which the optimizer may
+// flip the argument order in calls to _mm_max_pd, so we have to
+// resort to inline ASM here. This is supposed to be fixed in gcc6.3,
+// see also: https://gcc.gnu.org/bugzilla/show_bug.cgi?id=72867
+#ifdef EIGEN_VECTORIZE_AVX
Packet2d res;
asm("vmaxpd %[a], %[b], %[res]" : [res] "=x" (res) : [a] "x" (a), [b] "x" (b));
- #else
+#else
Packet2d res = b;
asm("maxpd %[a], %[res]" : [res] "+x" (res) : [a] "x" (a));
- #endif
+#endif
return res;
#else
// Arguments are reversed to match NaN propagation behavior of std::max.
@@ -571,6 +656,32 @@
return _mm_or_si128(_mm_and_si128(mask,a),_mm_andnot_si128(mask,b));
#endif
}
+template<> EIGEN_STRONG_INLINE Packet4ui pmax<Packet4ui>(const Packet4ui& a, const Packet4ui& b) {
+#ifdef EIGEN_VECTORIZE_SSE4_1
+ return _mm_max_epu32(a, b);
+#else
+ return padd((Packet4ui)pmax((Packet4i)psub(a, pset1<Packet4ui>(0x80000000UL)),
+ (Packet4i)psub(b, pset1<Packet4ui>(0x80000000UL))),
+ pset1<Packet4ui>(0x80000000UL));
+#endif
+}
+
+template<> EIGEN_STRONG_INLINE Packet4ui pcmp_lt(const Packet4ui& a, const Packet4ui& b) {
+#ifdef EIGEN_VECTORIZE_SSE4_1
+ return pxor(pcmp_eq(a, pmax(a, b)), ptrue(a));
+#else
+ return (Packet4ui)pcmp_lt((Packet4i)psub(a, pset1<Packet4ui>(0x80000000UL)),
+ (Packet4i)psub(b, pset1<Packet4ui>(0x80000000UL)));
+#endif
+}
+template<> EIGEN_STRONG_INLINE Packet4ui pcmp_le(const Packet4ui& a, const Packet4ui& b) {
+#ifdef EIGEN_VECTORIZE_SSE4_1
+ return pcmp_eq(a, pmin(a, b));
+#else
+ return (Packet4ui)pcmp_le((Packet4i)psub(a, pset1<Packet4ui>(0x80000000UL)),
+ (Packet4i)psub(b, pset1<Packet4ui>(0x80000000UL)));
+#endif
+}
template <typename Packet, typename Op>
EIGEN_STRONG_INLINE Packet pminmax_propagate_numbers(const Packet& a, const Packet& b, Op op) {
@@ -628,6 +739,10 @@
template<int N> EIGEN_STRONG_INLINE Packet4i plogical_shift_right (const Packet4i& a) { return _mm_srli_epi32(a,N); }
template<int N> EIGEN_STRONG_INLINE Packet4i plogical_shift_left (const Packet4i& a) { return _mm_slli_epi32(a,N); }
+template<int N> EIGEN_STRONG_INLINE Packet4ui parithmetic_shift_right(const Packet4ui& a) { return _mm_srli_epi32(a,N); }
+template<int N> EIGEN_STRONG_INLINE Packet4ui plogical_shift_right (const Packet4ui& a) { return _mm_srli_epi32(a,N); }
+template<int N> EIGEN_STRONG_INLINE Packet4ui plogical_shift_left (const Packet4ui& a) { return _mm_slli_epi32(a,N); }
+
template<> EIGEN_STRONG_INLINE Packet4f pabs(const Packet4f& a)
{
const Packet4f mask = _mm_castsi128_ps(_mm_setr_epi32(0x7FFFFFFF,0x7FFFFFFF,0x7FFFFFFF,0x7FFFFFFF));
@@ -640,24 +755,26 @@
}
template<> EIGEN_STRONG_INLINE Packet4i pabs(const Packet4i& a)
{
- #ifdef EIGEN_VECTORIZE_SSSE3
+#ifdef EIGEN_VECTORIZE_SSSE3
return _mm_abs_epi32(a);
- #else
+#else
Packet4i aux = _mm_srai_epi32(a,31);
return _mm_sub_epi32(_mm_xor_si128(a,aux),aux);
- #endif
+#endif
}
+template<> EIGEN_STRONG_INLINE Packet4ui pabs(const Packet4ui& a) { return a; }
template<> EIGEN_STRONG_INLINE Packet4f psignbit(const Packet4f& a) { return _mm_castsi128_ps(_mm_srai_epi32(_mm_castps_si128(a), 31)); }
template<> EIGEN_STRONG_INLINE Packet2d psignbit(const Packet2d& a)
{
- Packet4f tmp = psignbit<Packet4f>(_mm_castpd_ps(a));
+ Packet4f tmp = psignbit<Packet4f>(_mm_castpd_ps(a));
#ifdef EIGEN_VECTORIZE_AVX
- return _mm_castps_pd(_mm_permute_ps(tmp, (shuffle_mask<1, 1, 3, 3>::mask)));
+ return _mm_castps_pd(_mm_permute_ps(tmp, (shuffle_mask<1, 1, 3, 3>::mask)));
#else
- return _mm_castps_pd(_mm_shuffle_ps(tmp, tmp, (shuffle_mask<1, 1, 3, 3>::mask)));
+ return _mm_castps_pd(_mm_shuffle_ps(tmp, tmp, (shuffle_mask<1, 1, 3, 3>::mask)));
#endif // EIGEN_VECTORIZE_AVX
}
+template<> EIGEN_STRONG_INLINE Packet4ui psignbit(const Packet4ui& a) { return pzero(a); }
#ifdef EIGEN_VECTORIZE_SSE4_1
template<> EIGEN_STRONG_INLINE Packet4f pround<Packet4f>(const Packet4f& a)
@@ -756,13 +873,14 @@
template<> EIGEN_STRONG_INLINE Packet4f pload<Packet4f>(const float* from) { EIGEN_DEBUG_ALIGNED_LOAD return _mm_load_ps(from); }
template<> EIGEN_STRONG_INLINE Packet2d pload<Packet2d>(const double* from) { EIGEN_DEBUG_ALIGNED_LOAD return _mm_load_pd(from); }
template<> EIGEN_STRONG_INLINE Packet4i pload<Packet4i>(const int* from) { EIGEN_DEBUG_ALIGNED_LOAD return _mm_load_si128(reinterpret_cast<const __m128i*>(from)); }
+template<> EIGEN_STRONG_INLINE Packet4ui pload<Packet4ui>(const uint32_t* from) { EIGEN_DEBUG_ALIGNED_LOAD return _mm_load_si128(reinterpret_cast<const __m128i*>(from)); }
template<> EIGEN_STRONG_INLINE Packet16b pload<Packet16b>(const bool* from) { EIGEN_DEBUG_ALIGNED_LOAD return _mm_load_si128(reinterpret_cast<const __m128i*>(from)); }
#if EIGEN_COMP_MSVC
template<> EIGEN_STRONG_INLINE Packet4f ploadu<Packet4f>(const float* from) {
- EIGEN_DEBUG_UNALIGNED_LOAD
- return _mm_loadu_ps(from);
- }
+ EIGEN_DEBUG_UNALIGNED_LOAD
+ return _mm_loadu_ps(from);
+}
#else
// NOTE: with the code below, MSVC's compiler crashes!
@@ -783,6 +901,11 @@
EIGEN_DEBUG_UNALIGNED_LOAD
return _mm_loadu_si128(reinterpret_cast<const __m128i*>(from));
}
+template<> EIGEN_STRONG_INLINE Packet4ui ploadu<Packet4ui>(const uint32_t* from)
+{
+ EIGEN_DEBUG_UNALIGNED_LOAD
+ return _mm_loadu_si128(reinterpret_cast<const __m128i*>(from));
+}
template<> EIGEN_STRONG_INLINE Packet16b ploadu<Packet16b>(const bool* from) {
EIGEN_DEBUG_UNALIGNED_LOAD
return _mm_loadu_si128(reinterpret_cast<const __m128i*>(from));
@@ -810,6 +933,12 @@
tmp = _mm_loadl_epi64(reinterpret_cast<const __m128i*>(from));
return vec4i_swizzle1(tmp, 0, 0, 1, 1);
}
+template<> EIGEN_STRONG_INLINE Packet4ui ploaddup<Packet4ui>(const uint32_t* from)
+{
+ Packet4ui tmp;
+ tmp = _mm_loadl_epi64(reinterpret_cast<const __m128i*>(from));
+ return vec4ui_swizzle1(tmp, 0, 0, 1, 1);
+}
// Loads 8 bools from memory and returns the packet
// {b0, b0, b1, b1, b2, b2, b3, b3, b4, b4, b5, b5, b6, b6, b7, b7}
@@ -831,11 +960,13 @@
template<> EIGEN_STRONG_INLINE void pstore<float>(float* to, const Packet4f& from) { EIGEN_DEBUG_ALIGNED_STORE _mm_store_ps(to, from); }
template<> EIGEN_STRONG_INLINE void pstore<double>(double* to, const Packet2d& from) { EIGEN_DEBUG_ALIGNED_STORE _mm_store_pd(to, from); }
template<> EIGEN_STRONG_INLINE void pstore<int>(int* to, const Packet4i& from) { EIGEN_DEBUG_ALIGNED_STORE _mm_store_si128(reinterpret_cast<__m128i*>(to), from); }
+template<> EIGEN_STRONG_INLINE void pstore<uint32_t>(uint32_t* to, const Packet4ui& from) { EIGEN_DEBUG_ALIGNED_STORE _mm_store_si128(reinterpret_cast<__m128i*>(to), from); }
template<> EIGEN_STRONG_INLINE void pstore<bool>(bool* to, const Packet16b& from) { EIGEN_DEBUG_ALIGNED_STORE _mm_store_si128(reinterpret_cast<__m128i*>(to), from); }
template<> EIGEN_STRONG_INLINE void pstoreu<double>(double* to, const Packet2d& from) { EIGEN_DEBUG_UNALIGNED_STORE _mm_storeu_pd(to, from); }
template<> EIGEN_STRONG_INLINE void pstoreu<float>(float* to, const Packet4f& from) { EIGEN_DEBUG_UNALIGNED_STORE _mm_storeu_ps(to, from); }
template<> EIGEN_STRONG_INLINE void pstoreu<int>(int* to, const Packet4i& from) { EIGEN_DEBUG_UNALIGNED_STORE _mm_storeu_si128(reinterpret_cast<__m128i*>(to), from); }
+template<> EIGEN_STRONG_INLINE void pstoreu<uint32_t>(uint32_t* to, const Packet4ui& from) { EIGEN_DEBUG_UNALIGNED_STORE _mm_storeu_si128(reinterpret_cast<__m128i*>(to), from); }
template<> EIGEN_STRONG_INLINE void pstoreu<bool>(bool* to, const Packet16b& from) { EIGEN_DEBUG_ALIGNED_STORE _mm_storeu_si128(reinterpret_cast<__m128i*>(to), from); }
template<typename Scalar, typename Packet> EIGEN_STRONG_INLINE void pstorel(Scalar* to, const Packet& from);
@@ -858,6 +989,11 @@
{
return _mm_set_epi32(from[3*stride], from[2*stride], from[1*stride], from[0*stride]);
}
+template<> EIGEN_DEVICE_FUNC inline Packet4ui pgather<uint32_t, Packet4ui>(const uint32_t* from, Index stride)
+{
+ return _mm_set_epi32(numext::bit_cast<int32_t>(from[3 * stride]), numext::bit_cast<int32_t>(from[2 * stride]),
+ numext::bit_cast<int32_t>(from[1 * stride]), numext::bit_cast<int32_t>(from[0 * stride]));
+}
template<> EIGEN_DEVICE_FUNC inline Packet16b pgather<bool, Packet16b>(const bool* from, Index stride)
{
@@ -886,6 +1022,13 @@
to[stride*2] = _mm_cvtsi128_si32(_mm_shuffle_epi32(from, 2));
to[stride*3] = _mm_cvtsi128_si32(_mm_shuffle_epi32(from, 3));
}
+template<> EIGEN_DEVICE_FUNC inline void pscatter<uint32_t, Packet4ui>(uint32_t* to, const Packet4ui& from, Index stride)
+{
+ to[stride * 0] = numext::bit_cast<uint32_t>(_mm_cvtsi128_si32(from));
+ to[stride * 1] = numext::bit_cast<uint32_t>(_mm_cvtsi128_si32(_mm_shuffle_epi32(from, 1)));
+ to[stride * 2] = numext::bit_cast<uint32_t>(_mm_cvtsi128_si32(_mm_shuffle_epi32(from, 2)));
+ to[stride * 3] = numext::bit_cast<uint32_t>(_mm_cvtsi128_si32(_mm_shuffle_epi32(from, 3)));
+}
template<> EIGEN_DEVICE_FUNC inline void pscatter<bool, Packet16b>(bool* to, const Packet16b& from, Index stride)
{
to[4*stride*0] = _mm_cvtsi128_si32(from);
@@ -918,6 +1061,7 @@
template<> EIGEN_STRONG_INLINE void prefetch<float>(const float* addr) { _mm_prefetch((SsePrefetchPtrType)(addr), _MM_HINT_T0); }
template<> EIGEN_STRONG_INLINE void prefetch<double>(const double* addr) { _mm_prefetch((SsePrefetchPtrType)(addr), _MM_HINT_T0); }
template<> EIGEN_STRONG_INLINE void prefetch<int>(const int* addr) { _mm_prefetch((SsePrefetchPtrType)(addr), _MM_HINT_T0); }
+template<> EIGEN_STRONG_INLINE void prefetch<uint32_t>(const uint32_t* addr) { _mm_prefetch((SsePrefetchPtrType)(addr), _MM_HINT_T0); }
#endif
#if EIGEN_COMP_MSVC_STRICT && EIGEN_OS_WIN64
@@ -926,21 +1070,25 @@
template<> EIGEN_STRONG_INLINE float pfirst<Packet4f>(const Packet4f& a) { return a.m128_f32[0]; }
template<> EIGEN_STRONG_INLINE double pfirst<Packet2d>(const Packet2d& a) { return a.m128d_f64[0]; }
template<> EIGEN_STRONG_INLINE int pfirst<Packet4i>(const Packet4i& a) { int x = _mm_cvtsi128_si32(a); return x; }
+template<> EIGEN_STRONG_INLINE uint32_t pfirst<Packet4ui>(const Packet4ui& a) { uint32_t x = numext::bit_cast<uint32_t>(_mm_cvtsi128_si32(a)); return x; }
#elif EIGEN_COMP_MSVC_STRICT
// The temporary variable fixes an internal compilation error in vs <= 2008 and a wrong-result bug in vs 2010
template<> EIGEN_STRONG_INLINE float pfirst<Packet4f>(const Packet4f& a) { float x = _mm_cvtss_f32(a); return x; }
template<> EIGEN_STRONG_INLINE double pfirst<Packet2d>(const Packet2d& a) { double x = _mm_cvtsd_f64(a); return x; }
template<> EIGEN_STRONG_INLINE int pfirst<Packet4i>(const Packet4i& a) { int x = _mm_cvtsi128_si32(a); return x; }
+template<> EIGEN_STRONG_INLINE uint32_t pfirst<Packet4ui>(const Packet4ui& a) { uint32_t x = numext::bit_cast<uint32_t>(_mm_cvtsi128_si32(a)); return x; }
#else
template<> EIGEN_STRONG_INLINE float pfirst<Packet4f>(const Packet4f& a) { return _mm_cvtss_f32(a); }
template<> EIGEN_STRONG_INLINE double pfirst<Packet2d>(const Packet2d& a) { return _mm_cvtsd_f64(a); }
template<> EIGEN_STRONG_INLINE int pfirst<Packet4i>(const Packet4i& a) { return _mm_cvtsi128_si32(a); }
+template<> EIGEN_STRONG_INLINE uint32_t pfirst<Packet4ui>(const Packet4ui& a) { return numext::bit_cast<uint32_t>(_mm_cvtsi128_si32(a)); }
#endif
template<> EIGEN_STRONG_INLINE bool pfirst<Packet16b>(const Packet16b& a) { int x = _mm_cvtsi128_si32(a); return static_cast<bool>(x & 1); }
template<> EIGEN_STRONG_INLINE Packet4f preverse(const Packet4f& a) { return _mm_shuffle_ps(a,a,0x1B); }
template<> EIGEN_STRONG_INLINE Packet2d preverse(const Packet2d& a) { return _mm_shuffle_pd(a,a,0x1); }
template<> EIGEN_STRONG_INLINE Packet4i preverse(const Packet4i& a) { return _mm_shuffle_epi32(a,0x1B); }
+template<> EIGEN_STRONG_INLINE Packet4ui preverse(const Packet4ui& a) { return _mm_shuffle_epi32(a, 0x1B); }
template<> EIGEN_STRONG_INLINE Packet16b preverse(const Packet16b& a) {
#ifdef EIGEN_VECTORIZE_SSSE3
__m128i mask = _mm_set_epi8(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15);
@@ -979,10 +1127,10 @@
// Clamp exponent to [-2099, 2099]
const Packet2d max_exponent = pset1<Packet2d>(2099.0);
const Packet2d e = pmin(pmax(exponent, pnegate(max_exponent)), max_exponent);
-
+
// Convert e to integer and swizzle to low-order bits.
const Packet4i ei = vec4i_swizzle1(_mm_cvtpd_epi32(e), 0, 3, 1, 3);
-
+
// Split 2^e into four factors and multiply:
const Packet4i bias = _mm_set_epi32(0, 1023, 0, 1023);
Packet4i b = parithmetic_shift_right<2>(ei); // floor(e/4)
@@ -1038,24 +1186,24 @@
{
// Disable SSE3 _mm_hadd_pd that is extremely slow on all existing Intel's architectures
// (from Nehalem to Haswell)
-// #ifdef EIGEN_VECTORIZE_SSE3
-// Packet4f tmp = _mm_add_ps(a, vec4f_swizzle1(a,2,3,2,3));
-// return pfirst<Packet4f>(_mm_hadd_ps(tmp, tmp));
-// #else
+ // #ifdef EIGEN_VECTORIZE_SSE3
+ // Packet4f tmp = _mm_add_ps(a, vec4f_swizzle1(a,2,3,2,3));
+ // return pfirst<Packet4f>(_mm_hadd_ps(tmp, tmp));
+ // #else
Packet4f tmp = _mm_add_ps(a, _mm_movehl_ps(a,a));
return pfirst<Packet4f>(_mm_add_ss(tmp, _mm_shuffle_ps(tmp,tmp, 1)));
-// #endif
+ // #endif
}
template<> EIGEN_STRONG_INLINE double predux<Packet2d>(const Packet2d& a)
{
// Disable SSE3 _mm_hadd_pd that is extremely slow on all existing Intel's architectures
// (from Nehalem to Haswell)
-// #ifdef EIGEN_VECTORIZE_SSE3
-// return pfirst<Packet2d>(_mm_hadd_pd(a, a));
-// #else
+ // #ifdef EIGEN_VECTORIZE_SSE3
+ // return pfirst<Packet2d>(_mm_hadd_pd(a, a));
+ // #else
return pfirst<Packet2d>(_mm_add_sd(a, _mm_unpackhi_pd(a,a)));
-// #endif
+ // #endif
}
#ifdef EIGEN_VECTORIZE_SSSE3
@@ -1064,6 +1212,11 @@
Packet4i tmp0 = _mm_hadd_epi32(a,a);
return pfirst<Packet4i>(_mm_hadd_epi32(tmp0,tmp0));
}
+template<> EIGEN_STRONG_INLINE uint32_t predux<Packet4ui>(const Packet4ui& a)
+{
+ Packet4ui tmp0 = _mm_hadd_epi32(a, a);
+ return pfirst<Packet4ui>(_mm_hadd_epi32(tmp0, tmp0));
+}
#else
template<> EIGEN_STRONG_INLINE int predux<Packet4i>(const Packet4i& a)
@@ -1071,6 +1224,10 @@
Packet4i tmp = _mm_add_epi32(a, _mm_unpackhi_epi64(a,a));
return pfirst(tmp) + pfirst<Packet4i>(_mm_shuffle_epi32(tmp, 1));
}
+template<> EIGEN_STRONG_INLINE uint32_t predux<Packet4ui>(const Packet4ui& a) {
+ Packet4ui tmp = _mm_add_epi32(a, _mm_unpackhi_epi64(a, a));
+ return pfirst(tmp) + pfirst<Packet4ui>(_mm_shuffle_epi32(tmp, 1));
+}
#endif
template<> EIGEN_STRONG_INLINE bool predux<Packet16b>(const Packet16b& a) {
@@ -1100,6 +1257,15 @@
pstore(aux, a);
return (aux[0] * aux[1]) * (aux[2] * aux[3]);
}
+template<> EIGEN_STRONG_INLINE uint32_t predux_mul<Packet4ui>(const Packet4ui& a)
+{
+ // after some experiments, it is seems this is the fastest way to implement it
+ // for GCC (eg., reusing pmul is very slow !)
+ // TODO try to call _mm_mul_epu32 directly
+ EIGEN_ALIGN16 uint32_t aux[4];
+ pstore(aux, a);
+ return (aux[0] * aux[1]) * (aux[2] * aux[3]);
+}
template<> EIGEN_STRONG_INLINE bool predux_mul<Packet16b>(const Packet16b& a) {
Packet4i tmp = _mm_and_si128(a, _mm_unpackhi_epi64(a,a));
@@ -1132,6 +1298,21 @@
return aux0<aux2 ? aux0 : aux2;
#endif // EIGEN_VECTORIZE_SSE4_1
}
+template<> EIGEN_STRONG_INLINE uint32_t predux_min<Packet4ui>(const Packet4ui& a)
+{
+#ifdef EIGEN_VECTORIZE_SSE4_1
+ Packet4ui tmp = _mm_min_epu32(a, _mm_shuffle_epi32(a, _MM_SHUFFLE(0,0,3,2)));
+ return pfirst<Packet4ui>(_mm_min_epu32(tmp,_mm_shuffle_epi32(tmp, 1)));
+#else
+ // after some experiments, it is seems this is the fastest way to implement it
+ // for GCC (eg., it does not like using std::min after the pstore !!)
+ EIGEN_ALIGN16 uint32_t aux[4];
+ pstore(aux, a);
+ uint32_t aux0 = aux[0]<aux[1] ? aux[0] : aux[1];
+ uint32_t aux2 = aux[2]<aux[3] ? aux[2] : aux[3];
+ return aux0<aux2 ? aux0 : aux2;
+#endif // EIGEN_VECTORIZE_SSE4_1
+}
// max
template<> EIGEN_STRONG_INLINE float predux_max<Packet4f>(const Packet4f& a)
@@ -1158,6 +1339,21 @@
return aux0>aux2 ? aux0 : aux2;
#endif // EIGEN_VECTORIZE_SSE4_1
}
+template<> EIGEN_STRONG_INLINE uint32_t predux_max<Packet4ui>(const Packet4ui& a)
+{
+#ifdef EIGEN_VECTORIZE_SSE4_1
+ Packet4ui tmp = _mm_max_epu32(a, _mm_shuffle_epi32(a, _MM_SHUFFLE(0,0,3,2)));
+ return pfirst<Packet4ui>(_mm_max_epu32(tmp,_mm_shuffle_epi32(tmp, 1)));
+#else
+ // after some experiments, it is seems this is the fastest way to implement it
+ // for GCC (eg., it does not like using std::min after the pstore !!)
+ EIGEN_ALIGN16 uint32_t aux[4];
+ pstore(aux, a);
+ uint32_t aux0 = aux[0]>aux[1] ? aux[0] : aux[1];
+ uint32_t aux2 = aux[2]>aux[3] ? aux[2] : aux[3];
+ return aux0>aux2 ? aux0 : aux2;
+#endif // EIGEN_VECTORIZE_SSE4_1
+}
// not needed yet
// template<> EIGEN_STRONG_INLINE bool predux_all(const Packet4f& x)
@@ -1174,6 +1370,10 @@
{
return _mm_movemask_ps(_mm_castsi128_ps(x)) != 0x0;
}
+template<> EIGEN_STRONG_INLINE bool predux_any(const Packet4ui& x)
+{
+ return _mm_movemask_ps(_mm_castsi128_ps(x)) != 0x0;
+}
EIGEN_DEVICE_FUNC inline void
ptranspose(PacketBlock<Packet4f,4>& kernel) {
@@ -1199,6 +1399,9 @@
kernel.packet[2] = _mm_unpacklo_epi64(T2, T3);
kernel.packet[3] = _mm_unpackhi_epi64(T2, T3);
}
+EIGEN_DEVICE_FUNC inline void ptranspose(PacketBlock<Packet4ui, 4>& kernel) {
+ ptranspose((PacketBlock<Packet4i, 4>&)kernel);
+}
EIGEN_DEVICE_FUNC inline void
ptranspose(PacketBlock<Packet16b,4>& kernel) {
@@ -1304,6 +1507,10 @@
return _mm_or_si128(_mm_andnot_si128(false_mask, thenPacket), _mm_and_si128(false_mask, elsePacket));
#endif
}
+template<> EIGEN_STRONG_INLINE Packet4ui pblend(const Selector<4>& ifPacket, const Packet4ui& thenPacket,
+ const Packet4ui& elsePacket) {
+ return (Packet4ui)pblend(ifPacket, (Packet4i)thenPacket, (Packet4i)elsePacket);
+}
template<> EIGEN_STRONG_INLINE Packet4f pblend(const Selector<4>& ifPacket, const Packet4f& thenPacket, const Packet4f& elsePacket) {
const __m128 zero = _mm_setzero_ps();
const __m128 select = _mm_set_ps(ifPacket.select[3], ifPacket.select[2], ifPacket.select[1], ifPacket.select[0]);
@@ -1357,7 +1564,7 @@
// Helpers for half->float and float->half conversions.
// Currently only used by the AVX code.
EIGEN_STRONG_INLINE __m128i half2floatsse(__m128i h) {
- __m128i input = _mm_cvtepu16_epi32(h);
+ __m128i input = _mm_cvtepu16_epi32(h);
// Direct vectorization of half_to_float, C parts in the comments.
__m128i shifted_exp = _mm_set1_epi32(0x7c00 << 13);
diff --git a/Eigen/src/plugins/IndexedViewMethods.h b/Eigen/src/plugins/IndexedViewMethods.h
index b796b39..78f12fe 100644
--- a/Eigen/src/plugins/IndexedViewMethods.h
+++ b/Eigen/src/plugins/IndexedViewMethods.h
@@ -7,6 +7,7 @@
// Public License v. 2.0. If a copy of the MPL was not distributed
// with this file, You can obtain one at http://mozilla.org/MPL/2.0/.
+
#if !defined(EIGEN_PARSED_BY_DOXYGEN)
protected:
@@ -24,163 +25,278 @@
typedef typename internal::IndexedViewCompatibleType<Index, 1>::type IvcIndex;
template <typename Indices>
-IvcRowType<Indices> ivcRow(const Indices& indices) const {
+inline IvcRowType<Indices> ivcRow(const Indices& indices) const {
return internal::makeIndexedViewCompatible(
indices, internal::variable_if_dynamic<Index, RowsAtCompileTime>(derived().rows()), Specialized);
}
template <typename Indices>
-IvcColType<Indices> ivcCol(const Indices& indices) const {
+inline IvcColType<Indices> ivcCol(const Indices& indices) const {
return internal::makeIndexedViewCompatible(
indices, internal::variable_if_dynamic<Index, ColsAtCompileTime>(derived().cols()), Specialized);
}
template <typename Indices>
-IvcColType<Indices> ivcSize(const Indices& indices) const {
+inline IvcColType<Indices> ivcSize(const Indices& indices) const {
return internal::makeIndexedViewCompatible(
indices, internal::variable_if_dynamic<Index, SizeAtCompileTime>(derived().size()), Specialized);
}
+// this helper class assumes internal::valid_indexed_view_overload<RowIndices, ColIndices>::value == true
+template <typename RowIndices, typename ColIndices,
+ bool UseSymbolic = internal::traits<IndexedView<Derived, IvcRowType<RowIndices>, IvcColType<ColIndices>>>::ReturnAsScalar,
+ bool UseBlock = internal::traits<IndexedView<Derived, IvcRowType<RowIndices>, IvcColType<ColIndices>>>::ReturnAsBlock,
+ bool UseGeneric = internal::traits<IndexedView<Derived, IvcRowType<RowIndices>, IvcColType<ColIndices>>>::ReturnAsIndexedView>
+struct IndexedViewSelector;
+
+// Generic
+template <typename RowIndices, typename ColIndices>
+struct IndexedViewSelector<RowIndices, ColIndices, false, false, true> {
+ using ReturnType = IndexedView<Derived, IvcRowType<RowIndices>, IvcColType<ColIndices>>;
+ using ConstReturnType = IndexedView<const Derived, IvcRowType<RowIndices>, IvcColType<ColIndices>>;
+
+ static inline ReturnType run(Derived& derived, const RowIndices& rowIndices, const ColIndices& colIndices) {
+ return ReturnType(derived, derived.ivcRow(rowIndices), derived.ivcCol(colIndices));
+ }
+ static inline ConstReturnType run(const Derived& derived, const RowIndices& rowIndices,
+ const ColIndices& colIndices) {
+ return ConstReturnType(derived, derived.ivcRow(rowIndices), derived.ivcCol(colIndices));
+ }
+};
+
+// Block
+template <typename RowIndices, typename ColIndices>
+struct IndexedViewSelector<RowIndices, ColIndices, false, true, false> {
+ using IndexedViewType = IndexedView<Derived, IvcRowType<RowIndices>, IvcColType<ColIndices>>;
+ using ConstIndexedViewType = IndexedView<const Derived, IvcRowType<RowIndices>, IvcColType<ColIndices>>;
+ using ReturnType = typename internal::traits<IndexedViewType>::BlockType;
+ using ConstReturnType = typename internal::traits<ConstIndexedViewType>::BlockType;
+
+ static inline ReturnType run(Derived& derived, const RowIndices& rowIndices, const ColIndices& colIndices) {
+ IvcRowType<RowIndices> actualRowIndices = derived.ivcRow(rowIndices);
+ IvcColType<ColIndices> actualColIndices = derived.ivcCol(colIndices);
+ return ReturnType(derived, internal::first(actualRowIndices), internal::first(actualColIndices),
+ internal::index_list_size(actualRowIndices), internal::index_list_size(actualColIndices));
+ }
+ static inline ConstReturnType run(const Derived& derived, const RowIndices& rowIndices,
+ const ColIndices& colIndices) {
+ IvcRowType<RowIndices> actualRowIndices = derived.ivcRow(rowIndices);
+ IvcColType<ColIndices> actualColIndices = derived.ivcCol(colIndices);
+ return ConstReturnType(derived, internal::first(actualRowIndices), internal::first(actualColIndices),
+ internal::index_list_size(actualRowIndices), internal::index_list_size(actualColIndices));
+ }
+};
+
+// Symbolic
+template <typename RowIndices, typename ColIndices>
+struct IndexedViewSelector<RowIndices, ColIndices, true, false, false> {
+ using ReturnType = typename DenseBase<Derived>::Scalar&;
+ using ConstReturnType = typename DenseBase<Derived>::CoeffReturnType;
+
+ static inline ReturnType run(Derived& derived, const RowIndices& rowIndices, const ColIndices& colIndices) {
+ return derived(internal::eval_expr_given_size(rowIndices, derived.rows()),
+ internal::eval_expr_given_size(colIndices, derived.cols()));
+ }
+ static inline ConstReturnType run(const Derived& derived, const RowIndices& rowIndices,
+ const ColIndices& colIndices) {
+ return derived(internal::eval_expr_given_size(rowIndices, derived.rows()),
+ internal::eval_expr_given_size(colIndices, derived.cols()));
+ }
+};
+
+// this helper class assumes internal::is_valid_index_type<Indices>::value == false
+template <typename Indices,
+ bool UseSymbolic = symbolic::is_symbolic<Indices>::value,
+ bool UseBlock = !UseSymbolic && internal::get_compile_time_incr<IvcType<Indices>>::value == 1,
+ bool UseGeneric = !UseSymbolic && !UseBlock>
+struct VectorIndexedViewSelector;
+
+// Generic
+template <typename Indices>
+struct VectorIndexedViewSelector<Indices, false, false, true> {
+
+ static constexpr bool IsRowMajor = DenseBase<Derived>::IsRowMajor;
+
+ using RowMajorReturnType = IndexedView<Derived, IvcIndex, IvcType<Indices>>;
+ using ConstRowMajorReturnType = IndexedView<const Derived, IvcIndex, IvcType<Indices>>;
+
+ using ColMajorReturnType = IndexedView<Derived, IvcType<Indices>, IvcIndex>;
+ using ConstColMajorReturnType = IndexedView<const Derived, IvcType<Indices>, IvcIndex>;
+
+ using ReturnType = typename internal::conditional<IsRowMajor, RowMajorReturnType, ColMajorReturnType>::type;
+ using ConstReturnType =
+ typename internal::conditional<IsRowMajor, ConstRowMajorReturnType, ConstColMajorReturnType>::type;
+
+ template <bool UseRowMajor = IsRowMajor, std::enable_if_t<UseRowMajor, bool> = true>
+ static inline RowMajorReturnType run(Derived& derived, const Indices& indices) {
+ return RowMajorReturnType(derived, IvcIndex(0), derived.ivcCol(indices));
+ }
+ template <bool UseRowMajor = IsRowMajor, std::enable_if_t<UseRowMajor, bool> = true>
+ static inline ConstRowMajorReturnType run(const Derived& derived, const Indices& indices) {
+ return ConstRowMajorReturnType(derived, IvcIndex(0), derived.ivcCol(indices));
+ }
+ template <bool UseRowMajor = IsRowMajor, std::enable_if_t<!UseRowMajor, bool> = true>
+ static inline ColMajorReturnType run(Derived& derived, const Indices& indices) {
+ return ColMajorReturnType(derived, derived.ivcRow(indices), IvcIndex(0));
+ }
+ template <bool UseRowMajor = IsRowMajor, std::enable_if_t<!UseRowMajor, bool> = true>
+ static inline ConstColMajorReturnType run(const Derived& derived, const Indices& indices) {
+ return ConstColMajorReturnType(derived, derived.ivcRow(indices), IvcIndex(0));
+ }
+};
+
+// Block
+template <typename Indices>
+struct VectorIndexedViewSelector<Indices, false, true, false> {
+
+ using ReturnType = VectorBlock<Derived, internal::array_size<Indices>::value>;
+ using ConstReturnType = VectorBlock<const Derived, internal::array_size<Indices>::value>;
+
+ static inline ReturnType run(Derived& derived, const Indices& indices) {
+ IvcType<Indices> actualIndices = derived.ivcSize(indices);
+ return ReturnType(derived, internal::first(actualIndices), internal::index_list_size(actualIndices));
+ }
+ static inline ConstReturnType run(const Derived& derived, const Indices& indices) {
+ IvcType<Indices> actualIndices = derived.ivcSize(indices);
+ return ConstReturnType(derived, internal::first(actualIndices), internal::index_list_size(actualIndices));
+ }
+};
+
+// Symbolic
+template <typename Indices>
+struct VectorIndexedViewSelector<Indices, true, false, false> {
+
+ using ReturnType = typename DenseBase<Derived>::Scalar&;
+ using ConstReturnType = typename DenseBase<Derived>::CoeffReturnType;
+
+ static inline ReturnType run(Derived& derived, const Indices& id) {
+ return derived(internal::eval_expr_given_size(id, derived.size()));
+ }
+ static inline ConstReturnType run(const Derived& derived, const Indices& id) {
+ return derived(internal::eval_expr_given_size(id, derived.size()));
+ }
+};
+
+// SFINAE dummy types
+
+template <typename RowIndices, typename ColIndices>
+using EnableOverload = std::enable_if_t<
+ internal::valid_indexed_view_overload<RowIndices, ColIndices>::value && internal::is_lvalue<Derived>::value, bool>;
+
+template <typename RowIndices, typename ColIndices>
+using EnableConstOverload =
+ std::enable_if_t<internal::valid_indexed_view_overload<RowIndices, ColIndices>::value, bool>;
+
+template <typename Indices>
+using EnableVectorOverload =
+ std::enable_if_t<!internal::is_valid_index_type<Indices>::value && internal::is_lvalue<Derived>::value, bool>;
+
+template <typename Indices>
+using EnableConstVectorOverload = std::enable_if_t<!internal::is_valid_index_type<Indices>::value, bool>;
+
public:
-template <typename RowIndices, typename ColIndices>
-using IndexedViewType = IndexedView<Derived, IvcRowType<RowIndices>, IvcColType<ColIndices>>;
+// Public API for 2D matrices/arrays
+
+// non-const versions
template <typename RowIndices, typename ColIndices>
-using ConstIndexedViewType = IndexedView<const Derived, IvcRowType<RowIndices>, IvcColType<ColIndices>>;
+using IndexedViewType = typename IndexedViewSelector<RowIndices, ColIndices>::ReturnType;
-// This is the generic version
-
-template <typename RowIndices, typename ColIndices>
-std::enable_if_t<internal::valid_indexed_view_overload<RowIndices, ColIndices>::value &&
- internal::traits<IndexedViewType<RowIndices, ColIndices>>::ReturnAsIndexedView,
- IndexedViewType<RowIndices, ColIndices>>
-operator()(const RowIndices& rowIndices, const ColIndices& colIndices) {
- return IndexedViewType<RowIndices, ColIndices>(derived(), ivcRow(rowIndices), ivcCol(colIndices));
+template <typename RowIndices, typename ColIndices, EnableOverload<RowIndices, ColIndices> = true>
+IndexedViewType<RowIndices, ColIndices> operator()(const RowIndices& rowIndices, const ColIndices& colIndices) {
+ return IndexedViewSelector<RowIndices, ColIndices>::run(derived(), rowIndices, colIndices);
}
-template <typename RowIndices, typename ColIndices>
-std::enable_if_t<internal::valid_indexed_view_overload<RowIndices, ColIndices>::value &&
- internal::traits<ConstIndexedViewType<RowIndices, ColIndices>>::ReturnAsIndexedView,
- ConstIndexedViewType<RowIndices, ColIndices>>
-operator()(const RowIndices& rowIndices, const ColIndices& colIndices) const {
- return ConstIndexedViewType<RowIndices, ColIndices>(derived(), ivcRow(rowIndices), ivcCol(colIndices));
+template <typename RowType, size_t RowSize, typename ColIndices, typename RowIndices = Array<RowType, RowSize, 1>,
+ EnableOverload<RowIndices, ColIndices> = true>
+IndexedViewType<RowIndices, ColIndices> operator()(const RowType (&rowIndices)[RowSize], const ColIndices& colIndices) {
+ return IndexedViewSelector<RowIndices, ColIndices>::run(derived(), RowIndices{rowIndices}, colIndices);
}
-// The following overload returns a Block<> object
-
-template <typename RowIndices, typename ColIndices>
-std::enable_if_t<internal::valid_indexed_view_overload<RowIndices, ColIndices>::value &&
- internal::traits<IndexedViewType<RowIndices, ColIndices>>::ReturnAsBlock,
- typename internal::traits<IndexedViewType<RowIndices, ColIndices>>::BlockType>
-operator()(const RowIndices& rowIndices, const ColIndices& colIndices) {
- typedef typename internal::traits<IndexedViewType<RowIndices, ColIndices>>::BlockType BlockType;
- IvcRowType<RowIndices> actualRowIndices = ivcRow(rowIndices);
- IvcColType<ColIndices> actualColIndices = ivcCol(colIndices);
- return BlockType(derived(), internal::first(actualRowIndices), internal::first(actualColIndices),
- internal::index_list_size(actualRowIndices), internal::index_list_size(actualColIndices));
+template <typename RowIndices, typename ColType, size_t ColSize, typename ColIndices = Array<ColType, ColSize, 1>,
+ EnableOverload<RowIndices, ColIndices> = true>
+IndexedViewType<RowIndices, ColIndices> operator()(const RowIndices& rowIndices, const ColType (&colIndices)[ColSize]) {
+ return IndexedViewSelector<RowIndices, ColIndices>::run(derived(), rowIndices, ColIndices{colIndices});
}
-template <typename RowIndices, typename ColIndices>
-std::enable_if_t<internal::valid_indexed_view_overload<RowIndices, ColIndices>::value &&
- internal::traits<ConstIndexedViewType<RowIndices, ColIndices>>::ReturnAsBlock,
- typename internal::traits<ConstIndexedViewType<RowIndices, ColIndices>>::BlockType>
-operator()(const RowIndices& rowIndices, const ColIndices& colIndices) const {
- typedef typename internal::traits<ConstIndexedViewType<RowIndices, ColIndices>>::BlockType BlockType;
- IvcRowType<RowIndices> actualRowIndices = ivcRow(rowIndices);
- IvcColType<ColIndices> actualColIndices = ivcCol(colIndices);
- return BlockType(derived(), internal::first(actualRowIndices), internal::first(actualColIndices),
- internal::index_list_size(actualRowIndices), internal::index_list_size(actualColIndices));
+template <typename RowType, size_t RowSize, typename ColType, size_t ColSize,
+ typename RowIndices = Array<RowType, RowSize, 1>, typename ColIndices = Array<ColType, ColSize, 1>,
+ EnableOverload<RowIndices, ColIndices> = true>
+IndexedViewType<RowIndices, ColIndices> operator()(const RowType (&rowIndices)[RowSize],
+ const ColType (&colIndices)[ColSize]) {
+ return IndexedViewSelector<RowIndices, ColIndices>::run(derived(), RowIndices{rowIndices}, ColIndices{colIndices});
}
-// The following overload returns a Scalar
+// const versions
template <typename RowIndices, typename ColIndices>
-std::enable_if_t<internal::valid_indexed_view_overload<RowIndices, ColIndices>::value &&
- internal::traits<IndexedViewType<RowIndices, ColIndices>>::ReturnAsScalar && internal::is_lvalue<Derived>::value,
- Scalar&>
-operator()(const RowIndices& rowIndices, const ColIndices& colIndices) {
- return Base::operator()(internal::eval_expr_given_size(rowIndices, rows()),
- internal::eval_expr_given_size(colIndices, cols()));
+using ConstIndexedViewType = typename IndexedViewSelector<RowIndices, ColIndices>::ConstReturnType;
+
+template <typename RowIndices, typename ColIndices, EnableConstOverload<RowIndices, ColIndices> = true>
+ConstIndexedViewType<RowIndices, ColIndices> operator()(const RowIndices& rowIndices,
+ const ColIndices& colIndices) const {
+ return IndexedViewSelector<RowIndices, ColIndices>::run(derived(), rowIndices, colIndices);
}
-template <typename RowIndices, typename ColIndices>
-std::enable_if_t<internal::valid_indexed_view_overload<RowIndices, ColIndices>::value &&
- internal::traits<ConstIndexedViewType<RowIndices, ColIndices>>::ReturnAsScalar,
- CoeffReturnType>
-operator()(const RowIndices& rowIndices, const ColIndices& colIndices) const {
- return Base::operator()(internal::eval_expr_given_size(rowIndices, rows()),
- internal::eval_expr_given_size(colIndices, cols()));
+template <typename RowType, size_t RowSize, typename ColIndices, typename RowIndices = Array<RowType, RowSize, 1>,
+ EnableConstOverload<RowIndices, ColIndices> = true>
+ConstIndexedViewType<RowIndices, ColIndices> operator()(const RowType (&rowIndices)[RowSize],
+ const ColIndices& colIndices) const {
+ return IndexedViewSelector<RowIndices, ColIndices>::run(derived(), RowIndices{rowIndices}, colIndices);
}
-// Overloads for 1D vectors/arrays
+template <typename RowIndices, typename ColType, size_t ColSize, typename ColIndices = Array<ColType, ColSize, 1>,
+ EnableConstOverload<RowIndices, ColIndices> = true>
+ConstIndexedViewType<RowIndices, ColIndices> operator()(const RowIndices& rowIndices,
+ const ColType (&colIndices)[ColSize]) const {
+ return IndexedViewSelector<RowIndices, ColIndices>::run(derived(), rowIndices, ColIndices{colIndices});
+}
+
+template <typename RowType, size_t RowSize, typename ColType, size_t ColSize,
+ typename RowIndices = Array<RowType, RowSize, 1>, typename ColIndices = Array<ColType, ColSize, 1>,
+ EnableConstOverload<RowIndices, ColIndices> = true>
+ConstIndexedViewType<RowIndices, ColIndices> operator()(const RowType (&rowIndices)[RowSize],
+ const ColType (&colIndices)[ColSize]) const {
+ return IndexedViewSelector<RowIndices, ColIndices>::run(derived(), RowIndices{rowIndices}, ColIndices{colIndices});
+}
+
+// Public API for 1D vectors/arrays
+
+// non-const versions
template <typename Indices>
-std::enable_if_t<IsRowMajor && (!(internal::get_compile_time_incr<IvcType<Indices>>::value == 1 ||
- internal::is_valid_index_type<Indices>::value)),
- IndexedView<Derived, IvcIndex, IvcType<Indices>>>
-operator()(const Indices& indices) {
+using VectorIndexedViewType = typename VectorIndexedViewSelector<Indices>::ReturnType;
+
+template <typename Indices, EnableVectorOverload<Indices> = true>
+VectorIndexedViewType<Indices> operator()(const Indices& indices) {
EIGEN_STATIC_ASSERT_VECTOR_ONLY(Derived)
- return IndexedView<Derived, IvcIndex, IvcType<Indices>>(derived(), IvcIndex(0), ivcCol(indices));
+ return VectorIndexedViewSelector<Indices>::run(derived(), indices);
}
+template <typename IndexType, size_t Size, typename Indices = Array<IndexType, Size, 1>,
+ EnableVectorOverload<Indices> = true>
+VectorIndexedViewType<Indices> operator()(const IndexType (&indices)[Size]) {
+ EIGEN_STATIC_ASSERT_VECTOR_ONLY(Derived)
+ return VectorIndexedViewSelector<Indices>::run(derived(), Indices{indices});
+}
+
+// const versions
+
template <typename Indices>
-std::enable_if_t<IsRowMajor && (!(internal::get_compile_time_incr<IvcType<Indices>>::value == 1 ||
- internal::is_valid_index_type<Indices>::value)),
- IndexedView<const Derived, IvcIndex, IvcType<Indices>>>
-operator()(const Indices& indices) const {
+using ConstVectorIndexedViewType = typename VectorIndexedViewSelector<Indices>::ConstReturnType;
+
+template <typename Indices, EnableConstVectorOverload<Indices> = true>
+ConstVectorIndexedViewType<Indices> operator()(const Indices& indices) const {
EIGEN_STATIC_ASSERT_VECTOR_ONLY(Derived)
- return IndexedView<const Derived, IvcIndex, IvcType<Indices>>(derived(), IvcIndex(0), ivcCol(indices));
+ return VectorIndexedViewSelector<Indices>::run(derived(), indices);
}
-template <typename Indices>
-std::enable_if_t<(!IsRowMajor) && (!(internal::get_compile_time_incr<IvcType<Indices>>::value == 1 ||
- internal::is_valid_index_type<Indices>::value)),
- IndexedView<Derived, IvcType<Indices>, IvcIndex>>
-operator()(const Indices& indices) {
+template <typename IndexType, size_t Size, typename Indices = Array<IndexType, Size, 1>,
+ EnableConstVectorOverload<Indices> = true>
+ConstVectorIndexedViewType<Indices> operator()(const IndexType (&indices)[Size]) const {
EIGEN_STATIC_ASSERT_VECTOR_ONLY(Derived)
- return IndexedView<Derived, IvcType<Indices>, IvcIndex>(derived(), ivcRow(indices), IvcIndex(0));
-}
-
-template <typename Indices>
-std::enable_if_t<(!IsRowMajor) && (!(internal::get_compile_time_incr<IvcType<Indices>>::value == 1 ||
- internal::is_valid_index_type<Indices>::value)),
- IndexedView<const Derived, IvcType<Indices>, IvcIndex>>
-operator()(const Indices& indices) const {
- EIGEN_STATIC_ASSERT_VECTOR_ONLY(Derived)
- return IndexedView<const Derived, IvcType<Indices>, IvcIndex>(derived(), ivcRow(indices), IvcIndex(0));
-}
-
-template <typename Indices>
-std::enable_if_t<(internal::get_compile_time_incr<IvcType<Indices>>::value == 1) &&
- (!internal::is_valid_index_type<Indices>::value) && (!symbolic::is_symbolic<Indices>::value),
- VectorBlock<Derived, internal::array_size<Indices>::value>>
-operator()(const Indices& indices) {
- EIGEN_STATIC_ASSERT_VECTOR_ONLY(Derived)
- IvcType<Indices> actualIndices = ivcSize(indices);
- return VectorBlock<Derived, internal::array_size<Indices>::value>(derived(), internal::first(actualIndices),
- internal::index_list_size(actualIndices));
-}
-
-template <typename Indices>
-std::enable_if_t<(internal::get_compile_time_incr<IvcType<Indices>>::value == 1) &&
- (!internal::is_valid_index_type<Indices>::value) && (!symbolic::is_symbolic<Indices>::value),
- VectorBlock<const Derived, internal::array_size<Indices>::value>>
-operator()(const Indices& indices) const {
- EIGEN_STATIC_ASSERT_VECTOR_ONLY(Derived)
- IvcType<Indices> actualIndices = ivcSize(indices);
- return VectorBlock<const Derived, internal::array_size<Indices>::value>(derived(), internal::first(actualIndices),
- internal::index_list_size(actualIndices));
-}
-
-template <typename IndexType>
-std::enable_if_t<symbolic::is_symbolic<IndexType>::value && internal::is_lvalue<Derived>::value, Scalar&> operator()(const IndexType& id) {
- return Base::operator()(internal::eval_expr_given_size(id, size()));
-}
-
-template <typename IndexType>
-std::enable_if_t<symbolic::is_symbolic<IndexType>::value, CoeffReturnType> operator()(const IndexType& id) const {
- return Base::operator()(internal::eval_expr_given_size(id, size()));
+ return VectorIndexedViewSelector<Indices>::run(derived(), Indices{indices});
}
#else // EIGEN_PARSED_BY_DOXYGEN
diff --git a/test/indexed_view.cpp b/test/indexed_view.cpp
index 84a4767..41ba521 100644
--- a/test/indexed_view.cpp
+++ b/test/indexed_view.cpp
@@ -295,6 +295,69 @@
VERIFY_IS_EQUAL( a(std::array<int,3>{1,3,5}).SizeAtCompileTime, 3 );
VERIFY_IS_EQUAL( b(std::array<int,3>{1,3,5}).SizeAtCompileTime, 3 );
+ // check different index types (C-style array, STL container, Eigen type)
+ {
+ Index size = 10;
+ ArrayXd r = ArrayXd::Random(size);
+ ArrayXi idx = ArrayXi::EqualSpaced(size, 0, 1);
+ std::shuffle(idx.begin(), idx.end(), std::random_device());
+
+ int c_array[3] = { idx[0], idx[1], idx[2] };
+ std::vector<int> std_vector{ idx[0], idx[1], idx[2] };
+ Matrix<int, 3, 1> eigen_matrix{ idx[0], idx[1], idx[2] };
+
+ // non-const access
+ VERIFY_IS_CWISE_EQUAL(r({ idx[0], idx[1], idx[2] }), r(c_array));
+ VERIFY_IS_CWISE_EQUAL(r({ idx[0], idx[1], idx[2] }), r(std_vector));
+ VERIFY_IS_CWISE_EQUAL(r({ idx[0], idx[1], idx[2] }), r(eigen_matrix));
+ VERIFY_IS_CWISE_EQUAL(r(std_vector), r(c_array));
+ VERIFY_IS_CWISE_EQUAL(r(std_vector), r(eigen_matrix));
+ VERIFY_IS_CWISE_EQUAL(r(eigen_matrix), r(c_array));
+
+ const ArrayXd& r_ref = r;
+ // const access
+ VERIFY_IS_CWISE_EQUAL(r_ref({ idx[0], idx[1], idx[2] }), r_ref(c_array));
+ VERIFY_IS_CWISE_EQUAL(r_ref({ idx[0], idx[1], idx[2] }), r_ref(std_vector));
+ VERIFY_IS_CWISE_EQUAL(r_ref({ idx[0], idx[1], idx[2] }), r_ref(eigen_matrix));
+ VERIFY_IS_CWISE_EQUAL(r_ref(std_vector), r_ref(c_array));
+ VERIFY_IS_CWISE_EQUAL(r_ref(std_vector), r_ref(eigen_matrix));
+ VERIFY_IS_CWISE_EQUAL(r_ref(eigen_matrix), r_ref(c_array));
+ }
+
+ {
+ Index rows = 8;
+ Index cols = 11;
+ ArrayXXd R = ArrayXXd::Random(rows, cols);
+ ArrayXi r_idx = ArrayXi::EqualSpaced(rows, 0, 1);
+ ArrayXi c_idx = ArrayXi::EqualSpaced(cols, 0, 1);
+ std::shuffle(r_idx.begin(), r_idx.end(), std::random_device());
+ std::shuffle(c_idx.begin(), c_idx.end(), std::random_device());
+
+ int c_array_rows[3] = { r_idx[0], r_idx[1], r_idx[2] };
+ int c_array_cols[4] = { c_idx[0], c_idx[1], c_idx[2], c_idx[3] };
+ std::vector<int> std_vector_rows{ r_idx[0], r_idx[1], r_idx[2] };
+ std::vector<int> std_vector_cols{ c_idx[0], c_idx[1], c_idx[2], c_idx[3] };
+ Matrix<int, 3, 1> eigen_matrix_rows{ r_idx[0], r_idx[1], r_idx[2] };
+ Matrix<int, 4, 1> eigen_matrix_cols{ c_idx[0], c_idx[1], c_idx[2], c_idx[3] };
+
+ // non-const access
+ VERIFY_IS_CWISE_EQUAL(R({ r_idx[0], r_idx[1], r_idx[2] }, { c_idx[0], c_idx[1], c_idx[2], c_idx[3] }), R(c_array_rows, c_array_cols));
+ VERIFY_IS_CWISE_EQUAL(R({ r_idx[0], r_idx[1], r_idx[2] }, { c_idx[0], c_idx[1], c_idx[2], c_idx[3] }), R(std_vector_rows, std_vector_cols));
+ VERIFY_IS_CWISE_EQUAL(R({ r_idx[0], r_idx[1], r_idx[2] }, { c_idx[0], c_idx[1], c_idx[2], c_idx[3] }), R(eigen_matrix_rows, eigen_matrix_cols));
+ VERIFY_IS_CWISE_EQUAL(R(std_vector_rows, std_vector_cols), R(c_array_rows, c_array_cols));
+ VERIFY_IS_CWISE_EQUAL(R(std_vector_rows, std_vector_cols), R(eigen_matrix_rows, eigen_matrix_cols));
+ VERIFY_IS_CWISE_EQUAL(R(eigen_matrix_rows, eigen_matrix_cols), R(c_array_rows, c_array_cols));
+
+ const ArrayXXd& R_ref = R;
+ // const access
+ VERIFY_IS_CWISE_EQUAL(R_ref({ r_idx[0], r_idx[1], r_idx[2] }, { c_idx[0], c_idx[1], c_idx[2], c_idx[3] }), R_ref(c_array_rows, c_array_cols));
+ VERIFY_IS_CWISE_EQUAL(R_ref({ r_idx[0], r_idx[1], r_idx[2] }, { c_idx[0], c_idx[1], c_idx[2], c_idx[3] }), R_ref(std_vector_rows, std_vector_cols));
+ VERIFY_IS_CWISE_EQUAL(R_ref({ r_idx[0], r_idx[1], r_idx[2] }, { c_idx[0], c_idx[1], c_idx[2], c_idx[3] }), R_ref(eigen_matrix_rows, eigen_matrix_cols));
+ VERIFY_IS_CWISE_EQUAL(R_ref(std_vector_rows, std_vector_cols), R_ref(c_array_rows, c_array_cols));
+ VERIFY_IS_CWISE_EQUAL(R_ref(std_vector_rows, std_vector_cols), R_ref(eigen_matrix_rows, eigen_matrix_cols));
+ VERIFY_IS_CWISE_EQUAL(R_ref(eigen_matrix_rows, eigen_matrix_cols), R_ref(c_array_rows, c_array_cols));
+ }
+
// check mat(i,j) with weird types for i and j
{
VERIFY_IS_APPROX( A(B.RowsAtCompileTime-1, 1), A(3,1) );
@@ -357,8 +420,33 @@
A(XX,Y) = 1;
A(X,YY) = 1;
// check symbolic indices
- a(last) = 1;
+ a(last) = 1.0;
A(last, last) = 1;
+ // check weird non-const, non-lvalue scenarios
+ {
+ // in these scenarios, the objects are not declared 'const', and the compiler will atttempt to use the non-const
+ // overloads without intervention
+
+ // non-const map to a const object
+ Map<const ArrayXd> a_map(a.data(), a.size());
+ Map<const ArrayXXi> A_map(A.data(), A.rows(), A.cols());
+
+ VERIFY_IS_EQUAL(a_map(last), a.coeff(a.size() - 1));
+ VERIFY_IS_EQUAL(A_map(last, last), A.coeff(A.rows() - 1, A.cols() - 1));
+
+ // non-const expressions that have no modifiable data
+ using Op = internal::scalar_constant_op<double>;
+ using VectorXpr = CwiseNullaryOp<Op, VectorXd>;
+ using MatrixXpr = CwiseNullaryOp<Op, MatrixXd>;
+ double constant_val = internal::random<double>();
+ Op op(constant_val);
+ VectorXpr vectorXpr(10, 1, op);
+ MatrixXpr matrixXpr(8, 11, op);
+
+ VERIFY_IS_EQUAL(vectorXpr.coeff(vectorXpr.size() - 1), vectorXpr(last));
+ VERIFY_IS_EQUAL(matrixXpr.coeff(matrixXpr.rows() - 1, matrixXpr.cols() - 1), matrixXpr(last, last));
+ }
+
// Check compilation of varying integer types as index types:
Index i = n/2;
diff --git a/test/packetmath.cpp b/test/packetmath.cpp
index a98a014..5dd4cbc 100644
--- a/test/packetmath.cpp
+++ b/test/packetmath.cpp
@@ -77,7 +77,7 @@
EIGEN_USING_STD(frexp)
const T out = static_cast<T>(frexp(x, &iexp));
exp = static_cast<T>(iexp);
-
+
// The exponent value is unspecified if the input is inf or NaN, but MSVC
// seems to set it to 1. We need to set it back to zero for consistency.
if (!(numext::isfinite)(x)) {
@@ -340,60 +340,78 @@
CHECK_CWISE2_IF(true, internal::pcmp_lt_or_nan, internal::pcmp_lt_or_nan);
}
+template <typename Scalar, typename Packet, typename EnableIf = void>
+struct packetmath_boolean_mask_ops_notcomplex_test {
+ static void run() {}
+};
+
template <typename Scalar, typename Packet>
-void packetmath_boolean_mask_ops_notcomplex() {
- const int PacketSize = internal::unpacket_traits<Packet>::size;
- const int size = 2 * PacketSize;
- EIGEN_ALIGN_MAX Scalar data1[size];
- EIGEN_ALIGN_MAX Scalar data2[size];
- EIGEN_ALIGN_MAX Scalar ref[size];
+struct packetmath_boolean_mask_ops_notcomplex_test<
+ Scalar, Packet,
+ std::enable_if_t<internal::packet_traits<Scalar>::HasCmp &&
+ !internal::is_same<Scalar, bool>::value>> {
+ static void run() {
+ const int PacketSize = internal::unpacket_traits<Packet>::size;
+ const int size = 2 * PacketSize;
+ EIGEN_ALIGN_MAX Scalar data1[size];
+ EIGEN_ALIGN_MAX Scalar data2[size];
+ EIGEN_ALIGN_MAX Scalar ref[size];
- for (int i = 0; i < PacketSize; ++i) {
- data1[i] = internal::random<Scalar>();
- data1[i + PacketSize] = internal::random<bool>() ? data1[i] : Scalar(0);
- }
+ for (int i = 0; i < PacketSize; ++i) {
+ data1[i] = internal::random<Scalar>();
+ data1[i + PacketSize] = internal::random<bool>() ? data1[i] : Scalar(0);
+ }
- CHECK_CWISE2_IF(true, internal::pcmp_le, internal::pcmp_le);
- CHECK_CWISE2_IF(true, internal::pcmp_lt, internal::pcmp_lt);
+ CHECK_CWISE2_IF(true, internal::pcmp_le, internal::pcmp_le);
+ CHECK_CWISE2_IF(true, internal::pcmp_lt, internal::pcmp_lt);
//Test (-0) <=/< (0) for signed operations
- for (int i = 0; i < PacketSize; ++i) {
- data1[i] = Scalar(-0.0);
- data1[i + PacketSize] = internal::random<bool>() ? data1[i] : Scalar(0);
- }
- CHECK_CWISE2_IF(true, internal::pcmp_le, internal::pcmp_le);
- CHECK_CWISE2_IF(true, internal::pcmp_lt, internal::pcmp_lt);
+ for (int i = 0; i < PacketSize; ++i) {
+ data1[i] = Scalar(-0.0);
+ data1[i + PacketSize] = internal::random<bool>() ? data1[i] : Scalar(0);
+ }
+ CHECK_CWISE2_IF(true, internal::pcmp_le, internal::pcmp_le);
+ CHECK_CWISE2_IF(true, internal::pcmp_lt, internal::pcmp_lt);
//Test NaN
- for (int i = 0; i < PacketSize; ++i) {
- data1[i] = NumTraits<Scalar>::quiet_NaN();
- data1[i + PacketSize] = internal::random<bool>() ? data1[i] : Scalar(0);
+ for (int i = 0; i < PacketSize; ++i) {
+ data1[i] = NumTraits<Scalar>::quiet_NaN();
+ data1[i + PacketSize] = internal::random<bool>() ? data1[i] : Scalar(0);
+ }
+ CHECK_CWISE2_IF(true, internal::pcmp_le, internal::pcmp_le);
+ CHECK_CWISE2_IF(true, internal::pcmp_lt, internal::pcmp_lt);
}
- CHECK_CWISE2_IF(true, internal::pcmp_le, internal::pcmp_le);
- CHECK_CWISE2_IF(true, internal::pcmp_lt, internal::pcmp_lt);
-}
+};
-// Packet16b representing bool does not support ptrue, pandnot or pcmp_eq, since the scalar path
-// (for some compilers) compute the bitwise and with 0x1 of the results to keep the value in [0,1].
-template<>
+// Packet16b representing bool does not support ptrue, pandnot or pcmp_eq, since
+// the scalar path (for some compilers) compute the bitwise and with 0x1 of the
+// results to keep the value in [0,1].
+template <>
void packetmath_boolean_mask_ops<bool, internal::packet_traits<bool>::type>() {}
-template<>
-void packetmath_boolean_mask_ops_notcomplex<bool, internal::packet_traits<bool>::type>() {}
+
+template <typename Scalar, typename Packet, typename EnableIf = void>
+struct packetmath_minus_zero_add_test {
+ static void run() {}
+};
template <typename Scalar, typename Packet>
-void packetmath_minus_zero_add() {
- const int PacketSize = internal::unpacket_traits<Packet>::size;
- const int size = 2 * PacketSize;
- EIGEN_ALIGN_MAX Scalar data1[size] = {};
- EIGEN_ALIGN_MAX Scalar data2[size] = {};
- EIGEN_ALIGN_MAX Scalar ref[size] = {};
-
- for (int i = 0; i < PacketSize; ++i) {
- data1[i] = Scalar(-0.0);
- data1[i + PacketSize] = Scalar(-0.0);
+struct packetmath_minus_zero_add_test<
+ Scalar, Packet,
+ std::enable_if_t<!NumTraits<Scalar>::IsInteger>> {
+ static void run() {
+ const int PacketSize = internal::unpacket_traits<Packet>::size;
+ const int size = 2 * PacketSize;
+ EIGEN_ALIGN_MAX Scalar data1[size] = {};
+ EIGEN_ALIGN_MAX Scalar data2[size] = {};
+ EIGEN_ALIGN_MAX Scalar ref[size] = {};
+
+ for (int i = 0; i < PacketSize; ++i) {
+ data1[i] = Scalar(-0.0);
+ data1[i + PacketSize] = Scalar(-0.0);
+ }
+ CHECK_CWISE2_IF(internal::packet_traits<Scalar>::HasAdd, REF_ADD, internal::padd);
}
- CHECK_CWISE2_IF(internal::packet_traits<Scalar>::HasAdd, REF_ADD, internal::padd);
-}
+};
// Ensure optimization barrier compiles and doesn't modify contents.
// Only applies to raw types, so will not work for std::complex, Eigen::half
@@ -673,7 +691,7 @@
packetmath_boolean_mask_ops<Scalar, Packet>();
packetmath_pcast_ops_runner<Scalar, Packet>::run();
- packetmath_minus_zero_add<Scalar, Packet>();
+ packetmath_minus_zero_add_test<Scalar, Packet>::run();
for (int i = 0; i < size; ++i) {
data1[i] = numext::abs(internal::random<Scalar>());
@@ -682,9 +700,9 @@
CHECK_CWISE1_IF(PacketTraits::HasRsqrt, numext::rsqrt, internal::prsqrt);
CHECK_CWISE3_IF(true, REF_MADD, internal::pmadd);
if (!std::is_same<Scalar, bool>::value && NumTraits<Scalar>::IsSigned) {
- CHECK_CWISE3_IF(true, REF_NMSUB, internal::pnmsub);
+ CHECK_CWISE3_IF(PacketTraits::HasNegate, REF_NMSUB, internal::pnmsub);
}
-
+
// For pmsub, pnmadd, the values can cancel each other to become near zero,
// which can lead to very flaky tests. Here we ensure the signs are such that
// they do not cancel.
@@ -695,7 +713,7 @@
}
if (!std::is_same<Scalar, bool>::value && NumTraits<Scalar>::IsSigned) {
CHECK_CWISE3_IF(true, REF_MSUB, internal::pmsub);
- CHECK_CWISE3_IF(true, REF_NMADD, internal::pnmadd);
+ CHECK_CWISE3_IF(PacketTraits::HasNegate, REF_NMADD, internal::pnmadd);
}
}
@@ -714,7 +732,7 @@
T operator()(const T& val) const { \
return Func(val); \
} \
-}
+ }
CREATE_FUNCTOR(psqrt_functor, internal::psqrt);
CREATE_FUNCTOR(prsqrt_functor, internal::prsqrt);
@@ -742,20 +760,20 @@
// When EIGEN_FAST_MATH is 1 we relax the conditions slightly, and allow the function
// to return the same value for subnormals as the reference would return for zero with
// the same sign as the input.
- #if EIGEN_FAST_MATH
- data1[0] = Scalar(scale) * std::numeric_limits<Scalar>::denorm_min();
- data1[1] = -data1[0];
- test::packet_helper<Cond, Packet> h;
- h.store(data2, fun(h.load(data1)));
+#if EIGEN_FAST_MATH
+ data1[0] = Scalar(scale) * std::numeric_limits<Scalar>::denorm_min();
+ data1[1] = -data1[0];
+ test::packet_helper<Cond, Packet> h;
+ h.store(data2, fun(h.load(data1)));
for (int i=0; i < PacketSize; ++i) {
- const Scalar ref_zero = ref_fun(data1[i] < 0 ? -Scalar(0) : Scalar(0));
- const Scalar ref_val = ref_fun(data1[i]);
- VERIFY(((std::isnan)(data2[i]) && (std::isnan)(ref_val)) || data2[i] == ref_zero ||
- verifyIsApprox(data2[i], ref_val));
- }
- #else
- CHECK_CWISE1_IF(Cond, ref_fun, fun);
- #endif
+ const Scalar ref_zero = ref_fun(data1[i] < 0 ? -Scalar(0) : Scalar(0));
+ const Scalar ref_val = ref_fun(data1[i]);
+ VERIFY(((std::isnan)(data2[i]) && (std::isnan)(ref_val)) || data2[i] == ref_zero ||
+ verifyIsApprox(data2[i], ref_val));
+ }
+#else
+ CHECK_CWISE1_IF(Cond, ref_fun, fun);
+#endif
}
}
@@ -763,7 +781,7 @@
data1[0] = norm_min;
data1[1] = -data1[0];
CHECK_CWISE1_IF(Cond, ref_fun, fun);
-
+
// Test for largest floats.
data1[0] = norm_max;
data1[1] = -data1[0];
@@ -794,7 +812,7 @@
EIGEN_ALIGN_MAX Scalar data1[PacketSize * 4] = {};
EIGEN_ALIGN_MAX Scalar data2[PacketSize * 4] = {};
EIGEN_ALIGN_MAX Scalar ref[PacketSize * 4] = {};
-
+
// Negate with -0.
if (PacketTraits::HasNegate) {
test::packet_helper<PacketTraits::HasNegate,Packet> h;
@@ -831,7 +849,7 @@
CHECK_CWISE1_IF(PacketTraits::HasSign, numext::sign, internal::psign);
packetmath_boolean_mask_ops_real<Scalar,Packet>();
-
+
// Rounding edge cases.
if (PacketTraits::HasRound || PacketTraits::HasCeil || PacketTraits::HasFloor || PacketTraits::HasRint) {
typedef typename internal::make_integer<Scalar>::type IntType;
@@ -864,7 +882,7 @@
values.push_back(NumTraits<Scalar>::infinity());
values.push_back(-NumTraits<Scalar>::infinity());
values.push_back(NumTraits<Scalar>::quiet_NaN());
-
+
for (size_t k=0; k<values.size(); ++k) {
data1[0] = values[k];
CHECK_CWISE1_EXACT_IF(PacketTraits::HasRound, numext::round, internal::pround);
@@ -890,7 +908,7 @@
data1[0] = -NumTraits<Scalar>::infinity();
}
CHECK_CWISE1_IF(PacketTraits::HasExp, std::exp, internal::pexp);
-
+
CHECK_CWISE1_BYREF1_IF(PacketTraits::HasExp, REF_FREXP, internal::pfrexp);
if (PacketTraits::HasExp) {
// Check denormals:
@@ -900,11 +918,11 @@
data1[0] = -data1[0];
CHECK_CWISE1_BYREF1_IF(PacketTraits::HasExp, REF_FREXP, internal::pfrexp);
}
-
+
// zero
data1[0] = Scalar(0);
CHECK_CWISE1_BYREF1_IF(PacketTraits::HasExp, REF_FREXP, internal::pfrexp);
-
+
// inf and NaN only compare output fraction, not exponent.
test::packet_helper<PacketTraits::HasExp,Packet> h;
Packet pout;
@@ -919,7 +937,7 @@
VERIFY(test::areApprox(ref, data2, 1) && "internal::pfrexp");
}
}
-
+
for (int i = 0; i < PacketSize; ++i) {
data1[i] = Scalar(internal::random<double>(-1, 1));
data2[i] = Scalar(internal::random<double>(-1, 1));
@@ -1166,7 +1184,7 @@
ref[i] = SCALAR(REFOP(static_cast<REFTYPE>(data1[i]))); \
h.store(data2, POP(h.load(data1))); \
VERIFY(test::areApprox(ref, data2, PacketSize) && #POP); \
-}
+ }
template <typename Scalar>
Scalar propagate_nan_max(const Scalar& a, const Scalar& b) {
@@ -1293,7 +1311,7 @@
CHECK_CWISE2_IF(PacketTraits::HasMax, propagate_nan_max, internal::pmax<PropagateNaN>);
}
- packetmath_boolean_mask_ops_notcomplex<Scalar, Packet>();
+ packetmath_boolean_mask_ops_notcomplex_test<Scalar, Packet>::run();
}
template <typename Scalar, typename Packet, bool ConjLhs, bool ConjRhs>
diff --git a/unsupported/Eigen/CXX11/src/Tensor/TensorEvaluator.h b/unsupported/Eigen/CXX11/src/Tensor/TensorEvaluator.h
index f8e3f29..b16e5a6 100644
--- a/unsupported/Eigen/CXX11/src/Tensor/TensorEvaluator.h
+++ b/unsupported/Eigen/CXX11/src/Tensor/TensorEvaluator.h
@@ -817,14 +817,21 @@
{
typedef TensorSelectOp<IfArgType, ThenArgType, ElseArgType> XprType;
typedef typename XprType::Scalar Scalar;
+
+ using TernarySelectOp = internal::scalar_boolean_select_op<typename internal::traits<ThenArgType>::Scalar,
+ typename internal::traits<ElseArgType>::Scalar,
+ typename internal::traits<IfArgType>::Scalar>;
+ static constexpr bool TernaryPacketAccess =
+ TensorEvaluator<ThenArgType, Device>::PacketAccess && TensorEvaluator<ElseArgType, Device>::PacketAccess &&
+ TensorEvaluator<IfArgType, Device>::PacketAccess && internal::functor_traits<TernarySelectOp>::PacketAccess;
static constexpr int Layout = TensorEvaluator<IfArgType, Device>::Layout;
enum {
IsAligned = TensorEvaluator<ThenArgType, Device>::IsAligned &
TensorEvaluator<ElseArgType, Device>::IsAligned,
- PacketAccess = TensorEvaluator<ThenArgType, Device>::PacketAccess &
- TensorEvaluator<ElseArgType, Device>::PacketAccess &
- PacketType<Scalar, Device>::HasBlend,
+ PacketAccess = (TensorEvaluator<ThenArgType, Device>::PacketAccess &&
+ TensorEvaluator<ElseArgType, Device>::PacketAccess &&
+ PacketType<Scalar, Device>::HasBlend) || TernaryPacketAccess,
BlockAccess = TensorEvaluator<IfArgType, Device>::BlockAccess &&
TensorEvaluator<ThenArgType, Device>::BlockAccess &&
TensorEvaluator<ElseArgType, Device>::BlockAccess,
@@ -922,7 +929,9 @@
{
return m_condImpl.coeff(index) ? m_thenImpl.coeff(index) : m_elseImpl.coeff(index);
}
- template<int LoadMode>
+
+ template<int LoadMode, bool UseTernary = TernaryPacketAccess,
+ std::enable_if_t<!UseTernary, bool> = true>
EIGEN_DEVICE_FUNC PacketReturnType packet(Index index) const
{
internal::Selector<PacketSize> select;
@@ -936,6 +945,14 @@
}
+ template <int LoadMode, bool UseTernary = TernaryPacketAccess,
+ std::enable_if_t<UseTernary, bool> = true>
+ EIGEN_DEVICE_FUNC PacketReturnType packet(Index index) const {
+ return TernarySelectOp().template packetOp<PacketReturnType>(m_thenImpl.template packet<LoadMode>(index),
+ m_elseImpl.template packet<LoadMode>(index),
+ m_condImpl.template packet<LoadMode>(index));
+ }
+
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorOpCost
costPerCoeff(bool vectorized) const {
return m_condImpl.costPerCoeff(vectorized) +
diff --git a/unsupported/Eigen/src/SpecialFunctions/SpecialFunctionsImpl.h b/unsupported/Eigen/src/SpecialFunctions/SpecialFunctionsImpl.h
index d85729d..f9b3cd0 100644
--- a/unsupported/Eigen/src/SpecialFunctions/SpecialFunctionsImpl.h
+++ b/unsupported/Eigen/src/SpecialFunctions/SpecialFunctionsImpl.h
@@ -296,55 +296,54 @@
****************************************************************************/
/** \internal \returns the error function of \a a (coeff-wise)
- Doesn't do anything fancy, just a 13/8-degree rational interpolant which
- is accurate up to a couple of ulp in the range [-4, 4], outside of which
- fl(erf(x)) = +/-1.
+ Doesn't do anything fancy, just a 9/12-degree rational interpolant which
+ is accurate to 3 ulp for normalized floats in the range [-c;c], where
+ c = erfinv(1-2^-23), outside of which x should be +/-1 in single precision.
+ Strictly speaking c should be erfinv(1-2^-24), but we clamp slightly earlier
+ to avoid returning values greater than 1.
This implementation works on both scalars and Ts.
*/
template <typename T>
-EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE T generic_fast_erf_float(const T& a_x) {
- // Clamp the inputs to the range [-4, 4] since anything outside
- // this range is +/-1.0f in single-precision.
- const T plus_4 = pset1<T>(4.f);
- const T minus_4 = pset1<T>(-4.f);
- const T x = pmax(pmin(a_x, plus_4), minus_4);
+EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE T generic_fast_erf_float(const T& x) {
+ constexpr float kErfInvOneMinusHalfULP = 3.832506856900711f;
+ const T clamp = pcmp_le(pset1<T>(kErfInvOneMinusHalfULP), pabs(x));
// The monomial coefficients of the numerator polynomial (odd).
- const T alpha_1 = pset1<T>(-1.60960333262415e-02f);
- const T alpha_3 = pset1<T>(-2.95459980854025e-03f);
- const T alpha_5 = pset1<T>(-7.34990630326855e-04f);
- const T alpha_7 = pset1<T>(-5.69250639462346e-05f);
- const T alpha_9 = pset1<T>(-2.10102402082508e-06f);
- const T alpha_11 = pset1<T>(2.77068142495902e-08f);
- const T alpha_13 = pset1<T>(-2.72614225801306e-10f);
+ const T alpha_1 = pset1<T>(1.128379143519084f);
+ const T alpha_3 = pset1<T>(0.18520832239976145f);
+ const T alpha_5 = pset1<T>(0.050955695062380861f);
+ const T alpha_7 = pset1<T>(0.0034082910107109506f);
+ const T alpha_9 = pset1<T>(0.00022905065861350646f);
// The monomial coefficients of the denominator polynomial (even).
- const T beta_0 = pset1<T>(-1.42647390514189e-02f);
- const T beta_2 = pset1<T>(-7.37332916720468e-03f);
- const T beta_4 = pset1<T>(-1.68282697438203e-03f);
- const T beta_6 = pset1<T>(-2.13374055278905e-04f);
- const T beta_8 = pset1<T>(-1.45660718464996e-05f);
+ const T beta_0 = pset1<T>(1.0f);
+ const T beta_2 = pset1<T>(0.49746925110067538f);
+ const T beta_4 = pset1<T>(0.11098505178285362f);
+ const T beta_6 = pset1<T>(0.014070470171167667f);
+ const T beta_8 = pset1<T>(0.0010179625278914885f);
+ const T beta_10 = pset1<T>(0.000023547966471313185f);
+ const T beta_12 = pset1<T>(-1.1791602954361697e-7f);
// Since the polynomials are odd/even, we need x^2.
const T x2 = pmul(x, x);
// Evaluate the numerator polynomial p.
- T p = pmadd(x2, alpha_13, alpha_11);
- p = pmadd(x2, p, alpha_9);
- p = pmadd(x2, p, alpha_7);
+ T p = pmadd(x2, alpha_9, alpha_7);
p = pmadd(x2, p, alpha_5);
p = pmadd(x2, p, alpha_3);
p = pmadd(x2, p, alpha_1);
p = pmul(x, p);
// Evaluate the denominator polynomial p.
- T q = pmadd(x2, beta_8, beta_6);
+ T q = pmadd(x2, beta_12, beta_10);
+ q = pmadd(x2, q, beta_8);
+ q = pmadd(x2, q, beta_6);
q = pmadd(x2, q, beta_4);
q = pmadd(x2, q, beta_2);
q = pmadd(x2, q, beta_0);
// Divide the numerator by the denominator.
- return pdiv(p, q);
+ return pselect(clamp, psign(x), pdiv(p, q));
}
template <typename T>
diff --git a/unsupported/test/cxx11_tensor_comparisons.cpp b/unsupported/test/cxx11_tensor_comparisons.cpp
index 86c7335..e0bd90d 100644
--- a/unsupported/test/cxx11_tensor_comparisons.cpp
+++ b/unsupported/test/cxx11_tensor_comparisons.cpp
@@ -16,23 +16,41 @@
using Scalar = float;
+using TypedLTOp = internal::scalar_cmp_op<Scalar, Scalar, internal::cmp_LT, true>;
+using TypedLEOp = internal::scalar_cmp_op<Scalar, Scalar, internal::cmp_LE, true>;
+using TypedGTOp = internal::scalar_cmp_op<Scalar, Scalar, internal::cmp_GT, true>;
+using TypedGEOp = internal::scalar_cmp_op<Scalar, Scalar, internal::cmp_GE, true>;
+using TypedEQOp = internal::scalar_cmp_op<Scalar, Scalar, internal::cmp_EQ, true>;
+using TypedNEOp = internal::scalar_cmp_op<Scalar, Scalar, internal::cmp_NEQ, true>;
+
static void test_orderings()
{
Tensor<Scalar, 3> mat1(2,3,7);
Tensor<Scalar, 3> mat2(2,3,7);
+
+ mat1.setRandom();
+ mat2.setRandom();
+
Tensor<bool, 3> lt(2,3,7);
Tensor<bool, 3> le(2,3,7);
Tensor<bool, 3> gt(2,3,7);
Tensor<bool, 3> ge(2,3,7);
- mat1.setRandom();
- mat2.setRandom();
+ Tensor<Scalar, 3> typed_lt(2, 3, 7);
+ Tensor<Scalar, 3> typed_le(2, 3, 7);
+ Tensor<Scalar, 3> typed_gt(2, 3, 7);
+ Tensor<Scalar, 3> typed_ge(2, 3, 7);
lt = mat1 < mat2;
le = mat1 <= mat2;
gt = mat1 > mat2;
ge = mat1 >= mat2;
+ typed_lt = mat1.binaryExpr(mat2, TypedLTOp());
+ typed_le = mat1.binaryExpr(mat2, TypedLEOp());
+ typed_gt = mat1.binaryExpr(mat2, TypedGTOp());
+ typed_ge = mat1.binaryExpr(mat2, TypedGEOp());
+
for (int i = 0; i < 2; ++i) {
for (int j = 0; j < 3; ++j) {
for (int k = 0; k < 7; ++k) {
@@ -40,6 +58,11 @@
VERIFY_IS_EQUAL(le(i,j,k), mat1(i,j,k) <= mat2(i,j,k));
VERIFY_IS_EQUAL(gt(i,j,k), mat1(i,j,k) > mat2(i,j,k));
VERIFY_IS_EQUAL(ge(i,j,k), mat1(i,j,k) >= mat2(i,j,k));
+
+ VERIFY_IS_EQUAL(lt(i, j, k), (bool)typed_lt(i, j, k));
+ VERIFY_IS_EQUAL(le(i, j, k), (bool)typed_le(i, j, k));
+ VERIFY_IS_EQUAL(gt(i, j, k), (bool)typed_gt(i, j, k));
+ VERIFY_IS_EQUAL(ge(i, j, k), (bool)typed_ge(i, j, k));
}
}
}
@@ -65,14 +88,24 @@
Tensor<bool, 3> eq(2,3,7);
Tensor<bool, 3> ne(2,3,7);
+
+ Tensor<Scalar, 3> typed_eq(2, 3, 7);
+ Tensor<Scalar, 3> typed_ne(2, 3, 7);
+
eq = (mat1 == mat2);
ne = (mat1 != mat2);
+ typed_eq = mat1.binaryExpr(mat2, TypedEQOp());
+ typed_ne = mat1.binaryExpr(mat2, TypedNEOp());
+
for (int i = 0; i < 2; ++i) {
for (int j = 0; j < 3; ++j) {
for (int k = 0; k < 7; ++k) {
VERIFY_IS_EQUAL(eq(i,j,k), mat1(i,j,k) == mat2(i,j,k));
VERIFY_IS_EQUAL(ne(i,j,k), mat1(i,j,k) != mat2(i,j,k));
+
+ VERIFY_IS_EQUAL(eq(i, j, k), (bool)typed_eq(i,j,k));
+ VERIFY_IS_EQUAL(ne(i, j, k), (bool)typed_ne(i,j,k));
}
}
}
diff --git a/unsupported/test/cxx11_tensor_expr.cpp b/unsupported/test/cxx11_tensor_expr.cpp
index f99c80a..c76fbc5 100644
--- a/unsupported/test/cxx11_tensor_expr.cpp
+++ b/unsupported/test/cxx11_tensor_expr.cpp
@@ -280,6 +280,8 @@
static void test_select()
{
+ using TypedGTOp = internal::scalar_cmp_op<float, float, internal::cmp_GT, true>;
+
Tensor<float, 3> selector(2,3,7);
Tensor<float, 3> mat1(2,3,7);
Tensor<float, 3> mat2(2,3,7);
@@ -288,6 +290,8 @@
selector.setRandom();
mat1.setRandom();
mat2.setRandom();
+
+ // test select with a boolean condition
result = (selector > selector.constant(0.5f)).select(mat1, mat2);
for (int i = 0; i < 2; ++i) {
@@ -297,6 +301,18 @@
}
}
}
+
+ // test select with a typed condition
+ result = selector.binaryExpr(selector.constant(0.5f), TypedGTOp()).select(mat1, mat2);
+
+ for (int i = 0; i < 2; ++i) {
+ for (int j = 0; j < 3; ++j) {
+ for (int k = 0; k < 7; ++k) {
+ VERIFY_IS_APPROX(result(i, j, k), (selector(i, j, k) > 0.5f) ? mat1(i, j, k) : mat2(i, j, k));
+ }
+ }
+ }
+
}
template <typename Scalar>