溢出检测定位问题案例

问题现象

在昇腾AI处理器上进行BERT-CRF模型训练,发现Loss收敛异常。

总体思路

在PyTorch模型迁移与训练过程中出现的精度问题,可以从以下两方面依次进行定位。

  1. 若用户在训练过程中怀疑网络存在溢出问题,用户可以使用溢出检测。检测到有溢出结果之后,可以参考以下步骤规避或解决。
  2. 部分溢出属于过程溢出,可能不影响最终精度。因此检测到发生溢出的API后,可以通过dump比对验证该溢出是否导致了精度问题。这时可以用工具的list模式在NPU上dump溢出API的输入与输出,与标杆侧(CPU/GPU/NPU)做比对,观察溢出是否导致了精度问题。如果比对结果显示没有精度问题,可以考虑不进行下一步。
  3. 对于造成了精度问题的溢出,可将该API的计算用表示范围更大的数据类型来规避溢出,或联系华为工程师求助,可进入昇腾开源社区使用issue进行沟通。

环境准备

请参考ptdbg_ascend工具使用说明安装ptdbg_ascend精度工具。BERT-CRF模型训练工程可点击获取链接-ModelZoo,下载BERT-CRF模型训练脚本并准备训练环境和训练数据集。

溢出检测

  1. 进入模型脚本所在目录并打开。

    cd examples/sequence_labeling
    vi task_sequence_labeling_ner_crf.py

  2. 在训练脚本中导入精度工具包,使能精度工具溢出定位;在模型定义后,训练循环开始前,添加溢出检测函数,其中overflow_check为使能溢出检测开关,overflow_nums为检测到溢出抛出异常的阈值次数,超过这个阈值就会退出训练。

    from ptdbg_ascend import PrecisionDebugger
    
    ...
    
    model = Model().to(device) 
    
    print(model)
    
    if 'npu' in device:
        optimizer = apex.optimizers.NpuFusedAdam(model.parameters(), lr=args.lr)
        model, optimizer = amp.initialize(model, optimizer, opt_level=args.opt_level, \
             loss_scale=256, combine_grad=True, combine_ddp=True if distributed else False)
    else:
        optimizer = optim.Adam(model.parameters(), lr=args.lr)
    
    debugger = PrecisionDebugger(dump_path="./overflow", hook_name="overflow_check")
    
    # 模型初始化
    # 下面代码也可以用PrecisionDebugger.start()和PrecisionDebugger.stop()
    debugger.start()
    
    # 需要dump的代码片段1
    
    debugger.stop()
    debugger.start()
    
    # 需要dump的代码片段2
    
    debugger.stop()
    debugger.step()

  3. 参考ModelZoo BERT-CRF模型获取界面的说明,拉起训练。样例代码如下:

    bash ./test/train_full_1p.sh --data_path=$data_path    # data_path请根据实际情况设置

  4. 检测到溢出,抛出异常并结束训练。样例异常信息如下图。

    图1 溢出异常

  5. 找到图1中含溢出信息的pkl文件进行分析。

    分析.pkl日志文件。日志文件中只会记录发生溢出的API。发现溢出发生在第145次调用线性函数的反向传播output,由此可确定出现问题的API为Functional_linear_145_backward,接下来可以用工具的list模式,dump比对NPU侧和标杆(GPU/NPU/CPU)的数据,确定溢出是否导致了精度问题。list模式和dump比对可以参考configure_hook的dump模式和数据dump比对场景。如果确定有精度问题,可将该API的计算用表示范围更大的数据类型来规避溢出,或联系华为工程师求助,可进入昇腾开源社区使用issue进行沟通。
    图2 日志信息