集合通信算子入图

功能简介

集合通信算子入图能够从整图层面进行通信与计算并行优化,原生PyTorch 2.1版本暂不支持集合通信算子入图,从PyTorch 2.2及更高版本开始陆续支持。

原生PyTorch的dynamo trace入图仅支持算子和python内置函数入图,普通python函数会被dynamo拆解,将算子记录到图中。早期PyTorch提供的集合通信能力以API形式呈现,无法直接将其转换入图,PyTorch社区新增了一系列torch.ops.c10_functional前缀的集合通信算子解决该问题。当用户使用torch.ops.c10_functional算子时,需手动修改模型脚本中的集合通信部分,该部分通常被Deepspeed等框架封装,因此社区在dynamo内部将集合通信API转换为torch.ops.c10_functional算子,实现了早期集合通信API调用脚本也能trace入图。然而,社区在dynamo内部将集合通信API转换为torch.ops.c10_functional算子的行为,在PyTorch 2.3版本才逐渐补齐相关能力,PyTorch 2.1版本处于不可用状态。

针对PyTorch 2.1版本不支持入图的集合通信算子,TorchAir提供了补丁方案(patch),封装了一系列NPU通信算子,实现集合通信算子入图。

使用方法

假设训练/推理脚本在PyTorch 2.1版本能以Eager模式正常运行,用户无需修改脚本,直接导入补丁包(Monkey-Patch)即可完成通信算子入图,方案如下:

分布式场景下,由于Deepspeed框架对原生allreduce API的封装使其无法入图,TorchAir可同时Patch Deepspeed框架中allreduce封装函数,以解决入图问题。