功能描述
- 算子功能:对输入的张量进行量化处理。
- 计算公式:
- 如果div_mode为True:

- 如果div_mode为False:

接口原型
| torch_npu.npu_quantize(Tensor input, Tensor scales, Tensor? zero_points, ScalarType dtype, int axis=1, bool div_mode=True) -> Tensor
|
参数说明
- input:Tensor类型,需要进行量化的源数据张量,必选输入,数据格式支持ND,支持非连续的Tensor。div_mode为False且dtype为torch.quint4x2时,最后一维需要能被8整除。
- Atlas 推理系列产品:数据类型支持float、float16。
- Atlas A2 训练系列产品/Atlas 800I A2 推理产品/A200I A2 Box 异构组件:数据类型支持float、float16、bfloat16。
输出说明
y:Tensor类型,公式中的输出,输出大小与input一致。如果参数dtype为torch.quint4x2,输出的dtype是torch.int32,shape的最后一维是输入shape最后一维的1/8,shape其他维度和输入一致。
约束说明
- 该接口仅在推理场景下使用。
- 该接口支持图模式(目前仅支持PyTorch 2.1版本)。
- div_mode为False时:
- 支持Atlas A2 训练系列产品/Atlas 800I A2 推理产品/A200I A2 Box 异构组件。
- 当dtype为torch.quint4x2或者axis为-2时,不支持Atlas 推理系列产品。
支持的型号
- Atlas A2 训练系列产品/Atlas 800I A2 推理产品/A200I A2 Box 异构组件
调用示例
- 单算子模式调用
- Atlas A2 训练系列产品/Atlas 800I A2 推理产品/A200I A2 Box 异构组件
| import torch
import torch_npu
x = torch.randn(1, 1, 12).bfloat16().npu()
scale = torch.tensor([0.1] * 12).bfloat16().npu()
out = torch_npu.npu_quantize(x, scale, None, torch.qint8, -1, False)
print(out)
|
- Atlas 推理系列产品
| import torch
import torch_npu
x = torch.randn((2, 3, 12), dtype=torch.float).npu()
scale = torch.tensor(([3] * 12),dtype=torch.float).npu()
out = torch_npu.npu_quantize(x, scale, None, torch.qint8, -1, False)
print(out)
|
- 图模式调用
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22 | import torch
import torch_npu
import torchair as tng
from torchair.ge_concrete_graph import ge_apis as ge
from torchair.configs.compiler_config import CompilerConfig
x = torch.randn((2, 3, 12), dtype=torch.float16).npu()
scale = torch.tensor(([3] * 12),dtype=torch.float16).npu()
axis =1
div_mode = False
class Network(torch.nn.Module):
def __init__(self):
super(Network, self).__init__()
def forward(self, x, scale,zero_points, dst_type,div_mode):
return torch_npu.npu_quantize(x, scale,zero_points=zero_points,dtype=dst_type,div_mode=div_mode)
model = Network()
config = CompilerConfig()
npu_backend = tng.get_npu_backend(compiler_config=config)
config.debug.graph_dump.type = 'pbtxt'
model = torch.compile(model, fullgraph=True, backend=npu_backend, dynamic=True)
output_data = model(x, scale,None,dst_type=torch.qint8, div_mode=div_mode)
print(output_data)
|