blob: 35f931c5326f7054b20ba416748f518178f7610d [file] [log] [blame]
// This file is part of Eigen, a lightweight C++ template library
// for linear algebra.
//
// Mehdi Goli Codeplay Software Ltd.
// Ralph Potter Codeplay Software Ltd.
// Luke Iwanski Codeplay Software Ltd.
// Contact: <eigen@codeplay.com>
//
// This Source Code Form is subject to the terms of the Mozilla
// Public License v. 2.0. If a copy of the MPL was not distributed
// with this file, You can obtain one at http://mozilla.org/MPL/2.0/.
/*****************************************************************
* TensorTensorContractionsycl.h
*
* \brief:
* TensorContractionsycl
*
*****************************************************************/
#ifndef EIGEN_CXX11_TENSOR_TENSOR_CONTRACTION_SYCL_H
#define EIGEN_CXX11_TENSOR_TENSOR_CONTRACTION_SYCL_H
namespace Eigen {
template <typename Index, typename LhsScalar, typename RhsScalar,bool lhs_inner_dim_contiguous, bool rhs_inner_dim_contiguous, bool rhs_inner_dim_reordered> struct LaunchSyclKernels;
template<typename Indices, typename LeftArgType, typename RightArgType, typename OutputKernelType>
struct TensorEvaluator<const TensorContractionOp<Indices, LeftArgType, RightArgType, OutputKernelType>, const Eigen::SyclDevice> :
public TensorContractionEvaluatorBase<TensorEvaluator<const TensorContractionOp<Indices, LeftArgType, RightArgType, OutputKernelType>, const Eigen::SyclDevice> > {
static_assert(std::is_same<OutputKernelType, const NoOpOutputKernel>::value,
"SYCL tensor contraction does not support output kernels.");
typedef const Eigen::SyclDevice Device;
typedef TensorEvaluator<const TensorContractionOp<Indices, LeftArgType, RightArgType, OutputKernelType>, Device> Self;
typedef TensorContractionEvaluatorBase<Self> Base;
typedef TensorContractionOp<Indices, LeftArgType, RightArgType, OutputKernelType> XprType;
typedef typename internal::remove_const<typename XprType::Scalar>::type Scalar;
typedef typename XprType::Index Index;
typedef typename XprType::CoeffReturnType CoeffReturnType;
typedef typename PacketType<CoeffReturnType, Device>::type PacketReturnType;
enum {
Layout = TensorEvaluator<LeftArgType, Device>::Layout,
};
// Most of the code is assuming that both input tensors are ColMajor. If the
// inputs are RowMajor, we will "cheat" by swapping the LHS and RHS:
// If we want to compute A * B = C, where A is LHS and B is RHS, the code
// will pretend B is LHS and A is RHS.
typedef typename internal::conditional<
static_cast<int>(Layout) == static_cast<int>(ColMajor), LeftArgType, RightArgType>::type EvalLeftArgType;
typedef typename internal::conditional<
static_cast<int>(Layout) == static_cast<int>(ColMajor), RightArgType, LeftArgType>::type EvalRightArgType;
static const int LDims =
internal::array_size<typename TensorEvaluator<EvalLeftArgType, Device>::Dimensions>::value;
static const int RDims =
internal::array_size<typename TensorEvaluator<EvalRightArgType, Device>::Dimensions>::value;
static const int ContractDims = internal::array_size<Indices>::value;
typedef array<Index, LDims> left_dim_mapper_t;
typedef array<Index, RDims> right_dim_mapper_t;
typedef array<Index, ContractDims> contract_t;
typedef array<Index, LDims - ContractDims> left_nocontract_t;
typedef array<Index, RDims - ContractDims> right_nocontract_t;
static const int NumDims = LDims + RDims - 2 * ContractDims;
typedef DSizes<Index, NumDims> Dimensions;
// typedefs needed in evalTo
typedef typename internal::remove_const<typename EvalLeftArgType::Scalar>::type LhsScalar;
typedef typename internal::remove_const<typename EvalRightArgType::Scalar>::type RhsScalar;
typedef TensorEvaluator<EvalLeftArgType, Device> LeftEvaluator;
typedef TensorEvaluator<EvalRightArgType, Device> RightEvaluator;
typedef typename LeftEvaluator::Dimensions LeftDimensions;
typedef typename RightEvaluator::Dimensions RightDimensions;
EIGEN_DEVICE_FUNC TensorEvaluator(const XprType& op, const Device& device) :
Base(op, device) {}
// We need to redefine this method to make nvcc happy
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE bool evalSubExprsIfNeeded(Scalar* data) {
this->m_leftImpl.evalSubExprsIfNeeded(NULL);
this->m_rightImpl.evalSubExprsIfNeeded(NULL);
if (data) {
evalTo(data);
return false;
} else {
this->m_result = static_cast<Scalar*>(this->m_device.allocate(this->dimensions().TotalSize() * sizeof(Scalar)));
evalTo(this->m_result);
return true;
}
}
const Eigen::SyclDevice& device() const {return this->m_device;}
void evalTo(Scalar* buffer) const {
// Here is the result
if (this->m_lhs_inner_dim_contiguous) {
if (this->m_rhs_inner_dim_contiguous) {
if (this->m_rhs_inner_dim_reordered) {
evalTyped<true, true, true, Unaligned>(buffer);
}
else {
evalTyped<true, true, false, Unaligned>(buffer);
}
}
else {
if (this->m_rhs_inner_dim_reordered) {
evalTyped<true, false, true, Unaligned>(buffer);
}
else {
evalTyped<true, false, false, Unaligned>(buffer);
}
}
}
else {
if (this->m_rhs_inner_dim_contiguous) {
if (this->m_rhs_inner_dim_reordered) {
evalTyped<false, true, true, Unaligned>(buffer);
}
else {
evalTyped<false, true, false, Unaligned>(buffer);
}
}
else {
if (this->m_rhs_inner_dim_reordered) {
evalTyped<false, false, true, Unaligned>(buffer);
}
else {
evalTyped<false, false, false, Unaligned>(buffer);
}
}
}
}
template <bool lhs_inner_dim_contiguous, bool rhs_inner_dim_contiguous, bool rhs_inner_dim_reordered, int Alignment>
void evalTyped(Scalar* buffer) const {
// columns in left side, rows in right side
const Index k = this->m_k_size;
EIGEN_UNUSED_VARIABLE(k)
// rows in left side
const Index m = this->m_i_size;
// columns in right side
const Index n = this->m_j_size;
// zero out the result buffer (which must be of size at least m * n * sizeof(Scalar)
this->m_device.memset(buffer, 0, m * n * sizeof(Scalar));
LaunchSyclKernels<Index, LhsScalar, RhsScalar,lhs_inner_dim_contiguous, rhs_inner_dim_contiguous, rhs_inner_dim_reordered>::Run(*this, buffer, m, n, k,
this->m_k_strides, this->m_left_contracting_strides, this->m_right_contracting_strides,
this->m_i_strides, this->m_j_strides, this->m_left_nocontract_strides, this->m_right_nocontract_strides);
}
// required by sycl to construct the expr on the device. Returns original left_impl
const TensorEvaluator<LeftArgType, Device>& left_impl() const {
return choose(Cond<static_cast<int>(Layout) == static_cast<int>(ColMajor)>(), this->m_leftImpl, this->m_rightImpl);
}
// required by sycl to construct the expr on the device. Returns original right_impl
const TensorEvaluator<RightArgType, Device>& right_impl() const {
return choose(Cond<static_cast<int>(Layout) == static_cast<int>(ColMajor)>(), this->m_rightImpl, this->m_leftImpl);
}
};
template <typename HostExpr, typename OutScalar, typename LhsScalar, typename RhsScalar, typename LHSFunctorExpr, typename RHSFunctorExpr, typename LhsLocalAcc, typename RhsLocalAcc, typename OutAccessor, typename Index, typename ContractT, typename LeftNocontractT,
typename RightNocontractT, bool lhs_inner_dim_contiguous, bool rhs_inner_dim_contiguous, bool rhs_inner_dim_reordered,
typename HostExpr::Index TileSizeDimM, typename HostExpr::Index TileSizeDimN,typename HostExpr::Index TileSizeDimK, typename HostExpr::Index WorkLoadPerThreadM,typename HostExpr::Index WorkLoadPerThreadN,
typename HostExpr::Index LocalThreadSizeM, typename HostExpr::Index LocalThreadSizeN, typename HostExpr::Index LoadPerThreadLhs, typename HostExpr::Index LoadPerThreadRhs, typename LHSTupleType, typename RHSTupleType, typename Device> struct KernelConstructor{
typedef typename Eigen::internal::traits<HostExpr>::_LhsNested LHSHostExpr;
typedef typename Eigen::internal::traits<HostExpr>::_RhsNested RHSHostExpr;
typedef typename Eigen::TensorSycl::internal::createPlaceHolderExpression<LHSHostExpr>::Type LHSPlaceHolderExpr;
typedef typename Eigen::TensorSycl::internal::createPlaceHolderExpression<RHSHostExpr>::Type RHSPlaceHolderExpr;
LHSFunctorExpr lhs_functors;
RHSFunctorExpr rhs_functors;
LhsLocalAcc localLhs;
RhsLocalAcc localRhs;
OutAccessor out_res;
size_t out_offset;
Index roundUpK, M, N, K;
ContractT m_k_strides, m_left_contracting_strides, m_right_contracting_strides;
LeftNocontractT m_i_strides, m_left_nocontract_strides;
RightNocontractT m_j_strides, m_right_nocontract_strides;
LHSTupleType left_tuple_of_accessors;
RHSTupleType right_tuple_of_accessors;
Device dev;
KernelConstructor(LHSFunctorExpr lhs_functors_, RHSFunctorExpr rhs_functors_, LhsLocalAcc localLhs_, RhsLocalAcc localRhs_, OutAccessor out_res_, size_t out_offset_,
Index roundUpK_, Index M_, Index N_, Index K_, ContractT m_k_strides_, ContractT m_left_contracting_strides_,
ContractT m_right_contracting_strides_, LeftNocontractT m_i_strides_, RightNocontractT m_j_strides_,
LeftNocontractT m_left_nocontract_strides_, RightNocontractT m_right_nocontract_strides_, LHSTupleType left_tuple_of_accessors_, RHSTupleType right_tuple_of_accessors_, Device dev_)
:lhs_functors(lhs_functors_), rhs_functors(rhs_functors_), localLhs(localLhs_), localRhs(localRhs_), out_res(out_res_),
out_offset(out_offset_), roundUpK(roundUpK_), M(M_), N(N_), K(K_),
m_k_strides(m_k_strides_), m_left_contracting_strides(m_left_contracting_strides_),
m_right_contracting_strides(m_right_contracting_strides_),
m_i_strides(m_i_strides_), m_left_nocontract_strides(m_left_nocontract_strides_),
m_j_strides(m_j_strides_), m_right_nocontract_strides(m_right_nocontract_strides_),
left_tuple_of_accessors(left_tuple_of_accessors_), right_tuple_of_accessors(right_tuple_of_accessors_), dev(dev_){}
void operator()(cl::sycl::nd_item<2> itemID) {
typedef typename Eigen::TensorSycl::internal::ConvertToDeviceExpression<HostExpr>::Type DevExpr;
typedef typename Eigen::TensorSycl::internal::ConvertToDeviceExpression<LHSHostExpr>::Type LHSDevExpr;
typedef typename Eigen::TensorSycl::internal::ConvertToDeviceExpression<RHSHostExpr>::Type RHSDevExpr;
auto lhs_dev_expr = Eigen::TensorSycl::internal::createDeviceExpression<LHSDevExpr, LHSPlaceHolderExpr>(lhs_functors, left_tuple_of_accessors);
auto rhs_dev_expr = Eigen::TensorSycl::internal::createDeviceExpression<RHSDevExpr, RHSPlaceHolderExpr>(rhs_functors, right_tuple_of_accessors);
typedef decltype(lhs_dev_expr.expr) LeftArgType;
typedef decltype(rhs_dev_expr.expr) RightArgType;
typedef typename internal::conditional<static_cast<int>(Eigen::internal::traits<DevExpr>::Layout) == static_cast<int>(ColMajor), LeftArgType, RightArgType>::type EvalLeftArgType;
typedef typename internal::conditional<static_cast<int>(Eigen::internal::traits<DevExpr>::Layout) == static_cast<int>(ColMajor), RightArgType, LeftArgType>::type EvalRightArgType;
typedef TensorEvaluator<EvalLeftArgType, Device> LeftEvaluator;
typedef TensorEvaluator<EvalRightArgType, Device> RightEvaluator;
typedef internal::TensorContractionInputMapper<LhsScalar, Index, internal::Lhs,
LeftEvaluator, LeftNocontractT,
ContractT, 1,
lhs_inner_dim_contiguous,
false, Unaligned, MakeGlobalPointer> LhsMapper;
typedef internal::TensorContractionInputMapper<RhsScalar, Index, internal::Rhs,
RightEvaluator, RightNocontractT,
ContractT, 1,
rhs_inner_dim_contiguous,
rhs_inner_dim_reordered, Unaligned, MakeGlobalPointer> RhsMapper;
// initialize data mappers must happen inside the kernel for device eval
LhsMapper lhs(LeftEvaluator(choose(Cond<static_cast<int>(Eigen::internal::traits<DevExpr>::Layout) == static_cast<int>(ColMajor)>(),
lhs_dev_expr.expr, rhs_dev_expr.expr), dev), m_left_nocontract_strides, m_i_strides, m_left_contracting_strides, m_k_strides);
RhsMapper rhs(RightEvaluator(choose(Cond<static_cast<int>(Eigen::internal::traits<DevExpr>::Layout) == static_cast<int>(ColMajor)>(),
rhs_dev_expr.expr, lhs_dev_expr.expr),dev), m_right_nocontract_strides, m_j_strides, m_right_contracting_strides, m_k_strides);
auto out_ptr = ConvertToActualTypeSycl(OutScalar, out_res);
// Matmul Kernel
// Thread identifiers
const Index mLocalThreadId = itemID.get_local(0); // Local ID row
const Index nLocalThreadId = itemID.get_local(1); // Local ID col
const Index mGroupId = itemID.get_group(0); // Work-group ID row
const Index nGroupId = itemID.get_group(1); // Work-group ID localCol
const Index linearLocalThreadId = nLocalThreadId*LocalThreadSizeM + mLocalThreadId; // linear local thread ID
// Allocate register space
LhsScalar privateLhs;
RhsScalar privateRhs[WorkLoadPerThreadN];
OutScalar privateRes[WorkLoadPerThreadM][WorkLoadPerThreadN];
// Initialise the privateResumulation registers
for (Index wLPTM=0; wLPTM<WorkLoadPerThreadM; wLPTM++) {
for (Index wLPTN=0; wLPTN<WorkLoadPerThreadN; wLPTN++) {
privateRes[wLPTM][wLPTN] = static_cast<OutScalar>(0);
}
}
// Tile Lhs
for (Index lPTL=0; lPTL<LoadPerThreadLhs; lPTL++) {
Index localLhsLinearId = lPTL*LocalThreadSizeN*LocalThreadSizeM + linearLocalThreadId;
Index localLhsRow = localLhsLinearId% TileSizeDimM;
Index localLhsCol = localLhsLinearId/TileSizeDimM;
// Load the value (wide vector load)
Index GlobalLhsColId = TileSizeDimK*0 + localLhsCol;
localLhs[0 + ((localLhsCol*TileSizeDimM + localLhsRow)*2)] =((GlobalLhsColId < K)&& (mGroupId*(TileSizeDimM)+ localLhsRow <M))? lhs(mGroupId*(TileSizeDimM) + localLhsRow, GlobalLhsColId):static_cast<OutScalar>(0);
}
// Tile Rhs
for (Index lPTR=0; lPTR<LoadPerThreadRhs; lPTR++) {
Index localRhsLinearId = lPTR*LocalThreadSizeN*LocalThreadSizeM + linearLocalThreadId;
Index localRhsRow = localRhsLinearId% TileSizeDimN;
Index localRhsCol = localRhsLinearId/TileSizeDimN;
// Load the value (wide vector load)
Index GlobalRhsRowId = TileSizeDimK*0 + localRhsCol;
localRhs[0 + ((localRhsCol*TileSizeDimN + localRhsRow) *2)] = ((GlobalRhsRowId < K)&& ((nGroupId*(TileSizeDimN) + localRhsRow)< N))? rhs(GlobalRhsRowId, nGroupId*(TileSizeDimN) + localRhsRow): static_cast<OutScalar>(0);
}
// Loop over all tiles
const Index numTiles = roundUpK/TileSizeDimK;
Index firstHalf=0;
do {
// Synchronise
itemID.barrier(cl::sycl::access::fence_space::local_space);
// Load the next tile of Lhs and Rhs into local memory
Index nextHalf = firstHalf + 1;
if (nextHalf < numTiles) {
// Tile A
for (Index lPTL=0; lPTL<LoadPerThreadLhs; lPTL++) {
Index localLhsLinearId = lPTL*LocalThreadSizeN*LocalThreadSizeM + linearLocalThreadId;
Index localLhsRow = localLhsLinearId% TileSizeDimM;
Index localLhsCol = localLhsLinearId/TileSizeDimM;
// global K id
Index GlobalLhsColId = TileSizeDimK*nextHalf + localLhsCol;
// Store the loaded value into local memory
localLhs[(nextHalf%2) + ((localLhsCol*TileSizeDimM + localLhsRow) *2)] = ((GlobalLhsColId < K)&& (mGroupId*(TileSizeDimM)+ localLhsRow <M))? lhs(mGroupId*(TileSizeDimM) + localLhsRow, GlobalLhsColId): static_cast<OutScalar>(0);
}
// Tile B
for (Index lPTR=0; lPTR<LoadPerThreadRhs; lPTR++) {
Index localRhsLinearId = lPTR*LocalThreadSizeN*LocalThreadSizeM + linearLocalThreadId;
Index localRhsRow = localRhsLinearId% TileSizeDimN;
Index localRhsCol = localRhsLinearId/TileSizeDimN;
// Load the value (wide vector load)
Index GlobalRhsRowId = TileSizeDimK*nextHalf + localRhsCol;
// Store the loaded vector into local memory
localRhs[(nextHalf%2) +((localRhsCol*TileSizeDimN + localRhsRow)*2)] = ((GlobalRhsRowId < K)&& ((nGroupId*(TileSizeDimN) + localRhsRow)< N))? rhs(GlobalRhsRowId, nGroupId*(TileSizeDimN) + localRhsRow):static_cast<OutScalar>(0);
}
}
// Loop over the values of a single tile
for (Index k=0; k<TileSizeDimK; k++) {
// Cache the values of localRhs in registers
for (Index wLPTN=0; wLPTN<WorkLoadPerThreadN; wLPTN++) {
Index localRhsCol = nLocalThreadId + wLPTN*LocalThreadSizeN;
privateRhs[wLPTN] = localRhs[(firstHalf%2) +((k*TileSizeDimN + localRhsCol)*2)];
}
// Perform the computation
for (Index wLPTM=0; wLPTM<WorkLoadPerThreadM; wLPTM++) {
Index localLhsRow = mLocalThreadId + wLPTM*LocalThreadSizeM;
privateLhs = localLhs[(firstHalf%2)+ ((k*TileSizeDimM + localLhsRow)*2)];
for (Index wLPTN=0; wLPTN<WorkLoadPerThreadN; wLPTN++) {
privateRes[wLPTM][wLPTN] += privateLhs * privateRhs[wLPTN];
}
}
}
// Next tile
firstHalf++;
} while (firstHalf<numTiles);
// Store the final results in C
for (Index wLPTM=0; wLPTM<WorkLoadPerThreadM; wLPTM++) {
Index globalRow = mGroupId*TileSizeDimM + mLocalThreadId + wLPTM*LocalThreadSizeM;
if (globalRow< M){
for (Index wLPTN=0; wLPTN<WorkLoadPerThreadN; wLPTN++) {
Index globalCol = nGroupId*TileSizeDimN + nLocalThreadId + wLPTN*LocalThreadSizeN;
if(globalCol<N)
out_ptr[globalCol*M + globalRow +ConvertToActualSyclOffset(OutScalar, out_offset)] = privateRes[wLPTM][wLPTN];
}
}
}
}
};
template <typename Index, typename LhsScalar, typename RhsScalar, bool lhs_inner_dim_contiguous, bool rhs_inner_dim_contiguous, bool rhs_inner_dim_reordered> struct LaunchSyclKernels {
static const Index TileSizeDimM = 32ul; // Tile size for dimension M
static const Index TileSizeDimN = 32ul; // Tile size for dimension N
static const Index TileSizeDimK = 16ul; // Tile size for dimension K
static const Index WorkLoadPerThreadM = 4ul; // Work load per thread in dimension M
static const Index WorkLoadPerThreadN = 4ul; // work load per thread in dimension N
static const Index LocalThreadSizeM = (TileSizeDimM/WorkLoadPerThreadM); // Local thread size for the first dimension (M here)
static const Index LocalThreadSizeN = (TileSizeDimN/WorkLoadPerThreadN); // Local thread size for the second dimension (N here)
static const Index LoadPerThreadLhs = ((TileSizeDimK*WorkLoadPerThreadM*WorkLoadPerThreadN)/(TileSizeDimN)); // workload per thread for Lhs expression
static const Index LoadPerThreadRhs = ((TileSizeDimK*WorkLoadPerThreadM*WorkLoadPerThreadN)/(TileSizeDimM)); // workload per thread for Rhs expression
// RoundUp function to make sure that the global threadId is divisable by local threadId
static Index RoundUp(Index x, Index y) {
return ((((x) + (y) - 1) / (y))*(y));
}
template< typename Self, typename OutScalar, typename ContractT, typename LeftNocontractT, typename RightNocontractT>
static void Run(const Self& self, OutScalar* buffer, Index M, Index N, Index K,
ContractT m_k_strides, ContractT m_left_contracting_strides, ContractT m_right_contracting_strides,
LeftNocontractT m_i_strides, RightNocontractT m_j_strides, LeftNocontractT m_left_nocontract_strides, RightNocontractT m_right_nocontract_strides){
typedef typename Self::XprType HostExpr;
typedef typename Eigen::internal::traits<HostExpr>::_LhsNested LHSHostExpr;
typedef typename Eigen::internal::traits<HostExpr>::_RhsNested RHSHostExpr;
typedef TensorEvaluator<LHSHostExpr, const Eigen::SyclDevice> OrigLHSExpr;
typedef TensorEvaluator<RHSHostExpr, const Eigen::SyclDevice> OrigRHSExpr;
typedef Eigen::TensorSycl::internal::FunctorExtractor<OrigLHSExpr> LHSFunctorExpr;
typedef Eigen::TensorSycl::internal::FunctorExtractor<OrigRHSExpr> RHSFunctorExpr;
// extract lhs functor list
LHSFunctorExpr lhs_functors = Eigen::TensorSycl::internal::extractFunctors(self.left_impl());
// extract rhs functor list
RHSFunctorExpr rhs_functors = Eigen::TensorSycl::internal::extractFunctors(self.right_impl());
Index roundUpK = RoundUp(K, TileSizeDimK);
Index roundUpM = RoundUp(M, TileSizeDimM);
Index roundUpN = RoundUp(N, TileSizeDimN);
ptrdiff_t out_offset = self.device().get_offset(buffer);
self.device().sycl_queue().submit([&](cl::sycl::handler &cgh) {
/// work-around for gcc bug
typedef decltype(Eigen::TensorSycl::internal::createTupleOfAccessors<OrigLHSExpr>(cgh, self.left_impl())) LHSTupleType;
/// work-around for gcc bug
typedef decltype(Eigen::TensorSycl::internal::createTupleOfAccessors<OrigRHSExpr>(cgh, self.right_impl())) RHSTupleType;
// create lhs tuple of accessors
LHSTupleType left_tuple_of_accessors = Eigen::TensorSycl::internal::createTupleOfAccessors<OrigLHSExpr>(cgh, self.left_impl());
// create rhs tuple of accessors
RHSTupleType right_tuple_of_accessors = Eigen::TensorSycl::internal::createTupleOfAccessors<OrigRHSExpr>(cgh, self.right_impl());
// Local memory for elements of Lhs
typedef cl::sycl::accessor<LhsScalar, 1, cl::sycl::access::mode::read_write, cl::sycl::access::target::local> LhsLocalAcc;
LhsLocalAcc localLhs(cl::sycl::range<1>(2* TileSizeDimM * TileSizeDimK), cgh);
// Local memory for elements of Rhs
typedef cl::sycl::accessor<RhsScalar, 1, cl::sycl::access::mode::read_write, cl::sycl::access::target::local> RhsLocalAcc;
RhsLocalAcc localRhs(cl::sycl::range<1>(2* TileSizeDimK * TileSizeDimN), cgh);
typedef cl::sycl::accessor<uint8_t, 1, cl::sycl::access::mode::read_write, cl::sycl::access::target::global_buffer> OutAccessor;
//OutScalar memory
OutAccessor out_res= self.device(). template get_sycl_accessor<cl::sycl::access::mode::read_write>(cgh, buffer);
// sycl parallel for
cgh.parallel_for(cl::sycl::nd_range<2>(cl::sycl::range<2>(roundUpM/WorkLoadPerThreadM, roundUpN/WorkLoadPerThreadN),
cl::sycl::range<2>(LocalThreadSizeM, LocalThreadSizeN)),
KernelConstructor<HostExpr, OutScalar, LhsScalar, RhsScalar, LHSFunctorExpr, RHSFunctorExpr, LhsLocalAcc, RhsLocalAcc, OutAccessor, Index, ContractT, LeftNocontractT,
RightNocontractT, lhs_inner_dim_contiguous, rhs_inner_dim_contiguous, rhs_inner_dim_reordered, TileSizeDimM, TileSizeDimN, TileSizeDimK,
WorkLoadPerThreadM, WorkLoadPerThreadN, LocalThreadSizeM, LocalThreadSizeN, LoadPerThreadLhs, LoadPerThreadRhs, LHSTupleType, RHSTupleType, Eigen::SyclKernelDevice>(lhs_functors, rhs_functors,
localLhs, localRhs, out_res, out_offset, roundUpK, M, N, K, m_k_strides, m_left_contracting_strides, m_right_contracting_strides,m_i_strides, m_j_strides,
m_left_nocontract_strides,m_right_nocontract_strides, left_tuple_of_accessors, right_tuple_of_accessors, Eigen::SyclKernelDevice()));
});
self.device().asynchronousExec();
}
};
} // end namespace Eigen
#endif // EIGEN_CXX11_TENSOR_TENSOR_CONTRACTION_SYCL_H