旋转位置编码后进行concat操作。
rope计算公式:
kernel计算流程如下图所示:
1 2 3 |
struct RopeQConcatParam { uint8_t rsv[16] = {0}; }; |
成员名称 |
类型 |
默认值 |
取值范围 |
是否必选 |
描述 |
---|---|---|---|---|---|
rsv |
uint8_t[] |
{0} |
[0] |
否 |
预留字段。 |
参数 |
维度 |
数据类型 |
格式 |
描述 |
---|---|---|---|---|
q |
[ntokens, hidden_size_q] |
float16/bf16 |
ND |
2维。 |
cos |
[ntokens, head_dim] |
与q一致 |
ND |
2维。 |
sin |
[ntokens, head_dim] |
与q一致 |
ND |
2维。 |
ConcatInput |
[ntokens, head_num, concat_size] |
与q一致 |
ND |
3维。 |
参数 |
维度 |
数据类型 |
格式 |
描述 |
---|---|---|---|---|
ropeQConcat |
[ntokens, head_num, head_dim+concat_size] |
与q一致 |
ND |
3维。 |