split
Description
Splits data into sub tensors of size_splits size along the split_dim dimension.
Prototype
split(data, split_dim, size_splits)
Parameters
- data: a tvm.tensor for the input tensor.
Atlas Training Series Product : supports float16 and float32. - split_dim: an int for the axis to perform splitting.
- size_splits: a list of splitting sizes.
Returns
- output_shape_list: a shape list of the result tensors.
- output_tensor_list: a tvm.tensor list of result tensors whose data type are the same as that of the input tensor.
Restrictions
This API cannot be used in conjunction with other TBE DSL APIs.
Availability
Example
from tbe import tvm from tbe import dsl shape = (1024, 1024, 256) input_dtype = "float16" data = tvm.placeholder(shape, name="data", dtype=input_dtype) res = dsl.split(data, 1, [512, 512])
Parent topic: Tensor Operation APIs