torch_npu.contrib.module.LinearQuant
功能描述
LinearQuant是对torch_npu接口torch_npu.npu_quant_matmul的封装类,完成A8W8、A4W4量化算子的矩阵乘计算。
接口原型
torch_npu.contrib.module.LinearQuant(in_features, out_features, *, bias=True, offset=False, pertoken_scale=False, output_dtype=None)
参数说明
- in_features(计算参数):int类型,matmul计算中k轴的值。
- out_features(计算参数):int类型,matmul计算中n轴的值。
- bias(计算参数):bool类型,代表是否需要bias计算参数。如果设置成False,则bias不会加入量化matmul的计算。
- offset(计算参数):bool类型,代表是否需要offset计算参数。如果设置成False,则offset不会加入量化matmul的计算。
- pertoken_scale(计算参数):bool类型,可选参数代表是否需要pertoken_scale计算参数。如果设置成False,则pertoken_scale不会加入量化matmul的计算。Atlas 推理系列加速卡产品当前不支持pertoken_scale。当x1和x2的输入类型为INT32时,当前不支持pertoken_scale。
- output_dtype( 计算参数):ScalarType类型,表示输出Tensor的数据类型,支持输入torch.int8、torch.float16、torch.bfloat16。默认值为None,代表输出Tensor数据类型为INT8。Atlas 推理系列加速卡产品只支持output_dtype为torch.int8(含None,下同)和torch.float16。当x1,weight输入数据类型为INT32时,需要输入output_dtype,只支持output_dtype为torch.float16。
输入说明
x1(计算输入):Device侧的Tensor类型,数据类型支持INT8和INT32,其中INT32类型表示使用本接口进行INT4类型矩阵乘计算,INT32类型承载的是INT4数据,每个INT32数据存放8个INT4数据。数据格式支持ND,shape最少是2维,最多是6维。
变量说明
- weight(变量):Device侧的Tensor类型,数据类型支持INT8和INT32(同x1,表示INT4的数据计算),与x1的数据类型须保持一致。数据格式支持ND,shape最少是2维,最多是6维,数据类型为INT32时,shape为2维。
- scale(变量):Device侧的Tensor类型,量化计算的scale。数据类型支持FLOAT32、INT64、BFLOAT16。数据格式支持ND,shape是1维(t,),t=1或n,其中n与weight的n一致。如需传入INT64数据类型的scale,需要提前调用torch_npu.npu_trans_quant_param接口来获取INT64数据类型的scale。
- offset(变量):Device侧的Tensor类型,量化计算的offset。可选参数。数据类型支持FLOAT32,数据格式支持ND,shape是1维(t,),t=1或n,其中n与weight的n一致。
- pertoken_scale(变量):Device侧的Tensor类型,可选参数,量化计算的pertoken。数据类型支持FLOAT32,数据格式支持ND,shape是1维(m,),其中m与x1的m一致。Atlas 推理系列加速卡产品当前不支持pertoken[g2] _scale。
- bias(变量):Device侧的Tensor类型,可选参数。矩阵乘中的bias。数据类型支持INT32,BFLOAT16。数据格式支持ND,shape支持1维(n,)或3维(batch,1,n),n与weight的n一致。bias 3维(batch,1,n)只出现在out为3维的场景下,同时batch值需要等于x1,weight boardcast后推导出的batch值。
- output_dtype(变量):Device侧的ScalarType类型,可选参数。表示输出Tensor的数据类型,支持输入torch.int8、torch.float16、torch.bfloat16。默认值为None,代表输出Tensor数据类型为INT8。Atlas 推理系列加速卡产品只支持输出类型为torch.int8(含None,下同)和torch.float16。x1与weight输入数据类型为INT32时,需要输入output_dtype,只支持output_dtype为torch.float16。
输出说明
一个Tensor类型的输出,代表量化matmul的计算结果:
- 如果output_dtype为torch.float16,输出的数据类型为FLOAT16。
- 如果output_dtype为torch.bfloat16,输出的数据类型为BFLOAT16。
- 如果output_dtype为torch.int8或者None,输出的数据类型为INT8。
如果output_dtype非以上数据类型,返回错误码。
约束说明
- 该融合算子仅在推理场景使用。
- x1、weight、scale不能是空。
- x1、weight、bias、scale、offset、pertoken_scale的数据类型和数据格式需要在支持的范围之内。
- x1、weight的shape需要在2-6维范围,weight数据类型为INT32时,shape只能为2维。
- scale、offset的shape需要为1维(t,),t = 1或n,n与weight的n一致。
- pertoken_scale的shape需要为1维(m, ),m与x1的m一致。
- bias的shape支持1维(n,)或3维(batch,1,n),n与weight的n一致,batch值需要等于x1,weight boardcast后推导出的batch值。shape在out是2,4,5,6维情况下需要为1维,在out是3维情况下可以为1维或3维。
- output_dtype与scale的关系:
- output_dtype为torch.bfloat16时,scale需要为BFLOAT16或FLOAT32数据类型的Tensor。
- output_dtype为torch.float16或torch.int8,并且在pertoken_scale为空时,scale可为FLOAT32或INT64数据类型的Tensor。
- output_dtype为torch.float16且pertoken_scale不为空时,scale必须为FLOAT32。
- bias为BFLOAT16数据类型时,output_dtype需要为torch.bfloat16。
- pertoken_scale仅支持FLOAT32,目前仅在输出FLOAT16和BFLOAT16场景下可不为空。且Atlas 推理系列产品当前不支持pertoken_scale。
- offset不为空时,output_dtype仅支持torch.int8。
- 在pertoken为空的场景下使用模型进行计算前,scale变量在非Atlas 推理系列加速卡产品单算子模式下需要通过调用torch air仓里面的tng.experimental.inference.use_internal_format_weight(model)或者npu_trans_quant_param进行量化参数预处理,不然会增加额外的量化参数转换耗时。调用方式可参考示例代码。
- x1与x2最后一维的shape大小不能超过65535。
- Atlas 推理系列加速卡产品下需要调用tng.experimental.inference.use_internal_format_weight(model)或npu_format_cast可以完成输入x2(batch, n, k)高性能数据排布功能。
- Atlas A2训练系列产品/Atlas 800I A2推理产品需要调用npu_format_cast可以完成输入x2(batch,n,k)高性能数据排布功能,但不推荐使用该module方式,推荐npu_quant_matmul。
- INT4类型计算的额外约束:
- x1,weight的数据类型均为INT32,每个INT32类型的数据存放8个INT4数据。输入shape需要将数据原本INT4类型时的最后一维shape缩小8倍。INT4数据的最后一维shape应为8的倍数,例如:
进行(m,k)乘(k,n)的INT4类型矩阵乘计算时,需要输入INT32类型,shape为(m,k//8)(k,n//8)的数据,其中k与n都应是8的倍数。x1只能接受shape为(m,k//8)且数据排布连续的数据,weight只能接受shape为(n,k//8)且数据排布连续的数据(数据排布连续指数组中所有相邻的数,包括换行时内存地址连续;使用Tensor.is_contiguous返回值为true则表明tensor数据排布连续)。
- 进行INT4类型计算时输出数据类型只能为FLOAT16,可选输入output_dtype需要输入torch.float16。
- INT4类型计算当前不支持pertoken_scale。
- x1,weight的数据类型均为INT32,每个INT32类型的数据存放8个INT4数据。输入shape需要将数据原本INT4类型时的最后一维shape缩小8倍。INT4数据的最后一维shape应为8的倍数,例如:
支持的型号
- Atlas 推理系列加速卡产品
- Atlas A2训练系列产品/Atlas 800I A2推理产品
调用示例
- 单算子调用
- Atlas 推理系列加速卡产品
在单算子模式下不支持使能高带宽的x2数据排布,因此不能调用use_internal_format_weight,如果想追求极致性能,请使用图模式
import torch import torch_npu import logging import os from torch_npu.contrib.module import LinearQuant x1 = torch.randint(-1, 1, (1, 512), dtype=torch.int8).npu() x2 = torch.randint(-1, 1, (128, 512), dtype=torch.int8).npu() scale = torch.randn(1, dtype=torch.float32).npu() offset = torch.randn(128, dtype=torch.float32).npu() bias = torch.randint(-1,1, (128,), dtype=torch.int32).npu() in_features = 512 out_features = 128 output_dtype = torch.int8 model = LinearQuant(in_features, out_features, bias=True, offset=True, output_dtype=output_dtype) model = model.npu() model.weight.data = x2 model.scale.data = scale model.offset.data = offset model.bias.data = bias // 接口内部调用npu_trans_quant_param功能 output = model(x1)
- Atlas A2训练系列产品/Atlas 800I A2推理产品
# int8输入场景 import torch import torch_npu import logging import os from torch_npu.contrib.module import LinearQuant x1 = torch.randint(-1, 1, (1, 512), dtype=torch.int8).npu() x2 = torch.randint(-1, 1, (128, 512), dtype=torch.int8).npu() scale = torch.randn(1, dtype=torch.float32).npu() offset = torch.randn(128, dtype=torch.float32).npu() bias = torch.randint(-1,1, (128,), dtype=torch.int32).npu() in_features = 512 out_features = 128 output_dtype = torch.int8 model = LinearQuant(in_features, out_features, bias=True, offset=True, output_dtype=output_dtype) model = model.npu() model.weight.data = x2 model.scale.data = scale model.offset.data = offset model.bias.data = bias output = model(x1) # int4输入场景 import torch import torch_npu import logging import os from torch_npu.contrib.module import LinearQuant # 用int32类型承载int4数据,实际int4 shape为x1:(1, 512) x2: (128, 512) x1 = torch.randint(-1, 1, (1, 64), dtype=torch.int32).npu() x2 = torch.randint(-1, 1, (128, 64), dtype=torch.int32).npu() scale = torch.randn(1, dtype=torch.float32).npu() bias = torch.randint(-1,1, (128,), dtype=torch.int32).npu() in_features = 512 out_features = 128 output_dtype = torch.float16 model = LinearQuant(in_features, out_features, bias=True, offset=False, output_dtype=output_dtype) model = model.npu() model.weight.data = x2 model.scale.data = scale model.bias.data = bias output = model(x1)
- Atlas 推理系列加速卡产品
- 图模式调用(图模式目前仅支持PyTorch 2.1版本)
import torch import torch_npu import torchair as tng from torchair.ge_concrete_graph import ge_apis as ge from torchair.configs.compiler_config import CompilerConfig from torch_npu.contrib.module import LinearQuant import logging from torchair.core.utils import logger logger.setLevel(logging.DEBUG) import os import numpy as np os.environ["ENABLE_ACLNN"] = "true" config = CompilerConfig() npu_backend = tng.get_npu_backend(compiler_config=config) x1 = torch.randint(-1, 1, (1, 512), dtype=torch.int8).npu() x2 = torch.randint(-1, 1, (128, 512), dtype=torch.int8).npu() scale = torch.randn(1, dtype=torch.float32).npu() offset = torch.randn(128, dtype=torch.float32).npu() bias = torch.randint(-1,1, (128,), dtype=torch.int32).npu() in_features = 512 out_features = 128 output_dtype = torch.int8 model = LinearQuant(in_features, out_features, bias=True, offset=True, output_dtype=output_dtype) model = model.npu() model.weight.data = x2 model.scale.data = scale model.offset.data = offset if output_dtype != torch.bfloat16: #使能高带宽x2的数据排布功能 tng.experimental.inference.use_internal_format_weight(model) model.bias.data = bias model = torch.compile(model, backend=npu_backend, dynamic=False) output = model(x1)
父主题: torch_npu.contrib