昇腾社区首页
中文
注册

使用

  1. 从APEX库中导入AMP。
    from apex import amp
  2. 初始化AMP,使其能对模型、优化器以及PyTorch内部函数进行必要的改动。
    model, optimizer = amp.initialize(model, optimizer, combine_grad=True)
  3. 标记反向传播.backward()发生的位置,这样AMP就可以进行Loss Scaling并清除每次迭代的状态。
    原始代码如下:
    loss = criterion(…) 
    loss.backward() 
    optimizer.step()
    修改后支持Loss Scaling的代码如下:
    loss = criterion(…) 
    with amp.scale_loss(loss, optimizer) as scaled_loss:     
        scaled_loss.backward() 
    optimizer.step()

更多混合精度模块的使用可参见官方文档