参数列表

成员名称

类型

默认值

描述

transWeight

bool

true

权重是否需要转置,默认为true。

rank

int

0

当前卡所属通信编号。

rankSize

int

0

通信的卡的数量。

rankRoot

int

0

主通信编号。

hasResidual

bool

false

是否叠加残差。配置为“false”时不叠加残差,为“true”时叠加残差。默认不叠加残差。

backend

std::string

"hccl"

通信后端指示。支持“hccl”,“lccl”,“lcoc”。

Atlas 推理系列产品仅支持“backend”为“hccl”。

hcclComm

HcclComm

nullptr

HCCL通信域指针。

默认为空,加速库为用户创建;若用户想要自己管理通信域,则需要传入该通信域指针,加速库使用传入的通信域指针来执行通信算子。

commMode

CommMode

COMM_MULTI_PROCESS

通信模式,CommMode类型枚举值。hccl多线程只支持外部传入通信域方式。

rankTableFile

std::string

-

集群信息的配置文件路径,适用单机以及多机通信场景,当前仅支持hccl后端场景。

配置请参见《TensorFlow 1.15模型迁移指南》的“模型训练>执行分布式训练>准备ranktable资源配置文件”章节

type

ParallelType

LINEAR_ALL_REDUCE

权重并行类型。仅在“backend”为“lcoc”时生效。

  • UNDEFINED:未定义。
  • LINEAR_ALL_REDUCE:linear + AllReduce,默认值。
  • LINEAR_REDUCE_SCATTER:linear + reduce_scatter。
  • ALL_GATHER_LINEAR:AllGather + linear。
  • PURE_LINEAR:linear。
  • ALL_GATHER_LINEAR_REDUCE_SCATTER:AllGather + linear + reduce_scatter。
  • MAX :枚举类型最大值。

keepIntermediate

bool

false

是否返回中间结果,仅在“ParallelType”使用“ALL_GATHER_LINEAR”时生效。

quantType

QuantType

QUANT_TYPE_UNQUANT

量化类型。仅在“backend”为“lcoc”时生效。

  • QUANT_TYPE_UNDEFINED:默认值。
  • QUANT_TYPE_UNQUANT:默认值。
  • QUANT_TYPE_PER_TENSOR:对整个张量进行量化。
  • QUANT_TYPE_PER_CHANNEL:对张量中每个channel分别进行量化。
  • QUANT_TYPE_PER_GROUP:将张量按quantGroupSize划分后,分别进行量化。
  • QUANT_TYPE_MAX:枚举类型最大值。

quantGroupSize

int32_t

0

量化类型为“QUANT_TYPE_PER_GROUP”时有效。

outDataType

aclDataType

ACL_DT_UNDEFINED

  • 若为浮点linear,参数“outDataType”配置为ACL_DT_UNDEFINED,表示输出tensor的数据类型与输入tensor一致。
  • 若为量化linear,输出tensor的数据类型与输入tensor不一致,则参数“outDataType”配置为用户预期输出tensor的数据类型, 如ACL_FLOAT16/ACL_BF16。

commDomain

std::string

-

通信device组用通信域名标识,多通信域时使用。当backend为lccl时,commMode为多进程时,commDomain需要设置0-63。

twoDimTPInfo

TwoDimTPInfo

-

AllGather_Matmul_ReduceScatter算子参数。

rsv[56]

uint8_t

{0}

预留参数。

表1 TwoDimTPInfo成员

成员名称

类型

默认值

描述

agDim

uint16_t

0

表示allGather轴卡数,规定x轴方向是非连续卡号。

rsDim

uint16_t

0

表示reduceScatter轴卡数,规定y轴方向是连续卡号。

innerDimIsAg

uint8_t

1

allGather通信的rank是否连续,1表示true,0表示false。

rsv[3]

uint8_t

{0}

填充满8字节。