RopeQConcatOperation

功能

旋转位置编码后进行concat操作。

计算公式

rope计算公式

计算图

kernel计算流程如下图所示

定义

1
2
3
struct RopeQConcatParam {
    uint8_t rsv[16] = {0};
};

参数列表

成员名称

类型

默认值

取值范围

是否必选

描述

rsv

uint8_t[]

{0}

[0]

预留字段。

输入

参数

维度

数据类型

格式

描述

q

[ntokens, hidden_size_q]

float16/bf16

ND

2维。

cos

[ntokens, head_dim]

与q一致

ND

2维。

sin

[ntokens, head_dim]

与q一致

ND

2维。

ConcatInput

[ntokens, head_num, concat_size]

与q一致

ND

3维。

输出

参数

维度

数据类型

格式

描述

ropeQConcat

[ntokens, head_num, head_dim+concat_size]

与q一致

ND

3维。

规格约束