模块级精度数据dump

大模型场景下,通常不是简单的利用自动迁移能力实现GPU到NPU的训练脚本迁移,而是会对NPU网络进行一系列针对性的适配,因此,常常会造成迁移后的NPU模型存在部分子结构不能与GPU原始模型完全对应。模型结构不一致导致API调用类型及数量不一致,若直接按照API粒度进行精度数据dump和比对,则无法完全比对所有的API。

本节介绍的功能是对模型中的大粒度模块进行数据dump,使其比对时,对于无法以API粒度比对的模块可以直接以模块粒度进行比对。

模块指的是继承自nn.Module类模块,通常情况下这类模块就是一个小模型,可以被视为一个整体,dump数据时以模块为粒度进行dump。

示例代码如下:

# 根据需要import包
import os
import torch
import torch.nn as nn
import torch_npu
import torch.nn.functional as F
from ptdbg_ascend import *

torch.npu.set_device("npu:0")
# 定义一个简单的网络
class ModuleOP(nn.Module):
    def __init__(self) -> None:
        super().__init__()
        self.linear_1 = nn.Linear(in_features=8, out_features=4)
        self.linear_2 = nn.Linear(in_features=4, out_features=2)
    def forward(self, x):
        x1 = self.linear_1(x)
        x2 = self.linear_2(x1)
        r1 = F.relu(x2)
        return r1

if __name__ == "__main__":
    module = ModuleOP()

    # 注册工具
    pdbg = PrecisionDebugger("./dump_data/npu", hook_name="dump")
    pdbg.start()

    x = torch.randn(10, 8)
    module_dump(module, "MyModuleOP")    # 开启模块级精度数据dump
    out = module(x)
    module_dump_end()    # 结束模块级精度数据dump
    loss = out.sum()
    loss.backward()
    pdbg.stop()