LinearParallelOperation

功能

通信计算并行算子,该算子功能为linear和通信算子组合,通信和计算是并行处理,与串行处理相较性能大幅度提升。

该算子涉及多卡相关操作,可根据实际需求配置HCCL相关环境变量,具体请参见《CANN 环境变量参考》中的“集合通信”章节

定义

struct LinearParallelParam {
    enum ParallelType : int {
        UNDEFINED = -1,
        LINEAR_ALL_REDUCE = 0,
        LINEAR_REDUCE_SCATTER = 1,
        ALL_GATHER_LINEAR = 2,
        PURE_LINEAR = 3,
        MAX = 4,
    };
    enum QuantType : int {
        QUANT_TYPE_UNDEFINED = -1,
        QUANT_TYPE_PER_TENSOR = 0,
        QUANT_TYPE_PER_CHANNEL = 1,
        QUANT_TYPE_PER_GROUP = 2,
        QUANT_TYPE_MAX = 3,
    };
    bool transWeight = true;                    
    int rank = 0;                              
    int rankSize = 0;                          
    int rankRoot = 0;                          
    bool hasResidual = false;
    std::string backend = "hccl";              
    HcclComm hcclComm = nullptr;               
    CommMode commMode = COMM_MULTI_PROCESS;
    std::string rankTableFile;
    ParallelType type = LINEAR_ALL_REDUCE;    
    bool keepIntermediate = false;            
    QuantType quantType = QUANT_TYPE_UNDEFINED; 
    int32_t quantGroupSize = 0;                
    aclDataType outDataType = ACL_DT_UNDEFINED;
    std::string commDomain;
};

参数列表

成员名称

类型

默认值

描述

transWeight

bool

true

权重是否需要转置,默认为true。

rank

int

0

当前卡所属通信编号。

rankSize

int

0

通信的卡的数量。

rankRoot

int

0

主通信编号。

hasResidual

bool

false

是否叠加残差。配置为“false”时不叠加残差,为“true”时叠加残差。默认不叠加残差。

backend

std::string

"hccl"

通信后端指示。支持“hccl”,“lccl”,“lcoc”。

Atlas 推理系列产品 仅支持“backend”为“hccl”。

hcclComm

HcclComm

nullptr

HCCL通信域指针。

默认为空,加速库为用户创建;若用户想要自己管理通信域,则需要传入该通信域指针,加速库使用传入的通信域指针来执行通信算子。

commMode

CommMode

COMM_MULTI_PROCESS

通信模式,CommMode类型枚举值。hccl多线程只支持外部传入通信域方式。

rankTableFile

std::string

-

集群信息的配置文件路径,适用单机以及多机通信场景,当前仅支持hccl后端场景。

配置请参见《TensorFlow 1.15模型迁移指南》的“准备ranktable资源配置文件

type

ParallelType

LINEAR_ALL_REDUCE

权重并行类型。仅在“backend”为“lcoc”时生效。

  • UNDEFINED:默认值
  • LINEAR_ALL_REDUCE:linear + AllReduce。
  • LINEAR_REDUCE_SCATTER:linear + reduce_scatter。
  • ALL_GATHER_LINEAR:AllGather + linear。
  • PURE_LINEAR:linear。
  • MAX :枚举类型最大值。

keepIntermediate

bool

false

是否返回中间结果,仅在“ParallelType”使用“ALL_GATHER_LINEAR”时生效。

quantType

QuantType

QUANT_TYPE_UNDEFINED

量化类型。仅在“backend”为“lcoc”时生效。

  • QUANT_TYPE_UNDEFINED:默认值。
  • QUANT_TYPE_PER_TENSOR:对整个张量进行量化。
  • QUANT_TYPE_PER_CHANNEL:对张量中每个channel分别进行量化。
  • QUANT_TYPE_PER_GROUP:将张量按quantGroupSize划分后,分别进行量化。
  • QUANT_TYPE_MAX:枚举类型最大值。

quantGroupSize

int32_t

0

量化类型为“QUANT_TYPE_PER_GROUP”时有效。

outDataType

aclDataType

ACL_DT_UNDEFINED

  • 若为浮点linear,参数“outDataType”配置为ACL_DT_UNDEFINED,表示输出tensor的数据类型与输入tensor一致。
  • 若为量化linear,输出tensor的数据类型与输入tensor不一致,则参数“outDataType”配置为用户预期输出tensor的数据类型, 如ACL_FLOAT16/ACL_BF16。

commDomain

std::string

-

通信Device组用通信域名标识,多通信域时使用,当前仅支持“hccl”

输入

参数

维度

数据类型

格式

描述

input

[m, k]/[batch, m, k]

  • 浮点:float16/bf16
  • 量化:float16/bf16/int8

ND

矩阵乘运算的A矩阵。

k为32的整数倍。

weight

[k, n]

NZ:浮点额外支持[1, n/16, k, 16]

  • 浮点:float16/bf16
  • 量化:int8
  • 浮点:ND/NZ
  • 量化:ND

权重,矩阵乘的B矩阵。“backend”为“lcoc”时float16/bf16支持NZ;“backend”为“hccl”或“lccl”时仅float16支持NZ。

bias

  • “quantType”“per_tensor”时支持:[1]
  • “quantType”“per_channel”时支持:[1, n]/[n]
  • “quantType”“为”“per_group”时支持:[k/quantGroupSize, n]
  • W8A16量化场景:float16/bf16
  • W8A8量化场景:int32

ND

叠加的偏置矩阵。

n为16的整数倍。

deqScale

  • “quantType”“per_tensor”时支持:[1]
  • “quantType”“per_channel”时支持:[1, n]/[n]
  • “quantType”“为”“per_group”时支持:[k/quantGroupSize, n]
  • W8A16量化场景:float16/bf16
  • W8A8量化场景:int64

ND

反量化的scale。

量化时输入。

residual

[n]

float16/bf16

ND

残差,用于叠加到最后的输出结果上。

输出

参数

维度

数据类型

格式

描述

output

  • 当type为linear_all_reduce/pure_linear:[m, n]/[batch, m, n]
  • 当type为linear_reduce_scatter:[m/rankSize, n]/[batch/rankSize, m, n]
  • 当type为all_gather_linear:[m*rankSize, n]/[batch*rankSize, m, n];

float16/bf16

ND

输出tensor,维度数与x一致。

intermediateOutput

[m*rankSize, n]/[batch*rankSize, m, n]

float16/bf16

ND

输出tensor,维度数与x一致。

规格约束

  • 多用户使用时需要使用ATB_SHARE_MEMORY_NAME_SUFFIX环境变量(请参见Transformer加速库环境变量说明)进行共享内存的区分,以进行初始化信息同步。
  • 当使用加速库的通信算子异常退出时,需要清空残留数据,避免影响之后的使用,命令参考如下:
    rm -rf /dev/shm/sem.lccl*
    rm -rf /dev/shm/sem.hccl*
    ipcrm -a