昇腾社区首页
EN
注册

torch_npu.npu_weight_quant_batchmatmul

功能描述

该接口用于实现矩阵乘计算中weight输入和输出的量化操作,支持per-tensor、per-channel、per-group多场景量化。

不同产品支持的量化算法不同,如表1所示。

表1 支持的量化场景

产品型号

量化方式

Atlas 推理系列产品

per-tensor、per-channel

Atlas A2 训练系列产品/Atlas 800I A2 推理产品/A200I A2 Box 异构组件

per-tensor、per-channel、per-group

Atlas A3 训练系列产品/Atlas A3 推理系列产品

per-tensor、per-channel、per-group

接口原型

1
torch_npu.npu_weight_quant_batchmatmul(Tensor x, Tensor weight, Tensor antiquant_scale, Tensor? antiquant_offset=None, Tensor? quant_scale=None, Tensor? quant_offset=None, Tensor? bias=None, int antiquant_group_size=0, int inner_precise=0) -> Tensor

参数说明

  • x : Tensor类型,即矩阵乘中的x。数据格式支持ND,支持带transpose的非连续的Tensor,支持输入维度为两维(M, K) 。
    • Atlas 推理系列产品:数据类型支持float16。
    • Atlas A2 训练系列产品/Atlas 800I A2 推理产品/A200I A2 Box 异构组件:数据类型支持float16、bfloat16。
    • Atlas A3 训练系列产品/Atlas A3 推理系列产品:数据类型支持float16、bfloat16。
  • weight:Tensor类型,即矩阵乘中的weight。支持带transpose的非连续的Tensor,支持输入维度为两维(K, N),维度需与x保持一致。当数据格式为ND时,per-channel场景下为提高性能推荐使用transpose后的weight输入。
    • Atlas 推理系列产品:数据类型支持int8。数据格式支持ND、FRACTAL_NZ,其中FRACTAL_NZ格式只在“图模式”有效,需依赖接口torch_npu.npu_format_cast完成ND到FRACTAL_NZ的转换,可参考调用示例
    • Atlas A2 训练系列产品/Atlas 800I A2 推理产品/A200I A2 Box 异构组件:数据类型支持int8、int32(通过int32承载int4的输入,可参考torch_npu.npu_convert_weight_to_int4pack调用示例)。数据格式支持ND、FRACTAL_NZ。
    • Atlas A3 训练系列产品/Atlas A3 推理系列产品:数据类型支持int8、int32(通过int32承载int4的输入,可参考torch_npu.npu_convert_weight_to_int4pack调用示例)。数据格式支持ND、FRACTAL_NZ。
  • antiquant_scale:Tensor类型,反量化的scale,用于weight矩阵反量化,数据格式支持ND。支持带transpose的非连续的Tensor。antiquant_scale支持的shape与量化方式相关:
    • per_tensor模式:输入shape为(1,)或(1, 1)。
    • per_channel模式:输入shape为(1, N)或(N,)。
    • per_group模式:输入shape为(ceil(K, antiquant_group_size), N)。
    antiquant_scale支持的dtype如下:
    • Atlas 推理系列产品:数据类型支持float16,其数据类型需与x保持一致。
    • Atlas A2 训练系列产品/Atlas 800I A2 推理产品/A200I A2 Box 异构组件:数据类型支持float16、bfloat16、int64。
      • 若输入为float16、bfloat16, 其数据类型需与x保持一致。
      • 若输入为int64,x数据类型必须为float16且不带transpose输入,同时weight数据类型必须为int8、数据格式为ND、带transpose输入,可参考调用示例。此时只支持per-channel场景,M范围为[1, 96],且K和N要求64对齐。
    • Atlas A3 训练系列产品/Atlas A3 推理系列产品:数据类型支持float16、bfloat16、int64。
      • 若输入为float16、bfloat16, 其数据类型需与x保持一致。
      • 若输入为int64,x数据类型必须为float16且不带transpose输入,同时weight数据类型必须为int8、数据格式为ND、带transpose输入,可参考调用示例。此时只支持per-channel场景,M范围为[1, 96],且K和N要求64对齐。
  • antiquant_offset:Tensor类型,反量化的offset,用于weight矩阵反量化。可选参数,默认值为None,数据格式支持ND,支持带transpose的非连续的Tensor,支持输入维度为两维(1, N)或一维(N, )、(1, )。
    • Atlas 推理系列产品:数据类型支持float16,其数据类型需与antiquant_scale保持一致。
    • Atlas A2 训练系列产品/Atlas 800I A2 推理产品/A200I A2 Box 异构组件:数据类型支持float16、bfloat16、int32。per-group场景shape要求为(ceil_div(K, antiquant_group_size), N)。
      • 若输入为float16、bfloat16,其数据类型需与antiquant_scale保持一致。
      • 若输入为int32,antiquant_scale的数据类型必须为int64。
    • Atlas A3 训练系列产品/Atlas A3 推理系列产品:数据类型支持float16、bfloat16、int32。per-group场景shape要求为(ceil_div(K, antiquant_group_size), N)。
      • 若输入为float16、bfloat16,其数据类型需与antiquant_scale保持一致。
      • 若输入为int32,antiquant_scale的数据类型必须为int64。
  • quant_scale:Tensor类型,量化的scale,用于输出矩阵的量化,可选参数,默认值为None,仅在weight格式为ND时支持。数据类型支持float32、int64,数据格式支持ND,支持输入维度为两维(1, N)或一维(N, )、(1, )。当antiquant_scale的数据类型为int64时,此参数必须为空。
    • Atlas 推理系列产品:暂不支持此参数。
  • quant_offset: Tensor类型,量化的offset,用于输出矩阵的量化,可选参数,默认值为None,仅在weight格式为ND时支持。数据类型支持float32,数据格式支持ND,支持输入维度为两维(1, N)或一维(N, )、(1, )。当antiquant_scale的数据类型为int64时,此参数必须为空。
    • Atlas 推理系列产品:暂不支持此参数。
  • bias:Tensor类型, 即矩阵乘中的bias,可选参数,默认值为None,数据格式支持ND, 不支持非连续的Tensor,支持输入维度为两维(1, N)或一维(N, )、(1, )。
    • Atlas 推理系列产品:数据类型支持float16。
    • Atlas A2 训练系列产品/Atlas 800I A2 推理产品/A200I A2 Box 异构组件:数据类型支持float16、float32。当x数据类型为bfloat16,bias需为float32;当x数据类型为float16,bias需为float16。
    • Atlas A3 训练系列产品/Atlas A3 推理系列产品:数据类型支持float16、float32。当x数据类型为bfloat16,bias需为float32;当x数据类型为float16,bias需为float16。
  • antiquant_group_size:int类型, 用于控制per-group场景下group大小,其他量化场景不生效。可选参数。默认值为0,per-group场景下要求传入值的范围为[32, K-1]且必须是32的倍数。
    • Atlas 推理系列产品:暂不支持此参数。
  • inner_precise: int类型,计算模式选择, 默认为0。0表示高精度模式,1表示高性能模式,可能会影响精度。当weight以int32类型且以FRACTAL_NZ格式输入,M不大于16的per-group场景下可以设置为1,提升性能。其他场景不建议使用高性能模式。

输出说明

输出为Tensor类型,代表计算结果。当输入存在quant_scale时输出数据类型为int8,当输入不存在quant_scale时输出数据类型和输入x一致。

约束说明

  • 该接口仅在推理场景下使用。
  • 该接口支持图模式(目前仅支持PyTorch 2.1版本)。当输入weight为FRACTAL_NZ格式时暂不支持单算子调用,只支持图模式调用。
  • x和weight后两维必须为(M, K)和(K, N)格式,K、N的范围为[1, 65535];在x为非转置时,M的范围为[1, 2^31-1],在x为转置时,M的范围为[1, 65535]。
  • 不支持空Tensor输入。
  • antiquant_scale和antiquant_offset的输入shape要保持一致。
  • quant_scale和quant_offset的输入shape要保持一致,且quant_offset不能独立于quant_scale存在。
  • 如需传入int64数据类型的quant_scale,需要提前调用torch_npu.npu_trans_quant_param接口将数据类型为float32的quant_scale和quant_offset转换为数据类型为int64的quant_scale输入,可参考调用示例
  • 当输入weight为FRACTAL_NZ格式且类型为int32时,per-channel场景需满足weight为转置输入;per-group场景需满足x为转置输入,weight为非转置输入,antiquant_group_size为64或128,K为antiquant_group_size对齐,N为64对齐。
  • 不支持输入weight shape为(1, 8)且类型为int4,同时weight带有transpose的场景,否则会报错x矩阵和weight矩阵K轴不匹配,该场景建议走非量化算子获取更高精度和性能。

支持的型号

  • Atlas A2 训练系列产品/Atlas 800I A2 推理产品/A200I A2 Box 异构组件
  • Atlas A3 训练系列产品/Atlas A3 推理系列产品
  • Atlas 推理系列产品

调用示例

  • 单算子模式调用
    • weight非transpose+quant_scale场景,仅支持如下产品:
      • Atlas A2 训练系列产品/Atlas 800I A2 推理产品/A200I A2 Box 异构组件
      • Atlas A3 训练系列产品/Atlas A3 推理系列产品
       1
       2
       3
       4
       5
       6
       7
       8
       9
      10
      11
      import torch
      import torch_npu
      # 输入int8+ND 
      cpu_x = torch.randn((8192, 320),dtype=torch.float16)
      cpu_weight = torch.randint(low=-8, high=8, size=(320, 256),dtype=torch.int8)
      cpu_antiquantscale = torch.randn((1, 256),dtype=torch.float16)
      cpu_antiquantoffset = torch.randn((1, 256),dtype=torch.float16)
      cpu_quantscale = torch.randn((1, 256),dtype=torch.float32)
      cpu_quantoffset = torch.randn((1, 256),dtype=torch.float32)
      quantscale= torch_npu.npu_trans_quant_param(cpu_quantscale.npu(), cpu_quantoffset.npu())
      npu_out = torch_npu.npu_weight_quant_batchmatmul(cpu_x.npu(), cpu_weight.npu(), cpu_antiquantscale.npu(), cpu_antiquantoffset.npu(),quantscale.npu())
      
    • weight transpose+antiquant_scale场景
      1
      2
      3
      4
      5
      6
      7
      8
      import torch
      import torch_npu
      # 输入int8+ND 
      cpu_x = torch.randn((96, 320),dtype=torch.float16)
      cpu_weight = torch.randint(low=-8, high=8, size=(256, 320),dtype=torch.int8)
      cpu_antiquantscale = torch.randn((256,1),dtype=torch.float16)
      cpu_antiquantoffset = torch.randint(-128, 127, (256,1), dtype=torch.float16)
      npu_out = torch_npu.npu_weight_quant_batchmatmul(cpu_x.npu(), cpu_weight.npu().transpose(-1, -2), cpu_antiquantscale.npu().transpose(-1, -2), cpu_antiquantoffset.npu().transpose(-1, -2))
      
    • weight transpose+antiquant_scale场景 ,仅支持如下产品:
      • Atlas A2 训练系列产品/Atlas 800I A2 推理产品/A200I A2 Box 异构组件
      • Atlas A3 训练系列产品/Atlas A3 推理系列产品
      • Atlas 推理系列产品
      1
      2
      3
      4
      5
      6
      7
      8
      9
      import torch
      import torch_npu
      cpu_x = torch.randn((96, 320),dtype=torch.float16)
      cpu_weight = torch.randint(low=-8, high=8, size=(256, 320),dtype=torch.int8)
      cpu_antiquantscale = torch.randn((256),dtype=torch.float16)
      # 构建int64类型的scale参数
      antiquant_scale = torch_npu.npu_trans_quant_param(cpu_antiquantscale.to(torch.float32).npu()).reshape(256, 1)
      cpu_antiquantoffset = torch.randint(-128, 127, (256, 1), dtype=torch.int32)
      npu_out = torch_npu.npu_weight_quant_batchmatmul(cpu_x.npu(), cpu_weight.transpose(-1,-2).npu(), antiquant_scale.transpose(-1,-2).npu(), cpu_antiquantoffset.transpose(-1,-2).npu())
      
  • 图模式调用
    • weight输入为ND格式
       1
       2
       3
       4
       5
       6
       7
       8
       9
      10
      11
      12
      13
      14
      15
      16
      17
      18
      19
      20
      21
      22
      23
      24
      25
      # 图模式
      import torch
      import torch_npu
      import  torchair as tng
      from torchair.configs.compiler_config import CompilerConfig
      config = CompilerConfig()
      config.debug.graph_dump.type = "pbtxt"
      npu_backend = tng.get_npu_backend(compiler_config=config)
      
      cpu_x = torch.randn((8192, 320),device='npu',dtype=torch.bfloat16)
      cpu_weight = torch.randint(low=-8, high=8, size=(320, 256), dtype=torch.int8, device='npu')
      cpu_antiquantscale = torch.randn((1, 256),device='npu',dtype=torch.bfloat16)
      cpu_antiquantoffset = torch.randn((1, 256),device='npu',dtype=torch.bfloat16)
      
      class MyModel(torch.nn.Module):
          def __init__(self):
              super().__init__()
      
          def forward(self, x, weight, antiquant_scale, antiquant_offset, quant_scale,quant_offset, bias, antiquant_group_size):
              return torch_npu.npu_weight_quant_batchmatmul(x, weight, antiquant_scale, antiquant_offset, quant_scale ,quant_offset, bias, antiquant_group_size)
      
      cpu_model = MyModel()
      model = cpu_model.npu()
      model = torch.compile(cpu_model, backend=npu_backend, dynamic=True)
      npu_out = model(cpu_x.npu(), cpu_weight.npu(), cpu_antiquantscale.npu(), cpu_antiquantoffset.npu(), None, None, None, 0)
      
    • Atlas 推理系列产品:weight输入为FRACTAL_NZ格式
       1
       2
       3
       4
       5
       6
       7
       8
       9
      10
      11
      12
      13
      14
      15
      16
      17
      18
      19
      20
      21
      22
      23
      24
      25
      26
      27
      28
      29
      30
      31
      32
      33
      34
      35
      36
      37
      38
      39
      40
      41
      42
      43
      44
      45
      46
      47
      48
      49
      50
      51
      52
      53
      54
      55
      56
      57
      58
      59
      import torch_npu
      import torch
      from torchair.configs.compiler_config import CompilerConfig
      import torchair as tng
      config = CompilerConfig()
      config.debug.graph_dump.type = "pbtxt"
      npu_backend = tng.get_npu_backend(compiler_config=config)
      class NPUQuantizedLinearA16W8(torch.nn.Module):
          def __init__(self,
                       weight,
                       antiquant_scale,
                       antiquant_offset,
                       quant_offset=None,
                       quant_scale=None,
                       bias=None,
                       transpose_x=False,
                       transpose_weight=True,
                       w4=False):
              super().__init__()
      
              self.dtype = torch.float16
              self.weight = weight.to(torch.int8).npu()
              self.transpose_weight = transpose_weight
      
              if self.transpose_weight:
                  self.weight = torch_npu.npu_format_cast(self.weight.contiguous(), 29)
              else:
                  self.weight = torch_npu.npu_format_cast(self.weight.transpose(0, 1).contiguous(), 29) # n,k ->nz
      
              self.bias = None
              self.antiquant_scale = antiquant_scale
              self.antiquant_offset = antiquant_offset
              self.quant_offset = quant_offset
              self.quant_scale = quant_scale
              self.transpose_x = transpose_x
      
          def forward(self, x):
              x = torch_npu.npu_weight_quant_batchmatmul(x.transpose(0, 1) if self.transpose_x else x,
                                                         self.weight.transpose(0, 1),
                                                         self.antiquant_scale.transpose(0, 1),
                                                         self.antiquant_offset.transpose(0, 1),
                                                         self.quant_scale,
                                                         self.quant_offset,
                                                         self.bias)
              return x
      
      
      m, k, n = 4, 1024, 4096
      cpu_x = torch.randn((m, k),dtype=torch.float16)
      cpu_weight = torch.randint(1, 10, (k, n),dtype=torch.int8)
      cpu_weight = cpu_weight.transpose(-1, -2)
      
      cpu_antiquantscale = torch.randn((1, n),dtype=torch.float16)
      cpu_antiquantoffset = torch.randn((1, n),dtype=torch.float16)
      cpu_antiquantscale = cpu_antiquantscale.transpose(-1, -2)
      cpu_antiquantoffset = cpu_antiquantoffset.transpose(-1, -2)
      model = NPUQuantizedLinearA16W8(cpu_weight.npu(), cpu_antiquantscale.npu(), cpu_antiquantoffset.npu())
      model = torch.compile(model, backend=npu_backend, dynamic=True)
      out = model(cpu_x.npu())