昇腾社区首页
中文
注册

torch_npu.contrib.module.LinearWeightQuant

功能描述

LinearWeightQuant是对torch_npu接口torch_npu.npu_weight_quant_batchmatmul的封装类,完成矩阵乘计算中的weight输入和输出的量化操作,支持pertensor,perchannel,pergroup多场景量化。

接口原型

torch_npu.contrib.module.LinearWeightQuant(in_features, out_features, bias=True, device=None, dtype=None, antiquant_offset=False, quant_scale=False, quant_offset=False, antiquant_group_size=0)

参数说明

  • in_features:int类型,伪量化matmul计算中的k轴的值。
  • out_features:int类型,伪量化matmul计算中的n轴的值。
  • bias:bool类型,为可选参数默认为True,代表是否需要bias计算参数。如果设置成False, 则bias不会加入伪量化matmul的计算。
  • device:string类型,为可选参数,用于执行model的device名称,默认为None。
  • dtype:torch支持的dtype类型,为可选参数默认为None,伪量化matmul运算中的输入x的dtype。
  • antiquant_offset:bool类型,为可选参数默认为False,代表是否需要antiquant_offset计算参数。如果设置成False,则weight矩阵反量化时无需设置offset。
  • quant_scale:bool类型,为可选参数默认为False,代表是否需要quant_scale计算参数。如果设置成False,则伪量化输出不会进行量化计算。
  • quant_offset:bool类型,为可选参数默认为False,代表是否需要quant_offset计算参数。如果设置成False,则对伪量化输出进行量化计算时无需设置offset。
  • antiquant_group_size:Int类型,为可选参数,用于控制pergroup场景下的group大小, 当前默认为0,预留参数,暂未使用

输入说明

  • x:Device侧Tensor类型,即矩阵乘中的x。数据格式支持ND, 数据类型支持FLOAT16/ BFLOAT16,支持输入维度为两维(M,K) 。

变量说明

  • weight:Device侧Tensor类型,即矩阵乘中的weight。数据格式支持ND, 数据类型支持INT8, 支持非连续的Tensor,支持输入维度为两维(K,N)。
  • antiquant_scale:Device侧Tensor类型,反量化的scale,用于weight矩阵反量化 。数据格式支持ND, 数据类型支持FLOAT16/ BFLOAT16, 支持非连续的Tensor,支持输入维度为两维(1, N)或 一维(N, )、(1, )。
  • antiquant_offset:Device侧Tensor类型,反量化的offset,用于weight矩阵反量化 。数据格式支持ND, 数据类型支持FLOAT16/ BFLOAT16, 支持非连续的Tensor,支持输入维度为两维(1, N)或 一维(N, )、(1, )。
  • quant_scale:Device侧Tensor类型,量化的scale,用于输出矩阵的量化 。数据格式支持ND, 数据类型支持FLOAT32/ INT64,支持输入维度为两维(1, N) 或 一维(N, )、(1, )。
  • quant_offset:Device侧Tensor类型,量化的offset,用于输出矩阵的量化 。数据格式支持ND, 数据类型支持FLOAT32 ,支持输入维度为两维(1, N) 或 一维(N, )、(1, )。
  • bias:Device侧Tensor类型, 即矩阵乘中的bias, 数据格式支持ND, 数据类型支持FLOAT16/ FLOAT32, 支持非连续的Tensor, 支持输入维度为两维(1, N) 或 一维(N, )、(1, )。
  • antiquant_group_size:Int类型, 用于控制pergroup场景下的group大小, 当前默认为0,预留参数,暂未使用

输出说明

输出为Device侧Tensor类型,代表计算结果。当输入存在quant_scale时输出数据类型为INT8,当输入不存在quant_sclae时输出数据类型和输入x一致。

约束说明

  • LinearWeightQuant传入参数支持范围(参数说明中的参数)和torch_npu.npu_weight_quant_batchmatmul接口保持一致。
  • 使用模型进行计算前,quant_scale变量需要通过调用torchair.experimental.inference.use_internal_format_weight(model)进行量化参数预处理,不然会增加额外的量化参数转换耗时。调用方式可参考实例代码。预处理依赖PyTorch2.x版本。

支持的PyTorch版本

  • PyTorch 2.2
  • PyTorch 2.1
  • PyTorch 1.11.0

支持的型号

Atlas A2 训练系列产品

调用示例

单算子模式:
import torch
import torch_npu
import torchair as tng
from torch_npu.contrib.module import LinearWeightQuant

x = torch.randn((8192, 320),device='npu',dtype=torch.bfloat16)
weight = torch.randn((320, 256),device='npu',dtype=torch.int8)
antiquantscale = torch.randn((1, 256),device='npu',dtype=torch.bfloat16)
antiquantoffset = torch.randn((1, 256),device='npu',dtype=torch.bfloat16)
quantscale = torch.randn((1, 256),device='npu',dtype=torch.float)
quantoffset = torch.randn((1, 256),device='npu',dtype=torch.float)

model = LinearWeightQuant(in_features=320,
​                 out_features=256,
​                 bias=False,
​                 dtype=torch.bfloat16,
​                 antiquant_offset=True,
​                 quant_scale=True,
​                 quant_offset=True,
​                 antiquant_group_size=0,
​                 device=torch.device(f'npu:0')
​                 )
model.npu()
model.weight.data = weight
model.antiquant_scale.data = antiquantscale
model.antiquant_offset.data = antiquantoffset
model.quant_scale.data = quantscale
model.quant_offset.data = quantoffset
tng.experimental.inference.use_internal_format_weight(model)
out = model.(x)

图模式:
import torch
import torch_npu
import torchair as tng
from torch_npu.contrib.module import LinearWeightQuant
from torchair.configs.compiler_config import CompilerConfig

config = CompilerConfig()
config.debug.graph_dump.type = "pbtxt"
npu_backend = tng.get_npu_backend(compiler_config=config)

x = torch.randn((8192, 320),device='npu',dtype=torch.bfloat16)
weight = torch.randn((320, 256),device='npu',dtype=torch.int8)
antiquantscale = torch.randn((1, 256),device='npu',dtype=torch.bfloat16)
antiquantoffset = torch.randn((1, 256),device='npu',dtype=torch.bfloat16)
quantscale = torch.randn((1, 256),device='npu',dtype=torch.float)
quantoffset = torch.randn((1, 256),device='npu',dtype=torch.float)

model = LinearWeightQuant(in_features=320,
​                 out_features=256,
​                 bias=False,
​                 dtype=torch.bfloat16,
​                 antiquant_offset=True,
​                 quant_scale=True,
​                 quant_offset=True,
​                 antiquant_group_size=0,
​                 device=torch.device(f'npu:0')
​                 )
model.npu()
model.weight.data = weight
model.antiquant_scale.data = antiquantscale
model.antiquant_offset.data = antiquantoffset
model.quant_scale.data = quantscale
model.quant_offset.data = quantoffset
tng.experimental.inference.use_internal_format_weight(model)
model = torch.compile(model, backend=npu_backend, dynamic=False)
out = model.(x)