昇腾社区首页
中文
注册
开发者
下载

aclnnDenseLightningIndexerGradKLLoss

产品支持情况

[object Object]undefined

功能说明

  • 算子功能:DenseLightningIndexerGradKlLoss算子是LightningIndexer的反向算子,再额外融合了Loss计算功能。LightningIndexer算子将QueryToken和KeyToken之间的最高内在联系的TopK个筛选出来,从而减少长序列场景下Attention的计算量,加速长序列的网络的推理和训练的性能。稠密场景下的LightningIndexerGrad的输入query、key、query_index、key_index不用做稀疏化处理。

  • 计算公式:

    1. Top-k value的计算公式:
    It,:=Wt,:@ReLU(q~t,:@K~:t,:)I_{t,:}=W_{t,:}@ReLU(\tilde{q}_{t,:}@\tilde{K}_{:t,:}^\top)
    • Wt,:W_{t,:}是第tt个token对应的weightsweights
    • q~t,:\tilde{q}_{t,:}q~\tilde{q}矩阵第tt个token对应的GG个query头合轴后的结果;
    • K~:t,:\tilde{K}_{:t,:}ttK~\tilde{K}矩阵。
    1. 正向的Softmax对应公式:
    pt,:=Softmax(qt,:@K:t,:/d)p_{t,:} = \text{Softmax}(q_{t,:} @ K_{:t,:}^\top/\sqrt{d})
    • pt,:p_{t,:}是第tt个token对应的Softmax结果;
    • qt,:q_{t,:}qq矩阵第tt个token对应的GG个query头合轴后的结果;
    • K:t,:{K}_{:t,:}ttKK矩阵。
    1. npu_lightning_indexer会单独训练,对应的loss function为:
    Loss=tDKL(pt,:Softmax(It,:))Loss{=}\sum_tD_{KL}(p_{t,:}||Softmax(I_{t,:}))

    其中,pt,:p_{t,:}是target distribution,通过对main attention score 进行所有的head的求和,然后把求和结果沿着上下文方向进行L1正则化得到。DKLD_{KL}为KL散度,其表达式为:

    DKL(ab)=iailog(aibi)D_{KL}(a||b){=}\sum_ia_i\mathrm{log}{\left(\frac{a_i}{b_i}\right)}
    1. 通过求导可得Loss的梯度表达式:
    dIt,:=Softmax(It,:)pt,:dI\mathop{{}}\nolimits_{{t,:}}=Softmax \left( I\mathop{{}}\nolimits_{{t,:}} \left) -p\mathop{{}}\nolimits_{{t,:}}\right. \right.

    利用链式法则可以进行weights,query和key矩阵的梯度计算:

    dWt,:=dIt,:@(ReLU(St,:))dW\mathop{{}}\nolimits_{{t,:}}=dI\mathop{{}}\nolimits_{{t,:}}\text{@} \left( ReLU \left( S\mathop{{}}\nolimits_{{t,:}} \left) \left) \mathop{{}}\nolimits^{\top}\right. \right. \right. \right. dq~t,:=dSt,:@K~:t,:d\mathop{{\tilde{q}}}\nolimits_{{t,:}}=dS\mathop{{}}\nolimits_{{t,:}}@\tilde{K}\mathop{{}}\nolimits_{{:t,:}} dK~:t,:=(dSt,:)@q~:t,:d\tilde{K}\mathop{{}}\nolimits_{{:t,:}}=\left(dS\mathop{{}}\nolimits_{{t,:}} \left) \mathop{{}}\nolimits^{\top}@\tilde{q}\mathop{{}}\nolimits_{{:t, :}}\right. \right.

    其中,SSq~\tilde{q}KK矩阵乘的结果。

[object Object]

函数原型

算子执行接口为两段式接口,必须先调用“aclnnDenseLightningIndexerGradKLLossGetWorkspaceSize”接口获取入参并根据计算流程计算所需workspace大小,再调用“aclnnDenseLightningIndexerGradKLLoss”接口执行计算。

[object Object]
[object Object]

aclnnDenseLightningIndexerGradKLLoss

  • 参数说明:

    [object Object]
  • 返回值:

    返回aclnnStatus状态码,具体参见

    第一段接口完成入参校验,出现以下场景时报错:

    [object Object]

aclnnDenseLightningIndexerGradKLLoss

  • 参数说明:

    [object Object]
  • 返回值:

    返回aclnnStatus状态码,具体参见

约束说明

  • 公共约束

    • 确定性计算: aclnnDenseLightningIndexerGradKLLoss默认非确定性实现,支持通过alcrtCtxSetSysParamOpt开启确定性。
    • 入参为空的场景处理:
      • query或key或query_index或key_index或weight为空Tensor:当前不支持,会报错。
    [object Object]
  • 规格约束

    [object Object]
  • 典型值

    [object Object]

调用示例

示例代码如下,仅供参考,具体编译和执行过程请参考

[object Object]