gather_nd
Description
Gathers data slices on multiple dimensions based on indices. Unlike the gather API that obtains data slices on a single dimension only based on the index data, the gather_nd API gathers data slices based on different index values per dimension.
For details, see Example.
Prototype
gather_nd(params, indices, batch_dims=0)
Parameters
- params: a tvm.tensor for the data to be sliced.
The supported data types include float16, float32, int8, uint8, and int32.
- indices: a tvm.tensor for the index based on which the data will be obtained. The supported data types are int32 and int64.
Note that the size of the last axis of indices cannot be greater than the number of dimensions of params.
- batch_dims: (optional) an int for the number of dimensions for batch processing. It has the same meaning as batch_dims of the gather API. The value range is [–i, i – 1], where i indicates the dimension number of indices.
Returns
wrapped_tensor: a tvm.tensor for the result tensor.
Restrictions
This API cannot be used in conjunction with other TBE DSL APIs.
Availability
Example
Input tensor params:
params = [[1,2,3],
[4,5,6],
[7,8,9]]
Input tensor indices:
indices = [[1, 2],
[2, 0],
[0, 1]]
batch_dims=0
Call the gather_nd API to gather data slices from params based on the coordinates of indices. A code example is as follows:
from tbe import tvm from tbe import dsl params = tvm.placeholder((3, 3), dtype=dtype, name="params") indices = tvm.placeholder((3, 2), dtype=dtype, name="indices") set_valued_tensor = dsl.gather_nd(params, indices, 0)
The output is as follows:
gather_nd_tensor = [6, 7, 2]
Parent topic: Tensor Operation APIs