断点续训时间优化(PyTorch)
本章节介绍在PyTorch框架上使用断点续训特性时,用户可以选择使用的缩短断点续训时间的相关功能,包括缩短故障检测时间、重调度时间、集合通信初始化时间、训练回滚及加载checkpoint时间和算子编译时间。
故障检测时间优化
集群训练中出现参数面网络故障后,由于该故障不一定会影响训练任务,因此集群调度组件不会强制中断任务。当参数面网络故障影响训练任务时,由于集合通信的网络超时等待机制,将导致额外需要等待一段时间(通常默认为30分钟)后,集群调度组件才能感知到该故障从而触发断点续训。针对该问题,PyTorch Adapter插件(torch_npu)提供watchdog故障检测功能,该功能的详细说明请参考表1。
功能名称 |
watchdog故障检测。 |
|---|---|
功能特点 |
训练启动时,同时启动一个监测线程不断获取通信异常以及task执行异常。监测到故障发生后,快速抛出异常并终止训练任务进程,触发重调度。 |
使用说明 |
仅支持PyTorch 1.11.0、2.1.0及以上版本;PyTorch Adapter插件(torch_npu)版本必须高于6.0.RC1。 |
关键操作 |
在训练的shell启动脚本(例如train_start.sh)中,新增以下加粗的环境变量。PyTorch 2.1.0及以上版本默认开启watchdog故障检测,无需手动配置环境变量。 ... # env for breakpoint ckpt export RESUME_MODE_ENABLE=1 export HCCL_ASYNC_ERROR_HANDLING=1 # 开启watchdog功能,默认取值为0,表示不开启watchdog,取值为1表示开启watchdog |
重调度时间优化
重调度模式默认为Job级别重调度,每次故障时需要销毁所有Pod,然后重新创建并调度全部Pod。销毁、创建和调度Pod的过程将浪费大量时间。针对该问题,集群调度组件提供Pod级别重调度功能,该功能的介绍请参考表2。
功能名称 |
Pod级别重调度。 |
|---|---|
功能特点 |
每次故障只停止故障相关的Pod,重新创建并重调度故障相关的Pod后,重启训练任务。 |
使用说明 |
仅支持6.0.RC2及以上版本的集群调度组件。 |
关键操作 |
|
集合通信初始化时间优化
- PyTorch框架创建通信组时,使用TCP Store进行信息交换。任务规模变大后原生TCP Store的处理信息性能较差,导致创建通信组时间过长。针对该问题,PyTorch Adapter插件支持使用原生TCP Store的优化版本Parallel Store,详细说明请参考表3。
表3 Parallel Store功能说明 功能名称
Parallel Store。
功能特点
多线程处理建链请求,减少建链请求队列等待时间,降低总体建链时间。
使用说明
仅支持PyTorch 1.11.0版本;PyTorchPyTorch Adapter插件(torch_npu)版本必须高于6.0.RC1。
关键操作
将启动训练的shell脚本(例如train_start.sh)中,torchrun启动命令修改为torch_npu_run。
比如将
torchrun train.py --train_parameter=xxx ....
修改为
torch_npu_run train.py --train_parameter=xxx ....
- PyTorch框架NPU侧交换集合通信信息后,进行NPU卡间连接建链。当任务规模变大后,建链时间也大幅度增加。针对该问题,CANN对原生HCCL建链进行了性能优化,详细说明请参考表4。
训练回滚及加载checkpoint时间优化
- 训练时会每隔一段时间保存checkpoint文件(用于保存参数信息),每次保存checkpoint文件将浪费一定的训练时间,为了保证训练效率,保存checkpoint文件的间隔时间通常较大。每次故障时都需要从上一次保存的checkpoint回滚恢复训练,保存间隔越大每次故障时训练回滚浪费的时间就越长。针对该问题,集群调度组件支持MindIO ACP异步保存checkpoint,详细说明请参考表5。
- 训练回滚恢复时,通常需要从存储中加载之前保存的checkpoint,由于checkpoint数据量较大,直接从存储读取的耗时较长,导致加载checkpoint的时间较长。针对该问题,集群调度组件支持MindIO checkpoint高效恢复,详细说明请参考表6。
算子编译时间优化
功能名称 |
算子二进制。 |
|---|---|
功能特点 |
算子编译时提前加载预置的算子二进制,直接免编译执行算子。 |
使用说明 |
仅支持CANN 8.0.RC2及以上版本。 |
关键操作 |
在Python启动脚本中,添加算子二进制配置命令,开启算子二进制。 torch.npu.set_compile_mode(jit_compile=False) |
功能名称 |
算子编译缓存。 |
|---|---|
功能特点 |
算子编译时加载存储上保存的算子编译缓存文件,加载后可降低编译时间。 |
使用说明 |
仅支持CANN 8.0.RC2及以上版本。 |
关键操作 |
|