attention_forward
- 自torch.nn.functional.scaled_dot_product_attention迁移
- 原始代码:
query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) # the output of sdp = (batch, num_heads, seq_len, head_dim) hidden_states = F.scaled_dot_product_attention( query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False ) hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
- 调用def attention_forward优化后的代码:
from mindiesd import attention_forward # q,k,v shape is batch, seq_len, num_heads, head_dim query = query.view(batch_size, -1, attn.heads, head_dim) key = key.view(batch_size, -1, attn.heads, head_dim) value = value.view(batch_size, -1, attn.heads, head_dim) # the input shape of attention_forward = (batch, seq_len, num_heads, head_dim) # the output of attention_forward = (batch, seq_len, num_heads, head_dim) hidden_states = attention_forward(query, key, value, attn_mask=attention_mask) hidden_states = hidden_states.reshape(batch_size, -1, attn.heads * head_dim)
- 原始代码:
- 自flash_attention.flash_attn_func迁移
- 原始代码:
q = torch.randn(batch_size, seqlen_q, nheads, d, device=device, dtype=dtype) k = torch.randn(batch_size, seqlen_k, nheads, d, device=device, dtype=dtype) v = torch.randn(batch_size, seqlen_k, nheads, d, device=device, dtype=dtype) out = flash_attention.flash_attn_func(q, k, v)
- 调用def attention_forward优化后的代码:
from mindiesd import attention_forward q = torch.randn(batch_size, seqlen_q, nheads, d, device=device, dtype=dtype) k = torch.randn(batch_size, seqlen_k, nheads, d, device=device, dtype=dtype) v = torch.randn(batch_size, seqlen_k, nheads, d, device=device, dtype=dtype) out = attention_forward(q, k, v)
- 原始代码:

- 注意attention_forward接口的输入shape为(batch, seq_len, num_heads, head_dim),输出shape为(batch, seq_len, num_heads, head_dim)。
- attention_forward接口仅提供前向推理功能,不提供反向梯度计算,因此迁移时需要去掉dropout,并将输入tensor梯度设置为False。
父主题: layer层