Gemm

Function Usage

This API has been deprecated and will be removed in later versions. Do not use this API.

Multiplies two tensors and outputs a result tensor. Multiply matrix A and matrix B to obtain matrix C, and output matrix C.

Prototype

  • Functional APIs
    1
    2
    template <typename dst_T, typename src0_T, typename src1_T>
    __aicore__ inline void Gemm(const LocalTensor<dst_T>& dstLocal, const LocalTensor<src0_T>& src0Local, const LocalTensor<src1_T>& src1Local, const uint32_t m, const uint32_t k, const uint32_t n, GemmTiling tilling, bool partialsum = true, int32_t initValue = 0)
    
  • Tiling compute API
    1
    2
    template <typename T>
    __aicore__ inline GemmTiling GetGemmTiling(uint32_t m, uint32_t k, uint32_t n)
    

Parameters

Table 1 Parameters

Parameter

Input/Output

Meaning

dstLocal

Output

Destination operand.

The Atlas Training Series Product supports QuePosition values CO1 and CO2.

src0Local

Input

Source operand. QuePosition is A1.

src1Local

Input

Source operand. QuePosition is B1.

m

Input

Valid height of the left matrix Src0Local. The value range is [1, 4096].

Note: m does not need to be rounded up to a multiple of 16.

k

Input

Valid width of the left matrix Src0Local and valid height of the right matrix Src1Local.
  • If tensor Src0Local is of type float, the value range is [1, 8192].
  • If tensor Src0Local is of type half, the value range is [1, 16384].
  • If tensor Src0Local is of type int8_t, the value range is [1, 32768].

Note: k does not need to be rounded up to a multiple of 16.

n

Input

Valid width of the right matrix Src1Local. The value range is [1, 4096].

Note: n does not need to be rounded up to a multiple of 16.

tilling

Input

Tilling rule. The type is GemmTiling. The structure is defined as follows:

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
struct GemmTiling {
    const uint32_t blockSize = 16;
    LoopMode loopMode = LoopMode::MODE_NM;
    uint32_t mNum = 0;
    uint32_t nNum = 0;
    uint32_t kNum = 0;
    uint32_t roundM = 0;
    uint32_t roundN = 0;
    uint32_t roundK = 0;
    uint32_t c0Size = 32;
    uint32_t dtypeSize = 1;
    uint32_t mBlockNum = 0;
    uint32_t nBlockNum = 0;
    uint32_t kBlockNum = 0;
    uint32_t mIterNum = 0;
    uint32_t nIterNum = 0;
    uint32_t kIterNum = 0;
    uint32_t mTileBlock = 0;
    uint32_t nTileBlock = 0;
    uint32_t kTileBlock = 0;
    uint32_t kTailBlock = 0;
    uint32_t mTailBlock = 0;
    uint32_t nTailBlock = 0;
    bool kHasTail = false;
    bool mHasTail = false;
    bool nHasTail = false;
    bool kHasTailEle = false;
    uint32_t kTailEle = 0;
    uint32_t kThreadNum = 0;
};

For details about the parameter description, see Table 3.

partialsum

Input

When QuePosition where the dstLocal parameter is located is set to CO2, this parameter is used to control whether the computation result is moved out.
  • 0: move out computation result.
  • 1: The computation result is not moved out but used for subsequent computation.

initValue

Input

Indicates whether dstLocal needs to be initialized.
  • 0: dstLocal needs to be initialized. The dstLocal initial matrix stores the previous conv2d result and will be added up with the new conv2d result.
  • 1: dstLocal does not need to be initialized. The dstLocal initial matrix will be overwritten by the compute result.
Table 2 Data type combination of feature_map, weight, and dst

src0Local.dtype

src1Local.dtype

dstLocal.dtype

int8_t

int8_t

int32_t

half

half

float

half

half

half

Table 3 Parameters in the GemmTiling structure

Parameter

Input/Output

Meaning

blockSize

uint32_t

Number of elements stored in a dimension. The value is fixed at 16.

loopMode

LoopMode

Traversal mode. The structure is defined as follows:

1
2
3
4
5
6
enum class LoopMode {
    MODE_NM = 0,
    MODE_MN = 1,
    MODE_KM = 2,
    MODE_KN = 3
};

mNum

uint32_t

Equivalent data length of the M axis. The value range is [1, 4096].

nNum

uint32_t

Equivalent data length of the N axis. The value range is [1, 4096].

kNum

uint32_t

Equivalent data length of the K axis.
  • If tensor Src0Local is of type float, the value range is [1, 8192].
  • If tensor Src0Local is of type half, the value range is [1, 16384].
  • If tensor Src0Local is of type int8_t, the value range is [1, 32768].

roundM

uint32_t

Equivalent data length of the M axis. The value is rounded up to an integer multiple of blockSize. The value range is [1, 4096].

roundN

uint32_t

Equivalent data length of the N axis. The value is rounded up to an integer multiple of blockSize. The value range is [1, 4096].

roundK

uint32_t

Equivalent data length of the K axis. The value is rounded up to a multiple of c0Size.
  • If tensor Src0Local is of type float, the value range is [1, 8192].
  • If tensor Src0Local is of type half, the value range is [1, 16384].
  • If tensor Src0Local is of type int8_t, the value range is [1, 32768].

c0Size

uint32_t

Length of a block. The value can be 16 or 32.

dtypeSize

uint32_t

Length of the input data, in bytes. The value range is [1, 2].

mBlockNum

uint32_t

Number of blocks on the M axis. mBlockNum = mNum/blockSize

nBlockNum

uint32_t

Number of blocks on the N axis. nBlockNum = nNum/blockSize

kBlockNum

uint32_t

Number of blocks on the K axis. kBlockNum = kNum/blockSize

mIterNum

uint32_t

Number of traversed dimensions. The value range is [1, 4096].

nIterNum

uint32_t

Number of traversed dimensions. The value range is [1, 4096].

kIterNum

uint32_t

Number of traversed dimensions. The value range is [1, 4096].

mTileBlock

uint32_t

Number of split blocks on the M axis. The value range is [1, 4096].

nTileBlock

uint32_t

Number of split blocks on the N axis. The value range is [1, 4096].

kTileBlock

uint32_t

Number of split blocks on the K axis. The value range is [1, 4096].

kTailBlock

uint32_t

Number of tail blocks on the K axis. The value range is [1, 4096].

mTailBlock

uint32_t

Number of tail blocks on the M axis. The value range is [1, 4096].

nTailBlock

uint32_t

Number of tail blocks on the N axis. The value range is [1, 4096].

kHasTail

bool

Indicates whether a tail block exists on the K axis.

mHasTail

bool

Indicates whether a tail block exists on the M axis.

nHasTail

bool

Indicates whether a tail block exists on the N axis.

kHasTailEle

bool

Indicates whether the tail block element exists.

kTailEle

uint32_t

Tail block element on the K axis. The value range is [1, 4096].

Availability

Atlas Training Series Product

Precautions

  • The m, k, and n arguments do not need to be rounded up to multiples of 16 pixels. However, due to hardware restrictions, the shape of operands dstLocal, Src0Local, and Src1Local must meet the following alignment requirements. The m and n arguments must be rounded up to multiples of 16 pixels, and the k argument must be rounded up to multiples of 16 or 32 pixels, depending on the operand data type.
  • For details about the alignment requirements of the operand address offset, see General Restrictions.

Example

In this example, the shape of the left matrix is [m,k], and the shape of the right matrix is [k,n]. The computation result is moved to GM, and the destination matrix does not need to be initialized.
  1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
#include "kernel_operator.h"

class KernelCubeGEMM {
public:
    __aicore__ inline KernelCubeGEMM() {}
    __aicore__ inline void Init(__gm__ uint8_t* fmGm, __gm__ uint8_t* weGm, __gm__ uint8_t* dstGm, uint32_t mInput,
        uint32_t kInput, uint32_t nInput, bool initVal, AscendC::LoopMode mode)
    {
        m = mInput;
        k = kInput;
        n = nInput;

        initValue = initVal;
        loopMode = mode;

        featureMapA1Size = m * k;
        weightA1Size = k * n;
        dstCO1Size = m * n;

        roundm = AscendC::DivCeil(m, 16) * 16;
        roundn = AscendC::DivCeil(n, 16) * 16;
        roundk = AscendC::DivCeil(k, c0Size) * c0Size;

        fmGlobal.SetGlobalBuffer((__gm__ half*)fmGm);
        weGlobal.SetGlobalBuffer((__gm__ half*)weGm);
        dstGlobal.SetGlobalBuffer((__gm__ float*)dstGm);

        pipe.InitBuffer(inQueueFmA1, 1, featureMapA1Size * sizeof(half));
        pipe.InitBuffer(inQueueWeB1, 1, weightA1Size * sizeof(half));
        pipe.InitBuffer(outQueueCO1, 1, dstCO1Size * sizeof(float));
        pipe.InitBuffer(outQueueUB, 1, dstCO1Size * sizeof(float));
    }
    __aicore__ inline void Process()
    {
        CopyIn();
        Compute();
        CopyUB();
        CopyOut();
    }

private:
    __aicore__ inline void CopyIn()
    {
        AscendC::LocalTensor<half> featureMapA1 = inQueueFmA1.AllocTensor<half>();
        AscendC::LocalTensor<half> weightB1 = inQueueWeB1.AllocTensor<half>();

        AscendC::DataCopy(featureMapA1, fmGlobal, featureMapA1Size);
        AscendC::DataCopy(weightB1, weGlobal, weightA1Size);

        inQueueFmA1.EnQue(featureMapA1);
        inQueueWeB1.EnQue(weightB1);
    }

    __aicore__ inline void Compute()
    {
        AscendC::LocalTensor<half> featureMapA1 = inQueueFmA1.DeQue<half>();
        AscendC::LocalTensor<half> weightB1 = inQueueWeB1.DeQue<half>();
        AscendC::LocalTensor<float> dstCO1 = outQueueCO1.AllocTensor<float>();

        AscendC::GemmTiling tilling = GetGemmTiling<half>(m, k, n);
        tilling.loopMode = loopMode;
        // The shape of the left matrix is [m,k], and the shape of the right matrix is [k,n]. The computation result is moved to GM, and the destination matrix does not need to be initialized.
        AscendC::Gemm(dstCO1, featureMapA1, weightB1, m, k, n, tilling, false, initValue);

        outQueueCO1.EnQue<float>(dstCO1);
        inQueueFmA1.FreeTensor(featureMapA1);
        inQueueWeB1.FreeTensor(weightB1);
    }

    __aicore__ inline void CopyUB()
    {
        AscendC::LocalTensor<float> dstCO1 = outQueueCO1.DeQue<float>();
        AscendC::LocalTensor<float> dstUB = outQueueUB.AllocTensor<float>();

        AscendC::DataCopyParams dataCopyParams;
        dataCopyParams.blockCount = 1;
        dataCopyParams.blockLen = roundm * roundn * sizeof(float) / 1024;
        AscendC::DataCopyEnhancedParams enhancedParams;
        enhancedParams.blockMode = BlockMode::BLOCK_MODE_MATRIX;

        AscendC::DataCopy(dstUB, dstCO1, dataCopyParams, enhancedParams);

        outQueueUB.EnQue<float>(dstUB);
        outQueueCO1.FreeTensor(dstCO1);
    }

    __aicore__ inline void CopyOut()
    {
        AscendC::LocalTensor<float> dstUB = outQueueUB.DeQue<float>();
        AscendC::DataCopy(dstGlobal, dstUB, roundm * roundn);
        outQueueUB.FreeTensor(dstUB);
    }

private:
    AscendC::TPipe pipe;
    // feature map queue
    AscendC::TQue<AscendC::QuePosition::A1, 1> inQueueFmA1;
    // weight queue
    AscendC::TQue<AscendC::QuePosition::B1, 1> inQueueWeB1;
    // dst queue
    AscendC::TQue<AscendC::QuePosition::CO1, 1> outQueueCO1;

    AscendC::TQue<AscendC::QuePosition::VECOUT, 1> outQueueUB;

    AscendC::GlobalTensor<half> fmGlobal, weGlobal;
    AscendC::GlobalTensor<float> dstGlobal;

    uint16_t m;
    uint16_t k;
    uint16_t n;
    uint32_t roundm, roundk, roundn;

    uint32_t c0Size = 16;
    bool initValue = false;
    AscendC::LoopMode loopMode = AscendC::LoopMode::MODE_NM;

    uint32_t featureMapA1Size, weightA1Size, dstCO1Size;
};

extern "C" __global__ __aicore__ void cube_gemm_simple_kernel(__gm__ uint8_t* fmGm, __gm__ uint8_t* weGm,
    __gm__ uint8_t* dstGm, uint32_t m, uint32_t k, uint32_t n, bool initValue, LoopMode mode)
{
    KernelCubeGEMM op;
    // In the preceding example, the input parameters are: m = 32, k = 64, n = 32, initValue = false, mode = LoopMode::MODE_NM.
    op.Init(fmGm, weGm, dstGm, m, k, n, initValue, mode);
    op.Process();
}