昇腾社区首页
中文
注册
开发者
下载

aclnnApplyRotaryPosEmb

产品支持情况

产品 是否支持
[object Object]Atlas A3 训练系列产品/Atlas A3 推理系列产品[object Object]
[object Object]Atlas A2 训练系列产品/Atlas A2 推理系列产品[object Object]
[object Object]Atlas 200I/500 A2 推理产品[object Object] ×
[object Object]Atlas 推理系列产品[object Object]
[object Object]Atlas 训练系列产品[object Object] ×

功能说明

  • 算子功能:推理网络为了提升性能,将query和key两路算子融合成一路。执行旋转位置编码计算,计算结果执行原地更新。
  • 计算公式:
query_q1=query[...,:query.shape[1]//2]query\_q1 = query[..., : query.shape[-1] // 2] query_q2=query[...,query.shape[1]//2:]query\_q2 = query[..., query.shape[-1] // 2 :] query_rotate=torch.cat((query_q2,query_q1),dim=1)query\_rotate = torch.cat((-query\_q2, query\_q1), dim=-1) key_k1=key[...,:key.shape[1]//2]key\_k1 = key[..., : key.shape[-1] // 2] key_k2=key[...,key.shape[1]//2:]key\_k2 = key[..., key.shape[-1] // 2 :] key_rotate=torch.cat((key_k2,key_k1),dim=1)key\_rotate = torch.cat((-key\_k2, key\_k1), dim=-1) q_embed=(querycos)+query_rotatesinq\_embed = (query * cos) + query\_rotate * sin k_embed=(keycos)+key_rotatesink\_embed = (key * cos) + key\_rotate * sin

函数原型

每个算子分为,必须先调用“aclnnApplyRotaryPosEmbGetWorkspaceSize”接口获取入参并根据流程计算所需workspace大小,再调用“aclnnApplyRotaryPosEmb”接口执行计算。

[object Object]
[object Object]

aclnnApplyRotaryPosEmbGetWorkspaceSize

  • 参数说明:

    [object Object]
    • [object Object]Atlas 推理系列产品[object Object]:不支持BFLOAT16
  • 返回值:

    aclnnStatus:返回状态码,具体参见

    [object Object]

aclnnApplyRotaryPosEmb

  • 参数说明:

    [object Object]
  • 返回值:

    aclnnStatus:返回状态码,具体参见

约束说明

  • queryRef、keyRef、cos、sin输入shape的前2维(B、S)和最后一维(D)必须相等。

  • 输入张量queryRef、keyRef、cos、sin的dtype必须相同。

  • 输入queryRef的shape用(q_b, q_s, q_n, q_d)表示,keyRef shape用(q_b, q_s, k_n, q_d)表示,cos和sin shape用(q_b, q_s, 1, q_d)表示。其中,b表示batch_size,s表示seq_length,n表示head_num,d表示head_dim。

    • 当输入是BFLOAT16时,cast表示为1,castSize为4,DtypeSize为2
    • 当输入是FLOAT16或FLOAT32时,cast表示为0,castSize = DtypeSize(FLOAT16时为2,FLOAT32时为4)

    需要使用的UB空间大小计算方式:ub_required = (q_n + k_n) * 128 * castSize * 2 + 128 * DtypeSize * 4 + (q_n + k_n) * 128 * castSize + (q_n + k_n) * 128 * castSize * 2 + cast * (128 * 4 * 2), 当计算出ub_required的大小超过当前AI处理器的UB空间总大小时,不支持使用该融合算子。

调用示例

示例代码如下,仅供参考,具体编译和执行过程请参考

[object Object]