昇腾社区首页
中文
注册

RoPE

  • 原始代码:
    class Attention(nn.Module):
        def __init__(self, xxx): 
            # 省略
        def forward(self, hidden_states, freqs_cis_img):
            # 省略
            # 对query进行旋转位置编码处理,apply_rotary_emb为原始代码中的方法
            query = apply_rotary_emb(query, freqs_cis_img)
  • 调用def rotary_position_embedding优化后的代码:
    from mindiesd import rotary_position_embedding
    
    class Attention(nn.Module):
        def __init__(self, xxx): 
            # 省略
        def forward(self, hidden_states, freqs_cis_img):
            # 省略
            cos, sin = freqs_cis_img
            cos, sin = cos.to(x.device), sin.to(x.device)
            query = rotary_position_embedding(query, cos, sin, rotated_mode="rotated_half", head_first=False, fused=True)
            key = rotary_position_embedding(key, cos, sin, rotated_mode="rotated_half", head_first=False, fused=True)