该接口用于实现矩阵乘计算中的weight输入和输出的量化操作,支持pertensor、perchannel、pergroup多场景量化(Atlas 推理系列加速卡产品当前仅支持perchannel)。
npu_weight_quant_batchmatmul(Tensor x, Tensor weight, Tensor antiquant_scale, Tensor? antiquant_offset=None, Tensor? quant_scale=None, Tensor? quant_offset=None, Tensor? bias=None, int antiquant_group_size=0) -> Tensor
输出为Tensor类型,代表计算结果。当输入存在quant_scale时输出数据类型为INT8,当输入不存在quant_scale时输出数据类型和输入x一致。
# 单算子模式: import torch import torch_npu cpu_x = torch.randn((8192, 320),dtype=torch.float16) cpu_weight = torch.randint(low=-8, high=8, size=(320, 256),dtype=torch.int8) cpu_antiquantscale = torch.randn((1, 256),dtype=torch.float16) cpu_antiquantoffset = torch.randn((1, 256),dtype=torch.float16) cpu_quantscale = torch.randn((1, 256),dtype=torch.float32) cpu_quantoffset = torch.randn((1, 256),dtype=torch.float32) quantscale= torch_npu.npu_trans_quant_param(cpu_quantscale.npu(), cpu_quantoffset.npu()) npu_out = torch_npu.npu_weight_quant_batchmatmul(cpu_x.npu(), cpu_weight.npu(), cpu_antiquantscale.npu(), cpu_antiquantoffset.npu(),quantscale.npu())
# 图模式
import torch
import torch_npu
import torchair as tng
from torchair.configs.compiler_config import CompilerConfig
config = CompilerConfig()
config.debug.graph_dump.type = "pbtxt"
npu_backend = tng.get_npu_backend(compiler_config=config)
cpu_x = torch.randn((8192, 320),device='npu',dtype=torch.bfloat16)
cpu_weight = torch.randn((320, 256),device='npu',dtype=torch.int8)
cpu_antiquantscale = torch.randn((1, 256),device='npu',dtype=torch.bfloat16)
cpu_antiquantoffset = torch.randn((1, 256),device='npu',dtype=torch.bfloat16)
class MyModel(torch.nn.Module):
def __init__(self):
super().__init__()
def forward(self, x, weight, antiquant_scale, antiquant_offset, quant_scale,quant_offset, bias, antiquant_group_size):
return torch_npu.npu_weight_quant_batchmatmul(x, weight, antiquant_scale, antiquant_offset, quant_scale ,quant_offset, bias, antiquant_group_size)
cpu_model = MyModel()
model = cpu_model.npu()
model = torch.compile(cpu_model, backend=npu_backend, dynamic=True)
npu_out = model(cpu_x.npu(), cpu_weight.npu(), cpu_antiquantscale.npu(), cpu_antiquantoffset.npu(), None, None, None, 0)
# Atlas 推理系列加速卡产品图模式,weight输入为FRACTAL_NZ格式:
import torch_npu
import torch
from torchair.configs.compiler_config import CompilerConfig
import torchair as tng
config = CompilerConfig()
config.debug.graph_dump.type = "pbtxt"
npu_backend = tng.get_npu_backend(compiler_config=config)
class NPUQuantizedLinearA16W8(torch.nn.Module):
def __init__(self,
weight,
antiquant_scale,
antiquant_offset,
quant_offset=None,
quant_scale=None,
bias=None,
transpose_x=False,
transpose_weight=True,
w4=False):
super().__init__()
self.dtype = torch.float16
self.weight = weight.to(torch.int8).npu()
self.transpose_weight = transpose_weight
if self.transpose_weight:
self.weight = torch_npu.npu_format_cast(self.weight.contiguous(), 29)
else:
self.weight = torch_npu.npu_format_cast(self.weight.transpose(0, 1).contiguous(), 29) # n,k ->nz
self.bias = None
self.antiquant_scale = antiquant_scale
self.antiquant_offset = antiquant_offset
self.quant_offset = quant_offset
self.quant_scale = quant_scale
self.transpose_x = transpose_x
def forward(self, x):
x = torch_npu.npu_weight_quant_batchmatmul(x.transpose(0, 1) if self.transpose_x else x,
self.weight.transpose(0, 1),
self.antiquant_scale.transpose(0, 1),
self.antiquant_offset.transpose(0, 1),
self.quant_scale,
self.quant_offset,
self.bias)
return x
m, k, n = 4, 1024, 4096
cpu_x = torch.randn((m, k),dtype=torch.float16)
cpu_weight = torch.randint(1, 10, (k, n),dtype=torch.int8)
cpu_weight = cpu_weight.transpose(-1, -2)
cpu_antiquantscale = torch.randn((1, n),dtype=torch.float16)
cpu_antiquantoffset = torch.randn((1, n),dtype=torch.float16)
cpu_antiquantscale = cpu_antiquantscale.transpose(-1, -2)
cpu_antiquantoffset = cpu_antiquantoffset.transpose(-1, -2)
model = NPUQuantizedLinearA16W8(cpu_weight.npu(), cpu_antiquantscale.npu(), cpu_antiquantoffset.npu())
model = torch.compile(model, backend=npu_backend, dynamic=True)
out = model(cpu_x.npu())