blob: ceb4ee33c62798e3a3da0b95e132a19b2e7a7e70 [file] [log] [blame]
#ifndef EIGEN_PACKET_MATH_GOOGLE_AVX_H
#define EIGEN_PACKET_MATH_GOOGLE_AVX_H
#include <stdio.h>
#include <x86intrin.h>
#include "Eigen/src/Core/util/Macros.h"
namespace Eigen {
namespace internal {
typedef __m256 Packet8f;
#define BFLOAT_MASK 0xFFFF0000
#define BFLOAT_ROUND_MASK 0xFFFF8000
#define BFLOAT_EXP_MASK 0xFF800000
#define BFLOAT_HALF_MASK 0x00008000
#define BFLOAT_SIGN_BIT 0x80000000
#if defined(CUSTOM_NUMERICS)
#if !defined(CUSTOM_NUMERICS_MANT_BITS)
#define CUSTOM_NUMERICS_MANT_BITS 7
#endif // CUSTOM_NUMERICS_MANT_BITS
#if !defined(CUSTOM_NUMERICS_EXP_BITS)
#define CUSTOM_NUMERICS_EXP_BITS 8
#endif // CUSTOM_NUMERICS_EXP_BITS
#if !defined(CUSTOM_NUMERICS_EXP_BIAS)
#define CUSTOM_NUMERICS_EXP_BIAS ((1 << (CUSTOM_NUMERICS_EXP_BITS - 1)) - 1)
#endif // CUSTOM_NUMERICS_EXP_BIAS
#endif // CUSTOM_NUMERICS
EIGEN_STRONG_INLINE Packet8f BFConvertTrunc(const Packet8f& value) {
#if defined(HALF_PRECISION_BF16)
const __m256 bfloat_mask = _mm256_castsi256_ps(
_mm256_set1_epi32(BFLOAT_MASK));
__m256 result = _mm256_and_ps(bfloat_mask, value);
return result;
#elif defined(CUSTOM_NUMERICS) && (CUSTOM_NUMERICS_EXP_BITS == 8) && \
(CUSTOM_NUMERICS_EXP_BIAS == 127)
const int mask_scalar_ = 0xFFFFFFFF << (23 - CUSTOM_NUMERICS_MANT_BITS);
const __m256 bfloat_mask = _mm256_castsi256_ps(
_mm256_set1_epi32(mask_scalar_));
__m256 result = _mm256_and_ps(bfloat_mask, value);
return result;
#elif defined(CUSTOM_NUMERICS)
const __m256 sign_bit_simd_ = _mm256_castsi256_ps(
_mm256_set1_epi32(BFLOAT_SIGN_BIT));
const int mask_scalar_ = 0xFFFFFFFF << (23 - CUSTOM_NUMERICS_MANT_BITS);
const int abs_mask_scalar_ = mask_scalar_ & 0x7FFFFFFF;
const __m256 abs_mask_simd_ =
_mm256_castsi256_ps(_mm256_set1_epi32(abs_mask_scalar_));
const __m256 zero_simd_ = _mm256_castsi256_ps(_mm256_set1_epi32(0));
const int min_float_i = (128 - CUSTOM_NUMERICS_EXP_BIAS) << 23;
const int max_float_i =
(((127 + (1 << CUSTOM_NUMERICS_EXP_BITS) - 2 - CUSTOM_NUMERICS_EXP_BIAS)
<< 23) | ((1 << 23) - 1)) & mask_scalar_;
const __m256 min_simd_ =
_mm256_castsi256_ps(_mm256_set1_epi32(min_float_i));
const __m256 max_simd_ =
_mm256_castsi256_ps(_mm256_set1_epi32(max_float_i));
// extract sign bit and get the abs value
__m256 sign = _mm256_and_ps(sign_bit_simd_, value);
__m256 abs_value = _mm256_and_ps(abs_mask_simd_, value);
// mask the ones equal to 0 for SIMD construction
__m256 mask = _mm256_cmp_ps(abs_value, zero_simd_, 0);
// check against exponent range
__m256 result = _mm256_max_ps(abs_value, min_simd_);
result = _mm256_min_ps(result, max_simd_);
// apply the sign bit
result = _mm256_or_ps(result, sign);
// blend back to SIMD float with proper 0 value
result = _mm256_blendv_ps(result, zero_simd_, mask);
return result;
#else
return value;
#endif // HALF_PRECISION_BF16
}
EIGEN_STRONG_INLINE Packet8f BFConvertRound(const Packet8f& value) {
#if defined(HALF_PRECISION_BF16)
const __m256 bfloat_mask = _mm256_castsi256_ps(
_mm256_set1_epi32(BFLOAT_MASK));
const __m256 bfloat_round_mask = _mm256_castsi256_ps(
_mm256_set1_epi32(BFLOAT_ROUND_MASK));
const __m256 bfloat_exp_mask = _mm256_castsi256_ps(
_mm256_set1_epi32(BFLOAT_EXP_MASK));
const __m256 bfloat_half_mask = _mm256_castsi256_ps(
_mm256_set1_epi32(BFLOAT_HALF_MASK));
__m256 result = _mm256_and_ps(bfloat_round_mask, value);
__m256 one_scaled = _mm256_and_ps(bfloat_exp_mask, value);
__m256 one_plus_half_scaled = _mm256_or_ps(one_scaled,
bfloat_half_mask);
// subtract 1 (scaled)
result = _mm256_sub_ps(result, one_scaled);
// add 1 + half the lsb (scaled)
result = _mm256_add_ps(result, one_plus_half_scaled);
// truncate after adding half lsb
result = _mm256_and_ps(bfloat_mask, result);
return result;
#elif defined(CUSTOM_NUMERICS) && (CUSTOM_NUMERICS_EXP_BITS == 8) && \
(CUSTOM_NUMERICS_EXP_BIAS == 127)
const int mask_scalar_ = 0xFFFFFFFF << (23 - CUSTOM_NUMERICS_MANT_BITS);
const int mask_scalar_r_ = 0xFFFFFFFF << (22 - CUSTOM_NUMERICS_MANT_BITS);
const int mask_scalar_half_ = 0x1 << (22 - CUSTOM_NUMERICS_MANT_BITS);
const __m256 bfloat_mask = _mm256_castsi256_ps(
_mm256_set1_epi32(mask_scalar_));
const __m256 bfloat_round_mask = _mm256_castsi256_ps(
_mm256_set1_epi32(mask_scalar_r_));
const __m256 bfloat_exp_mask = _mm256_castsi256_ps(
_mm256_set1_epi32(BFLOAT_EXP_MASK));
const __m256 bfloat_half_mask = _mm256_castsi256_ps(
_mm256_set1_epi32(mask_scalar_half_));
__m256 result = _mm256_and_ps(bfloat_round_mask, value);
__m256 one_scaled = _mm256_and_ps(bfloat_exp_mask, value);
__m256 one_plus_half_scaled = _mm256_or_ps(one_scaled,
bfloat_half_mask);
// subtract 1 (scaled)
result = _mm256_sub_ps(result, one_scaled);
// add 1 + half the lsb (scaled)
result = _mm256_add_ps(result, one_plus_half_scaled);
// truncate after adding half lsb
result = _mm256_and_ps(bfloat_mask, result);
return result;
#elif defined(CUSTOM_NUMERICS)
const __m256 sign_bit_simd_ = _mm256_castsi256_ps(
_mm256_set1_epi32(BFLOAT_SIGN_BIT));
const int mask_scalar_ = 0xFFFFFFFF << (23 - CUSTOM_NUMERICS_MANT_BITS);
const int abs_mask_scalar_ = mask_scalar_ & 0x7FFFFFFF;
const __m256 abs_mask_simd_ =
_mm256_castsi256_ps(_mm256_set1_epi32(abs_mask_scalar_));
const int half_bit_scalar_ =
(mask_scalar_ ^ (mask_scalar_ >> 1)) & 0x7FFFFFFF;
const int mask_plus_half_scalar_ = mask_scalar_ | half_bit_scalar_;
const __m256 abs_mask_plus_half_simd_ =
_mm256_castsi256_ps(_mm256_set1_epi32(mask_plus_half_scalar_));
const __m256 half_simd_ =
_mm256_castsi256_ps(_mm256_set1_epi32(half_bit_scalar_));
const __m256 exp_mask_simd_ =
_mm256_castsi256_ps(_mm256_set1_epi32(0x7F800000));
const __m256 zero_simd_ = _mm256_castsi256_ps(_mm256_set1_epi32(0));
const int min_float_i = (128 - CUSTOM_NUMERICS_EXP_BIAS) << 23;
const int max_float_i =
(((127 + (1 << CUSTOM_NUMERICS_EXP_BITS) - 2 - CUSTOM_NUMERICS_EXP_BIAS)
<< 23) | ((1 << 23) - 1)) & mask_scalar_;
const __m256 min_simd_ =
_mm256_castsi256_ps(_mm256_set1_epi32(min_float_i));
const __m256 max_simd_ =
_mm256_castsi256_ps(_mm256_set1_epi32(max_float_i));
// extract sign bit and get the abs value
__m256 sign = _mm256_and_ps(sign_bit_simd_, value);
__m256 abs_value = _mm256_and_ps(abs_mask_plus_half_simd_, value);
__m256 one_scaled = _mm256_and_ps(exp_mask_simd_, value);
__m256 one_plus_half_scaled = _mm256_or_ps(one_scaled, half_simd_);
// subtract 1 (scaled)
abs_value = _mm256_sub_ps(abs_value, one_scaled);
// add 1 + half the lsb (scaled)
abs_value = _mm256_add_ps(abs_value, one_plus_half_scaled);
// truncate after adding half lsb
abs_value = _mm256_and_ps(abs_mask_simd_, abs_value);
// mask the ones equal to 0 for SIMD construction
__m256 mask = _mm256_cmp_ps(abs_value, zero_simd_, 0);
// check against exponent range
__m256 result = _mm256_max_ps(abs_value, min_simd_);
result = _mm256_min_ps(result, max_simd_);
// apply the sign bit
result = _mm256_or_ps(result, sign);
// blend back to SIMD float with proper 0 value
result = _mm256_blendv_ps(result, zero_simd_, mask);
return result;
#else
return value;
#endif // HALF_PRECISION_BF16
}
EIGEN_STRONG_INLINE Packet8f BFConvert(const Packet8f& value) {
#if defined(HALF_PRECISION_ROUND) || defined(CUSTOM_NUMERICS_ROUND)
return BFConvertRound(value);
#else
return BFConvertTrunc(value);
#endif // HALF_PRECISION_ROUND || defined(CUSTOM_NUMERICS_ROUND)
}
EIGEN_STRONG_INLINE Packet8f pmul_half_bf16(
const Packet8f& a, const Packet8f& b) {
__m256 a_bf16 = BFConvert(a);
__m256 b_bf16 = BFConvert(b);
#if defined(HALF_PRECISION_OUT16) || defined(CUSTOM_NUMERICS_OUT)
return BFConvert(_mm256_mul_ps(a_bf16, b_bf16));
#else
return _mm256_mul_ps(a_bf16, b_bf16);
#endif // HALF_PRECISION_OUT16 || defined(CUSTOM_NUMERICS_OUT)
}
EIGEN_STRONG_INLINE Packet8f pmul_half_fp16(
const Packet8f& a, const Packet8f& b) {
#if defined(HALF_PRECISION_FP16)
#if defined(HALF_PRECISION_ROUND) || defined(CUSTOM_NUMERICS_ROUND)
// default round to nearest even is _MM_FROUND_TO_NEAREST_INT = 0x0
__m128i a_fp16 = _mm256_cvtps_ph(a, _MM_FROUND_NO_EXC);
__m128i b_fp16 = _mm256_cvtps_ph(b, _MM_FROUND_NO_EXC);
#else
__m128i a_fp16 = _mm256_cvtps_ph(a, _MM_FROUND_TO_ZERO |_MM_FROUND_NO_EXC);
__m128i b_fp16 = _mm256_cvtps_ph(b, _MM_FROUND_TO_ZERO |_MM_FROUND_NO_EXC);
#endif // HALF_PRECISION_ROUND || defined(CUSTOM_NUMERICS_ROUND)
__m256 tmp = _mm256_mul_ps(_mm256_cvtph_ps(a_fp16), _mm256_cvtph_ps(b_fp16));
#if defined(HALF_PRECISION_OUT16) || defined(CUSTOM_NUMERICS_OUT)
#if defined(HALF_PRECISION_ROUND) || defined(CUSTOM_NUMERICS_ROUND)
__m128i tmp_fp16 = _mm256_cvtps_ph(tmp, _MM_FROUND_NO_EXC);
#else
__m128i tmp_fp16 = _mm256_cvtps_ph(tmp,
_MM_FROUND_TO_ZERO |_MM_FROUND_NO_EXC);
#endif // HALF_PRECISION_ROUND || defined(CUSTOM_NUMERICS_ROUND)
return _mm256_cvtph_ps(tmp_fp16);
#else
return tmp;
#endif // HALF_PRECISION_OUT16 || defined(CUSTOM_NUMERICS_OUT)
#else
return _mm256_mul_ps(a, b);
#endif // HALF_PRECISION_FP16
}
EIGEN_STRONG_INLINE Packet8f pmul_custom(const Packet8f& a, const Packet8f& b) {
#if defined(HALF_PRECISION_BF16) || defined(CUSTOM_NUMERICS)
return pmul_half_bf16(a, b);
#else
return pmul_half_fp16(a, b);
#endif // HALF_PRECISION_BF16 || defined(CUSTOM_NUMERICS)
}
EIGEN_STRONG_INLINE Packet8f pmadd_half_bf16(
const Packet8f& a, const Packet8f& b, const Packet8f& c) {
__m256 a_bf16 = BFConvert(a);
__m256 b_bf16 = BFConvert(b);
#if defined(HALF_PRECISION_OUT16) || defined(CUSTOM_NUMERICS_OUT)
return BFConvert(_mm256_fmadd_ps(a_bf16, b_bf16, c));
#else
return _mm256_fmadd_ps(a_bf16, b_bf16, c);
#endif // HALF_PRECISION_OUT16 || defined(CUSTOM_NUMERICS_OUT)
}
EIGEN_STRONG_INLINE Packet8f pmadd_half_fp16(
const Packet8f& a, const Packet8f& b, const Packet8f& c) {
#if defined(HALF_PRECISION_FP16)
#if defined(HALF_PRECISION_ROUND) || defined(CUSTOM_NUMERICS_ROUND)
// default round to nearest even is _MM_FROUND_TO_NEAREST_INT = 0x0
__m128i a_fp16 = _mm256_cvtps_ph(a, _MM_FROUND_NO_EXC);
__m128i b_fp16 = _mm256_cvtps_ph(b, _MM_FROUND_NO_EXC);
#else
__m128i a_fp16 = _mm256_cvtps_ph(a, _MM_FROUND_TO_ZERO |_MM_FROUND_NO_EXC);
__m128i b_fp16 = _mm256_cvtps_ph(b, _MM_FROUND_TO_ZERO |_MM_FROUND_NO_EXC);
#endif // HALF_PRECISION_ROUND || defined(CUSTOM_NUMERICS_ROUND)
__m256 tmp = _mm256_fmadd_ps(_mm256_cvtph_ps(a_fp16),
_mm256_cvtph_ps(b_fp16),
c);
#if defined(HALF_PRECISION_OUT16) || defined(CUSTOM_NUMERICS_OUT)
#if defined(HALF_PRECISION_ROUND) || defined(CUSTOM_NUMERICS_ROUND)
__m128i tmp_fp16 = _mm256_cvtps_ph(tmp, _MM_FROUND_NO_EXC);
#else
__m128i tmp_fp16 = _mm256_cvtps_ph(tmp,
_MM_FROUND_TO_ZERO |_MM_FROUND_NO_EXC);
#endif // HALF_PRECISION_ROUND || defined(CUSTOM_NUMERICS_ROUND)
return _mm256_cvtph_ps(tmp_fp16);
#else
return tmp;
#endif // HALF_PRECISION_OUT16 || defined(CUSTOM_NUMERICS_OUT)
#else
return _mm256_fmadd_ps(a, b, c);
#endif // HALF_PRECISION_FP16
}
EIGEN_STRONG_INLINE Packet8f pmadd_custom(
const Packet8f& a, const Packet8f& b, const Packet8f& c) {
#if defined(HALF_PRECISION_BF16) || defined(CUSTOM_NUMERICS)
return pmadd_half_bf16(a, b, c);
#else
return pmadd_half_fp16(a, b, c);
#endif // HALF_PRECISION_BF16 || defined(CUSTOM_NUMERICS)
}
} // end namespace internal
} // end namespace Eigen
#endif // EIGEN_PACKET_MATH_AVX_H