LinearA8W8Quant是对torch_npu接口torch_npu.npu_quant_matmul的封装类,完成A8W8量化算子的矩阵乘计算。
torch_npu.contrib.module.LinearA8W8Quant(in_features, out_features, *, bias=True, offset=False, pertoken_scale=False, output_dtype=None)
一个Tensor类型的输出,代表量化matmul的计算结果:
如果output_dtype非以上数据类型,返回错误码。
在单算子模式下不支持使能高带宽的x2数据排布,因此不能调用use_internal_format_weight,如果想追求极致性能,请使用图模式
import torch import torch_npu import logging import os from torch_npu.contrib.module import LinearA8W8Quant x1 = torch.randint(-1, 1, (1, 512), dtype=torch.int8).npu() x2 = torch.randint(-1, 1, (128, 512), dtype=torch.int8).npu() scale = torch.randn(1, dtype=torch.float32).npu() offset = torch.randn(128, dtype=torch.float32).npu() bias = torch.randint(-1,1, (128,), dtype=torch.int32).npu() in_features = 512 out_features = 128 output_dtype = torch.int8 model = LinearA8W8Quant(in_features, out_features, bias=True, offset=True, output_dtype=output_dtype) model = model.npu() model.weight.data = x2 model.scale.data = scale model.offset.data = offset model.bias.data = bias // 接口内部调用npu_trans_quant_param功能 output = model(x1)
import torch import torch_npu import logging import os from torch_npu.contrib.module import LinearA8W8Quant x1 = torch.randint(-1, 1, (1, 512), dtype=torch.int8).npu() x2 = torch.randint(-1, 1, (128, 512), dtype=torch.int8).npu() scale = torch.randn(1, dtype=torch.float32).npu() offset = torch.randn(128, dtype=torch.float32).npu() bias = torch.randint(-1,1, (128,), dtype=torch.int32).npu() in_features = 512 out_features = 128 output_dtype = torch.int8 model = LinearA8W8Quant(in_features, out_features, bias=True, offset=True, output_dtype=output_dtype) model = model.npu() model.weight.data = x2 model.scale.data = scale model.offset.data = offset model.bias.data = bias output = model(x1)
import torch
import torch_npu
import torchair as tng
from torchair.ge_concrete_graph import ge_apis as ge
from torchair.configs.compiler_config import CompilerConfig
from torch_npu.contrib.module import LinearA8W8Quant
import logging
from torchair.core.utils import logger
logger.setLevel(logging.DEBUG)
import os
import numpy as np
os.environ["ENABLE_ACLNN"] = "true"
config = CompilerConfig()
npu_backend = tng.get_npu_backend(compiler_config=config)
x1 = torch.randint(-1, 1, (1, 512), dtype=torch.int8).npu()
x2 = torch.randint(-1, 1, (128, 512), dtype=torch.int8).npu()
scale = torch.randn(1, dtype=torch.float32).npu()
offset = torch.randn(128, dtype=torch.float32).npu()
bias = torch.randint(-1,1, (128,), dtype=torch.int32).npu()
in_features = 512
out_features = 128
output_dtype = torch.int8
model = LinearA8W8Quant(in_features, out_features, bias=True, offset=True, output_dtype=output_dtype)
model = model.npu()
model.weight.data = x2
model.scale.data = scale
model.offset.data = offset
if output_dtype != torch.bfloat16:
# 包含npu_trans_quant_param功能,Atlas 推理系列产品还包含使能高带宽的x2数据排布功能
tng.experimental.inference.use_internal_format_weight(model)
model.bias.data = bias
model = torch.compile(model, backend=npu_backend, dynamic=False)
output = model(x1)