接口功能:在-1轴和-2轴上同时进行目的数据类型为FLOAT4类、FLOAT8类的MX量化。在给定的-1轴和-2轴上,每32个数,计算出这两组数对应的量化尺度mxscale1、mxscale2作为输出mxscale1Out、mxscale2Out的对应部分,然后分别对两组数所有元素除以对应的mxscale1或mxscale2,根据round_mode转换到对应的dstType,得到量化结果y1和y2分别作为输出y1Out和y2Out的对应部分。在dstType为FLOAT8_E4M3FN、FLOAT8_E5M2时,根据scaleAlg的取值来指定计算mxscale的不同算法。
合轴说明:算子实现时,会对-2轴(不包含)之前的所有轴进行合轴处理。即对于输入shape为的张量,-2轴之前的维度会被合并为一个维度,等效于将输入reshape为后再进行量化计算。
计算公式:
场景1,当scaleAlg为0时,即OCP Microscaling Formats (Mx) Specification实现:
将输入x在-1轴上按照32个数进行分组,一组32个数 量化为
同时,将输入x在-2轴上按照32个数进行分组,一组32个数 量化为
-1轴量化后的 按对应的 的位置组成输出y1Out,mxscale1按对应的-1轴维度上的分组组成输出mxscale1Out。-2轴量化后的 按对应的 的位置组成输出y2Out,mxscale2按对应的-2轴维度上的分组组成输出mxscale2Out。
emax: 对应数据类型的最大正则数的指数位。
[object Object]undefined
场景2,当scaleAlg为1时,只涉及FP8类型(CuBALS Scale计算算法):
- -1轴量化:将输入x在-1轴上按照32个数进行分组,每组长度为32,对每组单独计算一个块缩放因子,再把组内所有元素用同一个映射到目标低精度类型FP8。如果最后一组不足32个元素,把缺失值视为0,按照完整组处理。
- 找到该组中数值的最大绝对值:
- 将FP32映射到目标数据类型FP8可表示的范围内,其中是目标精度能表示的最大值
- 将块缩放因子转换为FP8格式下可表示的缩放值
- 从块的浮点缩放因子中提取无偏指数和尾数
- 为保证量化时不溢出,对指数进行向上取整,且在FP8可表示的范围内:
- 计算块缩放因子:
- 计算块转换因子:
- 应用到量化的最终步骤,对于每个组内元素,,最终-1轴输出的量化结果是,其中代表块的缩放因子,即,代表组内量化后的数据。
- -2轴量化:同时,将输入x在-2轴上按照32个数进行分组,采用与-1轴相同的CuBALS Scale计算算法,对每组独立计算块缩放因子并量化。-2轴输出的量化结果是。
- -1轴量化结果组成输出y1Out,对应的块缩放因子组成输出mxscale1Out。-2轴量化结果组成输出y2Out,对应的块缩放因子组成输出mxscale2Out。
- -1轴量化:将输入x在-1轴上按照32个数进行分组,每组长度为32,对每组单独计算一个块缩放因子,再把组内所有元素用同一个映射到目标低精度类型FP8。如果最后一组不足32个元素,把缺失值视为0,按照完整组处理。
每个算子分为,必须先调用“aclnnDynamicMxQuantWithDualAxisGetWorkspaceSize”接口获取计算所需workspace大小以及包含了算子计算流程的执行器,再调用“aclnnDynamicMxQuantWithDualAxis”接口执行计算。
- 关于x、mxscale1Out、mxscale2Out的shape约束说明如下:
- rank(mxscale1Out) = rank(x) + 1。
- rank(mxscale2Out) = rank(x) + 1。
- mxscale1Out.shape[-2] = (ceil(x.shape[-1] / 32) + 2 - 1) / 2。
- mxscale2Out.shape[-3] = (ceil(x.shape[-2] / 32) + 2 - 1) / 2。
- mxscale1Out.shape[-1] = 2。
- mxscale2Out.shape[-1] = 2。
- 其他维度与输入x一致。
- 举例:输入x的shape为[B, M, N],目的数据类型为FP8类时,对应的y1和y2的shape为[B, M, N],mxscale1的shape为[B, M, (ceil(N/32)+2-1)/2, 2],mxscale2的shape为[B, (ceil(M/32)+2-1)/2, N, 2]。