Loss Scale

概述

在混合精度计算中使用float16数据格式数据动态范围降低,造成梯度计算出现浮点溢出,会导致部分参数更新失败。为了保证部分模型训练在混合精度训练过程中收敛,需要配置Loss Scale的方法。

Loss Scale方法通过在前向计算所得的loss乘以loss scale系数S,起到在反向梯度计算过程中达到放大梯度的作用,从而最大程度规避浮点计算中较小梯度值无法用FP16表达而出现的溢出问题。在参数梯度聚合之后以及优化器更新参数之前,将聚合后的参数梯度值除以loss scale系数S还原。

动态Loss Scale通过在训练过程中检查梯度中浮点计算异常状态,自动动态选取loss scale系数S以适应训练过程中梯度变化,从而解决人工选取loss scale系数S和训练过程中自适应调整的问题。

在具体实现中,昇腾AI处理器由于浮点计算特性不同,在计算过程中的浮点异常检查等部分与GPU存在差异。

实现原理

使用Loss Scale

如果原始网络中使用了Loss Scale功能,工具会进行自动迁移,将TensorFlow的LossScaleManager迁移为NPU的ExponentialUpdateLossScaleManager或FixedLossScaleManager。如果原始网络中没有使用Loss Scale功能,用户可以根据需要参考Loss Scale自行添加。

由于NPU计算特性与GPU混合精度计算特性存在差异,LossScaleManager超参也往往需要进行适当的调整以保证精度。当用户模型基于默认Loss Scale参数训练产生溢出的迭代过多,影响最终精度时,需要对Loss Scale参数进行适当调整,减少发生浮点异常的次数。

具体方法为:参考打印loss scale值打印loss scale值,根据loss scale值观察溢出次数,调整LossScaleManager参数。

打印loss scale值

Estimator模式下,可以通过添加hook的方式实现对loss scale值进行打印:

class _LogSessionRunHook(tf.train.SessionRunHook):
   def before_run(self, run_context):
       return tf.estimator.SessionRunArgs(
               fetches=['overflow_status_reduce_all:0', 'loss_scale:0'])
 
   def after_run(self, run_context, run_values):
       if not run_values.results[0]:
           print('Find overflow in this step, skip apply gradients, loss scale value=%d' % run_values.results[1], flush=True)
       else:
           print('Apply gradients, loss scale value=%d' % run_values.results[1], flush=True)
  
...

if 'train' in params.exec_mode:
    training_hooks = get_hooks(params, logger)
    training_hooks.append(_LogSessionRunHook())
    estimator.train(
        input_fn = dataset.train_fn,
        steps = max_steps,
        hooks = training_hooks)

需要注意的是,以上hook无法适用所有网络,原因是loss scale值是根据算子名称打印的,如果用户使用了scope等指定网络中部分算子的名称,则该hook需要相应更改为需要获取的算子名称。

sess.run模式下,可以通过调用get_loss_scale接口从NPU的lossscale优化器获取loss scale的值。

# 原始代码示例
for step in range(restore_step, FLAGS.max_steps):
    data = next(data_generator)
    inputs_padded = data[0]
    bbox_padded = pad_bbox(data[1],FLAGS.num_bbox)
    input_image_np = inputs_padded
    input_bbox_np = bbox_padded

    ml, tl,ce_loss, bbox_loss, _, summary_str = sess.run([
                                       model_loss,
                                       total_loss, 
                                       rpn_cross_entropy,
                                       rpn_loss_box,
                                       train_op, summary_op],
                                       feed_dict={input_image: input_image_np,input_bbox: input_bbox_np})
    summary_writer.add_summary(summary_str, global_step=step)

# 修改后的代码示例
for step in range(restore_step, FLAGS.max_steps):
    data = next(data_generator)
    inputs_padded = data[0]
    bbox_padded = pad_bbox(data[1],FLAGS.num_bbox)
    input_image_np = inputs_padded
    input_bbox_np = bbox_padded

    lossScale = loss_scale_manager.get_loss_scale()
    overflow_status_reduce_all = tf.get_default_graph().get_tensor_by_name("overflow_status_reduce_all:0")

    l_s, overflow_status_reduce_all, global_steppp, ml, tl,ce_loss, bbox_loss, _, summary_str = sess.run(
                                      [lossScale, 
                                       overflow_status_reduce_all, 
                                       global_step,
                                       model_loss,
                                       total_loss,
                                       rpn_cross_entropy,
                                       rpn_loss_box,
                                       train_op, summary_op],
                                       feed_dict={input_image: input_image_np, input_bbox: input_bbox_np})
    summary_writer.add_summary(summary_str, global_step=step)
    print('loss_scale is: ', l_s)
    print("overflow_status_reduce_all:", overflow_status_reduce_all)
    print("global_step:", global_steppp)