aclnnTransposeBatchMatMul
产品支持情况
功能说明
接口功能:完成张量x1与张量x2的矩阵乘计算。仅支持三维的Tensor传入。Tensor支持转置,转置序列根据传入的数列进行变更。permX1代表张量x1的转置序列,permX2代表张量x2的转置序列,序列值为0的是batch维度,其余两个维度做矩阵乘法。
示例:
- x1的shape是(B, M, K),x2的shape是(B, K, N),scale为None,batchSplitFactor等于1时,计算输出out的shape是(M, B, N)。
- x1的shape是(B, M, K),x2的shape是(B, K, N),scale不为None,batchSplitFactor等于1时,计算输出out的shape是(M, 1, B * N)。
- x1的shape是(B, M, K),x2的shape是(B, K, N),scale为None,batchSplitFactor大于1时,计算输出out的shape是(batchSplitFactor, M, B * N / batchSplitFactor)。
函数原型
每个算子分为,必须先调用“aclnnTransposeBatchMatMulGetWorkspaceSize”接口获取入参并根据流程计算所需workspace大小,再调用“aclnnTransposeBatchMatMul”接口执行计算。
[object Object]
[object Object]
aclnnTransposeBatchMatMulGetWorkSpaceSize
aclnnTransposeBatchMatMul
约束说明
确定性说明:
- [object Object]Atlas 训练系列产品[object Object]、[object Object]Atlas 推理系列产品[object Object]:aclnnTransposeBatchMatMul默认确定性实现。
[object Object]Atlas A2 训练系列产品/Atlas A2 推理系列产品[object Object]、[object Object]Atlas A3 训练系列产品/Atlas A3 推理系列产品[object Object]:
- B的取值范围为[1, 65536),N的取值范围为[1, 65536)。
- 当x1的输入shape为(B, M, K)时,K <= 65535;当x1的输入shape为(M, B, K)时,B * K <= 65535。
- permX2仅支持输入[0, 1, 2]。
- 当scale不为空时,B与N的乘积小于65536, 且仅支持输入为FLOAT16和输出为INT8的类型推导。
调用示例
[object Object]