gather_nd(params, indices, batch_dims=0)
wrapped_tensor:执行gather_nd之后的Tensor,tvm.tensor类型
此接口暂不支持与其他TBE DSL计算接口混合使用。
Atlas 训练系列产品
Atlas 推理系列产品
Atlas 200I/500 A2推理产品
Atlas A2训练系列产品
输入Tensor params为:
params = [[1,2,3], [4,5,6], [7,8,9]]
输入Tensor indices为:
indices = [[1, 2], [2, 0], [0, 1]]
batch_dims=0
from tbe import tvm from tbe import dsl params = tvm.placeholder((3, 3), dtype=dtype, name="params") indices = tvm.placeholder((3, 2), dtype=dtype, name="indices") set_valued_tensor = dsl.gather_nd(params, indices, 0)
输出结果如下所示:
gather_nd_tensor = [6, 7, 2]