blob: 6ce80ceed30890a50a4a18eb1da8c7d5ef2ab149 [file] [log] [blame]
// This file is part of Eigen, a lightweight C++ template library
// for linear algebra.
//
// Copyright (C) 2008-2009 Gael Guennebaud <gael.guennebaud@inria.fr>
// Copyright (C) 2006-2008 Benoit Jacob <jacob.benoit.1@gmail.com>
// Copyright (C) 2016 Eugene Brevdo <ebrevdo@gmail.com>
//
// This Source Code Form is subject to the terms of the Mozilla
// Public License v. 2.0. If a copy of the MPL was not distributed
// with this file, You can obtain one at http://mozilla.org/MPL/2.0/.
#ifndef EIGEN_CWISE_TERNARY_OP_H
#define EIGEN_CWISE_TERNARY_OP_H
namespace Eigen {
/** \class CwiseTernaryOp
* \ingroup Core_Module
*
* \brief Generic expression where a coefficient-wise ternary operator is
* applied to two expressions
*
* \tparam TernaryOp template functor implementing the operator
* \tparam Arg1 the type of the first argument
* \tparam Arg2 the type of the second argument
* \tparam Arg3 the type of the third argument
*
* This class represents an expression where a coefficient-wise ternary
* operator is applied to three expressions.
* It is the return type of ternary operators, by which we mean only those
* ternary operators where
* all three arguments are Eigen expressions.
* For example, the return type of betainc(matrix1, matrix2, matrix3) is a
* CwiseTernaryOp.
*
* Most of the time, this is the only way that it is used, so you typically
* don't have to name
* CwiseTernaryOp types explicitly.
*/
namespace internal {
template <typename TernaryOp, typename Arg1, typename Arg2, typename Arg3>
struct traits<CwiseTernaryOp<TernaryOp, Arg1, Arg2, Arg3> > {
// we must not inherit from traits<Arg1> since it has
// the potential to cause problems with MSVC
typedef typename remove_all<Arg1>::type Ancestor;
typedef typename traits<Ancestor>::XprKind XprKind;
enum {
RowsAtCompileTime = traits<Ancestor>::RowsAtCompileTime,
ColsAtCompileTime = traits<Ancestor>::ColsAtCompileTime,
MaxRowsAtCompileTime = traits<Ancestor>::MaxRowsAtCompileTime,
MaxColsAtCompileTime = traits<Ancestor>::MaxColsAtCompileTime
};
// even though we require Arg1 and Arg3 to have the same scalar type (see
// CwiseTernaryOp constructor),
// we still want to handle the case when the result type is different.
typedef
typename result_of<TernaryOp(typename Arg1::Scalar, typename Arg2::Scalar,
typename Arg3::Scalar)>::type Scalar;
typedef typename internal::traits<Arg1>::StorageKind StorageKind;
typedef typename internal::traits<Arg1>::Index Index;
typedef typename Arg1::Nested Arg1Nested;
typedef typename Arg2::Nested Arg2Nested;
typedef typename Arg3::Nested Arg3Nested;
typedef typename remove_reference<Arg1Nested>::type _Arg1Nested;
typedef typename remove_reference<Arg1Nested>::type _Arg2Nested;
typedef typename remove_reference<Arg3Nested>::type _Arg3Nested;
enum {
Arg1CoeffReadCost = _Arg1Nested::CoeffReadCost,
Arg2CoeffReadCost = _Arg2Nested::CoeffReadCost,
Arg3CoeffReadCost = _Arg3Nested::CoeffReadCost,
Arg1Flags = _Arg1Nested::Flags,
Arg2Flags = _Arg2Nested::Flags,
Arg3Flags = _Arg3Nested::Flags,
SameType12 = is_same<typename _Arg1Nested::Scalar,
typename _Arg2Nested::Scalar>::value,
SameType13 = is_same<typename _Arg1Nested::Scalar,
typename _Arg3Nested::Scalar>::value,
StorageOrdersAgree =
((int(Arg1::Flags) & RowMajorBit) == (int(Arg2::Flags) & RowMajorBit) &&
(int(Arg1::Flags) & RowMajorBit) == (int(Arg3::Flags) & RowMajorBit)),
Flags0 = (int(Arg1Flags) | int(Arg2Flags) | int(Arg3Flags)) &
(HereditaryBits |
(int(Arg1Flags) & int(Arg2Flags) & int(Arg3Flags) &
(AlignedBit | (StorageOrdersAgree ? LinearAccessBit : 0) |
(functor_traits<TernaryOp>::PacketAccess &&
StorageOrdersAgree && SameType12 && SameType13
? PacketAccessBit
: 0)))),
Flags = (Flags0 & ~RowMajorBit) | (Arg1Flags & RowMajorBit),
CoeffReadCost = Arg1CoeffReadCost + Arg2CoeffReadCost + Arg3CoeffReadCost +
functor_traits<TernaryOp>::Cost
};
};
} // end namespace internal
template <typename TernaryOp, typename Arg1, typename Arg2, typename Arg3,
typename StorageKind>
class CwiseTernaryOpImpl;
template <typename TernaryOp, typename Arg1, typename Arg2, typename Arg3>
class CwiseTernaryOp
: internal::no_assignment_operator,
public CwiseTernaryOpImpl<TernaryOp, Arg1, Arg2, Arg3,
typename internal::traits<Arg1>::StorageKind> {
public:
typedef typename CwiseTernaryOpImpl<
TernaryOp, Arg1, Arg2, Arg3,
typename internal::traits<Arg1>::StorageKind>::Base Base;
EIGEN_GENERIC_PUBLIC_INTERFACE(CwiseTernaryOp)
typedef typename internal::nested<Arg1>::type Arg1Nested;
typedef typename internal::nested<Arg2>::type Arg2Nested;
typedef typename internal::nested<Arg3>::type Arg3Nested;
typedef typename internal::remove_reference<Arg1Nested>::type _Arg1Nested;
typedef typename internal::remove_reference<Arg2Nested>::type _Arg2Nested;
typedef typename internal::remove_reference<Arg3Nested>::type _Arg3Nested;
EIGEN_DEVICE_FUNC
EIGEN_STRONG_INLINE CwiseTernaryOp(const Arg1& arg1,
const Arg2& arg2,
const Arg3& arg3,
const TernaryOp& func = TernaryOp())
: m_arg1(arg1), m_arg2(arg2), m_arg3(arg3), m_functor(func) {
// require the sizes to match
EIGEN_STATIC_ASSERT_SAME_MATRIX_SIZE(Arg1, Arg2)
EIGEN_STATIC_ASSERT_SAME_MATRIX_SIZE(Arg1, Arg3)
eigen_assert(arg1.rows() == arg2.rows() && arg1.cols() == arg2.cols());
eigen_assert(arg1.rows() == arg3.rows() && arg1.cols() == arg3.cols());
}
EIGEN_DEVICE_FUNC
EIGEN_STRONG_INLINE Index rows() const {
// return the fixed size type if available to enable compile time
// optimizations
if (internal::traits<typename internal::remove_all<Arg1Nested>::type>::
RowsAtCompileTime == Dynamic &&
internal::traits<typename internal::remove_all<Arg2Nested>::type>::
RowsAtCompileTime == Dynamic)
return m_arg3.rows();
else if (internal::traits<typename internal::remove_all<Arg1Nested>::type>::
RowsAtCompileTime == Dynamic &&
internal::traits<typename internal::remove_all<Arg3Nested>::type>::
RowsAtCompileTime == Dynamic)
return m_arg2.rows();
else
return m_arg1.rows();
}
EIGEN_DEVICE_FUNC
EIGEN_STRONG_INLINE Index cols() const {
// return the fixed size type if available to enable compile time
// optimizations
if (internal::traits<typename internal::remove_all<Arg1Nested>::type>::
ColsAtCompileTime == Dynamic &&
internal::traits<typename internal::remove_all<Arg2Nested>::type>::
ColsAtCompileTime == Dynamic)
return m_arg3.cols();
else if (internal::traits<typename internal::remove_all<Arg1Nested>::type>::
ColsAtCompileTime == Dynamic &&
internal::traits<typename internal::remove_all<Arg3Nested>::type>::
ColsAtCompileTime == Dynamic)
return m_arg2.cols();
else
return m_arg1.cols();
}
/** \returns the first argument nested expression */
EIGEN_DEVICE_FUNC
const _Arg1Nested& arg1() const { return m_arg1; }
/** \returns the first argument nested expression */
EIGEN_DEVICE_FUNC
const _Arg2Nested& arg2() const { return m_arg2; }
/** \returns the third argument nested expression */
EIGEN_DEVICE_FUNC
const _Arg3Nested& arg3() const { return m_arg3; }
/** \returns the functor representing the ternary operation */
EIGEN_DEVICE_FUNC
const TernaryOp& functor() const { return m_functor; }
protected:
Arg1Nested m_arg1;
Arg2Nested m_arg2;
Arg3Nested m_arg3;
const TernaryOp m_functor;
};
template <typename TernaryOp, typename Arg1, typename Arg2, typename Arg3>
class CwiseTernaryOpImpl<TernaryOp, Arg1, Arg2, Arg3, Dense>
: public internal::dense_xpr_base<
CwiseTernaryOp<TernaryOp, Arg1, Arg2, Arg3> >::type {
typedef CwiseTernaryOp<TernaryOp, Arg1, Arg2, Arg3> Derived;
public:
typedef typename internal::dense_xpr_base<
CwiseTernaryOp<TernaryOp, Arg1, Arg2, Arg3> >::type Base;
EIGEN_DENSE_PUBLIC_INTERFACE(Derived)
EIGEN_DEVICE_FUNC
EIGEN_STRONG_INLINE const Scalar coeff(Index rowId, Index colId) const {
return derived().functor()(derived().arg1().coeff(rowId, colId),
derived().arg2().coeff(rowId, colId),
derived().arg3().coeff(rowId, colId));
}
template <int LoadMode>
EIGEN_STRONG_INLINE PacketScalar packet(Index rowId, Index colId) const {
return derived().functor().packetOp(
derived().arg1().template packet<LoadMode>(rowId, colId),
derived().arg2().template packet<LoadMode>(rowId, colId),
derived().arg3().template packet<LoadMode>(rowId, colId));
}
EIGEN_DEVICE_FUNC
EIGEN_STRONG_INLINE const Scalar coeff(Index index) const {
return derived().functor()(derived().arg1().coeff(index),
derived().arg2().coeff(index),
derived().arg3().coeff(index));
}
template <int LoadMode>
EIGEN_STRONG_INLINE PacketScalar packet(Index index) const {
return derived().functor().packetOp(
derived().arg1().template packet<LoadMode>(index),
derived().arg2().template packet<LoadMode>(index),
derived().arg3().template packet<LoadMode>(index));
}
};
} // end namespace Eigen
#endif // EIGEN_CWISE_TERNARY_OP_H