| #ifndef EIGEN_PACKET_MATH_GOOGLE_AVX_H |
| #define EIGEN_PACKET_MATH_GOOGLE_AVX_H |
| |
| #include <stdio.h> |
| #include <x86intrin.h> |
| |
| #include "Eigen/src/Core/util/Macros.h" |
| |
| namespace Eigen { |
| |
| namespace internal { |
| |
| typedef __m256 Packet8f; |
| |
| #define BFLOAT_MASK 0xFFFF0000 |
| #define BFLOAT_ROUND_MASK 0xFFFF8000 |
| #define BFLOAT_EXP_MASK 0xFF800000 |
| #define BFLOAT_HALF_MASK 0x00008000 |
| #define BFLOAT_SIGN_BIT 0x80000000 |
| |
| #if defined(CUSTOM_NUMERICS) |
| #if !defined(CUSTOM_NUMERICS_MANT_BITS) |
| #define CUSTOM_NUMERICS_MANT_BITS 7 |
| #endif // CUSTOM_NUMERICS_MANT_BITS |
| #if !defined(CUSTOM_NUMERICS_EXP_BITS) |
| #define CUSTOM_NUMERICS_EXP_BITS 8 |
| #endif // CUSTOM_NUMERICS_EXP_BITS |
| #if !defined(CUSTOM_NUMERICS_EXP_BIAS) |
| #define CUSTOM_NUMERICS_EXP_BIAS ((1 << (CUSTOM_NUMERICS_EXP_BITS - 1)) - 1) |
| #endif // CUSTOM_NUMERICS_EXP_BIAS |
| #endif // CUSTOM_NUMERICS |
| |
| EIGEN_STRONG_INLINE Packet8f BFConvertTrunc(const Packet8f& value) { |
| #if defined(HALF_PRECISION_BF16) |
| const __m256 bfloat_mask = _mm256_castsi256_ps( |
| _mm256_set1_epi32(BFLOAT_MASK)); |
| __m256 result = _mm256_and_ps(bfloat_mask, value); |
| return result; |
| #elif defined(CUSTOM_NUMERICS) && (CUSTOM_NUMERICS_EXP_BITS == 8) && \ |
| (CUSTOM_NUMERICS_EXP_BIAS == 127) |
| const int mask_scalar_ = 0xFFFFFFFF << (23 - CUSTOM_NUMERICS_MANT_BITS); |
| const __m256 bfloat_mask = _mm256_castsi256_ps( |
| _mm256_set1_epi32(mask_scalar_)); |
| __m256 result = _mm256_and_ps(bfloat_mask, value); |
| return result; |
| #elif defined(CUSTOM_NUMERICS) |
| const __m256 sign_bit_simd_ = _mm256_castsi256_ps( |
| _mm256_set1_epi32(BFLOAT_SIGN_BIT)); |
| const int mask_scalar_ = 0xFFFFFFFF << (23 - CUSTOM_NUMERICS_MANT_BITS); |
| const int abs_mask_scalar_ = mask_scalar_ & 0x7FFFFFFF; |
| const __m256 abs_mask_simd_ = |
| _mm256_castsi256_ps(_mm256_set1_epi32(abs_mask_scalar_)); |
| const __m256 zero_simd_ = _mm256_castsi256_ps(_mm256_set1_epi32(0)); |
| const int min_float_i = (128 - CUSTOM_NUMERICS_EXP_BIAS) << 23; |
| const int max_float_i = |
| (((127 + (1 << CUSTOM_NUMERICS_EXP_BITS) - 2 - CUSTOM_NUMERICS_EXP_BIAS) |
| << 23) | ((1 << 23) - 1)) & mask_scalar_; |
| const __m256 min_simd_ = |
| _mm256_castsi256_ps(_mm256_set1_epi32(min_float_i)); |
| const __m256 max_simd_ = |
| _mm256_castsi256_ps(_mm256_set1_epi32(max_float_i)); |
| // extract sign bit and get the abs value |
| __m256 sign = _mm256_and_ps(sign_bit_simd_, value); |
| __m256 abs_value = _mm256_and_ps(abs_mask_simd_, value); |
| // mask the ones equal to 0 for SIMD construction |
| __m256 mask = _mm256_cmp_ps(abs_value, zero_simd_, 0); |
| // check against exponent range |
| __m256 result = _mm256_max_ps(abs_value, min_simd_); |
| result = _mm256_min_ps(result, max_simd_); |
| // apply the sign bit |
| result = _mm256_or_ps(result, sign); |
| // blend back to SIMD float with proper 0 value |
| result = _mm256_blendv_ps(result, zero_simd_, mask); |
| return result; |
| #else |
| return value; |
| #endif // HALF_PRECISION_BF16 |
| } |
| |
| EIGEN_STRONG_INLINE Packet8f BFConvertRound(const Packet8f& value) { |
| #if defined(HALF_PRECISION_BF16) |
| const __m256 bfloat_mask = _mm256_castsi256_ps( |
| _mm256_set1_epi32(BFLOAT_MASK)); |
| const __m256 bfloat_round_mask = _mm256_castsi256_ps( |
| _mm256_set1_epi32(BFLOAT_ROUND_MASK)); |
| const __m256 bfloat_exp_mask = _mm256_castsi256_ps( |
| _mm256_set1_epi32(BFLOAT_EXP_MASK)); |
| const __m256 bfloat_half_mask = _mm256_castsi256_ps( |
| _mm256_set1_epi32(BFLOAT_HALF_MASK)); |
| __m256 result = _mm256_and_ps(bfloat_round_mask, value); |
| __m256 one_scaled = _mm256_and_ps(bfloat_exp_mask, value); |
| __m256 one_plus_half_scaled = _mm256_or_ps(one_scaled, |
| bfloat_half_mask); |
| // subtract 1 (scaled) |
| result = _mm256_sub_ps(result, one_scaled); |
| // add 1 + half the lsb (scaled) |
| result = _mm256_add_ps(result, one_plus_half_scaled); |
| // truncate after adding half lsb |
| result = _mm256_and_ps(bfloat_mask, result); |
| return result; |
| #elif defined(CUSTOM_NUMERICS) && (CUSTOM_NUMERICS_EXP_BITS == 8) && \ |
| (CUSTOM_NUMERICS_EXP_BIAS == 127) |
| const int mask_scalar_ = 0xFFFFFFFF << (23 - CUSTOM_NUMERICS_MANT_BITS); |
| const int mask_scalar_r_ = 0xFFFFFFFF << (22 - CUSTOM_NUMERICS_MANT_BITS); |
| const int mask_scalar_half_ = 0x1 << (22 - CUSTOM_NUMERICS_MANT_BITS); |
| const __m256 bfloat_mask = _mm256_castsi256_ps( |
| _mm256_set1_epi32(mask_scalar_)); |
| const __m256 bfloat_round_mask = _mm256_castsi256_ps( |
| _mm256_set1_epi32(mask_scalar_r_)); |
| const __m256 bfloat_exp_mask = _mm256_castsi256_ps( |
| _mm256_set1_epi32(BFLOAT_EXP_MASK)); |
| const __m256 bfloat_half_mask = _mm256_castsi256_ps( |
| _mm256_set1_epi32(mask_scalar_half_)); |
| __m256 result = _mm256_and_ps(bfloat_round_mask, value); |
| __m256 one_scaled = _mm256_and_ps(bfloat_exp_mask, value); |
| __m256 one_plus_half_scaled = _mm256_or_ps(one_scaled, |
| bfloat_half_mask); |
| // subtract 1 (scaled) |
| result = _mm256_sub_ps(result, one_scaled); |
| // add 1 + half the lsb (scaled) |
| result = _mm256_add_ps(result, one_plus_half_scaled); |
| // truncate after adding half lsb |
| result = _mm256_and_ps(bfloat_mask, result); |
| return result; |
| #elif defined(CUSTOM_NUMERICS) |
| const __m256 sign_bit_simd_ = _mm256_castsi256_ps( |
| _mm256_set1_epi32(BFLOAT_SIGN_BIT)); |
| const int mask_scalar_ = 0xFFFFFFFF << (23 - CUSTOM_NUMERICS_MANT_BITS); |
| const int abs_mask_scalar_ = mask_scalar_ & 0x7FFFFFFF; |
| const __m256 abs_mask_simd_ = |
| _mm256_castsi256_ps(_mm256_set1_epi32(abs_mask_scalar_)); |
| const int half_bit_scalar_ = |
| (mask_scalar_ ^ (mask_scalar_ >> 1)) & 0x7FFFFFFF; |
| const int mask_plus_half_scalar_ = mask_scalar_ | half_bit_scalar_; |
| const __m256 abs_mask_plus_half_simd_ = |
| _mm256_castsi256_ps(_mm256_set1_epi32(mask_plus_half_scalar_)); |
| const __m256 half_simd_ = |
| _mm256_castsi256_ps(_mm256_set1_epi32(half_bit_scalar_)); |
| const __m256 exp_mask_simd_ = |
| _mm256_castsi256_ps(_mm256_set1_epi32(0x7F800000)); |
| const __m256 zero_simd_ = _mm256_castsi256_ps(_mm256_set1_epi32(0)); |
| const int min_float_i = (128 - CUSTOM_NUMERICS_EXP_BIAS) << 23; |
| const int max_float_i = |
| (((127 + (1 << CUSTOM_NUMERICS_EXP_BITS) - 2 - CUSTOM_NUMERICS_EXP_BIAS) |
| << 23) | ((1 << 23) - 1)) & mask_scalar_; |
| const __m256 min_simd_ = |
| _mm256_castsi256_ps(_mm256_set1_epi32(min_float_i)); |
| const __m256 max_simd_ = |
| _mm256_castsi256_ps(_mm256_set1_epi32(max_float_i)); |
| // extract sign bit and get the abs value |
| __m256 sign = _mm256_and_ps(sign_bit_simd_, value); |
| __m256 abs_value = _mm256_and_ps(abs_mask_plus_half_simd_, value); |
| __m256 one_scaled = _mm256_and_ps(exp_mask_simd_, value); |
| __m256 one_plus_half_scaled = _mm256_or_ps(one_scaled, half_simd_); |
| // subtract 1 (scaled) |
| abs_value = _mm256_sub_ps(abs_value, one_scaled); |
| // add 1 + half the lsb (scaled) |
| abs_value = _mm256_add_ps(abs_value, one_plus_half_scaled); |
| // truncate after adding half lsb |
| abs_value = _mm256_and_ps(abs_mask_simd_, abs_value); |
| // mask the ones equal to 0 for SIMD construction |
| __m256 mask = _mm256_cmp_ps(abs_value, zero_simd_, 0); |
| // check against exponent range |
| __m256 result = _mm256_max_ps(abs_value, min_simd_); |
| result = _mm256_min_ps(result, max_simd_); |
| // apply the sign bit |
| result = _mm256_or_ps(result, sign); |
| // blend back to SIMD float with proper 0 value |
| result = _mm256_blendv_ps(result, zero_simd_, mask); |
| return result; |
| #else |
| return value; |
| #endif // HALF_PRECISION_BF16 |
| } |
| |
| EIGEN_STRONG_INLINE Packet8f BFConvert(const Packet8f& value) { |
| #if defined(HALF_PRECISION_ROUND) || defined(CUSTOM_NUMERICS_ROUND) |
| return BFConvertRound(value); |
| #else |
| return BFConvertTrunc(value); |
| #endif // HALF_PRECISION_ROUND || defined(CUSTOM_NUMERICS_ROUND) |
| } |
| |
| EIGEN_STRONG_INLINE Packet8f pmul_half_bf16( |
| const Packet8f& a, const Packet8f& b) { |
| __m256 a_bf16 = BFConvert(a); |
| __m256 b_bf16 = BFConvert(b); |
| #if defined(HALF_PRECISION_OUT16) || defined(CUSTOM_NUMERICS_OUT) |
| return BFConvert(_mm256_mul_ps(a_bf16, b_bf16)); |
| #else |
| return _mm256_mul_ps(a_bf16, b_bf16); |
| #endif // HALF_PRECISION_OUT16 || defined(CUSTOM_NUMERICS_OUT) |
| } |
| |
| EIGEN_STRONG_INLINE Packet8f pmul_half_fp16( |
| const Packet8f& a, const Packet8f& b) { |
| #if defined(HALF_PRECISION_FP16) |
| #if defined(HALF_PRECISION_ROUND) || defined(CUSTOM_NUMERICS_ROUND) |
| // default round to nearest even is _MM_FROUND_TO_NEAREST_INT = 0x0 |
| __m128i a_fp16 = _mm256_cvtps_ph(a, _MM_FROUND_NO_EXC); |
| __m128i b_fp16 = _mm256_cvtps_ph(b, _MM_FROUND_NO_EXC); |
| #else |
| __m128i a_fp16 = _mm256_cvtps_ph(a, _MM_FROUND_TO_ZERO |_MM_FROUND_NO_EXC); |
| __m128i b_fp16 = _mm256_cvtps_ph(b, _MM_FROUND_TO_ZERO |_MM_FROUND_NO_EXC); |
| #endif // HALF_PRECISION_ROUND || defined(CUSTOM_NUMERICS_ROUND) |
| __m256 tmp = _mm256_mul_ps(_mm256_cvtph_ps(a_fp16), _mm256_cvtph_ps(b_fp16)); |
| #if defined(HALF_PRECISION_OUT16) || defined(CUSTOM_NUMERICS_OUT) |
| #if defined(HALF_PRECISION_ROUND) || defined(CUSTOM_NUMERICS_ROUND) |
| __m128i tmp_fp16 = _mm256_cvtps_ph(tmp, _MM_FROUND_NO_EXC); |
| #else |
| __m128i tmp_fp16 = _mm256_cvtps_ph(tmp, |
| _MM_FROUND_TO_ZERO |_MM_FROUND_NO_EXC); |
| #endif // HALF_PRECISION_ROUND || defined(CUSTOM_NUMERICS_ROUND) |
| return _mm256_cvtph_ps(tmp_fp16); |
| #else |
| return tmp; |
| #endif // HALF_PRECISION_OUT16 || defined(CUSTOM_NUMERICS_OUT) |
| #else |
| return _mm256_mul_ps(a, b); |
| #endif // HALF_PRECISION_FP16 |
| } |
| |
| EIGEN_STRONG_INLINE Packet8f pmul_custom(const Packet8f& a, const Packet8f& b) { |
| #if defined(HALF_PRECISION_BF16) || defined(CUSTOM_NUMERICS) |
| return pmul_half_bf16(a, b); |
| #else |
| return pmul_half_fp16(a, b); |
| #endif // HALF_PRECISION_BF16 || defined(CUSTOM_NUMERICS) |
| } |
| |
| EIGEN_STRONG_INLINE Packet8f pmadd_half_bf16( |
| const Packet8f& a, const Packet8f& b, const Packet8f& c) { |
| __m256 a_bf16 = BFConvert(a); |
| __m256 b_bf16 = BFConvert(b); |
| #if defined(HALF_PRECISION_OUT16) || defined(CUSTOM_NUMERICS_OUT) |
| return BFConvert(_mm256_fmadd_ps(a_bf16, b_bf16, c)); |
| #else |
| return _mm256_fmadd_ps(a_bf16, b_bf16, c); |
| #endif // HALF_PRECISION_OUT16 || defined(CUSTOM_NUMERICS_OUT) |
| } |
| |
| EIGEN_STRONG_INLINE Packet8f pmadd_half_fp16( |
| const Packet8f& a, const Packet8f& b, const Packet8f& c) { |
| #if defined(HALF_PRECISION_FP16) |
| #if defined(HALF_PRECISION_ROUND) || defined(CUSTOM_NUMERICS_ROUND) |
| // default round to nearest even is _MM_FROUND_TO_NEAREST_INT = 0x0 |
| __m128i a_fp16 = _mm256_cvtps_ph(a, _MM_FROUND_NO_EXC); |
| __m128i b_fp16 = _mm256_cvtps_ph(b, _MM_FROUND_NO_EXC); |
| #else |
| __m128i a_fp16 = _mm256_cvtps_ph(a, _MM_FROUND_TO_ZERO |_MM_FROUND_NO_EXC); |
| __m128i b_fp16 = _mm256_cvtps_ph(b, _MM_FROUND_TO_ZERO |_MM_FROUND_NO_EXC); |
| #endif // HALF_PRECISION_ROUND || defined(CUSTOM_NUMERICS_ROUND) |
| __m256 tmp = _mm256_fmadd_ps(_mm256_cvtph_ps(a_fp16), |
| _mm256_cvtph_ps(b_fp16), |
| c); |
| #if defined(HALF_PRECISION_OUT16) || defined(CUSTOM_NUMERICS_OUT) |
| #if defined(HALF_PRECISION_ROUND) || defined(CUSTOM_NUMERICS_ROUND) |
| __m128i tmp_fp16 = _mm256_cvtps_ph(tmp, _MM_FROUND_NO_EXC); |
| #else |
| __m128i tmp_fp16 = _mm256_cvtps_ph(tmp, |
| _MM_FROUND_TO_ZERO |_MM_FROUND_NO_EXC); |
| #endif // HALF_PRECISION_ROUND || defined(CUSTOM_NUMERICS_ROUND) |
| return _mm256_cvtph_ps(tmp_fp16); |
| #else |
| return tmp; |
| #endif // HALF_PRECISION_OUT16 || defined(CUSTOM_NUMERICS_OUT) |
| #else |
| return _mm256_fmadd_ps(a, b, c); |
| #endif // HALF_PRECISION_FP16 |
| } |
| |
| EIGEN_STRONG_INLINE Packet8f pmadd_custom( |
| const Packet8f& a, const Packet8f& b, const Packet8f& c) { |
| #if defined(HALF_PRECISION_BF16) || defined(CUSTOM_NUMERICS) |
| return pmadd_half_bf16(a, b, c); |
| #else |
| return pmadd_half_fp16(a, b, c); |
| #endif // HALF_PRECISION_BF16 || defined(CUSTOM_NUMERICS) |
| } |
| |
| } // end namespace internal |
| |
| } // end namespace Eigen |
| |
| #endif // EIGEN_PACKET_MATH_AVX_H |