StridedBatchMatmulOperation

功能

对矩阵进行分组,指定每组矩阵之间的步长,实现更加灵活的矩阵乘法操作。

定义

struct StridedBatchMatmulParam {
    bool                 transposeA = false;
    bool                 transposeB = false;
    int32_t              batch      = 1;
    int32_t              headNum    = 1;
    std::vector<int32_t> m;
    std::vector<int32_t> n;
    std::vector<int32_t> k;
    std::vector<int32_t> lda;
    std::vector<int32_t> ldb;
    std::vector<int32_t> ldc;
    std::vector<int32_t> strideA;
    std::vector<int32_t> strideB;
    std::vector<int32_t> strideC;

    bool operator==(const StridedBatchMatmulParam &other) const
    {
        return this->transposeA == other.transposeA &&
            this->transposeB == other.transposeB &&
            this->batch == other.batch &&
            this->headNum == other.headNum &&
            this->m == other.m &&
            this->n == other.n &&
            this->k == other.k &&
            this->lda == other.lda &&
            this->ldb == other.ldb &&
            this->ldc == other.ldc &&
            this->strideA == other.strideA &&
            this->strideB == other.strideB &&
            this->strideC == other.strideC;
    }
};

成员

成员名称

描述

transposeA

是否转置A矩阵。

transposeB

是否转置B矩阵。

batch

batch个数batchSize。

headNum

多头注意力机制的head数。

m

A矩阵参与一次矩阵乘指令的shape大小。元素个数为batchSize。

n

B矩阵参与一次矩阵乘指令的shape大小。元素个数为batchSize。

k

C矩阵参与一次矩阵乘指令的shape大小。元素个数为batchSize。

lda

表示矩阵A的列数。元素个数为batchSize。

ldb

表示矩阵B的列数。元素个数为batchSize。

ldc

表示矩阵C的列数。元素个数为batchSize。

strideA

矩阵A在内存中相邻两次计算之间的跨度。元素个数为batchSize。

strideB

矩阵B在内存中相邻两次计算之间的跨度。元素个数为batchSize。

strideC

矩阵C在内存中相邻两次计算之间的跨度。元素个数为batchSize。

输入

参数

维度

数据类型

格式

描述

A

1) bmm1、bmm1_grad2、bmm2_grad1:

[nTokens, hiddenSize]

2) bmm1_grad1、bmm2、bmm2_grad2:

[nSquareTokens]

float16

ND

B

1) bmm1、bmm1_grad2、bmm2_grad1:

[nTokens, hiddenSize]

2) bmm1_grad1、bmm2、bmm2_grad2:

[nSquareTokens]

float16

ND

输出

参数

维度

数据类型

格式

描述

output

[outdims]

float16

ND