平台适配
重调度模式
Volcano中重调度代码示例如下:
func (tp *module) PreStartAction(ssn *framework.Session) error {
moduleFullName := util.NPUCardName + util.ModuleAcceleratorType
klog.V(util.LogInfoLev).Infof("Entering PreStartAction of %s...", moduleFullName)
defer klog.V(util.LogInfoLev).Infof("Leaving PreStartAction of %s", moduleFullName)
if tp == nil || ssn == nil || tp.FrameAttr.KubeClient == nil {
return fmt.Errorf("%s handler not enabled or ssn is nil: %s", moduleFullName, util.ArgumentError)
}
tp.reHandle = rescheduling.New(&tp.ScheduleEnv, rescheduling.CmFaultJobKind)
if tp.reHandle == nil {
klog.V(util.LogErrorLev).Infof("create new fault handler failed.")
return fmt.Errorf("%s reSchedule not enabled: %s", moduleFullName, util.ArgumentError)
}
tp.reHandle.NewReScheduler()
tp.reHandle.SynCacheFaultNodeWithSession(util.NPUCardName)
tp.reHandle.AddFaultNodeWithSession(util.NPUCardName)
tp.reHandle.SynCacheFaultJobWithSession(ssn, util.NPUCardName, util.NPUCardNamePre)
tp.reHandle.SynCacheNodeRankOccMapWithSession(ssn)
// 1. restart Fault Jobs that are recorded in cache
if restartErr := tp.reHandle.RestartNeedForceDeleteJobs(ssn); restartErr != nil {
klog.V(util.LogErrorLev).Infof("%s RestartNeedForceDeleteJobs: %s", moduleFullName, restartErr.Error())
}
// 2. get all the new jobs in session
runningJobs, getRunErr := tp.reHandle.GetRunningJobs(ssn, util.NPUCardName, util.ModuleAcceleratorType)
if getRunErr != nil {
klog.V(util.LogInfoLev).Infof("%s GetRunningJobs: %s", moduleFullName, getRunErr.Error())
}
// 3. get nodes of session and fault jobs
err := tp.reHandle.AddFaultJobWithSession(runningJobs, util.NPUCardName, util.NPUCardNamePre)
if err != nil {
klog.V(util.LogErrorLev).Infof("%s AddFaultJobWithSession", moduleFullName)
}
// 4. restart the fault jobs
if restartErr := tp.reHandle.RestartFaultJobs(ssn); restartErr != nil {
klog.V(util.LogErrorLev).Infof("%s RestartFaultJobs: %s", moduleFullName, restartErr.Error())
return restartErr
}
// 5. save structure for later allocation process
tp.reHandle.GenerateNodeRankIndexTaskMap()
return nil
}
故障包括了内部的节点故障、芯片故障、参数面网络故障、业务面故障,将作为对外的信息放在K8s的ConfigMap中,以供外部查询和使用。查询命令为kubectl describe cm -n volcano-system vcjob-fault-npu-cm,命令回显的参数说明见表1。
参数名 |
描述 |
|---|---|
fault-node |
节点维度的故障信息 |
NodeName |
节点名称 |
FaultDeviceList |
故障列表 |
- fault_type |
故障类型对象,对象包含fault_type、npu_name、large_model_fault_level、fault_level、fault_handling和fault_code等6个字段
|
- npu_name |
故障的芯片名称,节点故障时为空 |
- large_model_fault_level |
故障处理类型,节点故障时取值为空
说明:
large_model_fault_level、fault_level和fault_handling参数功能一致,推荐使用fault_handling。 |
- fault_level |
|
- fault_handling |
|
- fault_code |
故障码,由英文逗号拼接而成的字符串
|
FaultTasks |
任务维度的故障信息列表,包含Reason字段 |
- Reason |
故障原因,字段就是故障列表下的五个字段组成的字符串 |
- 故障类型为NodeUnhealthy,即节点故障时,会直接重新执行训练或推理任务,并隔离故障节点,将任务重新调度到其他节点执行任务,如果没有冗余的节点,则任务会持续处于Pending状态。
- 故障类型为CardUnhealthy和CardNetworkUnhealthy,即芯片故障和参数面网络故障时,会根据具体故障类型决定是否重新执行任务、隔离芯片、复位芯片等,具体参考“fault_handling”字段的执行策略。
- 故障类型为业务面故障时,Volcano会检测是否开启无条件重试功能,开启后会重新调度到本节点并重新执行训练或推理任务,重试次数减1;当重试次数为0或者没有开启无条件重试功能时,不会对业务容器故障进行处理。
优雅容错模式
优雅容错模式基于故障重调度模式,集成优雅容错模式前请先完成故障重调度模式集成。集群调度组件提供管理进程样例,用户可直接在训练脚本中执行该样例启动管理进程,以下为执行参考样例。
训练脚本启动管理进程reset_process示例:
# 单机多卡和分布式
if [ $# == 6 ]; then
export DEVICE_NUM=$1
export SERVER_NUM=$2
export RANK_SIZE=$1
export RANK_TABLE_FILE=$3
export SERVER_ID=$4
device_each_server=$((DEVICE_NUM / SERVER_NUM))
rank_start=$((${device_each_server} * SERVER_ID))
DATA_PATH=$5
CONFIG_PATH=$6
# 先启动后台任务,最后留一个前台任务查看日志输出
for((i=$((${device_each_server}-1)); i>=0; i--))
do
rankid=$((rank_start + i))
export DEVICE_ID=${i}
export RANK_ID=${rankid}
rm -rf ${ROOT_PATH}/train_parallel${rankid}
mkdir ${ROOT_PATH}/train_parallel${rankid}
cp ${ROOT_PATH}/../*.py ${ROOT_PATH}/train_parallel${rankid}
cp ${ROOT_PATH}/*.sh ${ROOT_PATH}/train_parallel${rankid}
cp -r ${ROOT_PATH}/../src ${ROOT_PATH}/train_parallel${rankid}
cd ${ROOT_PATH}/train_parallel${rankid} || exit
echo "start training for rank $RANK_ID, device $DEVICE_ID"
env > env.log
python ${ROOT_PATH}/../train.py --run_distribute=True --device_num=${RANK_SIZE} --data_path=${DATA_PATH} --config_path=${CONFIG_PATH} &> log &
train_pids[$i]=$!
done
else
echo "Invalid input parameter, usage: main.sh device_count server_count rank_table_file server_id dataset config_file_path" | tee -a log
exit 1
fi
python -u ${ROOT_PATH}/reset_process.py -p "${train_pids[@]}"
wait
Reset_Process基于Python语言编写,其中关键能力模块示例如下,用户可参考关键模块能力实现该管理进程能力。
# Reset_Process中关键模块示例如下,用户可参考不同模块实现自己的管理进程
# 获取故障卡RANK和恢复卡RANK
def get_fault_ranks(self):
fault_rank_list = self._get_ranks_from_cm(self.reset_cm_path, "unrecovered")
if len(fault_rank_list) != 0:
self.fault_rank_list = fault_rank_list
return fault_rank_list
def get_recover_ranks(self):
recover_rank_list = self._get_ranks_from_cm(self.reset_cm_path, "recovered")
if len(recover_rank_list) != 0:
self.recover_rank_list = recover_rank_list
return recover_rank_list
# 停止训练进程
def _kill_abnormal_process(self, abnormal_rank_list: list):
if self.killed_abnormal:
return
try:
logger.info(f"to kill abnormal rank {abnormal_rank_list}")
self._process_manager.kill_fault_process(abnormal_rank_list)
self.killed_abnormal = True
except Exception as e:
logger.error(f"an unexpected error {e} occur when kill abnormal process")
self.exit_recover_process()
def _kill_normal_process(self, normal_rank_list: list):
if self.killed_normal:
return
try:
logger.info(f"to kill normal rank {normal_rank_list}")
self._process_manager.kill_fault_process(normal_rank_list)
self.killed_normal = True
except Exception as e:
logger.error(f"an unexpected error {e} occur when kill normal process")
self.exit_recover_process()
# 重启训练进程
def restore_train_process(self):
"""
Recover all target processes in this node
"""
new_pid_list = multiprocessing.Manager().list()
if not self.all_stopped() or self._restart:
return new_pid_list
process = []
for rank in self._cmd_dict:
command = self._cmd_dict[rank]
pwd_path = self._env_dict[rank]['PWD']
env_info = self._env_dict[rank]
p = multiprocessing.Process(target=_run_recover, args=(command, pwd_path, env_info, new_pid_list))
process.append(p)
p.start()
for p in process:
p.join()
self._restart = True
logger.info(f"new pids are:{new_pid_list}")
return new_pid_list
父主题: 故障处理