功能描述
将updates中的值按指定的索引indices更新input中的值,并将结果保存到输出tensor,input本身的数据不变。
接口原型
| torch_npu.npu_scatter_nd_update(Tensor input, Tensor indices, Tensor updates) -> Tensor
|
参数说明
- input:Tensor类型,必选输入,源数据张量,数据格式支持ND,支持非连续的Tensor,数据类型需要与updates一致,维数只能是1~8维。
- Atlas 推理系列加速卡产品:数据类型支持float32、float16、bool。
- Atlas 训练系列产品:数据类型支持float32、float16、bool。
- Atlas A2 训练系列产品/Atlas 800I A2 推理产品/A200I A2 Box 异构组件:数据类型支持float32、float16、bool、bfloat16、int64、int8。
- Atlas A3 训练系列产品/Atlas A3 推理系列产品:数据类型支持float32、float16、bool、bfloat16、int64、int8。
- indices:Tensor类型,必选输入,索引张量,数据类型支持int32、int64,数据格式支持ND,支持非连续的Tensor,indices中的索引数据不支持越界。
- updates:Tensor类型,必选输入,更新数据张量,数据格式支持ND,支持非连续的Tensor,数据类型需要与input一致。
- Atlas 推理系列加速卡产品:数据类型支持float32、float16、bool。
- Atlas 训练系列产品:数据类型支持float32、float16、bool。
- Atlas A2 训练系列产品/Atlas 800I A2 推理产品/A200I A2 Box 异构组件:数据类型支持float32、float16、bool、bfloat16、int64、int8。
- Atlas A3 训练系列产品/Atlas A3 推理系列产品:数据类型支持float32、float16、bool、bfloat16、int64、int8。
输出说明
一个Tensor类型的输出,代表input被更新后的结果。
约束说明
- 该接口支持图模式(目前仅支持PyTorch 2.1版本)。
- indices至少是2维,其最后1维的大小不能超过input的维度大小。
- 假设indices最后1维的大小是a,则updates的shape等于indices除最后1维外的shape加上input除前a维外的shape。举例:input的shape是(4, 5, 6),indices的shape是(3, 2),则updates的shape必须是(3, 6)。
支持的型号
- Atlas A2 训练系列产品/Atlas 800I A2 推理产品/A200I A2 Box 异构组件
- Atlas 训练系列产品
- Atlas 推理系列产品
- Atlas A3 训练系列产品/Atlas A3 推理系列产品
调用示例
- 单算子模式调用
1
2
3
4
5
6
7
8
9
10
11
12 | import torch
import torch_npu
import numpy as np
data_var = np.random.uniform(0, 1, [24, 128]).astype(np.float16)
var = torch.from_numpy(data_var).to(torch.float16).npu()
data_indices = np.random.uniform(0, 12, [12, 1]).astype(np.int32)
indices = torch.from_numpy(data_indices).to(torch.int32).npu()
data_updates = np.random.uniform(1, 2, [12, 128]).astype(np.float16)
updates = torch.from_numpy(data_updates).to(torch.float16).npu()
out = torch_npu.npu_scatter_nd_update(var, indices, updates)
|
- 图模式调用
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40 | import os
import torch_npu
import torchair as tng
from torchair.configs.compiler_config import CompilerConfig
import torch.nn as nn
import torch
import numpy as np
import numpy
torch_npu.npu.set_compile_mode(jit_compile=True)
os.environ["ENABLE_ACLNN"] = "false"
class Network(nn.Module):
def __init__(self):
super(Network, self).__init__()
def forward(self, var, indices, update):
# 调用目标接口
res = torch_npu.npu_scatter_nd_update(var, indices, update)
return res
npu_mode = Network()
config = CompilerConfig()
npu_backend = tng.get_npu_backend(compiler_config=config)
npu_mode = torch.compile(npu_mode, fullgraph=True, backend=npu_backend, dynamic=False)
dtype = np.float32
x = [33 ,5]
indices = [33,25,1]
update = [33,25,5]
data_x = np.random.uniform(0, 1, x).astype(dtype)
data_indices = np.random.uniform(0, 10, indices).astype(dtype)
data_update = np.random.uniform(0, 1, update).astype(dtype)
tensor_x = torch.from_numpy(data_x).to(torch.float16)
tensor_indices = torch.from_numpy(data_indices).to(torch.int32)
tensor_update = torch.from_numpy(data_update).to(torch.float16)
# 传参
print(npu_mode(tensor_x.npu(), tensor_indices.npu(), tensor_update.npu()))
|