Strmm
功能描述
计算公式
接口原型
- AspbStatus asdBlasMakeStrmmPlan(asdBlasHandle handle)
- AspbStatus asdBlasStrmm(asdBlasHandle handle, asdBlasSideMode_t side, asdBlasFillMode_t uplo, asdBlasOperation_t trans,asdBlasDiagType_t diag, const int64_t m, const int64_t n, const float *alpha, aclTensor *A,const int64_t lda, aclTensor *B, const int64_t ldb, aclTensor *C, const int64_t ldc)
参数列表
- asdBlasMakeStrmmPlan
参数名称
Input/Output
类型
描述
handle
Input
asdBlasHandle
strmm算子的句柄。
- asdBlasStrmm
参数名称
Input/Output
类型
描述
handle
Input
asdBlasHandle
strmm算子的句柄。
side
Input
asdBlasSideMode_t
指定是矩阵A是乘法左侧还是右侧。
ASDBLAS_SIDE_LEFT // 左侧 ASDBLAS_SIDE_RIGHT // 右侧
uplo
Input
asdBlasFillMode_t
指定矩阵A的存储方式。
ASDBLAS_FILL_MODE_LOWER // 下三角 ASDBLAS_FILL_MODE_UPPER // 上三角
trans
Input
asdBlasOperation_t
指定是否对矩阵A进行转置。
ASDBLAS_OP_N // 不转置 ASDBLAS_OP_T // 转置
diag
Input
asdBlasDiagType_t
指定是否假定矩阵A的对角线元素为1。
ASDBLAS_DIAG_NON_UNIT // 不假定为1 ASDBLAS_DIAG_UNIT // 假定为1
m
Input
const int64_t
矩阵B和C的行数。
n
Input
const int64_t
矩阵B和C的列数。
alpha
Input
const float *
公式中的alpha,用于计算矩阵乘法的系数。
A
Input
aclTensor *
公式中的A,Device侧的Tensor,数据类型仅支持FLOAT32,数据格式支持ND,当A为乘法左矩阵时,shape为[m,m],当A为乘法右矩阵时,shape为[n,n]。
lda
Input
const int64_t
表示张量A中元素的间隔(当前约束为m/n,当side=ASDBLAS_SIDE_LEFT时为m,ASDBLAS_SIDE_RIGHT时为n)。
B
Input
aclTensor *
公式中的B,Device侧的Tensor,数据类型仅支持FLOAT32,数据格式支持ND,shape为[m,n]。
ldb
Input
const int64_t
表示张量B中元素的间隔(当前约束为m)。
C
Output
aclTensor *
公式中的C,Device侧的Tensor,数据类型仅支持FLOAT32,数据格式支持ND,shape为[m,n]。
ldc
Input
const int64_t
表示张量C中元素的间隔(当前约束为m)。
规格约束
- asdBlasMakeStrmmPlan:无。
- asdBlasStrmm
- 输入的元素个数m,n当前覆盖支持[1,8193];
- 当side = ASDBLAS_SIDE_LEFT时,算子输入shape为[m,m]、[m,n],输出shape为[m,n];
- 当side = ASDBLAS_SIDE_RIGHT时,算子输入shape为[n,n]、[m,n],输出shape为[m,n];
- 算子实际计算时,不支持ND高维度运算(不支持维度≥3的运算)。