LinearParallelOperation

功能

多卡并行Linear计算处理。

该算子涉及多卡相关操作,可根据实际需求配置HCCL相关环境变量,具体请参见《环境变量参考》执行相关 > 集合通信与分布式训练 > 集合通信相关配置章节。配置TLS等相关操作可查看对应设备的。

约束

当使用加速库的通信算子异常退出时,需要清空残留数据,避免影响之后的使用,命令参考如下:

rm -rf /dev/shm/sem.lccl*
rm -rf /dev/shm/sem.hccl*
ipcrm -a

定义

struct LinearParallelParam {
    enum ParallelType : int {
        UNDEFINED = -1,
        LINEAR_ALL_REDUCE = 0,
        LINEAR_REDUCE_SCATTER = 1,
        ALL_GATHER_LINEAR = 2,
        PURE_LINEAR = 3,
        MAX = 4,
    };
    enum QuantType : int {
        QUANT_TYPE_UNDEFINED = -1,
        QUANT_TYPE_PER_TENSOR = 0,
        QUANT_TYPE_PER_CHANNEL = 1,
        QUANT_TYPE_PER_GROUP = 2,
        QUANT_TYPE_MAX = 3,
    };
    bool transWeight = true;                    // 权重是否需要转置
    int rank = 0;                               // 每个进程的编号
    int rankSize = 0;                           // 总的进程数
    int rankRoot = 0;                           // 主进程编号
    bool hasResidual = false;                                               
    std::string backend = "hccl";               // 通信后端指示
    HcclComm hcclComm = nullptr;                // only effect when hcclComm is not null
    CommMode commMode = COMM_MULTI_PROCESS;
    std::string rankTableFile;
    ParallelType type = LINEAR_ALL_REDUCE;      // 当前仅在backend为lcoc时支持使用
    bool keepIntermediate = false;              // 是否返回中间结果
    QuantType quantType = QUANT_TYPE_UNDEFINED; // 量化类型,当前仅当type为LINEAR_ALL_REDUCE时支持
    int32_t quantGroupSize = 0;                 // 量化类型为QUANT_TYPE_PER_GROUP时生效
};

成员

成员名称

描述

transWeight

权重是否需要转置,默认为true。

rank

每个进程的编号。

rankSize

总的进程数。

rankRoot

主进程编号。

hasResidual

是否叠加偏置。默认为false,不叠加偏置。

backend

通信后端指示。支持“hccl”,“lccl”,“lcoc”。

hcclComm

HCCL通信域接口获取的地址指针,仅当“hcclComm”不为“nullptr”时可用。

commMode

通信模式,CommMode类型枚举值。Atlas 推理系列产品(配置Ascend 310P AI处理器)不支持lccl通信模式。

rankTableFile

多机通信时使用的计算卡对应的IP信息,仅支持“hccl”。

type

权重并行类型。

keepIntermediate

是否返回中间结果。

quantType

量化类型。

quantGroupSize

量化类型为QUANT_TYPE_PER_GROUP时生效。

输入

参数

维度

数据类型

格式

x

当y为ND格式时,支持以下维度输入

  1. x: [batch, m, k]

    weight: [k, n]

  2. x: [m,k]

    weight: [k, n]

当y为NZ格式时,支持以下维度输入:

  1. x: [m, k]

    weight: [1, n/32, k, 32]

  2. x: [batch, m, k]

    weight: [1, n/32, k, 32]

float16

ND

weight

float16

ND/NZ

bias

[1, n] / [n]

float16

ND

输出

参数

维度

数据类型

格式

output

根据以上输入维度,ND/NZ输出维度均为:

  1. output: [batch, m, n]
  2. output: [m ,n]

float16

ND