blob: 22cb1c93d2f8dd77348272261281b359c783c5d3 [file] [log] [blame]
// This file is part of Eigen, a lightweight C++ template library
// for linear algebra.
//
// Copyright (C) 2022 Intel Corporation
//
// This Source Code Form is subject to the terms of the Mozilla
// Public License v. 2.0. If a copy of the MPL was not distributed
// with this file, You can obtain one at http://mozilla.org/MPL/2.0/.
#ifndef EIGEN_UNROLLS_IMPL_H
#define EIGEN_UNROLLS_IMPL_H
template <bool isARowMajor = true>
static EIGEN_ALWAYS_INLINE
int64_t idA(int64_t i, int64_t j, int64_t LDA) {
EIGEN_IF_CONSTEXPR(isARowMajor) return i * LDA + j;
else return i + j * LDA;
}
/**
* This namespace contains various classes used to generate compile-time unrolls which are
* used throughout the trsm/gemm kernels. The unrolls are characterized as for-loops (1-D), nested
* for-loops (2-D), or triple nested for-loops (3-D). Unrolls are generated using template recursion
*
* Example, the 2-D for-loop is unrolled recursively by first flattening to a 1-D loop.
*
* for(startI = 0; startI < endI; startI++) for(startC = 0; startC < endI*endJ; startC++)
* for(startJ = 0; startJ < endJ; startJ++) ----> startI = (startC)/(endJ)
* func(startI,startJ) startJ = (startC)%(endJ)
* func(...)
*
* The 1-D loop can be unrolled recursively by using enable_if and defining an auxillary function
* with a template parameter used as a counter.
*
* template <endI, endJ, counter>
* std::enable_if_t<(counter <= 0)> <---- tail case.
* aux_func {}
*
* template <endI, endJ, counter>
* std::enable_if_t<(counter > 0)> <---- actual for-loop
* aux_func {
* startC = endI*endJ - counter
* startI = (startC)/(endJ)
* startJ = (startC)%(endJ)
* func(startI, startJ)
* aux_func<endI, endJ, counter-1>()
* }
*
* Note: Additional wrapper functions are provided for aux_func which hides the counter template
* parameter since counter usually depends on endI, endJ, etc...
*
* Conventions:
* 1) endX: specifies the terminal value for the for-loop, (ex: for(startX = 0; startX < endX; startX++))
*
* 2) rem, remM, remK template parameters are used for deciding whether to use masked operations for
* handling remaining tails (when sizes are not multiples of PacketSize or EIGEN_AVX_MAX_NUM_ROW)
*/
namespace unrolls {
template <int64_t N>
EIGEN_ALWAYS_INLINE auto remMask(int64_t m) {
EIGEN_IF_CONSTEXPR( N == 16) { return 0xFFFF >> (16 - m); }
else EIGEN_IF_CONSTEXPR( N == 8) { return 0xFF >> (8 - m); }
else EIGEN_IF_CONSTEXPR( N == 4) { return 0x0F >> (4 - m); }
return 0;
}
/***
* Unrolls for tranposed C stores
*/
template <typename Scalar>
class trans {
public:
using vec = typename std::conditional<std::is_same<Scalar, float>::value, vecFullFloat, vecFullDouble>::type;
using vecHalf = typename std::conditional<std::is_same<Scalar, float>::value, vecHalfFloat, vecFullDouble>::type;
static constexpr int64_t PacketSize = packet_traits<Scalar>::size;
/***********************************
* Auxillary Functions for:
* - storeC
***********************************
*/
/**
* aux_storeC
*
* 1-D unroll
* for(startN = 0; startN < endN; startN++)
*
* (endN <= PacketSize) is required to handle the fp32 case, see comments in transStoreC
*
**/
template<int64_t endN, int64_t counter, int64_t unrollN, int64_t packetIndexOffset, bool remM>
static EIGEN_ALWAYS_INLINE std::enable_if_t<(counter > 0 && endN <= PacketSize)>
aux_storeC(Scalar *C_arr, int64_t LDC,
PacketBlock<vec,EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> &zmm, int64_t remM_ = 0) {
constexpr int64_t counterReverse = endN-counter;
constexpr int64_t startN = counterReverse;
EIGEN_IF_CONSTEXPR(startN < EIGEN_AVX_MAX_NUM_ROW) {
EIGEN_IF_CONSTEXPR(remM) {
pstoreu<Scalar>(
C_arr + LDC*startN,
padd(ploadu<vecHalf>((const Scalar*)C_arr + LDC*startN, remMask<EIGEN_AVX_MAX_NUM_ROW>(remM_)),
preinterpret<vecHalf>(zmm.packet[packetIndexOffset + (unrollN/PacketSize)*startN]),
remMask<EIGEN_AVX_MAX_NUM_ROW>(remM_)),
remMask<EIGEN_AVX_MAX_NUM_ROW>(remM_));
}
else {
pstoreu<Scalar>(
C_arr + LDC*startN,
padd(ploadu<vecHalf>((const Scalar*)C_arr + LDC*startN),
preinterpret<vecHalf>(zmm.packet[packetIndexOffset + (unrollN/PacketSize)*startN])));
}
}
else { // This block is only needed for fp32 case
// Reinterpret as __m512 for _mm512_shuffle_f32x4
vecFullFloat zmm2vecFullFloat = preinterpret<vecFullFloat>(
zmm.packet[packetIndexOffset + (unrollN/PacketSize)*(startN - EIGEN_AVX_MAX_NUM_ROW)]);
// Swap lower and upper half of avx register.
zmm.packet[packetIndexOffset + (unrollN/PacketSize)*(startN - EIGEN_AVX_MAX_NUM_ROW)] =
preinterpret<vec>(_mm512_shuffle_f32x4(zmm2vecFullFloat, zmm2vecFullFloat, 0b01001110));
EIGEN_IF_CONSTEXPR(remM) {
pstoreu<Scalar>(
C_arr + LDC*startN,
padd(ploadu<vecHalf>((const Scalar*)C_arr + LDC*startN,
remMask<EIGEN_AVX_MAX_NUM_ROW>(remM_)),
preinterpret<vecHalf>(zmm.packet[packetIndexOffset + (unrollN/PacketSize)*(startN-EIGEN_AVX_MAX_NUM_ROW)])),
remMask<EIGEN_AVX_MAX_NUM_ROW>(remM_));
}
else {
pstoreu<Scalar>(
C_arr + LDC*startN,
padd(ploadu<vecHalf>((const Scalar*)C_arr + LDC*startN),
preinterpret<vecHalf>(zmm.packet[packetIndexOffset + (unrollN/PacketSize)*(startN-EIGEN_AVX_MAX_NUM_ROW)])));
}
}
aux_storeC<endN, counter - 1, unrollN, packetIndexOffset, remM>(C_arr, LDC, zmm, remM_);
}
template<int64_t endN, int64_t counter, int64_t unrollN, int64_t packetIndexOffset, bool remM>
static EIGEN_ALWAYS_INLINE std::enable_if_t<!(counter > 0 && endN <= PacketSize)>
aux_storeC(Scalar *C_arr, int64_t LDC,
PacketBlock<vec,EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> &zmm, int64_t remM_ = 0)
{
EIGEN_UNUSED_VARIABLE(C_arr);
EIGEN_UNUSED_VARIABLE(LDC);
EIGEN_UNUSED_VARIABLE(zmm);
EIGEN_UNUSED_VARIABLE(remM_);
}
template<int64_t endN, int64_t unrollN, int64_t packetIndexOffset, bool remM>
static EIGEN_ALWAYS_INLINE
void storeC(Scalar *C_arr, int64_t LDC,
PacketBlock<vec,EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> &zmm, int64_t remM_ = 0){
aux_storeC<endN, endN, unrollN, packetIndexOffset, remM>(C_arr, LDC, zmm, remM_);
}
/**
* Transposes LxunrollN row major block of matrices stored EIGEN_AVX_MAX_NUM_ACC zmm registers to
* "unrollN"xL ymm registers to be stored col-major into C.
*
* For 8x48, the 8x48 block (row-major) is stored in zmm as follows:
*
* row0: zmm0 zmm1 zmm2
* row1: zmm3 zmm4 zmm5
* .
* .
* row7: zmm21 zmm22 zmm23
*
* For 8x32, the 8x32 block (row-major) is stored in zmm as follows:
*
* row0: zmm0 zmm1
* row1: zmm2 zmm3
* .
* .
* row7: zmm14 zmm15
*
*
* In general we will have {1,2,3} groups of avx registers each of size
* EIGEN_AVX_MAX_NUM_ROW. packetIndexOffset is used to select which "block" of
* avx registers are being transposed.
*/
template<int64_t unrollN, int64_t packetIndexOffset>
static EIGEN_ALWAYS_INLINE
void transpose(PacketBlock<vec,EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> &zmm) {
// Note: this assumes EIGEN_AVX_MAX_NUM_ROW = 8. Unrolls should be adjusted
// accordingly if EIGEN_AVX_MAX_NUM_ROW is smaller.
constexpr int64_t zmmStride = unrollN/PacketSize;
PacketBlock<vec,EIGEN_AVX_MAX_NUM_ROW> r;
r.packet[0] = zmm.packet[packetIndexOffset + zmmStride*0];
r.packet[1] = zmm.packet[packetIndexOffset + zmmStride*1];
r.packet[2] = zmm.packet[packetIndexOffset + zmmStride*2];
r.packet[3] = zmm.packet[packetIndexOffset + zmmStride*3];
r.packet[4] = zmm.packet[packetIndexOffset + zmmStride*4];
r.packet[5] = zmm.packet[packetIndexOffset + zmmStride*5];
r.packet[6] = zmm.packet[packetIndexOffset + zmmStride*6];
r.packet[7] = zmm.packet[packetIndexOffset + zmmStride*7];
ptranspose(r);
zmm.packet[packetIndexOffset + zmmStride*0] = r.packet[0];
zmm.packet[packetIndexOffset + zmmStride*1] = r.packet[1];
zmm.packet[packetIndexOffset + zmmStride*2] = r.packet[2];
zmm.packet[packetIndexOffset + zmmStride*3] = r.packet[3];
zmm.packet[packetIndexOffset + zmmStride*4] = r.packet[4];
zmm.packet[packetIndexOffset + zmmStride*5] = r.packet[5];
zmm.packet[packetIndexOffset + zmmStride*6] = r.packet[6];
zmm.packet[packetIndexOffset + zmmStride*7] = r.packet[7];
}
};
/**
* Unrolls for copyBToRowMajor
*
* Idea:
* 1) Load a block of right-hand sides to registers (using loadB).
* 2) Convert the block from column-major to row-major (transposeLxL)
* 3) Store the blocks from register either to a temp array (toTemp == true), or back to B (toTemp == false).
*
* We use at most EIGEN_AVX_MAX_NUM_ACC avx registers to store the blocks of B. The remaining registers are
* used as temps for transposing.
*
* Blocks will be of size Lx{U1,U2,U3}. packetIndexOffset is used to index between these subblocks
* For fp32, PacketSize = 2*EIGEN_AVX_MAX_NUM_ROW, so we reinterpret packets as packets half the size (zmm -> ymm).
*/
template <typename Scalar>
class transB {
public:
using vec = typename std::conditional<std::is_same<Scalar, float>::value, vecFullFloat, vecFullDouble>::type;
using vecHalf = typename std::conditional<std::is_same<Scalar, float>::value, vecHalfFloat, vecFullDouble>::type;
static constexpr int64_t PacketSize = packet_traits<Scalar>::size;
/***********************************
* Auxillary Functions for:
* - loadB
* - storeB
* - loadBBlock
* - storeBBlock
***********************************
*/
/**
* aux_loadB
*
* 1-D unroll
* for(startN = 0; startN < endN; startN++)
**/
template<int64_t endN, int64_t counter, int64_t packetIndexOffset, bool remM>
static EIGEN_ALWAYS_INLINE std::enable_if_t<(counter > 0)>
aux_loadB(Scalar *B_arr, int64_t LDB,
PacketBlock<vecHalf,EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> &ymm, int64_t remM_ = 0) {
constexpr int64_t counterReverse = endN-counter;
constexpr int64_t startN = counterReverse;
EIGEN_IF_CONSTEXPR(remM) {
ymm.packet[packetIndexOffset + startN] = ploadu<vecHalf>(
(const Scalar*)&B_arr[startN*LDB], remMask<EIGEN_AVX_MAX_NUM_ROW>(remM_));
}
else
ymm.packet[packetIndexOffset + startN] = ploadu<vecHalf>((const Scalar*)&B_arr[startN*LDB]);
aux_loadB<endN, counter-1, packetIndexOffset, remM>(B_arr, LDB, ymm, remM_);
}
template<int64_t endN, int64_t counter, int64_t packetIndexOffset, bool remM>
static EIGEN_ALWAYS_INLINE std::enable_if_t<(counter <= 0)>
aux_loadB(Scalar *B_arr, int64_t LDB,
PacketBlock<vecHalf,EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> &ymm, int64_t remM_ = 0)
{
EIGEN_UNUSED_VARIABLE(B_arr);
EIGEN_UNUSED_VARIABLE(LDB);
EIGEN_UNUSED_VARIABLE(ymm);
EIGEN_UNUSED_VARIABLE(remM_);
}
/**
* aux_storeB
*
* 1-D unroll
* for(startN = 0; startN < endN; startN++)
**/
template<int64_t endN, int64_t counter, int64_t packetIndexOffset, bool remK, bool remM>
static EIGEN_ALWAYS_INLINE std::enable_if_t<(counter > 0)>
aux_storeB(Scalar *B_arr, int64_t LDB,
PacketBlock<vecHalf,EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> &ymm, int64_t rem_ = 0) {
constexpr int64_t counterReverse = endN-counter;
constexpr int64_t startN = counterReverse;
EIGEN_IF_CONSTEXPR( remK || remM) {
pstoreu<Scalar>(
&B_arr[startN*LDB],
ymm.packet[packetIndexOffset + startN],
remMask<EIGEN_AVX_MAX_NUM_ROW>(rem_));
}
else {
pstoreu<Scalar>(&B_arr[startN*LDB], ymm.packet[packetIndexOffset + startN]);
}
aux_storeB<endN, counter-1, packetIndexOffset, remK, remM>(B_arr, LDB, ymm, rem_);
}
template<int64_t endN, int64_t counter, int64_t packetIndexOffset, bool remK, bool remM>
static EIGEN_ALWAYS_INLINE std::enable_if_t<(counter <= 0)>
aux_storeB(Scalar *B_arr, int64_t LDB,
PacketBlock<vecHalf,EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> &ymm, int64_t rem_ = 0)
{
EIGEN_UNUSED_VARIABLE(B_arr);
EIGEN_UNUSED_VARIABLE(LDB);
EIGEN_UNUSED_VARIABLE(ymm);
EIGEN_UNUSED_VARIABLE(rem_);
}
/**
* aux_loadBBlock
*
* 1-D unroll
* for(startN = 0; startN < endN; startN += EIGEN_AVX_MAX_NUM_ROW)
**/
template<int64_t endN, int64_t counter, bool toTemp, bool remM>
static EIGEN_ALWAYS_INLINE std::enable_if_t<(counter > 0)>
aux_loadBBlock(Scalar *B_arr, int64_t LDB, Scalar *B_temp, int64_t LDB_,
PacketBlock<vecHalf,EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> &ymm,
int64_t remM_ = 0) {
constexpr int64_t counterReverse = endN-counter;
constexpr int64_t startN = counterReverse;
transB::template loadB<EIGEN_AVX_MAX_NUM_ROW,startN, false>(&B_temp[startN], LDB_, ymm);
aux_loadBBlock<endN, counter-EIGEN_AVX_MAX_NUM_ROW, toTemp, remM>(
B_arr, LDB, B_temp, LDB_, ymm, remM_);
}
template<int64_t endN, int64_t counter, bool toTemp, bool remM>
static EIGEN_ALWAYS_INLINE std::enable_if_t<(counter <= 0)>
aux_loadBBlock(Scalar *B_arr, int64_t LDB, Scalar *B_temp, int64_t LDB_,
PacketBlock<vecHalf,EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> &ymm,
int64_t remM_ = 0)
{
EIGEN_UNUSED_VARIABLE(B_arr);
EIGEN_UNUSED_VARIABLE(LDB);
EIGEN_UNUSED_VARIABLE(B_temp);
EIGEN_UNUSED_VARIABLE(LDB_);
EIGEN_UNUSED_VARIABLE(ymm);
EIGEN_UNUSED_VARIABLE(remM_);
}
/**
* aux_storeBBlock
*
* 1-D unroll
* for(startN = 0; startN < endN; startN += EIGEN_AVX_MAX_NUM_ROW)
**/
template<int64_t endN, int64_t counter, bool toTemp, bool remM, int64_t remK_>
static EIGEN_ALWAYS_INLINE std::enable_if_t<(counter > 0)>
aux_storeBBlock(Scalar *B_arr, int64_t LDB, Scalar *B_temp, int64_t LDB_,
PacketBlock<vecHalf,EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> &ymm,
int64_t remM_ = 0) {
constexpr int64_t counterReverse = endN-counter;
constexpr int64_t startN = counterReverse;
EIGEN_IF_CONSTEXPR(toTemp) {
transB::template storeB<EIGEN_AVX_MAX_NUM_ROW,startN, remK_ != 0, false>(
&B_temp[startN], LDB_, ymm, remK_);
}
else {
transB::template storeB<std::min(EIGEN_AVX_MAX_NUM_ROW,endN),startN, false, remM>(
&B_arr[0 + startN*LDB], LDB, ymm, remM_);
}
aux_storeBBlock<endN, counter-EIGEN_AVX_MAX_NUM_ROW, toTemp, remM, remK_>(
B_arr, LDB, B_temp, LDB_, ymm, remM_);
}
template<int64_t endN, int64_t counter, bool toTemp, bool remM, int64_t remK_>
static EIGEN_ALWAYS_INLINE std::enable_if_t<(counter <= 0)>
aux_storeBBlock(Scalar *B_arr, int64_t LDB, Scalar *B_temp, int64_t LDB_,
PacketBlock<vecHalf,EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> &ymm,
int64_t remM_ = 0)
{
EIGEN_UNUSED_VARIABLE(B_arr);
EIGEN_UNUSED_VARIABLE(LDB);
EIGEN_UNUSED_VARIABLE(B_temp);
EIGEN_UNUSED_VARIABLE(LDB_);
EIGEN_UNUSED_VARIABLE(ymm);
EIGEN_UNUSED_VARIABLE(remM_);
}
/********************************************************
* Wrappers for aux_XXXX to hide counter parameter
********************************************************/
template<int64_t endN, int64_t packetIndexOffset, bool remM>
static EIGEN_ALWAYS_INLINE
void loadB(Scalar *B_arr, int64_t LDB,
PacketBlock<vecHalf,EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> &ymm, int64_t remM_ = 0) {
aux_loadB<endN, endN, packetIndexOffset, remM>(B_arr, LDB, ymm, remM_);
}
template<int64_t endN, int64_t packetIndexOffset, bool remK, bool remM>
static EIGEN_ALWAYS_INLINE
void storeB(Scalar *B_arr, int64_t LDB,
PacketBlock<vecHalf,EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> &ymm, int64_t rem_ = 0) {
aux_storeB<endN, endN, packetIndexOffset, remK, remM>(B_arr, LDB, ymm, rem_);
}
template<int64_t unrollN, bool toTemp, bool remM>
static EIGEN_ALWAYS_INLINE
void loadBBlock(Scalar *B_arr, int64_t LDB, Scalar *B_temp, int64_t LDB_,
PacketBlock<vecHalf,EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> &ymm,
int64_t remM_ = 0) {
EIGEN_IF_CONSTEXPR(toTemp) {
transB::template loadB<unrollN,0,remM>(&B_arr[0],LDB, ymm, remM_);
}
else {
aux_loadBBlock<unrollN, unrollN, toTemp, remM>(
B_arr, LDB, B_temp, LDB_, ymm, remM_);
}
}
template<int64_t unrollN, bool toTemp, bool remM, int64_t remK_>
static EIGEN_ALWAYS_INLINE
void storeBBlock(Scalar *B_arr, int64_t LDB, Scalar *B_temp, int64_t LDB_,
PacketBlock<vecHalf,EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> &ymm,
int64_t remM_ = 0) {
aux_storeBBlock<unrollN, unrollN, toTemp, remM, remK_>(
B_arr, LDB, B_temp, LDB_, ymm, remM_);
}
template<int64_t packetIndexOffset>
static EIGEN_ALWAYS_INLINE
void transposeLxL(PacketBlock<vecHalf,EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> &ymm){
// Note: this assumes EIGEN_AVX_MAX_NUM_ROW = 8. Unrolls should be adjusted
// accordingly if EIGEN_AVX_MAX_NUM_ROW is smaller.
PacketBlock<vecHalf,EIGEN_AVX_MAX_NUM_ROW> r;
r.packet[0] = ymm.packet[packetIndexOffset + 0];
r.packet[1] = ymm.packet[packetIndexOffset + 1];
r.packet[2] = ymm.packet[packetIndexOffset + 2];
r.packet[3] = ymm.packet[packetIndexOffset + 3];
r.packet[4] = ymm.packet[packetIndexOffset + 4];
r.packet[5] = ymm.packet[packetIndexOffset + 5];
r.packet[6] = ymm.packet[packetIndexOffset + 6];
r.packet[7] = ymm.packet[packetIndexOffset + 7];
ptranspose(r);
ymm.packet[packetIndexOffset + 0] = r.packet[0];
ymm.packet[packetIndexOffset + 1] = r.packet[1];
ymm.packet[packetIndexOffset + 2] = r.packet[2];
ymm.packet[packetIndexOffset + 3] = r.packet[3];
ymm.packet[packetIndexOffset + 4] = r.packet[4];
ymm.packet[packetIndexOffset + 5] = r.packet[5];
ymm.packet[packetIndexOffset + 6] = r.packet[6];
ymm.packet[packetIndexOffset + 7] = r.packet[7];
}
template<int64_t unrollN, bool toTemp, bool remM>
static EIGEN_ALWAYS_INLINE
void transB_kernel(Scalar *B_arr, int64_t LDB, Scalar *B_temp, int64_t LDB_,
PacketBlock<vecHalf,EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> &ymm, int64_t remM_ = 0) {
constexpr int64_t U3 = PacketSize * 3;
constexpr int64_t U2 = PacketSize * 2;
constexpr int64_t U1 = PacketSize * 1;
/**
* Unrolls needed for each case:
* - AVX512 fp32 48 32 16 8 4 2 1
* - AVX512 fp64 24 16 8 4 2 1
*
* For fp32 L and U1 are 1:2 so for U3/U2 cases the loads/stores need to be split up.
*/
EIGEN_IF_CONSTEXPR(unrollN == U3) {
// load LxU3 B col major, transpose LxU3 row major
constexpr int64_t maxUBlock = std::min(3*EIGEN_AVX_MAX_NUM_ROW, U3);
transB::template loadBBlock<maxUBlock,toTemp, remM>(B_arr, LDB, B_temp, LDB_, ymm, remM_);
transB::template transposeLxL<0*EIGEN_AVX_MAX_NUM_ROW>(ymm);
transB::template transposeLxL<1*EIGEN_AVX_MAX_NUM_ROW>(ymm);
transB::template transposeLxL<2*EIGEN_AVX_MAX_NUM_ROW>(ymm);
transB::template storeBBlock<maxUBlock,toTemp, remM,0>(B_arr, LDB, B_temp, LDB_, ymm, remM_);
EIGEN_IF_CONSTEXPR( maxUBlock < U3) {
transB::template loadBBlock<maxUBlock,toTemp, remM>(&B_arr[maxUBlock*LDB], LDB, &B_temp[maxUBlock], LDB_, ymm, remM_);
transB::template transposeLxL<0*EIGEN_AVX_MAX_NUM_ROW>(ymm);
transB::template transposeLxL<1*EIGEN_AVX_MAX_NUM_ROW>(ymm);
transB::template transposeLxL<2*EIGEN_AVX_MAX_NUM_ROW>(ymm);
transB::template storeBBlock<maxUBlock,toTemp, remM,0>(&B_arr[maxUBlock*LDB], LDB, &B_temp[maxUBlock], LDB_, ymm, remM_);
}
}
else EIGEN_IF_CONSTEXPR(unrollN == U2) {
// load LxU2 B col major, transpose LxU2 row major
constexpr int64_t maxUBlock = std::min(3*EIGEN_AVX_MAX_NUM_ROW, U2);
transB::template loadBBlock<maxUBlock,toTemp, remM>(B_arr, LDB, B_temp, LDB_, ymm, remM_);
transB::template transposeLxL<0*EIGEN_AVX_MAX_NUM_ROW>(ymm);
transB::template transposeLxL<1*EIGEN_AVX_MAX_NUM_ROW>(ymm);
EIGEN_IF_CONSTEXPR(maxUBlock < U2) transB::template transposeLxL<2*EIGEN_AVX_MAX_NUM_ROW>(ymm);
transB::template storeBBlock<maxUBlock,toTemp,remM,0>(B_arr, LDB, B_temp, LDB_, ymm, remM_);
EIGEN_IF_CONSTEXPR( maxUBlock < U2) {
transB::template loadBBlock<EIGEN_AVX_MAX_NUM_ROW,toTemp, remM>(
&B_arr[maxUBlock*LDB], LDB, &B_temp[maxUBlock], LDB_, ymm, remM_);
transB::template transposeLxL<0>(ymm);
transB::template storeBBlock<EIGEN_AVX_MAX_NUM_ROW,toTemp,remM,0>(
&B_arr[maxUBlock*LDB], LDB, &B_temp[maxUBlock], LDB_, ymm, remM_);
}
}
else EIGEN_IF_CONSTEXPR(unrollN == U1) {
// load LxU1 B col major, transpose LxU1 row major
transB::template loadBBlock<U1,toTemp, remM>(B_arr, LDB, B_temp, LDB_, ymm, remM_);
transB::template transposeLxL<0>(ymm);
EIGEN_IF_CONSTEXPR(EIGEN_AVX_MAX_NUM_ROW < U1) {
transB::template transposeLxL<1*EIGEN_AVX_MAX_NUM_ROW>(ymm);
}
transB::template storeBBlock<U1,toTemp,remM,0>(B_arr, LDB, B_temp, LDB_, ymm, remM_);
}
else EIGEN_IF_CONSTEXPR(unrollN == 8 && U1 > 8) {
// load Lx4 B col major, transpose Lx4 row major
transB::template loadBBlock<8,toTemp,remM>(B_arr, LDB, B_temp, LDB_, ymm, remM_);
transB::template transposeLxL<0>(ymm);
transB::template storeBBlock<8,toTemp,remM,8>(B_arr, LDB, B_temp, LDB_, ymm, remM_);
}
else EIGEN_IF_CONSTEXPR(unrollN == 4 && U1 > 4) {
// load Lx4 B col major, transpose Lx4 row major
transB::template loadBBlock<4,toTemp,remM>(B_arr, LDB, B_temp, LDB_, ymm, remM_);
transB::template transposeLxL<0>(ymm);
transB::template storeBBlock<4,toTemp,remM,4>(B_arr, LDB, B_temp, LDB_, ymm, remM_);
}
else EIGEN_IF_CONSTEXPR(unrollN == 2) {
// load Lx2 B col major, transpose Lx2 row major
transB::template loadBBlock<2,toTemp,remM>(B_arr, LDB, B_temp, LDB_, ymm, remM_);
transB::template transposeLxL<0>(ymm);
transB::template storeBBlock<2,toTemp,remM,2>(B_arr, LDB, B_temp, LDB_, ymm, remM_);
}
else EIGEN_IF_CONSTEXPR(unrollN == 1) {
// load Lx1 B col major, transpose Lx1 row major
transB::template loadBBlock<1,toTemp,remM>(B_arr, LDB, B_temp, LDB_, ymm, remM_);
transB::template transposeLxL<0>(ymm);
transB::template storeBBlock<1,toTemp,remM,1>(B_arr, LDB, B_temp, LDB_, ymm, remM_);
}
}
};
/**
* Unrolls for triSolveKernel
*
* Idea:
* 1) Load a block of right-hand sides to registers in RHSInPacket (using loadRHS).
* 2) Do triangular solve with RHSInPacket and a small block of A (triangular matrix)
* stored in AInPacket (using triSolveMicroKernel).
* 3) Store final results (in avx registers) back into memory (using storeRHS).
*
* RHSInPacket uses at most EIGEN_AVX_MAX_NUM_ACC avx registers and AInPacket uses at most
* EIGEN_AVX_MAX_NUM_ROW registers.
*/
template <typename Scalar>
class trsm {
public:
using vec = typename std::conditional<std::is_same<Scalar, float>::value,
vecFullFloat,
vecFullDouble>::type;
static constexpr int64_t PacketSize = packet_traits<Scalar>::size;
/***********************************
* Auxillary Functions for:
* - loadRHS
* - storeRHS
* - divRHSByDiag
* - updateRHS
* - triSolveMicroKernel
************************************/
/**
* aux_loadRHS
*
* 2-D unroll
* for(startM = 0; startM < endM; startM++)
* for(startK = 0; startK < endK; startK++)
**/
template<bool isFWDSolve, int64_t endM, int64_t endK, int64_t counter, bool krem>
static EIGEN_ALWAYS_INLINE std::enable_if_t<(counter > 0)>
aux_loadRHS(Scalar* B_arr, int64_t LDB, PacketBlock<vec,EIGEN_AVX_MAX_NUM_ACC> &RHSInPacket, int64_t rem = 0) {
constexpr int64_t counterReverse = endM*endK-counter;
constexpr int64_t startM = counterReverse/(endK);
constexpr int64_t startK = counterReverse%endK;
constexpr int64_t packetIndex = startM*endK + startK;
constexpr int64_t startM_ = isFWDSolve ? startM : -startM;
const int64_t rhsIndex = (startK*PacketSize) + startM_*LDB;
EIGEN_IF_CONSTEXPR(krem) {
RHSInPacket.packet[packetIndex] = ploadu<vec>(&B_arr[rhsIndex], remMask<PacketSize>(rem));
}
else {
RHSInPacket.packet[packetIndex] = ploadu<vec>(&B_arr[rhsIndex]);
}
aux_loadRHS<isFWDSolve,endM, endK, counter-1, krem>(B_arr, LDB, RHSInPacket, rem);
}
template<bool isFWDSolve, int64_t endM, int64_t endK, int64_t counter, bool krem>
static EIGEN_ALWAYS_INLINE std::enable_if_t<(counter <= 0)>
aux_loadRHS(Scalar* B_arr, int64_t LDB, PacketBlock<vec,EIGEN_AVX_MAX_NUM_ACC> &RHSInPacket, int64_t rem = 0)
{
EIGEN_UNUSED_VARIABLE(B_arr);
EIGEN_UNUSED_VARIABLE(LDB);
EIGEN_UNUSED_VARIABLE(RHSInPacket);
EIGEN_UNUSED_VARIABLE(rem);
}
/**
* aux_storeRHS
*
* 2-D unroll
* for(startM = 0; startM < endM; startM++)
* for(startK = 0; startK < endK; startK++)
**/
template<bool isFWDSolve, int64_t endM, int64_t endK, int64_t counter, bool krem>
static EIGEN_ALWAYS_INLINE std::enable_if_t<(counter > 0)>
aux_storeRHS(Scalar* B_arr, int64_t LDB, PacketBlock<vec,EIGEN_AVX_MAX_NUM_ACC> &RHSInPacket, int64_t rem = 0) {
constexpr int64_t counterReverse = endM*endK-counter;
constexpr int64_t startM = counterReverse/(endK);
constexpr int64_t startK = counterReverse%endK;
constexpr int64_t packetIndex = startM*endK + startK;
constexpr int64_t startM_ = isFWDSolve ? startM : -startM;
const int64_t rhsIndex = (startK*PacketSize) + startM_*LDB;
EIGEN_IF_CONSTEXPR(krem) {
pstoreu<Scalar>(&B_arr[rhsIndex], RHSInPacket.packet[packetIndex], remMask<PacketSize>(rem));
}
else {
pstoreu<Scalar>(&B_arr[rhsIndex], RHSInPacket.packet[packetIndex]);
}
aux_storeRHS<isFWDSolve,endM, endK, counter-1, krem>(B_arr, LDB, RHSInPacket, rem);
}
template<bool isFWDSolve, int64_t endM, int64_t endK, int64_t counter, bool krem>
static EIGEN_ALWAYS_INLINE std::enable_if_t<(counter <= 0)>
aux_storeRHS(Scalar* B_arr, int64_t LDB, PacketBlock<vec,EIGEN_AVX_MAX_NUM_ACC> &RHSInPacket, int64_t rem = 0)
{
EIGEN_UNUSED_VARIABLE(B_arr);
EIGEN_UNUSED_VARIABLE(LDB);
EIGEN_UNUSED_VARIABLE(RHSInPacket);
EIGEN_UNUSED_VARIABLE(rem);
}
/**
* aux_divRHSByDiag
*
* currM may be -1, (currM >=0) in enable_if checks for this
*
* 1-D unroll
* for(startK = 0; startK < endK; startK++)
**/
template<int64_t currM, int64_t endK, int64_t counter>
static EIGEN_ALWAYS_INLINE std::enable_if_t<(counter > 0 && currM >= 0)>
aux_divRHSByDiag(PacketBlock<vec,EIGEN_AVX_MAX_NUM_ACC> &RHSInPacket, PacketBlock<vec,EIGEN_AVX_MAX_NUM_ROW> &AInPacket) {
constexpr int64_t counterReverse = endK-counter;
constexpr int64_t startK = counterReverse;
constexpr int64_t packetIndex = currM*endK + startK;
RHSInPacket.packet[packetIndex] = pmul(AInPacket.packet[currM], RHSInPacket.packet[packetIndex]);
aux_divRHSByDiag<currM, endK, counter-1>(RHSInPacket, AInPacket);
}
template<int64_t currM, int64_t endK, int64_t counter>
static EIGEN_ALWAYS_INLINE std::enable_if_t<!(counter > 0 && currM >= 0)>
aux_divRHSByDiag(PacketBlock<vec,EIGEN_AVX_MAX_NUM_ACC> &RHSInPacket, PacketBlock<vec,EIGEN_AVX_MAX_NUM_ROW> &AInPacket) {
EIGEN_UNUSED_VARIABLE(RHSInPacket);
EIGEN_UNUSED_VARIABLE(AInPacket);
}
/**
* aux_updateRHS
*
* 2-D unroll
* for(startM = initM; startM < endM; startM++)
* for(startK = 0; startK < endK; startK++)
**/
template<bool isARowMajor, bool isFWDSolve, bool isUnitDiag, int64_t initM, int64_t endM, int64_t endK, int64_t counter, int64_t currentM>
static EIGEN_ALWAYS_INLINE std::enable_if_t<(counter > 0)>
aux_updateRHS(Scalar *A_arr, int64_t LDA, PacketBlock<vec,EIGEN_AVX_MAX_NUM_ACC> &RHSInPacket, PacketBlock<vec,EIGEN_AVX_MAX_NUM_ROW> &AInPacket) {
constexpr int64_t counterReverse = (endM-initM)*endK-counter;
constexpr int64_t startM = initM + counterReverse/(endK);
constexpr int64_t startK = counterReverse%endK;
// For each row of A, first update all corresponding RHS
constexpr int64_t packetIndex = startM*endK + startK;
EIGEN_IF_CONSTEXPR(currentM > 0) {
RHSInPacket.packet[packetIndex] =
pnmadd(AInPacket.packet[startM],
RHSInPacket.packet[(currentM-1)*endK+startK],
RHSInPacket.packet[packetIndex]);
}
EIGEN_IF_CONSTEXPR(startK == endK - 1) {
// Once all RHS for previous row of A is updated, we broadcast the next element in the column A_{i, currentM}.
EIGEN_IF_CONSTEXPR(startM == currentM && !isUnitDiag) {
// If diagonal is not unit, we broadcast reciprocals of diagonals AinPacket.packet[currentM].
// This will be used in divRHSByDiag
EIGEN_IF_CONSTEXPR(isFWDSolve)
AInPacket.packet[currentM] = pset1<vec>(Scalar(1)/A_arr[idA<isARowMajor>(currentM,currentM,LDA)]);
else
AInPacket.packet[currentM] = pset1<vec>(Scalar(1)/A_arr[idA<isARowMajor>(-currentM,-currentM,LDA)]);
}
else {
// Broadcast next off diagonal element of A
EIGEN_IF_CONSTEXPR(isFWDSolve)
AInPacket.packet[startM] = pset1<vec>(A_arr[idA<isARowMajor>(startM,currentM,LDA)]);
else
AInPacket.packet[startM] = pset1<vec>(A_arr[idA<isARowMajor>(-startM,-currentM,LDA)]);
}
}
aux_updateRHS<isARowMajor, isFWDSolve, isUnitDiag, initM, endM, endK, counter - 1, currentM>(A_arr, LDA, RHSInPacket, AInPacket);
}
template<bool isARowMajor, bool isFWDSolve, bool isUnitDiag, int64_t initM, int64_t endM, int64_t endK, int64_t counter, int64_t currentM>
static EIGEN_ALWAYS_INLINE std::enable_if_t<(counter <= 0)>
aux_updateRHS(Scalar *A_arr, int64_t LDA, PacketBlock<vec,EIGEN_AVX_MAX_NUM_ACC> &RHSInPacket, PacketBlock<vec,EIGEN_AVX_MAX_NUM_ROW> &AInPacket) {
EIGEN_UNUSED_VARIABLE(A_arr);
EIGEN_UNUSED_VARIABLE(LDA);
EIGEN_UNUSED_VARIABLE(RHSInPacket);
EIGEN_UNUSED_VARIABLE(AInPacket);
}
/**
* aux_triSolverMicroKernel
*
* 1-D unroll
* for(startM = 0; startM < endM; startM++)
**/
template <bool isARowMajor, bool isFWDSolve, bool isUnitDiag, int64_t endM, int64_t counter, int64_t numK>
static EIGEN_ALWAYS_INLINE std::enable_if_t<(counter > 0)>
aux_triSolveMicroKernel(Scalar *A_arr, int64_t LDA, PacketBlock<vec,EIGEN_AVX_MAX_NUM_ACC> &RHSInPacket, PacketBlock<vec,EIGEN_AVX_MAX_NUM_ROW> &AInPacket) {
constexpr int64_t counterReverse = endM-counter;
constexpr int64_t startM = counterReverse;
constexpr int64_t currentM = startM;
// Divides the right-hand side in row startM, by digonal value of A
// broadcasted to AInPacket.packet[startM-1] in the previous iteration.
//
// Without "if constexpr" the compiler instantiates the case <-1, numK>
// this is handled with enable_if to prevent out-of-bound warnings
// from the compiler
EIGEN_IF_CONSTEXPR(!isUnitDiag && startM > 0)
trsm::template divRHSByDiag<startM-1, numK>(RHSInPacket, AInPacket);
// After division, the rhs corresponding to subsequent rows of A can be partially updated
// We also broadcast the reciprocal of the next diagonal to AInPacket.packet[currentM] (if needed)
// to be used in the next iteration.
trsm::template
updateRHS<isARowMajor, isFWDSolve, isUnitDiag, startM, endM, numK, currentM>(
A_arr, LDA, RHSInPacket, AInPacket);
// Handle division for the RHS corresponding to the final row of A.
EIGEN_IF_CONSTEXPR(!isUnitDiag && startM == endM-1)
trsm::template divRHSByDiag<startM, numK>(RHSInPacket, AInPacket);
aux_triSolveMicroKernel<isARowMajor, isFWDSolve, isUnitDiag, endM, counter - 1, numK>(A_arr, LDA, RHSInPacket, AInPacket);
}
template <bool isARowMajor, bool isFWDSolve, bool isUnitDiag, int64_t endM, int64_t counter, int64_t numK>
static EIGEN_ALWAYS_INLINE std::enable_if_t<(counter <= 0)>
aux_triSolveMicroKernel(Scalar *A_arr, int64_t LDA, PacketBlock<vec,EIGEN_AVX_MAX_NUM_ACC> &RHSInPacket, PacketBlock<vec,EIGEN_AVX_MAX_NUM_ROW> &AInPacket)
{
EIGEN_UNUSED_VARIABLE(A_arr);
EIGEN_UNUSED_VARIABLE(LDA);
EIGEN_UNUSED_VARIABLE(RHSInPacket);
EIGEN_UNUSED_VARIABLE(AInPacket);
}
/********************************************************
* Wrappers for aux_XXXX to hide counter parameter
********************************************************/
/**
* Load endMxendK block of B to RHSInPacket
* Masked loads are used for cases where endK is not a multiple of PacketSize
*/
template<bool isFWDSolve, int64_t endM, int64_t endK, bool krem = false>
static EIGEN_ALWAYS_INLINE
void loadRHS(Scalar* B_arr, int64_t LDB, PacketBlock<vec,EIGEN_AVX_MAX_NUM_ACC> &RHSInPacket, int64_t rem = 0) {
aux_loadRHS<isFWDSolve, endM, endK, endM*endK, krem>(B_arr, LDB, RHSInPacket, rem);
}
/**
* Load endMxendK block of B to RHSInPacket
* Masked loads are used for cases where endK is not a multiple of PacketSize
*/
template<bool isFWDSolve, int64_t endM, int64_t endK, bool krem = false>
static EIGEN_ALWAYS_INLINE
void storeRHS(Scalar* B_arr, int64_t LDB, PacketBlock<vec,EIGEN_AVX_MAX_NUM_ACC> &RHSInPacket, int64_t rem = 0) {
aux_storeRHS<isFWDSolve, endM, endK, endM*endK, krem>(B_arr, LDB, RHSInPacket, rem);
}
/**
* Only used if Triangular matrix has non-unit diagonal values
*/
template<int64_t currM, int64_t endK>
static EIGEN_ALWAYS_INLINE
void divRHSByDiag(PacketBlock<vec,EIGEN_AVX_MAX_NUM_ACC> &RHSInPacket, PacketBlock<vec,EIGEN_AVX_MAX_NUM_ROW> &AInPacket) {
aux_divRHSByDiag<currM, endK, endK>(RHSInPacket, AInPacket);
}
/**
* Update right-hand sides (stored in avx registers)
* Traversing along the column A_{i,currentM}, where currentM <= i <= endM, and broadcasting each value to AInPacket.
**/
template<bool isARowMajor, bool isFWDSolve, bool isUnitDiag, int64_t startM, int64_t endM, int64_t endK, int64_t currentM>
static EIGEN_ALWAYS_INLINE
void updateRHS(Scalar *A_arr, int64_t LDA, PacketBlock<vec,EIGEN_AVX_MAX_NUM_ACC> &RHSInPacket, PacketBlock<vec,EIGEN_AVX_MAX_NUM_ROW> &AInPacket) {
aux_updateRHS<isARowMajor, isFWDSolve, isUnitDiag, startM, endM, endK, (endM-startM)*endK, currentM>(
A_arr, LDA, RHSInPacket, AInPacket);
}
/**
* endM: dimension of A. 1 <= endM <= EIGEN_AVX_MAX_NUM_ROW
* numK: number of avx registers to use for each row of B (ex fp32: 48 rhs => 3 avx reg used). 1 <= endK <= 3.
* isFWDSolve: true => forward substitution, false => backwards substitution
* isUnitDiag: true => triangular matrix has unit diagonal.
*/
template <bool isARowMajor, bool isFWDSolve, bool isUnitDiag, int64_t endM, int64_t numK>
static EIGEN_ALWAYS_INLINE
void triSolveMicroKernel(Scalar *A_arr, int64_t LDA, PacketBlock<vec,EIGEN_AVX_MAX_NUM_ACC> &RHSInPacket, PacketBlock<vec,EIGEN_AVX_MAX_NUM_ROW> &AInPacket) {
static_assert( numK >= 1 && numK <= 3, "numK out of range" );
aux_triSolveMicroKernel<isARowMajor, isFWDSolve, isUnitDiag, endM, endM, numK>(
A_arr, LDA, RHSInPacket, AInPacket);
}
};
/**
* Unrolls for gemm kernel
*
* isAdd: true => C += A*B, false => C -= A*B
*/
template <typename Scalar, bool isAdd>
class gemm {
public:
using vec = typename std::conditional<std::is_same<Scalar, float>::value, vecFullFloat, vecFullDouble>::type;
static constexpr int64_t PacketSize = packet_traits<Scalar>::size;
/***********************************
* Auxillary Functions for:
* - setzero
* - updateC
* - storeC
* - startLoadB
* - triSolveMicroKernel
************************************/
/**
* aux_setzero
*
* 2-D unroll
* for(startM = 0; startM < endM; startM++)
* for(startN = 0; startN < endN; startN++)
**/
template<int64_t endM, int64_t endN, int64_t counter>
static EIGEN_ALWAYS_INLINE std::enable_if_t<(counter > 0)>
aux_setzero(PacketBlock<vec,EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> &zmm) {
constexpr int64_t counterReverse = endM*endN-counter;
constexpr int64_t startM = counterReverse/(endN);
constexpr int64_t startN = counterReverse%endN;
zmm.packet[startN*endM + startM] = pzero(zmm.packet[startN*endM + startM]);
aux_setzero<endM, endN, counter-1>(zmm);
}
template<int64_t endM, int64_t endN, int64_t counter>
static EIGEN_ALWAYS_INLINE std::enable_if_t<(counter <= 0)>
aux_setzero(PacketBlock<vec,EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> &zmm)
{
EIGEN_UNUSED_VARIABLE(zmm);
}
/**
* aux_updateC
*
* 2-D unroll
* for(startM = 0; startM < endM; startM++)
* for(startN = 0; startN < endN; startN++)
**/
template<int64_t endM, int64_t endN, int64_t counter, bool rem>
static EIGEN_ALWAYS_INLINE std::enable_if_t<(counter > 0)>
aux_updateC(Scalar *C_arr, int64_t LDC, PacketBlock<vec,EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> &zmm, int64_t rem_ = 0) {
EIGEN_UNUSED_VARIABLE(rem_);
constexpr int64_t counterReverse = endM*endN-counter;
constexpr int64_t startM = counterReverse/(endN);
constexpr int64_t startN = counterReverse%endN;
EIGEN_IF_CONSTEXPR(rem)
zmm.packet[startN*endM + startM] =
padd(ploadu<vec>(&C_arr[(startN) * LDC + startM*PacketSize], remMask<PacketSize>(rem_)),
zmm.packet[startN*endM + startM],
remMask<PacketSize>(rem_));
else
zmm.packet[startN*endM + startM] =
padd(ploadu<vec>(&C_arr[(startN) * LDC + startM*PacketSize]), zmm.packet[startN*endM + startM]);
aux_updateC<endM, endN, counter-1, rem>(C_arr, LDC, zmm, rem_);
}
template<int64_t endM, int64_t endN, int64_t counter, bool rem>
static EIGEN_ALWAYS_INLINE std::enable_if_t<(counter <= 0)>
aux_updateC(Scalar *C_arr, int64_t LDC, PacketBlock<vec,EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> &zmm, int64_t rem_ = 0)
{
EIGEN_UNUSED_VARIABLE(C_arr);
EIGEN_UNUSED_VARIABLE(LDC);
EIGEN_UNUSED_VARIABLE(zmm);
EIGEN_UNUSED_VARIABLE(rem_);
}
/**
* aux_storeC
*
* 2-D unroll
* for(startM = 0; startM < endM; startM++)
* for(startN = 0; startN < endN; startN++)
**/
template<int64_t endM, int64_t endN, int64_t counter, bool rem>
static EIGEN_ALWAYS_INLINE std::enable_if_t<(counter > 0)>
aux_storeC(Scalar *C_arr, int64_t LDC, PacketBlock<vec,EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> &zmm, int64_t rem_ = 0) {
EIGEN_UNUSED_VARIABLE(rem_);
constexpr int64_t counterReverse = endM*endN-counter;
constexpr int64_t startM = counterReverse/(endN);
constexpr int64_t startN = counterReverse%endN;
EIGEN_IF_CONSTEXPR(rem)
pstoreu<Scalar>(&C_arr[(startN) * LDC + startM*PacketSize], zmm.packet[startN*endM + startM], remMask<PacketSize>(rem_));
else
pstoreu<Scalar>(&C_arr[(startN) * LDC + startM*PacketSize], zmm.packet[startN*endM + startM]);
aux_storeC<endM, endN, counter-1, rem>(C_arr, LDC, zmm, rem_);
}
template<int64_t endM, int64_t endN, int64_t counter, bool rem>
static EIGEN_ALWAYS_INLINE std::enable_if_t<(counter <= 0)>
aux_storeC(Scalar *C_arr, int64_t LDC, PacketBlock<vec,EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> &zmm, int64_t rem_ = 0)
{
EIGEN_UNUSED_VARIABLE(C_arr);
EIGEN_UNUSED_VARIABLE(LDC);
EIGEN_UNUSED_VARIABLE(zmm);
EIGEN_UNUSED_VARIABLE(rem_);
}
/**
* aux_startLoadB
*
* 1-D unroll
* for(startL = 0; startL < endL; startL++)
**/
template<int64_t unrollM, int64_t unrollN, int64_t endL, int64_t counter, bool rem>
static EIGEN_ALWAYS_INLINE std::enable_if_t<(counter > 0)>
aux_startLoadB(Scalar *B_t, int64_t LDB, PacketBlock<vec,EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> &zmm, int64_t rem_ = 0) {
EIGEN_UNUSED_VARIABLE(rem_);
constexpr int64_t counterReverse = endL-counter;
constexpr int64_t startL = counterReverse;
EIGEN_IF_CONSTEXPR(rem)
zmm.packet[unrollM*unrollN+startL] =
ploadu<vec>(&B_t[(startL/unrollM)*LDB + (startL%unrollM)*PacketSize], remMask<PacketSize>(rem_));
else
zmm.packet[unrollM*unrollN+startL] = ploadu<vec>(&B_t[(startL/unrollM)*LDB + (startL%unrollM)*PacketSize]);
aux_startLoadB<unrollM, unrollN, endL, counter-1, rem>(B_t, LDB, zmm, rem_);
}
template<int64_t unrollM, int64_t unrollN, int64_t endL, int64_t counter, bool rem>
static EIGEN_ALWAYS_INLINE std::enable_if_t<(counter <= 0)>
aux_startLoadB(
Scalar *B_t, int64_t LDB,
PacketBlock<vec,EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> &zmm, int64_t rem_ = 0)
{
EIGEN_UNUSED_VARIABLE(B_t);
EIGEN_UNUSED_VARIABLE(LDB);
EIGEN_UNUSED_VARIABLE(zmm);
EIGEN_UNUSED_VARIABLE(rem_);
}
/**
* aux_startBCastA
*
* 1-D unroll
* for(startB = 0; startB < endB; startB++)
**/
template<bool isARowMajor, int64_t unrollM, int64_t unrollN, int64_t endB, int64_t counter, int64_t numLoad>
static EIGEN_ALWAYS_INLINE std::enable_if_t<(counter > 0)>
aux_startBCastA(Scalar *A_t, int64_t LDA, PacketBlock<vec,EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> &zmm) {
constexpr int64_t counterReverse = endB-counter;
constexpr int64_t startB = counterReverse;
zmm.packet[unrollM*unrollN+numLoad+startB] = pload1<vec>(&A_t[idA<isARowMajor>(startB, 0,LDA)]);
aux_startBCastA<isARowMajor, unrollM, unrollN, endB, counter-1, numLoad>(A_t, LDA, zmm);
}
template<bool isARowMajor, int64_t unrollM, int64_t unrollN, int64_t endB, int64_t counter, int64_t numLoad>
static EIGEN_ALWAYS_INLINE std::enable_if_t<(counter <= 0)>
aux_startBCastA(Scalar *A_t, int64_t LDA, PacketBlock<vec,EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> &zmm)
{
EIGEN_UNUSED_VARIABLE(A_t);
EIGEN_UNUSED_VARIABLE(LDA);
EIGEN_UNUSED_VARIABLE(zmm);
}
/**
* aux_loadB
* currK: current K
*
* 1-D unroll
* for(startM = 0; startM < endM; startM++)
**/
template<int64_t endM, int64_t counter, int64_t unrollN, int64_t currK, int64_t unrollK, int64_t numLoad, int64_t numBCast, bool rem>
static EIGEN_ALWAYS_INLINE std::enable_if_t<(counter > 0)>
aux_loadB(Scalar *B_t, int64_t LDB, PacketBlock<vec,EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> &zmm, int64_t rem_ = 0) {
EIGEN_UNUSED_VARIABLE(rem_);
if ((numLoad/endM + currK < unrollK)) {
constexpr int64_t counterReverse = endM-counter;
constexpr int64_t startM = counterReverse;
EIGEN_IF_CONSTEXPR(rem) {
zmm.packet[endM*unrollN+(startM+currK*endM)%numLoad] =
ploadu<vec>(&B_t[(numLoad/endM + currK)*LDB + startM*PacketSize], remMask<PacketSize>(rem_));
}
else {
zmm.packet[endM*unrollN+(startM+currK*endM)%numLoad] =
ploadu<vec>(&B_t[(numLoad/endM + currK)*LDB + startM*PacketSize]);
}
aux_loadB<endM, counter-1, unrollN, currK, unrollK, numLoad, numBCast, rem>(B_t, LDB, zmm, rem_);
}
}
template<int64_t endM, int64_t counter, int64_t unrollN, int64_t currK, int64_t unrollK, int64_t numLoad, int64_t numBCast, bool rem>
static EIGEN_ALWAYS_INLINE std::enable_if_t<(counter <= 0)>
aux_loadB(
Scalar *B_t, int64_t LDB,
PacketBlock<vec,EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> &zmm, int64_t rem_ = 0)
{
EIGEN_UNUSED_VARIABLE(B_t);
EIGEN_UNUSED_VARIABLE(LDB);
EIGEN_UNUSED_VARIABLE(zmm);
EIGEN_UNUSED_VARIABLE(rem_);
}
/**
* aux_microKernel
*
* 3-D unroll
* for(startM = 0; startM < endM; startM++)
* for(startN = 0; startN < endN; startN++)
* for(startK = 0; startK < endK; startK++)
**/
template<bool isARowMajor, int64_t endM, int64_t endN, int64_t endK, int64_t counter, int64_t numLoad, int64_t numBCast, bool rem>
static EIGEN_ALWAYS_INLINE std::enable_if_t<(counter > 0)>
aux_microKernel(
Scalar *B_t, Scalar* A_t, int64_t LDB, int64_t LDA,
PacketBlock<vec,EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> &zmm, int64_t rem_ = 0) {
EIGEN_UNUSED_VARIABLE(rem_);
constexpr int64_t counterReverse = endM*endN*endK-counter;
constexpr int startK = counterReverse/(endM*endN);
constexpr int startN = (counterReverse/(endM))%endN;
constexpr int startM = counterReverse%endM;
EIGEN_IF_CONSTEXPR(startK == 0 && startM == 0 && startN == 0) {
gemm:: template
startLoadB<endM, endN, numLoad, rem>(B_t, LDB, zmm, rem_);
gemm:: template
startBCastA<isARowMajor, endM, endN, numBCast, numLoad>(A_t, LDA, zmm);
}
{
// Interleave FMA and Bcast
EIGEN_IF_CONSTEXPR(isAdd) {
zmm.packet[startN*endM + startM] =
pmadd(zmm.packet[endM*endN+numLoad+(startN+startK*endN)%numBCast],
zmm.packet[endM*endN+(startM+startK*endM)%numLoad], zmm.packet[startN*endM + startM]);
}
else {
zmm.packet[startN*endM + startM] =
pnmadd(zmm.packet[endM*endN+numLoad+(startN+startK*endN)%numBCast],
zmm.packet[endM*endN+(startM+startK*endM)%numLoad], zmm.packet[startN*endM + startM]);
}
// Bcast
EIGEN_IF_CONSTEXPR(startM == endM - 1 && (numBCast + startN + startK*endN < endK*endN)) {
zmm.packet[endM*endN+numLoad+(startN+startK*endN)%numBCast] =
pload1<vec>(&A_t[idA<isARowMajor>((numBCast + startN + startK*endN)%endN,
(numBCast + startN + startK*endN)/endN, LDA)]);
}
}
// We have updated all accumlators, time to load next set of B's
EIGEN_IF_CONSTEXPR( (startN == endN - 1) && (startM == endM - 1) ) {
gemm::template loadB<endM, endN, startK, endK, numLoad, numBCast, rem>(B_t, LDB, zmm, rem_);
}
aux_microKernel<isARowMajor, endM, endN, endK, counter-1, numLoad, numBCast, rem>(B_t, A_t, LDB, LDA, zmm, rem_);
}
template<bool isARowMajor, int64_t endM, int64_t endN, int64_t endK, int64_t counter, int64_t numLoad, int64_t numBCast, bool rem>
static EIGEN_ALWAYS_INLINE std::enable_if_t<(counter <= 0)>
aux_microKernel(
Scalar *B_t, Scalar* A_t, int64_t LDB, int64_t LDA,
PacketBlock<vec,EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> &zmm, int64_t rem_ = 0)
{
EIGEN_UNUSED_VARIABLE(B_t);
EIGEN_UNUSED_VARIABLE(A_t);
EIGEN_UNUSED_VARIABLE(LDB);
EIGEN_UNUSED_VARIABLE(LDA);
EIGEN_UNUSED_VARIABLE(zmm);
EIGEN_UNUSED_VARIABLE(rem_);
}
/********************************************************
* Wrappers for aux_XXXX to hide counter parameter
********************************************************/
template<int64_t endM, int64_t endN>
static EIGEN_ALWAYS_INLINE
void setzero(PacketBlock<vec,EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> &zmm){
aux_setzero<endM, endN, endM*endN>(zmm);
}
/**
* Ideally the compiler folds these into vaddp{s,d} with an embedded memory load.
*/
template<int64_t endM, int64_t endN, bool rem = false>
static EIGEN_ALWAYS_INLINE
void updateC(Scalar *C_arr, int64_t LDC, PacketBlock<vec,EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> &zmm, int64_t rem_ = 0){
EIGEN_UNUSED_VARIABLE(rem_);
aux_updateC<endM, endN, endM*endN, rem>(C_arr, LDC, zmm, rem_);
}
template<int64_t endM, int64_t endN, bool rem = false>
static EIGEN_ALWAYS_INLINE
void storeC(Scalar *C_arr, int64_t LDC, PacketBlock<vec,EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> &zmm, int64_t rem_ = 0){
EIGEN_UNUSED_VARIABLE(rem_);
aux_storeC<endM, endN, endM*endN, rem>(C_arr, LDC, zmm, rem_);
}
/**
* Use numLoad registers for loading B at start of microKernel
*/
template<int64_t unrollM, int64_t unrollN, int64_t endL, bool rem>
static EIGEN_ALWAYS_INLINE
void startLoadB(Scalar *B_t, int64_t LDB, PacketBlock<vec,EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> &zmm, int64_t rem_ = 0){
EIGEN_UNUSED_VARIABLE(rem_);
aux_startLoadB<unrollM, unrollN, endL, endL, rem>(B_t, LDB, zmm, rem_);
}
/**
* Use numBCast registers for broadcasting A at start of microKernel
*/
template<bool isARowMajor, int64_t unrollM, int64_t unrollN, int64_t endB, int64_t numLoad>
static EIGEN_ALWAYS_INLINE
void startBCastA(Scalar *A_t, int64_t LDA, PacketBlock<vec,EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> &zmm){
aux_startBCastA<isARowMajor, unrollM, unrollN, endB, endB, numLoad>(A_t, LDA, zmm);
}
/**
* Loads next set of B into vector registers between each K unroll.
*/
template<int64_t endM, int64_t unrollN, int64_t currK, int64_t unrollK, int64_t numLoad, int64_t numBCast, bool rem>
static EIGEN_ALWAYS_INLINE
void loadB(
Scalar *B_t, int64_t LDB, PacketBlock<vec,EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> &zmm, int64_t rem_ = 0){
EIGEN_UNUSED_VARIABLE(rem_);
aux_loadB<endM, endM, unrollN, currK, unrollK, numLoad, numBCast, rem>(B_t, LDB, zmm, rem_);
}
/**
* Generates a microkernel for gemm (row-major) with unrolls {1,2,4,8}x{U1,U2,U3} to compute C -= A*B.
* A matrix can be row/col-major. B matrix is assumed row-major.
*
* isARowMajor: is A row major
* endM: Number registers per row
* endN: Number of rows
* endK: Loop unroll for K.
* numLoad: Number of registers for loading B.
* numBCast: Number of registers for broadcasting A.
*
* Ex: microkernel<isARowMajor,0,3,0,4,0,4,6,2>: 8x48 unroll (24 accumulators), k unrolled 4 times,
* 6 register for loading B, 2 for broadcasting A.
*
* Note: Ideally the microkernel should not have any register spilling.
* The avx instruction counts should be:
* - endK*endN vbroadcasts{s,d}
* - endK*endM vmovup{s,d}
* - endK*endN*endM FMAs
*
* From testing, there are no register spills with clang. There are register spills with GNU, which
* causes a performance hit.
*/
template<bool isARowMajor, int64_t endM, int64_t endN, int64_t endK, int64_t numLoad, int64_t numBCast, bool rem = false>
static EIGEN_ALWAYS_INLINE
void microKernel(
Scalar *B_t, Scalar* A_t, int64_t LDB, int64_t LDA,
PacketBlock<vec,EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> &zmm, int64_t rem_ = 0){
EIGEN_UNUSED_VARIABLE(rem_);
aux_microKernel<isARowMajor,endM, endN, endK, endM*endN*endK, numLoad, numBCast, rem>(
B_t, A_t, LDB, LDA, zmm, rem_);
}
};
} // namespace unrolls
#endif //EIGEN_UNROLLS_IMPL_H