blob: 98223fe79ddc7c548f5415cae2b3cd5a617e4e98 [file] [log] [blame]
// This file is part of Eigen, a lightweight C++ template library
// for linear algebra.
//
// Copyright (C) 2014 Benoit Steiner <benoit.steiner.goog@gmail.com>
//
// This Source Code Form is subject to the terms of the Mozilla
// Public License v. 2.0. If a copy of the MPL was not distributed
// with this file, You can obtain one at http://mozilla.org/MPL/2.0/.
#ifndef EIGEN_CXX11_TENSOR_TENSOR_REF_H
#define EIGEN_CXX11_TENSOR_TENSOR_REF_H
// IWYU pragma: private
#include "./InternalHeaderCheck.h"
namespace Eigen {
namespace internal {
template <typename Dimensions, typename Scalar>
class TensorLazyBaseEvaluator {
public:
TensorLazyBaseEvaluator() : m_refcount(0) {}
virtual ~TensorLazyBaseEvaluator() {}
EIGEN_DEVICE_FUNC virtual const Dimensions& dimensions() const = 0;
EIGEN_DEVICE_FUNC virtual const Scalar* data() const = 0;
EIGEN_DEVICE_FUNC virtual const Scalar coeff(DenseIndex index) const = 0;
EIGEN_DEVICE_FUNC virtual Scalar& coeffRef(DenseIndex index) = 0;
void incrRefCount() { ++m_refcount; }
void decrRefCount() { --m_refcount; }
int refCount() const { return m_refcount; }
private:
// No copy, no assignment;
TensorLazyBaseEvaluator(const TensorLazyBaseEvaluator& other);
TensorLazyBaseEvaluator& operator=(const TensorLazyBaseEvaluator& other);
int m_refcount;
};
template <typename Dimensions, typename Expr, typename Device>
class TensorLazyEvaluatorReadOnly
: public TensorLazyBaseEvaluator<Dimensions, typename TensorEvaluator<Expr, Device>::Scalar> {
public:
// typedef typename TensorEvaluator<Expr, Device>::Dimensions Dimensions;
typedef typename TensorEvaluator<Expr, Device>::Scalar Scalar;
typedef StorageMemory<Scalar, Device> Storage;
typedef typename Storage::Type EvaluatorPointerType;
typedef TensorEvaluator<Expr, Device> EvalType;
TensorLazyEvaluatorReadOnly(const Expr& expr, const Device& device) : m_impl(expr, device), m_dummy(Scalar(0)) {
EIGEN_STATIC_ASSERT(
internal::array_size<Dimensions>::value == internal::array_size<typename EvalType::Dimensions>::value,
"Dimension sizes must match.");
const auto& other_dims = m_impl.dimensions();
for (std::size_t i = 0; i < m_dims.size(); ++i) {
m_dims[i] = other_dims[i];
}
m_impl.evalSubExprsIfNeeded(NULL);
}
virtual ~TensorLazyEvaluatorReadOnly() { m_impl.cleanup(); }
EIGEN_DEVICE_FUNC virtual const Dimensions& dimensions() const { return m_dims; }
EIGEN_DEVICE_FUNC virtual const Scalar* data() const { return m_impl.data(); }
EIGEN_DEVICE_FUNC virtual const Scalar coeff(DenseIndex index) const { return m_impl.coeff(index); }
EIGEN_DEVICE_FUNC virtual Scalar& coeffRef(DenseIndex /*index*/) {
eigen_assert(false && "can't reference the coefficient of a rvalue");
return m_dummy;
};
protected:
TensorEvaluator<Expr, Device> m_impl;
Dimensions m_dims;
Scalar m_dummy;
};
template <typename Dimensions, typename Expr, typename Device>
class TensorLazyEvaluatorWritable : public TensorLazyEvaluatorReadOnly<Dimensions, Expr, Device> {
public:
typedef TensorLazyEvaluatorReadOnly<Dimensions, Expr, Device> Base;
typedef typename Base::Scalar Scalar;
typedef StorageMemory<Scalar, Device> Storage;
typedef typename Storage::Type EvaluatorPointerType;
TensorLazyEvaluatorWritable(const Expr& expr, const Device& device) : Base(expr, device) {}
virtual ~TensorLazyEvaluatorWritable() {}
EIGEN_DEVICE_FUNC virtual Scalar& coeffRef(DenseIndex index) { return this->m_impl.coeffRef(index); }
};
template <typename Dimensions, typename Expr, typename Device, bool IsWritable>
class TensorLazyEvaluator : public std::conditional_t<IsWritable, TensorLazyEvaluatorWritable<Dimensions, Expr, Device>,
TensorLazyEvaluatorReadOnly<Dimensions, const Expr, Device>> {
public:
typedef std::conditional_t<IsWritable, TensorLazyEvaluatorWritable<Dimensions, Expr, Device>,
TensorLazyEvaluatorReadOnly<Dimensions, const Expr, Device>>
Base;
typedef typename Base::Scalar Scalar;
TensorLazyEvaluator(const Expr& expr, const Device& device) : Base(expr, device) {}
virtual ~TensorLazyEvaluator() {}
};
template <typename Derived>
class TensorRefBase : public TensorBase<Derived> {
public:
typedef typename traits<Derived>::PlainObjectType PlainObjectType;
typedef typename PlainObjectType::Base Base;
typedef typename Eigen::internal::nested<Derived>::type Nested;
typedef typename traits<PlainObjectType>::StorageKind StorageKind;
typedef typename traits<PlainObjectType>::Index Index;
typedef typename traits<PlainObjectType>::Scalar Scalar;
typedef typename NumTraits<Scalar>::Real RealScalar;
typedef typename Base::CoeffReturnType CoeffReturnType;
typedef Scalar* PointerType;
typedef PointerType PointerArgType;
static constexpr Index NumIndices = PlainObjectType::NumIndices;
typedef typename PlainObjectType::Dimensions Dimensions;
static constexpr int Layout = PlainObjectType::Layout;
enum {
IsAligned = false,
PacketAccess = false,
BlockAccess = false,
PreferBlockAccess = false,
CoordAccess = false, // to be implemented
RawAccess = false
};
//===- Tensor block evaluation strategy (see TensorBlock.h) -----------===//
typedef TensorBlockNotImplemented TensorBlock;
//===------------------------------------------------------------------===//
EIGEN_STRONG_INLINE TensorRefBase() : m_evaluator(NULL) {}
TensorRefBase(const TensorRefBase& other) : TensorBase<Derived>(other), m_evaluator(other.m_evaluator) {
eigen_assert(m_evaluator->refCount() > 0);
m_evaluator->incrRefCount();
}
TensorRefBase& operator=(const TensorRefBase& other) {
if (this != &other) {
unrefEvaluator();
m_evaluator = other.m_evaluator;
eigen_assert(m_evaluator->refCount() > 0);
m_evaluator->incrRefCount();
}
return *this;
}
template <typename Expression,
typename EnableIf = std::enable_if_t<!std::is_same<std::decay_t<Expression>, Derived>::value>>
EIGEN_STRONG_INLINE TensorRefBase(const Expression& expr)
: m_evaluator(new TensorLazyEvaluator<Dimensions, Expression, DefaultDevice,
/*IsWritable=*/!std::is_const<PlainObjectType>::value &&
bool(is_lvalue<Expression>::value)>(expr, DefaultDevice())) {
m_evaluator->incrRefCount();
}
template <typename Expression,
typename EnableIf = std::enable_if_t<!std::is_same<std::decay_t<Expression>, Derived>::value>>
EIGEN_STRONG_INLINE TensorRefBase& operator=(const Expression& expr) {
unrefEvaluator();
m_evaluator = new TensorLazyEvaluator < Dimensions, Expression, DefaultDevice,
/*IsWritable=*/!std::is_const<PlainObjectType>::value&& bool(is_lvalue<Expression>::value) >
(expr, DefaultDevice());
m_evaluator->incrRefCount();
return *this;
}
~TensorRefBase() { unrefEvaluator(); }
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Index rank() const { return m_evaluator->dimensions().size(); }
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Index dimension(Index n) const { return m_evaluator->dimensions()[n]; }
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const Dimensions& dimensions() const { return m_evaluator->dimensions(); }
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Index size() const { return m_evaluator->dimensions().TotalSize(); }
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const Scalar* data() const { return m_evaluator->data(); }
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const Scalar operator()(Index index) const { return m_evaluator->coeff(index); }
template <typename... IndexTypes>
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const Scalar operator()(Index firstIndex, IndexTypes... otherIndices) const {
const std::size_t num_indices = (sizeof...(otherIndices) + 1);
const array<Index, num_indices> indices{{firstIndex, otherIndices...}};
return coeff(indices);
}
template <std::size_t NumIndices>
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const Scalar coeff(const array<Index, NumIndices>& indices) const {
const Dimensions& dims = this->dimensions();
Index index = 0;
if (PlainObjectType::Options & RowMajor) {
index += indices[0];
for (size_t i = 1; i < NumIndices; ++i) {
index = index * dims[i] + indices[i];
}
} else {
index += indices[NumIndices - 1];
for (int i = NumIndices - 2; i >= 0; --i) {
index = index * dims[i] + indices[i];
}
}
return m_evaluator->coeff(index);
}
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const Scalar coeff(Index index) const { return m_evaluator->coeff(index); }
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Scalar& coeffRef(Index index) { return m_evaluator->coeffRef(index); }
protected:
TensorLazyBaseEvaluator<Dimensions, Scalar>* evaluator() { return m_evaluator; }
private:
EIGEN_STRONG_INLINE void unrefEvaluator() {
if (m_evaluator) {
m_evaluator->decrRefCount();
if (m_evaluator->refCount() == 0) {
delete m_evaluator;
}
}
}
TensorLazyBaseEvaluator<Dimensions, Scalar>* m_evaluator;
};
} // namespace internal
/**
* \ingroup CXX11_Tensor_Module
*
* \brief A reference to a tensor expression
* The expression will be evaluated lazily (as much as possible).
*
*/
template <typename PlainObjectType>
class TensorRef : public internal::TensorRefBase<TensorRef<PlainObjectType>> {
typedef internal::TensorRefBase<TensorRef<PlainObjectType>> Base;
public:
using Scalar = typename Base::Scalar;
using Dimensions = typename Base::Dimensions;
EIGEN_STRONG_INLINE TensorRef() : Base() {}
EIGEN_STRONG_INLINE TensorRef(const TensorRef& other) : Base(other) {}
template <typename Expression>
EIGEN_STRONG_INLINE TensorRef(const Expression& expr) : Base(expr) {
EIGEN_STATIC_ASSERT(internal::is_lvalue<Expression>::value,
"Expression must be mutable to create a mutable TensorRef<Expression>. Did you mean "
"TensorRef<const Expression>?)");
}
TensorRef& operator=(const TensorRef& other) { return Base::operator=(other).derived(); }
template <typename Expression>
EIGEN_STRONG_INLINE TensorRef& operator=(const Expression& expr) {
EIGEN_STATIC_ASSERT(internal::is_lvalue<Expression>::value,
"Expression must be mutable to create a mutable TensorRef<Expression>. Did you mean "
"TensorRef<const Expression>?)");
return Base::operator=(expr).derived();
}
template <typename... IndexTypes>
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Scalar& coeffRef(Index firstIndex, IndexTypes... otherIndices) {
const std::size_t num_indices = (sizeof...(otherIndices) + 1);
const array<Index, num_indices> indices{{firstIndex, otherIndices...}};
return coeffRef(indices);
}
template <std::size_t NumIndices>
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Scalar& coeffRef(const array<Index, NumIndices>& indices) {
const Dimensions& dims = this->dimensions();
Index index = 0;
if (PlainObjectType::Options & RowMajor) {
index += indices[0];
for (size_t i = 1; i < NumIndices; ++i) {
index = index * dims[i] + indices[i];
}
} else {
index += indices[NumIndices - 1];
for (int i = NumIndices - 2; i >= 0; --i) {
index = index * dims[i] + indices[i];
}
}
return Base::evaluator()->coeffRef(index);
}
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Scalar& coeffRef(Index index) { return Base::evaluator()->coeffRef(index); }
};
/**
* \ingroup CXX11_Tensor_Module
*
* \brief A reference to a constant tensor expression
* The expression will be evaluated lazily (as much as possible).
*
*/
template <typename PlainObjectType>
class TensorRef<const PlainObjectType> : public internal::TensorRefBase<TensorRef<const PlainObjectType>> {
typedef internal::TensorRefBase<TensorRef<const PlainObjectType>> Base;
public:
EIGEN_STRONG_INLINE TensorRef() : Base() {}
EIGEN_STRONG_INLINE TensorRef(const TensorRef& other) : Base(other) {}
template <typename Expression>
EIGEN_STRONG_INLINE TensorRef(const Expression& expr) : Base(expr) {}
TensorRef& operator=(const TensorRef& other) { return Base::operator=(other).derived(); }
template <typename Expression>
EIGEN_STRONG_INLINE TensorRef& operator=(const Expression& expr) {
return Base::operator=(expr).derived();
}
};
// evaluator for rvalues
template <typename Derived, typename Device>
struct TensorEvaluator<const TensorRef<Derived>, Device> {
typedef typename Derived::Index Index;
typedef typename Derived::Scalar Scalar;
typedef typename Derived::Scalar CoeffReturnType;
typedef typename PacketType<CoeffReturnType, Device>::type PacketReturnType;
typedef typename Derived::Dimensions Dimensions;
typedef StorageMemory<CoeffReturnType, Device> Storage;
typedef typename Storage::Type EvaluatorPointerType;
static constexpr int Layout = TensorRef<Derived>::Layout;
enum {
IsAligned = false,
PacketAccess = false,
BlockAccess = false,
PreferBlockAccess = false,
CoordAccess = false, // to be implemented
RawAccess = false
};
//===- Tensor block evaluation strategy (see TensorBlock.h) -------------===//
typedef internal::TensorBlockNotImplemented TensorBlock;
//===--------------------------------------------------------------------===//
EIGEN_STRONG_INLINE TensorEvaluator(const TensorRef<Derived>& m, const Device&) : m_ref(m) {}
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const Dimensions& dimensions() const { return m_ref.dimensions(); }
EIGEN_STRONG_INLINE bool evalSubExprsIfNeeded(EvaluatorPointerType) { return true; }
EIGEN_STRONG_INLINE void cleanup() {}
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE CoeffReturnType coeff(Index index) const { return m_ref.coeff(index); }
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Scalar& coeffRef(Index index) { return m_ref.coeffRef(index); }
EIGEN_DEVICE_FUNC const Scalar* data() const { return m_ref.data(); }
protected:
TensorRef<Derived> m_ref;
};
// evaluator for lvalues
template <typename Derived, typename Device>
struct TensorEvaluator<TensorRef<Derived>, Device> : public TensorEvaluator<const TensorRef<Derived>, Device> {
typedef typename Derived::Index Index;
typedef typename Derived::Scalar Scalar;
typedef typename Derived::Scalar CoeffReturnType;
typedef typename PacketType<CoeffReturnType, Device>::type PacketReturnType;
typedef typename Derived::Dimensions Dimensions;
typedef TensorEvaluator<const TensorRef<Derived>, Device> Base;
enum { IsAligned = false, PacketAccess = false, BlockAccess = false, PreferBlockAccess = false, RawAccess = false };
//===- Tensor block evaluation strategy (see TensorBlock.h) -------------===//
typedef internal::TensorBlockNotImplemented TensorBlock;
//===--------------------------------------------------------------------===//
EIGEN_STRONG_INLINE TensorEvaluator(TensorRef<Derived>& m, const Device& d) : Base(m, d) {}
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Scalar& coeffRef(Index index) { return this->m_ref.coeffRef(index); }
};
} // end namespace Eigen
#endif // EIGEN_CXX11_TENSOR_TENSOR_REF_H