torch.distributed.distributed_c10d._world.default_pg._get_backend(torch.device("npu")).get_hccl_comm_name
功能描述
从初始化完成的集合通信域中获取集合通信域名字。
接口原型
torch.distributed.distributed_c10d._world.default_pg._get_backend(torch.device("npu")).get_hccl_comm_name(rankid->int) -> String
注:接口为pytorch的ProcessGroup类,backend为npu backend的方法。ProcessGroup可以是default_pg,也可以是torch.distributed.distributed_c10d.new_group创建的非default_pg。
调用该接口时,需要保证当前current device被设置为正确。
参数说明
rankid: 集合通信对应device的rankid。传入的rankid为全局的rankid,多机间device具有唯一的rankid。
输出说明
string类型的集合通信域的名字。
约束说明
- 使用该接口前确保init_process_group已被调用,且初始化的backend为hccl。
- Pytorch2.1及以后版本与Pytorch2.1之前的版本对该接口调用方式不同,见调用示例。
支持的PyTorch版本
- PyTorch 2.2.0
- PyTorch 2.1.0
- PyTorch 1.11.0
支持的型号
- Atlas 训练系列产品
- Atlas A2 训练系列产品
调用示例
import torch
import torch_npu
import torch.multiprocessing as mp
import os
from torch.distributed.distributed_c10d import _get_default_group
import torch.distributed as dist
def example(rank, world_size):
torch.npu.set_device("npu:" + str(rank))
dist.init_process_group("hccl", rank=rank, world_size=world_size)
default_pg = _get_default_group()
if torch.__version__ > '2.0':
hcomm_info = default_pg._get_backend(torch.device("npu")).get_hccl_comm_name(rank)
else:
hcomm_info = default_pg.get_hccl_comm_name(rank)
print(hcomm_info)
def main():
world_size = 2
mp.spawn(example,
args=(world_size, ),
nprocs=world_size,
join=True)
if __name__ == "__main__":
os.environ["MASTER_ADDR"] = "localhost"
os.environ["MASTER_PORT"] = "29505"
main()
父主题: Distributed