昇腾社区首页
EN
注册

torch_npu.npu_quant_matmul

功能描述

完成量化的矩阵乘计算,最小支持输入维度为2维,最大支持输入维度为6维。

接口原型

1
torch_npu.npu_quant_matmul(Tensor x1, Tensor x2, Tensor scale, *, Tensor? offset=None, Tensor? pertoken_scale=None, Tensor? bias=None, ScalarType? output_dtype=None) -> Tensor

参数说明

  • x1:Tensor类型,数据格式支持ND,shape需要在2-6维范围。
    • Atlas 推理系列加速卡产品:数据类型支持int8。
    • Atlas A2 训练系列产品/Atlas 800I A2 推理产品/A200I A2 Box 异构组件:数据类型支持int8和int32。其中int32表示int4类型矩阵乘计算,每个int32数据存放8个int4数据。
    • Atlas A3 训练系列产品/Atlas A3 推理系列产品:数据类型支持int8和int32。其中int32表示int4类型矩阵乘计算,每个int32数据存放8个int4数据。
  • x2:Tensor类型(weight),其与x1的数据类型须保持一致。数据格式支持ND,shape需要在2-6维范围。
    • Atlas 推理系列加速卡产品:数据类型支持int8。
    • Atlas A2 训练系列产品/Atlas 800I A2 推理产品/A200I A2 Box 异构组件:数据类型支持int8和int32(int32含义同x1,表示int4类型计算)。
    • Atlas A3 训练系列产品/Atlas A3 推理系列产品:数据类型支持int8和int32(int32含义同x1,表示int4类型计算)。
  • scale:Tensor类型,数据格式支持ND,shape需要是1维(t, ),t=1或n,其中n与x2的n一致。如需传入int64数据类型的scale,需要提前调用torch_npu.npu_trans_quant_param来获取int64数据类型的scale。
    • Atlas 推理系列加速卡产品:数据类型支持float32、int64。
    • Atlas A2 训练系列产品/Atlas 800I A2 推理产品/A200I A2 Box 异构组件:数据类型支持float32、int64、bfloat16。
    • Atlas A3 训练系列产品/Atlas A3 推理系列产品:数据类型支持float32、int64、bfloat16。
  • offset:Tensor类型,可选参数。数据类型支持float32,数据格式支持ND,shape需要是1维(t,),t=1或n,其中n与x2的n一致。
  • pertoken_scale:Tensor类型,可选参数。数据类型支持float32,数据格式支持ND,shape需要是1维(m,),其中m与x1的m一致。Atlas 推理系列加速卡产品当前不支持pertoken_scale。
  • bias:Tensor类型,可选参数,数据格式支持ND,shape支持1维(n,)或3维(batch, 1, n),n与x2的n一致,同时batch值需要等于x1和x2 boardcast后推导出的batch值。当输出是2、4、5、6维情况下,bias的shape必须为1维。当输出是3维情况下,bias的shape可以为1维或3维。
    • Atlas 推理系列加速卡产品:数据类型支持int32。
    • Atlas A2 训练系列产品/Atlas 800I A2 推理产品/A200I A2 Box 异构组件:数据类型支持int32、bfloat16、float16、float32。
    • Atlas A3 训练系列产品/Atlas A3 推理系列产品:数据类型支持int32、bfloat16、float16、float32。
  • output_dtype:ScalarType类型,可选参数。表示输出Tensor的数据类型。默认值为None,代表输出Tensor数据类型为int8。
    • Atlas 推理系列加速卡产品:支持输入torch.int8、torch.float16。
    • Atlas A2 训练系列产品/Atlas 800I A2 推理产品/A200I A2 Box 异构组件:支持输入torch.int8、torch.float16、torch.bfloat16、torch.int32。
    • Atlas A3 训练系列产品/Atlas A3 推理系列产品:支持输入torch.int8、torch.float16、torch.bfloat16、torch.int32。

输出说明

result:Tensor类型,代表量化matmul的计算结果。

  • 如果output_dtype为torch.float16,输出的数据类型为float16。
  • 如果output_dtype为torch.int8或者None,输出的数据类型为int8。
  • 如果output_dtype为torch.bfloat16,输出的数据类型为bfloat16。
  • 如果output_dtype为torch.int32,输出的数据类型为int32。

约束说明

  • 该接口仅在推理场景下使用。
  • 该接口支持图模式(目前仅支持PyTorch 2.1版本)。
  • 传入的x1、x2、scale不能是空。
  • x1、x2、bias、scale、offset、pertoken_scale、output_dtype的数据类型和数据格式需要在支持的范围之内。
  • x1与x2最后一维的shape大小不能超过65535。
  • 目前输出int8或float16且无pertoken_scale情况下,图模式不支持scale直接传入float32数据类型。
  • 如果在PyTorch图模式中使用本接口,且环境变量ENABLE_ACLNN=false,则在调用接口前需要对shape为(n, k//8)的x2数据进行转置,转置过程应写在图中。
  • 支持将x2转为昇腾亲和的数据排布以提高搬运效率。需要调用torch_npu.npu_format_cast完成输入x2(weight)为昇腾亲和的数据排布功能。
    • Atlas 推理系列加速卡产品:必须先将x2转置后再转亲和format。
    • Atlas A2 训练系列产品/Atlas 800I A2 推理产品/A200I A2 Box 异构组件:推荐x2不转置直接转亲和format。
    • Atlas A3 训练系列产品/Atlas A3 推理系列产品:推荐x2不转置直接转亲和format。
  • int4类型计算的额外约束:

    当x1、x2的数据类型均为int32,每个int32类型的数据存放8个int4数据。输入的int32 shape需要将数据原本int4类型时shape的最后一维缩小8倍。int4数据的shape最后一维应为8的倍数,例如:进行(m, k)乘(k, n)的int4类型矩阵乘计算时,需要输入int32类型、shape为(m, k//8)、(k, n//8)的数据,其中k与n都应是8的倍数。x1只能接受shape为(m, k//8)且数据排布连续的数据,x2可以接受(k, n[g1] //8)且数据排布连续的数据或shape为(k//8, n)且是由数据连续排布的(n, k//8)转置而来的数据。

    数据排布连续是指数组中所有相邻的数,包括换行时内存地址连续,使用Tensor.is_contiguous返回值为true则表明tensor数据排布连续。

  • 输入参数间支持的数据类型组合情况如下:
    表1 Atlas 推理系列产品

    x1

    x2

    scale

    offset

    bias

    pertoken_scale

    output_dtype

    int8

    int8

    int64/float32

    None

    int32/None

    None

    float16

    int8

    int8

    int64/float32

    float32/None

    int32/None

    None

    int8

    表2 Atlas A2 训练系列产品/Atlas 800I A2 推理产品/A200I A2 Box 异构组件)(Atlas A3 训练系列产品/Atlas A3 推理系列产品

    x1

    x2

    scale

    offset

    bias

    pertoken_scale

    output_dtype

    int8

    int8

    int64/float32

    None

    int32/None

    None

    float16

    int8

    int8

    int64/float32

    float32/None

    int32/None

    None

    int8

    int8

    int8

    float32/bfloat16

    None

    int32/bfloat16/float32/None

    float32/None

    bfloat16

    int8

    int8

    float32

    None

    int32/float16/float32/None

    float32

    float16

    int32

    int32

    int64/float32

    None

    int32/None

    None

    float16

    int8

    int8

    float32/bfloat16

    None

    int32/None

    None

    int32

支持的型号

  • Atlas A2 训练系列产品/Atlas 800I A2 推理产品/A200I A2 Box 异构组件
  • Atlas 推理系列加速卡产品
  • Atlas A3 训练系列产品/Atlas A3 推理系列产品

调用示例

  • 单算子调用
    • int8类型输入场景:
       1
       2
       3
       4
       5
       6
       7
       8
       9
      10
      11
      12
      13
      14
      15
      16
      import torch
      import torch_npu
      import logging
      import os
      
      cpu_x1 = torch.randint(-5, 5, (1, 256, 768), dtype=torch.int8)
      cpu_x2 = torch.randint(-5, 5, (31, 768, 16), dtype=torch.int8)
      scale = torch.randn(16, dtype=torch.float32)
      offset = torch.randn(16, dtype=torch.float32)
      bias = torch.randint(-5, 5, (31, 1, 16), dtype=torch.int32)
      # Method 1:You can directly call npu_quant_matmul
      npu_out = torch_npu.npu_quant_matmul(cpu_x1.npu(), cpu_x2.npu(), scale.npu(), offset=offset.npu(), bias=bias.npu())
      
      # Method 2: You can first call npu_trans_quant_param to convert scale and offset from float32 to int64 when output dtype is not torch.bfloat16 and pertoken_scale is none
      scale_1 = torch_npu.npu_trans_quant_param(scale.npu(), offset.npu())
      npu_out = torch_npu.npu_quant_matmul(cpu_x1.npu(), cpu_x2.npu(), scale_1,  bias=bias.npu())
      
  • 图模式调用(ND数据格式)
    • 输出float16
       1
       2
       3
       4
       5
       6
       7
       8
       9
      10
      11
      12
      13
      14
      15
      16
      17
      18
      19
      20
      21
      22
      23
      24
      25
      26
      27
      28
      29
      30
      31
      import torch
      import torch_npu
      import torchair as tng
      from torchair.ge_concrete_graph import ge_apis as ge
      from torchair.configs.compiler_config import CompilerConfig
      import logging
      from torchair.core.utils import logger
      logger.setLevel(logging.DEBUG)
      import os
      import numpy as np
      # "ENABLE_ACLNN"是否使能走aclnn, true: 回调走aclnn, false: 在线编译
      os.environ["ENABLE_ACLNN"] = "true"
      config = CompilerConfig()
      npu_backend = tng.get_npu_backend(compiler_config=config)
      
      class MyModel(torch.nn.Module):
          def __init__(self):
              super().__init__()
          def forward(self, x1, x2, scale, offset, bias):
              return torch_npu.npu_quant_matmul(x1, x2, scale, offset=offset, bias=bias, output_dtype=torch.float16)
      cpu_model = MyModel()
      model = cpu_model.npu()
      cpu_x1 = torch.randint(-1, 1, (15, 1, 512), dtype=torch.int8)
      cpu_x2 = torch.randint(-1, 1, (15, 512, 128), dtype=torch.int8)
      scale = torch.randn(1, dtype=torch.float32)  
      # pertoken_scale为空时,输出fp16必须先调用npu_trans_quant_param,将scale(offset)从float转为int64.
      scale_1 = torch_npu.npu_trans_quant_param(scale.npu(), None)
      bias = torch.randint(-1,1, (15, 1, 128), dtype=torch.int32)
      # dynamic=True: 动态图模式, dynamic=False: 静态图模式
      model = torch.compile(cpu_model, backend=npu_backend, dynamic=True)
      npu_out = model(cpu_x1.npu(), cpu_x2.npu(), scale_1, None, bias.npu())
      
    • 输出bfloat16,示例代码如下,仅支持如下产品:
      • Atlas A2 训练系列产品/Atlas 800I A2 推理产品/A200I A2 Box 异构组件
      • Atlas A3 训练系列产品/Atlas A3 推理系列产品
       1
       2
       3
       4
       5
       6
       7
       8
       9
      10
      11
      12
      13
      14
      15
      16
      17
      18
      19
      20
      21
      22
      23
      24
      25
      26
      27
      28
      29
      30
      31
      32
      33
      34
      35
      36
      import torch
      import torch_npu
      import torchair as tng
      from torchair.ge_concrete_graph import ge_apis as ge
      from torchair.configs.compiler_config import CompilerConfig
      import logging
      from torchair.core.utils import logger
      logger.setLevel(logging.DEBUG)
      import os
      import numpy as np
      os.environ["ENABLE_ACLNN"] = "true"
      config = CompilerConfig()
      npu_backend = tng.get_npu_backend(compiler_config=config)
      
      class MyModel(torch.nn.Module):
          def __init__(self):
              super().__init__()
          def forward(self, x1, x2, scale, offset, bias, pertoken_scale):
              return torch_npu.npu_quant_matmul(x1, x2.t(), scale, offset=offset, bias=bias, pertoken_scale=pertoken_scale, output_dtype=torch.bfloat16)
      cpu_model = MyModel()
      model = cpu_model.npu()
      m = 15
      k = 11264
      n = 6912
      bias_flag = True
      cpu_x1 = torch.randint(-1, 1, (m, k), dtype=torch.int8)
      cpu_x2 = torch.randint(-1, 1, (n, k), dtype=torch.int8)
      scale = torch.randint(-1,1, (n,), dtype=torch.bfloat16)
      pertoken_scale = torch.randint(-1,1, (m,), dtype=torch.float32)
      
      bias = torch.randint(-1,1, (n,), dtype=torch.bfloat16)
      model = torch.compile(cpu_model, backend=npu_backend, dynamic=True)
      if bias_flag:
          npu_out = model(cpu_x1.npu(), cpu_x2.npu(), scale.npu(), None, bias.npu(), pertoken_scale.npu())
      else:
          npu_out = model(cpu_x1.npu(), cpu_x2.npu(), scale.npu(), None, None, pertoken_scale.npu())
      
  • 图模式调用(高性能数据排布方式)
    • 将x2转置(batch, n, k)后转format,示例代码如下,仅支持Atlas 推理系列加速卡产品
       1
       2
       3
       4
       5
       6
       7
       8
       9
      10
      11
      12
      13
      14
      15
      16
      17
      18
      19
      20
      21
      22
      23
      24
      25
      26
      27
      28
      29
      30
      31
      32
      import torch
      import torch_npu
      import torchair as tng
      from torchair.ge_concrete_graph import ge_apis as ge
      from torchair.configs.compiler_config import CompilerConfig
      import logging
      from torchair.core.utils import logger
      logger.setLevel(logging.DEBUG)
      import os
      import numpy as np
      os.environ["ENABLE_ACLNN"] = "true"
      config = CompilerConfig()
      npu_backend = tng.get_npu_backend(compiler_config=config)
      
      class MyModel(torch.nn.Module):
          def __init__(self):
              super().__init__()
          def forward(self, x1, x2, scale, offset, bias):
              return torch_npu.npu_quant_matmul(x1, x2.transpose(2,1), scale, offset=offset, bias=bias)
      cpu_model = MyModel()
      model = cpu_model.npu()
      cpu_x1 = torch.randint(-1, 1, (15, 1, 512), dtype=torch.int8).npu()
      cpu_x2 = torch.randint(-1, 1, (15, 512, 128), dtype=torch.int8).npu()
      # Process x2 into a high-bandwidth format(29) offline to improve performance, please ensure that the input is continuous with (batch,n,k) layout
      cpu_x2_t_29 = torch_npu.npu_format_cast(cpu_x2.transpose(2,1).contiguous(), 29)
      scale = torch.randn(1, dtype=torch.float32).npu()
      offset = torch.randn(1, dtype=torch.float32).npu()
      bias = torch.randint(-1,1, (128,), dtype=torch.int32).npu()
      # Process scale from float32 to int64 offline to improve performance
      scale_1 = torch_npu.npu_trans_quant_param(scale, offset)
      model = torch.compile(cpu_model, backend=npu_backend, dynamic=False)
      npu_out = model(cpu_x1, cpu_x2_t_29, scale_1, offset, bias)
      
    • 将x2非转置(batch, k, n)后转format,示例代码如下,仅支持如下产品:
      • Atlas A2 训练系列产品/Atlas 800I A2 推理产品/A200I A2 Box 异构组件
      • Atlas A3 训练系列产品/Atlas A3 推理系列产品
       1
       2
       3
       4
       5
       6
       7
       8
       9
      10
      11
      12
      13
      14
      15
      16
      17
      18
      19
      20
      21
      22
      23
      24
      25
      26
      27
      28
      29
      30
      31
      32
      33
      34
      35
      36
      37
      import torch
      import torch_npu
      import torchair as tng
      from torchair.ge_concrete_graph import ge_apis as ge
      from torchair.configs.compiler_config import CompilerConfig
      import logging
      from torchair.core.utils import logger
      logger.setLevel(logging.DEBUG)
      import os
      import numpy as np
      config = CompilerConfig()
      npu_backend = tng.get_npu_backend(compiler_config=config)
      
      class MyModel(torch.nn.Module):
          def __init__(self):
              super().__init__()
          def forward(self, x1, x2, scale, offset, bias, pertoken_scale):
              return torch_npu.npu_quant_matmul(x1, x2, scale, offset=offset, bias=bias, pertoken_scale=pertoken_scale, output_dtype=torch.bfloat16)
      cpu_model = MyModel()
      model = cpu_model.npu()
      m = 15
      k = 11264
      n = 6912
      bias_flag = True
      cpu_x1 = torch.randint(-1, 1, (m, k), dtype=torch.int8)
      cpu_x2 = torch.randint(-1, 1, (n, k), dtype=torch.int8)
      # Process x2 into a high-bandwidth format(29) offline to improve performance, please ensure that the input is continuous with (batch,k,n) layout
      x2_notranspose_29 = torch_npu.npu_format_cast(cpu_x2.npu().transpose(1,0).contiguous(), 29)
      scale = torch.randint(-1,1, (n,), dtype=torch.bfloat16)
      pertoken_scale = torch.randint(-1,1, (m,), dtype=torch.float32)
      
      bias = torch.randint(-1,1, (n,), dtype=torch.bfloat16)
      model = torch.compile(cpu_model, backend=npu_backend, dynamic=True)
      if bias_flag:
          npu_out = model(cpu_x1.npu(), x2_notranspose_29, scale.npu(), None, bias.npu(), pertoken_scale.npu())
      else:
          npu_out = model(cpu_x1.npu(), x2_notranspose_29, scale.npu(), None, None, pertoken_scale.npu())