def rotary_position_embedding
函数功能
旋转位置编码技术,提升DiT模型在处理序列数据时的性能和效率。该函数在MindIE SD仓的路径为:mindiesd/layers/embedding.py,其具体使用方法请参见RoPE。
函数原型
def rotary_position_embedding(x: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor, rotated_mode: str = "rotated_half", head_first: bool = False, fused: bool = True) -> torch.Tensor:
参数说明
参数名 |
输入/输出 |
类型 |
说明 |
---|---|---|---|
x |
输入 |
torch.Tensor |
应用旋转嵌入的q或k张量,shape要求输入为4维,一般为[B, N, S, D]或[B, S, N, D]或[S, B, N, D]。x可表示为 [x_0, x_1,..., x_d/2-1, x_d/2, x_d/2+1,..., x_d-1]。 |
cos |
输入 |
torch.Tensor |
预计算的复指数cos频率张量。shape要求输入为2维或4维,一般为[S,D]或[1, 1, S, D]或[1, S, 1, D]或[S, 1, 1, D]。 |
sin |
输入 |
torch.Tensor |
预计算的复指数sin频率张量。shape要求输入为2维或4维,一般为[S,D]或[1, 1, S, D]或[1, S, 1, D]或[S, 1, 1, D]。 |
rotated_mode |
输入 |
str |
旋转模式:支持rotated_half和rotated_interleaved两种模式。
|
head_first |
输入 |
bool |
当x的layout中,head_dim在seqlen前面时,设置为True,否则设置为False。 |
fused |
输入 |
bool |
是否开启融合操作。
|
返回值说明
返回经旋转嵌入修改后的q张量和k张量。
父主题: API参考(Python)