SetOpSelectFormat

Function Usage

If you need to infer the dtpye and format supported by the operator inputs and outputs, you can implement an inference callback function and register it by calling this API. In addition, set DynamicFormatFlag to true, so that the inference function can be automatically called to set dtype and format during operator fusion. During operator prototype registration, dtype and format setting is not required.

Note that if dtype and format have been registered in the operator prototype, the registered dtype and format are used, even if the inference function is registered.

Prototype

OpAICoreDef &SetOpSelectFormat(optiling::OP_CHECK_FUNC func);

Parameters

Parameter

Input/Output

Description

func

Input

Function that infers the dtype and format supported by the operator inputs and outputs. The OP_CHECK_FUNC type is defined as follows:

using OP_CHECK_FUNC = ge::graphStatus (*)(const ge::Operator &op, ge::AscendString &result);

The input parameters of this function are operator descriptions, including the operator inputs, outputs, and properties. The output parameters are character strings that contain the format and dtype lists supported by the inputs and outputs of the current operator. An example of the character string format is as follows:

{
    "input0": {"name": "x","dtype": "float16,float32,int32","format": "ND,ND,ND"},
    "input1": {"name": "y","dtype": "float16,float32,int32","format": "ND,ND,ND"},
    "output0": {"name": "z","dtype": "float16,float32,int32","format": "ND,ND,ND"}
}

Returns

OpAICoreDef operator definition. For details, see OpAICoreDef.

Constraints

None

Example

The following provides examples of implementing and registering the inference function of the customized Add operator.

static ge::graphStatus OpSelectFormat(const ge::Operator &op, ge::AscendString &result)
{
    std::string resultJsonStr;
    // If the first dimension of the first input parameter's shape is less than or equal to 8, multiple formats are supported. Otherwise, only int32 is supported.
    if (op.GetInputDesc(0).GetShape().GetDim(0) <= 8) {
        resultJsonStr = R"({
        "input0": {"name": "x","dtype": "float16,float32,int32","format": "ND,ND,ND","unknownshape_format": "ND,ND,ND"},
        "input1": {"name": "y","dtype": "float16,float32,int32","format": "ND,ND,ND","unknownshape_format": "ND,ND,ND"},
        "output0": {"name": "z","dtype": "float16,float32,int32","format": "ND,ND,ND","unknownshape_format": "ND,ND,ND"}
        })";
    } else {
        resultJsonStr = R"({
        "input0": {"name": "x","dtype": "int32","format": "ND","unknownshape_format": "ND"},
        "input1": {"name": "y","dtype": "int32","format": "ND","unknownshape_format": "ND"},
        "output0": {"name": "z","dtype": "int32","format": "ND","unknownshape_format": "ND"}
        })";
    }
    result = ge::AscendString(resultJsonStr.c_str());
    return ge::GRAPH_SUCCESS;
}

An example of registering the inference function is as follows:

class AddCustom : public OpDef {
public:
    AddCustom(const char* name) : OpDef(name)
    {
        this->Input("x")
            .ParamType(REQUIRED);
        this->Input("y")
            .ParamType(REQUIRED);
        this->Output("z")
            .ParamType(REQUIRED);
        this->SetInferShape(ge::InferShape);
        this->AICore()
            .SetTiling(optiling::TilingFunc)
            .SetTilingParse(optiling::TilingPrepare)
            .SetOpSelectFormat(optiling::OpSelectFormat);

        OpAICoreConfig aicConfig;
        aicConfig.DynamicCompileStaticFlag(true)
            .DynamicFormatFlag(true)
            .DynamicRankSupportFlag(true)
            .DynamicShapeSupportFlag(true)
            .NeedCheckSupportFlag(false)
            .PrecisionReduceFlag(true);
        // Note: Replace soc_version with the actual AI processor version.
        this->AICore().AddConfig("soc_version", aicConfig);
    }
};