LinearParallelOperation
功能
多卡并行Linear计算处理。
该算子涉及多卡相关操作,可根据实际需求配置HCCL相关环境变量,具体请参见《环境变量参考》中 章节。配置TLS等相关操作可查看对应设备的《HCCN Tool 接口参考》。
约束
- 输入x / weight矩阵维度,通过transWeight配置要满足矩阵乘的维度关系。
- rank、rankSize、rankRoot需满足以下条件。
- 0 ≤ rank < rankSize
- 0 ≤ rankRoot < rankSize
定义
struct LinearParallelParam { bool transWeight = false; int rank = 0; int rankSize = 0; int rankRoot = 0; std::string bias = ""; std::string parallelType = "RowParallel"; std::string backend = "hccl"; HcclComm hcclComm = nullptr; };
成员
成员名称 |
描述 |
---|---|
transWeight |
权重是否不需要转置,默认为false(即需要转置)。 |
rank |
每个进程的编号。 |
rankSize |
总的进程数。 |
rankRoot |
主进程编号。 |
bias |
是否叠加偏置。配置为"None"时不叠加偏置,否则为叠加偏置。默认叠加偏置。 |
parallelType |
权重并行方式。仅支持“RowParallel”。 |
backend |
通信后端指示。仅支持“hccl”和“lccl”。 |
hcclComm |
HCCL通信域接口获取的地址指针,仅当hcclComm不为nullptr时可用。 |
输入
输出
参数 |
维度 |
数据类型 |
格式 |
---|---|---|---|
output |
根据以上输入维度,输出维度为:
|
float16 |
ND |