torch_npu.contrib.module.LinearWeightQuant

功能描述

LinearWeightQuant是对torch_npu.npu_weight_quant_batchmatmul接口的封装类,完成矩阵乘计算中的weight输入和输出的量化操作,支持per-tensor、per-channel、per-group多场景量化。

当前Atlas 推理系列产品仅支持per-channel量化。

接口原型

1
torch_npu.contrib.module.LinearWeightQuant(in_features, out_features, bias=True, device=None, dtype=None, antiquant_offset=False, quant_scale=False, quant_offset=False, antiquant_group_size=0, inner_precise=0)

参数说明

输入说明

x:Tensor类型,即矩阵乘中的x。数据格式支持ND,支持输入维度为两维(M, K) 。
  • Atlas A2 训练系列产品/Atlas 800I A2 推理产品/A200I A2 Box 异构组件:数据类型支持float16、bfloat16。
  • Atlas A3 训练系列产品/Atlas A3 推理系列产品:数据类型支持float16、bfloat16。
  • Atlas 推理系列产品:数据类型仅支持float16。

变量说明

输出说明

输出为Tensor类型,代表计算结果。当输入存在quant_scale时输出数据类型为int8,当输入不存在quant_scale时输出数据类型和输入x一致。

约束说明

支持的型号

调用示例