transpose
Description
Transposes an input tensor based on the value of axes. For details, see Example.
Prototype
transpose(tensor, axes)
Parameters
- tensor: a tvm.tensor for the input tensor.
The supported data types include in8, uint8, float16, int16, uint16, float32, int32, uint32, int64, and uint64.
- axes: a list or tuple of ints for the transpose index of the input tensor. The length of axes is d, that is, the dimension of the input tensor. The value range of elements in axes is [0, d – 1], and the value of each element must be unique.
Returns
res_tensor: a tvm.tensor for the transposed tensor.
Restrictions
This API cannot be used in conjunction with other TBE DSL APIs.
Availability
Example
Input tensor x:
x = [[1,2,3],
[4,5,6]]
Input axes:
axes = [1,0]
Call the transpose API to transpose the x axis. The following is an example.
from tbe import tvm from tbe import dsl shape = (2,3) input_dtype = "int32" x = tvm.placeholder(shape, name="x", dtype=input_dtype) axes = [1,0] res = dsl.transpose(x, axes) # res.shape = (3,2)
According to the preceding code, the axes 0 and 1 of the input tensor can be transposed. The result tensor (res) is as follows:
res = [[1,4],
[2,5],
[3,6]]
Parent topic: Tensor Operation APIs