导出ONNX模型
简介
昇腾AI处理器PyTorch模型的部署策略是基于PyTorch官方支持的ONNX模块实现的。ONNX是业内目前比较主流的模型格式,广泛用于模型交流及部署。本节主要介绍如何将Checkpoint文件通过torch.onnx.export()接口导出为ONNX模型。
.pth或.pt文件导出ONNX模型
保存的.pth或.pt文件可以通过PyTorch构建模型再加载权重的方法恢复,然后导出ONNX模型,样例如下:
import torch
import torch_npu
import torch.onnx
import torchvision.models as models
# 设置使用CPU导出模型
device = torch.device("cpu")
def convert():
# 模型定义来自于torchvision,样例生成的模型文件是基于resnet50模型
model = models.resnet50(pretrained = False)
resnet50_model = torch.load('resnet50.pth', map_location='cpu') #根据实际文件名称修改
model.load_state_dict(resnet50_model)
batch_size = 1 #批处理大小
input_shape = (3, 224, 224) #输入数据,改成自己的输入shape
# 模型设置为推理模式
model.eval()
dummy_input = torch.randn(batch_size, *input_shape) # 定义输入shape
torch.onnx.export(model,
dummy_input,
"resnet50_official.onnx",
input_names = ["input"], # 构造输入名
output_names = ["output"], # 构造输出名
opset_version=11, # ATC工具目前支持opset_version=9,10,11,12,13
dynamic_axes={"input":{0:"batch_size"}, "output":{0:"batch_size"}}) #支持输出动态轴
if __name__ == "__main__":
convert()
- 在导出ONNX模型之前,必须调用model.eval()来将dropout和batch normalization层设置为推理模式。
- 样例脚本中的model来自于torchvision模块中的定义,用户使用自己的模型时需自行指定。
- 构造输入输出需要对应训练时的输入输出,否则无法正常推理。
.pth.tar文件导出ONNX模型
.pth.tar在导出ONNX模型时需要先确定保存时的信息,有时保存的节点名称和模型定义中的节点会有差异,例如会多出前缀和后缀。在进行转换的时候,可以对节点名称进行修改。转换代码样例如下:
from collections import OrderedDict
import torch
import torch_npu
import torch.onnx
import torchvision.models as models
# 如果发现pth.tar文件保存时节点名加了前缀或后缀,则通过遍历删除。此处以遍历删除前缀"module."为例。若无前缀后缀则不影响。
def proc_nodes_module(checkpoint, AttrName):
new_state_dict = OrderedDict()
for key, value in checkpoint[AttrName].items():
if key == "module.features.0.0.weight":
print(value)
#根据实际前缀后缀情况修改
if(key[0:7] == "module."):
name = key[7:]
else:
name = key[0:]
new_state_dict[name] = value
return new_state_dict
def convert():
# 模型定义来自于torchvision,样例生成的模型文件是基于resnet50模型
checkpoint = torch.load("./resnet50.pth.tar", map_location=torch.device('cpu')) #根据实际文件名称修改
checkpoint['state_dict'] = proc_nodes_module(checkpoint,'state_dict')
model = models.resnet50(pretrained = False)
model.load_state_dict(checkpoint['state_dict'])
model.eval()
input_names = ["actual_input_1"]
output_names = ["output1"]
dummy_input = torch.randn(1, 3, 224, 224)
torch.onnx.export(model, dummy_input, "resnet50.onnx", input_names = input_names, output_names = output_names, opset_version=11) #输出文件名根据实际情况修改
if __name__ == "__main__":
convert()
父主题: 模型保存与导出