溢出检测定位问题案例

问题现象

在昇腾AI处理器上进行BERT-CRF模型训练,发现Loss收敛异常。

总体思路

在PyTorch模型迁移与训练过程中出现的精度问题,可以从以下两方面依次进行定位。

  1. 若用户在训练过程中怀疑网络存在溢出问题,用户可以使用溢出检测。检测到有溢出结果之后,可以参考以下步骤规避或解决。
  2. 部分溢出属于过程溢出,可能不影响最终精度。因此检测到发生溢出的API后,可以通过dump比对验证该溢出是否导致了精度问题。这时可以用工具的list模式在NPU上dump溢出API的输入与输出,与标杆侧(CPU/GPU/NPU)做比对,观察溢出是否导致了精度问题。如果比对结果显示没有精度问题,可以考虑不进行下一步。
  3. 对于造成了精度问题的溢出,可将该API的计算用表示范围更大的数据类型来规避溢出,或联系华为工程师求助,可进入昇腾开源社区使用issue进行沟通。

环境准备

请参考ptdbg_ascend工具使用说明安装ptdbg_ascend精度工具。BERT-CRF模型训练工程可点击获取链接-ModelZoo,下载BERT-CRF模型训练脚本并准备训练环境和训练数据集。

溢出检测

  1. 进入模型脚本所在目录并打开。

    cd examples/sequence_labeling
    vi task_sequence_labeling_ner_crf.py

  2. 在训练脚本中导入精度工具包,使能精度工具溢出定位。

    from ptdbg_ascend import register_hook, overflow_check, seed_all, set_dump_path, set_dump_switch, acc_cmp_dump

  3. 在模型定义后,训练循环开始前,添加溢出检测函数,其中overflow_check为使能溢出检测开关,overflow_nums为检测到溢出抛出异常的阈值次数,超过这个阈值就会退出训练。

    model = Model().to(device) 
    
    print(model)
    
    if 'npu' in device:
        optimizer = apex.optimizers.NpuFusedAdam(model.parameters(), lr=args.lr)
        model, optimizer = amp.initialize(model, optimizer, opt_level=args.opt_level, \
             loss_scale=256, combine_grad=True, combine_ddp=True if distributed else False)
    else:
        optimizer = optim.Adam(model.parameters(), lr=args.lr)
    
    set_dump_path("./data/dump", dump_tag='npu_overflow')    # 设置dump路径,最终数据保存在此路径下
    
    register_hook(model, overflow_check, overflow_nums=1)    # 使能溢出检测
    
    updates_total = len(train_dataloader) * args.train_epochs
    scheduler = get_linear_schedule_with_warmup(optimizer, \
         num_warmup_steps=warm_factor*updates_total, num_training_steps=updates_total)

  4. 参考ModelZoo BERT-CRF模型获取界面的说明,拉起训练。样例代码如下:

    bash ./test/train_full_1p.sh --data_path=$data_path    # data_path请根据实际情况设置

  5. 检测到溢出,抛出异常并结束训练。样例异常信息如下图。

    图1 溢出异常

  6. 在预设或默认路径中找到包含溢出信息的pkl文件。

    图2 包含溢出信息的pkl文件
    分析.pkl日志文件。日志文件中只会记录发生溢出的API。发现溢出发生在第145次调用线性函数的反向传播output,由此可确定出现问题的API为Functional_linear_145_backward,接下来可以用工具的list模式,dump比对NPU侧和标杆(GPU/NPU/CPU)的数据,确定溢出是否导致了精度问题。list模式和dump比对可以参考dump模式说明数据dump比对场景。如果确定有精度问题,可将该API的计算用表示范围更大的数据类型来规避溢出,或联系华为工程师求助,可进入昇腾开源社区使用issue进行沟通。
    图3 日志信息