昇腾社区首页
中文
注册

StridedBatchMatmulOperation

功能

对矩阵进行分组,指定每组矩阵之间的步长,实现更加灵活的矩阵乘法操作。

定义

struct StridedBatchMatmulParam {
    bool transposeA = false;
    bool transposeB = false;
    int32_t batch      = 1;
    int32_t headNum    = 1;
    std::vector<int32_t> m;
    std::vector<int32_t> n;
    std::vector<int32_t> k;
    std::vector<int32_t> lda;
    std::vector<int32_t> ldb;
    std::vector<int32_t> ldc;
    std::vector<int32_t> strideA;
    std::vector<int32_t> strideB;
    std::vector<int32_t> strideC;
};

参数列表

成员名称

类型

默认值

描述

transposeA

bool

false

是否转置A矩阵。

transposeB

bool

false

是否转置B矩阵。

batch

int32_t

1

batch个数batchSize。

headNum

int32_t

1

多头注意力机制的head数。

m

std::vector< int32_t >

-

A矩阵参与一次矩阵乘指令的shape大小。元素个数为batchSize。

n

std::vector< int32_t >

-

B矩阵参与一次矩阵乘指令的shape大小。元素个数为batchSize。

k

std::vector< int32_t >

-

C矩阵参与一次矩阵乘指令的shape大小。元素个数为batchSize。

lda

std::vector< int32_t >

-

表示矩阵A的列数。元素个数为batchSize。

ldb

std::vector< int32_t >

-

表示矩阵B的列数。元素个数为batchSize。

ldc

std::vector< int32_t >

-

表示矩阵C的列数。元素个数为batchSize。

strideA

std::vector< int32_t >

-

矩阵A在内存中相邻两次计算之间的跨度。元素个数为batchSize。

strideB

std::vector< int32_t >

-

矩阵B在内存中相邻两次计算之间的跨度。元素个数为batchSize。

strideC

std::vector< int32_t >

-

矩阵C在内存中相邻两次计算之间的跨度。元素个数为batchSize。

输入

参数

维度

数据类型

格式

描述

A

  • bmm1、bmm1_grad2、bmm2_grad1:[nTokens, hiddenSize]
  • bmm1_grad1、bmm2、bmm2_grad2:[nSquareTokens]

float16

ND

输入tensor。

B

  • bmm1、bmm1_grad2、bmm2_grad1:[nTokens, hiddenSize]
  • bmm1_grad1、bmm2、bmm2_grad2:[nSquareTokens]

float16

ND

输入tensor。

输出

参数

维度

数据类型

格式

描述

output

[outdims]

float16

ND

输出tensor。

规格约束

当前只支持Atlas 800I A2 推理产品/Atlas A2 训练系列产品