集合通信算子入图
功能简介
集合通信算子入图能够从整图层面进行通信与计算并行优化,而原生PyTorch 2.1版本暂不支持集合通信算子入图,从PyTorch 2.2版本及更高版本开始陆续支持相关算子入图。
原生PyTorch的dynamo trace入图逻辑是只有算子和python内置方法才会入图,普通python函数会被dynamo拆解,将算子记录到图中。而早期PyTorch提供的集合通信能力以API形式呈现,无法直接将其转换入图,为此PyTorch社区新增了以torch.ops.c10_functional开头的集合通信算子解决该问题。由于用户使用torch.ops.c10_functional系列集合通信算子,需要修改网络模型脚本中关于集合通信的部分,而这部分往往被如Deepspeed等框架所封装,因此社区在dynamo内部将集合通信API转换为torch.ops.c10_functional系列算子,达到早期集合通信API方式调用脚本也能trace入图的目的。然而,社区在dynamo内部将集合通信API转换为torch.ops.c10_functional系列算子的行为,在PyTorch 2.3版本才逐渐补齐相关能力,PyTorch 2.1版本处于不可用状态。
基于PyTorch现状,TorchAir提供了如下解决方案:
- 针对原生暂不支持入图的集合通信算子,TorchAir提供工具方法供用户显式调用,完成集合通信算子入图。
- 针对原生已经支持入图的集合通信算子,TorchAir补齐相关入图算子的converter后,用户无需显式调用工具方法。
使用方法
假设用户PyTorch分布式脚本已在PyTorch 2.1版本中能以Eager模式(即单算子调用模式)运行正常,如需开启图模式,可采用如下方法:
- 方法1(推荐):手动修改脚本,将网络中使用的集合通信API替换为对应的torch.ops.c10_functional系列集合通信算子(TorchAir已补齐相关converter,如表1所示)。
- 方法2:
针对无法手动修改脚本或者原生PyTorch通信算子不支持的情况,可通过TorchAir提供的工具方法解决原生图模式在PyTorch 2.1版本中不可用的问题。
工具方法的原理是TorchAir新增自定义NPU通信算子torch.ops.npu_difine.allreduce,将原生的torch.distributed.all_reduce API通过补丁(Monkey Patch)方式替换为该算子。
# 导入patch_for_hcom_allreduce包 import torch_npu import torchair.ge_concrete_graph.ge_converter.experimental.patch_for_hcom_allreduce
- TorchAir中自定义的npu_difine集合通信算子allreduce,目前只支持默认的集合通信域,不支持指定集合通信域,只支持SUM这一种计算类型。
- 由于Deepspeed框架对原生allreduce API的封装使其无法入图,TorchAir可以同时patch Deepspeed框架中allreduce封装函数,以解决入图问题。
- 方法3:
可以升级PyTorch版本至2.3或者更高版本,在原生版本补齐相关dynamo trace能力后无需其他操作。