SoftmaxFlashV2

Applicability

Product

Supported

Atlas A3 training products / Atlas A3 inference products

Atlas A2 training products / Atlas A2 inference products

Atlas 200I/500 A2 inference products

Atlas inference product 's AI Core

Atlas inference product 's Vector Core

x

Atlas training products

x

Function

Serves as the enhanced version of SoftmaxFlash, corresponding to the FlashAttention-2 algorithm. If the product of non-last axis lengths of the input tensor [m0, m1, ..., mt, n] (t ≥ 0) is considered as m, the shape of the input tensor is [m, n]. This API performs the following computation on the input tensor [m, n] by row. Different values of update correspond to different formulas, where x, inmax, and insum are inputs, and M, S, and E are outputs.

  • If update is false, the formulas are as follows.

  • If update is true, the formulas are as follows.

When the input shape is in ND format, the internal reduce process is performed along the last axis. When the input shape is in NZ format, the internal reduce process is performed along the last and first axes. For details about the reduce process, see the figures in SoftMax.

For ease of understanding, the formula expressed through a Python script is as follows, where src, inmax, insum, and update are inputs, and dst, x_sum, x_max, and exp_max are outputs.

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
def softmax_flash_2(src, inmax=None, insum=None, update=None):
    if update == None:
        x_max = np.max(src, axis=-1, keepdims=True)
        x_sub = src - x_max   
        dst = np.exp(x_sub) 
        x_sum = np.sum(dst, axis=-1, keepdims=True)
        exp_max = None
        return dst, x_max, x_sum, exp_max
    else:
        x_max = np.max(np.concatenate((inmax, src), axis=-1), axis=-1, keepdims=True)
        dst = np.exp(src - x_max)
        exp_max = np.exp(inmax - x_max)
        x_sum = np.sum(dst, axis=-1, keepdims=True)
        x_sum = exp_max * insum +  x_sum
        return dst, x_max, x_sum, exp_max

Principles

The following figure shows the internal algorithm diagram of the SoftmaxFlashV2 high-level APIs by taking the input tensor of the float type, in ND format, and with shape [m, k] as an example.

Figure 1 Diagram of the SoftmaxFlashV2 algorithm

The computation process is divided into two branches based on whether isUpdate is enabled, which are both performed on vectors.

  • When isUpdate is set to False, the process is as follows:
    1. reducemax: Compute the maximum value of each row of input x to obtain [m, 1]. The computation result is saved to the temporary space temp.
    2. broadcast: Pad the data [m, 1] in temp by data block. For example, for the float type, extend [m, 1] to [m, 8] and output max.
    3. sub: Subtract max from all data of input x by row.
    4. exp: Compute exp for all data after sub and output y.
    5. reducesum: Sum up each row of data after exp is performed to obtain [m, 1]. The computation result is saved to temp.
    6. broadcast: Pad [m, 1] in temp by data block. For example, for the float type, extend [m, 1] to [m, 8] and output sum.
  • When isUpdate is set to True, the process is as follows:
    1. reducemax: Compute the maximum value of each row of input x to obtain [m, 1]. The computation result is saved to the temporary space temp.
    2. broadcast: Pad the data [m, 1] in temp by data block. For example, for the float type, extend [m, 1] to [m, 8] and save the result as max.
    3. max: Perform the max operation on the input inmax and the max obtained in the previous step to obtain a new max and output it.
    4. sub: Subtract the input inmax from the new max, perform exp to obtain expmax, and output it.
    5. subs: Subtract the input x from the new max by row.
    6. exp: Compute exp for all data after sub and output y.
    7. reducesum: Sum up each row of data after exp is performed to obtain [m, 1]. The computation result is saved to temp.
    8. broadcast: Pad the data [m, 1] in temp by data block. For example, for the float type, extend [m, 1] to [m, 8] and save the result to sum.
    9. mul: Multiply insum and expmax.
    10. add: Add the multiplication result and sum, save the result to sum, and output the result.

Prototype

  • Allocate the temporary space through the API framework.
    • The data types of LocalTensor are the same, and ReduceMax is not output.
      1
      2
      template <typename T, bool isUpdate = false, bool isReuseSource = false, bool isBasicBlock = false, bool isDataFormatNZ = false, const SoftmaxConfig& config = SOFTMAX_DEFAULT_CFG>
      __aicore__ inline void SoftmaxFlashV2(const LocalTensor<T>& dstTensor, const LocalTensor<T>& expSumTensor, const LocalTensor<T>& maxTensor, const LocalTensor<T>& srcTensor, const LocalTensor<T>& expMaxTensor, const LocalTensor<T>& inExpSumTensor, const LocalTensor<T>& inMaxTensor, const SoftMaxTiling& tiling, const SoftMaxShapeInfo& softmaxShapeInfo = {})
      
    • The data types of LocalTensor are the same, and ReduceMax is output.
      1
      2
      template <typename T, bool isUpdate = false, bool isReuseSource = false, bool isBasicBlock = false, bool isDataFormatNZ = false, const SoftmaxConfig& config = SOFTMAX_DEFAULT_CFG>
      __aicore__ inline void SoftmaxFlashV2(const LocalTensor<T>& dstTensor, const LocalTensor<T>& outReduceMax, const LocalTensor<T>& outExpSum, const LocalTensor<T>& outMax, const LocalTensor<T>& srcTensor, const LocalTensor<T>& outExpMax, const LocalTensor<T>& inExpSum, const LocalTensor<T>& inMax, const SoftMaxTiling& tiling, const SoftMaxShapeInfo& softmaxShapeInfo = {})
      

      Atlas 200I/500 A2 inference products : This API is not supported.

      Atlas inference product 's AI Core: This API is not supported.

    • The data types of LocalTensor are not the same, and ReduceMax is not output.
      1
      2
      template <typename T, bool isUpdate = false, bool isReuseSource = false, bool isBasicBlock = false, bool isDataFormatNZ = false, const SoftmaxConfig& config = SOFTMAX_DEFAULT_CFG>
      __aicore__ inline void SoftmaxFlashV2(const LocalTensor<half>& dstTensor, const LocalTensor<float>& expSumTensor, const LocalTensor<float>& maxTensor, const LocalTensor<half>& srcTensor, const LocalTensor<half>& expMaxTensor, const LocalTensor<float>& inExpSumTensor, const LocalTensor<float>& inMaxTensor, const SoftMaxTiling& tiling, const SoftMaxShapeInfo& softmaxShapeInfo = {})
      
  • Pass to the temporary space through the sharedTmpBuffer input parameter.
    • The data types of LocalTensor are the same, and ReduceMax is not output.
      1
      2
      template <typename T, bool isUpdate = false, bool isReuseSource = false, bool isBasicBlock = false, bool isDataFormatNZ = false, const SoftmaxConfig& config = SOFTMAX_DEFAULT_CFG>
      __aicore__ inline void SoftmaxFlashV2(const LocalTensor<T>& dstTensor, const LocalTensor<T>& outExpSum, const LocalTensor<T>& outMax, const LocalTensor<T>& srcTensor, const LocalTensor<T>& outExpMax, const LocalTensor<T>& inExpSum, const LocalTensor<T>& inMax, const LocalTensor<uint8_t>& sharedTmpBuffer, const SoftMaxTiling& tiling, const SoftMaxShapeInfo& softmaxShapeInfo = {})
      
    • The data types of LocalTensor are the same, and ReduceMax is output.
      1
      2
      template <typename T, bool isUpdate = false, bool isReuseSource = false, bool isBasicBlock = false, bool isDataFormatNZ = false, const SoftmaxConfig& config = SOFTMAX_DEFAULT_CFG>
      __aicore__ inline void SoftmaxFlashV2(const LocalTensor<T>& dstTensor, const LocalTensor<T>& outReduceMax, const LocalTensor<T>& expSumTensor, const LocalTensor<T>& maxTensor, const LocalTensor<T>& srcTensor, const LocalTensor<T>& expMaxTensor, const LocalTensor<T>& inExpSumTensor, const LocalTensor<T>& inMaxTensor, const LocalTensor<uint8_t>& sharedTmpBuffer, const SoftMaxTiling& tiling, const SoftMaxShapeInfo& softmaxShapeInfo = {})
      

      Atlas 200I/500 A2 inference products : This API is not supported.

      Atlas inference product 's AI Core: This API is not supported.

    • The data types of LocalTensor are not the same, and ReduceMax is not output.
      1
      2
      template <typename T, bool isUpdate = false, bool isReuseSource = false, bool isBasicBlock = false, bool isDataFormatNZ = false, const SoftmaxConfig& config = SOFTMAX_DEFAULT_CFG>
      __aicore__ inline void SoftmaxFlashV2(const LocalTensor<half>& dstTensor, const LocalTensor<float>& expSumTensor, const LocalTensor<float>& maxTensor, const LocalTensor<half>& srcTensor, const LocalTensor<half>& expMaxTensor, const LocalTensor<float>& inExpSumTensor, const LocalTensor<float>& inMaxTensor, const LocalTensor<uint8_t>& sharedTmpBuffer, const SoftMaxTiling& tiling, const SoftMaxShapeInfo& softmaxShapeInfo = {})
      

Due to the complex computation involved in the internal implementation of this API, extra temporary space is required to store intermediate variables generated during computation. The temporary space can be allocated through the API framework or passed by developers through the sharedTmpBuffer input parameter.

  • When the API framework is used for temporary space allocation, developers do not need to allocate the space, but must reserve the required size for the space.
  • When the sharedTmpBuffer input parameter is used for passing the temporary space, the tensor serves as the temporary space. In this case, the API framework is not required for temporary space allocation. This enables developers to manage the sharedTmpBuffer space and reuse the buffer after calling the API, so that the buffer is not repeatedly allocated and deallocated, improving the flexibility and buffer utilization.

If the API framework is used, developers must reserve the temporary space. If sharedTmpBuffer is used, developers must allocate space for the tensor. The method of obtaining the temporary space size (BufferSize) is as follows: Obtain the required maximum and minimum temporary space sizes using the GetSoftMaxFlashV2MinTmpSize/GetSoftMaxFlashV2MaxTmpSize API provided in SoftmaxFlashV2 Tiling. The minimum space can ensure correct functionality, while the maximum space is used to improve performance.

In addition, an API for tiling computation in the kernel provided. When the input shape in the kernel is different from the shape transferred through TilingData on the host, this API can be used to recompute tiling in the kernel. For details about the parameters of this API, see SoftmaxFlashV2 Tiling.

  • Tiling computation API in the kernel
    1
    __aicore__ inline constexpr SoftMaxTiling SoftMaxFlashV2TilingFunc(const SoftMaxShapeInfo& shapeInfo, const uint32_t dataTypeSize1, const uint32_t dataTypeSize2, const uint32_t localWorkSpaceSize, const bool isUpdate = false, const bool isBasicBlock = false, const bool isDataFormatNZ = false, const bool isFlashOutputBrc = false)
    

Parameters

Table 1 Template parameters

Parameter

Description

T

Data type of the operand.

For the Atlas A3 training products / Atlas A3 inference products , the supported data types are half and float.

For the Atlas A2 training products / Atlas A2 inference products , the supported data types are half and float.

For the Atlas 200I/500 A2 inference products , the supported data types are half and float.

For the Atlas inference product 's AI Core, the supported data types are half and float.

isUpdate

Whether to enable computation in the update part.

isReuseSource

This parameter is reserved. Pass the default value false.

isBasicBlock

If the shape information and tiling strategy of both srcTensor and dstTensor meet the base block requirements, this parameter can be enabled to improve performance. By default, this parameter is disabled. Use either of the following methods to determine whether the base block requirements are met:

  • The shape information [m, n] of srcTensor and dstTensor must meet the following requirements:
    • The last axis length n is less than 2048 and greater than or equal to 256/sizeof(T). That is, the minimum value of n is 128 when the data type is half and 64 when the data type is float. In addition, n is a multiple of 64.
    • The product m of non-last axis lengths is a multiple of 8.
  • You can call IsBasicBlockInSoftMax to check whether the tiling strategy meets the tiling requirements of base blocks.

For the Atlas 200I/500 A2 inference products , this parameter is reserved for future use. Retain the default value.

isDataFormatNZ

Whether the current input and output data is in NZ format. The default data format is ND, that is, the default value of this parameter is false.

For the Atlas 200I/500 A2 inference products , the NZ format is not supported.

config

(Optional) structure template parameter, which is of the SoftmaxConfig type. The definition is as follows:

1
2
3
4
5
6
struct SoftmaxConfig{
bool isCheckTiling = true; // Whether to check the consistency between the shape and tiling. If they are inconsistent, the API re-computes the required tiling based on the shape. The default value is true, indicating that the API checks the consistency internally.
uint32_t oriSrcM = 0; // Product of the original non-last axis lengths. After this parameter is set, the shape is turned into a constant value, and the constant shape is used at compile time.
uint32_t oriSrcK = 0; // Original last axis length. After this parameter is set, the shape is turned into a constant value, and the constant shape is used at compile time.
SoftmaxMode mode = SoftmaxMode::SOFTMAX_NORMAL; // Processing mode of the output shape
};

The mode parameter indicates the processing mode of the output shape. This parameter is not supported when the input and output data is in NZ format. SoftmaxMode type. The values are as follows:

  • SOFTMAX_NORMAL (default): normal mode. In this mode, the output data is broadcast, so that the output shape is extended from (m, 1) to (m, 8) (float data type) or (m, 16) (half data type).
  • SOFTMAX_OUTPUT_WITHOUT_BRC: non-extended mode. In this mode, the output data is not broadcast. The output shape is always (m, 1), and the shape of the corresponding input parameters (such as inExpSumTensor and inMaxTensor) is also (m, 1).

A configuration example is as follows:

1
constexpr SoftmaxConfig SOFTMAX_DEFAULT_CFG = {true, 0, 0, SoftmaxMode::SOFTMAX_NORMAL};

This parameter is used together with the tiling computation API in the kernel.

Note: After oriSrcM and oriSrcK are set, isBasicBlock does not take effect. In this case, whether the computation data is a base block is determined and processed by the API.

For the Atlas A3 training products / Atlas A3 inference products , this parameter is supported.

For the Atlas A2 training products / Atlas A2 inference products , this parameter is supported.

For the Atlas 200I/500 A2 inference products , this parameter is reserved for future use. Retain the default value.

For the Atlas inference product 's AI Core, this parameter is supported, but mode cannot be configured.

Table 2 API parameters

Parameter

Input/Output

Description

dstTensor

Output

Destination operand.

The type is LocalTensor, and the supported TPosition is VECIN, VECCALC, or VECOUT.

The shape of dstTensor is the same as that of the source operand srcTensor.

outReduceMax

Output

Destination operand. It is used to store the result of the first reducemax computation in the softmax computation.

The type is LocalTensor, and the supported TPosition is VECIN, VECCALC, or VECOUT.

The shape of outReduceMax is the same as that of the destination operand maxTensor.

For the API that outputs the result:

  • If the template parameter isUpdate is set to false, the result is not output.
  • Only the ND format is supported for the input and output. The template parameter isDataFormatNZ is reserved. You can pass the default value false.
  • The template parameter config.isCheckTiling is reserved. You can pass the default value false.
  • The template parameter config.mode can only be set to the non-extended mode SoftmaxMode::SOFTMAX_OUTPUT_WITHOUT_BRC.

expSumTensor, outExpSum

Output

Destination operand. It is used to store the reducesum result during softmax computation.

The type is LocalTensor, and the supported TPosition is VECIN, VECCALC, or VECOUT.

  • The last axis length of expSumTensor is fixed at 32 bytes, which is the length of a data block, except when the model parameter config is set to the non-extended mode SoftmaxMode::SOFTMAX_OUTPUT_WITHOUT_BRC. All data in this data block shares a common value. For example, in the half data type, all 16 numbers in this data block possess an identical reducesum value.
  • The length of each non-last axis is the same as that of dstTensor.

maxTensor, outMax

Output

Destination operand. It is used to store the reducemax result during softmax computation.

The type is LocalTensor, and the supported TPosition is VECIN, VECCALC, or VECOUT.

  • The last axis length of maxTensor is fixed at 32 bytes, which is the length of a data block, except when the model parameter config is set to the non-extended mode SoftmaxMode::SOFTMAX_OUTPUT_WITHOUT_BRC. All data in this data block has the same value. For example, in the half data type, all 16 numbers in this data block possess an identical reducemax value.
  • The length of each non-last axis is the same as that of dstTensor.

srcTensor

Input

Source operand.

The type is LocalTensor, and the supported TPosition is VECIN, VECCALC, or VECOUT.

The length of the last axis must be 32-byte aligned.

expMaxTensor, outExpMax

Output

Destination operand. It is used to store the exponentiation (base e) of the difference between inmax and reducemax.

The type is LocalTensor, and the supported TPosition is VECIN, VECCALC, or VECOUT.

  • The last axis length of expMaxTensor is fixed at 32 bytes, which is the length of a data block, except when the model parameter config is set to the non-extended mode SoftmaxMode::SOFTMAX_OUTPUT_WITHOUT_BRC. All data in this data block shares a common value. For example, in the half data type, all 16 numbers in this data block possess an identical value.
  • The length of each non-last axis is the same as that of dstTensor.

inExpSumTensor, inExpSum

Input

Source operand. This parameter indicates the sum value required for softmax computation.

The type is LocalTensor, and the supported TPosition is VECIN, VECCALC, or VECOUT.

  • The last axis length of inExpSumTensor is fixed at 32 bytes, which is the length of a data block, except when the model parameter config is set to the non-extended mode SoftmaxMode::SOFTMAX_OUTPUT_WITHOUT_BRC. All data in this data block shares a common value. For example, in the half data type, all 16 numbers in this data block possess an identical value.
  • The length of each non-last axis is the same as that of dstTensor.

inMaxTensor, inMax

Input

Source operand. This parameter indicates the max value required for softmax computation.

The type is LocalTensor, and the supported TPosition is VECIN, VECCALC, or VECOUT.

  • The last axis length of inMaxTensor is fixed at 32 bytes, which is the length of a data block, except when the model parameter config is set to the non-extended mode SoftmaxMode::SOFTMAX_OUTPUT_WITHOUT_BRC. All data in this data block shares a common value. For example, in the half data type, all 16 numbers in this data block possess an identical value.
  • The length of each non-last axis is the same as that of dstTensor.

sharedTmpBuffer

Input

Temporary space.

The type is LocalTensor, and the supported TPosition is VECIN, VECCALC, or VECOUT.

The data type of this operand is fixed at uint8_t.

This parameter is used to store intermediate variables during complex internal API computation and is provided by developers.

For details about how to obtain the temporary space size (BufferSize), see SoftmaxFlashV2 Tiling.

tiling

Input

Tiling information required for softmaxflashv2 computation. For details about how to obtain the tiling information, see SoftmaxFlashV2 Tiling.

softmaxShapeInfo

Input

Shape of srcTensor, which is of the SoftMaxShapeInfo type. The specific definition is as follows:

1
2
3
4
5
6
struct SoftMaxShapeInfo {
uint32_t srcM; // Product of lengths of non-last axes.
uint32_t srcK; // Length of the last axis, which must be 32-byte aligned.
uint32_t oriSrcM; // Product of lengths of original non-last axes.
uint32_t oriSrcK; // Length of the original last axis.
};

Note that when the input and output data is in NZ format, the last axis length is the length of the reduce axis, that is, W0 × W1 in Figure 2 and the length of each non-last axis is H0 × H1.

Returns

None

Restrictions

  • The tensor space of srcTensor and dstTensor, maxTensor and inMaxTensor, and expSumTensor and inExpSumTensor can be reused.
  • Except when the template parameter config is set to a non-extended mode (SoftmaxMode::SOFTMAX_OUTPUT_WITHOUT_BRC), the length of the last axis must be fixed at 32 bytes for the tensor space of expSumTensor, maxTensor, expMaxTensor, inExpSumTensor, and inMaxTensor.
  • For the API that outputs ReduceMax:
    • The template parameters isReuseSource, isDataFormatNZ, and config.isCheckTiling are reserved.
    • config.mode can only be set to the non-extended mode SOFTMAX_OUTPUT_WITHOUT_BRC. If it is set to the SOFTMAX_NORMAL mode, the API function is not executed, and the outputs are not saved.
    • When the template parameter isUpdate is set to false, outReduceMax is not output.
    • Except outReduceMax, the computation result of each output is the same as that of the API that does not output ReduceMax.
  • For details about the operand address alignment requirements, see General Address Alignment Restrictions.
  • The address of sharedTmpBuffer must not overlap that of the source or destination operand.
  • When srcM ! is set to oriSrcM or srcK ! is set to oriSrcK in softmaxShapeInfo, for the original input (oriSrcM, oriSrcK) on the GM, you need to pad data to (srcM, srcK) in the M or K direction. The padded data will be involved in some computation. In the scenario where the input and output are reused, the computation result of the API will overwrite the original data padded to the srcTensor. In the scenario where the input and output are not reused, the computation result of the API will overwrite the data in dstTensor corresponding to the padded position of srcTensor.

Example

In this example, the shape size of the input srcTensor and output dstTensor is [320, 64], the shape size of the input inSumTensor, input inMaxTensor, and output expMaxTensor is [320, 16]. The data type is half, and the format of the input and output data is ND. The space of srcTensor and dstTensor cannot be reused, the base block is disabled, and isUpdate is set to true.
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
#include "kernel_operator.h"

// constexpr AscendC::SoftmaxConfig static_config = {true, 320, 64}; Used for a constant shape
template <typename T>
class KernelSoftmaxFlashV2 {
public:
    __aicore__ inline KernelSoftmaxFlashV2()
    {}
    __aicore__ inline void Init(__gm__ uint8_t *srcGm, __gm__ uint8_t *inMaxGm, __gm__ uint8_t *inSumGm,
        __gm__ uint8_t *dstGm, const SoftMaxTiling &tilingData)
    {
        elementNumPerBlk = 32 / sizeof(T);
        srcGlobal.SetGlobalBuffer((__gm__ T *)srcGm);
        dstGlobal.SetGlobalBuffer((__gm__ T *)dstGm);
        maxGlobal.SetGlobalBuffer((__gm__ T *)inMaxGm);
        sumGlobal.SetGlobalBuffer((__gm__ T *)inSumGm);
        pipe.InitBuffer(inQueueSrc, 1, height * width * sizeof(T));
        pipe.InitBuffer(maxQueue, 1, height * elementNumPerBlk * sizeof(T));
        pipe.InitBuffer(sumQueue, 1, height * elementNumPerBlk * sizeof(T));
        pipe.InitBuffer(expMaxQueue, 1, height * elementNumPerBlk * sizeof(T));
        pipe.InitBuffer(outQueueDst, 1, height * width * sizeof(T));
        tiling = tilingData;
    }
    __aicore__ inline void Process()
    {
        CopyIn();
        Compute();
        CopyOut();
    }

private:
    __aicore__ inline void CopyIn()
    {
        AscendC::LocalTensor<T> srcLocal = inQueueSrc.AllocTensor<T>();
        AscendC::LocalTensor<T> insumLocal = sumQueue.AllocTensor<T>();
        AscendC::LocalTensor<T> inmaxLocal = maxQueue.AllocTensor<T>();
        AscendC::DataCopy(srcLocal, srcGlobal, height * width);
        AscendC::DataCopy(insumLocal, sumGlobal, height * elementNumPerBlk);
        AscendC::DataCopy(inmaxLocal, maxGlobal, height * elementNumPerBlk);
        inQueueSrc.EnQue(srcLocal);
        sumQueue.EnQue(insumLocal);
        maxQueue.EnQue(inmaxLocal);
    }
    __aicore__ inline void Compute()
    {
        AscendC::LocalTensor<T> srcLocal = inQueueSrc.DeQue<T>();
        AscendC::LocalTensor<T> insumLocal = sumQueue.DeQue<T>();
        AscendC::LocalTensor<T> inmaxLocal = maxQueue.DeQue<T>();
        AscendC::LocalTensor<T> expMaxTensor = expMaxQueue.AllocTensor<T>();
        AscendC::LocalTensor<T> dstLocal = outQueueDst.AllocTensor<T>();
        AscendC::SoftMaxShapeInfo srcShape = {height, width, height, width};
        AscendC::SoftmaxFlashV2<T, true>(
            dstLocal, insumLocal, inmaxLocal, srcLocal, expMaxTensor, insumLocal, inmaxLocal, tiling, srcShape);
        //AscendC::SoftmaxFlashV2<T, true, false, false, false, static_config>(dstLocal, insumLocal, inmaxLocal, srcLocal,
//expMaxTensor, insumLocal, inmaxLocal, tiling, srcShape); Use the static_config parameter of the SoftmaxConfig type and pass the template parameter to turn the shape into a constant value.
        outQueueDst.EnQue<T>(dstLocal);
        maxQueue.FreeTensor(inmaxLocal);
        sumQueue.FreeTensor(insumLocal);
        inQueueSrc.FreeTensor(srcLocal);
    }
    __aicore__ inline void CopyOut()
    {
        AscendC::LocalTensor<T> dstLocal = outQueueDst.DeQue<T>();
        AscendC::DataCopy(dstGlobal, dstLocal, height * width);
        outQueueDst.FreeTensor(dstLocal);
    }

private:
    AscendC::TPipe pipe;
    AscendC::TQue<AscendC::TPosition::VECIN, 1> inQueueSrc;
    AscendC::TQue<AscendC::TPosition::VECIN, 1> maxQueue;
    AscendC::TQue<AscendC::TPosition::VECIN, 1> sumQueue;
    AscendC::TQue<AscendC::TPosition::VECIN, 1> expMaxQueue;
    AscendC::TQue<AscendC::TPosition::VECOUT, 1> outQueueDst;
    AscendC::GlobalTensor<T> srcGlobal, dstGlobal;
    AscendC::GlobalTensor<T> maxGlobal, sumGlobal;
    uint32_t elementNumPerBlk = 0;
    uint32_t width = 64;
    uint32_t height = 320;
    SoftMaxTiling tiling;
};

extern "C" __global__ __aicore__ void softmax_flashv2_generic_kernel_half(__gm__ uint8_t *srcGm,
    __gm__ uint8_t *inMaxGm, __gm__ uint8_t *inSumGm, __gm__ uint8_t *dstGm, __gm__ uint8_t *tiling)
{
    GET_TILING_DATA(tilingData, tiling);
    KernelSoftmaxFlashV2<half> op;
    op.Init(srcGm, inMaxGm, inSumGm, dstGm, tilingData.softmaxTilingData);
    op.Process();
}