#ifndef EIGEN_PACKET_MATH_GOOGLE_AVX_H
#define EIGEN_PACKET_MATH_GOOGLE_AVX_H

#include <stdio.h>
#include <x86intrin.h>

#include "Eigen/src/Core/util/Macros.h"

namespace Eigen {

namespace internal {

typedef __m256  Packet8f;

#define BFLOAT_MASK 0xFFFF0000
#define BFLOAT_ROUND_MASK 0xFFFF8000
#define BFLOAT_EXP_MASK 0xFF800000
#define BFLOAT_HALF_MASK 0x00008000
#define BFLOAT_SIGN_BIT 0x80000000

#if defined(CUSTOM_NUMERICS)
#if !defined(CUSTOM_NUMERICS_MANT_BITS)
#define CUSTOM_NUMERICS_MANT_BITS 7
#endif  // CUSTOM_NUMERICS_MANT_BITS
#if !defined(CUSTOM_NUMERICS_EXP_BITS)
#define CUSTOM_NUMERICS_EXP_BITS 8
#endif  // CUSTOM_NUMERICS_EXP_BITS
#if !defined(CUSTOM_NUMERICS_EXP_BIAS)
#define CUSTOM_NUMERICS_EXP_BIAS ((1 << (CUSTOM_NUMERICS_EXP_BITS - 1)) - 1)
#endif  // CUSTOM_NUMERICS_EXP_BIAS
#endif  // CUSTOM_NUMERICS

EIGEN_STRONG_INLINE Packet8f BFConvertTrunc(const Packet8f& value) {
#if defined(HALF_PRECISION_BF16)
  const __m256 bfloat_mask = _mm256_castsi256_ps(
      _mm256_set1_epi32(BFLOAT_MASK));
  __m256 result = _mm256_and_ps(bfloat_mask, value);
  return result;
#elif defined(CUSTOM_NUMERICS) && (CUSTOM_NUMERICS_EXP_BITS == 8) && \
    (CUSTOM_NUMERICS_EXP_BIAS == 127)
  const int mask_scalar_ = 0xFFFFFFFF << (23 - CUSTOM_NUMERICS_MANT_BITS);
  const __m256 bfloat_mask = _mm256_castsi256_ps(
      _mm256_set1_epi32(mask_scalar_));
  __m256 result = _mm256_and_ps(bfloat_mask, value);
  return result;
#elif defined(CUSTOM_NUMERICS)
  const __m256 sign_bit_simd_ = _mm256_castsi256_ps(
      _mm256_set1_epi32(BFLOAT_SIGN_BIT));
  const int mask_scalar_ = 0xFFFFFFFF << (23 - CUSTOM_NUMERICS_MANT_BITS);
  const int abs_mask_scalar_ = mask_scalar_ & 0x7FFFFFFF;
  const __m256 abs_mask_simd_ =
      _mm256_castsi256_ps(_mm256_set1_epi32(abs_mask_scalar_));
  const __m256 zero_simd_ = _mm256_castsi256_ps(_mm256_set1_epi32(0));
  const int min_float_i = (128 - CUSTOM_NUMERICS_EXP_BIAS) << 23;
  const int max_float_i =
      (((127 + (1 << CUSTOM_NUMERICS_EXP_BITS) - 2 - CUSTOM_NUMERICS_EXP_BIAS)
        << 23) | ((1 << 23) - 1)) & mask_scalar_;
  const __m256 min_simd_ =
      _mm256_castsi256_ps(_mm256_set1_epi32(min_float_i));
  const __m256 max_simd_ =
      _mm256_castsi256_ps(_mm256_set1_epi32(max_float_i));
  // extract sign bit and get the abs value
  __m256 sign = _mm256_and_ps(sign_bit_simd_, value);
  __m256 abs_value = _mm256_and_ps(abs_mask_simd_, value);
  // mask the ones equal to 0 for SIMD construction
  __m256 mask = _mm256_cmp_ps(abs_value, zero_simd_, 0);
  // check against exponent range
  __m256 result = _mm256_max_ps(abs_value, min_simd_);
  result = _mm256_min_ps(result, max_simd_);
  // apply the sign bit
  result = _mm256_or_ps(result, sign);
  // blend back to SIMD float with proper 0 value
  result = _mm256_blendv_ps(result, zero_simd_, mask);
  return result;
#else
  return value;
#endif  // HALF_PRECISION_BF16
}

EIGEN_STRONG_INLINE Packet8f BFConvertRound(const Packet8f& value) {
#if defined(HALF_PRECISION_BF16)
  const __m256 bfloat_mask = _mm256_castsi256_ps(
      _mm256_set1_epi32(BFLOAT_MASK));
  const __m256 bfloat_round_mask = _mm256_castsi256_ps(
      _mm256_set1_epi32(BFLOAT_ROUND_MASK));
  const __m256 bfloat_exp_mask = _mm256_castsi256_ps(
      _mm256_set1_epi32(BFLOAT_EXP_MASK));
  const __m256 bfloat_half_mask = _mm256_castsi256_ps(
      _mm256_set1_epi32(BFLOAT_HALF_MASK));
  __m256 result = _mm256_and_ps(bfloat_round_mask, value);
  __m256 one_scaled = _mm256_and_ps(bfloat_exp_mask, value);
  __m256 one_plus_half_scaled = _mm256_or_ps(one_scaled,
                                             bfloat_half_mask);
  // subtract 1 (scaled)
  result = _mm256_sub_ps(result, one_scaled);
  // add 1 + half the lsb (scaled)
  result = _mm256_add_ps(result, one_plus_half_scaled);
  // truncate after adding half lsb
  result = _mm256_and_ps(bfloat_mask, result);
  return result;
#elif defined(CUSTOM_NUMERICS) && (CUSTOM_NUMERICS_EXP_BITS == 8) && \
    (CUSTOM_NUMERICS_EXP_BIAS == 127)
  const int mask_scalar_ = 0xFFFFFFFF << (23 - CUSTOM_NUMERICS_MANT_BITS);
  const int mask_scalar_r_ = 0xFFFFFFFF << (22 - CUSTOM_NUMERICS_MANT_BITS);
  const int mask_scalar_half_ = 0x1 << (22 - CUSTOM_NUMERICS_MANT_BITS);
  const __m256 bfloat_mask = _mm256_castsi256_ps(
      _mm256_set1_epi32(mask_scalar_));
  const __m256 bfloat_round_mask = _mm256_castsi256_ps(
      _mm256_set1_epi32(mask_scalar_r_));
  const __m256 bfloat_exp_mask = _mm256_castsi256_ps(
      _mm256_set1_epi32(BFLOAT_EXP_MASK));
  const __m256 bfloat_half_mask = _mm256_castsi256_ps(
      _mm256_set1_epi32(mask_scalar_half_));
  __m256 result = _mm256_and_ps(bfloat_round_mask, value);
  __m256 one_scaled = _mm256_and_ps(bfloat_exp_mask, value);
  __m256 one_plus_half_scaled = _mm256_or_ps(one_scaled,
                                             bfloat_half_mask);
  // subtract 1 (scaled)
  result = _mm256_sub_ps(result, one_scaled);
  // add 1 + half the lsb (scaled)
  result = _mm256_add_ps(result, one_plus_half_scaled);
  // truncate after adding half lsb
  result = _mm256_and_ps(bfloat_mask, result);
  return result;
#elif defined(CUSTOM_NUMERICS)
  const __m256 sign_bit_simd_ = _mm256_castsi256_ps(
      _mm256_set1_epi32(BFLOAT_SIGN_BIT));
  const int mask_scalar_ = 0xFFFFFFFF << (23 - CUSTOM_NUMERICS_MANT_BITS);
  const int abs_mask_scalar_ = mask_scalar_ & 0x7FFFFFFF;
  const __m256 abs_mask_simd_ =
      _mm256_castsi256_ps(_mm256_set1_epi32(abs_mask_scalar_));
  const int half_bit_scalar_ =
      (mask_scalar_ ^ (mask_scalar_ >> 1)) & 0x7FFFFFFF;
  const int mask_plus_half_scalar_ = mask_scalar_ | half_bit_scalar_;
  const __m256 abs_mask_plus_half_simd_ =
      _mm256_castsi256_ps(_mm256_set1_epi32(mask_plus_half_scalar_));
  const __m256 half_simd_ =
      _mm256_castsi256_ps(_mm256_set1_epi32(half_bit_scalar_));
  const __m256 exp_mask_simd_ =
      _mm256_castsi256_ps(_mm256_set1_epi32(0x7F800000));
  const __m256 zero_simd_ = _mm256_castsi256_ps(_mm256_set1_epi32(0));
  const int min_float_i = (128 - CUSTOM_NUMERICS_EXP_BIAS) << 23;
  const int max_float_i =
      (((127 + (1 << CUSTOM_NUMERICS_EXP_BITS) - 2 - CUSTOM_NUMERICS_EXP_BIAS)
        << 23) | ((1 << 23) - 1)) & mask_scalar_;
  const __m256 min_simd_ =
      _mm256_castsi256_ps(_mm256_set1_epi32(min_float_i));
  const __m256 max_simd_ =
      _mm256_castsi256_ps(_mm256_set1_epi32(max_float_i));
  // extract sign bit and get the abs value
  __m256 sign = _mm256_and_ps(sign_bit_simd_, value);
  __m256 abs_value = _mm256_and_ps(abs_mask_plus_half_simd_, value);
  __m256 one_scaled = _mm256_and_ps(exp_mask_simd_, value);
  __m256 one_plus_half_scaled = _mm256_or_ps(one_scaled, half_simd_);
  // subtract 1 (scaled)
  abs_value = _mm256_sub_ps(abs_value, one_scaled);
  // add 1 + half the lsb (scaled)
  abs_value = _mm256_add_ps(abs_value, one_plus_half_scaled);
  // truncate after adding half lsb
  abs_value = _mm256_and_ps(abs_mask_simd_, abs_value);
  // mask the ones equal to 0 for SIMD construction
  __m256 mask = _mm256_cmp_ps(abs_value, zero_simd_, 0);
  // check against exponent range
  __m256 result = _mm256_max_ps(abs_value, min_simd_);
  result = _mm256_min_ps(result, max_simd_);
  // apply the sign bit
  result = _mm256_or_ps(result, sign);
  // blend back to SIMD float with proper 0 value
  result = _mm256_blendv_ps(result, zero_simd_, mask);
  return result;
#else
  return value;
#endif  // HALF_PRECISION_BF16
}

EIGEN_STRONG_INLINE Packet8f BFConvert(const Packet8f& value) {
#if defined(HALF_PRECISION_ROUND) || defined(CUSTOM_NUMERICS_ROUND)
  return BFConvertRound(value);
#else
  return BFConvertTrunc(value);
#endif  // HALF_PRECISION_ROUND || defined(CUSTOM_NUMERICS_ROUND)
}

EIGEN_STRONG_INLINE Packet8f pmul_half_bf16(
    const Packet8f& a, const Packet8f& b) {
  __m256 a_bf16 = BFConvert(a);
  __m256 b_bf16 = BFConvert(b);
#if defined(HALF_PRECISION_OUT16) || defined(CUSTOM_NUMERICS_OUT)
  return  BFConvert(_mm256_mul_ps(a_bf16, b_bf16));
#else
  return  _mm256_mul_ps(a_bf16, b_bf16);
#endif  // HALF_PRECISION_OUT16 || defined(CUSTOM_NUMERICS_OUT)
}

EIGEN_STRONG_INLINE Packet8f pmul_half_fp16(
    const Packet8f& a, const Packet8f& b) {
#if defined(HALF_PRECISION_FP16)
#if defined(HALF_PRECISION_ROUND) || defined(CUSTOM_NUMERICS_ROUND)
  // default round to nearest even is _MM_FROUND_TO_NEAREST_INT = 0x0
  __m128i a_fp16 = _mm256_cvtps_ph(a, _MM_FROUND_NO_EXC);
  __m128i b_fp16 = _mm256_cvtps_ph(b, _MM_FROUND_NO_EXC);
#else
  __m128i a_fp16 = _mm256_cvtps_ph(a, _MM_FROUND_TO_ZERO |_MM_FROUND_NO_EXC);
  __m128i b_fp16 = _mm256_cvtps_ph(b, _MM_FROUND_TO_ZERO |_MM_FROUND_NO_EXC);
#endif  // HALF_PRECISION_ROUND || defined(CUSTOM_NUMERICS_ROUND)
  __m256 tmp = _mm256_mul_ps(_mm256_cvtph_ps(a_fp16), _mm256_cvtph_ps(b_fp16));
#if defined(HALF_PRECISION_OUT16) || defined(CUSTOM_NUMERICS_OUT)
#if defined(HALF_PRECISION_ROUND) || defined(CUSTOM_NUMERICS_ROUND)
  __m128i tmp_fp16 = _mm256_cvtps_ph(tmp, _MM_FROUND_NO_EXC);
#else
  __m128i tmp_fp16 = _mm256_cvtps_ph(tmp,
                                     _MM_FROUND_TO_ZERO |_MM_FROUND_NO_EXC);
#endif  // HALF_PRECISION_ROUND || defined(CUSTOM_NUMERICS_ROUND)
  return _mm256_cvtph_ps(tmp_fp16);
#else
  return tmp;
#endif  // HALF_PRECISION_OUT16 || defined(CUSTOM_NUMERICS_OUT)
#else
  return _mm256_mul_ps(a, b);
#endif  // HALF_PRECISION_FP16
}

EIGEN_STRONG_INLINE Packet8f pmul_custom(const Packet8f& a, const Packet8f& b) {
#if defined(HALF_PRECISION_BF16) || defined(CUSTOM_NUMERICS)
  return pmul_half_bf16(a, b);
#else
  return pmul_half_fp16(a, b);
#endif  // HALF_PRECISION_BF16 || defined(CUSTOM_NUMERICS)
}

EIGEN_STRONG_INLINE Packet8f pmadd_half_bf16(
    const Packet8f& a, const Packet8f& b, const Packet8f& c) {
  __m256 a_bf16 = BFConvert(a);
  __m256 b_bf16 = BFConvert(b);
#if defined(HALF_PRECISION_OUT16) || defined(CUSTOM_NUMERICS_OUT)
  return  BFConvert(_mm256_fmadd_ps(a_bf16, b_bf16, c));
#else
  return  _mm256_fmadd_ps(a_bf16, b_bf16, c);
#endif  // HALF_PRECISION_OUT16 || defined(CUSTOM_NUMERICS_OUT)
}

EIGEN_STRONG_INLINE Packet8f pmadd_half_fp16(
    const Packet8f& a, const Packet8f& b, const Packet8f& c) {
#if defined(HALF_PRECISION_FP16)
#if defined(HALF_PRECISION_ROUND) || defined(CUSTOM_NUMERICS_ROUND)
  // default round to nearest even is _MM_FROUND_TO_NEAREST_INT = 0x0
  __m128i a_fp16 = _mm256_cvtps_ph(a, _MM_FROUND_NO_EXC);
  __m128i b_fp16 = _mm256_cvtps_ph(b, _MM_FROUND_NO_EXC);
#else
  __m128i a_fp16 = _mm256_cvtps_ph(a, _MM_FROUND_TO_ZERO |_MM_FROUND_NO_EXC);
  __m128i b_fp16 = _mm256_cvtps_ph(b, _MM_FROUND_TO_ZERO |_MM_FROUND_NO_EXC);
#endif  // HALF_PRECISION_ROUND || defined(CUSTOM_NUMERICS_ROUND)
  __m256 tmp = _mm256_fmadd_ps(_mm256_cvtph_ps(a_fp16),
                               _mm256_cvtph_ps(b_fp16),
                               c);
#if defined(HALF_PRECISION_OUT16) || defined(CUSTOM_NUMERICS_OUT)
#if defined(HALF_PRECISION_ROUND) || defined(CUSTOM_NUMERICS_ROUND)
  __m128i tmp_fp16 = _mm256_cvtps_ph(tmp, _MM_FROUND_NO_EXC);
#else
  __m128i tmp_fp16 = _mm256_cvtps_ph(tmp,
                                     _MM_FROUND_TO_ZERO |_MM_FROUND_NO_EXC);
#endif  // HALF_PRECISION_ROUND || defined(CUSTOM_NUMERICS_ROUND)
  return _mm256_cvtph_ps(tmp_fp16);
#else
  return tmp;
#endif  // HALF_PRECISION_OUT16 || defined(CUSTOM_NUMERICS_OUT)
#else
  return _mm256_fmadd_ps(a, b, c);
#endif  // HALF_PRECISION_FP16
}

EIGEN_STRONG_INLINE Packet8f pmadd_custom(
    const Packet8f& a, const Packet8f& b, const Packet8f& c) {
#if defined(HALF_PRECISION_BF16) || defined(CUSTOM_NUMERICS)
  return pmadd_half_bf16(a, b, c);
#else
  return pmadd_half_fp16(a, b, c);
#endif  // HALF_PRECISION_BF16 || defined(CUSTOM_NUMERICS)
}

} // end namespace internal

} // end namespace Eigen

#endif // EIGEN_PACKET_MATH_AVX_H
