SoftmaxFlashV3

Function Usage

Serves as the enhanced version of SoftmaxFlash, corresponding to the Softmax PASA algorithm. If the product of non-last axis lengths m0, m1, ..., mt of the input tensor [m0, m1, ..., mt, n] (t ≥ 0) is considered as m, the shape of the input tensor is [m, n]. This API performs the following computation on the input tensor [m, n] by row. Different values of update correspond to different formulas, where x, inmax, insum, and inmean are inputs, and M, S, and E are outputs.

  • If update is false, the formulas are as follows.

  • If update is true, the formulas are as follows.

Currently, this API supports the input in ND format only. The internal reduction process is processed based on the last axis.

For ease of understanding, the Python pseudocode is used to express the formulas as follows: repeatSize is 64; elementNumPerBlk and BlkcntPerRepeat are 8; splitMeanCnt is 8; src, inmean, inmax, insum, and update are inputs; dst, x_mean, x_sum, x_max, and exp_max are outputs.
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
def softmax_flash_3(src, height, width, loopCnt, alpha, baseK, inmax=None, insum=None, inmean=None, update=False):
    scalar = alapha / (1 - alapha)
    #(m,n)->(m,64)
    tmpbuffer0 = BlockReduceSum(repeatSize, repeatSize, elementNumPerBlk)
    remain = int(width / repeatSize - BlkcntPerRepeat)
    tmpbuffer0 = Add(tmpbuffer0, src, remain, repeatSize * elementNumPerBlk, width)
    #(m,64)->(m,8)
    tmpbuffer0 = BlockReduceSum(1, relementNumPerBlk, elementNumPerBlk)
    #width = baseK * splitMeanCnt
    rowMeanLocal = tmpbuffer0 / baseK
    rowMeanGlobal = np.mean(src, axis=(-1), keepdims=True)
    rowMeanGlobalTmp = (rowMeanGlobal - rowMeanLocal) * scalar
    src = src - rowMeanGlobalTmp 

    if update == False:
        x_mean = rowMeanGlobal
        maxTmp = np.max(src, axis=-1, keepdims=True)
        shiftCurr = (rowMeanGlobal - x_mean) * scalar
        x_max = shiftCurr + maxTmp
        maxTmp = x_max - shiftCurr
        x_sub = src - maxTmp   
        dst = np.exp(x_sub) 
        x_sum = np.sum(dst, axis=-1, keepdims=True)
        exp_max = None
        return dst, x_max, x_sum, x_mean, exp_max
    else:
        x_mean = (rowMeanGlobal + inmean * (loopCnt - 1)) / loopCnt
        maxTmp = np.max(src, axis=-1, keepdims=True)
        shiftCurr = (rowMeanGlobal - x_mean) * scalar
        shiftPrev = (inmean - x_mean) * scalar
	x_max = shiftCurr + maxTmp
        maxTmp = shiftPrev + inmax
        x_max = np.max(np.concatenate((x_max, maxTmp), axis=(-1)), axis=(-1), keepdims=True)
        maxTmp = x_max - shiftCurr
        x_sub = src - maxTmp   
        dst = np.exp(x_sub)
        exp_max = np.exp(inmax - x_max + shiftPrev)
        x_sum = np.sum(x_exp, axis=-1, keepdims=True)
        x_sum = exp_max * insum +  x_sum
        return x_exp, x_max, x_sum, x_mean, exp_max

Prototype

  • Allocate the temporary space through the API framework.
    1
    2
    template <typename T, typename U, bool isUpdate = false, bool isReuseSource = false, bool isBasicBlock = false, bool isDataFormatNZ = false, const SoftmaxConfig& config = SOFTMAX_DEFAULT_CFG>
    __aicore__ inline void SoftmaxFlashV3(const LocalTensor<T>& dstTensor, const LocalTensor<U>& meanTensor, const LocalTensor<U>& expSumTensor, const LocalTensor<U>& maxTensor, const LocalTensor<T>& srcTensor, const LocalTensor<T>& expMaxTensor, const LocalTensor<U>& inMeanTensor, const LocalTensor<U>& inExpSumTensor, const LocalTensor<U>& inMaxTensor, const SoftMaxTiling& tiling, const SoftMaxParams& params)
    
  • Pass the temporary space through the sharedTmpBuffer input parameter.
    1
    2
    template <typename T, typename U, bool isUpdate = false, bool isReuseSource = false, bool isBasicBlock = false, bool isDataFormatNZ = false, const SoftmaxConfig& config = SOFTMAX_DEFAULT_CFG>
    __aicore__ inline void SoftmaxFlashV3(const LocalTensor<T>& dstTensor, const LocalTensor<U>& meanTensor,const LocalTensor<U>& expSumTensor, const LocalTensor<U>& maxTensor, const LocalTensor<T>& srcTensor,const LocalTensor<T>& expMaxTensor, const LocalTensor<U>& inMeanTensor, const LocalTensor<U>& inExpSumTensor, const LocalTensor<U>& inMaxTensor, const LocalTensor<uint8_t>& sharedTmpBuffer, const SoftMaxTiling& tiling, const SoftMaxParams& params)
    

Due to the complex computation involved in the internal implementation of this API, additional temporary space is required to store intermediate variables generated during computation. The temporary space can be allocated through the API framework or passed by developers through the sharedTmpBuffer input parameter.

  • When the API framework is used for temporary space allocation, developers do not need to allocate the space, but must reserve the required size for the space.
  • When the sharedTmpBuffer input parameter is used for passing the temporary space, the tensor serves as the temporary space. In this case, the API framework is not required for temporary space allocation. This enables developers to manage the sharedTmpBuffer space and reuse the buffer after calling the API, so that the buffer is not repeatedly allocated and deallocated, improving the flexibility and buffer utilization.

If the API framework is used, developers must reserve the temporary space. If sharedTmpBuffer is used, developers must allocate space for the tensor. The method of obtaining the temporary space size (BufferSize) is as follows: Obtain the required minimum and maximum temporary space sizes using the GetSoftMaxFlashV3MaxMinTmpSize API described in SoftmaxFlashV3 Tiling. The minimum space can ensure correct functionality, while the maximum space is used to improve performance.

Parameters

Table 1 Parameters in the template

Parameter

Description

T

Data types of the input srcTensor and expMaxTensor, and the output dstTensor operands.

U

Data types of the input inMeanTensor, inExpSumTensor, and inMaxTensor and the output meanTensor, expSumTensor, and maxTensor operands.

isUpdate

Whether to set update to true in the computation.

isReuseSource

Reserved for future use. The default value false must be used.

isBasicBlock

Reserved for future use. The default value false must be used.

isDataFormatNZ

Reserved for future use. The default value false must be used.

config

Reserved for future use. The default value SOFTMAX_DEFAULT_CFG must be used.

Table 2 API parameters

Parameter

Input/Output

Description

dstTensor

Output

Destination operand.

The type is LocalTensor, and the supported TPosition is VECIN, VECCALC, or VECOUT.

The shape of dstTensor is the same as that of the source operand srcTensor.

meanTensor

Output

Destination operand.

The type is LocalTensor, and the supported TPosition is VECIN, VECCALC, or VECOUT.

It is used to store the mean value result during softmax computation.

  • The length of the last axis of meanTensor is fixed at 32 bytes, that is, the length of a data block. All data in this data block has the same value. For example, in the float data type, all eight numbers in this data block possess an identical value, that is, the mean value obtained after reducesum.
  • The length of each non-last axis is the same as that of dstTensor.

expSumTensor

Output

Destination operand.

The type is LocalTensor, and the supported TPosition is VECIN, VECCALC, or VECOUT.

It is used to store the reducesum result during softmax computation.

  • The length of the last axis of expSumTensor is fixed at 32 bytes, that is, the length of a data block. All data in this data block has the same value. For example, in the float data type, all eight numbers in this data block possess an identical value obtained after reducesum.
  • The length of each non-last axis is the same as that of dstTensor.

maxTensor

Output

Destination operand.

The type is LocalTensor, and the supported TPosition is VECIN, VECCALC, or VECOUT.

It is used to store the reducemax result during softmax computation.

  • The length of the last axis of maxTensor is fixed at 32 bytes, that is, the length of a data block. All data in this data block has the same value. For example, in the float data type, all eight numbers in this data block possess an identical value obtained after reducemax.
  • The length of each non-last axis is the same as that of dstTensor.

srcTensor

Input

Source operand.

The type is LocalTensor, and the supported TPosition is VECIN, VECCALC, or VECOUT.

The length of the last axis must be 32-byte aligned.

expMaxTensor

Output

Destination operand.

The type is LocalTensor, and the supported TPosition is VECIN, VECCALC, or VECOUT.

  • The length of the last axis of expMaxTensor is fixed at 32 bytes, that is, the length of a data block. All data in this data block has the same value. For example, in the half data type, all 16 numbers in this data block possess an identical value.
  • The length of each non-last axis is the same as that of dstTensor.

inMeanTensor

Input

Source operand.

The type is LocalTensor, and the supported TPosition is VECIN, VECCALC, or VECOUT.

Mean value required for softmax computation.

  • The length of the last axis of inMeanTensor is fixed at 32 bytes, that is, the length of a data block. All data in this data block has the same value. For example, in the float data type, all eight numbers in this data block possess an identical value.
  • The length of each non-last axis is the same as that of dstTensor.

inExpSumTensor

Input

Source operand.

The type is LocalTensor, and the supported TPosition is VECIN, VECCALC, or VECOUT.

Sum value required for softmax computation.

  • The length of the last axis of inExpSumTensor is fixed at 32 bytes, that is, the length of a data block. All data in this data block has the same value. For example, in the float data type, all eight numbers in this data block possess an identical value.
  • The length of each non-last axis is the same as that of dstTensor.

inMaxTensor

Input

Source operand.

The type is LocalTensor, and the supported TPosition is VECIN, VECCALC, or VECOUT.

Max value required for softmax computation.

  • The length of the last axis of inMaxTensor is fixed at 32 bytes, that is, the length of a data block. All data in this data block has the same value. For example, in the float data type, all eight numbers in this data block possess an identical value.
  • The length of each non-last axis is the same as that of dstTensor.

sharedTmpBuffer

Input

Temporary space.

The type is LocalTensor, and the supported TPosition is VECIN, VECCALC, or VECOUT.

The data type of this operand is fixed at uint8_t.

This parameter is used to store intermediate variables during complex internal API computation and is provided by developers.

For details about how to obtain the temporary space size (BufferSize), see SoftmaxFlashV3 Tiling.

tiling

Input

Tiling information required for SoftmaxFlashV3 computation. For details about how to obtain the tiling information, see SoftmaxFlashV3 Tiling.

params

Input

Shape information and computation parameters of srcTensor. It is of the SoftMaxParams type. The specific definition is as follows:

1
2
3
4
5
6
7
8
9
struct SoftMaxParams {
uint32_t srcM; // Product of lengths of non-last axes.
uint32_t srcK; // Length of the last axis, which must be 32-byte aligned.
uint32_t oriSrcM; // Product of lengths of original non-last axes.
uint32_t oriSrcK;  // Length of the original last axis.
uint32_t loopCnt; // loopCnt value in the formula when update is true. This parameter is greater than or equal to 1.
uint32_t splitMeanCnt; // Number of blocks for calculating the mean value of each row in the formula. Currently, the value can only be 8.
float alpha; // Computation parameter in the formula
};

Note: This API does not support the non-alignment scenario. Therefore, srcM is equal to oriSrcM, and srcK is equal to oriSrcK.

Returns

None

Availability

Precautions

  • For details about the alignment requirements of the operand address offset, see General Restrictions.
  • For the input srcTensor, the last axis length n must be greater than or equal to 512 and be a multiple of 64. The product m of non-last axis lengths is a multiple of 8.
  • The tensor space of srcTensor and dstTensor, meanTensor and inMeanTensor, maxTensor and inMaxTensor, and expSumTensor and inExpSumTensor can be reused.
  • For the tensor space of meanTensor, expSumTensor, maxTensor, expMaxTensor, inMeanTensor, or inExpSumTensor, and inMaxTensor, the length of the last axis must be 32 bytes.

Example

In this example, the shape size of the input srcTensor and output dstTensor is [8, 1024]. The shape size of the input inMeanTensor, inExpSumTensor, and inMaxTensor is [8, 8], and the data type is float. The shape size of the output expMaxTensor is [8, 16], and the data type is half. The format of the input and output data is ND. The space of srcTensor and dstTensor is not reused. isUpdate is true.
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
#include "kernel_operator.h"

template <typename T, typename U>
class KernelSoftmaxFlashV3 {
public:
    __aicore__ inline KernelSoftmaxFlashV3() {}
    __aicore__ inline void Init(__gm__ uint8_t *srcGm, __gm__ uint8_t *inMaxGm, __gm__ uint8_t *inSumGm,
       __gm__ uint8_t *inMeanGm, __gm__ uint8_t *dstGm, const SoftMaxTiling &tilingData)
    {
        srcGlobal.SetGlobalBuffer((__gm__ T *)srcGm);
        dstGlobal.SetGlobalBuffer((__gm__ T *)dstGm);
        maxGlobal.SetGlobalBuffer((__gm__ U *)inMaxGm);
        sumGlobal.SetGlobalBuffer((__gm__ U *)inSumGm);
        meanGlobal.SetGlobalBuffer((__gm__ U *)inMeanGm);
        pipe.InitBuffer(inQueueSrc, 1, height * width * sizeof(T));
        elementNumPerBlk1 = 32 / sizeof(U);
        pipe.InitBuffer(maxQueue, 1, height * elementNumPerBlk1 * sizeof(U));
        pipe.InitBuffer(sumQueue, 1, height * elementNumPerBlk1 * sizeof(U));
        pipe.InitBuffer(meanQueue, 1, height * elementNumPerBlk1 * sizeof(U));
        elementNumPerBlk2 = 32 / sizeof(T);
        pipe.InitBuffer(expMaxQueue, 1, height * elementNumPerBlk2 * sizeof(T));
        pipe.InitBuffer(outQueueDst, 1, height * width * sizeof(T));
        tiling = tilingData;
    }
    __aicore__ inline void Process()
    {
        CopyIn();
        Compute();
        CopyOut();
    }

private:
    __aicore__ inline void CopyIn()
    {
        AscendC::LocalTensor<T> srcLocal = inQueueSrc.AllocTensor<T>();
        AscendC::LocalTensor<U> insumLocal = sumQueue.AllocTensor<U>();
        AscendC::LocalTensor<U> inmaxLocal = maxQueue.AllocTensor<U>();
        AscendC::LocalTensor<U> inmeanLocal = meanQueue.AllocTensor<U>();
        AscendC::DataCopy(srcLocal, srcGlobal, height * width);
        AscendC::DataCopy(insumLocal, sumGlobal, height * elementNumPerBlk1);
        AscendC::DataCopy(inmaxLocal, maxGlobal, height * elementNumPerBlk1);
        AscendC::DataCopy(inmeanLocal, meanGlobal, height * elementNumPerBlk1);
        inQueueSrc.EnQue(srcLocal);
        sumQueue.EnQue(insumLocal);
        maxQueue.EnQue(inmaxLocal);
        meanQueue.EnQue(inmeanLocal);
    }
    __aicore__ inline void Compute()
    {
        AscendC::LocalTensor<T> srcLocal = inQueueSrc.DeQue<T>();
        AscendC::LocalTensor<U> insumLocal = sumQueue.DeQue<U>();
        AscendC::LocalTensor<U> inmaxLocal = maxQueue.DeQue<U>();
        AscendC::LocalTensor<U> inmeanLocal = meanQueue.DeQue<U>();
        AscendC::LocalTensor<T> expMaxTensor = expMaxQueue.AllocTensor<T>();
        AscendC::LocalTensor<T> dstLocal = outQueueDst.AllocTensor<T>();
        AscendC::SoftMaxParams params = {height, width, height, width, loopCnt, splitMeanCnt, alpha};
        AscendC::SoftmaxFlashV3<T, U, true>(dstLocal, inmeanLocal, insumLocal, inmaxLocal, srcLocal, expMaxTensor, inmeanLocal, insumLocal, inmaxLocal, tiling, params);

        outQueueDst.EnQue<T>(dstLocal);
        maxQueue.FreeTensor(inmaxLocal);
        sumQueue.FreeTensor(insumLocal);
        meanQueue.FreeTensor(inmeanLocal);
        inQueueSrc.FreeTensor(srcLocal);
    }
    __aicore__ inline void CopyOut()
    {
        AscendC::LocalTensor<T> dstLocal = outQueueDst.DeQue<T>();
        AscendC::DataCopy(dstGlobal, dstLocal, height * width);
        outQueueDst.FreeTensor(dstLocal);
    }

private:
    AscendC::TPipe pipe;
    AscendC::TQue<AscendC::QuePosition::VECIN, 1> inQueueSrc;
    AscendC::TQue<AscendC::QuePosition::VECIN, 1> meanQueue;
    AscendC::TQue<AscendC::QuePosition::VECIN, 1> maxQueue;
    AscendC::TQue<AscendC::QuePosition::VECIN, 1> sumQueue;
    AscendC::TQue<AscendC::QuePosition::VECIN, 1> expMaxQueue;
    AscendC::TQue<AscendC::QuePosition::VECOUT, 1> outQueueDst;
    AscendC::GlobalTensor<T> srcGlobal, dstGlobal;
    AscendC::GlobalTensor<U> meanGlobal, maxGlobal, sumGlobal;
    uint32_t elementNumPerBlk1 = 0;
    uint32_t elementNumPerBlk2 = 0;
    uint32_t width = 1024;
    uint32_t height = 8;
    uint32_t loopCnt = 2;
    uint32_t splitMeanCnt = 8;
    float alpha = 0.9375;
    SoftMaxTiling tiling;
};

extern "C" __global__ __aicore__ void softmax_flashv3_kernel(__gm__ uint8_t *srcGm,
    __gm__ uint8_t *inMaxGm, __gm__ uint8_t *inSumGm, __gm__ uint8_t *inMeanGm, __gm__ uint8_t *dstGm, __gm__ uint8_t *tiling)
{
    GET_TILING_DATA(tilingData, tiling);
    KernelSoftmaxFlashV3<half, float> op;
    op.Init(srcGm, inMaxGm, inSumGm, inMeanGm, dstGm, tilingData.softmaxTilingData);
    op.Process();
}