昇腾社区首页
中文
注册
开发者
下载

def attention_forward_varlen

函数功能

不等长场景的注意力计算模块,其具体使用方法请参见attention_forward_varlen

函数原型

def attention_forward_varlen(
        q: torch.Tensor,
        k: torch.Tensor,
        v: torch.Tensor,
        cu_seqlens_q: list[torch.Tensor],
        cu_seqlens_k: list[torch.Tensor],
        max_seqlen_q: Optional[int] = None,
        max_seqlen_k: Optional[int] = None,
        dropout_p: float = 0.0,
        softmax_scale: Optional[float] = None,
        causal: bool = False,
        window_size: Optional[int] = None,
        softcap: Optional[float] = None,
        alibi_slopes: Optional[torch.Tensor] = None,
        deterministic: Optional[bool] = None,
        return_attn_probs: Optional[bool] = None,
        block_table: Optional[torch.Tensor] = None,
):

参数说明

参数名

输入/输出

类型

说明

q

输入

torch.Tensor

查询(Query)张量。输入格式必须为(total_q, num_heads, head_dim)。total_q 是 batch 中所有query序列长度的总和(即“packed”格式,无填充)。num_heads 是注意力头数,head_dim 是每个头的维度。

k

输入

torch.Tensor

键(Key)张量。输入格式必须为(total_k, num_heads, head_dim)。total_k 是 batch 中所有key序列长度的总和(即“packed”格式,无填充)。num_heads 是注意力头数,head_dim 是每个头的维度。

v

输入

torch.Tensor

值(Value)张量。输入格式必须为(total_v, num_heads, head_dim)。total_v 是 batch 中所有value序列长度的总和(即“packed”格式,无填充)。num_heads 是注意力头数,head_dim 是每个头的维度。

cu_seqlens_q

输入

list[torch.Tensor]

查询序列的累积长度,用于将 packed 的 q 张量分割成独立序列。

cu_seqlens_k

输入

list[torch.Tensor]

键/值序列的累积长度。用于索引 k 和 v。

max_seqlen_q

输入

-

该参数仅支持与flash_attn_varlen_func保持一致,npu无需配置该参数,也不支持该参数。

max_seqlen_k

输入

-

该参数仅支持与flash_attn_varlen_func保持一致,npu无需配置该参数,也不支持该参数。

dropout_p

输入

float

表示数据需要忽略的概率,默认值为0.0,推理阶段建议保持默认值即可。

softmax_scale

输入

float

表示对QKT的缩放系数,若为None,则会根据head_dim ** -0.5进行缩放

causal

输入

bool

  • causal=true时,算子会传入下三角形式的atten mask;
  • causal=false时,算子不会传入atten mask。

window_size

输入

-

该参数仅支持与flash_attn_varlen_func保持一致,npu无需配置该参数,也不支持该参数。

softcap

输入

-

该参数仅支持与flash_attn_varlen_func保持一致,npu无需配置该参数,也不支持该参数。

alibi_slopes

输入

-

该参数仅支持与flash_attn_varlen_func保持一致,npu无需配置该参数,也不支持该参数。

deterministic

输入

-

该参数仅支持与flash_attn_varlen_func保持一致,npu无需配置该参数,也不支持该参数。

return_attn_probs

输入

-

该参数仅支持与flash_attn_varlen_func保持一致,npu无需配置该参数,也不支持该参数。

block_table

输入

-

该参数仅支持与flash_attn_varlen_func保持一致,npu无需配置该参数,也不支持该参数。

返回值说明

返回不等长场景的注意力机制的输出张量。