Sliding window attention(SWA)功能

功能

Mistral 7B,Attention部分在GQA(Group Query Attention)的基础上,叠加了SWA(Sliding Window Attention)的优化,可以进一步提高推理速度,并降低显存。

当前FA算子在进行推理时采用的是取所有KV进行计算,可以应用SWA提升推理速度,越远距离的信息,对当前位置的重要性越低,对于距离超过窗口大小的两个token不参与注意力分数计算。

通过设计mask实现n>m或m-n>windowLen(n为kvId, m为qId)时,注意力分数跳过计算。如图1,若windowSize = 3, “The”将不会计算与“the”之间的注意力分数。

图1 Sliding Window Attention的mask功能示意

Decoder支持仅长度为windowLen的最新KV历史参与计算。

kvCache优化:比如窗口大小=4,则当第5个token需要缓存,直接替换掉第1个token,这样就可以保持kv缓存有一个最大值(为窗口大小),而不会无限增长。

开启方式

特殊约束

SWA mask样例生成参考

当windowSize >= seqlen时,mask和不开启SWA时一样,为上三角,否则mask生成可参考如下: