昇腾社区首页
中文
注册

StridedBatchMatmulOperation

产品支持情况

产品

是否支持

Atlas A3 推理系列产品/Atlas A3 训练系列产品

Atlas A2 训练系列产品/Atlas 800I A2 推理产品

Atlas 训练系列产品

x

Atlas 推理系列产品

x

Atlas 200I/500 A2 推理产品

x

功能说明

对矩阵进行分组,指定每组矩阵之间的步长,实现更加灵活的矩阵乘法操作。

定义

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
struct StridedBatchMatmulParam {
    bool transposeA = false;
    bool transposeB = false;
    int32_t batch      = 1;
    int32_t headNum    = 1;
    std::vector<int32_t> m;
    std::vector<int32_t> n;
    std::vector<int32_t> k;
    std::vector<int32_t> lda;
    std::vector<int32_t> ldb;
    std::vector<int32_t> ldc;
    std::vector<int32_t> strideA;
    std::vector<int32_t> strideB;
    std::vector<int32_t> strideC;
    uint8_t rsv[8] = {0};
};

参数列表

成员名称

类型

默认值

描述

transposeA

bool

false

是否转置A矩阵。

transposeB

bool

false

是否转置B矩阵。

batch

int32_t

1

batch个数batchSize。

headNum

int32_t

1

多头注意力机制的head数。

m

std::vector< int32_t >

-

A矩阵参与一次矩阵乘指令的shape大小。元素个数为batchSize。

n

std::vector< int32_t >

-

B矩阵参与一次矩阵乘指令的shape大小。元素个数为batchSize。

k

std::vector< int32_t >

-

C矩阵参与一次矩阵乘指令的shape大小。元素个数为batchSize。

lda

std::vector< int32_t >

-

表示矩阵A的列数。元素个数为batchSize。

ldb

std::vector< int32_t >

-

表示矩阵B的列数。元素个数为batchSize。

ldc

std::vector< int32_t >

-

表示矩阵C的列数。元素个数为batchSize。

strideA

std::vector< int32_t >

-

矩阵A在内存中相邻两次计算之间的跨度。元素个数为batchSize。

strideB

std::vector< int32_t >

-

矩阵B在内存中相邻两次计算之间的跨度。元素个数为batchSize。

strideC

std::vector< int32_t >

-

矩阵C在内存中相邻两次计算之间的跨度。元素个数为batchSize。

rsv[8]

uint8_t

{0}

预留参数。

输入

参数

维度

数据类型

格式

描述

A

  • bmm1、bmm1_grad2、bmm2_grad1:[nTokens, hiddenSize]
  • bmm1_grad1、bmm2、bmm2_grad2:[nSquareTokens]

float16

ND

输入tensor。

B

  • bmm1、bmm1_grad2、bmm2_grad1:[nTokens, hiddenSize]
  • bmm1_grad1、bmm2、bmm2_grad2:[nSquareTokens]

float16

ND

输入tensor。

输出

参数

维度

数据类型

格式

描述

output

[outdims]

float16

ND

输出tensor。

约束说明