blob: 7d3a3fb3607e547a2ec1d9b95b66cef55371d104 [file] [log] [blame]
// This file is part of Eigen, a lightweight C++ template library
// for linear algebra.
//
// Copyright (C) 2020 The Eigen Team.
//
// This Source Code Form is subject to the terms of the Mozilla
// Public License v. 2.0. If a copy of the MPL was not distributed
// with this file, You can obtain one at http://mozilla.org/MPL/2.0/.
#ifndef EIGEN_COHERENT_PAD_OP_H
#define EIGEN_COHERENT_PAD_OP_H
#include "./InternalHeaderCheck.h"
namespace Eigen {
namespace internal {
// Pads a vector with zeros to a given size.
template <typename XprType, int SizeAtCompileTime_>
struct CoherentPadOp;
template <typename XprType, int SizeAtCompileTime_>
struct traits<CoherentPadOp<XprType, SizeAtCompileTime_>> : public traits<XprType> {
typedef typename internal::remove_all<XprType>::type PlainXprType;
typedef typename internal::ref_selector<XprType>::type XprNested;
typedef typename std::remove_reference_t<XprNested> XprNested_;
enum : int {
IsRowMajor = traits<PlainXprType>::Flags & RowMajorBit,
SizeAtCompileTime = SizeAtCompileTime_,
RowsAtCompileTime = IsRowMajor ? 1 : SizeAtCompileTime,
ColsAtCompileTime = IsRowMajor ? SizeAtCompileTime : 1,
MaxRowsAtCompileTime = RowsAtCompileTime,
MaxColsAtCompileTime = ColsAtCompileTime,
Flags = traits<XprType>::Flags & ~NestByRefBit,
};
};
// Pads a vector with zeros to a given size.
template <typename XprType, int SizeAtCompileTime_>
struct CoherentPadOp : public dense_xpr_base<CoherentPadOp<XprType, SizeAtCompileTime_>>::type {
typedef typename internal::generic_xpr_base<CoherentPadOp<XprType, SizeAtCompileTime_>>::type Base;
EIGEN_GENERIC_PUBLIC_INTERFACE(CoherentPadOp)
using XprNested = typename traits<CoherentPadOp>::XprNested;
using XprNested_ = typename traits<CoherentPadOp>::XprNested_;
using NestedExpression = XprNested_;
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE CoherentPadOp() = delete;
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE CoherentPadOp(const CoherentPadOp&) = default;
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE CoherentPadOp(CoherentPadOp&& other) = default;
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE CoherentPadOp(const XprType& xpr, Index size) : xpr_(xpr), size_(size) {
static_assert(XprNested_::IsVectorAtCompileTime, "input type must be a vector");
}
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const XprNested_& nestedExpression() const { return xpr_; }
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Index size() const { return size_.value(); }
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Index rows() const {
return traits<CoherentPadOp>::IsRowMajor ? Index(1) : size();
}
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Index cols() const {
return traits<CoherentPadOp>::IsRowMajor ? size() : Index(1);
}
private:
XprNested xpr_;
const internal::variable_if_dynamic<Index, SizeAtCompileTime> size_;
};
// Adapted from the Replicate evaluator.
template <typename ArgType, int SizeAtCompileTime>
struct unary_evaluator<CoherentPadOp<ArgType, SizeAtCompileTime>>
: evaluator_base<CoherentPadOp<ArgType, SizeAtCompileTime>> {
typedef CoherentPadOp<ArgType, SizeAtCompileTime> XprType;
typedef typename internal::remove_all_t<typename XprType::CoeffReturnType> CoeffReturnType;
typedef typename internal::nested_eval<ArgType, 1>::type ArgTypeNested;
typedef internal::remove_all_t<ArgTypeNested> ArgTypeNestedCleaned;
enum {
CoeffReadCost = evaluator<ArgTypeNestedCleaned>::CoeffReadCost,
LinearAccessMask = XprType::IsVectorAtCompileTime ? LinearAccessBit : 0,
Flags = evaluator<ArgTypeNestedCleaned>::Flags & (HereditaryBits | LinearAccessMask | RowMajorBit),
Alignment = evaluator<ArgTypeNestedCleaned>::Alignment
};
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE explicit unary_evaluator(const XprType& pad)
: m_arg(pad.nestedExpression()), m_argImpl(m_arg), m_size(pad.nestedExpression().size()) {}
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE CoeffReturnType coeff(Index row, Index col) const {
EIGEN_IF_CONSTEXPR(XprType::IsRowMajor) {
if (col < m_size.value()) {
return m_argImpl.coeff(1, col);
}
}
else {
if (row < m_size.value()) {
return m_argImpl.coeff(row, 1);
}
}
return CoeffReturnType(0);
}
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE CoeffReturnType coeff(Index index) const {
if (index < m_size.value()) {
return m_argImpl.coeff(index);
}
return CoeffReturnType(0);
}
template <int LoadMode, typename PacketType>
EIGEN_STRONG_INLINE PacketType packet(Index row, Index col) const {
// AutoDiff scalar's derivative must be a vector, which is enforced by static assert.
// Defer to linear access for simplicity.
EIGEN_IF_CONSTEXPR(XprType::IsRowMajor) { return packet(col); }
return packet(row);
}
template <int LoadMode, typename PacketType>
EIGEN_STRONG_INLINE PacketType packet(Index index) const {
constexpr int kPacketSize = unpacket_traits<PacketType>::size;
if (index + kPacketSize <= m_size.value()) {
return m_argImpl.template packet<LoadMode, PacketType>(index);
} else if (index < m_size.value()) {
// Partial packet.
EIGEN_ALIGN_MAX std::remove_const_t<CoeffReturnType> values[kPacketSize];
const int partial = m_size.value() - index;
for (int i = 0; i < partial && i < kPacketSize; ++i) {
values[i] = m_argImpl.coeff(index + i);
}
for (int i = partial; i < kPacketSize; ++i) {
values[i] = CoeffReturnType(0);
}
return pload<PacketType>(values);
}
return pset1<PacketType>(CoeffReturnType(0));
}
protected:
const ArgTypeNested m_arg;
evaluator<ArgTypeNestedCleaned> m_argImpl;
const variable_if_dynamic<Index, ArgTypeNestedCleaned::SizeAtCompileTime> m_size;
};
} // namespace internal
} // namespace Eigen
#endif // EIGEN_CWISE_BINARY_OP_H