softMax增强版本,除了可以对输入tensor做softmaxflash计算,同时还可以根据上一次softmax计算的expSum和max来更新本次的softmax计算结果,适用于last轴切轴的情况,因对last轴求reduce是需要全轴计算,切last轴的话每次计算的reduce结果并非是全轴的。当前仅支持传入shape为ND格式,内部的reduce过程都是按last轴进行。
template <typename T, bool isReuseSource = false, bool isBasicBlock = false>>
void SoftmaxFlash(const LocalTensor<T>& dstTensor, const LocalTensor<T>& expSumTensor,
const LocalTensor<T>& maxTensor, const LocalTensor<T>& srcTensor, const LocalTensor<T>& expMaxTensor,
const LocalTensor<T>& inExpSumTensor, const LocalTensor<T>& inMaxTensor, SoftMaxTiling& tiling,
bool isUpdate = false)
参数名 |
输入/输出 |
描述 |
---|---|---|
dstTensor |
输出 |
目的操作数,类型为LocalTensor,last轴长度需要32B对齐 |
expSumTensor |
输出 |
目的操作数,类型为LocalTensor,softmaxflash计算后,刷新的新的expSum值,last轴长度固定32B |
maxTensor |
输出 |
目的操作数,类型为LocalTensor,softmaxflash计算后,刷新的新的max值,last轴长度固定32B |
srcTensor |
输入 |
源操作数,类型为LocalTensor,last轴长度需要32B对齐 |
expMaxTensor |
输出 |
目的操作数,类型为LocalTensor,softmaxflash计算后,输出的expMax值,last轴长度固定32B |
inExpSumTensor |
输入 |
源操作数,类型为LocalTensor,softmaxflash计算需要的expSum值,last轴长度固定32B |
inMaxTensor |
输入 |
源操作数,类型为LocalTensor,softmaxflash计算需要的max值,last轴长度固定32B |
tiling |
输入 |
softmax计算所需tiling信息 |
isUpdate |
输入 |
是否使能flash计算,若不使能,只计算softmax结果,功能类似softmax接口 |
isReuseSource |
输入 |
是否复用src的空间 |
isBasicBlock |
输入 |
srcTensor是否是基本块,可以使能该参数用于提升性能,基本块的定义为:满足last轴长度为128的倍数的输入tensor都可以称为基本块。 |
无
Atlas A2训练系列产品
#include "kernel_operator.h" using namespace AscendC; namespace AscendC { template <typename T> class KernelSoftmaxFlash { public: __aicore__ inline KernelSoftmaxFlash() {} __aicore__ inline void Init(__gm__ uint8_t* src1Gm,__gm__ uint8_t* src2Gm, __gm__ uint8_t* dstGm) { elementNumPerBlk = 32 / sizeof(T); src1Global.SetGlobalBuffer((__gm__ T*)src1Gm); src2Global.SetGlobalBuffer((__gm__ T*)src2Gm); dstGlobal.SetGlobalBuffer((__gm__ T*)dstGm); pipe.InitBuffer(inQueueSrc1, 1, height*width * sizeof(T)); pipe.InitBuffer(inQueueSrc2, 1, height*width * sizeof(T)); pipe.InitBuffer(outQueueDst, 1, height*width * sizeof(T)); pipe.InitBuffer(inMaxQueue, 1, height*elementNumPerBlk * sizeof(T)); pipe.InitBuffer(inSumQueue, 1, height*elementNumPerBlk * sizeof(T)); pipe.InitBuffer(expMaxQueue, 1, height*elementNumPerBlk * sizeof(T)); } __aicore__ inline void Process() { CopyIn(); Compute(); CopyOut(); } private: __aicore__ inline void CopyIn() { LocalTensor<T> srcLocal1 = inQueueSrc1.AllocTensor<T>(); LocalTensor<T> srcLocal2 = inQueueSrc2.AllocTensor<T>(); DataCopy(srcLocal1, src1Global, height*width); DataCopy(srcLocal2, src2Global, height*width); inQueueSrc1.EnQue(srcLocal1); inQueueSrc2.EnQue(srcLocal2); } __aicore__ inline void Compute() { LocalTensor<T> srcLocal1 = inQueueSrc1.DeQue<T>(); LocalTensor<T> srcLocal2 = inQueueSrc2.DeQue<T>(); LocalTensor<T> dstLocal = outQueueDst.AllocTensor<T>(); LocalTensor<T> inmaxLocal = inMaxQueue.AllocTensor<T>(); LocalTensor<T> insumLocal = inSumQueue.AllocTensor<T>(); LocalTensor<T> expMaxTensor = expMaxQueue.AllocTensor<T>(); const uint32_t shapeDim = 2; uint32_t array[2] = {height, width}; srcLocal1.SetShapeInfo(ShapeInfo(shapeDim, array)); srcLocal2.SetShapeInfo(ShapeInfo(shapeDim, array)); dstLocal.SetShapeInfo(ShapeInfo(shapeDim, array)); array[0] = height; array[1] = elementNumPerBlk; insumLocal.SetShapeInfo(ShapeInfo(shapeDim, array)); inmaxLocal.SetShapeInfo(ShapeInfo(shapeDim, array)); expMaxTensor.SetShapeInfo(ShapeInfo(shapeDim, array)); SoftMaxTiling tiling1; // 本示例tiling为演示用 实际内容需要通过Tiling Api获取 SoftmaxFlash<T, false>(srcLocal1, insumLocal, inmaxLocal, srcLocal1, expMaxTensor, insumLocal, inmaxLocal, tiling1, false); SoftMaxTiling tiling2; // 本示例tiling为演示用 实际内容需要通过Tiling Api获取 SoftmaxFlash<T, false>(srcLocal2, insumLocal, inmaxLocal, srcLocal2, expMaxTensor, insumLocal, inmaxLocal, tiling2, true); DataCopy(dstLocal, srcLocal2, height*width); outQueueDst.EnQue<T>(dstLocal); inMaxQueue.FreeTensor(inmaxLocal); inSumQueue.FreeTensor(insumLocal); inQueueSrc1.FreeTensor(srcLocal1); inQueueSrc2.FreeTensor(srcLocal2); expMaxQueue.FreeTensor(expMaxTensor); } __aicore__ inline void CopyOut() { LocalTensor<T> dstLocal = outQueueDst.DeQue<T>(); DataCopy(dstGlobal, dstLocal, height*width); outQueueDst.FreeTensor(dstLocal); } private: TPipe pipe; TQue<QuePosition::VECIN, 1> inQueueSrc1; TQue<QuePosition::VECIN, 1> inQueueSrc2; TQue<QuePosition::VECOUT, 1> outQueueDst; TQue<QuePosition::VECIN, 1> inMaxQueue; TQue<QuePosition::VECIN, 1> inSumQueue; TQue<QuePosition::VECIN, 1> expMaxQueue; GlobalTensor<T> src1Global,src2Global, dstGlobal; uint32_t elementNumPerBlk = 0; uint32_t width = 144; uint32_t height = 80; }; } // namespace AscendC extern "C" __global__ __aicore__ void softmax_flash_kernel_half(__gm__ uint8_t *src1Gm, __gm__ uint8_t *src2Gm, __gm__ uint8_t *dstGm) { AscendC::KernelSoftmaxFlash<half> op; op.Init(src1Gm,src2Gm, dstGm); op.Process(); }