导出ONNX模型

简介

昇腾AI处理器PyTorch模型的部署策略是基于PyTorch官方支持的ONNX模块实现的。ONNX是业内目前比较主流的模型格式,广泛用于模型交流及部署。本节主要介绍如何将Checkpoint文件通过torch.onnx.export()接口导出为ONNX模型。

.pth或.pt文件导出ONNX模型

保存的.pth或.pt文件可以通过PyTorch构建模型再加载权重的方法恢复,然后导出ONNX模型,样例如下:

import torch
import torch_npu
import torch.onnx
import torchvision.models as models
# 设置使用CPU导出模型
device = torch.device("cpu") 

def convert():
    # 模型定义来自于torchvision,样例生成的模型文件是基于resnet50模型
    model = models.resnet50(pretrained = False)  
    resnet50_model = torch.load('resnet50.pth', map_location='cpu')    #根据实际文件名称修改
    model.load_state_dict(resnet50_model) 

    batch_size = 1  #批处理大小
    input_shape = (3, 224, 224)   #输入数据,改成自己的输入shape

    # 模型设置为推理模式
    model.eval()

    dummy_input = torch.randn(batch_size, *input_shape) #  定义输入shape
    torch.onnx.export(model, 
                      dummy_input, 
                      "resnet50_official.onnx", 
                      input_names = ["input"],   # 构造输入名
                      output_names = ["output"],    # 构造输出名
                      opset_version=11,    # ATC工具目前支持opset_version=9,10,11,12,13
                      dynamic_axes={"input":{0:"batch_size"}, "output":{0:"batch_size"}})  #支持输出动态轴

if __name__ == "__main__":
    convert()
  • 在导出ONNX模型之前,必须调用model.eval() 来将dropout和batch normalization层设置为推理模式。
  • 样例脚本中的model来自于torchvision模块中的定义,用户使用自己的模型时需自行指定。
  • 构造输入输出需要对应训练时的输入输出,否则无法正常推理。

.pth.tar文件导出ONNX模型

.pth.tar在导出ONNX模型时需要先确定保存时的信息,有时保存的节点名称和模型定义中的节点会有差异,例如会多出前缀和后缀。在进行转换的时候,可以对节点名称进行修改。转换代码样例如下:

from collections import OrderedDict
import torch
import torch_npu
import torch.onnx
import torchvision.models as models

# 如果发现pth.tar文件保存时节点名加了前缀或后缀,则通过遍历删除。此处以遍历删除前缀"module."为例。若无前缀后缀则不影响。
def proc_nodes_module(checkpoint, AttrName):
    new_state_dict = OrderedDict()
    for key, value in checkpoint[AttrName].items():
        if key == "module.features.0.0.weight":
            print(value)
        #根据实际前缀后缀情况修改
        if(key[0:7] == "module."):
            name = key[7:]
        else:
            name = key[0:]

        new_state_dict[name] = value
    return new_state_dict

def convert():
    # 模型定义来自于torchvision,样例生成的模型文件是基于resnet50模型
    checkpoint = torch.load("./resnet50.pth.tar", map_location=torch.device('cpu'))    #根据实际文件名称修改
    checkpoint['state_dict'] = proc_nodes_module(checkpoint,'state_dict')
    model = models.resnet50(pretrained = False)
    model.load_state_dict(checkpoint['state_dict'])
    model.eval()
    input_names = ["actual_input_1"]
    output_names = ["output1"]
    dummy_input = torch.randn(1, 3, 224, 224)
    torch.onnx.export(model, dummy_input, "resnet50.onnx", input_names = input_names, output_names = output_names, opset_version=11)    #输出文件名根据实际情况修改

if __name__ == "__main__":
    convert()

跨平台模型保存

PyTorch在训练过程中,通常使用torch.save()来保存Checkpoint文件,为了支持NPU训练出的模型权重或模型可以跨平台使用,需要在模型存储前将模型或tensor放在CPU上进行存储,示例如下:

# 将模型放置在cpu上
model = model.cpu()     
# 创建保存路径     
PATH = "state_dict_model.pt"     
# 保存模型     
torch.save(model.state_dict(), PATH)