PyTorch在训练过程中,通常使用torch.save()来保存Checkpoint文件,根据模型文件的后续用途会保存为两种格式的模型文件。
通过state_dict来保存和加载模型。
# 创建保存路径 PATH = "state_dict_model.pt" # 保存模型 torch.save(model.state_dict(), PATH)
# 模型文件保存路径 PATH = "state_dict_model.pt" model = TheModelClass(*args, **kwargs) #根据实际模型定义填写函数和参数 # 加载模型 model.load_state_dict(torch.load(PATH)) model.eval()
保存.pth或.pt文件扩展名的文件时要提供模型定义文件,否则无法保存。
PATH = "checkpoint.pth.tar" torch.save({ 'epoch': epoch, 'loss': loss, 'state_dict': model.state_dict(), 'optimizer' : optimizer.state_dict(), ... }, PATH)
model = TheModelClass(*args, **kwargs) #根据实际模型定义填写函数和参数 optimizer = TheOptimizerClass(*args, **kwargs) #根据实际优化器填写函数和参数 checkpoint = torch.load(PATH) model.load_state_dict(checkpoint['model_state_dict']) optimizer.load_state_dict(checkpoint['optimizer_state_dict']) epoch = checkpoint['epoch'] loss = checkpoint['loss'] model.eval() # - or - model.train()
通常情况下,训练图和推理图中对同一个算子处理方式不同(例如BatchNorm和dropout等算子),在输入格式上也有差别。因此在运行推理或导出ONNX模型之前,必须调用model.eval() 来将dropout和batch normalization层设置为推理模式。