check_dtype_rule
函数功能
对tensor的数据类型进行校验。
检查需要校验的数据类型是否在允许的check_list中,若校验失败,则抛出RuntimeError。
函数原型
def check_dtype_rule(dtype, check_list, param_name="default")
参数说明
参数 |
说明 |
---|---|
dtype |
需要校验的数据类型。 字符串类型,不允许为空。 |
check_list |
支持的数据类型列表。 字符串列表。 |
param_name |
参数名称,用于打印信息时辅助提示。 |
返回值说明
无返回值。
如果校验失败,则抛RuntimeError。
约束说明
无
调用示例
from tbe.common.utils import para_check para_check.check_dtype_rule("float64", ["float16","float32"], "a")
检查参数a的数据类型是否支持,抛RuntimeError。
父主题: 算子参数校验