将updates中的值按指定的索引indices更新input中的值,并将结果保存到输出tensor,input本身数据不变。
torch_npu.npu_scatter_nd_update(Tensor input, Tensor indices, Tensor updates) -> Tensor
一个Tensor类型的输出,代表input被更新后的结果。
import torch import torch_npu import numpy as np data_var = np.random.uniform(0, 1, [24, 128]).astype(np.float16) var = torch.from_numpy(data_var).to(torch.float16).npu() data_indices = np.random.uniform(0, 12, [12, 1]).astype(np.int32) indices = torch.from_numpy(data_indices).to(torch.int32).npu() data_updates = np.random.uniform(1, 2, [12, 128]).astype(np.float16) updates = torch.from_numpy(data_updates).to(torch.float16).npu() out = torch_npu.npu_scatter_nd_update(var, indices, updates)