cast

Description

Casts a tensor to a specified data type.

Prototype

def cast(x, dtype)

Parameters

Parameter

Description

x

Tensor.

dtype

Target data type, which is a string.

Returns

Tensor of the target data type

Restrictions

Table 1 Supported data type conversion

Source Data Type

Target Data Type

Rounding Mode

float32

float16

N/A

float32

int8

Round down

float32

uint8

Round down

float16

float32

N/A

float16

int8

Round down

float16

uint8

Round down

float16

int32

Round down

int8

float16

N/A

int8

uint8

Round down

int32

float16

N/A

int32

int8

Round down

int32

uint8

Round down

Example

from tbe.common.utils import shape_util
tensor_x = tvm.placeholder((32,64,64,3), name='tensor_x', dtype="float16")
shape_util.cast(tensor_x, "float32") 

Casts tensor_x of type float16 to type float32.