接口功能:目的数据类型为FLOAT4类、FLOAT8类的MX量化。在给定的轴axis上,根据每blocksize个数,计算出这组数对应的量化尺度mxscale作为输出mxscaleOut的对应部分,然后对这组数每一个除以mxscale,根据round_mode转换到对应的dstType,得到量化结果y作为输出yOut的对应部分。在dstType为FLOAT8_E4M3FN、FLOAT8_E5M2时,根据scaleAlg的取值来指定计算mxscale的不同算法。
计算公式:
- 场景1,当scaleAlg为0时:
- 将输入x在axis维度上按k = blocksize个数分组,一组k个数 动态量化为 , k = blocksize
量化后的 按对应的 的位置组成输出yOut,mxscale按对应的axis维度上的分组组成输出mxscaleOut。
emax: 对应数据类型的最大正则数的指数位。
[object Object]undefined
- 场景2,当scaleAlg为1时,只涉及FP8类型:
- 将长向量按块分,每块长度为k,对每块单独计算一个块缩放因子,再把块内所有元素用同一个映射到目标低精度类型FP8。如果最后一块不足k个元素,把缺失值视为0,按照完整块处理。
- 找到该块中数值的最大绝对值:
- 将FP32映射到目标数据类型FP8可表示的范围内,其中是目标精度能表示的最大值
- 将块缩放因子转换为FP8格式下可表示的缩放值
- 从块的浮点缩放因子中提取无偏指数和尾数
- 为保证量化时不溢出,对指数进行向上取整,且在FP8可表示的范围内:
- 计算块缩放因子:
- 计算块转换因子:
- 应用到量化的最终步骤,对于每个块内元素,,最终输出的量化结果是,其中代表块的缩放因子,这里指,代表块内量化后的数据。
- 场景1,当scaleAlg为0时:
每个算子分为,必须先调用“aclnnDynamicMxQuantGetWorkspaceSize”接口获取计算所需workspace大小以及包含了算子计算流程的执行器,再调用“aclnnDynamicMxQuant”接口执行计算。
[object Object]
[object Object]
- 确定性计算:
- aclnnDynamicMxQuant默认确定性实现。
- 关于x、mxscaleOut的shape约束说明如下:
- rank(mxscaleOut) = rank(x) + 1。
- axis_change = axis if axis >= 0 else axis + rank(x)。
- mxscaleOut.shape[axis_change] = (ceil(x.shape[axis] / blocksize) + 2 - 1) / 2。
- mxscaleOut.shape[-1] = 2。
- 其他维度与输入x一致。
[object Object]