昇腾社区首页
中文
注册

StridedBatchMatmulOperation

功能

对矩阵进行分组,指定每组矩阵之间的步长,实现更加灵活的矩阵乘法操作。当前只支持Atlas 800I A2推理产品

定义

struct StridedBatchMatmulParam {
    bool transposeA = false;
    bool transposeB = false;
    int32_t batch      = 1;
    int32_t headNum    = 1;
    std::vector<int32_t> m;
    std::vector<int32_t> n;
    std::vector<int32_t> k;
    std::vector<int32_t> lda;
    std::vector<int32_t> ldb;
    std::vector<int32_t> ldc;
    std::vector<int32_t> strideA;
    std::vector<int32_t> strideB;
    std::vector<int32_t> strideC;
};

成员

成员名称

描述

transposeA

是否转置A矩阵。

transposeB

是否转置B矩阵。

batch

batch个数batchSize。

headNum

多头注意力机制的head数。

m

A矩阵参与一次矩阵乘指令的shape大小。元素个数为batchSize。

n

B矩阵参与一次矩阵乘指令的shape大小。元素个数为batchSize。

k

C矩阵参与一次矩阵乘指令的shape大小。元素个数为batchSize。

lda

表示矩阵A的列数。元素个数为batchSize。

ldb

表示矩阵B的列数。元素个数为batchSize。

ldc

表示矩阵C的列数。元素个数为batchSize。

strideA

矩阵A在内存中相邻两次计算之间的跨度。元素个数为batchSize。

strideB

矩阵B在内存中相邻两次计算之间的跨度。元素个数为batchSize。

strideC

矩阵C在内存中相邻两次计算之间的跨度。元素个数为batchSize。

输入

参数

维度

数据类型

格式

描述

A

  • bmm1、bmm1_grad2、bmm2_grad1:[nTokens, hiddenSize]
  • bmm1_grad1、bmm2、bmm2_grad2:[nSquareTokens]

float16

ND

输入tensor。

B

  • bmm1、bmm1_grad2、bmm2_grad1:[nTokens, hiddenSize]
  • bmm1_grad1、bmm2、bmm2_grad2:[nSquareTokens]

float16

ND

输入tensor。

输出

参数

维度

数据类型

格式

描述

output

[outdims]

float16

ND

输出tensor。