保存模型

PyTorch在训练过程中,通常使用torch.save()来保存Checkpoint文件,根据模型文件的后续用途会保存为两种格式的模型文件(pth文件和pth.tar文件),以便用于在线推理。

  • 在昇腾PyTorch1.11.0版本中,NPU模型在使用torch.save()进行存储的时候会保存NPU特有的设备信息和数据格式,以便于更好的支持断点训练,这使得保存的pth、pt和pth.tar扩展名文件存在跨平台兼容性问题。为了支持NPU训练出的模型权重或模型可以跨平台使用,需要在模型存储前将模型或tensor放在CPU上进行存储,示例如下:
    # 将模型放置在cpu上 
    model = model.cpu()  
  • PyTorch2.1.0及以后版本已支持跨设备读取权重,不需要模型或tensor放在CPU上进行存储。