StridedBatchMatmulOperation
产品支持情况
产品  | 
是否支持  | 
|---|---|
√  | 
|
√  | 
|
x  | 
|
x  | 
|
x  | 
功能说明
对矩阵进行分组,指定每组矩阵之间的步长,实现更加灵活的矩阵乘法操作。
定义
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16  | struct StridedBatchMatmulParam { bool transposeA = false; bool transposeB = false; int32_t batch = 1; int32_t headNum = 1; std::vector<int32_t> m; std::vector<int32_t> n; std::vector<int32_t> k; std::vector<int32_t> lda; std::vector<int32_t> ldb; std::vector<int32_t> ldc; std::vector<int32_t> strideA; std::vector<int32_t> strideB; std::vector<int32_t> strideC; uint8_t rsv[8] = {0}; };  | 
参数列表
成员名称  | 
类型  | 
默认值  | 
描述  | 
|---|---|---|---|
transposeA  | 
bool  | 
false  | 
是否转置A矩阵。  | 
transposeB  | 
bool  | 
false  | 
是否转置B矩阵。  | 
batch  | 
int32_t  | 
1  | 
batch个数batchSize。  | 
headNum  | 
int32_t  | 
1  | 
多头注意力机制的head数。  | 
m  | 
std::vector< int32_t >  | 
-  | 
A矩阵参与一次矩阵乘指令的shape大小。元素个数为batchSize。  | 
n  | 
std::vector< int32_t >  | 
-  | 
B矩阵参与一次矩阵乘指令的shape大小。元素个数为batchSize。  | 
k  | 
std::vector< int32_t >  | 
-  | 
C矩阵参与一次矩阵乘指令的shape大小。元素个数为batchSize。  | 
lda  | 
std::vector< int32_t >  | 
-  | 
表示矩阵A的列数。元素个数为batchSize。  | 
ldb  | 
std::vector< int32_t >  | 
-  | 
表示矩阵B的列数。元素个数为batchSize。  | 
ldc  | 
std::vector< int32_t >  | 
-  | 
表示矩阵C的列数。元素个数为batchSize。  | 
strideA  | 
std::vector< int32_t >  | 
-  | 
矩阵A在内存中相邻两次计算之间的跨度。元素个数为batchSize。  | 
strideB  | 
std::vector< int32_t >  | 
-  | 
矩阵B在内存中相邻两次计算之间的跨度。元素个数为batchSize。  | 
strideC  | 
std::vector< int32_t >  | 
-  | 
矩阵C在内存中相邻两次计算之间的跨度。元素个数为batchSize。  | 
rsv[8]  | 
uint8_t  | 
{0}  | 
预留参数。  | 
输入
参数  | 
维度  | 
数据类型  | 
格式  | 
描述  | 
|---|---|---|---|---|
A  | 
  | 
float16  | 
ND  | 
输入tensor。  | 
B  | 
  | 
float16  | 
ND  | 
输入tensor。  | 
输出
参数  | 
维度  | 
数据类型  | 
格式  | 
描述  | 
|---|---|---|---|---|
output  | 
[outdims]  | 
float16  | 
ND  | 
输出tensor。  | 
约束说明
无