恢复时间优化(PyTorch)

本章节介绍在PyTorch框架上使用断点续训特性时,用户可以选择使用的缩短断点续训时间的相关功能,包括故障检测时间优化重调度时间优化集合通信初始化时间优化训练回滚及加载checkpoint时间优化算子编译时间优化

故障检测时间优化

由于集群中出现的参数面网络故障不一定会影响训练任务,因此集群调度组件不会强制中断任务;当参数面网络故障影响训练任务时,会触发集合通信的网络超时等待机制,在等待时间(通常默认为30分钟)后,集群调度组件才能感知到该故障,从而触发断点续训。针对该问题,PyTorch Adapter插件(torch_npu)提供watchdog故障检测功能,可用于检测训练任务是否受到影响,缩短故障检测时间,该功能的详细说明请参考表1

表1 watchdog故障检测功能说明

功能名称

watchdog故障检测。

功能特点

训练启动时,同时启动一个监测线程不断获取通信异常以及task执行异常。监测到故障发生后,快速抛出异常并终止训练任务进程,触发重调度。

使用说明

仅支持PyTorch 2.1.0及以上版本;PyTorch Adapter插件(torch_npu)版本必须高于6.0.RC1。

关键操作

PyTorch 2.1.0及以上版本默认开启watchdog故障检测,无需手动配置环境变量

(可选)如需关闭watchdog故障检测,需在训练的shell启动脚本(例如train_start.sh)中,修改以下环境变量。

...
# env for breakpoint ckpt
export RESUME_MODE_ENABLE=1

export HCCL_ASYNC_ERROR_HANDLING=0            # 取值为1,表示为开启watchdog功能;取值为0,表示关闭watchdog

重调度时间优化

断点续训故障处理的重调度模式默认使用Job级别重调度,需要在每次故障时销毁所有Pod,重新创建和调度全部Pod。销毁、创建和调度Pod的过程将浪费大量时间。针对该问题,集群调度组件提供Pod级别重调度功能,该功能的介绍请参考表2

表2 Pod级别重调度功能说明

功能名称

Pod级别重调度。

功能特点

每次故障只停止故障相关的Pod,重新创建并重调度故障相关的Pod后,重启训练任务。

使用说明

仅支持6.0.RC2及以上版本的集群调度组件。

关键操作

  1. 使用Dockerfile构建容器镜像,新增启动命令。
    ...
    # MindCluster无损失断点续训适配脚本
    RUN pip install $MINDX_ELASTIC_PKG
    RUN pip install $MINDIO_TTP_PKG
    
    # 可选,使用优雅容错、Pod级别重调度或进程级别重调度时必须配置以下命令
    RUN sed -i '/import logging/i import mindx_elastic.api' $(pip3 show torch | grep Location | awk -F ' ' '{print $2}')/torch/distributed/run.py
    
    ...
    说明:

    自6.0.0.SPC1版本起,删除Elastic Agent path相关说明;使用优雅容错、Pod级别重调度或进程级别重调度时新增配置命令。

  2. 在任务yaml中,新增以下字段,开启Pod级别重调度。
    ...
      labels:
        ...
        pod-rescheduling: "on"
    ...
  3. (可选)用户可以在启动训练的shell脚本(例如train_start.sh)中,新增max_restarts和monitor_interval参数,示例如下。
    ...
      logger "server id is: ""${server_id}"
      if [ "${framework}" == "PyTorch" ]; then
        get_env_for_pytorch_multi_node_job
        DISTRIBUTED_ARGS="--nproc_per_node $GPUS_PER_NODE --nnodes $NNODES --node_rank $NODE_RANK --master_addr $MASTER_ADDR --master_port $MASTER_PORT --max_restarts 5 --monitor_interval 10 "
    ...

    参数说明:

    • max_restarts:配置容器内最大允许触发的故障次数,取值为整数。超出次数后PyTorch训练进程会直接退出训练,不配置该参数时默认为32767次。
    • monitor_interval:配置监测训练进程状态的时间间隔,单位为秒,取值为整数。不配置该参数时默认为5秒。

集合通信初始化时间优化

训练回滚及加载checkpoint时间优化

算子编译时间优化

断点续训过程中拉起训练需要重新执行算子时,算子编译需要消耗大量时间。针对该问题,可选择算子二进制或算子编译缓存降低编译时间,详细说明请参考下表。

算子二进制和算子编译缓存二者不兼容,请选择其中之一进行使用。

表9 算子二进制功能说明

功能名称

使用算子二进制。

功能特点

算子编译时提前加载预置的算子二进制,直接免编译执行算子。

使用说明

仅支持CANN 8.0.RC2及以上版本。

关键操作

Python启动脚本中,添加算子二进制配置命令,开启算子二进制。

torch.npu.set_compile_mode(jit_compile=False)
表10 算子编译缓存功能说明

功能名称

算子编译缓存。

功能特点

算子编译时加载存储上保存的算子编译缓存文件,加载后可降低编译时间。

使用说明

仅支持CANN 8.0.RC2及以上版本。

关键操作

  1. Python启动脚本中,添加算子二进制编译缓存配置命令,开启算子二进制编编译缓存。
    torch.npu.set_compile_mode(jit_compile=True)
  2. 在训练的shell启动脚本中(例如train_start.sh),添加如下环境变量。
    export ASCEND_CACHE_PATH=xxx   # 添加共享存储路径
    export ASCEND_MAX_OP_CACHE_SIZE=-1    # 使用共享存储时建议开启,可解决多节点读取共享存储缓存资源争抢严重问题