昇腾社区首页
中文
注册

def rotary_position_embedding

函数功能

旋转位置编码技术,提升DiT模型在处理序列数据时的性能和效率。该函数在MindIE SD仓的路径为:mindiesd/layers/embedding.py,其具体使用方法请参见RoPE

函数原型

def rotary_position_embedding(x: torch.Tensor,
                              cos: torch.Tensor,
                              sin: torch.Tensor,
                              rotated_mode: str = "rotated_half",
                              head_first: bool = False,
                              fused: bool = True) -> torch.Tensor:

参数说明

参数名

输入/输出

类型

说明

x

输入

torch.Tensor

应用旋转嵌入的q或k张量,shape要求输入为4维,一般为[B, N, S, D]或[B, S, N, D]或[S, B, N, D]。x可表示为 [x_0, x_1,..., x_d/2-1, x_d/2, x_d/2+1,..., x_d-1]。

cos

输入

torch.Tensor

预计算的复指数cos频率张量。shape要求输入为2维或4维,一般为[S,D]或[1, 1, S, D]或[1, S, 1, D]或[S, 1, 1, D]。

sin

输入

torch.Tensor

预计算的复指数sin频率张量。shape要求输入为2维或4维,一般为[S,D]或[1, 1, S, D]或[1, S, 1, D]或[S, 1, 1, D]。

rotated_mode

输入

str

旋转模式:支持rotated_half和rotated_interleaved两种模式。

  • rotated_half:对半旋转,将x旋转为[-x_d/2, -x_d/2+1,..., -x_d-1, x_0, x_1,..., x_d/2-1]。
  • rotated_interleaved:相邻旋转,将x旋转为[-x_1, x_0, -x_3, x_2,..., -x_d-1, x_d-2]。

head_first

输入

bool

当x的layout中,head_dim在seqlen前面时,设置为True,否则设置为False。

fused

输入

bool

是否开启融合操作。

  • True:选择高性能的RoPE融合算子。
  • False:使用原始计算公式。

返回值说明

返回经旋转嵌入修改后的q张量和k张量。