PyTorch在训练过程中,通常使用torch.save()来保存Checkpoint文件,根据模型文件的后续用途会保存为两种格式的模型文件(pth文件和pth.tar文件),以便用于在线推理。
通过state_dict来保存和加载模型。
1 2 3 4 |
# 创建保存路径 save_pt_path = "state_dict_model.pt" # 保存模型 torch.save(model.state_dict(), save_pt_path) # model为前面训练定义的模型变量 |
1 2 3 4 5 6 7 8 9 |
# 模型文件保存路径 save_pt_path = "state_dict_model.pt" model = TheModelClass(*args, **kwargs) #根据实际模型定义填写函数和参数 # 以模型脚本与启动脚本配置中介绍的简单模型为例: model = ToyModel() # 加载模型 model.load_state_dict(torch.load(save_pt_path)) model.eval() |
保存为后缀是.pth/.pt的文件时,需要提供模型定义文件,否则后续模型无法部署。
1 2 3 4 5 6 7 8 |
checkpoint_path = "checkpoint.pth.tar" torch.save({ 'epoch': epoch, 'loss': loss, 'state_dict': model.state_dict(), 'optimizer' : optimizer.state_dict(), ... }, checkpoint_path) |
1 2 3 4 5 6 7 8 9 10 11 12 |
model = TheModelClass(*args, **kwargs) #根据实际模型定义填写函数和参数 optimizer = TheOptimizerClass(*args, **kwargs) #根据实际优化器填写函数和参数 checkpoint = torch.load(checkpoint_path) model.load_state_dict(checkpoint['state_dict']) optimizer.load_state_dict(checkpoint['optimizer']) epoch = checkpoint['epoch'] loss = checkpoint['loss'] model.eval() # - or - model.train() |
通常情况下,训练图和推理图中对同一个算子处理方式不同(例如BatchNorm和dropout等算子),在输入格式上也有差别。因此在运行推理或导出ONNX模型之前,必须调用model.eval()将dropout和batch normalization层设置为推理模式。