blob: 978818ecd2015521cf98970ecdc9bc7c8e032d08 [file]
// This file is part of Eigen, a lightweight C++ template library
// for linear algebra.
//
// Copyright (C) 2009-2019 Gael Guennebaud <gael.guennebaud@inria.fr>
// Copyright (C) 2018-2025 Rasmus Munk Larsen <rmlarsen@gmail.com>
//
// This Source Code Form is subject to the terms of the Mozilla
// Public License v. 2.0. If a copy of the MPL was not distributed
// with this file, You can obtain one at http://mozilla.org/MPL/2.0/.
#ifndef EIGEN_ARCH_GENERIC_PACKET_MATH_FREXP_LDEXP_H
#define EIGEN_ARCH_GENERIC_PACKET_MATH_FREXP_LDEXP_H
// IWYU pragma: private
#include "../../InternalHeaderCheck.h"
namespace Eigen {
namespace internal {
// Creates a Scalar integer type with same bit-width.
template <typename T>
struct make_integer;
template <>
struct make_integer<float> {
typedef numext::int32_t type;
};
template <>
struct make_integer<double> {
typedef numext::int64_t type;
};
template <>
struct make_integer<half> {
typedef numext::int16_t type;
};
template <>
struct make_integer<bfloat16> {
typedef numext::int16_t type;
};
template <typename Packet>
EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC Packet pfrexp_generic_get_biased_exponent(const Packet& a) {
typedef typename unpacket_traits<Packet>::type Scalar;
typedef typename unpacket_traits<Packet>::integer_packet PacketI;
static constexpr int mantissa_bits = numext::numeric_limits<Scalar>::digits - 1;
return pcast<PacketI, Packet>(plogical_shift_right<mantissa_bits>(preinterpret<PacketI>(pabs(a))));
}
// Safely applies frexp, correctly handles denormals.
// Assumes IEEE floating point format.
template <typename Packet>
EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC Packet pfrexp_generic(const Packet& a, Packet& exponent) {
typedef typename unpacket_traits<Packet>::type Scalar;
typedef typename make_unsigned<typename make_integer<Scalar>::type>::type ScalarUI;
static constexpr int TotalBits = sizeof(Scalar) * CHAR_BIT, MantissaBits = numext::numeric_limits<Scalar>::digits - 1,
ExponentBits = TotalBits - MantissaBits - 1;
constexpr ScalarUI scalar_sign_mantissa_mask =
~(((ScalarUI(1) << ExponentBits) - ScalarUI(1)) << MantissaBits); // ~0x7f800000
const Packet sign_mantissa_mask = pset1frombits<Packet>(static_cast<ScalarUI>(scalar_sign_mantissa_mask));
const Packet half = pset1<Packet>(Scalar(0.5));
const Packet zero = pzero(a);
const Packet normal_min = pset1<Packet>((numext::numeric_limits<Scalar>::min)()); // Minimum normal value, 2^-126
// To handle denormals, normalize by multiplying by 2^(int(MantissaBits)+1).
const Packet is_denormal = pcmp_lt(pabs(a), normal_min);
constexpr ScalarUI scalar_normalization_offset = ScalarUI(MantissaBits + 1); // 24
// The following cannot be constexpr because bfloat16(uint16_t) is not constexpr.
const Scalar scalar_normalization_factor = Scalar(ScalarUI(1) << int(scalar_normalization_offset)); // 2^24
const Packet normalization_factor = pset1<Packet>(scalar_normalization_factor);
const Packet normalized_a = pselect(is_denormal, pmul(a, normalization_factor), a);
// Determine exponent offset: -126 if normal, -126-24 if denormal
const Scalar scalar_exponent_offset = -Scalar((ScalarUI(1) << (ExponentBits - 1)) - ScalarUI(2)); // -126
Packet exponent_offset = pset1<Packet>(scalar_exponent_offset);
const Packet normalization_offset = pset1<Packet>(-Scalar(scalar_normalization_offset)); // -24
exponent_offset = pselect(is_denormal, padd(exponent_offset, normalization_offset), exponent_offset);
// Determine exponent and mantissa from normalized_a.
exponent = pfrexp_generic_get_biased_exponent(normalized_a);
// Zero, Inf and NaN return 'a' unmodified, exponent is zero
// (technically the exponent is unspecified for inf/NaN, but GCC/Clang set it to zero)
const Scalar scalar_non_finite_exponent = Scalar((ScalarUI(1) << ExponentBits) - ScalarUI(1)); // 255
const Packet non_finite_exponent = pset1<Packet>(scalar_non_finite_exponent);
const Packet is_zero_or_not_finite = por(pcmp_eq(a, zero), pcmp_eq(exponent, non_finite_exponent));
const Packet m = pselect(is_zero_or_not_finite, a, por(pand(normalized_a, sign_mantissa_mask), half));
exponent = pselect(is_zero_or_not_finite, zero, padd(exponent, exponent_offset));
return m;
}
// Safely applies ldexp, correctly handles overflows, underflows and denormals.
// Assumes IEEE floating point format.
template <typename Packet>
EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC Packet pldexp_generic(const Packet& a, const Packet& exponent) {
// We want to return a * 2^exponent, allowing for all possible integer
// exponents without overflowing or underflowing in intermediate
// computations.
//
// Since 'a' and the output can be denormal, the maximum range of 'exponent'
// to consider for a float is:
// -255-23 -> 255+23
// Below -278 any finite float 'a' will become zero, and above +278 any
// finite float will become inf, including when 'a' is the smallest possible
// denormal.
//
// Unfortunately, 2^(278) cannot be represented using either one or two
// finite normal floats, so we must split the scale factor into at least
// three parts. It turns out to be faster to split 'exponent' into four
// factors, since [exponent>>2] is much faster to compute that [exponent/3].
//
// Set e = min(max(exponent, -278), 278);
// b = floor(e/4);
// out = ((((a * 2^(b)) * 2^(b)) * 2^(b)) * 2^(e-3*b))
//
// This will avoid any intermediate overflows and correctly handle 0, inf,
// NaN cases.
typedef typename unpacket_traits<Packet>::integer_packet PacketI;
typedef typename unpacket_traits<Packet>::type Scalar;
typedef typename unpacket_traits<PacketI>::type ScalarI;
static constexpr int TotalBits = sizeof(Scalar) * CHAR_BIT, MantissaBits = numext::numeric_limits<Scalar>::digits - 1,
ExponentBits = TotalBits - MantissaBits - 1;
const Packet max_exponent = pset1<Packet>(Scalar((ScalarI(1) << ExponentBits) + ScalarI(MantissaBits - 1))); // 278
const PacketI bias = pset1<PacketI>((ScalarI(1) << (ExponentBits - 1)) - ScalarI(1)); // 127
const PacketI e = pcast<Packet, PacketI>(pmin(pmax(exponent, pnegate(max_exponent)), max_exponent));
PacketI b = parithmetic_shift_right<2>(e); // floor(e/4);
Packet c = preinterpret<Packet>(plogical_shift_left<MantissaBits>(padd(b, bias))); // 2^b
Packet out = pmul(pmul(pmul(a, c), c), c); // a * 2^(3b)
b = pnmadd(pset1<PacketI>(3), b, e); // e - 3b
c = preinterpret<Packet>(plogical_shift_left<MantissaBits>(padd(b, bias))); // 2^(e-3*b)
out = pmul(out, c);
return out;
}
// Explicitly multiplies
// a * (2^e)
// clamping e to the range
// [NumTraits<Scalar>::min_exponent()-2, NumTraits<Scalar>::max_exponent()]
//
// This is approx 7x faster than pldexp_impl, but will prematurely over/underflow
// if 2^e doesn't fit into a normal floating-point Scalar.
//
// Assumes IEEE floating point format
template <typename Packet>
EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC Packet pldexp_fast(const Packet& a, const Packet& exponent) {
typedef typename unpacket_traits<Packet>::integer_packet PacketI;
typedef typename unpacket_traits<Packet>::type Scalar;
typedef typename unpacket_traits<PacketI>::type ScalarI;
static constexpr int TotalBits = sizeof(Scalar) * CHAR_BIT, MantissaBits = numext::numeric_limits<Scalar>::digits - 1,
ExponentBits = TotalBits - MantissaBits - 1;
const Packet bias = pset1<Packet>(Scalar((ScalarI(1) << (ExponentBits - 1)) - ScalarI(1))); // 127
const Packet limit = pset1<Packet>(Scalar((ScalarI(1) << ExponentBits) - ScalarI(1))); // 255
// restrict biased exponent between 0 and 255 for float.
const PacketI e = pcast<Packet, PacketI>(pmin(pmax(padd(exponent, bias), pzero(limit)), limit)); // exponent + 127
// return a * (2^e)
return pmul(a, preinterpret<Packet>(plogical_shift_left<MantissaBits>(e)));
}
} // end namespace internal
} // end namespace Eigen
#endif // EIGEN_ARCH_GENERIC_PACKET_MATH_FREXP_LDEXP_H