SoftmaxFlash

功能说明

softMax增强版本,除了可以对输入tensor做softmaxflash计算,同时还可以根据上一次softmax计算的expSum和max来更新本次的softmax计算结果,适用于last轴切轴的情况,因对last轴求reduce是需要全轴计算,切last轴的话每次计算的reduce结果并非是全轴的。当前仅支持传入shape为ND格式,内部的reduce过程都是按last轴进行。

函数原型

template <typename T, bool isReuseSource = false, bool isBasicBlock = false>>

void SoftmaxFlash(const LocalTensor<T>& dstTensor, const LocalTensor<T>& expSumTensor,

const LocalTensor<T>& maxTensor, const LocalTensor<T>& srcTensor, const LocalTensor<T>& expMaxTensor,

const LocalTensor<T>& inExpSumTensor, const LocalTensor<T>& inMaxTensor, SoftMaxTiling& tiling,

bool isUpdate = false)

参数说明

表1 接口参数说明

参数名

输入/输出

描述

dstTensor

输出

目的操作数,类型为LocalTensor,last轴长度需要32B对齐

expSumTensor

输出

目的操作数,类型为LocalTensor,softmaxflash计算后,刷新的新的expSum值,last轴长度固定32B

maxTensor

输出

目的操作数,类型为LocalTensor,softmaxflash计算后,刷新的新的max值,last轴长度固定32B

srcTensor

输入

源操作数,类型为LocalTensor,last轴长度需要32B对齐

expMaxTensor

输出

目的操作数,类型为LocalTensor,softmaxflash计算后,输出的expMax值,last轴长度固定32B

inExpSumTensor

输入

源操作数,类型为LocalTensor,softmaxflash计算需要的expSum值,last轴长度固定32B

inMaxTensor

输入

源操作数,类型为LocalTensor,softmaxflash计算需要的max值,last轴长度固定32B

tiling

输入

softmax计算所需tiling信息

isUpdate

输入

是否使能flash计算,若不使能,只计算softmax结果,功能类似softmax接口

isReuseSource

输入

是否复用src的空间

isBasicBlock

输入

srcTensor是否是基本块,可以使能该参数用于提升性能,基本块的定义为:满足last轴长度为128的倍数的输入tensor都可以称为基本块。

返回值

支持的型号

Atlas A2训练系列产品

注意事项

调用示例

本样例输入src的Shape大小为[80,144],输出Shape大小dst=[80,144],输入inExpSumTensor=[80,16],输入inMaxTensor=[80,16],输出expMaxTensor=[80,16],数据类型均为half。第一次flash接口调用是为了计算第二次flash接口所需的inExpSumTensor和inMaxTensor。
#include "kernel_operator.h"
using namespace AscendC;

namespace AscendC {

template <typename T> class KernelSoftmaxFlash {
public:
    __aicore__ inline KernelSoftmaxFlash() {}
    __aicore__ inline void Init(__gm__ uint8_t* src1Gm,__gm__ uint8_t* src2Gm, __gm__ uint8_t* dstGm)
    {
        elementNumPerBlk = 32 / sizeof(T);
        src1Global.SetGlobalBuffer((__gm__ T*)src1Gm);
        src2Global.SetGlobalBuffer((__gm__ T*)src2Gm);
        dstGlobal.SetGlobalBuffer((__gm__ T*)dstGm);
        pipe.InitBuffer(inQueueSrc1, 1, height*width * sizeof(T));
        pipe.InitBuffer(inQueueSrc2, 1, height*width * sizeof(T));
        pipe.InitBuffer(outQueueDst, 1, height*width * sizeof(T));
        pipe.InitBuffer(inMaxQueue, 1, height*elementNumPerBlk * sizeof(T));
        pipe.InitBuffer(inSumQueue, 1, height*elementNumPerBlk * sizeof(T));
        pipe.InitBuffer(expMaxQueue, 1, height*elementNumPerBlk * sizeof(T));

    }
    __aicore__ inline void Process()
    {
        CopyIn();
        Compute();
        CopyOut();
    }

private:

    __aicore__ inline void CopyIn()
    {
        LocalTensor<T> srcLocal1 = inQueueSrc1.AllocTensor<T>();
        LocalTensor<T> srcLocal2 = inQueueSrc2.AllocTensor<T>();
        DataCopy(srcLocal1, src1Global, height*width);
        DataCopy(srcLocal2, src2Global, height*width);
        inQueueSrc1.EnQue(srcLocal1);
        inQueueSrc2.EnQue(srcLocal2);
    }
    __aicore__ inline void Compute()
    {
        LocalTensor<T> srcLocal1 = inQueueSrc1.DeQue<T>();
        LocalTensor<T> srcLocal2 = inQueueSrc2.DeQue<T>();
        LocalTensor<T> dstLocal = outQueueDst.AllocTensor<T>();

        LocalTensor<T> inmaxLocal = inMaxQueue.AllocTensor<T>();
        LocalTensor<T> insumLocal = inSumQueue.AllocTensor<T>();
        LocalTensor<T> expMaxTensor = expMaxQueue.AllocTensor<T>();

        const uint32_t shapeDim = 2;
        uint32_t array[2] = {height, width};
        srcLocal1.SetShapeInfo(ShapeInfo(shapeDim, array));
        srcLocal2.SetShapeInfo(ShapeInfo(shapeDim, array));
        dstLocal.SetShapeInfo(ShapeInfo(shapeDim, array));

        array[0] = height;
        array[1] = elementNumPerBlk;
        insumLocal.SetShapeInfo(ShapeInfo(shapeDim, array));
        inmaxLocal.SetShapeInfo(ShapeInfo(shapeDim, array));
        expMaxTensor.SetShapeInfo(ShapeInfo(shapeDim, array));

        SoftMaxTiling tiling1; // 本示例tiling为演示用 实际内容需要通过Tiling Api获取
        SoftmaxFlash<T, false>(srcLocal1, insumLocal, inmaxLocal, srcLocal1, expMaxTensor, insumLocal, inmaxLocal, tiling1, false);
        
        SoftMaxTiling tiling2; // 本示例tiling为演示用 实际内容需要通过Tiling Api获取
        SoftmaxFlash<T, false>(srcLocal2, insumLocal, inmaxLocal, srcLocal2, expMaxTensor, insumLocal, inmaxLocal, tiling2, true);

        DataCopy(dstLocal, srcLocal2, height*width);

        outQueueDst.EnQue<T>(dstLocal);
        inMaxQueue.FreeTensor(inmaxLocal);
        inSumQueue.FreeTensor(insumLocal);
        inQueueSrc1.FreeTensor(srcLocal1);
        inQueueSrc2.FreeTensor(srcLocal2);
        expMaxQueue.FreeTensor(expMaxTensor);
    }
    __aicore__ inline void CopyOut()
    {
        LocalTensor<T> dstLocal = outQueueDst.DeQue<T>();
        DataCopy(dstGlobal, dstLocal, height*width);
        outQueueDst.FreeTensor(dstLocal);
    }

private:
    TPipe pipe;
    TQue<QuePosition::VECIN, 1> inQueueSrc1;
    TQue<QuePosition::VECIN, 1> inQueueSrc2;
    TQue<QuePosition::VECOUT, 1> outQueueDst;
    TQue<QuePosition::VECIN, 1> inMaxQueue;
    TQue<QuePosition::VECIN, 1> inSumQueue;
    TQue<QuePosition::VECIN, 1> expMaxQueue;

    GlobalTensor<T> src1Global,src2Global, dstGlobal;
    uint32_t elementNumPerBlk = 0;
    uint32_t width = 144;
    uint32_t height = 80;
};

}  // namespace AscendC


extern "C" __global__ __aicore__ void softmax_flash_kernel_half(__gm__ uint8_t *src1Gm, __gm__ uint8_t *src2Gm, __gm__ uint8_t *dstGm)
{
    AscendC::KernelSoftmaxFlash<half> op;
    op.Init(src1Gm,src2Gm, dstGm);
    op.Process();
}