消减非必要的异步集合通信算子

基本原理

集合通信算子的输入输出内存在计算流上申请,会被集合通信流依赖,因此会存在延迟释放。跨流内存复用优化是通过在handle.wait()时及时释放被通信流依赖的内存达到内存优化的目的。集合通信算子异步的场景,会延迟wait调用抵消跨流内存复用的优化。

使用场景

集合通信算子异步场景。

操作步骤

需要排查集合通信算子asyncOp=True是否合理,并行代码表达是否必要。如果为了提高性能做计算通信并行,可以评估权衡内存收益和性能收益。非必要的异步,这里的all_reduce没有与计算并行,因此可以改为同步:

handles = []

for bucket in buckets:
    handles.append(dist_group.all_reduce(npu_tensor1, async_op=True))

for handle in handles:
    handle.wait()