昇腾社区首页
中文
注册

导出ONNX模型

简介

PyTorch模型在昇腾AI处理器上的部署策略是基于PyTorch官方支持的ONNX模块实现的。ONNX是业内目前比较主流的模型格式,广泛用于模型交流及部署。本节主要介绍如何将Checkpoint文件通过torch.onnx.export()接口导出为ONNX模型。

.pth.tar文件导出ONNX模型

.pth.tar在导出ONNX模型时需要先确定保存时的信息,有时保存的节点名称和模型定义中的节点会有差异,例如会多出前缀和后缀。在进行转换的时候,可以对节点名称进行修改。

  1. 安装onnx。
    1
    pip3 install onnx
    
  2. 创建转换脚本,例如命名为“pth2onnx.py”,和ResNet50网络训练脚本main.py放置在同一目录下。

    转换代码样例脚本如下:

     1
     2
     3
     4
     5
     6
     7
     8
     9
    10
    11
    12
    13
    14
    15
    16
    17
    18
    19
    20
    21
    22
    23
    24
    25
    26
    27
    28
    29
    30
    31
    32
    33
    34
    35
    from collections import OrderedDict
    import torch
    import torch_npu
    import torch.onnx
    import torchvision.models as models
    
    # 如果.pth.tar文件保存时节点名加了前缀或后缀,则通过遍历删除。此处以遍历删除前缀"module."为例。若无前缀后缀则不影响。
    def proc_nodes_module(checkpoint, AttrName):
        new_state_dict = OrderedDict()
        for key, value in checkpoint[AttrName].items():
            if key == "module.features.0.0.weight":
                print(value)
            # 根据实际前缀后缀情况修改
            if(key[0:7] == "module."):
                name = key[7:]
            else:
                name = key[0:]
    
            new_state_dict[name] = value
        return new_state_dict
    
    def convert():
        # 模型定义来自于torchvision,样例生成的模型文件是基于resnet50模型
        checkpoint = torch.load("./checkpoint.pth.tar", map_location=torch.device('cpu'))    # 根据实际文件名称修改
        checkpoint['state_dict'] = proc_nodes_module(checkpoint,'state_dict')
        model = models.resnet50(pretrained = False)
        model.load_state_dict(checkpoint['state_dict'])
        model.eval()
        input_names = ["actual_input_1"]
        output_names = ["output1"]
        dummy_input = torch.randn(1, 3, 224, 224)
        torch.onnx.export(model, dummy_input, "resnet50.onnx", input_names = input_names, output_names = output_names, opset_version=11)    # 输出文件名根据实际情况修改
    
    if __name__ == "__main__":
        convert()
    
  3. 执行命令转换模型。
    1
    python3 pth2onnx.py
    

    命令执行成功后,会在当前目录下生成“resnet50.onnx”模型文件。