AscendDequant

功能说明

按元素做反量化计算,计算公式如下,其中PAR表示矢量计算单元一个迭代能够处理的元素个数,deq_scale_cnt表示矢量deqScale含有的元素个数 。

定义原型

由于该接口的内部实现中涉及复杂的数学计算,需要额外的临时空间来存储计算过程中的中间变量。临时空间支持接口框架申请和开发者通过sharedTmpBuffer入参传入两种方式。

接口框架申请的方式,开发者需要预留临时空间;通过sharedTmpBuffer传入的情况,开发者需要为sharedTmpBuffer申请空间。临时空间大小BufferSize的获取方式如下:通过AscendDequant Tiling中提供的GetAscendDequantMaxMinTmpSize接口获取需要预留空间的范围大小。

参数说明

表1 接口参数说明

参数名

输入/输出

描述

dstTensor

输出

目的操作数。

类型为LocalTensor,支持的TPosition为VECIN/VECCALC/VECOUT。

Atlas A2训练系列产品,支持的数据类型为:half

srcTensor

输入

源操作数。

类型为LocalTensor,支持的TPosition为VECIN/VECCALC/VECOUT。

Atlas A2训练系列产品,支持的数据类型为:int32_t

deqScale

输入

源操作数。

类型为LocalTensor,支持的TPosition为VECIN/VECCALC/VECOUT。

Atlas A2训练系列产品,支持的数据类型为:uint64_t

sharedTmpBuffer

输入

临时缓存。

类型为LocalTensor,支持的TPosition为VECIN/VECCALC/VECOUT。

临时空间大小BufferSize的获取方式请参考AscendQuant Tiling

Atlas A2训练系列产品,支持的数据类型为:uint8_t

calCount

输入

实际计算数据元素个数,calCount∈[0, srcTensor.GetSize()], 且必须为deqScale元素个数的整数倍。

返回值

支持的型号

Atlas A2训练系列产品

约束说明

调用示例

constexpr uint32_t CYCLES = 4;
class KernelAscendDequant {
public:
    __aicore__ inline KernelAscendDequant() {}
    __aicore__ inline void Init(GM_ADDR src_gm, GM_ADDR dst_gm, GM_ADDR deq_scale_gm, uint32_t inputSize)
    {
        dataSize = inputSize;
        src_global.SetGlobalBuffer(reinterpret_cast<__gm__ int32_t*>(src_gm), dataSize * CYCLES);
        deq_scale_global.SetGlobalBuffer(reinterpret_cast<__gm__ uint64_t*>(deq_scale_gm), dataSize);
        dst_global.SetGlobalBuffer(reinterpret_cast<__gm__ half*>(dst_gm), dataSize * CYCLES);
        pipe.InitBuffer(inQueueX, 1, CYCLES * dataSize * sizeof(int32_t));
        pipe.InitBuffer(inQueueDeqscale, 1, dataSize * sizeof(uint64_t));
        pipe.InitBuffer(outQueue, 1, CYCLES * dataSize * sizeof(half));
    }
    __aicore__ inline void Process()
    {
        CopyIn();
        Compute();
        CopyOut();
    }
private:
    __aicore__ inline void CopyIn()
    {
        LocalTensor<int32_t> srcLocal = inQueueX.AllocTensor<int32_t>();
        DataCopy(srcLocal, src_global, dataSize * CYCLES);
        LocalTensor<uint64_t> deqscaleLocal = inQueueDeqscale.AllocTensor<uint64_t>();
        DataCopy(deqscaleLocal, deq_scale_global, dataSize);
        inQueueX.EnQue(srcLocal);
        inQueueDeqscale.EnQue(deqscaleLocal);
    }
    __aicore__ inline void Compute()
    {
        LocalTensor<half> dstLocal = outQueue.AllocTensor<half>();
        LocalTensor<int32_t> srcLocal = inQueueX.DeQue<int32_t>();
        LocalTensor<uint64_t> deqscaleLocal = inQueueDeqscale.DeQue<uint64_t>();
        AscendDequant(dstLocal, srcLocal, deqscaleLocal);
        outQueue.EnQue<half>(dstLocal);
        inQueueX.FreeTensor(srcLocal);
        inQueueDeqscale.FreeTensor(deqscaleLocal);
    }
    __aicore__ inline void CopyOut()
    {
        LocalTensor<half> dstLocal = outQueue.DeQue<half>();
        DataCopy(dst_global, dstLocal, dataSize * CYCLES);
        outQueue.FreeTensor(dstLocal);
    }
private:
    GlobalTensor<int32_t> src_global;
    GlobalTensor<uint64_t> deq_scale_global;
    GlobalTensor<half> dst_global;
    TPipe pipe;
    TQue<QuePosition::VECIN, 1> inQueueX;
    TQue<QuePosition::VECIN, 1> inQueueDeqscale;
    TQue<QuePosition::VECOUT, 1> outQueue;
    uint32_t dataSize = 0;
};
__aicore__ void kernel_ascend_dequant_operator(GM_ADDR src_gm, GM_ADDR dst_gm, GM_ADDR deq_scale_gm, uint32_t dataSize)
{
    KernelAscendDequant op;
    op.Init(src_gm, dst_gm, deq_scale_gm, dataSize);
    op.Process();
}
结果示例如下:
输入数据(srcLocal):  [-8  5 -5 -7 -3 -8  3  6  9  2 -5  0  0 -5 -7  0 -6  0 -2  3 -2  8  5  2
  2  2 -4  5 -4  4 -8  3  1  1 -3 -6  0 -8  6  4 -9 -8 -9  9  1 -6 -2  7
  0  9 -7  4  8  8  8  3  8  0 -1 -9  7 -6  4  9  7  6  1 -2 -6 -1  5 -1
 -7 -6  9 -4  7 -6  8 -1  3  4 -1  8 -8  0 -5 -3  4 -2  2  9 -1 -8 -8 -5
 -4 -2  6 -8  6  6 -3  3  9  9  3  7 -3  2 -6 -3  9  5  6 -7 -9  7  4  7
  7  5  0 -8 -3  2  1 -4]
反量化参数deqScale:  [ 4684667042262804261 13859763255858939167 13913477933643455555
 13907151579958429368 13897949676949746767  4582719314807759850
 13853283086508392697 13909797127326894857 13866799982091814874
  4540299872743366446  4607979737595448655 13909334383254771263
 13839508170748780740 13842303019803952926 13730250582932487241
 13912302074676945322 13909810673658225786  4655783565419604228
  4501315002210063326 13883028479528669462  4671684893717131608
 13878358404816523577 13881140985283325552  4527412879213716782
  4683436237743530696  4674205262125954062 13905975420345761780
  4677333589610067412 13907265529723024752 13905295284355041002
 13817041585540784494 13902864429505525097]
反量化参数deqScale(转换为float):  [-8.868932    8.205042   -3.9339063  -3.371414   -8.866275   -9.41271
 -7.4773827  -8.0079775  -6.7563853  -6.9823723  -3.752192    1.1960124
 -3.0391219  -3.0116923   2.274111   -8.5954075   1.9514115  -3.7620308
 -4.956931    0.5093182   1.1488436   1.89713     2.1107328  -8.492658
  1.1943593  -2.247031   -5.8524313  -2.4021761  -6.3276715  -0.27275205
 -8.789469   -9.151617    3.34207    -8.598415    1.8742986   4.895811
  6.939925    0.24207361 -7.5020857  -5.3257866  -4.545086    6.661213
  5.166165   -4.8073044   7.4275436  -5.116233   -8.993452    0.4152376
 -0.44753098  7.965874   -3.1164584   6.94103    -8.359364   -7.873409
  5.6862583   7.2883444  -2.7024803  -8.033279    7.172719   -7.797899
 -7.6681433  -1.4999424  -9.45054    -7.5280194 ]
输出数据(dstLocal):  [ 70.94  -19.67   44.34   52.34   20.27   30.02   -9.12   13.65   17.56
  -9.914  -5.746   0.      0.     29.27   44.28   -0.    -20.05    0.
 -13.88  -22.5     9.09   41.34   37.12  -17.98   -0.895  -6.234  33.44
  28.44   10.81   28.69   61.34  -28.36   -8.87   -3.934  26.6    44.88
  -0.     30.02  -18.23    9.09  -17.56   39.66  -10.336  19.      1.194
  35.12   12.66  -61.53    0.     16.88  -48.6   -30.02  -36.38   41.34
  59.4   -26.98   -3.58   -0.      8.36  -51.2   -18.92  -43.03  -30.67
 -85.06  -62.1   -23.61   -8.87   14.95   40.53    3.752 -15.195  -2.273
 -13.66   29.73   10.336  -8.445   8.36   35.12  -50.62    8.79   10.02
   7.496  -6.94  -60.03   36.38    0.    -37.12   26.98   -1.79    6.234
 -16.72   51.2     2.703 -57.38   61.34   47.25   35.47    7.867 -53.2
  59.8   -40.53  -22.52    9.12    6.824  17.56  -44.62    3.447  14.77
  -3.584 -11.7    37.97   26.38   30.08    9.375  41.62   52.5    40.9
  36.16   29.7   -62.97   -3.133 -15.586  -0.    -45.5     8.11   14.34
  -7.668  37.8  ]