blob: 8ceaf967b61c6a8b675516ba73fe9233cd844c2d [file]
// This file is part of Eigen, a lightweight C++ template library
// for linear algebra.
//
// Copyright (C) 2009-2019 Gael Guennebaud <gael.guennebaud@inria.fr>
// Copyright (C) 2018-2025 Rasmus Munk Larsen <rmlarsen@gmail.com>
//
// This Source Code Form is subject to the terms of the Mozilla
// Public License v. 2.0. If a copy of the MPL was not distributed
// with this file, You can obtain one at http://mozilla.org/MPL/2.0/.
#ifndef EIGEN_ARCH_GENERIC_PACKET_MATH_COMPLEX_H
#define EIGEN_ARCH_GENERIC_PACKET_MATH_COMPLEX_H
// IWYU pragma: private
#include "../../InternalHeaderCheck.h"
namespace Eigen {
namespace internal {
//----------------------------------------------------------------------
// Complex Arithmetic and Functions
//----------------------------------------------------------------------
template <typename Packet>
EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS Packet pdiv_complex(const Packet& x, const Packet& y) {
typedef typename unpacket_traits<Packet>::as_real RealPacket;
typedef typename unpacket_traits<RealPacket>::type RealScalar;
// In the following we annotate the code for the case where the inputs
// are a pair length-2 SIMD vectors representing a single pair of complex
// numbers x = a + i*b, y = c + i*d.
const RealPacket one = pset1<RealPacket>(RealScalar(1));
const RealPacket abs_y = pabs(y.v);
const RealPacket abs_y_flip = pcplxflip(Packet(abs_y)).v;
const RealPacket mask = pcmp_lt(abs_y, abs_y_flip); // |c| < |d|
RealPacket y_scaled = pselect(mask, pdiv(abs_y, abs_y_flip), one);
y_scaled = por(y_scaled, pandnot(y.v, abs_y)); // copy signs in case |c| == |d|
RealPacket denom = pmul(y.v, y_scaled);
denom = padd(denom, pcplxflip(Packet(denom)).v); // c * c' + d * d'
Packet num = pmul(x, pconj(Packet(y_scaled))); // a * c' + b * d', -a * d + b * c
return Packet(pdiv(num.v, denom));
}
template <typename Packet>
EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS Packet pmul_complex(const Packet& x, const Packet& y) {
// In the following we annotate the code for the case where the inputs
// are a pair length-2 SIMD vectors representing a single pair of complex
// numbers x = a + i*b, y = c + i*d.
Packet x_re = pdupreal(x); // a, a
Packet x_im = pdupimag(x); // b, b
Packet tmp_re = Packet(pmul(x_re.v, y.v)); // a*c, a*d
Packet tmp_im = Packet(pmul(x_im.v, y.v)); // b*c, b*d
tmp_im = pcplxflip(pconj(tmp_im)); // -b*d, d*c
return padd(tmp_im, tmp_re); // a*c - b*d, a*d + b*c
}
template <typename Packet>
EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS Packet plog_complex(const Packet& x) {
typedef typename unpacket_traits<Packet>::type Scalar;
typedef typename Scalar::value_type RealScalar;
typedef typename unpacket_traits<Packet>::as_real RealPacket;
// Real part
RealPacket x_flip = pcplxflip(x).v; // b, a
Packet x_norm = phypot_complex(x); // sqrt(a^2 + b^2), sqrt(a^2 + b^2)
RealPacket xlogr = plog(x_norm.v); // log(sqrt(a^2 + b^2)), log(sqrt(a^2 + b^2))
// Imag part
RealPacket ximg = patan2(x.v, x_flip); // atan2(a, b), atan2(b, a)
const RealPacket cst_pos_inf = pset1<RealPacket>(NumTraits<RealScalar>::infinity());
RealPacket x_abs = pabs(x.v);
RealPacket is_x_pos_inf = pcmp_eq(x_abs, cst_pos_inf);
RealPacket is_y_pos_inf = pcplxflip(Packet(is_x_pos_inf)).v;
RealPacket is_any_inf = por(is_x_pos_inf, is_y_pos_inf);
RealPacket xreal = pselect(is_any_inf, cst_pos_inf, xlogr);
return Packet(pselect(peven_mask(xreal), xreal, ximg)); // log(sqrt(a^2 + b^2)), atan2(b, a)
}
template <typename Packet>
EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS Packet pexp_complex(const Packet& a) {
typedef typename unpacket_traits<Packet>::as_real RealPacket;
typedef typename unpacket_traits<Packet>::type Scalar;
typedef typename Scalar::value_type RealScalar;
const RealPacket even_mask = peven_mask(a.v);
const RealPacket odd_mask = pcplxflip(Packet(even_mask)).v;
// Let a = x + iy.
// exp(a) = exp(x) * cis(y), plus some special edge-case handling.
// exp(x):
RealPacket x = pand(a.v, even_mask);
x = por(x, pcplxflip(Packet(x)).v);
RealPacket expx = pexp(x); // exp(x);
// cis(y):
RealPacket y = pand(odd_mask, a.v);
y = por(y, pcplxflip(Packet(y)).v);
RealPacket cisy = psincos_selector<RealPacket>(y);
cisy = pcplxflip(Packet(cisy)).v; // cos(y) + i * sin(y)
const RealPacket cst_pos_inf = pset1<RealPacket>(NumTraits<RealScalar>::infinity());
const RealPacket cst_neg_inf = pset1<RealPacket>(-NumTraits<RealScalar>::infinity());
// If x is -inf, we know that cossin(y) is bounded,
// so the result is (0, +/-0), where the sign of the imaginary part comes
// from the sign of cossin(y).
RealPacket cisy_sign = por(pandnot(cisy, pabs(cisy)), pset1<RealPacket>(RealScalar(1)));
cisy = pselect(pcmp_eq(x, cst_neg_inf), cisy_sign, cisy);
// If x is inf, and cos(y) has unknown sign (y is inf or NaN), the result
// is (+/-inf, NaN), where the signs are undetermined (take the sign of y).
RealPacket y_sign = por(pandnot(y, pabs(y)), pset1<RealPacket>(RealScalar(1)));
cisy = pselect(pand(pcmp_eq(x, cst_pos_inf), pisnan(cisy)), pand(y_sign, even_mask), cisy);
// If exp(x) is +inf and y is finite, replace cisy with copysign(1, cisy) to
// prevent inf * 0 = NaN. The vectorized sincos may compute exact zero
// for near-zero values like cos(pi/2), and inf * +-1 = +-inf is correct.
// The y=0 case is handled separately below.
RealPacket cisy_sign_one = por(pand(cisy, pset1<RealPacket>(RealScalar(-0.0))), pset1<RealPacket>(RealScalar(1)));
RealPacket expx_inf_y_finite = pand(pcmp_eq(expx, cst_pos_inf), pcmp_lt(pabs(y), cst_pos_inf));
cisy = pselect(expx_inf_y_finite, cisy_sign_one, cisy);
Packet result = Packet(pmul(expx, cisy));
// If y is +/- 0, the input is real, so take the real result for consistency.
result = pselect(Packet(pcmp_eq(y, pzero(y))), Packet(por(pand(expx, even_mask), pand(y, odd_mask))), result);
return result;
}
template <typename Packet>
EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS Packet psqrt_complex(const Packet& a) {
typedef typename unpacket_traits<Packet>::type Scalar;
typedef typename Scalar::value_type RealScalar;
typedef typename unpacket_traits<Packet>::as_real RealPacket;
// Computes the principal sqrt of the complex numbers in the input.
//
// For example, for packets containing 2 complex numbers stored in interleaved format
// a = [a0, a1] = [x0, y0, x1, y1],
// where x0 = real(a0), y0 = imag(a0) etc., this function returns
// b = [b0, b1] = [u0, v0, u1, v1],
// such that b0^2 = a0, b1^2 = a1.
//
// To derive the formula for the complex square roots, let's consider the equation for
// a single complex square root of the number x + i*y. We want to find real numbers
// u and v such that
// (u + i*v)^2 = x + i*y <=>
// u^2 - v^2 + i*2*u*v = x + i*v.
// By equating the real and imaginary parts we get:
// u^2 - v^2 = x
// 2*u*v = y.
//
// For x >= 0, this has the numerically stable solution
// u = sqrt(0.5 * (x + sqrt(x^2 + y^2)))
// v = 0.5 * (y / u)
// and for x < 0,
// v = sign(y) * sqrt(0.5 * (-x + sqrt(x^2 + y^2)))
// u = 0.5 * (y / v)
//
// To avoid unnecessary over- and underflow, we compute sqrt(x^2 + y^2) as
// l = max(|x|, |y|) * sqrt(1 + (min(|x|, |y|) / max(|x|, |y|))^2) ,
// In the following, without lack of generality, we have annotated the code, assuming
// that the input is a packet of 2 complex numbers.
//
// Step 1. Compute l = [l0, l0, l1, l1], where
// l0 = sqrt(x0^2 + y0^2), l1 = sqrt(x1^2 + y1^2)
// To avoid over- and underflow, we use the stable formula for each hypotenuse
// l0 = (min0 == 0 ? max0 : max0 * sqrt(1 + (min0/max0)**2)),
// where max0 = max(|x0|, |y0|), min0 = min(|x0|, |y0|), and similarly for l1.
RealPacket a_abs = pabs(a.v); // [|x0|, |y0|, |x1|, |y1|]
RealPacket a_abs_flip = pcplxflip(Packet(a_abs)).v; // [|y0|, |x0|, |y1|, |x1|]
RealPacket a_max = pmax(a_abs, a_abs_flip);
RealPacket a_min = pmin(a_abs, a_abs_flip);
RealPacket a_min_zero_mask = pcmp_eq(a_min, pzero(a_min));
RealPacket a_max_zero_mask = pcmp_eq(a_max, pzero(a_max));
RealPacket r = pdiv(a_min, a_max);
const RealPacket cst_one = pset1<RealPacket>(RealScalar(1));
RealPacket l = pmul(a_max, psqrt(padd(cst_one, pmul(r, r)))); // [l0, l0, l1, l1]
// Set l to a_max if a_min is zero.
l = pselect(a_min_zero_mask, a_max, l);
// Step 2. Compute [rho0, *, rho1, *], where
// rho0 = sqrt(0.5 * (l0 + |x0|)), rho1 = sqrt(0.5 * (l1 + |x1|))
// We don't care about the imaginary parts computed here. They will be overwritten later.
const RealPacket cst_half = pset1<RealPacket>(RealScalar(0.5));
Packet rho;
rho.v = psqrt(pmul(cst_half, padd(a_abs, l)));
// Step 3. Compute [rho0, eta0, rho1, eta1], where
// eta0 = (y0 / l0) / 2, and eta1 = (y1 / l1) / 2.
// set eta = 0 of input is 0 + i0.
RealPacket eta = pandnot(pmul(cst_half, pdiv(a.v, pcplxflip(rho).v)), a_max_zero_mask);
RealPacket real_mask = peven_mask(a.v);
Packet positive_real_result;
// Compute result for inputs with positive real part.
positive_real_result.v = pselect(real_mask, rho.v, eta);
// Step 4. Compute solution for inputs with negative real part:
// [|eta0|, sign(y0)*rho0, |eta1|, sign(y1)*rho1]
const RealPacket cst_imag_sign_mask = pset1<Packet>(Scalar(RealScalar(0.0), RealScalar(-0.0))).v;
RealPacket imag_signs = pand(a.v, cst_imag_sign_mask);
Packet negative_real_result;
// Notice that rho is positive, so taking its absolute value is a noop.
negative_real_result.v = por(pabs(pcplxflip(positive_real_result).v), imag_signs);
// Step 5. Select solution branch based on the sign of the real parts.
Packet negative_real_mask;
negative_real_mask.v = pcmp_lt(pand(real_mask, a.v), pzero(a.v));
negative_real_mask.v = por(negative_real_mask.v, pcplxflip(negative_real_mask).v);
Packet result = pselect(negative_real_mask, negative_real_result, positive_real_result);
// Step 6. Handle special cases for infinities:
// * If z is (x,+∞), the result is (+∞,+∞) even if x is NaN
// * If z is (x,-∞), the result is (+∞,-∞) even if x is NaN
// * If z is (-∞,y), the result is (0*|y|,+∞) for finite or NaN y
// * If z is (+∞,y), the result is (+∞,0*|y|) for finite or NaN y
const RealPacket cst_pos_inf = pset1<RealPacket>(NumTraits<RealScalar>::infinity());
Packet is_inf;
is_inf.v = pcmp_eq(a_abs, cst_pos_inf);
Packet is_real_inf;
is_real_inf.v = pand(is_inf.v, real_mask);
is_real_inf = por(is_real_inf, pcplxflip(is_real_inf));
// prepare packet of (+∞,0*|y|) or (0*|y|,+∞), depending on the sign of the infinite real part.
Packet real_inf_result;
real_inf_result.v = pmul(a_abs, pset1<Packet>(Scalar(RealScalar(1.0), RealScalar(0.0))).v);
real_inf_result.v = pselect(negative_real_mask.v, pcplxflip(real_inf_result).v, real_inf_result.v);
// prepare packet of (+∞,+∞) or (+∞,-∞), depending on the sign of the infinite imaginary part.
Packet is_imag_inf;
is_imag_inf.v = pandnot(is_inf.v, real_mask);
is_imag_inf = por(is_imag_inf, pcplxflip(is_imag_inf));
Packet imag_inf_result;
imag_inf_result.v = por(pand(cst_pos_inf, real_mask), pandnot(a.v, real_mask));
// unless otherwise specified, if either the real or imaginary component is nan, the entire result is nan
Packet result_is_nan = pisnan(result);
result = por(result_is_nan, result);
return pselect(is_imag_inf, imag_inf_result, pselect(is_real_inf, real_inf_result, result));
}
// \internal \returns the norm of a complex number z = x + i*y, defined as sqrt(x^2 + y^2).
// Implemented using the hypot(a,b) algorithm from https://doi.org/10.48550/arXiv.1904.09481
template <typename Packet>
EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS Packet phypot_complex(const Packet& a) {
typedef typename unpacket_traits<Packet>::type Scalar;
typedef typename Scalar::value_type RealScalar;
typedef typename unpacket_traits<Packet>::as_real RealPacket;
const RealPacket cst_zero_rp = pset1<RealPacket>(static_cast<RealScalar>(0.0));
const RealPacket cst_minus_one_rp = pset1<RealPacket>(static_cast<RealScalar>(-1.0));
const RealPacket cst_two_rp = pset1<RealPacket>(static_cast<RealScalar>(2.0));
const RealPacket evenmask = peven_mask(a.v);
RealPacket a_abs = pabs(a.v);
RealPacket a_flip = pcplxflip(Packet(a_abs)).v; // |b|, |a|
RealPacket a_all = pselect(evenmask, a_abs, a_flip); // |a|, |a|
RealPacket b_all = pselect(evenmask, a_flip, a_abs); // |b|, |b|
RealPacket a2 = pmul(a.v, a.v); // |a^2, b^2|
RealPacket a2_flip = pcplxflip(Packet(a2)).v; // |b^2, a^2|
RealPacket h = psqrt(padd(a2, a2_flip)); // |sqrt(a^2 + b^2), sqrt(a^2 + b^2)|
RealPacket h_sq = pmul(h, h); // |a^2 + b^2, a^2 + b^2|
RealPacket a_sq = pselect(evenmask, a2, a2_flip); // |a^2, a^2|
RealPacket m_h_sq = pmul(h_sq, cst_minus_one_rp);
RealPacket m_a_sq = pmul(a_sq, cst_minus_one_rp);
RealPacket x = psub(psub(pmadd(h, h, m_h_sq), pmadd(b_all, b_all, psub(a_sq, h_sq))), pmadd(a_all, a_all, m_a_sq));
h = psub(h, pdiv(x, pmul(cst_two_rp, h))); // |h - x/(2*h), h - x/(2*h)|
// handle zero-case
RealPacket iszero = pcmp_eq(por(a_abs, a_flip), cst_zero_rp);
h = pandnot(h, iszero); // |sqrt(a^2+b^2), sqrt(a^2+b^2)|
return Packet(h); // |sqrt(a^2+b^2), sqrt(a^2+b^2)|
}
} // end namespace internal
} // end namespace Eigen
#endif // EIGEN_ARCH_GENERIC_PACKET_MATH_COMPLEX_H