check_dtype

Description

Verifies the data type of a tensor.

Check whether dtype is listed in check_list. If the verification fails, RuntimeError: "only support %s while dtype is %s" is thrown.

Prototype

def check_dtype(dtype, check_list=ALL_DTYPE_LIST, param_name=PARAM_NAME)

Parameters

Parameter

Description

dtype

dtype to check.

check_list

Allowed dtype list, for example, ["float16","int32"].

Defaults to ALL_DTYPE_LIST, which is defined as follows:

ALL_DTYPE_LIST = ("int8", "uint8", "int16", "uint16", "int32", "uint32", "int64", "uint64", "float16", "float32", "float64")

param_name

Parameter name, used as an additional prompt during message printing. Defaults to NULL.

Returns

None. If the verification fails, the following verification exception is thrown:

RuntimeError:"only support %s while dtype is %s"

Restrictions

None

Example

from tbe.common.utils import para_check
from tbe.common.utils.para_check import OpParamInfoKey
para_check.check_dtype(x.get(OpParamInfoKey.DTYPE))  # Check dtype using the default verification rule.
para_check.check_dtype(x.get(OpParamInfoKey.DTYPE), ["float16", "float32"])   # Verify whether the input data type is float16 or float32. An exception is thrown on verification failure.