SelfAttentionOperation
功能
完成Attention计算功能。

约束
- tokenOffset ≥ seqLen。
- 对于输入cacheV,当参数isEncoder为true时,只支持数据类型为int32,当参数isEncoder为false时,只支持数据类型为float16。
- attentionMask的维度:
- 支持[batch, maxSeqLen, maxSeqLen],此时需要保证batch的维度和其他输入batch维度的一致性。
- 支持[maxSeqLen, maxSeqLen],此时所有batch的mask相同。
- 支持[batch, 1, maxSeqLen],此时所有batch的mask不同,mask为向量。
定义
struct SelfAttentionParam { int32_t headDim = 0; int32_t headNum = 0; float qScale = 1; float qkScale = 1; bool isFusion = true; bool withCache = true; bool batchRunStatusEnable = false; int32_t kvHeadNum = 0; bool isEncoder = false; enum CoderType : int { UNDEFINED = 0, ENCODER, // encoder for flashAttention DECODER // decoder for flashAttention }; CoderType coderType = UNDEFINED bool isSupportAlibi = false; };
成员
成员名称 |
描述 |
---|---|
headDim |
头维度。 Atlas 推理系列产品(配置Ascend 310P AI处理器)仅支持配置为128。 |
headNum |
多头数量。 headNum需大于或等于0。 |
qScale |
q缩放系数。 |
qkScale |
qk缩放系数。 |
isFusion |
是否使用融合算子。 |
withCache |
仅当isFusion为false时有用,intensor有pastKVCache。 |
batchRunStatusEnable |
是否动态batch。 |
kvHeadNum |
该值需要用户根据使用的模型实际情况传入。
|
isEncoder |
编解码策略。
|
coderType |
使用flashAttention时的编解码策略。
|
isSupportAlibi |
输入的attentionMask是否融合了alibi。 |
输入
参数 |
维度 |
数据类型 |
格式 |
描述 |
---|---|---|---|---|
q |
[nTokens, qHiddenSize] |
float16 |
ND |
- |
k |
[nTokens, hiddenSize] |
float16 |
ND |
当前的key。 |
v |
[nTokens, hiddenSize] |
float16 |
ND |
当前的value。 |
cacheK |
[layer, batch, maxSeqLen, hiddenSize] |
float16 |
ND/NZ |
之前所有的key, 待与当前key拼接。 |
cacheV |
[layer, batch, maxSeqLen, hiddenSize] |
float16/int32 |
ND/NZ |
之前所有value, 待与当前value拼接。 |
attentionMask |
[maxSeqLen, maxSeqLen] [batch, maxSeqLen, maxSeqLen] [batch, 1, maxSeqLen] [batch, headNum, maxSeqLen, maxSeqLen] |
float16 |
ND/NZ |
支持场景:
|
tokenOffset |
[batch] |
uint32/int32 |
ND |
计算完成后的token偏移。 |
seqLen |
[batch] |
uint32/int32 |
ND |
|
layerId |
[1] |
uint32/int32 |
ND |
当前处于第几个layer。 |
batchStatus |
[batch] |
uint32/int32 |
ND |
batchRunStatusEnable = true,即开启动态batch时,控制具体需要运算的batch。 |
输出
参数 |
维度 |
数据类型 |
格式 |
---|---|---|---|
output |
[ntokens, qHiddenSize] |
float16 |
ND |