def attention_forward
函数功能
注意力计算模块,集成多种计算模式,包含原始计算模式、多种近似优化,用于搜索最优的计算公式。该函数在MindIE SD仓的路径为:mindiesd/layers/flash_attn/attention_forward.py,其具体使用方法请参见attention_forward。
函数原型
def attention_forward(query, key, value, attn_mask=None, scale=None, fused=True, **kwargs):
参数说明
参数名 |
输入/输出 |
类型 |
说明 |
---|---|---|---|
query |
输入 |
torch.Tensor |
注意力计算公式的q输入,输入格式必须为(batch, seq_len, num_heads, head_dim)。 |
key |
输入 |
torch.Tensor |
注意力计算公式的k输入,输入格式必须为(batch, seq_len, num_heads, head_dim)。 |
value |
输入 |
torch.Tensor |
注意力计算公式的v输入,输入格式必须为(batch, seq_len, num_heads, head_dim)。 |
attn_mask |
输入 |
torch.Tensor |
注意力掩码。 |
scale |
输入 |
torch.Tensor |
输入缩放。 |
fused |
输入 |
bool |
是否开启融合操作。
|
kwargs |
输入 |
- |
其他参数,包含以下三个可选项:
|
返回值说明
返回搜索后的最优注意力计算公式。

- 接口的输入shape为(batch, seq_len, num_heads, head_dim),输出shape为(batch, seq_len, num_heads, head_dim)。
- manual模式下设置算子的layout只会影响内部算子执行的layout,接口的输入shape和输出shape依然为(batch, seq_len, num_heads, head_dim)。
父主题: API参考(Python)