昇腾社区首页
中文
注册

mindx_elastic.torch_exception_module.exception_checkpoint

功能说明

是mindx_elastic.torch_exception_module.ExceptionHandler的实例化对象,实例化过程中会完成参数初始化及临终信号函数注册。其中ExceptionHandler的成员属性说明如下:

  • _save_func (函数对象) - 指定模型保存的函数,初始为"None"。
  • _args (tuple) - 指定模型保存时的相关参数,初始为"()"。
  • _group_info (List) - 指定混合并行模式下模型切分信息,初始为"None"。
  • _partial_save (bool) - 指定是否开启部分保存,初始为"False"。
  • _replicas (int) - 指定部分保存的副本数量,初始为"1"。

使用示例

from mindx_elastic.torch_exception_module import exception_checkpoint
...
def get_checkpoint_names(checkpoints_path, iteration, use_distributed_optimizer, release=False,
                        pipeline_parallel=None, tensor_rank=None, pipeline_rank=None):
"""Determine the directory name for this rank's checkpoint."""
    args = get_args()
    ...
    if use_distributed_optimizer:
        model_name = os.path.join(common_path, "model_rng.pt")
        optim_name = os.path.join(
            common_path + "_%03d" % mpu.get_data_parallel_rank(),
            "optim.pt")
    else:
        fault_npus = exception_checkpoint._get_fault_npus()
        data_paral_size = len(exception_checkpoint._group_info)
        if str(args.rank) in fault_npus:
            for i in exception_checkpoint._group_info[int(args.rank)%data_paral_size]:
                if str(i) not in fault_npus :
                    model_name = optim_name = os.path.join(common_path, f"model_optim_rng_{i}.pt")
                       break
        else:
            model_name = optim_name = os.path.join(common_path, f"model_optim_rng_{args.rank}.pt")
    return model_name, optim_name
...
    iteration, release = read_metadata(tracker_filename)
    if not exception_checkpoint.check_resumable_model():
        iteration = iteration // args.save_interval * args.save_interval
    if rank0:

...

set_config(save_func,args,partial_save,group_info,replicas)

设置内部参数。

表1 参数说明

参数

类型

说明

save_func

函数对象

保存ckpt的方法名。

args

tuple

save_func的参数,需要根据实际情况填写。

partial_save

bool

是否保存部分ckpt,默认为False。

group_info

List

混合并行模式下的权重分割策略,数据并行模式下不需要传递该参数。

replicas

int

保存ckpt的分数,开启partial_save的情况下才生效,默认值为1。

check_resumable_model()

检查能否完整恢复模型,返回值类型为bool。

使用示例如下:
from mindx_elastic.torch_exception_module import exception_checkpoint
if not exception_checkpoint.check_resumable_model():
  #回退iteration

_get_fault_npus()

获取本节点上故障芯片的rank id,返回值类型为List。

使用示例如下:
from mindx_elastic.torch_exception_module import exception_checkpoint
fault_npus = exception_checkpoint._get_fault_npus()