昇腾社区首页
中文
注册

GenAttentionMaskOperation

产品支持情况

产品

是否支持

Atlas A3 推理系列产品/Atlas A3 训练系列产品

Atlas A2 训练系列产品/Atlas 800I A2 推理产品

Atlas 训练系列产品

x

Atlas 推理系列产品

x

Atlas 200I/500 A2 推理产品

x

功能说明

将attentionMask根据每个batch的实际seqlen进行转化,得到结果为一维tensor。

定义

1
2
3
4
5
struct GenAttentionMaskParam {
    int32_t headNum = 1;
    atb::SVector<int32_t> seqLen;
    uint8_t rsv[8] = {0};
};

参数列表

成员名称

类型

默认值

描述

headNum

int32_t

1

多头注意力机制的head数。

seqLen

atb::SVector<int32_t>

-

存储unpad场景下每个batch实际seqlen的值。元素个数为batchSize,最大为32。

rsv[8]

uint8_t

{0}

预留参数。

输入输出

参数

维度

数据类型

格式

描述

x

[batchSize, 1, maxSeqLen, maxSeqLen]

float16

ND

输入,用于attentionmask计算的随机矩阵。

output

[nSquareTokens]

float16

ND

输出,attentionmask计算的结果矩阵。

其中nSquareTokens的计算公式为:

约束说明

qSeqLen数组长度不超过32,且要求各元素大于0。