本OP对标torch库中的matmul算子,功能略有不同。部分对标的功能参考torch.matmul文档中的示例:
# batched matrix x batched matrix tensor1 = torch.randn(10, 3, 4) tensor2 = torch.randn(10, 4, 5) torch.matmul(tensor1, tensor2).size() # torch.Size([10, 3, 5]) # batched matrix x broadcasted matrix tensor1 = torch.randn(10, 3, 4) tensor2 = torch.randn(4, 5) torch.matmul(tensor1, tensor2).size() # torch.Size([10, 3, 5])