LinearParallelOperation
功能
通信计算并行算子,该算子功能为linear和通信算子组合,通信和计算是并行处理,与串行处理相较性能大幅度提升。
该算子涉及多卡相关操作,可根据实际需求配置HCCL相关环境变量,具体请参见《CANN 环境变量参考》中的“集合通信相关配置”章节。配置TLS等相关操作可查看对应设备的《HCCN Tool 接口参考》。
约束
- 输入x / weight矩阵维度,通过transWeight配置需满足矩阵乘的维度关系。
- rank、rankSize、rankRoot需满足以下条件。
- 0 ≤ rank < rankSize
- 0 ≤ rankRoot < rankSize
定义
struct LinearParallelParam { enum ParallelType : int { UNDEFINED = -1, LINEAR_ALL_REDUCE = 0, LINEAR_REDUCE_SCATTER = 1, ALL_GATHER_LINEAR = 2, PURE_LINEAR = 3, MAX = 4, }; enum QuantType : int { QUANT_TYPE_UNDEFINED = -1, QUANT_TYPE_PER_TENSOR = 0, QUANT_TYPE_PER_CHANNEL = 1, QUANT_TYPE_PER_GROUP = 2, QUANT_TYPE_MAX = 3, }; bool transWeight = true; int rank = 0; int rankSize = 0; int rankRoot = 0; bool hasResidual = false; std::string backend = "hccl"; HcclComm hcclComm = nullptr; CommMode commMode = COMM_MULTI_PROCESS; std::string rankTableFile; ParallelType type = LINEAR_ALL_REDUCE; bool keepIntermediate = false; QuantType quantType = QUANT_TYPE_UNDEFINED; int32_t quantGroupSize = 0; aclDataType outDataType = ACL_DT_UNDEFINED; std::string commDomain; };
成员
成员名称 |
描述 |
---|---|
ParallelType |
通信类型。
|
QuantType |
QuantType类型。
|
transWeight |
权重是否需要转置,默认为true。 |
rank |
每张卡所属通信编号。 |
rankSize |
通信的卡的数量。 |
rankRoot |
主通信编号。 |
hasResidual |
是否叠加残差。配置为“false”时不叠加残差,为“true”时叠加残差。默认不叠加残差。 |
backend |
通信后端指示。支持“hccl”,“lccl”,“lcoc”。 |
hcclComm |
HCCL通信域指针。 默认为空,加速库为用户创建;若用户想要自己管理通信域,则需要传入该通信域指针,加速库使用传入的通信域指针来执行通信算子。 |
commMode |
通信模式,CommMode类型枚举值。HCCL多线程只支持外部传入通信域方式。 |
rankTableFile |
集群信息的配置文件路径,适用单机以及多机通信场景,当前仅支持hccl后端场景。 |
type |
权重并行类型。仅在“backend”为“lcoc”时生效。 |
keepIntermediate |
是否返回中间结果,仅在“ParallelType”使用“ALL_GATHER_LINEAR”时生效。 |
quantType |
量化类型。仅在“backend”为“lcoc”时生效。 |
quantGroupSize |
量化类型为“QUANT_TYPE_PER_GROUP”时有效。 |
outDataType |
|
commDomain |
通信Device组用通信域名标识,多通信域时使用,当前仅支持“hccl”。 |
输入
参数 |
维度 |
数据类型 |
格式 |
描述 |
---|---|---|---|---|
input |
[m, k]/[batch, m, k] |
|
ND |
矩阵乘运算的A矩阵。 k为32的整数倍。 |
weight |
[k, n] |
|
ND/NZ |
权重,矩阵乘的B矩阵。“backend”为“lcoc”时float16/bf16支持NZ;“backend”为“hccl”或“lccl”时仅float16支持NZ。 |
bias |
|
|
ND |
叠加的偏置矩阵。 n为16的整数倍。 |
deqScale |
|
|
ND |
反量化的scale。 量化时输入。 |
residual |
[n] |
float16/bfloat16 |
ND |
残差,用于叠加到最后的输出结果上。 |
输出
参数 |
维度 |
数据类型 |
格式 |
描述 |
---|---|---|---|---|
output |
|
float16/bfloat16 |
ND |
输出tensor,维度数与x一致。 |
intermediateOutput |
[m*rankSize, n]/[batch*rankSize, m, n] |
float16/bfloat16 |
ND |
输出tensor,维度数与x一致。 |