| // This file is part of Eigen, a lightweight C++ template library |
| // for linear algebra. |
| // |
| // Copyright (C) 2008-2009 Gael Guennebaud <gael.guennebaud@inria.fr> |
| // Copyright (C) 2006-2008 Benoit Jacob <jacob.benoit.1@gmail.com> |
| // Copyright (C) 2016 Eugene Brevdo <ebrevdo@gmail.com> |
| // |
| // This Source Code Form is subject to the terms of the Mozilla |
| // Public License v. 2.0. If a copy of the MPL was not distributed |
| // with this file, You can obtain one at http://mozilla.org/MPL/2.0/. |
| |
| #ifndef EIGEN_CWISE_TERNARY_OP_H |
| #define EIGEN_CWISE_TERNARY_OP_H |
| |
| namespace Eigen { |
| |
| /** \class CwiseTernaryOp |
| * \ingroup Core_Module |
| * |
| * \brief Generic expression where a coefficient-wise ternary operator is |
| * applied to two expressions |
| * |
| * \tparam TernaryOp template functor implementing the operator |
| * \tparam Arg1 the type of the first argument |
| * \tparam Arg2 the type of the second argument |
| * \tparam Arg3 the type of the third argument |
| * |
| * This class represents an expression where a coefficient-wise ternary |
| * operator is applied to three expressions. |
| * It is the return type of ternary operators, by which we mean only those |
| * ternary operators where |
| * all three arguments are Eigen expressions. |
| * For example, the return type of betainc(matrix1, matrix2, matrix3) is a |
| * CwiseTernaryOp. |
| * |
| * Most of the time, this is the only way that it is used, so you typically |
| * don't have to name |
| * CwiseTernaryOp types explicitly. |
| */ |
| |
| namespace internal { |
| template <typename TernaryOp, typename Arg1, typename Arg2, typename Arg3> |
| struct traits<CwiseTernaryOp<TernaryOp, Arg1, Arg2, Arg3> > { |
| // we must not inherit from traits<Arg1> since it has |
| // the potential to cause problems with MSVC |
| typedef typename remove_all<Arg1>::type Ancestor; |
| typedef typename traits<Ancestor>::XprKind XprKind; |
| enum { |
| RowsAtCompileTime = traits<Ancestor>::RowsAtCompileTime, |
| ColsAtCompileTime = traits<Ancestor>::ColsAtCompileTime, |
| MaxRowsAtCompileTime = traits<Ancestor>::MaxRowsAtCompileTime, |
| MaxColsAtCompileTime = traits<Ancestor>::MaxColsAtCompileTime |
| }; |
| |
| // even though we require Arg1 and Arg3 to have the same scalar type (see |
| // CwiseTernaryOp constructor), |
| // we still want to handle the case when the result type is different. |
| typedef |
| typename result_of<TernaryOp(typename Arg1::Scalar, typename Arg2::Scalar, |
| typename Arg3::Scalar)>::type Scalar; |
| |
| typedef typename internal::traits<Arg1>::StorageKind StorageKind; |
| typedef typename internal::traits<Arg1>::Index Index; |
| |
| typedef typename Arg1::Nested Arg1Nested; |
| typedef typename Arg2::Nested Arg2Nested; |
| typedef typename Arg3::Nested Arg3Nested; |
| typedef typename remove_reference<Arg1Nested>::type _Arg1Nested; |
| typedef typename remove_reference<Arg1Nested>::type _Arg2Nested; |
| typedef typename remove_reference<Arg3Nested>::type _Arg3Nested; |
| enum { |
| Arg1CoeffReadCost = _Arg1Nested::CoeffReadCost, |
| Arg2CoeffReadCost = _Arg2Nested::CoeffReadCost, |
| Arg3CoeffReadCost = _Arg3Nested::CoeffReadCost, |
| Arg1Flags = _Arg1Nested::Flags, |
| Arg2Flags = _Arg2Nested::Flags, |
| Arg3Flags = _Arg3Nested::Flags, |
| SameType12 = is_same<typename _Arg1Nested::Scalar, |
| typename _Arg2Nested::Scalar>::value, |
| SameType13 = is_same<typename _Arg1Nested::Scalar, |
| typename _Arg3Nested::Scalar>::value, |
| StorageOrdersAgree = |
| ((int(Arg1::Flags) & RowMajorBit) == (int(Arg2::Flags) & RowMajorBit) && |
| (int(Arg1::Flags) & RowMajorBit) == (int(Arg3::Flags) & RowMajorBit)), |
| Flags0 = (int(Arg1Flags) | int(Arg2Flags) | int(Arg3Flags)) & |
| (HereditaryBits | |
| (int(Arg1Flags) & int(Arg2Flags) & int(Arg3Flags) & |
| (AlignedBit | (StorageOrdersAgree ? LinearAccessBit : 0) | |
| (functor_traits<TernaryOp>::PacketAccess && |
| StorageOrdersAgree && SameType12 && SameType13 |
| ? PacketAccessBit |
| : 0)))), |
| Flags = (Flags0 & ~RowMajorBit) | (Arg1Flags & RowMajorBit), |
| CoeffReadCost = Arg1CoeffReadCost + Arg2CoeffReadCost + Arg3CoeffReadCost + |
| functor_traits<TernaryOp>::Cost |
| }; |
| }; |
| } // end namespace internal |
| |
| template <typename TernaryOp, typename Arg1, typename Arg2, typename Arg3, |
| typename StorageKind> |
| class CwiseTernaryOpImpl; |
| |
| template <typename TernaryOp, typename Arg1, typename Arg2, typename Arg3> |
| class CwiseTernaryOp |
| : internal::no_assignment_operator, |
| public CwiseTernaryOpImpl<TernaryOp, Arg1, Arg2, Arg3, |
| typename internal::traits<Arg1>::StorageKind> { |
| public: |
| typedef typename CwiseTernaryOpImpl< |
| TernaryOp, Arg1, Arg2, Arg3, |
| typename internal::traits<Arg1>::StorageKind>::Base Base; |
| EIGEN_GENERIC_PUBLIC_INTERFACE(CwiseTernaryOp) |
| |
| typedef typename internal::nested<Arg1>::type Arg1Nested; |
| typedef typename internal::nested<Arg2>::type Arg2Nested; |
| typedef typename internal::nested<Arg3>::type Arg3Nested; |
| typedef typename internal::remove_reference<Arg1Nested>::type _Arg1Nested; |
| typedef typename internal::remove_reference<Arg2Nested>::type _Arg2Nested; |
| typedef typename internal::remove_reference<Arg3Nested>::type _Arg3Nested; |
| |
| EIGEN_DEVICE_FUNC |
| EIGEN_STRONG_INLINE CwiseTernaryOp(const Arg1& arg1, |
| const Arg2& arg2, |
| const Arg3& arg3, |
| const TernaryOp& func = TernaryOp()) |
| : m_arg1(arg1), m_arg2(arg2), m_arg3(arg3), m_functor(func) { |
| // require the sizes to match |
| EIGEN_STATIC_ASSERT_SAME_MATRIX_SIZE(Arg1, Arg2) |
| EIGEN_STATIC_ASSERT_SAME_MATRIX_SIZE(Arg1, Arg3) |
| eigen_assert(arg1.rows() == arg2.rows() && arg1.cols() == arg2.cols()); |
| eigen_assert(arg1.rows() == arg3.rows() && arg1.cols() == arg3.cols()); |
| } |
| |
| EIGEN_DEVICE_FUNC |
| EIGEN_STRONG_INLINE Index rows() const { |
| // return the fixed size type if available to enable compile time |
| // optimizations |
| if (internal::traits<typename internal::remove_all<Arg1Nested>::type>:: |
| RowsAtCompileTime == Dynamic && |
| internal::traits<typename internal::remove_all<Arg2Nested>::type>:: |
| RowsAtCompileTime == Dynamic) |
| return m_arg3.rows(); |
| else if (internal::traits<typename internal::remove_all<Arg1Nested>::type>:: |
| RowsAtCompileTime == Dynamic && |
| internal::traits<typename internal::remove_all<Arg3Nested>::type>:: |
| RowsAtCompileTime == Dynamic) |
| return m_arg2.rows(); |
| else |
| return m_arg1.rows(); |
| } |
| EIGEN_DEVICE_FUNC |
| EIGEN_STRONG_INLINE Index cols() const { |
| // return the fixed size type if available to enable compile time |
| // optimizations |
| if (internal::traits<typename internal::remove_all<Arg1Nested>::type>:: |
| ColsAtCompileTime == Dynamic && |
| internal::traits<typename internal::remove_all<Arg2Nested>::type>:: |
| ColsAtCompileTime == Dynamic) |
| return m_arg3.cols(); |
| else if (internal::traits<typename internal::remove_all<Arg1Nested>::type>:: |
| ColsAtCompileTime == Dynamic && |
| internal::traits<typename internal::remove_all<Arg3Nested>::type>:: |
| ColsAtCompileTime == Dynamic) |
| return m_arg2.cols(); |
| else |
| return m_arg1.cols(); |
| } |
| |
| /** \returns the first argument nested expression */ |
| EIGEN_DEVICE_FUNC |
| const _Arg1Nested& arg1() const { return m_arg1; } |
| /** \returns the first argument nested expression */ |
| EIGEN_DEVICE_FUNC |
| const _Arg2Nested& arg2() const { return m_arg2; } |
| /** \returns the third argument nested expression */ |
| EIGEN_DEVICE_FUNC |
| const _Arg3Nested& arg3() const { return m_arg3; } |
| /** \returns the functor representing the ternary operation */ |
| EIGEN_DEVICE_FUNC |
| const TernaryOp& functor() const { return m_functor; } |
| |
| protected: |
| Arg1Nested m_arg1; |
| Arg2Nested m_arg2; |
| Arg3Nested m_arg3; |
| const TernaryOp m_functor; |
| }; |
| |
| template <typename TernaryOp, typename Arg1, typename Arg2, typename Arg3> |
| class CwiseTernaryOpImpl<TernaryOp, Arg1, Arg2, Arg3, Dense> |
| : public internal::dense_xpr_base< |
| CwiseTernaryOp<TernaryOp, Arg1, Arg2, Arg3> >::type { |
| typedef CwiseTernaryOp<TernaryOp, Arg1, Arg2, Arg3> Derived; |
| |
| public: |
| typedef typename internal::dense_xpr_base< |
| CwiseTernaryOp<TernaryOp, Arg1, Arg2, Arg3> >::type Base; |
| EIGEN_DENSE_PUBLIC_INTERFACE(Derived) |
| |
| EIGEN_DEVICE_FUNC |
| EIGEN_STRONG_INLINE const Scalar coeff(Index rowId, Index colId) const { |
| return derived().functor()(derived().arg1().coeff(rowId, colId), |
| derived().arg2().coeff(rowId, colId), |
| derived().arg3().coeff(rowId, colId)); |
| } |
| |
| template <int LoadMode> |
| EIGEN_STRONG_INLINE PacketScalar packet(Index rowId, Index colId) const { |
| return derived().functor().packetOp( |
| derived().arg1().template packet<LoadMode>(rowId, colId), |
| derived().arg2().template packet<LoadMode>(rowId, colId), |
| derived().arg3().template packet<LoadMode>(rowId, colId)); |
| } |
| |
| EIGEN_DEVICE_FUNC |
| EIGEN_STRONG_INLINE const Scalar coeff(Index index) const { |
| return derived().functor()(derived().arg1().coeff(index), |
| derived().arg2().coeff(index), |
| derived().arg3().coeff(index)); |
| } |
| |
| template <int LoadMode> |
| EIGEN_STRONG_INLINE PacketScalar packet(Index index) const { |
| return derived().functor().packetOp( |
| derived().arg1().template packet<LoadMode>(index), |
| derived().arg2().template packet<LoadMode>(index), |
| derived().arg3().template packet<LoadMode>(index)); |
| } |
| }; |
| |
| } // end namespace Eigen |
| |
| #endif // EIGEN_CWISE_TERNARY_OP_H |