昇腾社区首页
中文
注册
开发者
下载

attention_forward_varlen

  • 自flash_attn .flash_attn_varlen_func迁移,不使能causal时。
    • 原始代码:
      out = flash_attn_varlen_func( q, k, v, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k, dropout_p=0.0, softmax_scale=None, causal=False)
    • 调用def attention_forward_varlen优化后的代码:
      from mindiesd import attention_forward_varlen
      out = attention_forward_varlen( q, k, v, cu_seqlens_q, cu_seqlens_k, dropout_p=0.0, softmax_scale=None, causal=False)
  • 自flash_attn .flash_attn_varlen_func迁移,使能causal时。
    • 原始代码:
      out = flash_attn_varlen_func( q, k, v, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k, dropout_p=0.0, softmax_scale=None, causal=True)
    • 调用def attention_forward_varlen优化后的代码:
      from mindiesd import attention_forward_varlen
      out = attention_forward_varlen( q, k, v, cu_seqlens_q, cu_seqlens_k, dropout_p=0.0, softmax_scale=None, causal=True)