昇腾社区首页
中文
注册
开发者
下载

aclnnQkvRmsNormRopeCache

产品支持情况

[object Object]undefined

功能说明

  • 接口功能:输入qkv融合张量,通过SplitVD拆分q、k、v张量,执行RmsNorm、ApplyRotaryPosEmb、Quant、Scatter融合操作,输出qOut、kCache、vCache、qBeforeQuant(可选)、kBeforeQuant(可选)、vBeforeQuant(可选)。

  • 本接口目前支持的场景如下表:

    [object Object]undefined
  • 计算公式:

    (1) SplitVD:

    下式中,NqN_qNkN_kNvN_v分别表示 q、k、v 分量的注意力头数量,必须满足:

    {Nk=NvNqkv=Nk+Nv+NqDqkv=Dq=Dk=Dv\begin{cases} N_k = N_v \\ N_{qkv} = N_k + N_v + N_q \\ D_{qkv} = D_q = D_k = D_v \end{cases} q=qkv[...,[:Nq]Dqkv]k=qkv[...,[Nq:Nv]Dqkv]v=qkv[...,[Nv:]Dqkv]\begin{aligned} q &= qkv[..., [:N_q] * D_{qkv}] \\ k &= qkv[..., [N_q:-N_v] * D_{qkv}] \\ v &= qkv[..., [-N_v:] * D_{qkv}] \end{aligned}

    (2) RmsNorm:

    此处x和y分别表示RmsNorm的输入张量和输出张量,归一化沿最后一维(feature dimension)进行,该计算规则通用于q、k分量。

    squareX=xxsquareX = x * x meanSquareX=squareX.mean(dim=1,keepdim=True)meanSquareX = squareX.mean(dim = -1, keepdim = True) rms=meanSquareX+epsilonrms = \sqrt{meanSquareX + epsilon} y=(x/rms)gammay = (x / rms) * gamma

    (3) RoPE (Half-and-Half):

    此处的y指代完成RmsNorm计算的输出结果。

    y1=y[,:d/2]y1 = y[\ldots, :d/2] y2=y[,d/2:]y2 = y[\ldots, d/2:] y_RoPE=torch.cat((y2,y1),dim=1)y\_RoPE = torch.cat((-y2, y1), dim = -1) y_embed=(ycos)+y_RoPEsiny\_embed = (y * cos) + y\_RoPE * sin

    (4) Quant:

    无量化:

    kQuant=kRoPEvQuant=vkQuant = kRoPE \\ vQuant = v

    对称量化部分:

    kQuant=kRoPE/kScalevQuant=v/vScalekQuant = kRoPE / kScale \\ vQuant = v / vScale

    非对称量化部分:

    kQuant=kRoPE/kScale+kOffsetvQuant=v/vScale+vOffsetkQuant = kRoPE / kScale + kOffset \\ vQuant = v / vScale + vOffset

函数原型

每个算子分为,必须先调用“aclnnQkvRmsNormRopeCacheGetWorkspaceSize”接口获得入参并根据流程计算所需workspace大小,再调用“aclnnQkvRmsNormRopeCache”接口执行计算。

[object Object]
[object Object]

aclnnQkvRmsNormRopeCacheGetWorkspaceSize

  • 参数说明:

    [object Object]
  • 返回值:

    aclnnStatus:返回状态码,具体参见

    第一段接口完成入参校验,出现以下场景时报错:

    [object Object]

aclnnQkvRmsNormRopeCache

  • 参数说明:

    [object Object]
  • 返回值:

    aclnnStatus:返回状态码,具体参见

约束说明

  • 确定性计算:

    • aclnnQkvRmsNormRopeCache默认确定性实现。
  • 输入shape限制:

    • B[object Object]qkv[object Object]为输入qkv的batch_size,S[object Object]qkv[object Object]为输入qkv的sequence length,大小由qkvSize决定。
    • N[object Object]qkv[object Object]为输入qkv的head number。D[object Object]qkv[object Object]为输入qkv的head dim,目前仅支持128。D[object Object]q[object Object]、D[object Object]k[object Object]和D[object Object]k[object Object]分别为q、k、v的head dim,要求D[object Object]qkv[object Object] = D[object Object]q[object Object] = D[object Object]k[object Object] = D[object Object]v[object Object],D[object Object]qkv[object Object]需要满足(D[object Object]qkv[object Object]*qkv数据类型占字节数)可以被32整除。
    • 根据rope规则,D[object Object]k[object Object]和D[object Object]q[object Object]为偶数。若cacheMode为PA_NZ场景下,D[object Object]k[object Object]、D[object Object]q[object Object]需32B对齐;BlockSize需32B对齐。
    • 关于上述32B对齐的情形,对齐值由cache的数据类型决定。以BlockSize为例,若cache的数据类型为int8,则需要满足BlockSize % 32 = 0;若cache的数据类型为float16,则需要满足BlockSize % 16 = 0;若kCache与vCache参数的dtype不一致,BlockSize需同时满足BlockSize % 32 = 0和BlockSize % 16 = 0。
    • BlockNum为写入cache的内存块数,大小由用户输入场景决定,要求BlockNum >= Ceil(S[object Object]qkv[object Object] / BlockSize) * B[object Object]qkv[object Object]。
    • 使用requireMemory表示存放数据所需的空间大小,需满足:requireMemory >= (B[object Object]qkv[object Object] * S[object Object]qkv[object Object] * N[object Object]qkv[object Object] * D[object Object]qkv[object Object] + 2 * D[object Object]qkv[object Object] + 2 * B[object Object]qkv[object Object] * S[object Object]qkv[object Object] * D[object Object]qkv[object Object] + B[object Object]qkv[object Object] * S[object Object]qkv[object Object] * N[object Object]q[object Object] * D[object Object]qkv[object Object] + BlockNum * BlockSize * N[object Object]v[object Object] * D[object Object]qkv[object Object] + BlockNum * BlockSize * N[object Object]k[object Object] * D[object Object]qkv[object Object]) * sizeof(FLOAT16) + B[object Object]qkv[object Object] * S[object Object]qkv[object Object] * sizeof(INT64) + (2 * N[object Object]k[object Object] * D[object Object]qkv[object Object] + 2 * N[object Object]v[object Object]) * sizeof(FLOAT),当计算出requireMemory的大小超过当前AI处理器的GM空间总大小,不支持使用该接口。
  • 其他限制:

    • 对于index,要求index的value值范围为[-1, BlockNum * BlockSize)。value数值不可以重复,index为-1时,代表跳过更新。
    • kScaleOptional, vScaleOptional表示对称量化的缩放因子,因此若传参,则值不能为0。

调用示例

示例代码如下,仅供参考,具体编译和执行过程请参考

[object Object]