check_dtype
函数功能
对tensor的数据类型进行校验。
检查dtype是否在允许的check_list中,若校验失败,则抛出RuntimeError:"only support %s while dtype is %s"。
函数原型
def check_dtype(dtype, check_list=ALL_DTYPE_LIST, param_name=PARAM_NAME)
参数说明
参数 |
说明 |
---|---|
dtype |
需要校验的dtype。 |
check_list |
允许的dtype值,需要使用列表形式,例如:["float16","int32"] 默认值为ALL_DTYPE_LIST,ALL_DTYPE_LIST的定义如下: ALL_DTYPE_LIST = ("int8", "uint8", "int16", "uint16", "int32", "uint32", "int64", "uint64", "float16", "float32", "float64") |
param_name |
参数名称,用于打印信息时辅助提示,默认值为NULL。 |
返回值说明
无返回值,若校验失败,则抛出如下校验异常:
RuntimeError:"only support %s while dtype is %s"
约束说明
无。
调用示例
from tbe.common.utils import para_check from tbe.common.utils.para_check import OpParamInfoKey para_check.check_dtype(x.get(OpParamInfoKey.DTYPE)) # 使用默认规则校验 para_check.check_dtype(x.get(OpParamInfoKey.DTYPE), [“float16”, “float32”]) # 校验该输入的数据类型是否是“float16”或者“float32”, 否则将抛出异常。
父主题: 算子参数校验