SoftmaxFlashV3

Applicability

Product

Supported

Atlas A3 training products/Atlas A3 inference products

Atlas A2 training products/Atlas A2 inference products

Atlas 200I/500 A2 inference products

x

Atlas inference product's AI Core

x

Atlas inference product's Vector Core

x

Atlas training products

x

Function

Serves as the enhanced version of SoftmaxFlash, corresponding to the Softmax PASA algorithm. If the product of non-last axis lengths m0, m1, ..., mt of the input tensors [m0, m1, ..., mt, n] (t ≥ 0) is considered as m, the shape of the input tensor is [m, n]. The last axis of the input tensor x is split. The number of blocks is splitMeanCnt. The tensor after splitting is x_cnti. Below is the formula, where x, inmax, insum, and inmean are inputs, and M, S, E, and A are outputs.

  • If update is false, the formulas are as follows.

  • If update is true, the formulas are as follows.

Currently, this API supports the input in ND format only. The internal reduce process is processed based on the last axis.

For ease of understanding, the Python pseudocode is used to express the formulas as follows: repeatSize is 64; elementNumPerBlk and BlkcntPerRepeat are 8; splitMeanCnt is 8; src, inmean, inmax, insum, and update are inputs; dst, x_mean, x_sum, x_max, and exp_max are outputs.
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
def softmax_flash_3(src, height, width, loopCnt, alpha, baseK, inmax=None, insum=None, inmean=None, update=False):
    scalar = alpha / (1 - alpha)
    #(m,n)->(m,64)
    tmpbuffer0 = BlockReduceSum(repeatSize, repeatSize, elementNumPerBlk)
    remain = int(width / repeatSize - BlkcntPerRepeat)
    tmpbuffer0 = Add(tmpbuffer0, src, remain, repeatSize * elementNumPerBlk, width)
    #(m,64)->(m,8)
    tmpbuffer0 = BlockReduceSum(1, elementNumPerBlk, elementNumPerBlk)
    #width = baseK * splitMeanCnt
    rowMeanLocal = tmpbuffer0 / baseK
    rowMeanGlobal = np.mean(src, axis=(-1), keepdims=True)
    rowMeanGlobalTmp = (rowMeanGlobal - rowMeanLocal) * scalar
    src = src - rowMeanGlobalTmp 

    if update == False:
        x_mean = rowMeanGlobal
        maxTmp = np.max(src, axis=-1, keepdims=True)
        shiftCurr = (rowMeanGlobal - x_mean) * scalar
        x_max = shiftCurr + maxTmp
        maxTmp = x_max - shiftCurr
        x_sub = src - maxTmp   
        dst = np.exp(x_sub) 
        x_sum = np.sum(dst, axis=-1, keepdims=True)
        exp_max = None
        return dst, x_max, x_sum, x_mean, exp_max
    else:
        x_mean = (rowMeanGlobal + inmean * (loopCnt - 1)) / loopCnt
        maxTmp = np.max(src, axis=-1, keepdims=True)
        shiftCurr = (rowMeanGlobal - x_mean) * scalar
        shiftPrev = (inmean - x_mean) * scalar
	x_max = shiftCurr + maxTmp
        maxTmp = shiftPrev + inmax
        x_max = np.max(np.concatenate((x_max, maxTmp), axis=(-1)), axis=(-1), keepdims=True)
        maxTmp = x_max - shiftCurr
        x_sub = src - maxTmp   
        dst = np.exp(x_sub)
        exp_max = np.exp(inmax - x_max + shiftPrev)
        x_sum = np.sum(x_exp, axis=-1, keepdims=True)
        x_sum = exp_max * insum +  x_sum
        return x_exp, x_max, x_sum, x_mean, exp_max

Prototype

  • Allocate the temporary space through the API framework.
    1
    2
    template <typename T, typename U, bool isUpdate = false, bool isReuseSource = false, bool isBasicBlock = false, bool isDataFormatNZ = false, const SoftmaxConfig& config = SOFTMAX_DEFAULT_CFG>
    __aicore__ inline void SoftmaxFlashV3(const LocalTensor<T>& dstTensor, const LocalTensor<U>& meanTensor, const LocalTensor<U>& expSumTensor, const LocalTensor<U>& maxTensor, const LocalTensor<T>& srcTensor, const LocalTensor<T>& expMaxTensor, const LocalTensor<U>& inMeanTensor, const LocalTensor<U>& inExpSumTensor, const LocalTensor<U>& inMaxTensor, const SoftMaxTiling& tiling, const SoftMaxParams& params)
    
  • Pass to the temporary space through the sharedTmpBuffer input parameter.
    1
    2
    template <typename T, typename U, bool isUpdate = false, bool isReuseSource = false, bool isBasicBlock = false, bool isDataFormatNZ = false, const SoftmaxConfig& config = SOFTMAX_DEFAULT_CFG>
    __aicore__ inline void SoftmaxFlashV3(const LocalTensor<T>& dstTensor, const LocalTensor<U>& meanTensor,const LocalTensor<U>& expSumTensor, const LocalTensor<U>& maxTensor, const LocalTensor<T>& srcTensor,const LocalTensor<T>& expMaxTensor, const LocalTensor<U>& inMeanTensor, const LocalTensor<U>& inExpSumTensor, const LocalTensor<U>& inMaxTensor, const LocalTensor<uint8_t>& sharedTmpBuffer, const SoftMaxTiling& tiling, const SoftMaxParams& params)
    

Due to the complex computation involved in the internal implementation of this API, extra temporary space is required to store intermediate variables generated during computation. The temporary space can be allocated through the API framework or passed by developers through the sharedTmpBuffer input parameter.

  • When the API framework is used for temporary space allocation, developers do not need to allocate the space, but must reserve the required size for the space.
  • When the sharedTmpBuffer input parameter is used for passing the temporary space, the tensor serves as the temporary space. In this case, the API framework is not required for temporary space allocation. This enables developers to manage the sharedTmpBuffer space and reuse the buffer after calling the API, so that the buffer is not repeatedly allocated and deallocated, improving the flexibility and buffer utilization.

If the API framework is used, developers must reserve the temporary space. If sharedTmpBuffer is used, developers must allocate space for the tensor. The method of obtaining the temporary space size (BufferSize) is as follows: Obtain the required minimum and maximum temporary space sizes using the GetSoftMaxFlashV3MaxMinTmpSize API described in SoftmaxFlashV3 Tiling. The minimum space can ensure correct functionality, while the maximum space is used to improve performance.

Parameters

Table 1 Template parameters

Parameter

Description

T

Data types of the input srcTensor and output dstTensor and expMaxTensor operands.

For the Atlas A3 training products/Atlas A3 inference products, the supported data type is half.

For the Atlas A2 training products/Atlas A2 inference products, the supported data type is half.

U

Data types of the input inMeanTensor, inExpSumTensor, and inMaxTensor and the output meanTensor, expSumTensor, and maxTensor operands.

For the Atlas A3 training products/Atlas A3 inference products, the supported data type is float.

For the Atlas A2 training products/Atlas A2 inference products, the supported data type is float.

isUpdate

Whether to set update to true in the computation.

isReuseSource

This parameter is reserved. Pass the default value false.

isBasicBlock

This parameter is reserved. Pass the default value false.

isDataFormatNZ

This parameter is reserved. Pass the default value false.

config

This parameter is reserved. Use the default value SOFTMAX_DEFAULT_CFG.

Table 2 API parameters

Parameter

Input/Output

Description

dstTensor

Output

Destination operand.

The type is LocalTensor, and the supported TPosition is VECIN, VECCALC, or VECOUT.

The shape of dstTensor is the same as that of the source operand srcTensor.

meanTensor

Output

Destination operand.

The type is LocalTensor, and the supported TPosition is VECIN, VECCALC, or VECOUT.

It is used to store the mean value result during softmax computation.

  • The length of the last axis of meanTensor is fixed at 32 bytes. It is the length of a data block. All data in this data block has the same value. For example, in the float data type, all eight numbers in this data block possess an identical value, which is the mean value obtained after reducesum.
  • The length of each non-last axis is the same as that of dstTensor.

expSumTensor

Output

Destination operand.

The type is LocalTensor, and the supported TPosition is VECIN, VECCALC, or VECOUT.

It is used to store the reducesum result during softmax computation.

  • The length of the last axis of expSumTensor is fixed at 32 bytes. It is the length of a data block. All data in this data block has the same value. For example, in the float data type, all eight numbers in this data block possess an identical value obtained after reducesum.
  • The length of each non-last axis is the same as that of dstTensor.

maxTensor

Output

Destination operand.

The type is LocalTensor, and the supported TPosition is VECIN, VECCALC, or VECOUT.

It is used to store the reducemax result during softmax computation.

  • The length of the last axis of maxTensor is fixed at 32 bytes. It is the length of a data block. All data in this data block has the same value. For example, in the float data type, all eight numbers in this data block possess an identical value obtained after reducemax.
  • The length of each non-last axis is the same as that of dstTensor.

srcTensor

Input

Source operand.

The type is LocalTensor, and the supported TPosition is VECIN, VECCALC, or VECOUT.

The length of the last axis must be 32-byte aligned.

expMaxTensor

Output

Destination operand.

The type is LocalTensor, and the supported TPosition is VECIN, VECCALC, or VECOUT.

  • The length of the last axis of expMaxTensor is fixed at 32 bytes. It is the length of a data block. All data in this data block has the same value. For example, in the half data type, all 16 numbers in this data block possess an identical value.
  • The length of each non-last axis is the same as that of dstTensor.

inMeanTensor

Input

Source operand.

The type is LocalTensor, and the supported TPosition is VECIN, VECCALC, or VECOUT.

This parameter indicates the mean value required for softmax computation.

  • The length of the last axis of inMeanTensor is fixed at 32 bytes. It is the length of a data block. All data in this data block has the same value. For example, in the float data type, all eight numbers in this data block possess an identical value.
  • The length of each non-last axis is the same as that of dstTensor.

inExpSumTensor

Input

Source operand.

The type is LocalTensor, and the supported TPosition is VECIN, VECCALC, or VECOUT.

This parameter indicates the sum value required for softmax computation.

  • The length of the last axis of inExpSumTensor is fixed at 32 bytes. It is the length of a data block. All data in this data block has the same value. For example, in the float data type, all eight numbers in this data block possess an identical value.
  • The length of each non-last axis is the same as that of dstTensor.

inMaxTensor

Input

Source operand.

The type is LocalTensor, and the supported TPosition is VECIN, VECCALC, or VECOUT.

This parameter indicates the max value required for softmax computation.

  • The length of the last axis of inMaxTensor is fixed at 32 bytes. It is the length of a data block. All data in this data block has the same value. For example, in the float data type, all eight numbers in this data block possess an identical value.
  • The length of each non-last axis is the same as that of dstTensor.

sharedTmpBuffer

Input

Temporary space.

The type is LocalTensor, and the supported TPosition is VECIN, VECCALC, or VECOUT.

The data type of this operand is fixed at uint8_t.

This parameter is used to store intermediate variables during complex internal API computation and is provided by developers.

For details about how to obtain the temporary space size (BufferSize), see SoftmaxFlashV3 Tiling.

tiling

Input

Tiling information required for SoftmaxFlashV3 computation. For details about how to obtain the tiling information, see SoftmaxFlashV3 Tiling.

params

Input

Shape information and computation parameters of srcTensor. It is of the SoftMaxParams type. The specific definition is as follows:

1
2
3
4
5
6
7
8
9
struct SoftMaxParams {
uint32_t srcM; // Product of lengths of non-last axes
uint32_t srcK; // Length of the last axis, which must be 32-byte aligned
uint32_t oriSrcM; // Product of lengths of original non-last axes
uint32_t oriSrcK;  // Length of the original last axis
uint32_t loopCnt; // loopCnt value in the formula when update is true. This parameter is greater than or equal to 1.
uint32_t splitMeanCnt; // Number of blocks for calculating the mean value of each row in the formula. Currently, the value can only be 8.
    float alpha; // Computation parameter in the formula. The recommended values are 0.9375, 0.96889, and 0.984497.
};

Note: This API does not support the non-alignment scenario. Therefore, srcM is equal to oriSrcM, and srcK is equal to oriSrcK.

Returns

None

Restrictions

  • For details about the operand address alignment requirements, see General Address Alignment Restrictions.
  • For the input srcTensor, the last axis length n must be greater than or equal to 512 and be a multiple of 64. The product m of non-last axis lengths is a multiple of 8.
  • The tensor space of srcTensor and dstTensor, meanTensor and inMeanTensor, maxTensor and inMaxTensor, and expSumTensor and inExpSumTensor can be reused.
  • For the tensor space of meanTensor, expSumTensor, maxTensor, expMaxTensor, inMeanTensor, or inExpSumTensor, and inMaxTensor, the length of the last axis must be fixed at 32 bytes.
  • The address of sharedTmpBuffer must not overlap that of the source or destination operand.

Example

In this example, the shape size of the input srcTensor and output dstTensor is [8, 1024]. The shape size of the input inMeanTensor, inExpSumTensor, and inMaxTensor is [8, 8], and the data type is float. The shape size of the output expMaxTensor is [8, 16], and the data type is half. The format of the input and output data is ND. The space of srcTensor and dstTensor is not reused. isUpdate is true.
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
#include "kernel_operator.h"

template <typename T, typename U>
class KernelSoftmaxFlashV3 {
public:
    __aicore__ inline KernelSoftmaxFlashV3() {}
    __aicore__ inline void Init(__gm__ uint8_t *srcGm, __gm__ uint8_t *inMaxGm, __gm__ uint8_t *inSumGm,
       __gm__ uint8_t *inMeanGm, __gm__ uint8_t *dstGm, const SoftMaxTiling &tilingData)
    {
        srcGlobal.SetGlobalBuffer((__gm__ T *)srcGm);
        dstGlobal.SetGlobalBuffer((__gm__ T *)dstGm);
        maxGlobal.SetGlobalBuffer((__gm__ U *)inMaxGm);
        sumGlobal.SetGlobalBuffer((__gm__ U *)inSumGm);
        meanGlobal.SetGlobalBuffer((__gm__ U *)inMeanGm);
        pipe.InitBuffer(inQueueSrc, 1, height * width * sizeof(T));
        elementNumPerBlk1 = 32 / sizeof(U);
        pipe.InitBuffer(maxQueue, 1, height * elementNumPerBlk1 * sizeof(U));
        pipe.InitBuffer(sumQueue, 1, height * elementNumPerBlk1 * sizeof(U));
        pipe.InitBuffer(meanQueue, 1, height * elementNumPerBlk1 * sizeof(U));
        elementNumPerBlk2 = 32 / sizeof(T);
        pipe.InitBuffer(expMaxQueue, 1, height * elementNumPerBlk2 * sizeof(T));
        pipe.InitBuffer(outQueueDst, 1, height * width * sizeof(T));
        tiling = tilingData;
    }
    __aicore__ inline void Process()
    {
        CopyIn();
        Compute();
        CopyOut();
    }

private:
    __aicore__ inline void CopyIn()
    {
        AscendC::LocalTensor<T> srcLocal = inQueueSrc.AllocTensor<T>();
        AscendC::LocalTensor<U> insumLocal = sumQueue.AllocTensor<U>();
        AscendC::LocalTensor<U> inmaxLocal = maxQueue.AllocTensor<U>();
        AscendC::LocalTensor<U> inmeanLocal = meanQueue.AllocTensor<U>();
        AscendC::DataCopy(srcLocal, srcGlobal, height * width);
        AscendC::DataCopy(insumLocal, sumGlobal, height * elementNumPerBlk1);
        AscendC::DataCopy(inmaxLocal, maxGlobal, height * elementNumPerBlk1);
        AscendC::DataCopy(inmeanLocal, meanGlobal, height * elementNumPerBlk1);
        inQueueSrc.EnQue(srcLocal);
        sumQueue.EnQue(insumLocal);
        maxQueue.EnQue(inmaxLocal);
        meanQueue.EnQue(inmeanLocal);
    }
    __aicore__ inline void Compute()
    {
        AscendC::LocalTensor<T> srcLocal = inQueueSrc.DeQue<T>();
        AscendC::LocalTensor<U> insumLocal = sumQueue.DeQue<U>();
        AscendC::LocalTensor<U> inmaxLocal = maxQueue.DeQue<U>();
        AscendC::LocalTensor<U> inmeanLocal = meanQueue.DeQue<U>();
        AscendC::LocalTensor<T> expMaxTensor = expMaxQueue.AllocTensor<T>();
        AscendC::LocalTensor<T> dstLocal = outQueueDst.AllocTensor<T>();
        AscendC::SoftMaxParams params = {height, width, height, width, loopCnt, splitMeanCnt, alpha};
        AscendC::SoftmaxFlashV3<T, U, true>(dstLocal, inmeanLocal, insumLocal, inmaxLocal, srcLocal, expMaxTensor, inmeanLocal, insumLocal, inmaxLocal, tiling, params);

        outQueueDst.EnQue<T>(dstLocal);
        maxQueue.FreeTensor(inmaxLocal);
        sumQueue.FreeTensor(insumLocal);
        meanQueue.FreeTensor(inmeanLocal);
        inQueueSrc.FreeTensor(srcLocal);
    }
    __aicore__ inline void CopyOut()
    {
        AscendC::LocalTensor<T> dstLocal = outQueueDst.DeQue<T>();
        AscendC::DataCopy(dstGlobal, dstLocal, height * width);
        outQueueDst.FreeTensor(dstLocal);
    }

private:
    AscendC::TPipe pipe;
    AscendC::TQue<AscendC::TPosition::VECIN, 1> inQueueSrc;
    AscendC::TQue<AscendC::TPosition::VECIN, 1> meanQueue;
    AscendC::TQue<AscendC::TPosition::VECIN, 1> maxQueue;
    AscendC::TQue<AscendC::TPosition::VECIN, 1> sumQueue;
    AscendC::TQue<AscendC::TPosition::VECIN, 1> expMaxQueue;
    AscendC::TQue<AscendC::TPosition::VECOUT, 1> outQueueDst;
    AscendC::GlobalTensor<T> srcGlobal, dstGlobal;
    AscendC::GlobalTensor<U> meanGlobal, maxGlobal, sumGlobal;
    uint32_t elementNumPerBlk1 = 0;
    uint32_t elementNumPerBlk2 = 0;
    uint32_t width = 1024;
    uint32_t height = 8;
    uint32_t loopCnt = 2;
    uint32_t splitMeanCnt = 8;
    float alpha = 0.9375;
    SoftMaxTiling tiling;
};

extern "C" __global__ __aicore__ void softmax_flashv3_kernel(__gm__ uint8_t *srcGm,
    __gm__ uint8_t *inMaxGm, __gm__ uint8_t *inSumGm, __gm__ uint8_t *inMeanGm, __gm__ uint8_t *dstGm, __gm__ uint8_t *tiling)
{
    GET_TILING_DATA(tilingData, tiling);
    KernelSoftmaxFlashV3<half, float> op;
    op.Init(srcGm, inMaxGm, inSumGm, inMeanGm, dstGm, tilingData.softmaxTilingData);
    op.Process();
}