昇腾社区首页
中文
注册

ReduceScatterVOperation

产品支持情况

硬件型号

是否支持

Atlas A3 推理系列产品/Atlas A3 训练系列产品

x

Atlas A2 训练系列产品/Atlas 800I A2 推理产品

Atlas 训练系列产品

x

Atlas 推理系列产品

Atlas 200I/500 A2 推理产品

x

功能说明

将多个通信卡上的数据进行计算,支持相加、取最大、最小三种计算,然后不等地发送到每张卡上。

使用场景

推理场景中会出现batch size不能被TP数整除的情况,reducescatter后续的计算算子需要按照batch维度处理数据,再将处理数据进行allgather,如图图1所示。

图1 算子上文示意图

对于原始的ReduceScatter来说,要求在每张卡的数据一样多,对于ReduceScatterV来说,假设有5份数据,可以在一张卡上分三份,一张卡上分两份

使用示例:

  • 输入
    [[[0], [1], [2], [3], [4]], [[0], [1], [2], [3], [4]]]
    [[2, 3], [2, 3]]
    [[0, 2], [0, 2]]
    [2, 3]
  • 输出

    NPU0 的输出

    [[0], [2]]

    NPU1 的输出

    [[4], [6], [8]]

定义

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
struct ReduceScatterVParam {
    int rank = 0;
    int rankSize = 0;
    int rankRoot = 0;
    std::vector<int64_t> sendCounts;
    std::vector<int64_t> sdispls;
    std::int64_t recvCount = 0;
    std::string reduceType = "sum";
    HcclComm hcclComm = nullptr;
    CommMode commMode = COMM_MULTI_PROCESS;
    std::string backend = "hccl";
    std::string rankTableFile;
    std::string commDomain;
    uint8_t rsv[64] = {0};
};

参数列表

成员名称

类型

默认值

取值范围

是否必选

描述

rank

int

0

[0, rankSize-1]

当前卡所属通信编号。

rankSize

int

0

-

通信的卡的数量。

rankRoot

int

0

[0, rankSize-1]

主通信编号。

reduceType

string

"sum"

sum

prod

max

min

通信计算类型。

支持sum(相加)、prod(相乘)、max(取最大)和min(取最小)。

backend

string

“hccl”

hccl

通信计算类型,仅支持hccl。
  • 支持sum、prod、max和min。
  • “reduceType”为prod时,不支持数据类型为int16和bf16。
  • Atlas 推理系列产品不支持int64、bf16,int16只有“reduceType”为sum时支持。

hcclComm

HcclComm

nullptr

hccl

HCCL通信域指针。当前算子仅支持hccl,此参数为预留参数。默认为空,加速库为用户创建;若用户想要自己管理通信域,则需要传入该通信域指针,加速库使用传入的通信域指针来执行通信算子。

commMode

CommMode

COMM_MULTI_PROCESS

COMM_MULTI_PROCESS/COMM_MULTI_THREAD

通信模式,CommMode类型枚举值。

rankTableFile

string

-

集群信息的配置文件路径。

commDomain

string

-

通信device组用通信域名标识。

rsv[64]

uint8_t

{0}

[0]

预留参数。

输入

参数

维度

数据类型

格式

是否必选

描述

x

[dim_0, dim_1]

"hccl": float16/int8

ND

输入tensor,dim_0,dim_1大小没有限制。

sendCounts

[rankSize]

int64

ND

一维张量,其长度等于rankSize卡数,张量里的每个索引的值代表着每张卡分配的数据量,例如sendCounts[0]:rank0分配的数据量。张量中值的和等于x的dim_0。

sdispls

[rankSize]

int64

ND

一维张量 ,其长度等于ranksize卡数,张量里的每个索引的值代表着从对应索引卡号接收到的数据量的偏移,sdispls [0] = n表示rank0从相对于输入起始位置的偏移量为n的位置开始接收sendCounts[0]的数据量。张量中的值需小于x的dim_0。

recvCount

[1]

int64

ND

一维张量,该维度大小就是该固定1,recvCount 代表着该卡接收到的数据量。数值与sendCounts保持一致。

y

[dim_0]

float16

ND

一维张量,该张量用来推导output的shape。数值与recvCount相等。

输出

参数

维度

数据类型

格式

是否必选

描述

output

[dim_0, dim_1]

"hccl": float16/int8

ND

输出tensor,dim_0为y的shape。dim_0等于recvCount.shape[rank],dim_1等于x.shape[1]。

约束说明

  • rank、rankSize、rankRoot需满足以下条件。
    • 0 ≤ rank < rankSize
    • 0 ≤ rankRoot < rankSize