接口功能:完成张量x1与张量x2的矩阵乘计算,仅支持x1为ND格式,x2为NZ格式,只支持x1为3维,x2为5维。Tensor支持转置,转置序列根据传入的数列进行变更。permX1代表张量x1的转置序列,permX2代表张量x2的转置序列,序列值为0的是batch维度,其余两个维度做矩阵乘法。
示例:(x2的NZ转换为ND对应的viewshape视角)
- x1的shape是(B, M, K),x2的shape是(B, K, N),scale为None,batchSplitFactor等于1时,计算输出out的shape是(M, B, N)。
- x1的shape是(B, M, K),x2的shape是(B, K, N),scale不为None,batchSplitFactor等于1时,计算输出out的shape是(M, 1, B * N)。
- x1的shape是(B, M, K),x2的shape是(B, K, N),scale为None,batchSplitFactor大于1时,计算输出out的shape是(batchSplitFactor, M, B * N / batchSplitFactor)。
每个算子分为,必须先调用“aclnnTransposeBatchMatMulWeightNzGetWorkspaceSize”接口获取入参并根据流程计算所需workspace大小,再调用“aclnnTransposeBatchMatMulWeightNz”接口执行计算。
[object Object]
[object Object]
确定性说明:
- aclnnTransposeBatchMatMulWeightNz默认确定性实现。
[object Object]Atlas A2 训练系列产品/Atlas A2 推理系列产品[object Object]、[object Object]Atlas A3 训练系列产品/Atlas A3 推理系列产品[object Object]:
- B的取值范围为[1, 65536),N的取值范围为[1, 65536)。
- 当x1的输入shape为(B, M, K)时,K <= 65535;当x1的输入shape为(M, B, K)时,B * K <= 65535。
- x2的NZ格式对应的ND格式中,第二维和第三维都必须被16整除。
- x2的StorageFormat必须为NZ格式。
- x1和x2的dtype必须相同。
- permX2仅支持输入[0, 1, 2]。
- 当scale不为空时,B与N的乘积小于65536, 且仅支持输入为FLOAT16和输出为INT8的类型推导。
Atlas 350 加速卡:
- permX2支持输入[0, 1, 2]、[0, 2, 1]。
- 当scale不为空时,batchSplitFactor只能等于1,且仅支持输入为FLOAT16和输出为INT8的类型推导。
self只支持3维, mat2只支持昇腾私有格式,调用此接口之前,必须完成mat2从ND到昇腾私有格式的转换。
不支持mat2最后两根轴其中一根轴为1,即k=1或者n=1。