导出ONNX模型
简介
PyTorch模型在昇腾AI处理器上的部署策略是基于PyTorch官方支持的ONNX模块实现的。ONNX是业内目前比较主流的模型格式,广泛用于模型交流及部署。本节主要介绍如何将Checkpoint文件通过torch.onnx.export()接口导出为ONNX模型。
.pth.tar文件导出ONNX模型
.pth.tar在导出ONNX模型时需要先确定保存时的信息,有时保存的节点名称和模型定义中的节点会有差异,例如会多出前缀和后缀。在进行转换的时候,可以对节点名称进行修改。
- 安装onnx。
1
pip3 install onnx
- 创建转换脚本,例如命名为“pth2onnx.py”,和ResNet50网络训练脚本main.py放置在同一目录下。
转换代码样例脚本如下:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35
from collections import OrderedDict import torch import torch_npu import torch.onnx import torchvision.models as models # 如果.pth.tar文件保存时节点名加了前缀或后缀,则通过遍历删除。此处以遍历删除前缀"module."为例。若无前缀后缀则不影响。 def proc_nodes_module(checkpoint, AttrName): new_state_dict = OrderedDict() for key, value in checkpoint[AttrName].items(): if key == "module.features.0.0.weight": print(value) # 根据实际前缀后缀情况修改 if(key[0:7] == "module."): name = key[7:] else: name = key[0:] new_state_dict[name] = value return new_state_dict def convert(): # 模型定义来自于torchvision,样例生成的模型文件是基于resnet50模型 checkpoint = torch.load("./checkpoint.pth.tar", map_location=torch.device('cpu')) # 根据实际文件名称修改 checkpoint['state_dict'] = proc_nodes_module(checkpoint,'state_dict') model = models.resnet50(pretrained = False) model.load_state_dict(checkpoint['state_dict']) model.eval() input_names = ["actual_input_1"] output_names = ["output1"] dummy_input = torch.randn(1, 3, 224, 224) torch.onnx.export(model, dummy_input, "resnet50.onnx", input_names = input_names, output_names = output_names, opset_version=11) # 输出文件名根据实际情况修改 if __name__ == "__main__": convert()
- 执行命令转换模型。
1
python3 pth2onnx.py
命令执行成功后,会在当前目录下生成“resnet50.onnx”模型文件。
父主题: 快速入门