昇腾社区首页
中文
注册

混合精度开启

在迁移完成后,为保证模型的性能,需要开启混合精度。用户可以根据场景选择引入APEX混合精度模块(推荐)或使用PyTorch框架(1.8.1版本及以上)自带的AMP功能。APEX模块的安装请参考相关README文档进行编译安装。

单卡训练开启混合精度

  • (推荐)使用APEX混合精度模块。
    1. 导入混合精度模块。
      from apex import amp
    2. 在模型和优化器定义之后初始化APEX模块。
      model = ...
      optimizer = ...
      model, optimizer = amp.initialize(model, optimizer, combine_grad=True)
    3. 改写梯度反向传播loss.backward()
      loss = criterion(…) 
      #将loss.backward()替换为如下形式
      with amp.scale_loss(loss, optimizer) as scaled_loss:     
          scaled_loss.backward() 
      optimizer.step()
  • 使用框架自带AMP功能(PyTorch 1.8.1版本及以上)。
    model = ...
    optimizer = ...
    #在模型、优化器定义之后,使用AMP功能。
    scaler=GradScaler()
    #创建缩放器
    for epoch in epochs:
        for input, target in data:
            optimizer.zero_grad()
            with autocast():
                output = model(input)
                loss = loss_fn(output, target)
            ......
            scaler.scale(loss).backward()
            scaler.step(optimizer)
            scaler.update()

多卡训练开启混合精度

  • (推荐)使用APEX混合精度并关闭combine_ddp开关,并开启PyTorch框架自带的DistributedDataParallel(DDP)模式,DDP模式可参考PyTorch官方文档使用。
    from torch.nn.parallel import DistributedDataParallel as DDP
    from apex import amp
    ...
    model = ...
    optimizer = ...
    #在模型、优化器定义之后,初始化APEX模块。
    model, optimizer = amp.initialize(model, optimizer, combine_ddp=False)
  • 使用APEX混合精度并开启combine_ddp开关,并关闭PyTorch框架自带的DDP模式。
    from apex import amp
    ...
    model = ...
    optimizer = ...
    #在模型、优化器定义之后,初始化APEX模块。
    model, optimizer = amp.initialize(model, optimizer, combine_ddp=True)
  • 使用AMP功能并开启PyTorch框架自带的DDP模式。
    from torch.nn.parallel import DistributedDataParallel as DDP
    ...
    model = ...
    optimizer = ...
    #在模型、优化器定义之后,使用AMP功能。
    scaler=GradScaler()
    model, optimizer = amp.initialize(model, optimizer, opt_level="O1")
    ...
            with autocast():
                output = model(input)
                loss = loss_fn(output, target)
    ...
            scaler.scale(loss).backward()
            scaler.step(optimizer)
            scaler.update()

更多具体混合精度使用说明请参见混合精度说明