RoPE
- 原始代码:
class Attention(nn.Module): def __init__(self, xxx): # 省略 def forward(self, hidden_states, freqs_cis_img): # 省略 # 对query进行旋转位置编码处理,apply_rotary_emb为原始代码中的方法 query = apply_rotary_emb(query, freqs_cis_img)
- 调用def rotary_position_embedding优化后的代码:
from mindiesd import rotary_position_embedding class Attention(nn.Module): def __init__(self, xxx): # 省略 def forward(self, hidden_states, freqs_cis_img): # 省略 cos, sin = freqs_cis_img cos, sin = cos.to(x.device), sin.to(x.device) query = rotary_position_embedding(query, cos, sin, rotated_mode="rotated_half", head_first=False, fused=True) key = rotary_position_embedding(key, cos, sin, rotated_mode="rotated_half", head_first=False, fused=True)
父主题: layer层