from apex import amp
model, optimizer = amp.initialize(model, optimizer, combine_grad=True)
loss = criterion(…) loss.backward() optimizer.step()
修改以支持Loss Scaling后的代码如下:
loss = criterion(…) with amp.scale_loss(loss, optimizer) as scaled_loss: scaled_loss.backward() optimizer.step()