BNSD维度输入
功能说明
一般的,传入SelfAttention算子的q,k,v的维度为[batch, seqLen, headNum, head_dim],即[b, s, n, d],或者是它合轴后的变种。在某些场景下,传入[b, n, s, d]性能更好。
开启方式
参数“inputLayout”置为TYPE_BNSD。
输入参数如下:
- “calctype”不为PA_ENCODER时,在
Atlas A2 训练系列产品 /Atlas 800I A2 推理产品 和Atlas A3 推理系列产品 /Atlas A3 训练系列产品 上:输入tensor
维度
数据类型
格式
描述
query
[batch, headNum, seqLen, headSize]
float16/bf16
ND
query矩阵。
cacheK
[layer, batch, headNum, seqLen, headSize]
float16/bf16
ND
- NPU:存储之前所有的k,本次执行时将key刷新到cacheK上 。
- CPU:输入为已经准备好的cacheK,输入时根据分成batch个tensor作为std::vector<tensor>传入,此时layer维度要求为1。
cacheV
[layer, batch, headNum, seqLen, headSize]
float16/bf16
ND
- NPU:存储之前所有的v,本次执行时将value刷新到cacheV上。
- CPU:输入为已经准备好的cacheV,输入时根据分成batch个tensor作为std::vector<tensor>传入,此时layer维度要求为1。
- “calctype”不为PA_ENCODER时,在
Atlas 推理系列产品 上:输入tensor
维度
数据类型
格式
描述
query
[batch, headNum, seqLen, headSize]
float16/bf16
ND
query矩阵。
cacheK
[layer, batch*headNum, headSize / 16, kvMaxSeq, 16]
float16/bf16
NZ
- NPU:存储之前所有的k,本次执行时将key刷新到cacheK上 。
- CPU:输入为已经准备好的cacheK,输入时根据分成batch个tensor作为std::vector<tensor>传入,此时layer维度要求为1。
cacheV
[layer, batch*headNum, headSize / 16, kvMaxSeq, 16]
float16/bf16
NZ
- NPU:存储之前所有的v,本次执行时将value刷新到cacheV上。
- CPU:输入为已经准备好的cacheV,输入时根据分成batch个tensor作为std::vector<tensor>传入,此时layer维度要求为1。
- “calctype”为PA_ENCODER时:
参数
维度
数据类型
格式
描述
query
[batch, headNum, qSeqLen, headSize]
bf16
ND
query矩阵。
cacheK
[batch, kvHeadNum, kvSeqLen, headSize]
bf16
ND
key矩阵。
cacheV
[batch, kvHeadNum, kvSeqLen, headSize]
bf16
ND
value矩阵。
SeqLen
[batch] / [2, batch]
int32/uint32
ND
若shape为[batch] ,代表每个batch的序列长度,query,cacheK,cacheV相同。
若shape为[2,batch],SeqLen[0]代表query的序列长度,SeqLen[1]代表cacheK,cacheV的序列长度。
attnOut
[batch, headNum, qSeqLen, headSize]
bf16
ND
输出tensor。
约束说明
- BNSD只有开启kv-bypass功能,即参数“kvcacheCfg”置为K_BYPASS_V_BYPASS或“calctype”置为PA_ENCODER时才可用。
- 使用BNSD维度输入且“calctype”不为PA_ENCODER时,“maskType”不能为MASK_TYPE_UNDEFINED。“calctype”为PA_ENCODER时,“maskType”只能为MASK_TYPE_UNDEFINED。
- seqlen的dimNum只有“calctype”为PA_ENCODER且开启BNSD之后才能为2。
- “calctype”为PA_ENCODER时不支持
Atlas 推理系列产品 。 - query、key、value的HeadSize必须相同且<=256。
- “calctype”为PA_ENCODER时使用BNSD维度输入功能,不支持高精度。