split

Description

Splits data into sub tensors of size_splits size along the split_dim dimension.

Prototype

split(data, split_dim, size_splits)

Parameters

  • data: a tvm.tensor for the input tensor.

    Atlas Training Series Product: supports float16 and float32.

  • split_dim: an int for the axis to perform splitting.
  • size_splits: a list of splitting sizes.

Returns

  • output_shape_list: a shape list of the result tensors.
  • output_tensor_list: a tvm.tensor list of result tensors whose data type are the same as that of the input tensor.

Restrictions

This API cannot be used in conjunction with other TBE DSL APIs.

Availability

Atlas Training Series Product

Example

from tbe import tvm
from tbe import dsl
shape = (1024, 1024, 256)
input_dtype = "float16"
data = tvm.placeholder(shape, name="data", dtype=input_dtype)
res = dsl.split(data, 1, [512, 512])