下载
中文
注册
使用MMEngine进行断点续训时报错

使用MMEngine进行断点续训时报错

2025/03/18

20

暂无评分
我要评分

问题信息

问题来源产品大类产品子类关键字
官方模型训练PyTorch--

问题现象描述

当PyTorch版本为2.1.0,在NPU上执行多卡训练,通过MMEngine进行断点续训时报如下错误:

RuntimeError: Attempted to set the storage of a tensor on device "npu:X" to a storage on different device "npu:0"

原因分析

PyTorch2.1.0多卡训练断点续训加载权重时,处理自定义设备会默认将权重都放到0卡上。

解决措施

  1. 根据报错堆栈找到MMEngine中加载预训练权重的代码,如下所示:
    checkpoint = self.load_checkpoint(filename, map_location=device)
  2. 1中代码修改为如下代码。
    import os
    device_id = os.environ['LOCAL_RANK']
    device = get_device()
    checkpoint = self.load_checkpoint(filename, map_location=f"{device}:{device_id}")

本页内容