昇腾社区首页
中文
注册

LinearParallelOperation

功能

多卡并行Linear计算处理。

该算子涉及多卡相关操作,可根据实际需求配置HCCL相关环境变量,具体请参见环境变量参考执行相关 > 集合通信与分布式训练 > 集合通信相关配置章节。配置TLS等相关操作可查看对应设备的《HCCN Tool 接口参考》

约束

  • 输入x / weight矩阵维度,通过transWeight配置要满足矩阵乘的维度关系。
  • rank、rankSize、rankRoot需满足以下条件。
    • 0 ≤ rank < rankSize
    • 0 ≤ rankRoot < rankSize

定义

struct LinearParallelParam {
    bool transWeight = false;
    int rank = 0;
    int rankSize = 0;
    int rankRoot = 0;
    std::string bias = "";
    std::string parallelType = "RowParallel";
    std::string backend = "hccl";
    HcclComm hcclComm = nullptr; 
};

成员

成员名称

描述

transWeight

权重是否需要转置,默认为false(即需要转置)。

rank

每个进程的编号。

rankSize

总的进程数。

rankRoot

主进程编号。

bias

是否叠加偏置。配置为"None"时不叠加偏置,否则为叠加偏置。默认叠加偏置。

parallelType

权重并行方式。仅支持“RowParallel”。

backend

通信后端指示。仅支持“hccl”和“lccl”。

hcclComm

HCCL通信域接口获取的地址指针,仅当hcclComm不为nullptr时可用。

输入

参数

维度

数据类型

格式

x

  1. x: [batch, m, k]

    weight: [k, n]

  2. x: [m,k]

    weight: [k, n]

float16

ND

weight

float16

ND

bias

[1, n] / [n]

float16

ND

输出

参数

维度

数据类型

格式

output

根据以上输入维度,输出维度为:

  1. output: [batch, m, n]
  2. output: [m ,n]

float16

ND