| // This file is part of Eigen, a lightweight C++ template library |
| // for linear algebra. |
| // |
| // Mehdi Goli Codeplay Software Ltd. |
| // Ralph Potter Codeplay Software Ltd. |
| // Luke Iwanski Codeplay Software Ltd. |
| // Contact: <eigen@codeplay.com> |
| // |
| // This Source Code Form is subject to the terms of the Mozilla |
| // Public License v. 2.0. If a copy of the MPL was not distributed |
| // with this file, You can obtain one at http://mozilla.org/MPL/2.0/. |
| |
| /***************************************************************** |
| * TensorSyclextractFunctors.h |
| * |
| * \brief: |
| * Used to extract all the functors allocated to each node of the expression |
| *tree. |
| * |
| *****************************************************************/ |
| |
| #ifndef UNSUPPORTED_EIGEN_CXX11_SRC_TENSOR_TENSORSYCL_EXTRACT_FUNCTORS_HPP |
| #define UNSUPPORTED_EIGEN_CXX11_SRC_TENSOR_TENSORSYCL_EXTRACT_FUNCTORS_HPP |
| |
| namespace Eigen { |
| namespace TensorSycl { |
| namespace internal { |
| /// struct FunctorExtractor: This struct is used to extract the functors |
| /// constructed on |
| /// the host-side, to pack them and reuse them in reconstruction of the |
| /// expression on the device. |
| /// We have to do that as in Eigen the functors are not stateless so we cannot |
| /// re-instantiate them on the device. |
| /// We have to pass instantiated functors to the device. |
| // This struct is used for leafNode (TensorMap) and nodes behaving like leafNode (TensorForcedEval). |
| template <typename Evaluator> struct FunctorExtractor{ |
| typedef typename Evaluator::Dimensions Dimensions; |
| const Dimensions m_dimensions; |
| const Dimensions& dimensions() const { return m_dimensions; } |
| FunctorExtractor(const Evaluator& expr) |
| : m_dimensions(expr.dimensions()) {} |
| |
| }; |
| |
| /// specialisation of the \ref FunctorExtractor struct when the node type is |
| /// const TensorCwiseNullaryOp, const TensorCwiseUnaryOp, and const TensorBroadcastingOp |
| template <template <class, class> class UnaryCategory, typename OP, typename RHSExpr, typename Dev> |
| struct FunctorExtractor<TensorEvaluator<const UnaryCategory<OP, RHSExpr>, Dev> > { |
| FunctorExtractor<TensorEvaluator<RHSExpr, Dev> > rhsExpr; |
| OP func; |
| FunctorExtractor(const TensorEvaluator<const UnaryCategory<OP, RHSExpr>, Dev>& expr) |
| : rhsExpr(expr.impl()), func(expr.functor()) {} |
| }; |
| /// specialisation of the \ref FunctorExtractor struct when the node type is |
| /// TensorCwiseNullaryOp, TensorCwiseUnaryOp, and TensorBroadcastingOp |
| template <template <class, class> class UnaryCategory, typename OP, typename RHSExpr, typename Dev> |
| struct FunctorExtractor<TensorEvaluator<UnaryCategory<OP, RHSExpr>, Dev> > |
| : FunctorExtractor<TensorEvaluator<const UnaryCategory<OP, RHSExpr>, Dev> >{}; |
| |
| /// specialisation of the \ref FunctorExtractor struct when the node type is |
| /// const TensorCwiseBinaryOp |
| template <template<class, class, class> class BinaryCategory, typename OP, typename LHSExpr, typename RHSExpr, typename Dev> |
| struct FunctorExtractor<TensorEvaluator<const BinaryCategory<OP, LHSExpr, RHSExpr>, Dev> > { |
| FunctorExtractor<TensorEvaluator<LHSExpr, Dev> > lhsExpr; |
| FunctorExtractor<TensorEvaluator<RHSExpr, Dev> > rhsExpr; |
| OP func; |
| FunctorExtractor(const TensorEvaluator<const BinaryCategory<OP, LHSExpr, RHSExpr>, Dev>& expr) |
| : lhsExpr(expr.left_impl()),rhsExpr(expr.right_impl()),func(expr.functor()) {} |
| }; |
| |
| /// specialisation of the \ref FunctorExtractor struct when the node type is |
| /// const TensorCwiseBinaryOp |
| template <template <class, class, class> class BinaryCategory, typename OP, typename LHSExpr, typename RHSExpr, typename Dev> |
| struct FunctorExtractor<TensorEvaluator<BinaryCategory<OP, LHSExpr, RHSExpr>, Dev> > |
| : FunctorExtractor<TensorEvaluator<const BinaryCategory<OP, LHSExpr, RHSExpr>, Dev> >{}; |
| |
| /// specialisation of the \ref FunctorExtractor struct when the node type is |
| /// const TensorCwiseTernaryOp |
| template <template <class, class, class, class> class TernaryCategory, typename OP, typename Arg1Expr, typename Arg2Expr, typename Arg3Expr,typename Dev> |
| struct FunctorExtractor<TensorEvaluator<const TernaryCategory<OP, Arg1Expr, Arg2Expr, Arg3Expr>, Dev> > { |
| FunctorExtractor<TensorEvaluator<Arg1Expr, Dev> > arg1Expr; |
| FunctorExtractor<TensorEvaluator<Arg2Expr, Dev> > arg2Expr; |
| FunctorExtractor<TensorEvaluator<Arg3Expr, Dev> > arg3Expr; |
| OP func; |
| FunctorExtractor(const TensorEvaluator<const TernaryCategory<OP, Arg1Expr, Arg2Expr, Arg3Expr>, Dev>& expr) |
| : arg1Expr(expr.arg1Impl()), arg2Expr(expr.arg2Impl()), arg3Expr(expr.arg3Impl()), func(expr.functor()) {} |
| }; |
| |
| /// specialisation of the \ref FunctorExtractor struct when the node type is |
| /// TensorCwiseTernaryOp |
| template <template <class, class, class, class> class TernaryCategory, typename OP, typename Arg1Expr, typename Arg2Expr, typename Arg3Expr, typename Dev> |
| struct FunctorExtractor<TensorEvaluator< TernaryCategory<OP, Arg1Expr, Arg2Expr, Arg3Expr>, Dev> > |
| :FunctorExtractor<TensorEvaluator<const TernaryCategory<OP, Arg1Expr, Arg2Expr, Arg3Expr>, Dev> >{}; |
| |
| /// specialisation of the \ref FunctorExtractor struct when the node type is |
| /// const TensorCwiseSelectOp. This is an specialisation without OP so it has to be separated. |
| template <typename IfExpr, typename ThenExpr, typename ElseExpr, typename Dev> |
| struct FunctorExtractor< TensorEvaluator<const TensorSelectOp<IfExpr, ThenExpr, ElseExpr>, Dev> > { |
| FunctorExtractor<TensorEvaluator<IfExpr, Dev> > ifExpr; |
| FunctorExtractor<TensorEvaluator<ThenExpr, Dev> > thenExpr; |
| FunctorExtractor<TensorEvaluator<ElseExpr, Dev> > elseExpr; |
| FunctorExtractor(const TensorEvaluator<const TensorSelectOp<IfExpr, ThenExpr, ElseExpr>, Dev>& expr) |
| : ifExpr(expr.cond_impl()), thenExpr(expr.then_impl()), elseExpr(expr.else_impl()) {} |
| }; |
| |
| /// specialisation of the \ref FunctorExtractor struct when the node type is |
| /// TensorCwiseSelectOp. This is an specialisation without OP so it has to be separated |
| template <typename IfExpr, typename ThenExpr, typename ElseExpr, typename Dev> |
| struct FunctorExtractor<TensorEvaluator<TensorSelectOp<IfExpr, ThenExpr, ElseExpr>, Dev> > |
| :FunctorExtractor< TensorEvaluator<const TensorSelectOp<IfExpr, ThenExpr, ElseExpr>, Dev> > {}; |
| |
| /// specialisation of the \ref FunctorExtractor struct when the node type is |
| /// const TensorAssignOp. This is an specialisation without OP so it has to be separated. |
| template <typename LHSExpr, typename RHSExpr, typename Dev> |
| struct FunctorExtractor<TensorEvaluator<const TensorAssignOp<LHSExpr, RHSExpr>, Dev> > { |
| FunctorExtractor<TensorEvaluator<LHSExpr, Dev> > lhsExpr; |
| FunctorExtractor<TensorEvaluator<RHSExpr, Dev> > rhsExpr; |
| FunctorExtractor(const TensorEvaluator<const TensorAssignOp<LHSExpr, RHSExpr>, Dev>& expr) |
| : lhsExpr(expr.left_impl()), rhsExpr(expr.right_impl()) {} |
| }; |
| |
| /// specialisation of the \ref FunctorExtractor struct when the node type is |
| /// TensorAssignOp. This is an specialisation without OP so it has to be separated. |
| template <typename LHSExpr, typename RHSExpr, typename Dev> |
| struct FunctorExtractor<TensorEvaluator<TensorAssignOp<LHSExpr, RHSExpr>, Dev> > |
| :FunctorExtractor<TensorEvaluator<const TensorAssignOp<LHSExpr, RHSExpr>, Dev> >{}; |
| |
| |
| /// specialisation of the \ref FunctorExtractor struct when the node type is |
| /// const TensorEvalToOp, This is an specialisation without OP so it has to be separated. |
| template <typename RHSExpr, typename Dev> |
| struct FunctorExtractor<TensorEvaluator<const TensorEvalToOp<RHSExpr>, Dev> > { |
| FunctorExtractor<TensorEvaluator<RHSExpr, Dev> > rhsExpr; |
| FunctorExtractor(const TensorEvaluator<const TensorEvalToOp<RHSExpr>, Dev>& expr) |
| : rhsExpr(expr.impl()) {} |
| }; |
| |
| /// specialisation of the \ref FunctorExtractor struct when the node type is |
| /// TensorEvalToOp. This is a specialisation without OP so it has to be separated. |
| template <typename RHSExpr, typename Dev> |
| struct FunctorExtractor<TensorEvaluator<TensorEvalToOp<RHSExpr>, Dev> > |
| : FunctorExtractor<TensorEvaluator<const TensorEvalToOp<RHSExpr>, Dev> > {}; |
| |
| template<typename Dim, size_t NumOutputDim> struct DimConstr { |
| template<typename InDim> |
| static inline Dim getDim(InDim dims ) {return dims;} |
| }; |
| |
| template<typename Dim> struct DimConstr<Dim, 0> { |
| template<typename InDim> |
| static inline Dim getDim(InDim dims ) {return Dim(dims.TotalSize());} |
| }; |
| |
| template<typename Op, typename Dims, typename ArgType, template <class> class MakePointer_, typename Device> |
| struct FunctorExtractor<TensorEvaluator<const TensorReductionOp<Op, Dims, ArgType, MakePointer_>, Device>>{ |
| typedef TensorEvaluator<const TensorReductionOp<Op, Dims, ArgType, MakePointer_>, Device> Evaluator; |
| typedef typename Eigen::internal::conditional<Evaluator::NumOutputDims==0, DSizes<typename Evaluator::Index, 1>, typename Evaluator::Dimensions >::type Dimensions; |
| const Dimensions m_dimensions; |
| const Dimensions& dimensions() const { return m_dimensions; } |
| FunctorExtractor(const TensorEvaluator<const TensorReductionOp<Op, Dims, ArgType, MakePointer_>, Device>& expr) |
| : m_dimensions(DimConstr<Dimensions, Evaluator::NumOutputDims>::getDim(expr.dimensions())) {} |
| }; |
| |
| |
| template<typename Op, typename Dims, typename ArgType, template <class> class MakePointer_, typename Device> |
| struct FunctorExtractor<TensorEvaluator<TensorReductionOp<Op, Dims, ArgType, MakePointer_>, Device>> |
| : FunctorExtractor<TensorEvaluator<const TensorReductionOp<Op, Dims, ArgType, MakePointer_>, Device>>{}; |
| /// template deduction function for FunctorExtractor |
| template <typename Evaluator> |
| auto inline extractFunctors(const Evaluator& evaluator)-> FunctorExtractor<Evaluator> { |
| return FunctorExtractor<Evaluator>(evaluator); |
| } |
| } // namespace internal |
| } // namespace TensorSycl |
| } // namespace Eigen |
| |
| #endif // UNSUPPORTED_EIGEN_CXX11_SRC_TENSOR_TENSORSYCL_EXTRACT_FUNCTORS_HPP |