Strmm
产品支持情况
硬件型号 |
支持情况 |
---|---|
|
不支持 |
|
不支持 |
|
不支持 |
|
支持 |
|
支持 |
功能描述
接口功能
- asdBlasMakeStrmmPlan:初始化该句柄对应的Strmm算子配置。
- asdBlasStrmm:单精度,其功能是将一个三角矩阵A乘一个矩阵B,得到一个新的矩阵C。
计算公式
函数原型
- AspbStatus asdBlasMakeStrmmPlan(asdBlasHandle handle)
- AspbStatus asdBlasStrmm(asdBlasHandle handle, asdBlasSideMode_t side, asdBlasFillMode_t uplo, asdBlasOperation_t trans,asdBlasDiagType_t diag, const int64_t m, const int64_t n, const float *alpha, aclTensor *A,const int64_t lda, aclTensor *B, const int64_t ldb, aclTensor *C, const int64_t ldc)
参数说明
- asdBlasMakeStrmmPlan
参数名称
Input/Output
类型
描述
handle
Input
asdBlasHandle
Strmm算子的句柄。
- asdBlasStrmm
参数名称
Input/Output
类型
描述
handle
Input
asdBlasHandle
Strmm算子的句柄。
side
Input
asdBlasSideMode_t
指定矩阵A是乘法左侧还是右侧。
ASDBLAS_SIDE_LEFT // 左侧 ASDBLAS_SIDE_RIGHT // 右侧
uplo
Input
asdBlasFillMode_t
指定矩阵A的存储方式。
ASDBLAS_FILL_MODE_LOWER // 下三角 ASDBLAS_FILL_MODE_UPPER // 上三角
trans
Input
asdBlasOperation_t
指定是否对矩阵A进行转置。
ASDBLAS_OP_N // 不转置 ASDBLAS_OP_T // 转置
diag
Input
asdBlasDiagType_t
指定是否假定矩阵A的对角线元素为1。
ASDBLAS_DIAG_NON_UNIT // 不假定为1 ASDBLAS_DIAG_UNIT // 假定为1
m
Input
const int64_t
矩阵B和C的行数。
n
Input
const int64_t
矩阵B和C的列数。
alpha
Input
const float *
公式中的alpha,用于计算矩阵乘法的系数。
A
Input
aclTensor *
公式中的A,Device侧的Tensor,数据类型仅支持FLOAT32,数据格式支持ND,当A为乘法左矩阵时,shape为[m,m],当A为乘法右矩阵时,shape为[n,n]。
lda
Input
const int64_t
表示张量A中元素的间隔(当前约束为m/n,当side=ASDBLAS_SIDE_LEFT时为m,ASDBLAS_SIDE_RIGHT时为n)。
B
Input
aclTensor *
公式中的B,Device侧的Tensor,数据类型仅支持FLOAT32,数据格式支持ND,shape为[m,n]。
ldb
Input
const int64_t
表示张量B中元素的间隔(当前约束为m)。
C
Output
aclTensor *
公式中的C,Device侧的Tensor,数据类型仅支持FLOAT32,数据格式支持ND,shape为[m,n]。
ldc
Input
const int64_t
表示张量C中元素的间隔(当前约束为m)。
约束说明
- asdBlasMakeStrmmPlan:无。
- asdBlasStrmm
- 输入的元素个数m,n当前覆盖支持[1,8193];
- 当side = ASDBLAS_SIDE_LEFT时,算子输入shape为[m,m]、[m,n],输出shape为[m,n];
- 当side = ASDBLAS_SIDE_RIGHT时,算子输入shape为[n,n]、[m,n],输出shape为[m,n];
- 算子实际计算时,不支持ND高维度运算(不支持维度≥3的运算)。
调用示例
算子的调用示例参见Strmm。