设置输入输出tensor的format匹配模式。
1 | OpDef &FormatMatchMode(FormatCheckOption option) |
参数 |
输入/输出 |
说明 |
---|---|---|
option |
输入 |
匹配模式配置参数,类型为FormatCheckOption枚举类。支持以下几种取值:
|
OpDef算子定义,OpDef请参考OpDef。
不调用该接口的情况下,默认将NCHW/NHWC/DHWCN/NCDHW/NCL格式的输入输出转成ND格式进行处理。
下面示例中,算子AddCustom输入x只支持format为NCHW,输入y只支持foramt为NHWC,需要配置FormatMatchMode(FormatCheckOption::STRICT),如果不配置aclnn框架会转成ND格式传给算子tiling。
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 | AddCustom(const char* name) : OpDef(name) { this->Input("x") .ParamType(REQUIRED) .DataType({ge::DT_FLOAT}) .FormatList({ge::FORMAT_NCHW}); this->Input("y") .ParamType(REQUIRED) .DataType({ge::DT_FLOAT}) .FormatList({ge::FORMAT_NHWC}); this->Output("z") .ParamType(REQUIRED) .DataType({ge::DT_FLOAT}) .FormatList({ge::FORMAT_ND}); this->AICore().SetTiling(optiling::TilingFunc); this->AICore().AddConfig("ascendxxx"); this->FormatMatchMode(FormatCheckOption::STRICT); } |