文档
注册
评分
提单
论坛
小AI

gather_nd

功能说明

gather接口只能获取某一维度上的索引数据,而gather_nd接口可以获取多个维度上的索引数据,且针对同一维度的每组数据,可分别取不同的索引值。

详细功能介绍可参见调用示例

函数原型

gather_nd(params, indices, batch_dims=0)

参数说明

  • params:待切片的数据。输入Tensor,tvm.tensor类型。

    支持的数据类型:float16, float32, int8, uint8, int32。

  • indices:待取出数据的位置索引。输入Tensor,tvm.tensor类型。支持的数据类型:int32, int64。

    需要注意:indices的最后一个轴的大小不能大于params的维度数。

  • batch_dims:可选参数,代表批处理的维度数,含义同gather接口的batch_dims,取值范围:[-i, i -1],其中i是indices的维数,int类型。

返回值

wrapped_tensor:执行gather_nd之后的Tensor,tvm.tensor类型

约束说明

此接口暂不支持与其他TBE DSL计算接口混合使用。

支持的型号

Atlas 训练系列产品

Atlas 推理系列产品(Ascend 310P处理器)

Atlas 200/500 A2推理产品

Atlas A2训练系列产品/Atlas 800I A2推理产品

调用示例

输入Tensor params为:

params = [[1,2,3],
          [4,5,6],
          [7,8,9]]

输入Tensor indices为:

indices = [[1, 2],
           [2, 0],
           [0, 1]]

batch_dims=0

调用gather_nd接口,根据indices的坐标值,获取params数据切片, 代码示例如下所示:
from tbe import tvm
from tbe import dsl
params = tvm.placeholder((3, 3), dtype=dtype, name="params")
indices = tvm.placeholder((3, 2), dtype=dtype, name="indices")
set_valued_tensor = dsl.gather_nd(params, indices, 0)

输出结果如下所示:

gather_nd_tensor = [6, 7, 2]
搜索结果
找到“0”个结果

当前产品无相关内容

未找到相关内容,请尝试其他搜索词