昇腾社区首页
中文
注册

def attention_forward

函数功能

注意力计算模块,集成多种计算模式,包含原始计算模式、多种近似优化,用于搜索最优的计算公式。该函数在MindIE SD仓的路径为:mindiesd/layers/flash_attn/attention_forward.py,其具体使用方法请参见attention_forward

函数原型

def attention_forward(query, key, value, attn_mask=None, scale=None, fused=True, **kwargs):

参数说明

参数名

输入/输出

类型

说明

query

输入

torch.Tensor

注意力计算公式的q输入,输入格式必须为(batch, seq_len, num_heads, head_dim)。

key

输入

torch.Tensor

注意力计算公式的k输入,输入格式必须为(batch, seq_len, num_heads, head_dim)。

value

输入

torch.Tensor

注意力计算公式的v输入,输入格式必须为(batch, seq_len, num_heads, head_dim)。

attn_mask

输入

torch.Tensor

注意力掩码。

scale

输入

torch.Tensor

输入缩放。

fused

输入

bool

是否开启融合操作。

  • True:可选择融合算子。
  • False:使用原始计算公式,即缩放点积注意力(scaled dot-product attention)。

kwargs

输入

-

其他参数,包含以下三个可选项:

  • opt_mode:str类型,支持runtime、static和manual三种模式,默认为runtime。
    • runtime:在运行时动态搜索最佳融合算子,仅第一次搜索会消耗时间。
    • static:通过静态表获取最佳融合算子。
    • manual:手动设置融合算子类型。
  • op_type:str类型,表示融合算子类型,支持prompt_flash_attn、fused_attn_score和ascend_laser_attention。
  • layout:str类型,表示注意力机制布局方式,仅当opt_mode参数设置为manual时生效,支持BNSD、BSND和BSH。

返回值说明

返回搜索后的最优注意力计算公式。

  • 接口的输入shape为(batch, seq_len, num_heads, head_dim),输出shape为(batch, seq_len, num_heads, head_dim)。
  • manual模式下设置算子的layout只会影响内部算子执行的layout,接口的输入shape和输出shape依然为(batch, seq_len, num_heads, head_dim)。