gather_nd

Description

Gathers data slices on multiple dimensions based on indices. Unlike the gather API that obtains data slices on a single dimension only based on the index data, the gather_nd API gathers data slices based on different index values per dimension.

For details, see Example.

Prototype

gather_nd(params, indices, batch_dims=0)

Parameters

  • params: a tvm.tensor for the data to be sliced.

    The supported data types include float16, float32, int8, uint8, and int32.

  • indices: a tvm.tensor for the index based on which the data will be obtained. The supported data types are int32 and int64.

    Note that the size of the last axis of indices cannot be greater than the number of dimensions of params.

  • batch_dims: (optional) an int for the number of dimensions for batch processing. It has the same meaning as batch_dims of the gather API. The value range is [–i, i – 1], where i indicates the dimension number of indices.

Returns

wrapped_tensor: a tvm.tensor for the result tensor.

Restrictions

This API cannot be used in conjunction with other TBE DSL APIs.

Availability

Atlas Training Series Product

Example

Input tensor params:

params = [[1,2,3],
          [4,5,6],
          [7,8,9]]

Input tensor indices:

indices = [[1, 2],
           [2, 0],
           [0, 1]]

batch_dims=0

Call the gather_nd API to gather data slices from params based on the coordinates of indices. A code example is as follows:
from tbe import tvm
from tbe import dsl
params = tvm.placeholder((3, 3), dtype=dtype, name="params")
indices = tvm.placeholder((3, 2), dtype=dtype, name="indices")
set_valued_tensor = dsl.gather_nd(params, indices, 0)

The output is as follows:

gather_nd_tensor = [6, 7, 2]