昇腾社区首页
中文
注册

aclnnApplyRotaryPosEmbV2

产品支持情况

产品 是否支持
[object Object]Atlas A3 训练系列产品/Atlas A3 推理系列产品[object Object]
[object Object]Atlas A2 训练系列产品/Atlas 800I A2 推理产品/A200I A2 Box 异构组件[object Object]
[object Object]Atlas 200I/500 A2 推理产品[object Object] ×
[object Object]Atlas 推理系列产品 [object Object]
[object Object]Atlas 训练系列产品[object Object] ×

功能说明

  • 算子功能:推理网络为了提升性能,将query和key两路算子融合成一路。执行旋转位置编码计算,计算结果执行原地更新。 本接口针对undefined做了如下功能变更,请根据实际情况选择合适的接口:

    • 新增rotaryMode参数,用于控制不同的旋转编码方式
  • 计算公式

(1)rotaryMode为"half": $$ query_q1 = query[..., : query.shape[-1] // 2] $$

query_q2=query[...,query.shape[1]//2:]query\_q2 = query[..., query.shape[-1] // 2 :] query_rotate=torch.cat((query_q2,query_q1),dim=1)query\_rotate = torch.cat((-query\_q2, query\_q1), dim=-1) key_k1=key[...,:key.shape[1]//2]key\_k1 = key[..., : key.shape[-1] // 2] key_k2=key[...,key.shape[1]//2:]key\_k2 = key[..., key.shape[-1] // 2 :] key_rotate=torch.cat((key_k2,key_k1),dim=1)key\_rotate = torch.cat((-key\_k2, key\_k1), dim=-1) q_embed=(querycos)+query_rotatesinq\_embed = (query * cos) + query\_rotate * sin k_embed=(keycos)+key_rotatesink\_embed = (key * cos) + key\_rotate * sin

(2)rotaryMode为"quarter": $$ query_q1 = query[..., : query.shape[-1] // 4] $$

query_q2=query[...,query.shape[1]//4:query.shape[1]//2]query\_q2 = query[..., query.shape[-1] // 4 : query.shape[-1] // 2] query_q3=query[...,query.shape[1]//2:query.shape[1]//43]query\_q3 = query[..., query.shape[-1] // 2 : query.shape[-1] // 4 * 3] query_q4=query[...,query.shape[1]//43:]query\_q4 = query[..., query.shape[-1] // 4 * 3 :] query_rotate=torch.cat((query_q2,query_q1,query_q4,query_q3),dim=1)query\_rotate = torch.cat((-query\_q2, query\_q1, -query\_q4, query\_q3), dim=-1) key_q1=key[...,:key.shape[1]//4]key\_q1 = key[..., : key.shape[-1] // 4] key_q2=key[...,key.shape[1]//4:key.shape[1]//2]key\_q2 = key[..., key.shape[-1] // 4 : key.shape[-1] // 2] key_q3=key[...,key.shape[1]//2:key.shape[1]//43]key\_q3 = key[..., key.shape[-1] // 2 : key.shape[-1] // 4 * 3] key_q4=key[...,key.shape[1]//43:]key\_q4 = key[..., key.shape[-1] // 4 * 3 :] key_rotate=torch.cat((key_q2,key_q1,key_q4,key_q3),dim=1)key\_rotate = torch.cat((-key\_q2, key\_q1, -key\_q4, key\_q3), dim=-1) q_embed=(querycos)+query_rotatesinq\_embed = (query * cos) + query\_rotate * sin k_embed=(keycos)+key_rotatesink\_embed = (key * cos) + key\_rotate * sin

(3)rotaryMode为"interleave": $$ query_q1 = query[..., ::2].view(-1, 1) $$

query_q2=query[...,1::2].view(1,1)query\_q2 = query[..., 1::2].view(-1, 1) query_rotate=torch.cat((query_q2,query_q1),dim=1).view(query.shape[0],query.shape[1],query.shape[2],query.shape[3])query\_rotate = torch.cat((-query\_q2, query\_q1), dim=-1).view(query.shape[0], query.shape[1], query.shape[2], query.shape[3]) key_q1=key[...,::2].view(1,1)key\_q1 = key[..., ::2].view(-1, 1) key_q2=key[...,1::2].view(1,1)key\_q2 = key[..., 1::2].view(-1, 1) key_rotate=torch.cat((key_q2,key_q1),dim=1).view(key.shape[0],key.shape[1],key.shape[2],key.shape[3])key\_rotate = torch.cat((-key\_q2, key\_q1), dim=-1).view(key.shape[0], key.shape[1], key.shape[2], key.shape[3]) q_embed=(querycos)+query_rotatesinq\_embed = (query * cos) + query\_rotate * sin k_embed=(keycos)+key_rotatesink\_embed = (key * cos) + key\_rotate * sin

函数原型

每个算子分为undefined,必须先调用“aclnnApplyRotaryPosEmbV2GetWorkspaceSize”接口获取入参并根据流程计算所需workspace大小,再调用“aclnnApplyRotaryPosEmbV2”接口执行计算。

  • aclnnStatus aclnnApplyRotaryPosEmbV2GetWorkspaceSize(aclTensor *queryRef, aclTensor *keyRef, const aclTensor *cos, const aclTensor *sin, int64_t layout, char *rotaryMode, uint64_t *workspaceSize, aclOpExecutor **executor)
  • aclnnStatus aclnnApplyRotaryPosEmbV2(void *workspace, uint64_t workspaceSize, aclOpExecutor *executor, aclrtStream stream)

aclnnApplyRotaryPosEmbV2GetWorkspaceSize

  • 参数说明:

    • queryRef(aclTensor*,计算输入):表示要执行旋转位置编码的第一个张量,公式中的query,Device侧的aclTensor。支持undefinedundefined支持ND,维度为4维。计算结果原地更新。
      • [object Object]Atlas 推理系列产品[object Object]:数据类型支持FLOAT16、FLOAT32。
      • [object Object]Atlas A2 训练系列产品/Atlas 800I A2 推理产品/A200I A2 Box 异构组件[object Object]、[object Object]Atlas A3 训练系列产品/Atlas A3 推理系列产品[object Object]:数据类型支持BFLOAT16、FLOAT16、FLOAT32。
    • keyRef(aclTensor*,计算输入):表示要执行旋转位置编码的第二个张量,公式中的key,Device侧的aclTensor。支持undefinedundefined支持ND,维度为4维。
      • [object Object]Atlas 推理系列产品[object Object]:数据类型支持FLOAT16、FLOAT32。
      • [object Object]Atlas A2 训练系列产品/Atlas 800I A2 推理产品/A200I A2 Box 异构组件[object Object]、[object Object]Atlas A3 训练系列产品/Atlas A3 推理系列产品[object Object]:数据类型支持BFLOAT16、FLOAT16、FLOAT32。
    • cos(aclTensor*,计算输入):表示参与计算的位置编码张量,公式中的cos,Device侧的aclTensor。支持undefinedundefined支持ND,维度为4维。
      • [object Object]Atlas 推理系列产品[object Object]:数据类型支持FLOAT16、FLOAT32。
      • [object Object]Atlas A2 训练系列产品/Atlas 800I A2 推理产品/A200I A2 Box 异构组件[object Object]、[object Object]Atlas A3 训练系列产品/Atlas A3 推理系列产品[object Object]:数据类型支持BFLOAT16、FLOAT16、FLOAT32。
    • sin(aclTensor*,计算输入):表示参与计算的位置编码张量,公式中的sin,Device侧的aclTensor,支持undefinedundefined支持ND,维度为4维。
      • [object Object]Atlas 推理系列产品[object Object]:数据类型支持FLOAT16、FLOAT32。
      • [object Object]Atlas A2 训练系列产品/Atlas 800I A2 推理产品/A200I A2 Box 异构组件[object Object]、[object Object]Atlas A3 训练系列产品/Atlas A3 推理系列产品[object Object]:数据类型支持BFLOAT16、FLOAT16、FLOAT32。
    • layout(int64_t,计算输入):表示输入Tensor的布局格式,数据类型支持int64。
      • [object Object]Atlas 推理系列产品[object Object]、[object Object]Atlas A2 训练系列产品/Atlas 800I A2 推理产品/A200I A2 Box 异构组件[object Object]、[object Object]Atlas A3 训练系列产品/Atlas A3 推理系列产品[object Object]:支持1,代表格式为BSND的4维Tensor。
    • rotaryMode(char*, 计算输入):公式中的旋转模式,数据类型支持char.
      • [object Object]Atlas 推理系列产品[object Object]、[object Object]Atlas A2 训练系列产品/Atlas 800I A2 推理产品/A200I A2 Box 异构组件[object Object]、[object Object]Atlas A3 训练系列产品/Atlas A3 推理系列产品[object Object]:支持"half"模式。
    • workspaceSize(uint64_t*,出参):返回需要在Device侧申请的workspace大小。
    • executor(aclOpExecutor**,出参):返回op执行器,包含了算子计算流程。
  • 返回值:

    aclnnStatus:返回状态码,具体参见undefined

    [object Object]

aclnnApplyRotaryPosEmbV2

  • 参数说明:

    • workspace(void*,入参):在Device侧申请的workspace内存地址。
    • workspaceSize(uint64_t,入参):在Device侧申请的workspace大小,由第一段接口aclnnApplyRotaryPosEmbV2GetWorkspaceSize获取。
    • executor(aclOpExecutor*,入参):op执行器,包含了算子计算流程。
    • stream(aclrtStream,入参):指定执行任务的Stream。
  • 返回值:

    aclnnStatus:返回状态码,具体参见undefined

约束说明

  • [object Object]Atlas 推理系列产品[object Object]、[object Object]Atlas A2 训练系列产品/Atlas 800I A2 推理产品/A200I A2 Box 异构组件[object Object]、[object Object]Atlas A3 训练系列产品/Atlas A3 推理系列产品[object Object]:
    • 输入张量queryRef、keyRef、cos、sin只支持4维的shape,layout只支持1

    • 输入张量queryRef、keyRef、cos、sin的dtype必须相同,且4个输入shape的前2维和最后一维必须相等,cos和sin的shape第3维必须等于1,输入shape最后一维必须等于128

    • 输入queryRef的shape用(q_b, q_s, q_n, q_d)表示,keyRef shape用(q_b, q_s, k_n, q_d)表示,cos和sin shape用(q_b, q_s, 1, q_d)表示。其中,b表示batch_size,s表示seq_length,n表示head_num,d表示head_dim。

      • 当输入是BFLOAT16时,cast表示为1,castSize为4,DtypeSize为2
      • 当输入是FLOAT16或FLOAT32时,cast表示为0,castSize = DtypeSize(FLOAT16时为2,FLOAT32时为4)

      需要使用的UB空间大小计算方式:ub_required = (q_n + k_n) * 128 * castSize * 2 + 128 * DtypeSize * 4 + (q_n + k_n) * 128 * castSize + (q_n + k_n) * 128 * castSize * 2 + cast * (128 * 4 * 2), 当计算出ub_required的大小超过当前AI处理器的UB空间总大小时,不支持使用该融合算子

    • 不支持空tensor场景

调用示例

示例代码如下,仅供参考,具体编译和执行过程请参考undefined

[object Object]