昇腾社区首页
中文
注册

attention_forward

  • 自torch.nn.functional.scaled_dot_product_attention迁移
    • 原始代码:
      query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
      key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
      value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
      # the output of sdp = (batch, num_heads, seq_len, head_dim)
      hidden_states = F.scaled_dot_product_attention(
          query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
      )
      hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
    • 调用def attention_forward优化后的代码:
      from mindiesd import attention_forward
      # q,k,v shape is batch, seq_len, num_heads, head_dim
      query = query.view(batch_size, -1, attn.heads, head_dim)
      key = key.view(batch_size, -1, attn.heads, head_dim)
      value = value.view(batch_size, -1, attn.heads, head_dim)
      # the input shape of attention_forward = (batch, seq_len, num_heads, head_dim)
      # the output of attention_forward = (batch, seq_len, num_heads, head_dim)
      hidden_states = attention_forward(query, key, value, attn_mask=attention_mask)
      hidden_states = hidden_states.reshape(batch_size, -1, attn.heads * head_dim)
  • 自flash_attention.flash_attn_func迁移
    • 原始代码:
      q = torch.randn(batch_size, seqlen_q, nheads, d, device=device, dtype=dtype)
      k = torch.randn(batch_size, seqlen_k, nheads, d, device=device, dtype=dtype)
      v = torch.randn(batch_size, seqlen_k, nheads, d, device=device, dtype=dtype)
      out = flash_attention.flash_attn_func(q, k, v)
    • 调用def attention_forward优化后的代码:
      from mindiesd import attention_forward
      q = torch.randn(batch_size, seqlen_q, nheads, d, device=device, dtype=dtype)
      k = torch.randn(batch_size, seqlen_k, nheads, d, device=device, dtype=dtype)
      v = torch.randn(batch_size, seqlen_k, nheads, d, device=device, dtype=dtype)
      out = attention_forward(q, k, v)
  • 注意attention_forward接口的输入shape为(batch, seq_len, num_heads, head_dim),输出shape为(batch, seq_len, num_heads, head_dim)。
  • attention_forward接口仅提供前向推理功能,不提供反向梯度计算,因此迁移时需要去掉dropout,并将输入tensor梯度设置为False。