使用MMEngine进行断点续训时报错
2025/03/18
20
问题信息
问题来源 | 产品大类 | 产品子类 | 关键字 |
---|---|---|---|
官方 | 模型训练 | PyTorch | -- |
问题现象描述
当PyTorch版本为2.1.0,在NPU上执行多卡训练,通过MMEngine进行断点续训时报如下错误:
RuntimeError: Attempted to set the storage of a tensor on device "npu:X" to a storage on different device "npu:0"
原因分析
PyTorch2.1.0多卡训练断点续训加载权重时,处理自定义设备会默认将权重都放到0卡上。
解决措施
- 根据报错堆栈找到MMEngine中加载预训练权重的代码,如下所示:
checkpoint = self.load_checkpoint(filename, map_location=device)
- 将1中代码修改为如下代码。
import os device_id = os.environ['LOCAL_RANK'] device = get_device() checkpoint = self.load_checkpoint(filename, map_location=f"{device}:{device_id}")