Optional Format Inference and Argument Verification

Function op_select_format

In the operator information library definition, set dynamicFormat.flag of the input and output dtype and format to true. During operator fusion, the op_select_format function in the operator implementation file will be called automatically to infer dtypes and formats of the input and output, freeing you from the hassles associated with manually listing the supported dtypes and formats.

For details about how to configure the operator information library definition, see TBE Operator Information Library.

Declare the op_select_format function as follows:

def op_select_format(input_x1, input_x2, output_y, attribute1=None, attribute2=None,..., kernel_name="xx"):

The arguments (including the operator inputs, outputs, attributes, and kernel name) to the op_select_format function call must be consistent with those passed to the operator API call. The op_select_format function returns a string containing the supported input and output formats and data types of the operator.

{
"input0": {
"name": "x",
"dtype": "float16,float16,int8,int8",
"format": "NC1HWC0_C04,NC1HWC0,NC1HWC0_C04,NC1HWC0"
},
"input1": {
"name": "y",
"dtype": "float16,float16,int8,int8",
"format": "FRACTAL_Z_C04,FRACTAL_Z,FRACTAL_Z_C04,FRACTAL_Z"
},
"output0": {
"name": "z",
"dtype": "float16,float16,int32,int32",
"format": "NC1HWC0,NC1HWC0,NC1HWC0,NC1HWC0"
}
}

For example, the op_select_format function of the Conv2d operator can be implemented as follows:

import json
def op_select_format(inputs, weights, bias, offset_w, outputs, strides,
                     pads, dilations, groups=1, data_format='NHWC',
                     offset_x=0, kernel_name="conv2d"):
    shape_x = inputs.get("ori_shape")
    format_x = inputs.get("ori_format")
    shape_y = weights.get("ori_shape")
    format_y = weights.get("ori_format")
    x_dict = dict(zip(list(format_x), shape_x))
    y_dict = dict(zip(list(format_y), shape_y))

    use_c04 = False
    if x_dict["C"] <= 4 and (y_dict["W"] != 1 or y_dict["H"] != 1):
        use_c04 = True
    res = {}
    if use_c04:
        res["input0"] = {
                "name":"x",
                "dtype":"float16, float16, int8, int8",
                "format": "NC1HWC0, NC1HWC0_C04, NC1HWC0, NC1HWC0_C04"
            }
        res["input1"] = {
            "name":"filter",
            "dtype":"float16, float16, int8, int8",
            "format": "FRACTAL_Z, FRACTAL_Z_C04, FRACTAL_Z, FRACTAL_Z_C04"
        }
        res["input2"] = {
            "name":"bias",
            "dtype":"float16, float16, int32, int32",
            "format": "ND, ND, ND, ND"
        }
        res["input3"] = {
            "name":"offset_w",
            "dtype":"int8, int8, int8, int8",
            "format": "ND, ND, ND, ND"
        }
        res["output0"] = {
            "name":"filter",
            "dtype":"float16, float16, int8, int8",
            "format": "NC1HWC0, NC1HWC0, NC1HWC0, NC1HWC0"
        }
    else:
        res["input0"] = {
            "name":"x",
            "dtype":"float16, int8",
            "format": "NC1HWC0, NC1HWC0"
        }
        res["input1"] = {
            "name":"filter",
            "dtype":"float16, int8",
            "format": "FRACTAL_Z, FRACTAL_Z"
        }
        res["input2"] = {
            "name":"bias",
            "dtype":"float16, int32",
            "format": "ND, ND"
        }
        res["input3"] = {
            "name":"offset_w",
            "dtype":"int8, int8",
            "format": "ND, ND"
        }
        res["output0"] = {
            "name":"filter",
            "dtype":"float16, int8",
            "format": "NC1HWC0, NC1HWC0"
        }

    return json.dumps(res, indent=4)

Function check_supported

To verify operator arguments in the operator fusion phase, implement the check_supported function in the operator implementation file and set the flag parameter of the needCheckSupport configuration option to true in the operator information library definition file. For details about the configuration of the operator information library definition, see TBE Operator Information Library.

If the check_supported function passes verification, the AI Core supports the operator arguments. In this case, the corresponding operator (TBE operator) on the AI Core is selected for execution. Otherwise, the AI CPU operator is executed.

In the check_supported function, you can customize the data type and shape verification of the operator input and output. The function declaration is as follows:

def check_supported(input_x1, input_x2, output_y, attribute1=None, attribute2=None,..., kernel_name="xx"):

Keep the arguments passed to the check_supported call consistent with those passed to the operator API call (in terms of the operator inputs, outputs, attributes, and kernel name).

Returns True if the verification is passed; otherwise, False.

For example, the check_supported function of the InTopK operator can be implemented as follows to verify the input data type.

def check_supported(predictions,targets,precision,k,kernel_name='in_top_k'):
    prediction_dtype = predictions.get("dtype").lower()
    target_dtype = targets.get("dtype").lower()
    if prediction_dtype != "float32":
        return False
    if target_dtype != "int32":
        return False

    return True

The check_supported function of the InplaceUpdate operator can be implemented as follows to verify the input data type and shape.

def check_supported(x, indices, v, y, kernel_name="inplace_update"):
    shape_indices = indices.get("shape")
    shape_v = v.get("shape")
    dtype_v = v.get("dtype").lower()
    reg_v_len = 1
    for i in range(1, len(shape_v)):
        reg_v_len = reg_v_len * shape_v[i]

    if dtype_v in ("float32", "int32"):
        dtype_size = 4
    else:
        dtype_size = 2
    reg_v_size = reg_v_len * dtype_size

    try:
        if len(shape_indices) != 1 or (reg_v_size % 32 != 0):
            return False

    except RuntimeError:
        return False

    return True

For an optional input and output, check whether its value is None before checking the shape and data type.