昇腾社区首页
中文
注册

torch_npu.contrib.module.LinearA8W8Quant

须知:[object Object] 该接口计划废弃,可以使用torch_npu.contrib.module.LinearQuant接口进行替换。

功能说明

LinearA8W8Quant是对torch_npu.npu_quant_matmul接口的封装类,完成A8W8量化算子的矩阵乘计算。

函数原型

[object Object]

参数说明

  • in_features(计算参数):int类型,matmul计算中k轴的值。
  • out_features(计算参数):int类型,matmul计算中n轴的值。
  • bias(计算参数):bool类型,代表是否需要bias计算参数。如果设置成False,则bias不会加入量化matmul的计算。
  • offset(计算参数):bool类型,代表是否需要offset计算参数。如果设置成False,则offset不会加入量化matmul的计算。
  • pertoken_scale(计算参数):bool类型,代表是否需要pertoken_scale计算参数。如果设置成False,则pertoken_scale不会加入量化matmul的计算。[object Object]Atlas 推理系列产品[object Object]当前不支持pertoken_scale。
  • output_dtype(计算参数):ScalarType类型,表示输出Tensor的数据类型。默认值为None,代表输出Tensor数据类型为int8。
    • [object Object]Atlas 推理系列产品[object Object]:支持输入torch.int8、torch.float16。
    • [object Object]Atlas A2 训练系列产品/Atlas 800I A2 推理产品/A200I A2 Box 异构组件[object Object]:支持输入torch.int8、torch.float16、torch.bfloat16。
    • [object Object]Atlas A3 训练系列产品/Atlas A3 推理系列产品[object Object]:支持输入torch.int8、torch.float16、torch.bfloat16。

输入说明

x1(计算输入):Tensor类型,数据类型支持int8。数据格式支持ND,shape需要在2-6维范围。

变量说明

  • weight(变量):Tensor类型,矩阵乘中的weight。数据格式支持int8。数据格式支持ND,shape为(batch, n, k),shape需要在2-6维范围。

    • [object Object]Atlas 推理系列产品[object Object]:需要调用torchair.experimental.inference.use_internal_format_weight或torch_npu.npu_format_cast完成weight(batch, n, k)高性能数据排布功能。
    • [object Object]Atlas A2 训练系列产品/Atlas 800I A2 推理产品/A200I A2 Box 异构组件[object Object]:需要调用torch_npu.npu_format_cast完成weight(batch,n,k)高性能数据排布功能,但不推荐使用该module方式,推荐torch_npu.npu_quant_matmul。
    • [object Object]Atlas A3 训练系列产品/Atlas A3 推理系列产品[object Object]:需要调用npu_format_cast完成weight(batch, n, k)高性能数据排布功能,但不推荐使用该module方式,推荐torch_npu.npu_quant_matmul。
  • scale(变量):Tensor类型,量化计算的scale。数据格式支持ND,shape是1维(t,),t=1或n,其中n与weight的n一致。如需传入int64数据类型的scale,需要提前调用torch_npu.npu_trans_quant_param接口来获取int64数据类型的scale。

    • [object Object]Atlas 推理系列产品[object Object]:数据类型支持float32、int64。
    • [object Object]Atlas A2 训练系列产品/Atlas 800I A2 推理产品/A200I A2 Box 异构组件[object Object]:数据类型支持float32、int64、bfloat16。
    • [object Object]Atlas A3 训练系列产品/Atlas A3 推理系列产品[object Object]:数据类型支持float32、int64、bfloat16。
  • offset(变量):Tensor类型,量化计算的offset。可选参数。数据类型支持float32,数据格式支持ND,shape是1维(t,),t=1或n,其中n与weight的n一致。

  • pertoken_scale(变量):Tensor类型,可选参数。量化计算的pertoken。数据类型支持float32,数据格式支持ND,shape是1维(m,),其中m与x1的m一致。目前仅在输出为float16和bfloat16场景下可不为空。[object Object]Atlas 推理系列产品[object Object]当前不支持pertoken_scale。

  • bias(变量):Tensor类型,可选参数。矩阵乘中的bias。数据格式支持ND,shape支持1维(n,)或3维(batch, 1, n),n与weight的n一致,同时batch值需要等于x1,weight broadcast后推导出的batch值。当输出为2、4、5、6维情况下,bias shape为1维;当输出为3维情况下,bias shape为1维或3维。

    • [object Object]Atlas 推理系列产品[object Object]:数据类型支持int32。
    • [object Object]Atlas A2 训练系列产品/Atlas 800I A2 推理产品/A200I A2 Box 异构组件[object Object]:数据类型支持int32、bfloat16、float16、float32。
    • [object Object]Atlas A3 训练系列产品/Atlas A3 推理系列产品[object Object]:数据类型支持int32、bfloat16、float16、float32。
  • output_dtype(变量):ScalarType类型,可选参数。表示输出Tensor的数据类型。默认值为None,代表输出Tensor数据类型为int8。

    • [object Object]Atlas 推理系列产品[object Object]:支持输入torch.int8、torch.float16。
    • [object Object]Atlas A2 训练系列产品/Atlas 800I A2 推理产品/A200I A2 Box 异构组件[object Object]:支持输入torch.int8、torch.float16、torch.bfloat16。
    • [object Object]Atlas A3 训练系列产品/Atlas A3 推理系列产品[object Object]:支持输入torch.int8、torch.float16、torch.bfloat16。

输出说明

一个Tensor类型的输出,代表量化matmul的计算结果:

  • 如果output_dtype为torch.float16,输出的数据类型为float16。
  • 如果output_dtype为torch.int8或者None,输出的数据类型为int8。
  • 如果output_dtype为torch.bfloat16,输出的数据类型为bfloat16。

约束说明

  • 该接口支持推理场景下使用。

  • 该接口支持图模式(PyTorch 2.1版本)。

  • x1、weight、scale不能是空。

  • x1与x2最后一维的shape大小不能超过65535。

  • 输入参数或变量间支持的数据类型组合情况如下:

    表1 [object Object]Atlas 推理系列产品[object Object]

    [object Object][object Object]

    [object Object]

    表2 [object Object]Atlas A2 训练系列产品/Atlas 800I A2 推理产品/A200I A2 Box 异构组件[object Object]、[object Object]Atlas A3 训练系列产品/Atlas A3 推理系列产品[object Object]

    [object Object][object Object]

    [object Object]

支持的型号

  • [object Object][object Object]Atlas A2 训练系列产品/Atlas 800I A2 推理产品/A200I A2 Box 异构组件[object Object][object Object]
  • [object Object]Atlas A3 训练系列产品/Atlas A3 推理系列产品[object Object]
  • [object Object]Atlas 推理系列产品[object Object]

调用示例

  • 单算子模式调用

    [object Object]
  • 图模式调用

    [object Object]