拉起多卡训练脚本示例

构建模型脚本

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
# 导入依赖和库 
import torch 
from torch import nn 
import torch_npu 
import torch.distributed as dist 
from torch.utils.data import DataLoader 
from torchvision import datasets 
from torchvision.transforms import ToTensor 
import time 
import torch.multiprocessing as mp 
import os 
 
torch.manual_seed(0) 
# 下载训练数据 
training_data = datasets.FashionMNIST( 
    root="./data", 
    train=True, 
    download=True, 
    transform=ToTensor(), 
) 
 
# 下载测试数据 
test_data = datasets.FashionMNIST( 
    root="./data", 
    train=False, 
    download=True, 
    transform=ToTensor(), 
) 
 
# 构建模型 
class NeuralNetwork(nn.Module): 
    def __init__(self): 
        super().__init__() 
        self.flatten = nn.Flatten() 
        self.linear_relu_stack = nn.Sequential( 
            nn.Linear(28*28, 512), 
            nn.ReLU(), 
            nn.Linear(512, 512), 
            nn.ReLU(), 
            nn.Linear(512, 10) 
        ) 
 
    def forward(self, x): 
        x = self.flatten(x) 
        logits = self.linear_relu_stack(x) 
        return logits 
 
def test(dataloader, model, loss_fn): 
    size = len(dataloader.dataset) 
    num_batches = len(dataloader) 
    model.eval() 
    test_loss, correct = 0, 0 
    with torch.no_grad(): 
        for X, y in dataloader: 
            X, y = X.to(device), y.to(device) 
            pred = model(X) 
            test_loss += loss_fn(pred, y).item() 
            correct += (pred.argmax(1) == y).type(torch.float).sum().item() 
    test_loss /= num_batches 
    correct /= size 
    print(f"Test Error: \n Accuracy: {(100*correct):>0.1f}%, Avg loss: {test_loss:>8f} \n")

获取分布式超参数

在模型脚本中,构建主函数main,在其中获取分布式训练所需的超参数。

设置地址和端口号

在模型脚本中设置地址与端口号,用于拉起分布式训练。由于昇腾AI处理器初始化进程组时initmethod只支持env:// (即环境变量初始化方式),所以在初始化前需要根据实际情况配置MASTER_ADDR、MASTER_PORT等参数。

添加分布式逻辑

不同的拉起训练方式下,device号的获取方式不同:

用户需根据自己选择的方式对代码做不同的修改。下面是修改代码示例:

配置传参逻辑

在模型脚本中,根据拉起方式不同,需要传入不同的参数,传参配置逻辑如下(此处使用argparse逻辑):

拉起单机八卡训练

我们给出了每种方式的拉起命令示例,用户可根据实际情况自行更改。

当屏幕打印/定向日志中出现模型加载、训练等正常运行日志时,说明拉起多卡训练成功,如图1所示。

图1 屏幕运行日志截图

拉起双机16卡训练

启动命令需要在集群每台机器分别执行:

当屏幕打印/定向日志中出现模型加载、训练等正常运行日志时,说明拉起双机训练成功。