SoftmaxFlash

Applicability

Product

Supported

Atlas A3 training products/Atlas A3 inference products

Atlas A2 training products/Atlas A2 inference products

Atlas 200I/500 A2 inference products

x

Atlas inference product's AI Core

Atlas inference product's Vector Core

x

Atlas training products

x

Function

Note: This API will be deprecated in the future. Use SoftmaxFlashV2 that has better precision and performance.

Serves as the enhanced version of SoftMax, which not only performs softmaxflash computation on the input tensor, but updates the result of the current softmax computation based on the sum and max values obtained in the previous softmax computation. In scenarios where last axis tiling is involved, the reduce result computed each time is not along the entire axis. In this case, you can use this enhanced API to update the result of the current softmax computation based on the sum and max values obtained in the previous softmax computation. The NZ format is not supported.

Currently, only input shapes in ND format are supported. The internal reduce process is performed along the last axis. When update is disabled, this API is equivalent to SoftMax.

For ease of understanding, the formula expressed through a Python script is as follows, where src, inmax, insum, and update are inputs, and dst, x_sum, x_max, and exp_max are outputs.

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
def softmax_flash(src, inmax=None, insum=None, update=None):
    if update == None:
      # Perform rowmax (taking the maximum value by row) processing along the last axis. 
        x_max = np.max(src, axis=-1, keepdims=True)
        x_sub = src - x_max
        x_exp = np.exp(x_sub)
       # Perform rowsum (taking the sum by row) processing along the last axis.
        x_sum = np.sum(x_exp, axis=-1, keepdims=True)
        dst = x_exp / x_sum
        exp_max = None
        return dst, x_max, x_sum, exp_max
    else:
       # Combine inmax and src to obtain rowmax.
        x_max = np.max(np.concatenate((inmax, src), axis=-1), axis=-1, keepdims=True)
        x_exp = np.exp(src - x_max)
        x_sum = np.sum(x_exp, axis=-1, keepdims=True)
        exp_max = np.exp(inmax - x_max)
        x_sum = exp_max * insum +  x_sum
        exp_max = exp_max * insum / x_sum
        dst = x_exp / x_sum
        return dst, x_max, x_sum, exp_max

Prototype

  • Allocate the temporary space through the API framework.
    1
    2
    template <typename T, bool isReuseSource = false, bool isBasicBlock = false>
    __aicore__ inline void SoftmaxFlash(const LocalTensor<T> &dstTensor, const LocalTensor<T> &sumTensor, const LocalTensor<T> &maxTensor, const LocalTensor<T> &srcTensor, const LocalTensor<T> &expMaxTensor, const LocalTensor<T> &inSumTensor, const LocalTensor<T> &inMaxTensor, const SoftMaxTiling &tiling, bool isUpdate = false, const SoftMaxShapeInfo &softmaxShapeInfo = {})
    
    1
    2
    template <typename T, bool isReuseSource = false, bool isBasicBlock = false>
    __aicore__ inline void SoftmaxFlash(const LocalTensor<half>& dstTensor, const LocalTensor<float>& sumTensor, const LocalTensor<float>& maxTensor, const LocalTensor<half>& srcTensor, const LocalTensor<half>& expMaxTensor, const LocalTensor<float>& inSumTensor, const LocalTensor<float>& inMaxTensor, const SoftMaxTiling& tiling, bool isUpdate = false, const SoftMaxShapeInfo& softmaxShapeInfo = {})
    
  • Pass to the temporary space through the sharedTmpBuffer input parameter.
    1
    2
    template <typename T, bool isReuseSource = false, bool isBasicBlock = false>
    __aicore__ inline void SoftmaxFlash(const LocalTensor<T>& dstTensor, const LocalTensor<T>& sumTensor, const LocalTensor<T>& maxTensor, const LocalTensor<T>& srcTensor, const LocalTensor<T>& expMaxTensor, const LocalTensor<T>& inSumTensor, const LocalTensor<T>& inMaxTensor, const LocalTensor<uint8_t>& sharedTmpBuffer, const SoftMaxTiling& tiling, bool isUpdate = false, const SoftMaxShapeInfo& softmaxShapeInfo = {})
    
    1
    2
    template <typename T, bool isReuseSource = false, bool isBasicBlock = false>
    __aicore__ inline void SoftmaxFlash(const LocalTensor<half>& dstTensor, const LocalTensor<float>& sumTensor, const LocalTensor<float>& maxTensor, const LocalTensor<half>& srcTensor, const LocalTensor<half>& expMaxTensor, const LocalTensor<float>& inSumTensor, const LocalTensor<float>& inMaxTensor, const LocalTensor<uint8_t>& sharedTmpBuffer, const SoftMaxTiling& tiling, bool isUpdate = false, const SoftMaxShapeInfo& softmaxShapeInfo = {})
    

Due to the complex computation involved in the internal implementation of this API, extra temporary space is required to store intermediate variables generated during computation. The temporary space can be allocated through the API framework or passed by developers through the sharedTmpBuffer input parameter.

  • When the API framework is used for temporary space allocation, developers do not need to allocate the space, but must reserve the required size for the space.
  • When the sharedTmpBuffer input parameter is used for passing the temporary space, the tensor serves as the temporary space. In this case, the API framework is not required for temporary space allocation. This enables developers to manage the sharedTmpBuffer space and reuse the buffer after calling the API, so that the buffer is not repeatedly allocated and deallocated, improving the flexibility and buffer utilization.

If the API framework is used, developers must reserve the temporary space. If sharedTmpBuffer is used, developers must allocate space for the tensor. The method of obtaining the temporary space size (BufferSize) is as follows: Obtain the required maximum and minimum temporary space sizes using the GetSoftMaxFlashMaxTmpSize/GetSoftMaxFlashMinTmpSize API provided in SoftmaxFlash Tiling. The minimum space can ensure correct functionality, while the maximum space is used to improve performance.

Parameters

Table 1 Template parameters

Parameter

Description

T

Data type of the operand.

For the Atlas A3 training products/Atlas A3 inference products, the supported data types are half and float.

For the Atlas A2 training products/Atlas A2 inference products, the supported data types are half and float.

For the Atlas inference product's AI Core, the supported data types are half and float.

isReuseSource

This parameter is reserved. Pass the default value false.

isBasicBlock

If the shape information and tiling strategy of both srcTensor and dstTensor meet the base block requirements, this parameter can be enabled to improve performance. By default, this parameter is disabled. Use either of the following methods to determine whether the base block requirements are met:

  • The shape information [m, n] of srcTensor and dstTensor must meet the following requirements:
    • The last axis length n is less than 2048 and greater than or equal to 256/sizeof(T). That is, the minimum value of n is 128 when the data type is half and 64 when the data type is float. In addition, n is a multiple of 64.
    • The product m of non-last axis lengths is a multiple of 8.
  • You can call IsBasicBlockInSoftMax to check whether the tiling strategy meets the tiling requirements of base blocks.
Table 2 API parameters

Parameter

Input/Output

Description

dstTensor

Output

Destination operand.

The type is LocalTensor, and the supported TPosition is VECIN, VECCALC, or VECOUT.

The shape of dstTensor is the same as that of the source operand srcTensor.

sumTensor

Output

Destination operand.

The type is LocalTensor, and the supported TPosition is VECIN, VECCALC, or VECOUT.

It is used to store the reducesum result during softmax computation.

  • The length of the last axis of sumTensor is fixed at 32 bytes. It is the length of a data block. All data in this data block shares a common value. For example, in the half data type, all 16 numbers in this data block possess an identical reducesum value.
  • The length of each non-last axis is the same as that of dstTensor.

maxTensor

Output

Destination operand.

The type is LocalTensor, and the supported TPosition is VECIN, VECCALC, or VECOUT.

It is used to store the reducemax result during softmax computation.

  • The length of the last axis of maxTensor is fixed at 32 bytes. It is the length of a data block. All data in this data block has the same value. For example, in the half data type, all 16 numbers in this data block possess an identical reducemax value.
  • The length of each non-last axis is the same as that of dstTensor.

srcTensor

Input

Source operand.

The type is LocalTensor, and the supported TPosition is VECIN, VECCALC, or VECOUT.

The length of the last axis must be 32-byte aligned.

expMaxTensor

Output

Destination operand.

The type is LocalTensor, and the supported TPosition is VECIN, VECCALC, or VECOUT.

  • The length of the last axis of expMaxTensor is fixed at 32 bytes. It is the length of a data block. All data in this data block has the same value. For example, in the half data type, all 16 numbers in this data block possess an identical value.
  • The length of each non-last axis is the same as that of dstTensor.

inSumTensor

Input

Source operand.

The type is LocalTensor, and the supported TPosition is VECIN, VECCALC, or VECOUT.

This parameter indicates the sum value required for softmax computation.

  • The length of the last axis of inSumTensor is fixed at 32 bytes. It is the length of a data block. All data in this data block shares a common value. For example, in the half data type, all 16 numbers in this data block possess an identical value.
  • The length of each non-last axis is the same as that of dstTensor.

inMaxTensor

Input

Source operand.

The type is LocalTensor, and the supported TPosition is VECIN, VECCALC, or VECOUT.

This parameter indicates the max value required for softmax computation.

  • The length of the last axis of inMaxTensor is fixed at 32 bytes. It is the length of a data block. All data in this data block shares a common value. For example, in the half data type, all 16 numbers in this data block possess an identical value.
  • The length of each non-last axis is the same as that of dstTensor.

sharedTmpBuffer

Input

Temporary space.

The type is LocalTensor, and the supported TPosition is VECIN, VECCALC, or VECOUT.

This parameter is used to store intermediate variables during complex internal API computation and is provided by developers.

For details about how to obtain the temporary space size (BufferSize), see SoftmaxFlash Tiling.

tiling

Input

Tiling information required for API computation. For details about how to obtain the tiling information, see SoftmaxFlash Tiling.

isUpdate

Input

Whether to enable the update algorithm.

softmaxShapeInfo

Input

Shape of srcTensor, which is of the SoftMaxShapeInfo type. The specific definition is as follows:

1
2
3
4
5
6
struct SoftMaxShapeInfo {
uint32_t srcM; // Product of lengths of non-last axes
uint32_t srcK; // Length of the last axis, which must be 32-byte aligned
uint32_t oriSrcM; // Product of lengths of original non-last axes
uint32_t oriSrcK; // Length of the original last axis
};

Returns

None

Restrictions

  • The space of srcTensor and dstTensor, maxTensor and inMaxTensor, and sumTensor and inSumTensor can be reused.
  • For the tensor space of sumTensor, maxTensor, expMaxTensor, inSumTensor, or inMaxTensor, the length of the last axis must be fixed at 32 bytes.
  • For details about the operand address alignment requirements, see General Address Alignment Restrictions.
  • The address of sharedTmpBuffer must not overlap that of the source or destination operand.

Example

In this example, the shape size of the input src and output dst is [80, 144]. The shape size of the input inExpSumTensor, input inMaxTensor, and output expMaxTensor is [80, 16]. The data type is half, and update is false. For more operator examples, see softmaxflash operator sample.
  1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
#include "kernel_operator.h"

template <typename T>
class KernelSoftmaxFlash {
public:
    __aicore__ inline KernelSoftmaxFlash()
    {}
    __aicore__ inline void Init(
        GM_ADDR srcGm, GM_ADDR inMaxGm, GM_ADDR inSumGm, GM_ADDR dstGm, const SoftMaxTiling &tilingData)
    {
        elementNumPerBlk = 32 / sizeof(T);
        srcGlobal.SetGlobalBuffer((__gm__ T *)srcGm);
        maxGlobal.SetGlobalBuffer((__gm__ T *)inMaxGm);
        sumGlobal.SetGlobalBuffer((__gm__ T *)inSumGm);
        dstGlobal.SetGlobalBuffer((__gm__ T *)dstGm);
        pipe.InitBuffer(inQueueSrc, 1, height * width * sizeof(T));
        pipe.InitBuffer(outQueueDst, 1, height * width * sizeof(T));
        pipe.InitBuffer(inMaxQueue, 1, height * elementNumPerBlk * sizeof(T));
        pipe.InitBuffer(inSumQueue, 1, height * elementNumPerBlk * sizeof(T));
        pipe.InitBuffer(expMaxQueue, 1, height * elementNumPerBlk * sizeof(T));
        tiling = tilingData;
    }
    __aicore__ inline void Process()
    {
        CopyIn();
        Compute();
        CopyOut();
    }

private:
    __aicore__ inline void CopyIn()
    {
        AscendC::LocalTensor<T> srcLocal = inQueueSrc.AllocTensor<T>();
        AscendC::LocalTensor<T> inSumLocal = inSumQueue.AllocTensor<T>();
        AscendC::LocalTensor<T> inMaxLocal = inMaxQueue.AllocTensor<T>();
        AscendC::DataCopy(srcLocal, srcGlobal, height * width);
        AscendC::DataCopy(inSumLocal, sumGlobal, height * elementNumPerBlk);
        AscendC::DataCopy(inMaxLocal, maxGlobal, height * elementNumPerBlk);
        inQueueSrc.EnQue(srcLocal);
        inSumQueue.EnQue(inSumLocal);
        inMaxQueue.EnQue(inMaxLocal);
    }
    __aicore__ inline void Compute()
    {
        AscendC::LocalTensor<T> srcLocal = inQueueSrc.DeQue<T>();
        AscendC::LocalTensor<T> dstLocal = outQueueDst.AllocTensor<T>();

        AscendC::LocalTensor<T> inMaxLocal = inMaxQueue.AllocTensor<T>();
        AscendC::LocalTensor<T> inSumLocal = inSumQueue.AllocTensor<T>();
        AscendC::LocalTensor<T> expMaxTensor = expMaxQueue.AllocTensor<T>();
        AscendC::SoftMaxShapeInfo srcShape = {height, width, height, width};
        AscendC::SoftmaxFlash<T, false>(srcLocal,
            inSumLocal,
            inMaxLocal,
            srcLocal,
            expMaxTensor,
            inSumLocal,
            inMaxLocal,
            tiling,
            false,
            srcShape);

        AscendC::DataCopy(dstLocal, srcLocal, height * width);

        outQueueDst.EnQue<T>(dstLocal);
        inMaxQueue.FreeTensor(inMaxLocal);
        inSumQueue.FreeTensor(inSumLocal);
        inQueueSrc.FreeTensor(srcLocal);
 
        expMaxQueue.FreeTensor(expMaxTensor);
    }
    __aicore__ inline void CopyOut()
    {
        AscendC::LocalTensor<T> dstLocal = outQueueDst.DeQue<T>();
        AscendC::DataCopy(dstGlobal, dstLocal, height * width);
        outQueueDst.FreeTensor(dstLocal);
    }

private:
    AscendC::TPipe pipe;
    AscendC::TQue<AscendC::TPosition::VECIN, 1> inQueueSrc;
    AscendC::TQue<AscendC::TPosition::VECOUT, 1> outQueueDst;
    AscendC::TQue<AscendC::TPosition::VECIN, 1> inMaxQueue;
    AscendC::TQue<AscendC::TPosition::VECIN, 1> inSumQueue;
    AscendC::TQue<AscendC::TPosition::VECIN, 1> expMaxQueue;

    AscendC::GlobalTensor<T> srcGlobal, dstGlobal;
    AscendC::GlobalTensor<T> maxGlobal, sumGlobal;
    uint32_t elementNumPerBlk = 0;
    uint32_t width = 144;
    uint32_t height = 80;
    SoftMaxTiling tiling;
};

extern "C" __global__ __aicore__ void softmax_flash_kernel_half(GM_ADDR srcGm, GM_ADDR inMaxGm, GM_ADDR inSumGm, GM_ADDR dstGm, GM_ADDR tiling)
{
    GET_TILING_DATA(tilingData, tiling);
    KernelSoftmaxFlash<half> op;
    op.Init(srcGm, inMaxGm, inSumGm, dstGm, tilingData.softmaxTilingData);
    op.Process();
}