昇腾社区首页
中文
注册
开发者
下载

attention_forward_varlen

  • 自flash_attn .flash_attn_varlen_func迁移,不使能causal时。
    • 原始代码:
      1
      out = flash_attn_varlen_func( q, k, v, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k, dropout_p=0.0, softmax_scale=None, causal=False)
      
    • 调用def attention_forward_varlen优化后的代码:
      1
      2
      from mindiesd import attention_forward_varlen
      out = attention_forward_varlen( q, k, v, cu_seqlens_q, cu_seqlens_k, dropout_p=0.0, softmax_scale=None, causal=False)
      
  • 自flash_attn .flash_attn_varlen_func迁移,使能causal时。
    • 原始代码:
      1
      out = flash_attn_varlen_func( q, k, v, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k, dropout_p=0.0, softmax_scale=None, causal=True)
      
    • 调用def attention_forward_varlen优化后的代码:
      1
      2
      from mindiesd import attention_forward_varlen
      out = attention_forward_varlen( q, k, v, cu_seqlens_q, cu_seqlens_k, dropout_p=0.0, softmax_scale=None, causal=True)