RopeOperation

功能

旋转位置编码处理。

约束

定义

struct RopeParam {
    int32_t rotaryCoeff = 4;
    int32_t cosFormat = 0;
    bool operator == (const RopeParam &other) const
    {
        return this->rotaryCoeff == other.rotaryCoeff && this->cosFormat == other.cosFormat;
    }
};

成员

成员名称

描述

rotaryCoeff

rope,旋转系数,对半旋转是2,支持配置2、4或headDim / 2。

cosFormat

训练用参数,支持配置0或1

输入

参数

维度

数据类型

格式

q

[ntokens, hiddenSizeQ]

float16/bf16

ND

k

[ntokens, hiddenSizeK]

float16/bf16

ND

cos

[ntokens, headDim] / [ntokens, headDim / 2]

float16/float/bf16

ND

  • 当cos的第二个维度与参数rotaryCoeff相等时,其值为headDim/2。
  • 当cos的第二个维度与参数rotaryCoeff不相等时,其值为headDim。

sin

[ntokens, headDim] / [ntokens, headDim/ 2]

float16/float/bf16

ND

  • 当sin的第二个维度与参数rotaryCoeff相等时,其值为headDim/2。
  • 当sin的第二个维度与参数rotaryCoeff不相等时,其值为headDim。

seqLen

[batch]

uint32/int32

ND

输出

参数

维度

数据类型

格式

ropeQ

[ntokens, hiddenSizeQ]

float16/bf16

ND

ropeK

[ntokens, hiddenSizeK]

float16/bf16

ND