def attention_forward_varlen
函数功能
不等长场景的注意力计算模块,其具体使用方法请参见attention_forward_varlen。
函数原型
def attention_forward_varlen(
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
cu_seqlens_q: list[torch.Tensor],
cu_seqlens_k: list[torch.Tensor],
max_seqlen_q: Optional[int] = None,
max_seqlen_k: Optional[int] = None,
dropout_p: float = 0.0,
softmax_scale: Optional[float] = None,
causal: bool = False,
window_size: Optional[int] = None,
softcap: Optional[float] = None,
alibi_slopes: Optional[torch.Tensor] = None,
deterministic: Optional[bool] = None,
return_attn_probs: Optional[bool] = None,
block_table: Optional[torch.Tensor] = None,
):
参数说明
参数名 |
输入/输出 |
类型 |
说明 |
|---|---|---|---|
q |
输入 |
torch.Tensor |
查询(Query)张量。输入格式必须为(total_q, num_heads, head_dim)。total_q 是 batch 中所有query序列长度的总和(即“packed”格式,无填充)。num_heads 是注意力头数,head_dim 是每个头的维度。 |
k |
输入 |
torch.Tensor |
键(Key)张量。输入格式必须为(total_k, num_heads, head_dim)。total_k 是 batch 中所有key序列长度的总和(即“packed”格式,无填充)。num_heads 是注意力头数,head_dim 是每个头的维度。 |
v |
输入 |
torch.Tensor |
值(Value)张量。输入格式必须为(total_v, num_heads, head_dim)。total_v 是 batch 中所有value序列长度的总和(即“packed”格式,无填充)。num_heads 是注意力头数,head_dim 是每个头的维度。 |
cu_seqlens_q |
输入 |
list[torch.Tensor] |
查询序列的累积长度,用于将 packed 的 q 张量分割成独立序列。 |
cu_seqlens_k |
输入 |
list[torch.Tensor] |
键/值序列的累积长度。用于索引 k 和 v。 |
max_seqlen_q |
输入 |
- |
该参数仅支持与flash_attn_varlen_func保持一致,npu无需配置该参数,也不支持该参数。 |
max_seqlen_k |
输入 |
- |
该参数仅支持与flash_attn_varlen_func保持一致,npu无需配置该参数,也不支持该参数。 |
dropout_p |
输入 |
float |
表示数据需要忽略的概率,默认值为0.0,推理阶段建议保持默认值即可。 |
softmax_scale |
输入 |
float |
表示对QKT的缩放系数,若为None,则会根据head_dim ** -0.5进行缩放 |
causal |
输入 |
bool |
|
window_size |
输入 |
- |
该参数仅支持与flash_attn_varlen_func保持一致,npu无需配置该参数,也不支持该参数。 |
softcap |
输入 |
- |
该参数仅支持与flash_attn_varlen_func保持一致,npu无需配置该参数,也不支持该参数。 |
alibi_slopes |
输入 |
- |
该参数仅支持与flash_attn_varlen_func保持一致,npu无需配置该参数,也不支持该参数。 |
deterministic |
输入 |
- |
该参数仅支持与flash_attn_varlen_func保持一致,npu无需配置该参数,也不支持该参数。 |
return_attn_probs |
输入 |
- |
该参数仅支持与flash_attn_varlen_func保持一致,npu无需配置该参数,也不支持该参数。 |
block_table |
输入 |
- |
该参数仅支持与flash_attn_varlen_func保持一致,npu无需配置该参数,也不支持该参数。 |
返回值说明
返回不等长场景的注意力机制的输出张量。