torch_npu.scatter_update
产品支持情况
| 产品 | 是否支持 |
|---|---|
| [object Object]Atlas A3 训练系列产品[object Object] | √ |
| [object Object]Atlas A2 训练系列产品/Atlas A2 推理系列产品[object Object] | √ |
| [object Object]Atlas 训练系列产品[object Object] | √ |
功能说明
将tensor updates中的值按指定的轴axis和索引indices更新tensor data中的值,并将结果保存到输出tensor,data本身的数据不变。
函数原型
[object Object]
参数说明
data (
Tensor):必选参数。代表更新前的原数据,data只支持2-8维,且维度大小需要与updates一致;支持非连续的tensor;数据格式支持;不支持空Tensor。- [object Object]Atlas A2 训练系列产品/Atlas A2 推理系列产品[object Object]:数据类型支持
int8、float16、float32、bfloat16、int32。 - [object Object]Atlas A3 训练系列产品[object Object]:数据类型支持
int8、float16、float32、bfloat16、int32。 - [object Object]Atlas 训练系列产品[object Object]:数据类型支持
int8、float16、float32、int32。
- [object Object]Atlas A2 训练系列产品/Atlas A2 推理系列产品[object Object]:数据类型支持
indices (
Tensor):必选参数。代表索引,数据类型支持int32、int64;目前仅支持一维和二维;支持非连续的tensor;数据格式支持;不支持空Tensor。updates (
Tensor):必选参数。代表更新的数据,updates的维度大小需要与data一致;支持非连续的tensor;数据格式支持;不支持空Tensor。- [object Object]Atlas A2 训练系列产品/Atlas A2 推理系列产品[object Object]:数据类型支持
int8、float16、float32、bfloat16、int32。 - [object Object]Atlas A3 训练系列产品[object Object]:数据类型支持
int8、float16、float32、bfloat16、int32。 - [object Object]Atlas 训练系列产品[object Object]:数据类型支持
int8、float16、float32、int32。
- [object Object]Atlas A2 训练系列产品/Atlas A2 推理系列产品[object Object]:数据类型支持
axis (
int):必选参数。代表轴,用来scatter的维度,数据类型为int64。
返回值说明
Tensor
计算输出,只支持2-8维,且维度大小需要与data一致;支持非连续的tensor;数据格式支持;不支持空Tensor。
- [object Object]Atlas A2 训练系列产品/Atlas A2 推理系列产品[object Object]:数据类型支持
int8、float16、float32、bfloat16、int32。 - [object Object]Atlas A3 训练系列产品[object Object]:数据类型支持
int8、float16、float32、bfloat16、int32。 - [object Object]Atlas 训练系列产品[object Object]:数据类型支持
int8、float16、float32、int32。
约束说明
data与updates的秩一致。- 不支持索引越界,索引越界不校验。
调用示例
[object Object]