昇腾AI处理器PyTorch模型的部署策略是基于PyTorch官方支持的ONNX模块实现的。ONNX是业内目前比较主流的模型格式,广泛用于模型交流及部署。本节主要介绍如何将Checkpoint文件通过torch.onnx.export()接口导出为ONNX模型。
保存的.pth或.pt文件可以通过PyTorch构建模型再加载权重的方法恢复,然后导出ONNX模型,样例如下:
import torch import torch_npu import torch.onnx import torchvision.models as models # 设置使用CPU导出模型 device = torch.device("cpu") def convert(): # 模型定义来自于torchvision,样例生成的模型文件是基于resnet50模型 model = models.resnet50(pretrained = False) resnet50_model = torch.load('resnet50.pth', map_location='cpu') #根据实际文件名称修改 model.load_state_dict(resnet50_model) batch_size = 1 #批处理大小 input_shape = (3, 224, 224) #输入数据,改成自己的输入shape # 模型设置为推理模式 model.eval() dummy_input = torch.randn(batch_size, *input_shape) # 定义输入shape torch.onnx.export(model, dummy_input, "resnet50_official.onnx", input_names = ["input"], # 构造输入名 output_names = ["output"], # 构造输出名 opset_version=11, # ATC工具目前支持opset_version=9,10,11,12,13 dynamic_axes={"input":{0:"batch_size"}, "output":{0:"batch_size"}}) #支持输出动态轴 if __name__ == "__main__": convert()
.pth.tar在导出ONNX模型时需要先确定保存时的信息,有时保存的节点名称和模型定义中的节点会有差异,例如会多出前缀和后缀。在进行转换的时候,可以对节点名称进行修改。转换代码样例如下:
from collections import OrderedDict import torch import torch_npu import torch.onnx import torchvision.models as models # 如果发现pth.tar文件保存时节点名加了前缀或后缀,则通过遍历删除。此处以遍历删除前缀"module."为例。若无前缀后缀则不影响。 def proc_nodes_module(checkpoint, AttrName): new_state_dict = OrderedDict() for key, value in checkpoint[AttrName].items(): if key == "module.features.0.0.weight": print(value) #根据实际前缀后缀情况修改 if(key[0:7] == "module."): name = key[7:] else: name = key[0:] new_state_dict[name] = value return new_state_dict def convert(): # 模型定义来自于torchvision,样例生成的模型文件是基于resnet50模型 checkpoint = torch.load("./resnet50.pth.tar", map_location=torch.device('cpu')) #根据实际文件名称修改 checkpoint['state_dict'] = proc_nodes_module(checkpoint,'state_dict') model = models.resnet50(pretrained = False) model.load_state_dict(checkpoint['state_dict']) model.eval() input_names = ["actual_input_1"] output_names = ["output1"] dummy_input = torch.randn(1, 3, 224, 224) torch.onnx.export(model, dummy_input, "resnet50.onnx", input_names = input_names, output_names = output_names, opset_version=11) #输出文件名根据实际情况修改 if __name__ == "__main__": convert()