APEX模块使用

  1. 使用APEX混合精度模块需要首先从APEX库中导入AMP,代码如下:
    from apex import amp
  2. 导入AMP模块后,需要初始化AMP,使其能对模型、优化器以及PyTorch内部函数进行必要的改动,初始化代码如下:
    model, optimizer = amp.initialize(model, optimizer, combine_grad=True)
  3. 标记反向传播.backward()发生的位置,这样AMP就可以进行Loss Scaling并清除每次迭代的状态,原始代码如下:
    loss = criterion(…) 
    loss.backward() 
    optimizer.step()

    修改以支持Loss Scaling后的代码如下:

    loss = criterion(…) 
    with amp.scale_loss(loss, optimizer) as scaled_loss:     
        scaled_loss.backward() 
    optimizer.step()
  4. 更多混合精度模块的使用可参见官方文档