(beta)torch_npu.npu_scatter
接口原型
torch_npu.npu_scatter(self, indices, updates, dim) -> Tensor
功能描述
使用dim对scatter结果进行计数。类似于torch.scatter,优化NPU设备实现。
参数说明
- self (Tensor) - 输入张量。
- indices (Tensor) - 待scatter的元素index,可以为空,也可以与src有相同的维数。当为空时,操作返回“self unchanged”。
- updates (Tensor) - 待scatter的源元素。
- dim (Int) - 要进行index的轴。
调用示例
>>> input = torch.tensor([[1.6279, 0.1226], [0.9041, 1.0980]]).npu()
>>> input
tensor([[1.6279, 0.1226],
[0.9041, 1.0980]], device='npu:0')
>>> indices = torch.tensor([0, 1],dtype=torch.int32).npu()
>>> indices
tensor([0, 1], device='npu:0', dtype=torch.int32)
>>> updates = torch.tensor([-1.1993, -1.5247]).npu()
>>> updates
tensor([-1.1993, -1.5247], device='npu:0')
>>> dim = 0
>>> output = torch_npu.npu_scatter(input, indices, updates, dim)
>>> output
tensor([[-1.1993, 0.1226],
[ 0.9041, -1.5247]], device='npu:0')
父主题: torch_npu