消减内存碎片

基本原理

频繁地进行内存申请、释放将会产生内存碎片。可以通过调整训练脚本,让内存申请逻辑尽量亲和PyTorch内存池逻辑,以减少内存碎片产生。长生命周期内存在训练开始时优先申请,如模型权重、梯度、优化器状态等;或优先申请大内存后申请小内存以提高内存复用率;功能等效的情况下尽量串行申请释放内存,避免批量申请释放内存,提高内存复用。

使用场景

针对Pytorch内存管理特点,修改脚本使内存申请与PyTorch内存管理更亲和。

操作步骤

  1. 采集内存profiling,从单次内存较大的内存分析是否可以优化。
  2. 将长生命周期内存放在训练代码的最开始申请,对短生命周期内存尽量做从大到小的申请释放逻辑处理。
  3. 调整网络脚本,避免不必要的内存生命周期并行,通过串行的方式提高内存复用率。如下代码,可以先申请一个bucket做完集合通信后在申请下一个,避免一次申请多个导致的内存峰值推高。

    for bucket in buckets:
        bucket.apply()
    
    for bucket in buckets:
        dist_group.broadcast(bucket)
    
    for bucket in buckets:
        bucket.release()