mindx_elastic.torch_exception_module.exception_checkpoint
功能说明
是mindx_elastic.torch_exception_module.ExceptionHandler的实例化对象,实例化过程中会完成参数初始化及临终信号函数注册。其中ExceptionHandler的成员属性说明如下:
- _save_func (函数对象) - 指定模型保存的函数,初始为"None"。
- _args (tuple) - 指定模型保存时的相关参数,初始为"()"。
- _group_info (List) - 指定混合并行模式下模型切分信息,初始为"None"。
- _partial_save (bool) - 指定是否开启部分保存,初始为"False"。
- _replicas (int) - 指定部分保存的副本数量,初始为"1"。
使用示例
from mindx_elastic.torch_exception_module import exception_checkpoint ... def get_checkpoint_names(checkpoints_path, iteration, use_distributed_optimizer, release=False, pipeline_parallel=None, tensor_rank=None, pipeline_rank=None): """Determine the directory name for this rank's checkpoint.""" args = get_args() ... if use_distributed_optimizer: model_name = os.path.join(common_path, "model_rng.pt") optim_name = os.path.join( common_path + "_%03d" % mpu.get_data_parallel_rank(), "optim.pt") else: fault_npus = exception_checkpoint._get_fault_npus() data_paral_size = len(exception_checkpoint._group_info) if str(args.rank) in fault_npus: for i in exception_checkpoint._group_info[int(args.rank)%data_paral_size]: if str(i) not in fault_npus : model_name = optim_name = os.path.join(common_path, f"model_optim_rng_{i}.pt") break else: model_name = optim_name = os.path.join(common_path, f"model_optim_rng_{args.rank}.pt") return model_name, optim_name ... iteration, release = read_metadata(tracker_filename) if not exception_checkpoint.check_resumable_model(): iteration = iteration // args.save_interval * args.save_interval if rank0: ...
set_config(save_func,args,partial_save,group_info,replicas)
设置内部参数。
参数 |
类型 |
说明 |
---|---|---|
save_func |
函数对象 |
保存ckpt的方法名。 |
args |
tuple |
save_func的参数,需要根据实际情况填写。 |
partial_save |
bool |
是否保存部分ckpt,默认为False。 |
group_info |
List |
混合并行模式下的权重分割策略,数据并行模式下不需要传递该参数。 |
replicas |
int |
保存ckpt的分数,开启partial_save的情况下才生效,默认值为1。 |
check_resumable_model()
检查能否完整恢复模型,返回值类型为bool。
使用示例如下:
from mindx_elastic.torch_exception_module import exception_checkpoint if not exception_checkpoint.check_resumable_model(): #回退iteration
_get_fault_npus()
获取本节点上故障芯片的rank id,返回值类型为List。
使用示例如下:
from mindx_elastic.torch_exception_module import exception_checkpoint fault_npus = exception_checkpoint._get_fault_npus()