ONNX Model Exporting
Overview
The ONNX module officially supported by PyTorch is the basis of the policy for deploying PyTorch models on Ascend AI Processors. ONNX is a mainstream model format in the industry and is widely used for model sharing and deployment. This section describes how to export a checkpoint file as an ONNX model by using the torch.onnx.export() API.
Using the .pth.tar File to Export an ONNX Model
Before exporting the ONNX model using the .pth.tar file, you need to check the saved information. Sometimes, the saved node name may be different from the node name in the model definition. For example, a prefix and suffix may be added. During the conversion, you can modify the node name.
- Install ONNX.
1pip3 install onnx
- Create a conversion script, for example, pth2onnx.py, and place it in the same directory as the ResNet-50 network training script main.py.
The sample script of the conversion code is as follows:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35
from collections import OrderedDict import torch import torch_npu import torch.onnx import torchvision.models as models # If a prefix or suffix is added to the node names when the pth.tar file is saved, delete the prefix or suffix by traversing. The following describes how to traverse and delete the prefix module. If there is no prefix or suffix, skip the operation. def proc_nodes_module(checkpoint, AttrName): new_state_dict = OrderedDict() for key, value in checkpoint[AttrName].items(): if key == "module.features.0.0.weight": print(value) # Make modifications based on the actual prefix or suffix. if(key[0:7] == "module."): name = key[7:] else: name = key[0:] new_state_dict[name] = value return new_state_dict def convert(): # The model definition comes from the torchvision. The model file generated in the example is based on the ResNet-50 model. checkpoint = torch.load("./checkpoint.pth.tar", map_location=torch.device('cpu')) # Replace the file name as required. checkpoint['state_dict'] = proc_nodes_module(checkpoint,'state_dict') model = models.resnet50(pretrained = False) model.load_state_dict(checkpoint['state_dict']) model.eval() input_names = ["actual_input_1"] output_names = ["output1"] dummy_input = torch.randn(1, 3, 224, 224) torch.onnx.export(model, dummy_input, "resnet50.onnx", input_names = input_names, output_names = output_names, opset_version=11) # Replace the output file name as required. if __name__ == "__main__": convert()
- Run the following command to convert the model:
1python3 pth2onnx.pyAfter the command is executed successfully, the resnet50.onnx model file is generated in the current directory.