LinearParallelOperation
功能
通信计算并行算子,该算子功能为linear和通信算子组合,通信和计算是并行处理,与串行处理相较性能大幅度提升。
该算子涉及多卡相关操作,可根据实际需求配置HCCL相关环境变量,具体请参见《CANN 环境变量参考》中的“集合通信相关配置”章节。配置TLS等相关操作可查看对应设备的《HCCN Tool 接口参考》。
约束
- 输入x / weight矩阵维度,通过transWeight配置需满足矩阵乘的维度关系。
- rank、rankSize、rankRoot需满足以下条件。
- 0 ≤ rank < rankSize
- 0 ≤ rankRoot < rankSize
定义
struct LinearParallelParam {
enum ParallelType : int {
UNDEFINED = -1,
LINEAR_ALL_REDUCE = 0,
LINEAR_REDUCE_SCATTER = 1,
ALL_GATHER_LINEAR = 2,
PURE_LINEAR = 3,
MAX = 4,
};
enum QuantType : int {
QUANT_TYPE_UNDEFINED = -1,
QUANT_TYPE_PER_TENSOR = 0,
QUANT_TYPE_PER_CHANNEL = 1,
QUANT_TYPE_PER_GROUP = 2,
QUANT_TYPE_MAX = 3,
};
bool transWeight = true;
int rank = 0;
int rankSize = 0;
int rankRoot = 0;
bool hasResidual = false;
std::string backend = "hccl";
HcclComm hcclComm = nullptr;
CommMode commMode = COMM_MULTI_PROCESS;
std::string rankTableFile;
ParallelType type = LINEAR_ALL_REDUCE;
bool keepIntermediate = false;
QuantType quantType = QUANT_TYPE_UNDEFINED;
int32_t quantGroupSize = 0;
aclDataType outDataType = ACL_DT_UNDEFINED;
std::string commDomain;
};
成员
|
成员名称 |
描述 |
|---|---|
|
ParallelType |
通信类型。
|
|
QuantType |
QuantType类型。
|
|
transWeight |
权重是否需要转置,默认为true。 |
|
rank |
每张卡所属通信编号。 |
|
rankSize |
通信的卡的数量。 |
|
rankRoot |
主通信编号。 |
|
hasResidual |
是否叠加残差。配置为“false”时不叠加残差,为“true”时叠加残差。默认不叠加残差。 |
|
backend |
通信后端指示。支持“hccl”,“lccl”,“lcoc”。 |
|
hcclComm |
HCCL通信域指针。 默认为空,加速库为用户创建;若用户想要自己管理通信域,则需要传入该通信域指针,加速库使用传入的通信域指针来执行通信算子。 |
|
commMode |
通信模式,CommMode类型枚举值。HCCL多线程只支持外部传入通信域方式。 |
|
rankTableFile |
集群信息的配置文件路径,适用单机以及多机通信场景,当前仅支持hccl后端场景。 配置请参见。 |
|
type |
权重并行类型。仅在“backend”为“lcoc”时生效。 |
|
keepIntermediate |
是否返回中间结果,仅在“ParallelType”使用“ALL_GATHER_LINEAR”时生效。 |
|
quantType |
量化类型。仅在“backend”为“lcoc”时生效。 |
|
quantGroupSize |
量化类型为“QUANT_TYPE_PER_GROUP”时有效。 |
|
outDataType |
|
|
commDomain |
通信Device组用通信域名标识,多通信域时使用,当前仅支持“hccl”。 |
输入
|
参数 |
维度 |
数据类型 |
格式 |
描述 |
|---|---|---|---|---|
|
input |
[m, k]/[batch, m, k] |
|
ND |
矩阵乘运算的A矩阵。 k为32的整数倍。 |
|
weight |
[k, n] |
|
ND/NZ |
权重,矩阵乘的B矩阵。“backend”为“lcoc”时float16/bf16支持NZ;“backend”为“hccl”或“lccl”时仅float16支持NZ。 |
|
bias |
|
|
ND |
叠加的偏置矩阵。 n为16的整数倍。 |
|
deqScale |
|
|
ND |
反量化的scale。 量化时输入。 |
|
residual |
[n] |
float16/bfloat16 |
ND |
残差,用于叠加到最后的输出结果上。 |
输出
|
参数 |
维度 |
数据类型 |
格式 |
描述 |
|---|---|---|---|---|
|
output |
|
float16/bfloat16 |
ND |
输出tensor,维度数与x一致。 |
|
intermediateOutput |
[m*rankSize, n]/[batch*rankSize, m, n] |
float16/bfloat16 |
ND |
输出tensor,维度数与x一致。 |
