HcgemmBatched
产品支持情况
硬件型号 |
支持情况 |
---|---|
|
不支持 |
|
不支持 |
|
不支持 |
|
支持 |
|
支持 |
功能描述
接口功能
- asdBlasMakeHCgemmBatchedPlan:初始化该句柄对应的算子配置。
- asdBlasHCgemmBatched:用于计算两批复数矩阵的乘积。
计算公式
asdBlasHCgemmBatched的计算公式:
- 示例:
[ [ 1+i, 1+2i ], [ 1+3i, 1+4i ] ]
输入“inTensorB[i]”为:
[ [ 2+i, 2+2i ], [ 2+3i, 2+4i ] ]
输入“inTensorC[i]”为:
[ [ 3+i, 3+2i ], [ 3+3i, 3+4i ] ]
输入“transa”为: N,输入“transb”为:T,
输入“m”为:2,输入“n”为: 2,输入“k”为:2,输入“alpha”为:1+i,“beta”为:2+2i,
输入“lda”为: 2,输入“ldb”为:2,输入“ldc”为:2,
输入“batchCount”为:1
调用“Cgemm”算子后,
输出“C”为:
[ [ -15+19i, -27+19i ], [ -37+21i, -57+13i ] ]
函数原型
- AspbStatus asdBlasMakeHCgemmBatchedPlan(asdBlasHandle handle);
- AspbStatus asdBlasHCgemmBatched(asdBlasHandle handle, asdBlasOperation_t transa, asdBlasOperation_t transb, const int64_t m,const int64_t n, const int64_t k, const std::complex<op::fp16_t> &alpha, aclTensor *A,const int64_t lda, aclTensor *B, const int64_t ldb, const std::complex<op::fp16_t> &beta,aclTensor *C, const int64_t ldc, const int64_t batchCount);
参数说明
- asdBlasMakeHCgemmBatchedPlan
参数名称
Input/Output
类型
描述
handle
Input
asdBlasHandle
HCgemmBatched算子的句柄。
- asdBlasHCgemmBatched
参数名称
Input/Output
类型
描述
handle
Input
asdBlasHandle
HCgemmBatched算子的句柄。
transa
Input
asdBlasOperation_t
指定矩阵A是否需要转置,取值必须为ASDBLAS_OP_N。
transb
Input
asdBlasOperation_t
指定矩阵B是否需要转置,取值必须为ASDBLAS_OP_N。
m
Input
const int64_t
矩阵C的行数,取值范围为:{1-32}。
n
Input
const int64_t
矩阵C的列数,取值范围为:{1-32}。
k
Input
const int64_t
矩阵A和B的公共维度,取值范围为:{1-32}。
alpha
Input
const std::complex<op::fp16_t> &
公式中的alpha,复数标量,用于乘以矩阵乘法的结果,取值必须为1+0j。
A
Input
aclTensor *
公式中的A,列主序,Device侧的Tensor,数据类型仅支持COMPLEX32,数据格式支持ND,shape为[batchCount, m, k]。
lda
Input
const int64_t
A左右相邻元素间的内存地址偏移量,取值和k相等。
B
Input
aclTensor *
公式中的B,Device侧的Tensor,数据类型仅支持COMPLEX32,数据格式支持ND,shape为[batchCount, k, n]。
ldb
Input
const int64_t
B左右相邻元素间的内存地址偏移量,取值和n相等。
beta
Input
const std::complex<op::fp16_t> &
公式中的beta,复数标量,用于乘以矩阵C。取值必须为 0+0j。
C
Input
aclTensor *
公式中的C,Device侧的Tensor,数据类型仅支持COMPLEX32,数据格式支持ND,shape为[batchCount, m, n]。
ldc
Input
const int64_t
C左右相邻元素间的内存地址偏移量,取值和n相等。
batchCount
Input
const int64_t
批次数量。取值范围为{12 - 26208}。
约束说明
- asdBlasMakeHCgemmBatchedPlan:无。
- asdBlasHCgemmBatched
- 支持的CANN版本为CANN8.0及以上。
- 输入支持的数据类型为COMPLEX32。
- 输出支持的数据类型为COMPLEX32。
- 算子实际计算时,只支持3维ND运算。
- 算子输入数据为行主序,输入shape为[batchCount, m,k,]、[batchCount, k,n]、[batchCount, m,n],输出shape为[batchCount, m,n];
调用示例
算子的调用示例参见HcgemmBatched。