混合精度开启
在迁移完成后,为保证模型的性能,需要开启混合精度。用户可以根据场景选择引入APEX混合精度模块(推荐)或使用PyTorch框架(1.8.1版本及以上)自带的AMP功能。APEX模块的安装请参考相关README文档进行编译安装。
单卡训练开启混合精度
- (推荐)使用APEX混合精度模块。
- 导入混合精度模块。
from apex import amp
- 在模型和优化器定义之后初始化APEX模块。
model = ... optimizer = ... model, optimizer = amp.initialize(model, optimizer, combine_grad=True)
- 改写梯度反向传播loss.backward()。
loss = criterion(…) #将loss.backward()替换为如下形式 with amp.scale_loss(loss, optimizer) as scaled_loss: scaled_loss.backward() optimizer.step()
- 导入混合精度模块。
- 使用框架自带AMP功能(PyTorch 1.8.1版本及以上)。
model = ... optimizer = ... #在模型、优化器定义之后,使用AMP功能。 scaler=GradScaler() #创建缩放器 for epoch in epochs: for input, target in data: optimizer.zero_grad() with autocast(): output = model(input) loss = loss_fn(output, target) ...... scaler.scale(loss).backward() scaler.step(optimizer) scaler.update()
多卡训练开启混合精度
- (推荐)使用APEX混合精度并关闭combine_ddp开关,并开启PyTorch框架自带的DistributedDataParallel(DDP)模式,DDP模式可参考PyTorch官方文档使用。
from torch.nn.parallel import DistributedDataParallel as DDP from apex import amp ... model = ... optimizer = ... #在模型、优化器定义之后,初始化APEX模块。 model, optimizer = amp.initialize(model, optimizer, combine_ddp=False)
- 使用APEX混合精度并开启combine_ddp开关,并关闭PyTorch框架自带的DDP模式。
from apex import amp ... model = ... optimizer = ... #在模型、优化器定义之后,初始化APEX模块。 model, optimizer = amp.initialize(model, optimizer, combine_ddp=True)
- 使用AMP功能并开启PyTorch框架自带的DDP模式。
from torch.nn.parallel import DistributedDataParallel as DDP ... model = ... optimizer = ... #在模型、优化器定义之后,使用AMP功能。 scaler=GradScaler() model, optimizer = amp.initialize(model, optimizer, opt_level="O1") ... with autocast(): output = model(input) loss = loss_fn(output, target) ... scaler.scale(loss).backward() scaler.step(optimizer) scaler.update()
更多具体混合精度使用说明请参见混合精度说明。
父主题: 模型迁移与训练