按元素做反量化计算,计算公式如下,其中PAR表示矢量计算单元一个迭代能够处理的元素个数,deq_scale_cnt表示矢量deqScale含有的元素个数 。
__aicore__ inline void AscendDequant(const LocalTensor<half>& dstTensor, const LocalTensor<int32_t>& srcTensor, const LocalTensor<uint64_t>& deqScale, const LocalTensor<uint8_t>& sharedTmpBuffer, const uint32_t calCount);
__aicore__ inline void AscendDequant(const LocalTensor<half>& dstTensor, const LocalTensor<int32_t>& srcTensor, const LocalTensor<uint64_t>& deqScale, const LocalTensor<uint8_t>& sharedTmpBuffer);
__aicore__ inline void AscendDequant(const LocalTensor<half>& dstTensor, const LocalTensor<int32_t>& srcTensor, const LocalTensor<uint64_t>& deqScale, const uint32_t calCount);
__aicore__ inline void AscendDequant(const LocalTensor<half>& dstTensor, const LocalTensor<int32_t>& srcTensor, const LocalTensor<uint64_t>& deqScale);
由于该接口的内部实现中涉及复杂的数学计算,需要额外的临时空间来存储计算过程中的中间变量。临时空间支持接口框架申请和开发者通过sharedTmpBuffer入参传入两种方式。
接口框架申请的方式,开发者需要预留临时空间;通过sharedTmpBuffer传入的情况,开发者需要为sharedTmpBuffer申请空间。临时空间大小BufferSize的获取方式如下:通过AscendDequant Tiling中提供的GetAscendDequantMaxMinTmpSize接口获取需要预留空间的范围大小。
参数名 |
输入/输出 |
描述 |
---|---|---|
dstTensor |
输出 |
目的操作数。 类型为LocalTensor,支持的TPosition为VECIN/VECCALC/VECOUT。 Atlas A2训练系列产品,支持的数据类型为:half |
srcTensor |
输入 |
源操作数。 类型为LocalTensor,支持的TPosition为VECIN/VECCALC/VECOUT。 Atlas A2训练系列产品,支持的数据类型为:int32_t |
deqScale |
输入 |
源操作数。 类型为LocalTensor,支持的TPosition为VECIN/VECCALC/VECOUT。 Atlas A2训练系列产品,支持的数据类型为:uint64_t |
sharedTmpBuffer |
输入 |
临时缓存。 类型为LocalTensor,支持的TPosition为VECIN/VECCALC/VECOUT。 临时空间大小BufferSize的获取方式请参考AscendQuant Tiling。 Atlas A2训练系列产品,支持的数据类型为:uint8_t |
calCount |
输入 |
实际计算数据元素个数,calCount∈[0, srcTensor.GetSize()], 且必须为deqScale元素个数的整数倍。 |
无
Atlas A2训练系列产品
constexpr uint32_t CYCLES = 4; class KernelAscendDequant { public: __aicore__ inline KernelAscendDequant() {} __aicore__ inline void Init(GM_ADDR src_gm, GM_ADDR dst_gm, GM_ADDR deq_scale_gm, uint32_t inputSize) { dataSize = inputSize; src_global.SetGlobalBuffer(reinterpret_cast<__gm__ int32_t*>(src_gm), dataSize * CYCLES); deq_scale_global.SetGlobalBuffer(reinterpret_cast<__gm__ uint64_t*>(deq_scale_gm), dataSize); dst_global.SetGlobalBuffer(reinterpret_cast<__gm__ half*>(dst_gm), dataSize * CYCLES); pipe.InitBuffer(inQueueX, 1, CYCLES * dataSize * sizeof(int32_t)); pipe.InitBuffer(inQueueDeqscale, 1, dataSize * sizeof(uint64_t)); pipe.InitBuffer(outQueue, 1, CYCLES * dataSize * sizeof(half)); } __aicore__ inline void Process() { CopyIn(); Compute(); CopyOut(); } private: __aicore__ inline void CopyIn() { LocalTensor<int32_t> srcLocal = inQueueX.AllocTensor<int32_t>(); DataCopy(srcLocal, src_global, dataSize * CYCLES); LocalTensor<uint64_t> deqscaleLocal = inQueueDeqscale.AllocTensor<uint64_t>(); DataCopy(deqscaleLocal, deq_scale_global, dataSize); inQueueX.EnQue(srcLocal); inQueueDeqscale.EnQue(deqscaleLocal); } __aicore__ inline void Compute() { LocalTensor<half> dstLocal = outQueue.AllocTensor<half>(); LocalTensor<int32_t> srcLocal = inQueueX.DeQue<int32_t>(); LocalTensor<uint64_t> deqscaleLocal = inQueueDeqscale.DeQue<uint64_t>(); AscendDequant(dstLocal, srcLocal, deqscaleLocal); outQueue.EnQue<half>(dstLocal); inQueueX.FreeTensor(srcLocal); inQueueDeqscale.FreeTensor(deqscaleLocal); } __aicore__ inline void CopyOut() { LocalTensor<half> dstLocal = outQueue.DeQue<half>(); DataCopy(dst_global, dstLocal, dataSize * CYCLES); outQueue.FreeTensor(dstLocal); } private: GlobalTensor<int32_t> src_global; GlobalTensor<uint64_t> deq_scale_global; GlobalTensor<half> dst_global; TPipe pipe; TQue<QuePosition::VECIN, 1> inQueueX; TQue<QuePosition::VECIN, 1> inQueueDeqscale; TQue<QuePosition::VECOUT, 1> outQueue; uint32_t dataSize = 0; }; __aicore__ void kernel_ascend_dequant_operator(GM_ADDR src_gm, GM_ADDR dst_gm, GM_ADDR deq_scale_gm, uint32_t dataSize) { KernelAscendDequant op; op.Init(src_gm, dst_gm, deq_scale_gm, dataSize); op.Process(); }
输入数据(srcLocal): [-8 5 -5 -7 -3 -8 3 6 9 2 -5 0 0 -5 -7 0 -6 0 -2 3 -2 8 5 2 2 2 -4 5 -4 4 -8 3 1 1 -3 -6 0 -8 6 4 -9 -8 -9 9 1 -6 -2 7 0 9 -7 4 8 8 8 3 8 0 -1 -9 7 -6 4 9 7 6 1 -2 -6 -1 5 -1 -7 -6 9 -4 7 -6 8 -1 3 4 -1 8 -8 0 -5 -3 4 -2 2 9 -1 -8 -8 -5 -4 -2 6 -8 6 6 -3 3 9 9 3 7 -3 2 -6 -3 9 5 6 -7 -9 7 4 7 7 5 0 -8 -3 2 1 -4] 反量化参数deqScale: [ 4684667042262804261 13859763255858939167 13913477933643455555 13907151579958429368 13897949676949746767 4582719314807759850 13853283086508392697 13909797127326894857 13866799982091814874 4540299872743366446 4607979737595448655 13909334383254771263 13839508170748780740 13842303019803952926 13730250582932487241 13912302074676945322 13909810673658225786 4655783565419604228 4501315002210063326 13883028479528669462 4671684893717131608 13878358404816523577 13881140985283325552 4527412879213716782 4683436237743530696 4674205262125954062 13905975420345761780 4677333589610067412 13907265529723024752 13905295284355041002 13817041585540784494 13902864429505525097] 反量化参数deqScale(转换为float): [-8.868932 8.205042 -3.9339063 -3.371414 -8.866275 -9.41271 -7.4773827 -8.0079775 -6.7563853 -6.9823723 -3.752192 1.1960124 -3.0391219 -3.0116923 2.274111 -8.5954075 1.9514115 -3.7620308 -4.956931 0.5093182 1.1488436 1.89713 2.1107328 -8.492658 1.1943593 -2.247031 -5.8524313 -2.4021761 -6.3276715 -0.27275205 -8.789469 -9.151617 3.34207 -8.598415 1.8742986 4.895811 6.939925 0.24207361 -7.5020857 -5.3257866 -4.545086 6.661213 5.166165 -4.8073044 7.4275436 -5.116233 -8.993452 0.4152376 -0.44753098 7.965874 -3.1164584 6.94103 -8.359364 -7.873409 5.6862583 7.2883444 -2.7024803 -8.033279 7.172719 -7.797899 -7.6681433 -1.4999424 -9.45054 -7.5280194 ] 输出数据(dstLocal): [ 70.94 -19.67 44.34 52.34 20.27 30.02 -9.12 13.65 17.56 -9.914 -5.746 0. 0. 29.27 44.28 -0. -20.05 0. -13.88 -22.5 9.09 41.34 37.12 -17.98 -0.895 -6.234 33.44 28.44 10.81 28.69 61.34 -28.36 -8.87 -3.934 26.6 44.88 -0. 30.02 -18.23 9.09 -17.56 39.66 -10.336 19. 1.194 35.12 12.66 -61.53 0. 16.88 -48.6 -30.02 -36.38 41.34 59.4 -26.98 -3.58 -0. 8.36 -51.2 -18.92 -43.03 -30.67 -85.06 -62.1 -23.61 -8.87 14.95 40.53 3.752 -15.195 -2.273 -13.66 29.73 10.336 -8.445 8.36 35.12 -50.62 8.79 10.02 7.496 -6.94 -60.03 36.38 0. -37.12 26.98 -1.79 6.234 -16.72 51.2 2.703 -57.38 61.34 47.25 35.47 7.867 -53.2 59.8 -40.53 -22.52 9.12 6.824 17.56 -44.62 3.447 14.77 -3.584 -11.7 37.97 26.38 30.08 9.375 41.62 52.5 40.9 36.16 29.7 -62.97 -3.133 -15.586 -0. -45.5 8.11 14.34 -7.668 37.8 ]