昇腾社区首页
中文
注册

RopeQConcatOperation(代码开放)

产品支持情况

硬件型号

是否支持

Atlas A3 推理系列产品 / Atlas A3 训练系列产品

x

Atlas A2 训练系列产品 / Atlas 800I A2 推理产品

Atlas 训练系列产品

x

Atlas 推理系列产品

x

Atlas 200I/500 A2 推理产品

x

功能说明

旋转位置编码后进行concat操作。

计算公式

rope计算公式

计算图

kernel计算流程如下图所示

定义

1
2
3
struct RopeQConcatParam {
    uint8_t rsv[16] = {0};
};

参数列表

成员名称

类型

默认值

取值范围

是否必选

描述

rsv

uint8_t[]

{0}

[0]

预留字段。

输入

参数

维度

数据类型

格式

描述

q

[ntokens, hidden_size_q]

float16/bf16

ND

2维。

cos

[ntokens, head_dim]

与q一致

ND

2维。

sin

[ntokens, head_dim]

与q一致

ND

2维。

ConcatInput

[ntokens, head_num, concat_size]

与q一致

ND

3维。

输出

参数

维度

数据类型

格式

描述

ropeQConcat

[ntokens, head_num, head_dim+concat_size]

与q一致

ND

3维。

约束说明

  • hidden_size_q = head_dim * head_num。
  • head_dim*sizeof(dtype) 需要32Byte对齐,即head_dim需要是16的整数倍且需小于等于64。
  • concat_size*sizeof(dtype) 需要32Byte对齐,即concat_size需要是16的整数倍。
  • 保证head_dim * 26 + concat_size * 2<maxUbSize(196352) ,对head_num无限制。