ReduceScatterOperation
产品支持情况
| 
          硬件型号  | 
        
          是否支持  | 
        
          特殊说明  | 
       
|---|---|---|
| 
          | 
        
          √  | 
        
          -  | 
       
| 
          | 
        
          √  | 
        
          不支持  | 
       
| 
          | 
        
          √  | 
        
          -  | 
       
| 
          | 
        
          √  | 
        
          -  | 
       
| 
          | 
        
          x  | 
        
          -  | 
       
功能
多对多。在集群内的所有节点上都按维度执行相同的Reduce规约运算,再将结果发散到集群内所有的节点上。
    注:硬件类型仅支持
使用场景
- 可应用于数据并行和模型并行。
 - 数据并行allReduce里的 ReduceScatter+ Allgather组合里的ReduceScatter操作。
 - 模型并行里在前向allgather后的反向计算里的ReduceScatter。
 
使用示例
以rankSize=2为例
reduceType=sum
输入tensor的shape为(2 * 3 * 4)
输入tensor:
input_tensor_0 = [[[1, 2, 3, 4], [5, 6, 7, 8], [9, 10, 11, 12]],
                  [[13, 14, 15, 16], [17, 18, 19, 20], [21, 22, 23, 24]]]
input_tensor_1 = [[[25, 26, 27, 28], [29, 30, 31, 32], [33, 34, 35, 36]],
                  [[37, 38, 39, 40], [41, 42, 43, 44], [45, 46, 47, 48]]]
    输出tensor的shape为(1 * 3 * 4)
输出tensor:
output_tensor_0 = [[[26, 28, 30, 32],[34, 36, 38, 40],[42, 44, 46, 48]]] output_tensor_1 = [[[50, 52, 54, 56],[58, 60, 62, 64],[66, 68, 70, 72]]]
定义
          1 2 3 4 5 6 7 8 9 10 11 12  | 
         
          struct ReduceScatterParam{ int rank = 0; int rankSize = 0; int rankRoot = 0; std::string reduceType = "sum"; HcclComm hcclComm = nullptr; CommMode commMode = COMM_MULTI_PROCESS; std::string backend = "lccl"; std::string rankTableFile; std::string commDomain; uint8_t rsv[64] = {0}; };  | 
        
参数列表
| 
          成员名称  | 
        
          类型  | 
        
          默认值  | 
        
          取值范围  | 
        
          是否必选  | 
        
          描述  | 
       
|---|---|---|---|---|---|
| 
          rank  | 
        
          int  | 
        
          0  | 
        
          [0, rankSize-1]  | 
        
          是  | 
        
          当前卡所属通信编号。  | 
       
| 
          rankSize  | 
        
          int  | 
        
          0  | 
        
          -  | 
        
          是  | 
        
          通信的卡的数量。  | 
       
| 
          rankRoot  | 
        
          int  | 
        
          0  | 
        
          [0, rankSize-1]  | 
        
          是  | 
        
          主通信编号。  | 
       
| 
          reduceType  | 
        
          string  | 
        
          "sum"  | 
        
          sum prod max min  | 
        
          是  | 
        
          通信计算类型。 支持“sum”(相加),“prod”(相乘),“max”(取最大)和“min”(取最小)。  | 
       
| 
          backend  | 
        
          string  | 
        
          “lccl”  | 
        
          lccl/hccl  | 
        
          是  | 
        
          通信计算类型,仅支持“hccl”和“lccl”。 
  | 
       
| 
          hcclComm  | 
        
          HcclComm  | 
        
          nullptr  | 
        
          -  | 
        
          否  | 
        
          HCCL通信域指针。 默认为空,加速库为用户创建;若用户想要自己管理通信域,则需要传入该通信域指针,加速库使用传入的通信域指针来执行通信算子。  | 
       
| 
          commMode  | 
        
          CommMode  | 
        
          COMM_MULTI_PROCESS  | 
        
          COMM_MULTI_PROCESS/COMM_MULTI_THREAD  | 
        
          否  | 
        
          通信模式,CommMode类型枚举值。  | 
       
| 
          rankTableFile  | 
        
          string  | 
        
          无  | 
        
          -  | 
        
          否  | 
        
          集群信息的配置文件路径。  | 
       
| 
          commDomain  | 
        
          string  | 
        
          无  | 
        
          -  | 
        
          否  | 
        
          通信device组用通信域名标识,多通信域时使用。当backend为lccl时,commMode为多进程时,commDomain需要设置0-65535,支持设置通信域的内存大小,用于性能优化,配置方式通信域:通信域大小(MB),如“0:400”,当同一个通信域中多个算子配置冲突时,以第一个执行到的该通信域算子配置为准。 限制:默认值为200,当前配置大小不支持小于200。  | 
       
| 
          rsv[64]  | 
        
          uint8_t  | 
        
          {0}  | 
        
          [0]  | 
        
          否  | 
        
          预留参数。  | 
       
输入
| 
          参数  | 
        
          维度  | 
        
          数据类型  | 
        
          格式  | 
        
          描述  | 
       
|---|---|---|---|---|
| 
          x  | 
        
          [rankSize*n, dim_0, dim_1, ..., dim_n]  | 
        
          "lccl": float16/float/int8/int16/int32/bf16 "hccl": float16/float/int8/int16/int32/bf16  | 
        
          ND  | 
        
          输入tensor,维度小于等于8,第一维的大小为rankSize的n倍(n为正整数)。  | 
       
输出
| 
          参数  | 
        
          维度  | 
        
          数据类型  | 
        
          格式  | 
        
          描述  | 
       
|---|---|---|---|---|
| 
          output  | 
        
          [n, dim_0, dim_1, ..., dim_n]  | 
        
          "lccl": float16/float/int8/int16/int32/bf16 "hccl": float16/float/int8/int16/int32/bf16  | 
        
          ND  | 
        
          输出tensor,维度小于等于8,第一维的大小n=x[0]/rankSize,从第二维开始输出output的维数与输入x的维度以及维度值相等,数据类型也相同。  | 
       
规格约束
- rank、rankSize、rankRoot需满足以下条件。
      
- 0 ≤ rank < rankSize
 - 0 ≤ rankRoot < rankSize
 
 
