适配样例(多模型、损失函数和优化器场景)
多模型、损失函数和优化器场景是指在神经网络中同时存在多个模型、多个损失函数和优化器的场景。
导入AMP模块
导入AMP模块,定义两个简单的神经网络。
import time
import torch
import torch.nn as nn
import torch_npu
from torch_npu.npu import amp
from torch.utils.data import Dataset, DataLoader
import torchvision
device = torch.device('npu:0') # 用户请自行定义训练设备
# 定义一个简单的神经网络
class CNN(nn.Module):
def __init__(self):
super(CNN, self).__init__()
self.net = nn.Sequential(
nn.Conv2d(in_channels = 1, out_channels = 16,
kernel_size = (3, 3),
stride = (1, 1),
padding = 1),
nn.MaxPool2d(kernel_size = 2),
nn.Conv2d(16, 32, 3, 1, 1),
nn.MaxPool2d(2),
nn.Flatten(),
nn.Linear(32*7*7, 16),
nn.ReLU(),
nn.Linear(16, 10)
)
def forward(self, x):
return self.net(x)
# 定义第二个相似的神经网络,增加一层卷积层。
class CNN_2(nn.Module):
def __init__(self):
super(CNN_2, self).__init__()
self.net = nn.Sequential(
nn.Conv2d(in_channels = 1, out_channels = 16,
kernel_size = (3, 3),
stride = (1, 1),
padding = 1),
nn.MaxPool2d(kernel_size = 2),
nn.Conv2d(16, 32, 3, 1, 1),
nn.MaxPool2d(2),
nn.Conv2d(32, 32, 3, 1, 1),
nn.Flatten(),
nn.Linear(32*7*7, 16),
nn.ReLU(),
nn.Linear(16, 10)
)
def forward(self, x):
return self.net(x)
定义损失缩放参数
在模型、优化器定义之后,定义AMP功能中的GradScaler。
# 数据集获取
train_data = torchvision.datasets.MNIST(
root = 'mnist',
download = True,
train = True,
transform = torchvision.transforms.ToTensor()
)
batch_size = 64
model0 = CNN().to(device) # 把模型0放到指定NPU上
model1 = CNN_2().to(device) # 把模型1放到指定NPU上
train_dataloader = DataLoader(train_data, batch_size = batch_size) # 定义DataLoader
loss_func = nn.CrossEntropyLoss().to(device) # 定义损失函数
optimizer0 = torch.optim.SGD(model0.parameters(), lr = 0.1) # 定义优化器0
optimizer1 = torch.optim.SGD(model1.parameters(), lr = 0.1) # 定义优化器1
scaler = amp.GradScaler() # 在模型、优化器定义之后,定义GradScaler
epochs = 10 # 设置循环次数
适配AMP并训练
在训练代码中添加AMP功能相关的代码开启AMP,对多个损失函数和优化器进行计算。
for epoch in range(epochs):
for imgs, labels in train_dataloader:
imgs = imgs.to(device)
labels = labels.to(device)
with amp.autocast():
outputs0 = model0(imgs) # 前向计算
outputs1 = model1(imgs)
loss0 = loss_func(2*outputs0+3*outputs1, labels) # 损失函数计算
loss1 = loss_func(3*outputs0-5*outputs1, labels)
optimizer0.zero_grad()
optimizer1.zero_grad()
# 进行反向传播前后的loss缩放、参数更新
scaler.scale(loss0).backward(retain_graph=True) # loss缩放并反向转播
scaler.scale(loss1).backward()
scaler.step(optimizer0) # 更新参数(自动unscaling)
scaler.step(optimizer1)
scaler.update() # 基于动态Loss Scale更新loss_scaling系数
父主题: 混合精度适配(可选)