对矩阵进行分组,指定每组矩阵之间的步长,实现更加灵活的矩阵乘法操作。
struct StridedBatchMatmulParam { bool transposeA = false; bool transposeB = false; int32_t batch = 1; int32_t headNum = 1; std::vector<int32_t> m; std::vector<int32_t> n; std::vector<int32_t> k; std::vector<int32_t> lda; std::vector<int32_t> ldb; std::vector<int32_t> ldc; std::vector<int32_t> strideA; std::vector<int32_t> strideB; std::vector<int32_t> strideC; bool operator==(const StridedBatchMatmulParam &other) const { return this->transposeA == other.transposeA && this->transposeB == other.transposeB && this->batch == other.batch && this->headNum == other.headNum && this->m == other.m && this->n == other.n && this->k == other.k && this->lda == other.lda && this->ldb == other.ldb && this->ldc == other.ldc && this->strideA == other.strideA && this->strideB == other.strideB && this->strideC == other.strideC; } };
成员名称 |
描述 |
---|---|
transposeA |
是否转置A矩阵。 |
transposeB |
是否转置B矩阵。 |
batch |
batch个数batchSize。 |
headNum |
多头注意力机制的head数。 |
m |
A矩阵参与一次矩阵乘指令的shape大小。元素个数为batchSize。 |
n |
B矩阵参与一次矩阵乘指令的shape大小。元素个数为batchSize。 |
k |
C矩阵参与一次矩阵乘指令的shape大小。元素个数为batchSize。 |
lda |
表示矩阵A的列数。元素个数为batchSize。 |
ldb |
表示矩阵B的列数。元素个数为batchSize。 |
ldc |
表示矩阵C的列数。元素个数为batchSize。 |
strideA |
矩阵A在内存中相邻两次计算之间的跨度。元素个数为batchSize。 |
strideB |
矩阵B在内存中相邻两次计算之间的跨度。元素个数为batchSize。 |
strideC |
矩阵C在内存中相邻两次计算之间的跨度。元素个数为batchSize。 |
参数 |
维度 |
数据类型 |
格式 |
描述 |
---|---|---|---|---|
A |
1) bmm1、bmm1_grad2、bmm2_grad1: [nTokens, hiddenSize] 2) bmm1_grad1、bmm2、bmm2_grad2: [nSquareTokens] |
float16 |
ND |
|
B |
1) bmm1、bmm1_grad2、bmm2_grad1: [nTokens, hiddenSize] 2) bmm1_grad1、bmm2、bmm2_grad2: [nSquareTokens] |
float16 |
ND |
参数 |
维度 |
数据类型 |
格式 |
描述 |
---|---|---|---|---|
output |
[outdims] |
float16 |
ND |