昇腾故障案例详情页

PyTorch断点续训精度溢出异常

更新时间: 2024/02/21

暂无评分

问题信息

问题来源产品大类产品子类关键字
官方模型训练PyTorch--

问题现象描述

MT线上集群多机训练时,偶现训练进程挂掉(各种原因),此时需要根据最近保存的ckpt继续训练,拉起后发现训练精度溢出,无法正常训练。

现象截图:

原因分析

NPU为了提高性能,内部包含许多种数据格式。训练时模型CKPT只会保留权重信息、格式,不会保存优化器(optimizer)的信息和格式。此时断点续训并载入优化器,会使用默认的数据格式,导致NPU无法识别,使得后续计算全部溢出。

解决措施

修改保存、载入ckpt的逻辑如下,即在保存ckpt时,额外保存一份data_format.pt用来存储数据格式。载入ckpt时,读取该.pt文件。

代码详情如下所示:

if 'optimizer_format' in checkpoint:
logger.info("Loading optimizer weights format from checkpoint...")
checkpoint['optimizer']['state'] = checkpoint_npu_format_cast(
checkpoint['optimizer']['state'], checkpoint['optimizer_format'])
elif os.path.exists(os.path.join(config.OUTPUT, 'optim_format.pt')):
logger.info(
"Loading optimizer weights format from optim_format.pt...")
ckpt_optimize_format = torch.load(
os.path.join(config.OUTPUT, 'optim_format.pt'))
checkpoint['optimizer']['state'] = checkpoint_npu_format_cast(
checkpoint['optimizer']['state'], ckpt_optimize_format)

本页内容

该页面对您有帮助吗?
我要评分