gather_nd
函数原型
gather_nd(params, indices, batch_dims=0)
参数说明
- params:待切片的数据。输入Tensor,tvm.tensor类型。
- indices:待取出数据的位置索引。输入Tensor,tvm.tensor类型。支持的数据类型:int32, int64。
- batch_dims:可选参数,代表批处理的维度数,含义同gather接口的batch_dims,取值范围:[-i, i -1],其中i是indices的维数,int类型。
返回值
wrapped_tensor:执行gather_nd之后的Tensor,tvm.tensor类型
约束说明
此接口暂不支持与其他TBE DSL计算接口混合使用。
支持的型号
Atlas 训练系列产品
Atlas 推理系列产品(Ascend 310P处理器)
Atlas 200/500 A2推理产品
Atlas A2训练系列产品/Atlas 800I A2推理产品
调用示例
输入Tensor params为:
params = [[1,2,3], [4,5,6], [7,8,9]]
输入Tensor indices为:
indices = [[1, 2], [2, 0], [0, 1]]
batch_dims=0
调用gather_nd接口,根据indices的坐标值,获取params数据切片, 代码示例如下所示:
from tbe import tvm from tbe import dsl params = tvm.placeholder((3, 3), dtype=dtype, name="params") indices = tvm.placeholder((3, 2), dtype=dtype, name="indices") set_valued_tensor = dsl.gather_nd(params, indices, 0)
输出结果如下所示:
gather_nd_tensor = [6, 7, 2]
父主题: Tensor操作接口