加载权重时遇到报错“load state_dict error.”
2023/06/06
174
问题信息
问题来源 | 产品大类 | 关键字 |
---|---|---|
官方 | 模型训练 | -- |
问题现象描述
- 报错截图
- 报错文本
…… RuntimeError: Error(s) in loading state_dict for ShuffleNetV2: Missing keys in state_dict: ……
原因分析
模型训练后保存的state_dict的key值与加载时state_dict的key值不一致,保存时会在每个key的最前面增加一个module前缀。
解决措施
加载权重时先遍历state_dict字典,修改key值,并使用新建的字典。具体用例参考以下内容:
ckpt = torch.load("checkpoint.pth", map_location=loc) # model.load_state_dict(ckpt['state_dict']) state_dict_old = ckpt['state_dict'] state_dict = {} for key, value in state_dict_old.items(): key = key[7:] state_dict[key] = value model.load_state_dict(state_dict)
本页内容