attention_forward_varlen
- 自flash_attn .flash_attn_varlen_func迁移,不使能causal时。
- 原始代码:
1out = flash_attn_varlen_func( q, k, v, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k, dropout_p=0.0, softmax_scale=None, causal=False)
- 调用def attention_forward_varlen优化后的代码:
1 2
from mindiesd import attention_forward_varlen out = attention_forward_varlen( q, k, v, cu_seqlens_q, cu_seqlens_k, dropout_p=0.0, softmax_scale=None, causal=False)
- 原始代码:
- 自flash_attn .flash_attn_varlen_func迁移,使能causal时。
- 原始代码:
1out = flash_attn_varlen_func( q, k, v, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k, dropout_p=0.0, softmax_scale=None, causal=True)
- 调用def attention_forward_varlen优化后的代码:
1 2
from mindiesd import attention_forward_varlen out = attention_forward_varlen( q, k, v, cu_seqlens_q, cu_seqlens_k, dropout_p=0.0, softmax_scale=None, causal=True)
- 原始代码:
父主题: layer层