1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104 | # 入图方式
import torch
import torch_npu
import math
import torchair as tng
import numpy as np
from torchair.configs.compiler_config import CompilerConfig
import torch._dynamo
TORCHDYNAMO_VERBOSE=1
TORCH_LOGS="+dynamo"
# 支持入图的打印宏
import logging
from torchair.core.utils import logger
logger.setLevel(logging.DEBUG)
config = CompilerConfig()
config.debug.graph_dump.type = "pbtxt"
npu_backend = tng.get_npu_backend(compiler_config=config)
from torch.library import Library, impl
# 数据生成
dtype_list2 =["fp16","int8","int32","fp32","fp16"]
dtype_list =[np.float16,np.int8,np.int32,np.float32]
updates_shape =[1,11,1,32]
var_shape =[1,11,1,32]
indices_shape =[1,2]
quant_scales_shape =[1,1,1,32]
quant_zero_points_shape =[1,1,1,32]
axis =-2
quant_axis=-1
reduce = "update"
updates_data = np.random.uniform(-1,1,size=updates_shape)
var_data = np.random.uniform(1,2,size=var_shape).astype(dtype_list[1])
quant_scales_data = np.random.uniform(1,2,size=quant_scales_shape)
indices_data = np.random.uniform(0,1,size=indices_shape).astype(dtype_list[2])
quant_zero_points_data = np.random.uniform(0,1,size=quant_zero_points_shape)
updates_npu = torch.from_numpy(updates_data).npu().to(torch.bfloat16).npu()
var_npu = torch.from_numpy(var_data).npu()
quant_scales_npu = torch.from_numpy(quant_scales_data).npu().to(torch.bfloat16).npu()
quant_zero_points_npu = torch.from_numpy(quant_zero_points_data).to(torch.bfloat16).npu()
indices_npu = torch.from_numpy(indices_data).npu()
class Model(torch.nn.Module):
def __init__(self):
super().__init__()
def forward(self):
return torch_npu.npu_quant_scatter(var_npu, indices_npu, updates_npu, quant_scales_npu, quant_zero_points_npu, axis=axis, quant_axis=quant_axis, reduce=reduce)
def MetaInfershape():
with torch.no_grad():
model = Model()
model = torch.compile(model, backend=npu_backend, dynamic=False, fullgraph=True)
graph_output = model()
single_op = torch_npu.npu_quant_scatter(var_npu, indices_npu, updates_npu, quant_scales_npu, quant_zero_points_npu, axis=axis, quant_axis=quant_axis, reduce=reduce)
print("single op output with mask:", single_op[0], single_op[0].shape)
print("graph output with mask:", graph_output[0], graph_output[0].shape)
if __name__ == "__main__":
MetaInfershape()
# 执行上述代码的输出类似如下
single op output with mask: tensor([[[ 1, 1, 0, 1, 0, -1, 0, 0, 0, 1, 0, 1, 0, -1, 1, 0, 0,
0, 0, 0, 0, 0, 1, 0, 1, 0, 1, 1, 2, 1, 0, 0]],
[[ 1, 0, 0, 1, 0, 0, 1, 1, 1, 1, 0, 0, 0, 1, 1, 0, 1,
1, 1, 1, 1, 1, 0, 1, 0, 0, 1, 1, 0, 1, 0, 0]],
[[ 1, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0, 0, -1, 1, 1, 1, 1,
0, 1, 0, 2, 0, 0, 0, 1, 0, 1, 1, 2, 0, 1, 1]],
[[ 1, 1, 0, 1, 0, -1, 0, 1, 0, 1, 0, 0, -1, 0, 1, 0, 0,
1, 0, 2, 2, 0, 0, 1, 0, 1, 0, 0, 1, 0, 1, 0]],
[[ 1, 1, 0, 1, 1, 1, 0, 1, 1, 0, 1, 0, 1, 1, 1, 1, 1,
0, 0, 1, 2, 0, 1, 1, 0, 0, 1, 0, 1, 0, 1, 1]],
[[ 0, 1, 0, 1, 0, 1, 1, 0, 0, 1, 1, 0, 0, 0, 1, 1, 0,
0, 1, 1, 0, -1, 1, 1, 1, 0, 0, 1, 1, 1, 0, 0]],
[[ 0, 0, 0, 1, 0, 0, 0, 1, 1, 1, 0, 1, 0, -1, 1, 0, 0,
1, 1, 0, 1, 0, 1, 0, 1, 0, 1, 1, 1, 0, 1, 1]],
[[ 1, 1, 1, 0, 0, 0, 0, 1, 0, 1, 1, 1, 0, 1, 1, 1, 1,
0, 0, 1, 1, 0, 0, 1, 0, 0, 0, 1, 1, 0, 1, 1]],
[[ 1, 1, 0, 0, 1, 0, 0, 1, 0, 1, 1, 1, 0, 0, 1, -1, 0,
1, 1, 0, 0, 1, 0, 1, 1, 0, 0, 1, 0, 1, 1, 1]],
[[ 1, 0, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 1, 1, 1, 0, 1,
0, 1, 1, 1, -1, 0, 1, 0, 0, 0, 1, 1, 1, 0, 0]],
[[ 1, 0, -1, 1, 0, 0, 1, 0, 1, 2, 0, 1, 0, -1, 1, 1, 1,
1, 0, 0, 2, 1, 0, 1, 1, 0, 1, 0, 1, 0, 1, 0]]],
device='npu:0', dtype=torch.int8) torch.Size([11, 1, 32])
graph output with mask: tensor([[[ 1, 1, 0, 1, 0, -1, 0, 0, 0, 1, 0, 1, 0, -1, 1, 0, 0,
0, 0, 0, 0, 0, 1, 0, 1, 0, 1, 1, 2, 1, 0, 0]],
[[ 1, 0, 0, 1, 0, 0, 1, 1, 1, 1, 0, 0, 0, 1, 1, 0, 1,
1, 1, 1, 1, 1, 0, 1, 0, 0, 1, 1, 0, 1, 0, 0]],
[[ 1, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0, 0, -1, 1, 1, 1, 1,
0, 1, 0, 2, 0, 0, 0, 1, 0, 1, 1, 2, 0, 1, 1]],
[[ 1, 1, 0, 1, 0, -1, 0, 1, 0, 1, 0, 0, -1, 0, 1, 0, 0,
1, 0, 2, 2, 0, 0, 1, 0, 1, 0, 0, 1, 0, 1, 0]],
[[ 1, 1, 0, 1, 1, 1, 0, 1, 1, 0, 1, 0, 1, 1, 1, 1, 1,
0, 0, 1, 2, 0, 1, 1, 0, 0, 1, 0, 1, 0, 1, 1]],
[[ 0, 1, 0, 1, 0, 1, 1, 0, 0, 1, 1, 0, 0, 0, 1, 1, 0,
0, 1, 1, 0, -1, 1, 1, 1, 0, 0, 1, 1, 1, 0, 0]],
[[ 0, 0, 0, 1, 0, 0, 0, 1, 1, 1, 0, 1, 0, -1, 1, 0, 0,
1, 1, 0, 1, 0, 1, 0, 1, 0, 1, 1, 1, 0, 1, 1]],
[[ 1, 1, 1, 0, 0, 0, 0, 1, 0, 1, 1, 1, 0, 1, 1, 1, 1,
0, 0, 1, 1, 0, 0, 1, 0, 0, 0, 1, 1, 0, 1, 1]],
[[ 1, 1, 0, 0, 1, 0, 0, 1, 0, 1, 1, 1, 0, 0, 1, -1, 0,
1, 1, 0, 0, 1, 0, 1, 1, 0, 0, 1, 0, 1, 1, 1]],
[[ 1, 0, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 1, 1, 1, 0, 1,
0, 1, 1, 1, -1, 0, 1, 0, 0, 0, 1, 1, 1, 0, 0]],
[[ 1, 0, -1, 1, 0, 0, 1, 0, 1, 2, 0, 1, 0, -1, 1, 1, 1,
1, 0, 0, 2, 1, 0, 1, 1, 0, 1, 0, 1, 0, 1, 0]]],
device='npu:0', dtype=torch.int8) torch.Size([11, 1, 32])
|