transpose

Description

Transposes an input tensor based on the value of axes. For details, see Example.

Prototype

transpose(tensor, axes)

Parameters

  • tensor: a tvm.tensor for the input tensor.

    The supported data types include in8, uint8, float16, int16, uint16, float32, int32, uint32, int64, and uint64.

  • axes: a list or tuple of ints for the transpose index of the input tensor. The length of axes is d, that is, the dimension of the input tensor. The value range of elements in axes is [0, d – 1], and the value of each element must be unique.

Returns

res_tensor: a tvm.tensor for the transposed tensor.

Restrictions

This API cannot be used in conjunction with other TBE DSL APIs.

Availability

Atlas 200/300/500 Inference Product

Atlas Training Series Product

Example

Input tensor x:

x = [[1,2,3],
     [4,5,6]]

Input axes:

axes = [1,0]

Call the transpose API to transpose the x axis. The following is an example.

from tbe import tvm
from tbe import dsl
shape = (2,3) 
input_dtype = "int32"
x = tvm.placeholder(shape, name="x", dtype=input_dtype)  
axes = [1,0]
res = dsl.transpose(x, axes) 
# res.shape = (3,2)

According to the preceding code, the axes 0 and 1 of the input tensor can be transposed. The result tensor (res) is as follows:

res = [[1,4],
       [2,5],
       [3,6]]