SoftmaxFlashV2
功能说明
SoftmaxFlash增强版本,对应FlashAttention-2算法。将输入tensor[m0, m1, ...mt, n](t大于等于0)的非尾轴长度相乘的结果看作m,则输入tensor的shape看作[m, n]。对输入tensor[m,n]按行做如下计算,不同的update值对应不同的计算公式,其中M\S\E均为输出,inmax和insum为输入。
当输入shape为ND格式时,内部的reduce过程按last轴进行;当输入shape为NZ格式时,内部的reduce过程按照last轴和first轴进行,reduce过程可以参考SoftMax中的图示说明。
为方便理解,通过python脚本实现的方式,表达其计算公式如下,其中src、inmax、 insum、update为输入,dst、x_sum、x_max、exp_max为输出。
def softmax_flash_2(src, inmax=None, insum=None, update=None):
    if update == None:
        x_max = np.max(src, axis=-1, keepdims=True)
        x_sub = src - x_max   
        dst = np.exp(x_sub) 
        x_sum = np.sum(dst, axis=-1, keepdims=True)
        exp_max = None
        return dst, x_max, x_sum, exp_max
    else:
        x_max = np.max(np.concatenate((inmax, src), axis=-1), axis=-1, keepdims=True)
        x_exp = np.exp(src - x_max)
        exp_max = np.exp(inmax - x_max)
        x_sum = np.sum(x_exp, axis=-1, keepdims=True)
        x_sum = exp_max * insum +  x_sum
        return x_exp, x_max, x_sum, exp_max
   实现原理
以float类型,ND格式,shape为[m, k]的输入Tensor为例,描述SoftmaxFlashV2高阶API内部算法框图,如下图所示。
 
    计算过程根据isUpdate是否使能分为两个分支处理,均在Vector上进行。
- 当isUpdate为False时,分为如下几步:
- reducemax步骤:对输入x的每一行数据求最大值得到[m, 1],计算结果会保存到一个临时空间temp中;
- broadcast步骤:对temp中的数据([m, 1])做一个按datablock为单位的填充,比如float类型下,把[m, 1]扩展成[m, 8],同时输出max;
- sub步骤:对输入x的所有数据按行减去max;
- exp步骤:对sub之后的所有数据求exp,并且输出y;
- reducesum步骤:对exp结果的每一行数据求和得到[m, 1],计算结果会保存到临时空间temp中;
- broadcast步骤:对temp([m, 1])做一个按datablock为单位的填充,比如float类型下,把[m, 1]扩展成[m, 8],同时输出sum。
- 当isUpdate为True时,分为如下几步:
- reducemax步骤:对输入x的每一行数据求最大值得到[m, 1],计算结果会保存到一个临时空间temp中;
- broadcast步骤:对temp中的数据([m, 1])做一个按datablock为单位的填充,比如float类型下,把[m, 1]扩展成[m, 8],保存为max;
- max步骤:对输入inmax和上一步计算的max做max操作,得到新的max并输出;
- sub步骤:将输入inmax和新的max相减,然后做exp,计算得到expmax并输出;
- sub步骤:将输入x和新的max按行相减;
- exp步骤:对sub之后的所有数据求exp,并且输出y;
- reducesum步骤:对exp结果的每一行数据求和得到[m, 1],计算结果会保存到临时空间temp中;
- broadcast步骤:对temp数据([m, 1])做一个按datablock为单位的填充,比如float类型下,把[m, 1]扩展成[m, 8],保存到sum中;
- mul步骤:将输入insum和expmax结果相乘;
- add步骤:将相乘结果和sum相加,保存到sum并输出。
函数原型
- 接口框架申请临时空间
      1 2 template <typename T, bool isUpdate = false, bool isReuseSource = false, bool isBasicBlock = false, bool isDataFormatNZ = false, const SoftmaxConfig& config = SOFTMAX_DEFAULT_CFG> __aicore__ inline void SoftmaxFlashV2(const LocalTensor<T>& dstTensor, const LocalTensor<T>& expSumTensor, const LocalTensor<T>& maxTensor, const LocalTensor<T>& srcTensor, const LocalTensor<T>& expMaxTensor, const LocalTensor<T>& inExpSumTensor, const LocalTensor<T>& inMaxTensor, const SoftMaxTiling& tiling, const SoftMaxShapeInfo& softmaxShapeInfo = {}) 1 2 template <typename T, bool isUpdate = false, bool isReuseSource = false, bool isBasicBlock = false, bool isDataFormatNZ = false, const SoftmaxConfig& config = SOFTMAX_DEFAULT_CFG> __aicore__ inline void SoftmaxFlashV2(const LocalTensor<half>& dstTensor, const LocalTensor<float>& expSumTensor, const LocalTensor<float>& maxTensor, const LocalTensor<half>& srcTensor, const LocalTensor<half>& expMaxTensor, const LocalTensor<float>& inExpSumTensor, const LocalTensor<float>& inMaxTensor, const SoftMaxTiling& tiling, const SoftMaxShapeInfo& softmaxShapeInfo = {}) 
- 通过sharedTmpBuffer入参传入临时空间
      1 2 template <typename T, bool isUpdate = false, bool isReuseSource = false, bool isBasicBlock = false, bool isDataFormatNZ = false, const SoftmaxConfig& config = SOFTMAX_DEFAULT_CFG> __aicore__ inline void SoftmaxFlashV2(const LocalTensor<T>& dstTensor, const LocalTensor<T>& expSumTensor, const LocalTensor<T>& maxTensor, const LocalTensor<T>& srcTensor, const LocalTensor<T>& expMaxTensor, const LocalTensor<T>& inExpSumTensor, const LocalTensor<T>& inMaxTensor, const LocalTensor<uint8_t>& sharedTmpBuffer, const SoftMaxTiling& tiling, const SoftMaxShapeInfo& softmaxShapeInfo = {}) 1 2 template <typename T, bool isUpdate = false, bool isReuseSource = false, bool isBasicBlock = false, bool isDataFormatNZ = false, const SoftmaxConfig& config = SOFTMAX_DEFAULT_CFG> __aicore__ inline void SoftmaxFlashV2(const LocalTensor<half>& dstTensor, const LocalTensor<float>& expSumTensor, const LocalTensor<float>& maxTensor, const LocalTensor<half>& srcTensor, const LocalTensor<half>& expMaxTensor, const LocalTensor<float>& inExpSumTensor, const LocalTensor<float>& inMaxTensor, const LocalTensor<uint8_t>& sharedTmpBuffer, const SoftMaxTiling& tiling, const SoftMaxShapeInfo& softmaxShapeInfo = {}) 
由于该接口的内部实现中涉及复杂的计算,需要额外的临时空间来存储计算过程中的中间变量。临时空间支持接口框架申请和开发者通过sharedTmpBuffer入参传入两种方式。
- 接口框架申请临时空间,开发者无需申请,但是需要预留临时空间的大小。
- 通过sharedTmpBuffer入参传入,使用该tensor作为临时空间进行处理,接口框架不再申请。该方式开发者可以自行管理sharedTmpBuffer内存空间,并在接口调用完成后,复用该部分内存,内存不会反复申请释放,灵活性较高,内存利用率也较高。
接口框架申请的方式,开发者需要预留临时空间;通过sharedTmpBuffer传入的情况,开发者需要为tensor申请空间。临时空间大小BufferSize的获取方式如下:通过SoftmaxFlashV2 Tiling接口中提供的GetSoftMaxFlashV2MinTmpSize/GetSoftMaxFlashV2MaxTmpSize接口获取所需最小和最大临时空间大小,最小空间可以保证功能正确,最大空间用于提升性能。
另外提供了一个kernel侧tiling计算的接口,当kernel侧的输入shape与通过host侧TilingData传入的shape不一致时,可使用该接口在kernel侧重新计算tiling。
- kernel侧tiling计算接口
      1__aicore__ inline constexpr SoftMaxTiling SoftMaxFlashV2TilingFunc(const SoftMaxShapeInfo& shapeInfo, const uint32_t dataTypeSize1, const uint32_t dataTypeSize2, const uint32_t localWorkSpaceSize, const bool isUpdate = false, const bool isBasicBlock = false, const bool isDataFormatNZ = false) 
参数说明
| 参数名 | 描述 | 
|---|---|
| T | 操作数的数据类型。 | 
| isUpdate | 是否使能update部分中的计算。 | 
| isReuseSource | 预留参数,暂未启用,为后续的功能扩展做保留,必须使用默认值。 | 
| isBasicBlock | srcTensor和dstTensor的shape信息和Tiling切分策略满足基本块要求的情况下,可以使能该参数用于提升性能,默认不使能。基本块要求如下: 
 针对Atlas 200/500 A2推理产品,该参数为预留参数,暂未启用,为后续的功能扩展做保留,保持默认值即可。 | 
| isDataFormatNZ | 当前输入输出的数据格式是否为NZ格式,默认数据格式为ND,即默认取值为false。 针对Atlas 200/500 A2推理产品,不支持配置为NZ格式。 | 
| SoftmaxConfig | 结构体模板参数,此参数可选配,具体定义如下: struct SoftmaxConfig{
bool isCheckTiling = true; // 是否需要检查shape和tiling的一致性;若不一致,API内会根据shape重新计算所需tiling。默认取值true:API内部会检查一致性
uint32_t oriSrcM = 0; // 预留参数,保持默认值即可
uint32_t oriSrcK = 0; // 预留参数,保持默认值即可
};配置示例如下: constexpr SoftmaxConfig SOFTMAX_DEFAULT_CFG = {true, 0, 0};此参数一般用于配合kernel侧tiling计算的接口使用。 | 
| 参数名 | 输入/输出 | 描述 | 
|---|---|---|
| dstTensor | 输出 | 目的操作数。 类型为LocalTensor,支持的TPosition为VECIN/VECCALC/VECOUT。 Atlas A2训练系列产品/Atlas 800I A2推理产品,支持的数据类型为:half/float Atlas推理系列产品AI Core,支持的数据类型为:half/float Atlas 200/500 A2推理产品,支持的数据类型为:half/float dstTensor的shape和源操作数srcTensor一致。 | 
| expSumTensor | 输出 | 目的操作数。 类型为LocalTensor,支持的TPosition为VECIN/VECCALC/VECOUT。 Atlas A2训练系列产品/Atlas 800I A2推理产品,支持的数据类型为:half/float Atlas推理系列产品AI Core,支持的数据类型为:half/float Atlas 200/500 A2推理产品,支持的数据类型为:half/float 用于保存softmax计算过程中reducesum的结果。 
 | 
| maxTensor | 输出 | 目的操作数。 类型为LocalTensor,支持的TPosition为VECIN/VECCALC/VECOUT。 Atlas A2训练系列产品/Atlas 800I A2推理产品,支持的数据类型为:half/float Atlas推理系列产品AI Core,支持的数据类型为:half/float Atlas 200/500 A2推理产品,支持的数据类型为:half/float 用于保存softmax计算过程中reducemax的结果。 
 | 
| srcTensor | 输入 | 源操作数。 类型为LocalTensor,支持的TPosition为VECIN/VECCALC/VECOUT。 Atlas A2训练系列产品/Atlas 800I A2推理产品,支持的数据类型为:half/float Atlas推理系列产品AI Core,支持的数据类型为:half/float Atlas 200/500 A2推理产品,支持的数据类型为:half/float last轴长度需要32Byte对齐。 | 
| expMaxTensor | 输出 | 目的操作数。 类型为LocalTensor,支持的TPosition为VECIN/VECCALC/VECOUT。 Atlas A2训练系列产品/Atlas 800I A2推理产品,支持的数据类型为:half/float Atlas推理系列产品AI Core,支持的数据类型为:half/float Atlas 200/500 A2推理产品,支持的数据类型为:half/float 
 | 
| inExpSumTensor | 输入 | 源操作数。 类型为LocalTensor,支持的TPosition为VECIN/VECCALC/VECOUT。 Atlas A2训练系列产品/Atlas 800I A2推理产品,支持的数据类型为:half/float Atlas推理系列产品AI Core,支持的数据类型为:half/float Atlas 200/500 A2推理产品,支持的数据类型为:half/float softmax计算所需要的sum值。 
 | 
| inMaxTensor | 输入 | 源操作数。 类型为LocalTensor,支持的TPosition为VECIN/VECCALC/VECOUT。 Atlas A2训练系列产品/Atlas 800I A2推理产品,支持的数据类型为:half/float Atlas推理系列产品AI Core,支持的数据类型为:half/float Atlas 200/500 A2推理产品,支持的数据类型为:half/float softmax计算所需要的max值。 
 | 
| tiling | 输入 | softmaxflashv2接口计算所需tiling信息,Tiling信息的获取请参考SoftmaxFlashV2 Tiling接口。 | 
| softmaxShapeInfo | 输入 | srcTensor的shape信息。SoftMaxShapeInfo类型,具体定义如下: struct SoftMaxShapeInfo {
uint32_t srcM; // 非尾轴长度的乘积
uint32_t srcK; // 尾轴长度,必须32Byte对齐
uint32_t oriSrcM; // 原始非尾轴长度的乘积
uint32_t oriSrcK;  // 原始尾轴长度
};需要注意,当输入输出的数据格式为NZ格式时,尾轴长度为reduce轴长度即图2中的W0*W1,非尾轴为H0*H1。 | 
返回值
无
支持的型号
Atlas A2训练系列产品/Atlas 800I A2推理产品
Atlas推理系列产品AI Core
Atlas 200/500 A2推理产品
注意事项
- srcTensor和dstTensor的Tensor的空间可以复用,maxTensor和inMaxTensor的空间可以复用,sumTensor和inSumTensor的空间可以复用。
- sumTensor、maxTensor、expMaxTensor、inSumTensor、inMaxTensor的Tensor空间,last轴长度必须固定32Byte。
- 操作数地址偏移对齐要求请参见通用约束。
调用示例
#include "kernel_operator.h"
template <typename T>
class KernelSoftmaxFlashV2 {
public:
    __aicore__ inline KernelSoftmaxFlashV2()
    {}
    __aicore__ inline void Init(__gm__ uint8_t *srcGm, __gm__ uint8_t *inMaxGm, __gm__ uint8_t *inSumGm,
        __gm__ uint8_t *dstGm, const SoftMaxTiling &tilingData)
    {
        elementNumPerBlk = 32 / sizeof(T);
        srcGlobal.SetGlobalBuffer((__gm__ T *)srcGm);
        dstGlobal.SetGlobalBuffer((__gm__ T *)dstGm);
        maxGlobal.SetGlobalBuffer((__gm__ T *)inMaxGm);
        sumGlobal.SetGlobalBuffer((__gm__ T *)inSumGm);
        pipe.InitBuffer(inQueueSrc, 1, height * width * sizeof(T));
        pipe.InitBuffer(maxQueue, 1, height * elementNumPerBlk * sizeof(T));
        pipe.InitBuffer(sumQueue, 1, height * elementNumPerBlk * sizeof(T));
        pipe.InitBuffer(expMaxQueue, 1, height * elementNumPerBlk * sizeof(T));
        pipe.InitBuffer(outQueueDst, 1, height * width * sizeof(T));
        tiling = tilingData;
    }
    __aicore__ inline void Process()
    {
        CopyIn();
        Compute();
        CopyOut();
    }
private:
    __aicore__ inline void CopyIn()
    {
        AscendC::LocalTensor<T> srcLocal = inQueueSrc.AllocTensor<T>();
        AscendC::LocalTensor<T> insumLocal = sumQueue.AllocTensor<T>();
        AscendC::LocalTensor<T> inmaxLocal = maxQueue.AllocTensor<T>();
        AscendC::DataCopy(srcLocal, srcGlobal, height * width);
        AscendC::DataCopy(insumLocal, sumGlobal, height * elementNumPerBlk);
        AscendC::DataCopy(inmaxLocal, maxGlobal, height * elementNumPerBlk);
        inQueueSrc.EnQue(srcLocal);
        sumQueue.EnQue(insumLocal);
        maxQueue.EnQue(inmaxLocal);
    }
    __aicore__ inline void Compute()
    {
        AscendC::LocalTensor<T> srcLocal = inQueueSrc.DeQue<T>();
        AscendC::LocalTensor<T> insumLocal = sumQueue.DeQue<T>();
        AscendC::LocalTensor<T> inmaxLocal = maxQueue.DeQue<T>();
        AscendC::LocalTensor<T> expMaxTensor = expMaxQueue.AllocTensor<T>();
        AscendC::LocalTensor<T> dstLocal = outQueueDst.AllocTensor<T>();
        AscendC::SoftMaxShapeInfo srcShape = {height, width, height, width};
        AscendC::SoftmaxFlashV2<T, true>(
            dstLocal, insumLocal, inmaxLocal, srcLocal, expMaxTensor, insumLocal, inmaxLocal, tiling, srcShape);
        outQueueDst.EnQue<T>(dstLocal);
        maxQueue.FreeTensor(inmaxLocal);
        sumQueue.FreeTensor(insumLocal);
        inQueueSrc.FreeTensor(srcLocal);
    }
    __aicore__ inline void CopyOut()
    {
        AscendC::LocalTensor<T> dstLocal = outQueueDst.DeQue<T>();
        AscendC::DataCopy(dstGlobal, dstLocal, height * width);
        outQueueDst.FreeTensor(dstLocal);
    }
private:
    AscendC::TPipe pipe;
    AscendC::TQue<AscendC::QuePosition::VECIN, 1> inQueueSrc;
    AscendC::TQue<AscendC::QuePosition::VECIN, 1> maxQueue;
    AscendC::TQue<AscendC::QuePosition::VECIN, 1> sumQueue;
    AscendC::TQue<AscendC::QuePosition::VECIN, 1> expMaxQueue;
    AscendC::TQue<AscendC::QuePosition::VECOUT, 1> outQueueDst;
    AscendC::GlobalTensor<T> srcGlobal, dstGlobal;
    AscendC::GlobalTensor<T> maxGlobal, sumGlobal;
    uint32_t elementNumPerBlk = 0;
    uint32_t width = 64;
    uint32_t height = 320;
    SoftMaxTiling tiling;
};
extern "C" __global__ __aicore__ void softmax_flashv2_generic_kernel_half(__gm__ uint8_t *srcGm,
    __gm__ uint8_t *inMaxGm, __gm__ uint8_t *inSumGm, __gm__ uint8_t *dstGm, __gm__ uint8_t *tiling)
{
    GET_TILING_DATA(tilingData, tiling);
    KernelSoftmaxFlashV2<half> op;
    op.Init(srcGm, inMaxGm, inSumGm, dstGm, tilingData.softmaxTilingData);
    op.Process();
}
    
