接口功能:RmsNorm算子是大模型常用的归一化操作,相比LayerNorm算子,其去掉了减去均值的部分。DynamicMxQuant算子则是在尾轴上按blocksize分组进行动态MX量化的算子。AddRmsNormDynamicMxQuant算子将RmsNorm前的Add算子和RmsNorm归一化输出给到的DynamicMxQuant算子融合起来,减少搬入搬出操作。 在输入尾轴axis上,根据每blocksize=32个数,计算出这组数对应的量化尺度mxscale,然后对这组数每一个除以mxscale,根据round_mode转换到对应的dst_type,得到量化结果y。在dst_type为FLOAT8_E4M3FN、FLOAT8_E5M2时,根据scale_alg的取值来指定计算mxscale的不同算法。
计算公式:
当scaleAlg为0时:
- 将RmsNorm输出y在尾轴维度上按k = 32个数分组,一组k个数 动态量化为
emax: 对应数据类型的最大正则数的指数位。
[object Object]undefined
当scaleAlg为1时,只涉及FP8类型:
将长向量按块分,每块长度为k,对每块单独计算一个块缩放因子,再把块内所有元素用同一个映射到目标低精度类型FP8。
找到该块中数值的最大绝对值:
将FP32映射到目标数据类型FP8可表示的范围内:
转换为FP8格式下可表示的缩放值
从块的浮点缩放因子中提取无偏指数和尾数
为保证量化时不溢出,对指数进行向上取整:
计算块缩放因子:
计算块转换因子:
应用到量化的最终步骤:
每个算子分为,必须先调用[object Object]接口获取入参并根据计算流程所需workspace大小,再调用[object Object]接口执行计算。
边界值场景说明
- 当输入是Inf时:1、输出yOut为0;2、输出xOut为Inf;3、输出mxscaleOut为255,偶数pad填充值为0;4、输出rstdOut为0。
- 当输入是NaN时:1、输出yOut为0;2、输出xOut为Nan;3、输出mxscaleOut为255,偶数pad填充值为0;4、输出rstdOut为NaN。
数据格式说明
所有输入输出Tensor的数据格式推荐使用ND格式,其他数据格式会由框架默认转换成ND格式进行处理。
各产品型号支持数据类型说明
Atlas 350 加速卡:
[object Object]undefined
mxscaleOut的shape约束说明如下:
- rank(mxscaleOut) = rank(x1) + 1。
- mxscaleOut.shape[-2] = (ceil(x1.shape[-1] / 32) + 2 - 1) / 2。
- mxscaleOut.shape[-1] = 2。
- 其他维度与输入x1一致。
x1的shape约束说明如下:
- 当输出yOut的数据类型为FLOAT4_E2M1或FLOAT4_E1M2,x1尾轴的值必须为偶数。
确定性计算:
- aclnnAddRmsNormDynamicMxQuant默认确定性实现。