(可选)Format推导及参数校验
op_select_format函数实现
开发者可以在算子实现文件中实现op_select_format函数,推导出算子的输入输出支持的dtype与format,后续进行算子信息库定义时将输入输出的dtype与format的dynamicFormat.flag配置为true即可,算子融合时会自动调用算子实现文件中的op_select_format函数进行dtype与format的设置,无需配置固定的dtype与format;若算子实现文件中不实现此函数,则后续进行算子信息库定义时需要配置输入输出支持的dtype与format列表。
算子信息库定义的配置可参见TBE算子信息库。
op_select_format函数的声明如下所示:
def op_select_format(input_x1, input_x2, output_y, attribute1=None, attribute2=None,..., kernel_name="xx"):
op_select_format函数的入参和算子接口函数保持一致(即算子的输入、输出、属性及kernel_name),出参为包含了当前算子输入输出支持的format和data type列表的字符串,字符串格式如下所示:
{
"input0": {
"name": "x",
"dtype": "float16,float16,int8,int8",
"format": "NC1HWC0_C04,NC1HWC0,NC1HWC0_C04,NC1HWC0"
},
"input1": {
"name": "y",
"dtype": "float16,float16,int8,int8",
"format": "FRACTAL_Z_C04,FRACTAL_Z,FRACTAL_Z_C04,FRACTAL_Z"
},
"output0": {
"name": "z",
"dtype": "float16,float16,int32,int32",
"format": "NC1HWC0,NC1HWC0,NC1HWC0,NC1HWC0"
}
}
例如,conv2d算子的op_select_format函数实现如下:
import json
def op_select_format(inputs, weights, bias, offset_w, outputs, strides,
pads, dilations, groups=1, data_format='NHWC',
offset_x=0, kernel_name="conv2d"):
shape_x = inputs.get("ori_shape")
format_x = inputs.get("ori_format")
shape_y = weights.get("ori_shape")
format_y = weights.get("ori_format")
x_dict = dict(zip(list(format_x), shape_x))
y_dict = dict(zip(list(format_y), shape_y))
use_c04 = False
if x_dict["C"] <= 4 and (y_dict["W"] != 1 or y_dict["H"] != 1):
use_c04 = True
res = {}
if use_c04:
res["input0"] = {
"name":"x",
"dtype":"float16, float16, int8, int8",
"format": "NC1HWC0, NC1HWC0_C04, NC1HWC0, NC1HWC0_C04"
}
res["input1"] = {
"name":"filter",
"dtype":"float16, float16, int8, int8",
"format": "FRACTAL_Z, FRACTAL_Z_C04, FRACTAL_Z, FRACTAL_Z_C04"
}
res["input2"] = {
"name":"bias",
"dtype":"float16, float16, int32, int32",
"format": "ND, ND, ND, ND"
}
res["input3"] = {
"name":"offset_w",
"dtype":"int8, int8, int8, int8",
"format": "ND, ND, ND, ND"
}
res["output0"] = {
"name":"filter",
"dtype":"float16, float16, int8, int8",
"format": "NC1HWC0, NC1HWC0, NC1HWC0, NC1HWC0"
}
else:
res["input0"] = {
"name":"x",
"dtype":"float16, int8",
"format": "NC1HWC0, NC1HWC0"
}
res["input1"] = {
"name":"filter",
"dtype":"float16, int8",
"format": "FRACTAL_Z, FRACTAL_Z"
}
res["input2"] = {
"name":"bias",
"dtype":"float16, int32",
"format": "ND, ND"
}
res["input3"] = {
"name":"offset_w",
"dtype":"int8, int8",
"format": "ND, ND"
}
res["output0"] = {
"name":"filter",
"dtype":"float16, int8",
"format": "NC1HWC0, NC1HWC0"
}
return json.dumps(res, indent=4)
check_supported函数实现
若开发者需要在算子融合阶段进行算子参数校验,则可在算子实现文件中实现check_supported函数,并在算子信息库定义文件中将配置项needCheckSupport的flag参数配置为true,算子信息库定义的配置可参见TBE算子信息库。
若check_supported函数校验通过,则代表AI Core支持此算子参数,则会选择AI Core上相应的算子(即TBE算子)执行,否则,会选择AI CPU算子执行。
check_supported函数中可自定义实现算子输入输出dtype的校验以及shape的校验,函数声明如下所示:
def check_supported(input_x1, input_x2, output_y, attribute1=None, attribute2=None,..., kernel_name="xx"):
check_supported函数的入参和算子接口函数保持一致(即算子的输入、输出、属性及kernel_name)。
若校验成功,则返回True;若校验失败,则返回False。
例如,InTopK算子的check_supported函数实现如下,实现对输入参数的数据类型的校验。
def check_supported(predictions,targets,precision,k,kernel_name='in_top_k'):
prediction_dtype = predictions.get("dtype").lower()
target_dtype = targets.get("dtype").lower()
if prediction_dtype != "float32":
return False
if target_dtype != "int32":
return False
return True
InplaceUpdate算子的check_supported函数实现如下,实现对输入参数的数据类型以及shape的校验。
def check_supported(x, indices, v, y, kernel_name="inplace_update"):
shape_indices = indices.get("shape")
shape_v = v.get("shape")
dtype_v = v.get("dtype").lower()
reg_v_len = 1
for i in range(1, len(shape_v)):
reg_v_len = reg_v_len * shape_v[i]
if dtype_v in ("float32", "int32"):
dtype_size = 4
else:
dtype_size = 2
reg_v_size = reg_v_len * dtype_size
try:
if len(shape_indices) != 1 or (reg_v_size % 32 != 0):
return False
except RuntimeError:
return False
return True
对于可选类型的输入与输出,应首先判断是否为None,然后再进行shape与dtype的判断等操作。