torch_npu.npu_weight_quant_batchmatmul
接口原型
1 | torch_npu.npu_weight_quant_batchmatmul(Tensor x, Tensor weight, Tensor antiquant_scale, Tensor? antiquant_offset=None, Tensor? quant_scale=None, Tensor? quant_offset=None, Tensor? bias=None, int antiquant_group_size=0, int inner_precise=0) -> Tensor |
参数说明
- x : Tensor类型,即矩阵乘中的x。数据格式支持ND,支持带transpose的非连续的Tensor,支持输入维度为两维(M, K) 。
Atlas 推理系列产品 :数据类型支持float16。Atlas A2 训练系列产品/Atlas 800I A2 推理产品/A200I A2 Box 异构组件 :数据类型支持float16、bfloat16。Atlas A3 训练系列产品/Atlas A3 推理系列产品 :数据类型支持float16、bfloat16。
- weight:Tensor类型,即矩阵乘中的weight。支持带transpose的非连续的Tensor,支持输入维度为两维(K, N),维度需与x保持一致。当数据格式为ND时,per-channel场景下为提高性能推荐使用transpose后的weight输入。
Atlas 推理系列产品 :数据类型支持int8。数据格式支持ND、FRACTAL_NZ,其中FRACTAL_NZ格式只在“图模式”有效,需依赖接口torch_npu.npu_format_cast完成ND到FRACTAL_NZ的转换,可参考调用示例。Atlas A2 训练系列产品/Atlas 800I A2 推理产品/A200I A2 Box 异构组件 :数据类型支持int8、int32(通过int32承载int4的输入,可参考torch_npu.npu_convert_weight_to_int4pack调用示例)。数据格式支持ND、FRACTAL_NZ。Atlas A3 训练系列产品/Atlas A3 推理系列产品 :数据类型支持int8、int32(通过int32承载int4的输入,可参考torch_npu.npu_convert_weight_to_int4pack调用示例)。数据格式支持ND、FRACTAL_NZ。
- antiquant_scale:Tensor类型,反量化的scale,用于weight矩阵反量化,数据格式支持ND。支持带transpose的非连续的Tensor。antiquant_scale支持的shape与量化方式相关:
- per_tensor模式:输入shape为(1,)或(1, 1)。
- per_channel模式:输入shape为(1, N)或(N,)。
- per_group模式:输入shape为(ceil(K, antiquant_group_size), N)。
antiquant_scale支持的dtype如下:Atlas 推理系列产品 :数据类型支持float16,其数据类型需与x保持一致。Atlas A2 训练系列产品/Atlas 800I A2 推理产品/A200I A2 Box 异构组件 :数据类型支持float16、bfloat16、int64。- 若输入为float16、bfloat16, 其数据类型需与x保持一致。
- 若输入为int64,x数据类型必须为float16且不带transpose输入,同时weight数据类型必须为int8、数据格式为ND、带transpose输入,可参考调用示例。此时只支持per-channel场景,M范围为[1, 96],且K和N要求64对齐。
Atlas A3 训练系列产品/Atlas A3 推理系列产品 :数据类型支持float16、bfloat16、int64。- 若输入为float16、bfloat16, 其数据类型需与x保持一致。
- 若输入为int64,x数据类型必须为float16且不带transpose输入,同时weight数据类型必须为int8、数据格式为ND、带transpose输入,可参考调用示例。此时只支持per-channel场景,M范围为[1, 96],且K和N要求64对齐。
- antiquant_offset:Tensor类型,反量化的offset,用于weight矩阵反量化。可选参数,默认值为None,数据格式支持ND,支持带transpose的非连续的Tensor,支持输入维度为两维(1, N)或一维(N, )、(1, )。
Atlas 推理系列产品 :数据类型支持float16,其数据类型需与antiquant_scale保持一致。Atlas A2 训练系列产品/Atlas 800I A2 推理产品/A200I A2 Box 异构组件 :数据类型支持float16、bfloat16、int32。per-group场景shape要求为(ceil_div(K, antiquant_group_size), N)。- 若输入为float16、bfloat16,其数据类型需与antiquant_scale保持一致。
- 若输入为int32,antiquant_scale的数据类型必须为int64。
Atlas A3 训练系列产品/Atlas A3 推理系列产品 :数据类型支持float16、bfloat16、int32。per-group场景shape要求为(ceil_div(K, antiquant_group_size), N)。- 若输入为float16、bfloat16,其数据类型需与antiquant_scale保持一致。
- 若输入为int32,antiquant_scale的数据类型必须为int64。
- quant_scale:Tensor类型,量化的scale,用于输出矩阵的量化,可选参数,默认值为None,仅在weight格式为ND时支持。数据类型支持float32、int64,数据格式支持ND,支持输入维度为两维(1, N)或一维(N, )、(1, )。当antiquant_scale的数据类型为int64时,此参数必须为空。
Atlas 推理系列产品 :暂不支持此参数。
- quant_offset: Tensor类型,量化的offset,用于输出矩阵的量化,可选参数,默认值为None,仅在weight格式为ND时支持。数据类型支持float32,数据格式支持ND,支持输入维度为两维(1, N)或一维(N, )、(1, )。当antiquant_scale的数据类型为int64时,此参数必须为空。
Atlas 推理系列产品 :暂不支持此参数。
- bias:Tensor类型, 即矩阵乘中的bias,可选参数,默认值为None,数据格式支持ND, 不支持非连续的Tensor,支持输入维度为两维(1, N)或一维(N, )、(1, )。
Atlas 推理系列产品 :数据类型支持float16。Atlas A2 训练系列产品/Atlas 800I A2 推理产品/A200I A2 Box 异构组件 :数据类型支持float16、float32。当x数据类型为bfloat16,bias需为float32;当x数据类型为float16,bias需为float16。Atlas A3 训练系列产品/Atlas A3 推理系列产品 :数据类型支持float16、float32。当x数据类型为bfloat16,bias需为float32;当x数据类型为float16,bias需为float16。
- antiquant_group_size:int类型, 用于控制per-group场景下group大小,其他量化场景不生效。可选参数。默认值为0,per-group场景下要求传入值的范围为[32, K-1]且必须是32的倍数。
Atlas 推理系列产品 :暂不支持此参数。
- inner_precise: int类型,计算模式选择, 默认为0。0表示高精度模式,1表示高性能模式,可能会影响精度。当weight以int32类型且以FRACTAL_NZ格式输入,M不大于16的per-group场景下可以设置为1,提升性能。其他场景不建议使用高性能模式。
输出说明
输出为Tensor类型,代表计算结果。当输入存在quant_scale时输出数据类型为int8,当输入不存在quant_scale时输出数据类型和输入x一致。
约束说明
- 该接口仅在推理场景下使用。
- 该接口支持图模式(目前仅支持PyTorch 2.1版本)。当输入weight为FRACTAL_NZ格式时暂不支持单算子调用,只支持图模式调用。
- x和weight后两维必须为(M, K)和(K, N)格式,K、N的范围为[1, 65535];在x为非转置时,M的范围为[1, 2^31-1],在x为转置时,M的范围为[1, 65535]。
- 不支持空Tensor输入。
- antiquant_scale和antiquant_offset的输入shape要保持一致。
- quant_scale和quant_offset的输入shape要保持一致,且quant_offset不能独立于quant_scale存在。
- 如需传入int64数据类型的quant_scale,需要提前调用torch_npu.npu_trans_quant_param接口将数据类型为float32的quant_scale和quant_offset转换为数据类型为int64的quant_scale输入,可参考调用示例。
- 当输入weight为FRACTAL_NZ格式且类型为int32时,per-channel场景需满足weight为转置输入;per-group场景需满足x为转置输入,weight为非转置输入,antiquant_group_size为64或128,K为antiquant_group_size对齐,N为64对齐。
- 不支持输入weight shape为(1, 8)且类型为int4,同时weight带有transpose的场景,否则会报错x矩阵和weight矩阵K轴不匹配,该场景建议走非量化算子获取更高精度和性能。
支持的型号
Atlas A2 训练系列产品/Atlas 800I A2 推理产品/A200I A2 Box 异构组件 Atlas A3 训练系列产品/Atlas A3 推理系列产品 Atlas 推理系列产品
调用示例
- 单算子模式调用
- weight非transpose+quant_scale场景,仅支持如下产品:
Atlas A2 训练系列产品/Atlas 800I A2 推理产品/A200I A2 Box 异构组件 Atlas A3 训练系列产品/Atlas A3 推理系列产品
1 2 3 4 5 6 7 8 9 10 11
import torch import torch_npu # 输入int8+ND cpu_x = torch.randn((8192, 320),dtype=torch.float16) cpu_weight = torch.randint(low=-8, high=8, size=(320, 256),dtype=torch.int8) cpu_antiquantscale = torch.randn((1, 256),dtype=torch.float16) cpu_antiquantoffset = torch.randn((1, 256),dtype=torch.float16) cpu_quantscale = torch.randn((1, 256),dtype=torch.float32) cpu_quantoffset = torch.randn((1, 256),dtype=torch.float32) quantscale= torch_npu.npu_trans_quant_param(cpu_quantscale.npu(), cpu_quantoffset.npu()) npu_out = torch_npu.npu_weight_quant_batchmatmul(cpu_x.npu(), cpu_weight.npu(), cpu_antiquantscale.npu(), cpu_antiquantoffset.npu(),quantscale.npu())
- weight transpose+antiquant_scale场景
1 2 3 4 5 6 7 8
import torch import torch_npu # 输入int8+ND cpu_x = torch.randn((96, 320),dtype=torch.float16) cpu_weight = torch.randint(low=-8, high=8, size=(256, 320),dtype=torch.int8) cpu_antiquantscale = torch.randn((256,1),dtype=torch.float16) cpu_antiquantoffset = torch.randint(-128, 127, (256,1), dtype=torch.float16) npu_out = torch_npu.npu_weight_quant_batchmatmul(cpu_x.npu(), cpu_weight.npu().transpose(-1, -2), cpu_antiquantscale.npu().transpose(-1, -2), cpu_antiquantoffset.npu().transpose(-1, -2))
- weight transpose+antiquant_scale场景 ,仅支持如下产品:
Atlas A2 训练系列产品/Atlas 800I A2 推理产品/A200I A2 Box 异构组件 Atlas A3 训练系列产品/Atlas A3 推理系列产品 Atlas 推理系列产品
1 2 3 4 5 6 7 8 9
import torch import torch_npu cpu_x = torch.randn((96, 320),dtype=torch.float16) cpu_weight = torch.randint(low=-8, high=8, size=(256, 320),dtype=torch.int8) cpu_antiquantscale = torch.randn((256),dtype=torch.float16) # 构建int64类型的scale参数 antiquant_scale = torch_npu.npu_trans_quant_param(cpu_antiquantscale.to(torch.float32).npu()).reshape(256, 1) cpu_antiquantoffset = torch.randint(-128, 127, (256, 1), dtype=torch.int32) npu_out = torch_npu.npu_weight_quant_batchmatmul(cpu_x.npu(), cpu_weight.transpose(-1,-2).npu(), antiquant_scale.transpose(-1,-2).npu(), cpu_antiquantoffset.transpose(-1,-2).npu())
- weight非transpose+quant_scale场景,仅支持如下产品:
- 图模式调用
- weight输入为ND格式
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25
# 图模式 import torch import torch_npu import torchair as tng from torchair.configs.compiler_config import CompilerConfig config = CompilerConfig() config.debug.graph_dump.type = "pbtxt" npu_backend = tng.get_npu_backend(compiler_config=config) cpu_x = torch.randn((8192, 320),device='npu',dtype=torch.bfloat16) cpu_weight = torch.randint(low=-8, high=8, size=(320, 256), dtype=torch.int8, device='npu') cpu_antiquantscale = torch.randn((1, 256),device='npu',dtype=torch.bfloat16) cpu_antiquantoffset = torch.randn((1, 256),device='npu',dtype=torch.bfloat16) class MyModel(torch.nn.Module): def __init__(self): super().__init__() def forward(self, x, weight, antiquant_scale, antiquant_offset, quant_scale,quant_offset, bias, antiquant_group_size): return torch_npu.npu_weight_quant_batchmatmul(x, weight, antiquant_scale, antiquant_offset, quant_scale ,quant_offset, bias, antiquant_group_size) cpu_model = MyModel() model = cpu_model.npu() model = torch.compile(cpu_model, backend=npu_backend, dynamic=True) npu_out = model(cpu_x.npu(), cpu_weight.npu(), cpu_antiquantscale.npu(), cpu_antiquantoffset.npu(), None, None, None, 0)
Atlas 推理系列产品 :weight输入为FRACTAL_NZ格式1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59
import torch_npu import torch from torchair.configs.compiler_config import CompilerConfig import torchair as tng config = CompilerConfig() config.debug.graph_dump.type = "pbtxt" npu_backend = tng.get_npu_backend(compiler_config=config) class NPUQuantizedLinearA16W8(torch.nn.Module): def __init__(self, weight, antiquant_scale, antiquant_offset, quant_offset=None, quant_scale=None, bias=None, transpose_x=False, transpose_weight=True, w4=False): super().__init__() self.dtype = torch.float16 self.weight = weight.to(torch.int8).npu() self.transpose_weight = transpose_weight if self.transpose_weight: self.weight = torch_npu.npu_format_cast(self.weight.contiguous(), 29) else: self.weight = torch_npu.npu_format_cast(self.weight.transpose(0, 1).contiguous(), 29) # n,k ->nz self.bias = None self.antiquant_scale = antiquant_scale self.antiquant_offset = antiquant_offset self.quant_offset = quant_offset self.quant_scale = quant_scale self.transpose_x = transpose_x def forward(self, x): x = torch_npu.npu_weight_quant_batchmatmul(x.transpose(0, 1) if self.transpose_x else x, self.weight.transpose(0, 1), self.antiquant_scale.transpose(0, 1), self.antiquant_offset.transpose(0, 1), self.quant_scale, self.quant_offset, self.bias) return x m, k, n = 4, 1024, 4096 cpu_x = torch.randn((m, k),dtype=torch.float16) cpu_weight = torch.randint(1, 10, (k, n),dtype=torch.int8) cpu_weight = cpu_weight.transpose(-1, -2) cpu_antiquantscale = torch.randn((1, n),dtype=torch.float16) cpu_antiquantoffset = torch.randn((1, n),dtype=torch.float16) cpu_antiquantscale = cpu_antiquantscale.transpose(-1, -2) cpu_antiquantoffset = cpu_antiquantoffset.transpose(-1, -2) model = NPUQuantizedLinearA16W8(cpu_weight.npu(), cpu_antiquantscale.npu(), cpu_antiquantoffset.npu()) model = torch.compile(model, backend=npu_backend, dynamic=True) out = model(cpu_x.npu())
- weight输入为ND格式
父主题: torch_npu