昇腾社区首页
中文
注册

GroupMatmulReduceScatterAlltoallvc输入输出

输入

参数

维度

数据类型

格式

描述

input

[m',k]

浮点:float16/bf16

量化:int8

ND

矩阵乘的A矩阵。

weight

[localExpertNums, k,n]

浮点:float16/bf16

量化:int8

ND

矩阵乘的B矩阵,权重。

bias

[n]

量化:int32

ND

叠加的偏置矩阵。hasBias为true时输入。

deqScale

[n*localExpertNums]

量化:int64/float32

ND

反量化的scale。仅量化时需要该参数。

输出为float16时只支持int64,输出为bf16时同时支持int64和float32。

residual

[n]

float16/bf16

ND

残差,用于叠加到最后的输出结果上。

dequantPerTokenScale

[m']

量化:float32

ND

perToken反量化的scale。仅量化且量化为perToken时需要该参数。

globalTokensPerExpertMatrix

[ep, ep*localExpertNums]

int32

ND

Alltoallvc通信矩阵。

maxOutputSize

[m]

int32

ND

用于推测输出的Shape。

输出

参数

维度

数据类型

格式

描述

output

[m,n]

float16/bf16

ND

输出tensor。