gather

Description

Gathers slices from an input tensor according to specific indices.

Example:

  • Input tensor params: params = [[1,2,3],[4,5,6],[7,8,9]]
  • Input tensor indices: indices = [2,0]
  • Input axis = 0

If this API is called, the elements whose indices are 0 and 2 in the 0 axis are gathered. The result is as follows:

gather_tensor = [[7,8,9], [1,2,3]]

Prototype

gather(params, indices, axis=None, batch_dims=0, impl_mode="support_out_of_bound_index")

Parameters

  • params: a tvm.tensor for the data to be sliced.

    The supported data types include float16, float32, int8, uint8, int16, uint16, int32, uint32, int64, uint64, and bool.

  • indices: index based on which the data will be obtained. It is a tvm.tensor for the input. Supported data types are int32 and int64.
  • axis: axis along which the data will be sliced. The value range is [–p, p – 1] and must be greater than or equal to batch_dims. p indicates the dimension number of params and is of the int type.
  • batch_dims: (optional) number of dimensions for batch processing. The parameter allows you to gather data slices from each element of a batch. For example, if batch_dims is 1, it indicates that there is an outer loop over the first axis of params and indices. For details, see the example with batch_dims = 1.

    The value range of batch_dims is [–i, +i], where i is the dimension number of indices.

    It is an int and must meet the following condition: batch_dims <= axis.

  • impl_mode: (optional) operator processing mode. The default value is support_out_of_bound_index. You can also set this parameter to None. In support_out_of_bound_index mode, when the index value exceeds the range of [–params.shape[axis], params.shape[axis]), the corresponding output will be set to 0. In other modes, when the index value exceeds the range, an AI Core error will be reported.

Returns

gather_tensor: a tvm.tensor for the result tensor.

Restrictions

  • If batch_dims is configured (batch_dims ≠ 0), the batch axis of params must be the same as indices. For example, batch_dims = 1 indicates that batch processing is performed on the first axis. In this case, the dimension size of axis 0 of params must be the same as that of indices.
  • This API cannot be used in conjunction with other TBE DSL APIs.

Availability

Atlas Training Series Product

Example

  • Example: batch_dims = 0

    Input tensor params:

    params = [[1,2,3],
              [4,5,6],
              [7,8,9]]

    Input tensor indices:

    indices = [2,0]

    Inputs: axis = 1, batch_dims = 0

    Call the gather API to gather data slices from params on axis based on the value of the indices tensor. A code example is as follows:
    from tbe import tvm
    from tbe import dsl
    params = tvm.placeholder((3, 3), dtype=dtype, name="params")
    indices = tvm.placeholder((2,), dtype=dtype, name="indices")
    set_valued_tensor = dsl.gather(params, indices, 0, 0)

    The preceding code is used to gather data slices whose indexes are 2 and 0 on the axis 0 of params. The output is as follows:

    gather_tensor = [[3, 1],
                     [6, 4],
                     [9, 7]]
  • Example: batch_dims = 1

    Input tensor params:

    params = [[1,2,3],
              [4,5,6],
              [7,8,9]]

    Input tensor indices:

    indices = [[1],
               [2],
               [0]]

    Inputs: axis = 1, batch_dims = 1

    The code example of the gather call is as follows.
    from tbe import tvm
    from tbe import dsl
    params = tvm.placeholder((3, 3), dtype=dtype, name="params")
    indices = tvm.placeholder((3, 1), dtype=dtype, name="indices")
    set_valued_tensor = dsl.gather(params, indices, 1, 1)

    batch_dims = 1 indicates that batch processing is performed on the first axis (axis 0). On axis 0, gather params[i] and indices[i]. Obtain data slices from params on axis based on the value of indices. i indicates the coordinates of the batch axis.

    The output is as follows:

    gather_tensor = [[2,],
                     [6,],
                     [7,]]