SetOpSelectFormat

Function

If you need to derive the data types and formats supported by the operator inputs and outputs, you can implement the derivation callback function and register it by using this API. In addition, you need to set DynamicFormatFlag to true. In this case, the derivation function is automatically called to set the data types and formats during operator fusion. You do not need to configure the data types and formats supported by the inputs and outputs during operator prototype registration.

Note that if the data types and formats have been registered for the operator prototype, the registered data types and formats are used, and the derivation function will not be executed even if it is registered.

Prototype

1
OpAICoreDef &SetOpSelectFormat(optiling::OP_CHECK_FUNC func)

Parameters

Parameter

Input/Output

Description

func

Input

Function for deriving the data types and formats supported by the operator inputs and outputs. The OP_CHECK_FUNC type is defined as follows:

1
using OP_CHECK_FUNC = ge::graphStatus (*)(const ge::Operator &op, ge::AscendString &result);

The input parameter of this function is the description of the operator, including the input, output, and attributes of the operator. The output parameter is a string that contains the list of data types and formats supported by the operator inputs and outputs. The following is an example of the string format:

{
    "input0": {"name": "x","dtype": "float16,float32,int32","format": "ND,ND,ND"},
    "input1": {"name": "y","dtype": "float16,float32,int32","format": "ND,ND,ND"},
    "output0": {"name": "z","dtype": "float16,float32,int32","format": "ND,ND,ND"}
}

Returns

OpAICoreDef operator definition. For details, see OpAICoreDef.

Constraints

None

Examples

The following provides examples of implementing and registering the inference function of the customized Add operator.

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
static ge::graphStatus OpSelectFormat(const ge::Operator &op, ge::AscendString &result)
{
    std::string resultJsonStr;
    // If the first dimension of the first input parameter's shape is less than or equal to 8, multiple formats are supported. Otherwise, only int32 is supported.
    if (op.GetInputDesc(0).GetShape().GetDim(0) <= 8) {
        resultJsonStr = R"({
        "input0": {"name": "x","dtype": "float16,float32,int32","format": "ND,ND,ND","unknownshape_format": "ND,ND,ND"},
        "input1": {"name": "y","dtype": "float16,float32,int32","format": "ND,ND,ND","unknownshape_format": "ND,ND,ND"},
        "output0": {"name": "z","dtype": "float16,float32,int32","format": "ND,ND,ND","unknownshape_format": "ND,ND,ND"}
        })";
    } else {
        resultJsonStr = R"({
        "input0": {"name": "x","dtype": "int32","format": "ND","unknownshape_format": "ND"},
        "input1": {"name": "y","dtype": "int32","format": "ND","unknownshape_format": "ND"},
        "output0": {"name": "z","dtype": "int32","format": "ND","unknownshape_format": "ND"}
        })";
    }
    result = ge::AscendString(resultJsonStr.c_str());
    return ge::GRAPH_SUCCESS;
}

An example of registering the inference function is as follows:

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
class AddCustom : public OpDef {
public:
    AddCustom(const char* name) : OpDef(name)
    {
        this->Input("x")
            .ParamType(REQUIRED);
        this->Input("y")
            .ParamType(REQUIRED);
        this->Output("z")
            .ParamType(REQUIRED);
        this->SetInferShape(ge::InferShape);
        this->AICore()
            .SetTiling(optiling::TilingFunc)
            .SetTilingParse(optiling::TilingPrepare)
            .SetOpSelectFormat(optiling::OpSelectFormat);

        OpAICoreConfig aicConfig;
        aicConfig.DynamicCompileStaticFlag(true)
            .DynamicFormatFlag(true)
            .DynamicRankSupportFlag(true)
            .DynamicShapeSupportFlag(true)
            .NeedCheckSupportFlag(false)
            .PrecisionReduceFlag(true);
        // Note: Replace soc_version with the actual AI Processor version.
        this->AICore().AddConfig("soc_version", aicConfig);
    }
};