总体思路

精度问题排查需要找出具体引起问题的原因,主要根据以下方面进行精度问题排查:

  1. 模型网络计算错误。
    • 定位思路:在网络中加入hook进行排查判断是哪个地方有较大嫌疑,然后构建单算子样例构建逐渐缩小错误范围,证明该算子在当前网络场景下计算有误,可以对比CPU或GPU结果证明。
    • 规避方案:使用同等语义其他算子替代。
    • 解决方案:改进算子精度或功能问题。
  2. loss计算错误。
    • 定位思路:由于loss的特殊性和可以自定义,在判断loss计算错误后,建议dump网络中loss的输入来测试,而不是使用相同shape的随机tensor测试,这样才能更好地复现证明。
    • 规避方案:使用同等语义其他算子替代。
    • 解决方案:改进算子精度或功能问题(loss也是由算子构成)。
  3. 参数更新错误。
    • 定位思路:在每个optim.step()前对网络中的参数逐个打印其grad进行排查判断,然后构建单算子样例构建逐渐缩小错误范围,证明该算子在当前网络场景下梯度计算有误,可以对比CPU或GPU结果证明。该项优先级应低于12,因为上述两步骤的错误同样可以造成grad异常。
    • 规避方案:使用同等语义其他算子替代。
    • 解决方案:改进计算grad的算子精度或功能问题。
  4. 多卡计算错误。
    • 定位思路:在保证单卡精度OK的前提下,稳定复现多卡不收敛。
    • 解决方案:建议联系华为方支撑人员,提供稳定复现的单P和多P脚本。
  5. 其他要点。
    • 若不收敛或精度相差较大,着重在于算子精度问题,使用精度比对工具查看详细信息。
    • 若精度无法对齐,仅差少量几个点,可以调小loss_scale。
    • 若精度波动较大且收敛慢,可以少量调小学习率。
    • 若训练精度正常,eval精度为0或较低,则优先排查在eval时是否正确执行了model.eval(),DDP中broadcast_buffers是否已被置为False。
    • 若loss较大或部分loss异常,可以在loss函数中打点打印,确认异常的具体位置。