昇腾社区首页
中文
注册

AscendDequant

功能说明

按元素做反量化计算,计算公式如下,其中PAR表示矢量计算单元一个迭代能够处理的元素个数,deq_scale_cnt表示矢量deqScale含有的元素个数 。

  • 输入deqScale的数据类型是uint64_t,数值低32位是参与计算的数据,数据类型是float,数值高32位是一些控制参数,本接口不使用。
  • 要求输入数据srcTensor参与计算的元素个数是deqScale元素个数的整数倍。

定义原型

  • 通过sharedTmpBuffer入参传入临时空间
    • 源操作数Tensor全部/部分参与计算

      __aicore__ inline void AscendDequant(const LocalTensor<half>& dstTensor, const LocalTensor<int32_t>& srcTensor, const LocalTensor<uint64_t>& deqScale, const LocalTensor<uint8_t>& sharedTmpBuffer, const uint32_t calCount)

    • 源操作数Tensor全部参与计算

      __aicore__ inline void AscendDequant(const LocalTensor<half>& dstTensor, const LocalTensor<int32_t>& srcTensor, const LocalTensor<uint64_t>& deqScale, const LocalTensor<uint8_t>& sharedTmpBuffer)

  • 接口框架申请临时空间
    • 源操作数Tensor全部/部分参与计算

      __aicore__ inline void AscendDequant(const LocalTensor<half>& dstTensor, const LocalTensor<int32_t>& srcTensor, const LocalTensor<uint64_t>& deqScale, const uint32_t calCount)

    • 源操作数Tensor全部参与计算

      __aicore__ inline void AscendDequant(const LocalTensor<half>& dstTensor, const LocalTensor<int32_t>& srcTensor, const LocalTensor<uint64_t>& deqScale)

由于该接口的内部实现中涉及复杂的数学计算,需要额外的临时空间来存储计算过程中的中间变量。临时空间支持接口框架申请和开发者通过sharedTmpBuffer入参传入两种方式。

  • 接口框架申请临时空间,开发者无需申请,但是需要预留临时空间的大小。
  • 通过sharedTmpBuffer入参传入,使用该tensor作为临时空间进行处理,接口框架不再申请。该方式开发者可以自行管理sharedTmpBuffer内存空间,并在接口调用完成后,复用该部分内存,内存不会反复申请释放,灵活性较高,内存利用率也较高。

接口框架申请的方式,开发者需要预留临时空间;通过sharedTmpBuffer传入的情况,开发者需要为sharedTmpBuffer申请空间。临时空间大小BufferSize的获取方式如下:通过AscendDequant Tiling中提供的GetAscendDequantMaxMinTmpSize接口获取需要预留空间的范围大小。

参数说明

表1 接口参数说明

参数名

输入/输出

描述

dstTensor

输出

目的操作数。

类型为LocalTensor,支持的TPosition为VECIN/VECCALC/VECOUT。

Atlas A2训练系列产品,支持的数据类型为:half

srcTensor

输入

源操作数。

类型为LocalTensor,支持的TPosition为VECIN/VECCALC/VECOUT。

Atlas A2训练系列产品,支持的数据类型为:int32_t

deqScale

输入

源操作数。

类型为LocalTensor,支持的TPosition为VECIN/VECCALC/VECOUT。

Atlas A2训练系列产品,支持的数据类型为:uint64_t

sharedTmpBuffer

输入

临时缓存。

类型为LocalTensor,支持的TPosition为VECIN/VECCALC/VECOUT。

临时空间大小BufferSize的获取方式请参考AscendQuant Tiling

Atlas A2训练系列产品,支持的数据类型为:uint8_t

calCount

输入

实际计算数据元素个数,calCount∈[0, srcTensor.GetSize()], 且必须为deqScale元素个数的整数倍。

返回值

支持的型号

Atlas A2训练系列产品

约束说明

  • 源操作数与目的操作数允许同时使用(即地址重叠)。
  • 操作数地址偏移对齐要求请参见通用约束
  • 输入输出操作数参与计算的数据长度要求32B对齐。

调用示例

constexpr uint32_t CYCLES = 4;
class KernelAscendDequant {
public:
    __aicore__ inline KernelAscendDequant() {}
    __aicore__ inline void Init(GM_ADDR src_gm, GM_ADDR dst_gm, GM_ADDR deq_scale_gm, uint32_t inputSize)
    {
        dataSize = inputSize;
        src_global.SetGlobalBuffer(reinterpret_cast<__gm__ int32_t*>(src_gm), dataSize * CYCLES);
        deq_scale_global.SetGlobalBuffer(reinterpret_cast<__gm__ uint64_t*>(deq_scale_gm), dataSize);
        dst_global.SetGlobalBuffer(reinterpret_cast<__gm__ half*>(dst_gm), dataSize * CYCLES);
        pipe.InitBuffer(inQueueX, 1, CYCLES * dataSize * sizeof(int32_t));
        pipe.InitBuffer(inQueueDeqscale, 1, dataSize * sizeof(uint64_t));
        pipe.InitBuffer(outQueue, 1, CYCLES * dataSize * sizeof(half));
    }
    __aicore__ inline void Process()
    {
        CopyIn();
        Compute();
        CopyOut();
    }
private:
    __aicore__ inline void CopyIn()
    {
        LocalTensor<int32_t> srcLocal = inQueueX.AllocTensor<int32_t>();
        DataCopy(srcLocal, src_global, dataSize * CYCLES);
        LocalTensor<uint64_t> deqscaleLocal = inQueueDeqscale.AllocTensor<uint64_t>();
        DataCopy(deqscaleLocal, deq_scale_global, dataSize);
        inQueueX.EnQue(srcLocal);
        inQueueDeqscale.EnQue(deqscaleLocal);
    }
    __aicore__ inline void Compute()
    {
        LocalTensor<half> dstLocal = outQueue.AllocTensor<half>();
        LocalTensor<int32_t> srcLocal = inQueueX.DeQue<int32_t>();
        LocalTensor<uint64_t> deqscaleLocal = inQueueDeqscale.DeQue<uint64_t>();
        AscendDequant(dstLocal, srcLocal, deqscaleLocal);
        outQueue.EnQue<half>(dstLocal);
        inQueueX.FreeTensor(srcLocal);
        inQueueDeqscale.FreeTensor(deqscaleLocal);
    }
    __aicore__ inline void CopyOut()
    {
        LocalTensor<half> dstLocal = outQueue.DeQue<half>();
        DataCopy(dst_global, dstLocal, dataSize * CYCLES);
        outQueue.FreeTensor(dstLocal);
    }
private:
    GlobalTensor<int32_t> src_global;
    GlobalTensor<uint64_t> deq_scale_global;
    GlobalTensor<half> dst_global;
    TPipe pipe;
    TQue<QuePosition::VECIN, 1> inQueueX;
    TQue<QuePosition::VECIN, 1> inQueueDeqscale;
    TQue<QuePosition::VECOUT, 1> outQueue;
    uint32_t dataSize = 0;
};
__aicore__ void kernel_ascend_dequant_operator(GM_ADDR src_gm, GM_ADDR dst_gm, GM_ADDR deq_scale_gm, uint32_t dataSize)
{
    KernelAscendDequant op;
    op.Init(src_gm, dst_gm, deq_scale_gm, dataSize);
    op.Process();
}
结果示例如下:
输入数据(srcLocal):  [-8  5 -5 -7 -3 -8  3  6  9  2 -5  0  0 -5 -7  0 -6  0 -2  3 -2  8  5  2
  2  2 -4  5 -4  4 -8  3  1  1 -3 -6  0 -8  6  4 -9 -8 -9  9  1 -6 -2  7
  0  9 -7  4  8  8  8  3  8  0 -1 -9  7 -6  4  9  7  6  1 -2 -6 -1  5 -1
 -7 -6  9 -4  7 -6  8 -1  3  4 -1  8 -8  0 -5 -3  4 -2  2  9 -1 -8 -8 -5
 -4 -2  6 -8  6  6 -3  3  9  9  3  7 -3  2 -6 -3  9  5  6 -7 -9  7  4  7
  7  5  0 -8 -3  2  1 -4]
反量化参数deqScale:  [ 4684667042262804261 13859763255858939167 13913477933643455555
 13907151579958429368 13897949676949746767  4582719314807759850
 13853283086508392697 13909797127326894857 13866799982091814874
  4540299872743366446  4607979737595448655 13909334383254771263
 13839508170748780740 13842303019803952926 13730250582932487241
 13912302074676945322 13909810673658225786  4655783565419604228
  4501315002210063326 13883028479528669462  4671684893717131608
 13878358404816523577 13881140985283325552  4527412879213716782
  4683436237743530696  4674205262125954062 13905975420345761780
  4677333589610067412 13907265529723024752 13905295284355041002
 13817041585540784494 13902864429505525097]
反量化参数deqScale(转换为float):  [-8.868932    8.205042   -3.9339063  -3.371414   -8.866275   -9.41271
 -7.4773827  -8.0079775  -6.7563853  -6.9823723  -3.752192    1.1960124
 -3.0391219  -3.0116923   2.274111   -8.5954075   1.9514115  -3.7620308
 -4.956931    0.5093182   1.1488436   1.89713     2.1107328  -8.492658
  1.1943593  -2.247031   -5.8524313  -2.4021761  -6.3276715  -0.27275205
 -8.789469   -9.151617    3.34207    -8.598415    1.8742986   4.895811
  6.939925    0.24207361 -7.5020857  -5.3257866  -4.545086    6.661213
  5.166165   -4.8073044   7.4275436  -5.116233   -8.993452    0.4152376
 -0.44753098  7.965874   -3.1164584   6.94103    -8.359364   -7.873409
  5.6862583   7.2883444  -2.7024803  -8.033279    7.172719   -7.797899
 -7.6681433  -1.4999424  -9.45054    -7.5280194 ]
输出数据(dstLocal):  [ 70.94  -19.67   44.34   52.34   20.27   30.02   -9.12   13.65   17.56
  -9.914  -5.746   0.      0.     29.27   44.28   -0.    -20.05    0.
 -13.88  -22.5     9.09   41.34   37.12  -17.98   -0.895  -6.234  33.44
  28.44   10.81   28.69   61.34  -28.36   -8.87   -3.934  26.6    44.88
  -0.     30.02  -18.23    9.09  -17.56   39.66  -10.336  19.      1.194
  35.12   12.66  -61.53    0.     16.88  -48.6   -30.02  -36.38   41.34
  59.4   -26.98   -3.58   -0.      8.36  -51.2   -18.92  -43.03  -30.67
 -85.06  -62.1   -23.61   -8.87   14.95   40.53    3.752 -15.195  -2.273
 -13.66   29.73   10.336  -8.445   8.36   35.12  -50.62    8.79   10.02
   7.496  -6.94  -60.03   36.38    0.    -37.12   26.98   -1.79    6.234
 -16.72   51.2     2.703 -57.38   61.34   47.25   35.47    7.867 -53.2
  59.8   -40.53  -22.52    9.12    6.824  17.56  -44.62    3.447  14.77
  -3.584 -11.7    37.97   26.38   30.08    9.375  41.62   52.5    40.9
  36.16   29.7   -62.97   -3.133 -15.586  -0.    -45.5     8.11   14.34
  -7.668  37.8  ]