昇腾社区首页
中文
注册
开发者
下载

torch_npu.npu_quantize

产品支持情况

产品 是否支持
[object Object]Atlas A2 训练系列产品/Atlas A2 推理系列产品[object Object]
[object Object]Atlas 推理系列产品[object Object]

功能说明

  • API功能:对输入的张量进行量化处理。
  • 计算公式:
    • div_modeTrue

      result=(input/scales)+zero_pointsresult=(input/scales)+zero\_points
    • div_modeFalse

      result=(inputscales)+zero_pointsresult=(input*scales)+zero\_points

函数原型

[object Object]

参数说明

  • input (Tensor):必选参数,需要进行量化的源数据张量,数据格式支持NDNDNZNZ,支持非连续的Tensor。div_modeFalsedtypequint4x2时,最后一维需要能被8整除。

    • [object Object]Atlas 推理系列产品[object Object]:数据类型支持floatfloat16
    • [object Object]Atlas A2 训练系列产品/Atlas A2 推理系列产品[object Object]:数据类型支持floatfloat16bfloat16
  • scales (Tensor):必选参数,对input进行缩放的张量:

    • div_modeTrue时:

      • [object Object]Atlas 推理系列产品[object Object]:数据类型支持float
      • [object Object]Atlas A2 训练系列产品/Atlas A2 推理系列产品[object Object]:数据类型支持floatbfloat16
    • div_modeFalse时,数据格式支持NDND,支持非连续的Tensor。支持1维或多维(1维时,对应轴的大小需要与input中第axis维相等或等于1;多维时,scales的shape需要与input的shape维度相等,除axis指定的维度,其他维度为1,axis指定的维度必须和input对应的维度相等或等于1)。

      • [object Object]Atlas 推理系列产品[object Object]:数据类型支持floatfloat16
      • [object Object]Atlas A2 训练系列产品/Atlas A2 推理系列产品[object Object]:数据类型支持floatfloat16bfloat16
  • zero_points (Tensor):必选参数,允许为None,对input进行偏移的张量。

    • div_modeTrue

      • [object Object]Atlas 推理系列产品[object Object]:数据类型支持int8uint8int32
      • [object Object]Atlas A2 训练系列产品/Atlas A2 推理系列产品[object Object]:数据类型支持int8uint8int32bfloat16
    • div_modeFalse时,数据格式支持NDND,支持非连续的Tensor。支持1维或多维(1维时,对应轴的大小需要与input中第axis维相等或等于1;多维时,scales的shape需要与input维度相等,除axis指定的维度,其他维度为1,axis指定的维度必须和input对应的维度相等)。zero_points的shape和dtype需要和scales一致。

      • [object Object]Atlas 推理系列产品[object Object]:数据类型支持floatfloat16
      • [object Object]Atlas A2 训练系列产品/Atlas A2 推理系列产品[object Object]:数据类型支持floatfloat16bfloat16
  • dtype (int):必选参数,指定输出参数的类型。

    • div_modeTrue时,

      • [object Object]Atlas 推理系列产品[object Object]:类型支持qint8quint8int32
      • [object Object]Atlas A2 训练系列产品/Atlas A2 推理系列产品[object Object]:类型支持qint8quint8int32
    • div_modeFalse时,类型支持qint8quint4x2。如果dtypequint4x2时,输出tensor类型为int32,由8个int4拼接。

  • axis (int):可选参数,量化的element-wise轴,其他的轴做broadcast,默认值为1

    div_modeFalse时,axis取值范围是[-2, +∞)且指定的轴不能超过输入input的维度数。如果axis为-2,代表量化的element-wise轴是输入input的倒数第二根轴;如果axis大于-2,量化的element-wise轴是输入的最后一根轴。

  • div_mode (bool):可选参数,表示计算scales模式。当div_modeTrue时,表示用除法计算scalesdiv_modeFalse时,表示用乘法计算scales,默认值为True

返回值说明

Tensor

对应公式中的resultresult,输出大小与input一致。数据类型由参数dtype指定,如果参数dtypequint4x2,输出的dtypeint32,shape的最后一维是输入shape最后一维的1/8,shape其他维度和输入一致。

约束说明

  • 该接口支持推理场景下使用。
  • 该接口支持图模式。
  • input数据格式为NZNZ时,input输入shape支持3维,形如(e, k, n),scales输入shape支持1维,zero_points输入为None,dtypequint4x2
  • div_modeFalse时:
    • 支持[object Object]Atlas A2 训练系列产品/Atlas A2 推理系列产品[object Object]。
    • dtypequint4x2或者axis为-2时,不支持[object Object]Atlas 推理系列产品[object Object]。

调用示例

  • 单算子模式调用

    • [object Object]Atlas A2 训练系列产品/Atlas A2 推理系列产品[object Object]

      [object Object]
    • [object Object]Atlas 推理系列产品[object Object]

      [object Object]
  • 图模式调用

    [object Object]