原始模型转换为Graph
除了可以使用算子原型直接构图外,CANN还提供了框架解析功能,将主流框架的模型格式转换成CANN模型格式。
功能介绍
目前业界开源的深度学习框架(例如TensorFlow,PyTorch、Caffe等),定义模型的格式各有不同,例如TensorFlow通过自定义pb描述静态shape图和模型,PyTorch通过ONNX规范描述,因此需要通过统一的框架解析功能隔离上层框架差异,通过Parser模块完成解析并转换成昇腾AI处理器支持的CANN模型格式。
涉及的主要接口为:
- 解析TensorFlow模型:aclgrphParseTensorFlow
- 解析Caffe模型:aclgrphParseCaffe
- 解析ONNX原始模型:aclgrphParseONNX
- 解析加载至内存的ONNX模型:aclgrphParseONNXFromMem
Parser层目前为用户开放了自定义OpParser和自定义TensorFlow Scope融合规则的能力,如果用户在Parser解析时需要对框架进行更灵活的适配,则可以自定义OpParser或自定义开发TensorFlow Scope融合规则。
- 自定义OpParser:
如果用户需要将原始框架中算子直接映射到CANN中已实现的Ascend C算子,可直接进行第三方框架的适配,具体请参见AI框架算子适配章节。
- 自定义TensorFlow Scope融合规则:基于TensorFlow构建的神经网络计算图通常由大量的小算子组成,为了实现高性能的计算,往往需要对子图中的小算子进行融合,使得融合后的大算子可以充分利用硬件加速资源。具体请参见《TensorFlow Parser Scope融合规则开发指南》。

原始模型转换为Graph时,如果Tensor的shape维度和format维度数量不一致,按照如下表格中的规则理解当前维度:
例如,shape只有1维为[16],format为4维,比如NHWC,该场景下可以理解为shape的1维为C轴,其他轴需要补维,补维后格式为[1,1,1,16];
shape为2维[16,16],format为4维,比如NHWC,该场景下可以理解为shape的2维为HW轴,其他轴需要补维,补维后格式为[1,16,16,1]。
实际维度数 |
format |
维度理解为 |
---|---|---|
1 |
NCHW NHWC HWCN CHWN NDHWC NCDHW DHWCN DHWNC |
C |
2 |
NCHW |
CH |
NHWC |
HW |
|
HWCN |
CN |
|
CHWN |
WN |
|
NDHWC |
WC |
|
NCDHW |
HW |
|
DHWCN |
CN |
|
DHWNC |
NC |
|
3 |
NCHW |
CHW |
NHWC |
HWC |
|
HWCN |
WCN |
|
CHWN |
HWN |
|
NDHWC |
HWC |
|
NCDHW |
DHW |
|
DHWCN |
WCN |
|
DHWNC |
WNC |
|
4 |
NDHWC |
DHWC |
NCDHW |
CDHW |
|
DHWCN |
HWCN |
|
DHWNC |
HWNC |
基于TensorFlow模型解析
包含的头文件:
1
|
#include "tensorflow_parser.h" |
1 2 3 |
std::string tfPath = "../data/tf_test.pb"; ge::Graph graph1; auto tfStatus = ge::aclgrphParseTensorFlow(tfPath.c_str(),graph1); |
同时,支持用户指定parser_params:
1 2 3 4 5 6 |
std::string tfPath = "../data/tf_test.pb"; std::map<ge::AscendString, ge::AscendString> parser_params= { {ge::AscendString(ge::ir_option::INPUT_FP16_NODES), ge::AscendString("input1;input2")}, {ge::AscendString(ge::ir_option::OUTPUT), ge::AscendString("newIssue")}}; ge::Graph graph1; auto tfStatus = ge::aclgrphParseTensorFlow(tfPath.c_str(), parser_params, graph1); |
基于Caffe模型解析
包含的头文件:
1
|
#include "caffe_parser.h" |
1 2 3 4 |
std::string caffePath = "../data/caffe_test.prototxt"; std::string weight = "../data/caffe_test.caffemodel"; ge::Graph graph1; auto caffeStatus = ge::aclgrphParseCaffe(caffePath.c_str(), weight.c_str(), graph1); |
同时,支持用户指定parser_params:
1 2 3 4 5 6 7 |
std::string caffePath = "../data/caffe_test.prototxt"; std::string weight = "../data/caffe_test.caffemodel"; std::map<ge::AscendString, ge::AscendString> parser_params= { {ge::AscendString(ge::ir_option::INPUT_FP16_NODES), ge::AscendString("input1;input2")}, {ge::AscendString(ge::ir_option::OUTPUT), ge::AscendString("newIssue")}}; ge::Graph graph1; auto caffeStatus = ge::aclgrphParseCaffe(caffePath.c_str(), weight.c_str(), parser_params, graph1); |
基于ONNX模型解析
包含的头文件:
1
|
#include "onnx_parser.h" |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 |
/* read file in binary format */ FILE *pFile = fopen("./onnx/resnet101.onnx", "rb" ); if(pFile==NULL) { fputs("File error",stderr); exit(1); } /* get the size of the file */ fseek(pFile, 0, SEEK_END); long lSize = ftell(pFile); rewind(pFile); /* assign memory buffer for the file*/ char *buffer =(char*) malloc(sizeof(char)*lSize); if(buffer == NULL) { fputs("Memory error", stderr); exit(2); } /* copy the file to buffer */ size_t result = fread(buffer, 1, lSize, pFile); if(result != lSize) { fputs("Reading error", stderr); exit(3); } std::map<ge::AscendString, ge::AscendString> parser_params= { {ge::AscendString(ge::ir_option::INPUT_FP16_NODES), ge::AscendString("input1;input2")}, {ge::AscendString(ge::ir_option::OUTPUT), ge::AscendString("newIssue")}}; ge::Graph graph1; auto onnxStatus = ge::aclgrphParseONNXFromMem(buffer, result, parser_params, graph1); |
通过aclgrphParseONNX接口将ONNX原始模型转换为Graph,此时Graph保存在内存缓冲区中。同时,支持用户指定parser_params:
1 2 3 4 5 6 |
std::string onnxPath = "../data/onnx_test.onnx"; std::map<ge::AscendString, ge::AscendString> parser_params= { {ge::AscendString(ge::ir_option::INPUT_FP16_NODES), ge::AscendString("input1;input2")}, {ge::AscendString(ge::ir_option::OUTPUT), ge::AscendString("newIssue")}}; ge::Graph graph1; auto onnxStatus = ge::aclgrphParseONNX(onnxPath.c_str(), parser_params, graph1); |