GroupMatmulReduceScatterAlltoallvc输入输出
输入
参数 |
维度 |
数据类型 |
格式 |
描述 |
---|---|---|---|---|
input |
[m',k] |
浮点:float16/bf16 量化:int8 |
ND |
矩阵乘的A矩阵。 |
weight |
[localExpertNums, k,n] |
浮点:float16/bf16 量化:int8 |
ND |
矩阵乘的B矩阵,权重。 |
bias |
[n] |
量化:int32 |
ND |
叠加的偏置矩阵。hasBias为true时输入。 |
deqScale |
[n*localExpertNums] |
量化:int64/float32 |
ND |
反量化的scale。仅量化时需要该参数。 输出为float16时只支持int64,输出为bf16时同时支持int64和float32。 |
residual |
[n] |
float16/bf16 |
ND |
残差,用于叠加到最后的输出结果上。 |
dequantPerTokenScale |
[m'] |
量化:float32 |
ND |
perToken反量化的scale。仅量化且量化为perToken时需要该参数。 |
globalTokensPerExpertMatrix |
[ep, ep*localExpertNums] |
int32 |
ND |
Alltoallvc通信矩阵。 |
maxOutputSize |
[m] |
int32 |
ND |
用于推测输出的Shape。 |
输出
参数 |
维度 |
数据类型 |
格式 |
描述 |
---|---|---|---|---|
output |
[m,n] |
float16/bf16 |
ND |
输出tensor。 |
父主题: 输入输出列表