文档
注册
评分
提单
论坛
小AI

check_dtype

函数功能

对tensor的数据类型进行校验。

检查dtype是否在允许的check_list中,若校验失败,则抛出RuntimeError:"only support %s while dtype is %s"。

函数原型

def check_dtype(dtype, check_list=ALL_DTYPE_LIST, param_name=PARAM_NAME)

参数说明

参数

说明

dtype

需要校验的dtype。

check_list

允许的dtype值,需要使用列表形式,例如:["float16","int32"]

默认值为ALL_DTYPE_LIST,ALL_DTYPE_LIST的定义如下:

ALL_DTYPE_LIST = ("int8", "uint8", "int16", "uint16", "int32", "uint32", "int64", "uint64", "float16", "float32", "float64")

param_name

参数名称,用于打印信息时辅助提示,默认值为NULL。

返回值说明

无返回值,若校验失败,则抛出如下校验异常:

RuntimeError:"only support %s while dtype is %s"

约束说明

无。

调用示例

from tbe.common.utils import para_check
from tbe.common.utils.para_check import OpParamInfoKey
para_check.check_dtype(x.get(OpParamInfoKey.DTYPE))  # 使用默认规则校验
para_check.check_dtype(x.get(OpParamInfoKey.DTYPE), [“float16”, “float32”])   # 校验该输入的数据类型是否是“float16”或者“float32”, 否则将抛出异常。
搜索结果
找到“0”个结果

当前产品无相关内容

未找到相关内容,请尝试其他搜索词