check_dtype
Description
Verifies the data type of a tensor.
Check whether dtype is listed in check_list. If the verification fails, RuntimeError: "only support %s while dtype is %s" is thrown.
Prototype
def check_dtype(dtype, check_list=ALL_DTYPE_LIST, param_name=PARAM_NAME)
Parameters
Parameter |
Description |
|---|---|
dtype |
dtype to check. |
check_list |
Allowed dtype list, for example, ["float16","int32"]. Defaults to ALL_DTYPE_LIST, which is defined as follows: ALL_DTYPE_LIST = ("int8", "uint8", "int16", "uint16", "int32", "uint32", "int64", "uint64", "float16", "float32", "float64") |
param_name |
Parameter name, used as an additional prompt during message printing. Defaults to NULL. |
Returns
None. If the verification fails, the following verification exception is thrown:
RuntimeError:"only support %s while dtype is %s"
Restrictions
None
Example
from tbe.common.utils import para_check from tbe.common.utils.para_check import OpParamInfoKey para_check.check_dtype(x.get(OpParamInfoKey.DTYPE)) # Check dtype using the default verification rule. para_check.check_dtype(x.get(OpParamInfoKey.DTYPE), ["float16", "float32"]) # Verify whether the input data type is float16 or float32. An exception is thrown on verification failure.
Parent topic: Operator Argument Verification