昇腾社区首页
中文
注册

torch_npu.contrib.module.LinearA8W8Quant

功能描述

LinearA8W8Quant是对torch_npu接口torch_npu.npu_quant_matmul的封装类,完成A8W8量化算子的矩阵乘计算。

接口原型

torch_npu.contrib.module.LinearA8W8Quant(in_features, out_features, *, bias=True, offset=False, pertoken_scale=False, output_dtype=None)

参数说明

  • in_features(计算参数):int类型,matmul计算中k轴的值。
  • out_features(计算参数):int类型,matmul计算中n轴的值。
  • bias(计算参数):bool类型,代表是否需要bias计算参数。如果设置成False, 则bias不会加入量化matmul的计算。
  • offset(计算参数):bool类型,代表是否需要offset计算参数。如果设置成False, 则offset不会加入量化matmul的计算。
  • pertoken_scale(计算参数):bool类型,代表是否需要pertoken_scale计算参数。如果设置成False, 则pertoken_scale不会加入量化matmul的计算。

    目前仅为接口预留参数,pertoken功能当前不可用,应设置为False。

  • output_dtype(计算参数):ScalarType类型,表示输出Tensor的数据类型,支持输入torch.int8,torch.float16, torch.bfloat16。默认值为None,代表输出Tensor数据类型为INT8。Atlas 推理系列产品只支持output_dtype为torch.int8(含None,下同)和torch.float16。

输入说明

  • x1(计算输入):Device侧的Tensor类型,数据类型支持INT8。数据格式支持ND,shape最少是2维,最多是6维。

变量说明

  • weight(变量):Device侧的Tensor类型,数据格式支持INT8。数据格式支持ND,shape为(n, k)。
  • scale(变量):Device侧的Tensor类型,数据类型支持FLOAT32, INT64, BFLOAT16。数据格式支持ND,shape是1维(t,),t = 1或n,其中n与weight的n一致。如需传入INT64数据类型的scale, 需要提前调用torch_npu.npu_trans_quant_param接口来获取INT64数据类型的scale。
  • offset(变量):Device侧的Tensor类型,可选参数。数据类型支持FLOAT32,数据格式支持ND,shape是1维(t,),t = 1或n,其中n与weight的n一致。
  • pertoken_scale(变量):Device侧的Tensor类型,可选参数。数据类型支持FLOAT32,数据格式支持ND,shape是1维(m,),其中m与x1的m一致。目前仅为接口预留参数,pertoken功能当前不可用。
  • bias(变量):Device侧的Tensor类型,可选参数。数据类型支持INT32,数据格式支持ND,shape支持1维(n,)或3维(batch,1,n),n与weight的n一致。bias 3维(batch,1,n)只出现在out为3维的场景下,同时batch值需要等于x1, weight boardcast后推导出的batch值。
  • output_dtype(变量):Device侧的ScalarType类型,可选参数。表示输出Tensor的数据类型,支持输入torch.int8,torch.float16, torch.bfloat16。默认值为None,代表输出Tensor数据类型为INT8。Atlas 推理系列产品只支持输出类型为torch.int8(含None,下同)和torch.float16。

输出说明

一个Tensor类型的输出,代表量化matmul的计算结果。如果output_dtype为torch.float16,输出的数据类型为FLOAT16;如果output_dtype为torch.bfloat16,输出的数据类型为BFLOAT16;如果output_dtype为torch.int8或者None,输出的数据类型为INT8;如果output_dtype非以上数据类型,返回错误码。

约束说明

  • x1、weight、scale不能是空。
  • x1、weight、bias、scale、offset、pertoken_scale的数据类型和数据格式需要在支持的范围之内。
  • x1、weight的shape需要在2-6维范围。
  • scale, offset的shape需要为1维(t,),t = 1或n,n与x2的n一致。
  • pertoken_scale的shape需要为1维(m, ),m与x1的m一致。
  • bias的shape支持1维(n,)或3维(batch,1,n),n与weight的n一致, batch值需要等于x1, weight boardcast后推导出的batch值。
  • bias的shape在out 是2,4,5,6维情况下需要为1维,在out 是3维情况下可以为1维或3维。
  • output_dtype为torch.bfloat16时,scale需要为BFLOAT16数据类型的Tensor。output_dtype为torch.float16或torch.int8,并且在pertoken_scale为空时,scale可为FLOAT32或INT64数据类型的Tensor。output_dtype为torch.float16且pertoken_scale不为空时,scale必须为FLOAT32。
  • pertoken_scale仅支持FLOAT32,目前仅在输出FLOAT16和BFLOAT16场景下可不为空。目前仅为接口预留参数,pertoken功能当前不可用。
  • offset不为空时,output_dtype仅支持torch.int8。
  • 在pertoken为空的场景下使用模型进行计算前,scale变量Atlas A2 训练系列产品单算子模式下需要通过调用tng.experimental.inference.use_internal_format_weight(model)进行量化参数预处理,不然会增加额外的量化参数转换耗时。调用方式可参考示例代码。
  • x1与x2最后一维的shape大小不能超过65535。

支持的PyTorch版本

  • PyTorch 2.2
  • PyTorch 2.1
  • PyTorch 1.11.0

支持的型号

  • Atlas A2 训练系列产品
  • Atlas 推理系列产品

调用示例

1.单算子模式:
1.1 Atlas 推理系列产品:在单算子模式下不支持使能高带宽的x2数据排布,因此不能调用use_internal_format_weight,如果想追求极致性能,请使用图模式
import torch
import torch_npu
import logging
import os
import torchair as tng
from torch_npu.contrib.module import LinearA8W8Quant
x1 = torch.randint(-1, 1, (1, 512), dtype=torch.int8).npu()
x2 = torch.randint(-1, 1, (128, 512), dtype=torch.int8).npu()
scale = torch.randn(1, dtype=torch.float32).npu()
offset = torch.randn(128, dtype=torch.float32).npu()
bias = torch.randint(-1,1, (128,), dtype=torch.int32).npu()
in_features = 512
out_features = 128
output_dtype = torch.int8
model = LinearA8W8Quant(in_features, out_features, bias=True, offset=True, output_dtype=output_dtype)
model = model.npu()
model.weight.data = x2
model.scale.data = scale
model.offset.data = offset
model.bias.data = bias
// 接口内部调用npu_trans_quant_param功能
output = model(x1)

1.2 Atlas A2 训练系列产品单算子模式调用示例
import torch
import torch_npu
import logging
import os
import torchair as tng
from torch_npu.contrib.module import LinearA8W8Quant
x1 = torch.randint(-1, 1, (1, 512), dtype=torch.int8).npu()
x2 = torch.randint(-1, 1, (128, 512), dtype=torch.int8).npu()
scale = torch.randn(1, dtype=torch.float32).npu()
offset = torch.randn(128, dtype=torch.float32).npu()
bias = torch.randint(-1,1, (128,), dtype=torch.int32).npu()
in_features = 512
out_features = 128
output_dtype = torch.int8
model = LinearA8W8Quant(in_features, out_features, bias=True, offset=True, output_dtype=output_dtype)
model = model.npu()
model.weight.data = x2
model.scale.data = scale
model.offset.data = offset
if output_dtype != torch.bfloat16:
    # Atlas A2 训练系列产品只包含npu_trans_quant_param功能
    tng.experimental.inference.use_internal_format_weight(model)
model.bias.data = bias
output = model(x1)

2.图模式
import torch
import torch_npu
import torchair as tng
from torchair.ge_concrete_graph import ge_apis as ge
from torchair.configs.compiler_config import CompilerConfig
from torch_npu.contrib.module import LinearA8W8Quant
import logging
from torchair.core.utils import logger
logger.setLevel(logging.DEBUG)
import os
import numpy as np
os.environ["ENABLE_ACLNN"] = "true"
config = CompilerConfig()
npu_backend = tng.get_npu_backend(compiler_config=config)
x1 = torch.randint(-1, 1, (1, 512), dtype=torch.int8).npu()
x2 = torch.randint(-1, 1, (128, 512), dtype=torch.int8).npu()
scale = torch.randn(1, dtype=torch.float32).npu()
offset = torch.randn(128, dtype=torch.float32).npu()
bias = torch.randint(-1,1, (128,), dtype=torch.int32).npu()
in_features = 512
out_features = 128
output_dtype = torch.int8
model = LinearA8W8Quant(in_features, out_features, bias=True, offset=True, output_dtype=output_dtype)
model = model.npu()
model.weight.data = x2
model.scale.data = scale
model.offset.data = offset
if output_dtype != torch.bfloat16:
    # 包含npu_trans_quant_param功能,Atlas 推理系列产品还包含使能高带宽的x2数据排布功能
    tng.experimental.inference.use_internal_format_weight(model)
model.bias.data = bias
model = torch.compile(model, backend=npu_backend, dynamic=False)
output = model(x1)