GroupMatmulReduceScatterAlltoallvc输入输出
输入
参数  | 
维度  | 
数据类型  | 
格式  | 
描述  | 
|---|---|---|---|---|
input  | 
[m',k]  | 
浮点:float16/bf16 量化:int8  | 
ND  | 
矩阵乘的A矩阵。  | 
weight  | 
[localExpertNums, k,n]  | 
浮点:float16/bf16 量化:int8  | 
ND  | 
矩阵乘的B矩阵,权重。  | 
bias  | 
[n]  | 
量化:int32  | 
ND  | 
叠加的偏置矩阵。hasBias为true时输入。  | 
deqScale  | 
[n*localExpertNums]  | 
量化:int64/float32  | 
ND  | 
反量化的scale。仅量化时需要该参数。 输出为float16时只支持int64,输出为bf16时同时支持int64和float32。  | 
residual  | 
[n]  | 
float16/bf16  | 
ND  | 
残差,用于叠加到最后的输出结果上。  | 
dequantPerTokenScale  | 
[m']  | 
量化:float32  | 
ND  | 
perToken反量化的scale。仅量化且量化为perToken时需要该参数。  | 
globalTokensPerExpertMatrix  | 
[ep, ep*localExpertNums]  | 
int32  | 
ND  | 
Alltoallvc通信矩阵。  | 
maxOutputSize  | 
[m]  | 
int32  | 
ND  | 
用于推测输出的Shape。  | 
输出
参数  | 
维度  | 
数据类型  | 
格式  | 
描述  | 
|---|---|---|---|---|
output  | 
[m,n]  | 
float16/bf16  | 
ND  | 
输出tensor。  | 
父主题: 输入输出列表