GE提供了针对前端框架导出模型的解析功能,若开发者已经获取到PyTorch或者TensorFlow框架导出的模型(.onnx模型或者.pb模型),可通过GE提供的ATC命令行工具或者C++语言的Parser接口,将这两类模型转化为适配昇腾的模型并基于图模式执行。
MindSpore框架网络导出的.air模型,也可以通过ATC命令行工具转换为适配昇腾的om模型,然后通过AscendCL接口加载执行。
下文仅针对如何将PyTorch或者TensorFlow框架导出的模型通过图模式执行进行详细介绍。
atc --model=$HOME/module/resnet50*.onnx --framework=5 --output=$HOME/module/out/onnx_resnet50 --soc_version=<soc_version>
atc --model=$HOME/module/resnet50_tensorflow*.pb --framework=3 --output=$HOME/module/out/tf_resnet50 --soc_version=<soc_version>
关键参数解释如下,ATC工具支持的详细参数及含义可参见《ATC工具使用指南》。
详细的模型加载与推理的方法可参见《CANN AscendCL应用软件开发指南(C&C++)》中的“推理应用开发流程”。
#include "onnx_parser.h" std::string onnxPath = "../data/onnx_test.onnx"; std::map<ge::AscendString, ge::AscendString> parser_params= { {ge::AscendString(ge::ir_option::INPUT_FP16_NODES), ge::AscendString("input1;input2")}, {ge::AscendString(ge::ir_option::OUTPUT), ge::AscendString("newIssue")}}; ge::Graph graph1; auto onnxStatus = ge::aclgrphParseONNX(onnxPath.c_str(), parser_params, graph1);
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 |
#include "onnx_parser.h" FILE *pFile = fopen("./onnx/resnet101.onnx", "rb" ); if(pFile==NULL) { fputs("File error",stderr); exit(1); } /* get the size of the file */ fseek(pFile, 0, SEEK_END); long lSize = ftell(pFile); rewind(pFile); /* assign memory buffer for the file*/ char *buffer =(char*) malloc(sizeof(char)*lSize); if(buffer == NULL) { fputs("Memory error", stderr); exit(2); } /* copy the file to buffer */ size_t result = fread(buffer, 1, lSize, pFile); if(result != lSize) { fputs("Reading error", stderr); exit(3); } std::map<ge::AscendString, ge::AscendString> parser_params= { {ge::AscendString(ge::ir_option::INPUT_FP16_NODES), ge::AscendString("input1;input2")}, {ge::AscendString(ge::ir_option::OUTPUT), ge::AscendString("newIssue")}}; ge::Graph graph1; auto onnxStatus = ge::aclgrphParseONNXFromMem(buffer, result, parser_params, graph1); |
开发者在调用解析接口时,还可以通过parser_params参数配置扩展参数,使用方法可参见《Ascend Graph开发指南》中的原始模型转换为Graph。
完成原始框架模型的解析后,会得到GE的Graph(即上述代码示例中的对象graph1),此时Graph保存在内存缓冲区中,开发者后续可以直接编译运行此Graph,图编译运行的流程可参见下一节Ascend Graph构图。