SoftmaxFlashV2

功能说明

SoftmaxFlash增强版本,对应FlashAttention-2算法。当输入shape为ND格式时,内部的reduce过程按last轴进行;当输入shape为NZ格式时,内部的reduce过程按照last轴和first轴进行,reduce过程可以参考SoftMax中的图示说明。

为方便理解,通过python脚本实现的方式,表达其计算公式如下,其中src、inmax、 insum、update为输入,dst、x_sum、x_max、exp_max为输出。

def softmax_flash_2(src, inmax=None, insum=None, update=None):
if update == None:
    x_max = np.max(src, axis=-1, keepdims=True)
    x_sub = src - x_max   
    dst = np.exp(x_sub) 
    x_sum = np.sum(dst, axis=-1, keepdims=True)
    exp_max = None
    return dst, x_max, x_sum, exp_max
else:
    x_max = np.max(np.concatenate((inmax, src), axis=-1), axis=-1, keepdims=True)
    x_exp = np.exp(src - x_max)
    exp_max = np.exp(inmax - x_max)
    x_sum = np.sum(x_exp, axis=-1, keepdims=True)
    x_sum = exp_max * insum +  x_sum
    return x_exp, x_max, x_sum, exp_max

函数原型

由于该接口的内部实现中涉及复杂的计算,需要额外的临时空间来存储计算过程中的中间变量。临时空间支持接口框架申请和开发者通过sharedTmpBuffer入参传入两种方式。

接口框架申请的方式,开发者需要预留临时空间;通过sharedTmpBuffer传入的情况,开发者需要为tensor申请空间。临时空间大小BufferSize的获取方式如下:通过SoftmaxFlashV2 Tiling接口中提供的GetSoftMaxFlashV2MinTmpSize/GetSoftMaxFlashV2MaxTmpSize接口获取所需最小和最大临时空间大小,最小空间可以保证功能正确,最大空间用于提升性能。

参数说明

表1 接口参数说明

参数名

输入/输出

描述

dstTensor

输出

目的操作数。

类型为LocalTensor,支持的TPosition为VECIN/VECCALC/VECOUT。

Atlas A2训练系列产品,支持的数据类型为:half/float

Atlas推理系列产品AI Core,支持的数据类型为:half/float

dstTensor的shape和源操作数srcTensor一致。

sumTensor

输出

目的操作数。

类型为LocalTensor,支持的TPosition为VECIN/VECCALC/VECOUT。

Atlas A2训练系列产品,支持的数据类型为:half/float

Atlas推理系列产品AI Core,支持的数据类型为:half/float

用于保存softmax计算过程中reducesum的结果。

  • sumTensor的last轴长度固定为32Byte,即一个block长度。该block中的所有数据为同一个值,比如half数据类型下,该block中的16个数均为相同的reducesum的值。
  • 非last轴的长度与dstTensor保持一致。

maxTensor

输出

目的操作数。

类型为LocalTensor,支持的TPosition为VECIN/VECCALC/VECOUT。

Atlas A2训练系列产品,支持的数据类型为:half/float

Atlas推理系列产品AI Core,支持的数据类型为:half/float

用于保存softmax计算过程中reducemax的结果。

  • maxTensor的last轴长度固定为32Byte,即一个block长度。该block中的所有数据为同一个值。比如half数据类型下,该block中的16个数均为相同的reducemax的值。
  • 非last轴的长度与dstTensor保持一致。

srcTensor

输入

源操作数。

类型为LocalTensor,支持的TPosition为VECIN/VECCALC/VECOUT。

Atlas A2训练系列产品,支持的数据类型为:half/float

Atlas推理系列产品AI Core,支持的数据类型为:half/float

last轴长度需要32Byte对齐。

expMaxTensor

输出

目的操作数。

类型为LocalTensor,支持的TPosition为VECIN/VECCALC/VECOUT。

Atlas A2训练系列产品,支持的数据类型为:half/float

Atlas推理系列产品AI Core,支持的数据类型为:half/float

  • expMaxTensor的last轴长度固定为32Byte,即一个block长度。该block中的所有数据为同一个值,比如half数据类型下,该block中的16个数均为相同的值。
  • 非last轴的长度需要与dstTensor保持一致。

inSumTensor

输入

源操作数。

类型为LocalTensor,支持的TPosition为VECIN/VECCALC/VECOUT。

Atlas A2训练系列产品,支持的数据类型为:half/float

Atlas推理系列产品AI Core,支持的数据类型为:half/float

softmax计算所需要的sum值。

  • inSumTensor的last轴长度固定为32Byte,即一个block长度。该block中的所有数据为同一个值,比如half数据类型下,该block中的16个数均为相同的值。
  • 非last轴的长度需要与dstTensor保持一致。

inMaxTensor

输入

源操作数。

类型为LocalTensor,支持的TPosition为VECIN/VECCALC/VECOUT。

Atlas A2训练系列产品,支持的数据类型为:half/float

Atlas推理系列产品AI Core,支持的数据类型为:half/float

softmax计算所需要的max值。

  • inMaxTensor的last轴长度固定为32Byte,即一个block长度。该block中的所有数据为同一个值,比如half数据类型下,该block里的16个数均为相同的值。
  • 非last轴的长度需要与dstTensor保持一致。

tiling

输入

softmaxflashv2接口计算所需tiling信息,Tiling信息的获取请参考SoftmaxFlashV2 Tiling接口

softmaxShapeInfo

输入

srcTensor的shape信息。SoftMaxShapeInfo类型,具体定义如下:

struct SoftMaxShapeInfo {
uint32_t srcM; // 非尾轴长度的乘积
uint32_t srcK; // 尾轴长度,必须32Byte对齐
uint32_t oriSrcM; // 原始非尾轴长度的乘积
uint32_t oriSrcK;  // 原始尾轴长度
};

需要注意,当输入输出的数据格式为NZ格式时,尾轴长度为reduce轴长度即图2中的W0*W1,非尾轴为H0*H1。

isUpdate

输入

是否使能update部分中的计算。

isReuseSource

输入

dstTensor是否复用srcTensor的空间。

isBasicBlock

输入

srcTensor和dstTensor的shape信息和Tiling切分策略满足基本块要求的情况下,可以使能该参数用于提升性能,默认不使能。基本块要求如下:

  • srcTensor和dstTensor的shape信息需要满足如下条件:last轴长度小于2048并且是64的倍数,非last轴长度的乘积为8的倍数。
  • 通过调用IsBasicBlockInSoftMax判断Tiling切分策略是否满足基本块的切分要求。

isDataFormatNZ

输入

当前输入输出的数据格式是否为NZ格式,默认数据格式为ND,即默认取值为false。

返回值

支持的型号

Atlas A2训练系列产品

Atlas推理系列产品AI Core

注意事项

调用示例

本样例中输入srcTensor和输出dstTensor的shape大小为[320,64],输入inSumTensor、inMaxTensor的shape大小为[320,16],输出expMaxTensor的shape大小为[320,16],数据类型均为half,输入输出的数据排布格式为ND,srcTensor和dstTensor空间不复用,不使能基本块,isUpdate为true。
#include "kernel_operator.h

namespace AscendC {
template <typename T> class KernelSoftmaxFlashV2 {
public:
    __aicore__ inline KernelSoftmaxFlashV2() {}
    __aicore__ inline void Init(__gm__ uint8_t* srcGm, __gm__ uint8_t* inMaxGm, __gm__ uint8_t* inSumGm,
        __gm__ uint8_t* dstGm, const SoftMaxTiling& tilingData)
    {
        elementNumPerBlk = 32 / sizeof(T);
        srcGlobal.SetGlobalBuffer((__gm__ T*)srcGm);
        dstGlobal.SetGlobalBuffer((__gm__ T*)dstGm);
        maxGlobal.SetGlobalBuffer((__gm__ T*)inMaxGm);
        sumGlobal.SetGlobalBuffer((__gm__ T*)inSumGm);

        pipe.InitBuffer(inQueueSrc, 1, height * width * sizeof(T));
        pipe.InitBuffer(maxQueue, 1, height * elementNumPerBlk * sizeof(T));
        pipe.InitBuffer(sumQueue, 1, height * elementNumPerBlk * sizeof(T));
        pipe.InitBuffer(expMaxQueue, 1, height * elementNumPerBlk * sizeof(T));
        pipe.InitBuffer(outQueueDst, 1, height * width * sizeof(T));
        tiling = tilingData;
    }
    __aicore__ inline void Process()
    {
        CopyIn();
        Compute();
        CopyOut();
    }

private:
    __aicore__ inline void CopyIn()
    {
        LocalTensor<T> srcLocal = inQueueSrc.AllocTensor<T>();
        LocalTensor<T> insumLocal = sumQueue.AllocTensor<T>();
        LocalTensor<T> inmaxLocal = maxQueue.AllocTensor<T>();
        DataCopy(srcLocal, srcGlobal, height * width);
        DataCopy(insumLocal, sumGlobal, height * elementNumPerBlk);
        DataCopy(inmaxLocal, maxGlobal, height * elementNumPerBlk);
        inQueueSrc.EnQue(srcLocal);
        sumQueue.EnQue(insumLocal);
        maxQueue.EnQue(inmaxLocal);
    }
    __aicore__ inline void Compute()
    {
        LocalTensor<T> srcLocal = inQueueSrc.DeQue<T>();
        LocalTensor<T> insumLocal = sumQueue.DeQue<T>();
        LocalTensor<T> inmaxLocal = maxQueue.DeQue<T>();
        LocalTensor<T> expMaxTensor = expMaxQueue.AllocTensor<T>();
        LocalTensor<T> dstLocal = outQueueDst.AllocTensor<T>();

        SoftMaxShapeInfo srcShape = { height, width, height, width };
        SoftmaxFlashV2<T, true>(dstLocal, insumLocal, inmaxLocal, srcLocal, expMaxTensor, insumLocal, inmaxLocal, tiling, srcShape);

        outQueueDst.EnQue<T>(dstLocal);
        maxQueue.FreeTensor(inmaxLocal);
        sumQueue.FreeTensor(insumLocal);
        inQueueSrc.FreeTensor(srcLocal);
    }
    __aicore__ inline void CopyOut()
    {
        LocalTensor<T> dstLocal = outQueueDst.DeQue<T>();
        DataCopy(dstGlobal, dstLocal, height * width);
        outQueueDst.FreeTensor(dstLocal);
    }

private:
    TPipe pipe;
    TQue<QuePosition::VECIN, 1> inQueueSrc;
    TQue<QuePosition::VECIN, 1> maxQueue;
    TQue<QuePosition::VECIN, 1> sumQueue;
    TQue<QuePosition::VECIN, 1> expMaxQueue;
    TQue<QuePosition::VECOUT, 1> outQueueDst;
    GlobalTensor<T> srcGlobal, dstGlobal;
    GlobalTensor<T> maxGlobal, sumGlobal;
    uint32_t elementNumPerBlk = 0;
    uint32_t width = 64;
    uint32_t height = 320;
    SoftMaxTiling tiling;
};
} // namespace AscendC

extern "C" __global__ __aicore__ void softmax_flashv2_generic_kernel_half(__gm__ uint8_t* srcGm,
    __gm__ uint8_t* inMaxGm, __gm__ uint8_t* inSumGm, __gm__ uint8_t* dstGm, __gm__ uint8_t* tiling)
{
    GET_TILING_DATA(tilingData, tiling);
    AscendC::KernelSoftmaxFlashV2<half> op;
    op.Init(srcGm, inMaxGm, inSumGm, dstGm, tilingData.softmaxTilingData);
    op.Process();
}