StridedBatchMatmulOperation
功能
对矩阵进行分组,指定每组矩阵之间的步长,实现更加灵活的矩阵乘法操作。
定义
struct StridedBatchMatmulParam { bool transposeA = false; bool transposeB = false; int32_t batch = 1; int32_t headNum = 1; std::vector<int32_t> m; std::vector<int32_t> n; std::vector<int32_t> k; std::vector<int32_t> lda; std::vector<int32_t> ldb; std::vector<int32_t> ldc; std::vector<int32_t> strideA; std::vector<int32_t> strideB; std::vector<int32_t> strideC; };
参数列表
成员名称 |
类型 |
默认值 |
描述 |
---|---|---|---|
transposeA |
bool |
false |
是否转置A矩阵。 |
transposeB |
bool |
false |
是否转置B矩阵。 |
batch |
int32_t |
1 |
batch个数batchSize。 |
headNum |
int32_t |
1 |
多头注意力机制的head数。 |
m |
std::vector< int32_t > |
- |
A矩阵参与一次矩阵乘指令的shape大小。元素个数为batchSize。 |
n |
std::vector< int32_t > |
- |
B矩阵参与一次矩阵乘指令的shape大小。元素个数为batchSize。 |
k |
std::vector< int32_t > |
- |
C矩阵参与一次矩阵乘指令的shape大小。元素个数为batchSize。 |
lda |
std::vector< int32_t > |
- |
表示矩阵A的列数。元素个数为batchSize。 |
ldb |
std::vector< int32_t > |
- |
表示矩阵B的列数。元素个数为batchSize。 |
ldc |
std::vector< int32_t > |
- |
表示矩阵C的列数。元素个数为batchSize。 |
strideA |
std::vector< int32_t > |
- |
矩阵A在内存中相邻两次计算之间的跨度。元素个数为batchSize。 |
strideB |
std::vector< int32_t > |
- |
矩阵B在内存中相邻两次计算之间的跨度。元素个数为batchSize。 |
strideC |
std::vector< int32_t > |
- |
矩阵C在内存中相邻两次计算之间的跨度。元素个数为batchSize。 |
输入
参数 |
维度 |
数据类型 |
格式 |
描述 |
---|---|---|---|---|
A |
|
float16 |
ND |
输入tensor。 |
B |
|
float16 |
ND |
输入tensor。 |
输出
参数 |
维度 |
数据类型 |
格式 |
描述 |
---|---|---|---|---|
output |
[outdims] |
float16 |
ND |
输出tensor。 |
规格约束
当前只支持