昇腾社区首页
中文
注册

MindSpore训练场景

前提条件

完成训练前配置检查

执行采集

  1. 创建配置文件。
    以在训练脚本所在目录创建config.json配置文件为例,文件内容拷贝如下示例配置。
     1
     2
     3
     4
     5
     6
     7
     8
     9
    10
    11
    12
    13
    {
        "task": "tensor",
        "dump_path": "./dump_data",
        "rank": [],
        "step": [],
        "level": "L1",
    
        "tensor": {
            "scope": [], 
            "list": [],
            "data_mode": ["all"]
        }
    }
    
  2. 在训练脚本(mindspore_main.py文件)中添加工具,如下所示。

    完整代码请参见MindSpore精度数据采集代码样例

    添加以下接口前,建议注释前面步骤的非训练脚本接口。
     1
     2
     3
     4
     5
     6
     7
     8
     9
    10
    11
    12
    13
    14
    from msprobe.mindspore import PrecisionDebugger
    debugger = PrecisionDebugger(config_path="./config.json")
    ...
    if __name__ == "__main__":
        step = 0
        # Train Model
        for data, label in ds.GeneratorDataset(generator_net(), ["data", "label"]):
            debugger.start(model)
            train_step(data, label)
            print(f"train step {step}")
            step += 1
            debugger.stop()
            debugger.step()
        print("train finish")
    

    精度数据会占据一定的磁盘空间,可能存在磁盘写满导致服务器不可用的风险。精度数据所需空间跟模型的参数、采集开关配置、采集的迭代数量有较大关系,须用户自行保证落盘目录下的可用磁盘空间。

  3. 执行训练脚本命令,工具会采集模型训练过程中的精度数据。
    python mindspore_main.py

    日志打印出现如下示例信息表示数据采集成功,完成采集后即可查看数据。

    1
    2
    3
    4
    5
    The cell hook function is successfully mounted to the model.
    The module statistics hook function is successfully mounted to the model.
    msprobe: debugger.start() is set successfully
    Dump switch is turned on at step 0.
    Dump data will be saved in /home/user1/dump/dump_data/step0.
    

结果查看

dump_path参数指定的路径下会出现如下目录结构,可以根据需求选择合适的数据进行分析。

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
dump_data/
├── step0
    └── rank
        ├── construct.json           # 保存Module的层级关系信息,当前场景为空
        ├── dump.json                # 保存前反向API的输入输出的统计量信息和溢出信息等
        ├── dump_tensor_data         # 保存前反向API的输入输出tensor的真实数据信息等
           ├── Jit.Momentum.0.forward.input.1.0.npy
           ├── Primitive.matmul.MatMul.1.forward.input.1.npy
           ├── Mint.add.1.backward.input.0.npy
           ├── Primitive.matmul.MatMul.1.forward.output.0.npy
        ...
        └── stack.json               # 保存API的调用栈信息
├── step1
...

采集后的数据需要用精度预检精度比对等工具进行进一步分析。