CheckPoint文件格式转换示例(Torch)

对于使用Pytorch框架的用户,在大模型训练结束后,CheckPoint文件需要用于推理。这里举例说明,如何将MindIO ACP保存的CheckPoint文件转换成Torch原生格式的文件。

  • load_dir:替换为真实的CheckPoint保存目录。
  • new_dir:替换为CheckPoint转换后新保存的目录,建议为空目录。
  • iteration:指定转换这个iteration迭代周期的所有CheckPoint文件,会和load_dir进行拼接。
#  Copyright (c) Huawei Technologies Co., Ltd. 2024-2024. All rights reserved.
import os
import torch_mindio


def main():
    load_dir = ""  # Replace with the actual checkpoint directory path
    new_dir = ""  # Replace with the actual new directory path
    iteration = 2000  # Replace with the actual iteration number

    directory = 'iter_{:07d}'.format(iteration)
    common_path = os.path.join(load_dir, directory)

    if not os.path.exists(common_path):
        print(f"Source directory {common_path} does not exist.")
        return

    if not os.path.exists(new_dir):
        os.makedirs(new_dir)

    for root, _, files in os.walk(common_path):
        # Compute the relative path and target directory
        relative_path = os.path.relpath(root, common_path)
        target_dir = os.path.join(new_dir, relative_path)

        # Create directories in the target directory
        if not os.path.exists(target_dir):
            os.makedirs(target_dir)

        # Convert all files in the current directory
        for file in files:
            src_file = os.path.join(root, file)
            dst_file = os.path.join(target_dir, file)
            res = torch_mindio.convert(src_file, dst_file)
            print(f"Convert {src_file} to {dst_file}, result: {res}")


if __name__ == '__main__':
    main()