gather
Description
Gathers slices from an input tensor according to specific indices.
Example:
- Input tensor params: params = [[1,2,3],[4,5,6],[7,8,9]]
- Input tensor indices: indices = [2,0]
- Input axis = 0
If this API is called, the elements whose indices are 0 and 2 in the 0 axis are gathered. The result is as follows:
gather_tensor = [[7,8,9], [1,2,3]]
Prototype
gather(params, indices, axis=None, batch_dims=0, impl_mode="support_out_of_bound_index")
Parameters
- params: a tvm.tensor for the data to be sliced.
The supported data types include float16, float32, int8, uint8, int16, uint16, int32, uint32, int64, uint64, and bool.
- indices: index based on which the data will be obtained. It is a tvm.tensor for the input. Supported data types are int32 and int64.
- axis: axis along which the data will be sliced. The value range is [–p, p – 1] and must be greater than or equal to batch_dims. p indicates the dimension number of params and is of the int type.
- batch_dims: (optional) number of dimensions for batch processing. The parameter allows you to gather data slices from each element of a batch. For example, if batch_dims is 1, it indicates that there is an outer loop over the first axis of params and indices. For details, see the example with batch_dims = 1.
The value range of batch_dims is [–i, +i], where i is the dimension number of indices.
It is an int and must meet the following condition: batch_dims <= axis.
- impl_mode: (optional) operator processing mode. The default value is support_out_of_bound_index. You can also set this parameter to None. In support_out_of_bound_index mode, when the index value exceeds the range of [–params.shape[axis], params.shape[axis]), the corresponding output will be set to 0. In other modes, when the index value exceeds the range, an AI Core error will be reported.
Returns
gather_tensor: a tvm.tensor for the result tensor.
Restrictions
- If batch_dims is configured (batch_dims ≠ 0), the batch axis of params must be the same as indices. For example, batch_dims = 1 indicates that batch processing is performed on the first axis. In this case, the dimension size of axis 0 of params must be the same as that of indices.
- This API cannot be used in conjunction with other TBE DSL APIs.
Availability
Example
- Example: batch_dims = 0
params = [[1,2,3], [4,5,6], [7,8,9]]Input tensor indices:
indices = [2,0]
Inputs: axis = 1, batch_dims = 0
Call the gather API to gather data slices from params on axis based on the value of the indices tensor. A code example is as follows:from tbe import tvm from tbe import dsl params = tvm.placeholder((3, 3), dtype=dtype, name="params") indices = tvm.placeholder((2,), dtype=dtype, name="indices") set_valued_tensor = dsl.gather(params, indices, 0, 0)
The preceding code is used to gather data slices whose indexes are 2 and 0 on the axis 0 of params. The output is as follows:
gather_tensor = [[3, 1], [6, 4], [9, 7]] - Example: batch_dims = 1
params = [[1,2,3], [4,5,6], [7,8,9]]Input tensor indices:
indices = [[1], [2], [0]]Inputs: axis = 1, batch_dims = 1
The code example of the gather call is as follows.from tbe import tvm from tbe import dsl params = tvm.placeholder((3, 3), dtype=dtype, name="params") indices = tvm.placeholder((3, 1), dtype=dtype, name="indices") set_valued_tensor = dsl.gather(params, indices, 1, 1)
batch_dims = 1 indicates that batch processing is performed on the first axis (axis 0). On axis 0, gather params[i] and indices[i]. Obtain data slices from params on axis based on the value of indices. i indicates the coordinates of the batch axis.
The output is as follows:
gather_tensor = [[2,], [6,], [7,]]