旋转位置编码处理。
struct RopeParam { int32_t rotaryCoeff = 4; int32_t cosFormat = 0; bool operator == (const RopeParam &other) const { return this->rotaryCoeff == other.rotaryCoeff && this->cosFormat == other.cosFormat; } };
成员名称 |
描述 |
---|---|
rotaryCoeff |
rope,旋转系数,对半旋转是2,支持配置2、4或headDim / 2。 |
cosFormat |
训练用参数,支持配置0或1 |
参数 |
维度 |
数据类型 |
格式 |
---|---|---|---|
q |
[ntokens, hiddenSizeQ] |
float16/bf16 |
ND |
k |
[ntokens, hiddenSizeK] |
float16/bf16 |
ND |
cos |
[ntokens, headDim] / [ntokens, headDim / 2] |
float16/float/bf16 |
ND
|
sin |
[ntokens, headDim] / [ntokens, headDim/ 2] |
float16/float/bf16 |
ND
|
seqLen |
[batch] |
uint32/int32 |
ND |
参数 |
维度 |
数据类型 |
格式 |
---|---|---|---|
ropeQ |
[ntokens, hiddenSizeQ] |
float16/bf16 |
ND |
ropeK |
[ntokens, hiddenSizeK] |
float16/bf16 |
ND |