昇腾社区首页
中文
注册
开发者
下载

aclnnSparseLightningIndexerGradKLLoss

产品支持情况

[object Object]undefined

功能说明

  • 接口功能:SparselightningIndexerGradKlLoss算子是LightningIndexer的反向算子,再额外融合了Loss计算功能。LightningIndexer算子将QueryToken和KeyToken之间的最高内在联系的TopK个筛选出来,存放在SparseIndices中,从而减少长序列场景下Attention的计算量,加速长序列的网络的推理和训练的性能。

  • 计算公式: 用于取Top-k的value的计算公式可以表示为:

    It,:=Wt,:@ReLU(qt,:@(K:t,:)T)I_{t,:}=W_{t,:}@ReLU(q_{t,:}@(K_{:t,:})^T)

    其中,WW是第tt个token对应的weights,qq是第tt个token对应的GG个query头合轴后的矩阵,KKttKK矩阵。

    LightningIndexer会单独训练,对应的loss function为:

    L(I)=tDKL(pt,:Softmax(It,:))L(I){=}\sum_tD_{KL}(p_{t,:}||Softmax(I_{t,:}))

    其中,pp是target distribution,通过对main attention score 进行所有的head的求和,然后把求和结果沿着上下文方向进行L1正则化得到。DKLD_{KL}为KL散度,其表达式为:

    DKL(ab)=iailog(aibi)D_{KL}(a||b){=}\sum_ia_i\mathrm{log}{\left(\frac{a_i}{b_i}\right)}

    通过求导可得Loss的梯度表达式:

    dIt,:=Softmax(It,:)pt,:dI\mathop{{}}\nolimits_{{t,:}}=Softmax \left( I\mathop{{}}\nolimits_{{t,:}} \left) -p\mathop{{}}\nolimits_{{t,:}}\right. \right.

    利用链式法则可以进行weights,query和key矩阵的梯度计算:

    dWt,:=dIt,:@(ReLU(St,:))TdW\mathop{{}}\nolimits_{{t,:}}=dI\mathop{{}}\nolimits_{{t,:}}\text{@} \left( ReLU \left( S\mathop{{}}\nolimits_{{t,:}} \left) \left) \mathop{{}}\nolimits^{{T}}\right. \right. \right. \right. dqt,:=dSt,:@K:t,:d\mathop{{q}}\nolimits_{{t,:}}=dS\mathop{{}}\nolimits_{{t,:}}@K\mathop{{}}\nolimits_{{:t,:}} dK:t,:=(dSt,:)T@q:t,:dK\mathop{{}}\nolimits_{{:t,:}}= \left( dS\mathop{{}}\nolimits_{{t,:}} \left) \mathop{{}}\nolimits^{{T}}@q\mathop{{}}\nolimits_{{:t,:}}\right. \right.

    其中,S为QK矩阵softmax的结果。

[object Object]

函数原型

每个算子分为,必须先调用“aclnnSparseLightningIndexerGradKLLossGetWorkspaceSize”接口获取入参并根据计算流程计算所需workspace大小,再调用“aclnnSparseLightningIndexerGradKLLoss”接口执行计算。

[object Object]
[object Object]

aclnnSparseLightningIndexerGradKLLossGetWorkspaceSize

  • 参数说明:

    [object Object]
  • 返回值:

    返回aclnnStatus状态码,具体参见

    第一段接口完成入参校验,出现以下场景时报错:

    [object Object]

aclnnSparseLightningIndexerGradKLLoss

  • 参数说明:

    [object Object]
  • 返回值:

    返回aclnnStatus状态码,具体参见

约束说明

  • 确定性计算:
    • aclnnSparseLightningIndexerKLLoss默认非确定性实现,不支持通过aclrtCtxSetSysParamOpt开启确定性。
  • 公共约束
    • 入参为空的场景处理:
      • query为空Tensor:直接返回。
      • 公共约束里入参为空的场景和FAG保持一致。
    [object Object]
  • 规格约束[object Object]
  • 典型值[object Object]

调用示例

调用示例代码如下,仅供参考,具体编译和执行过程请参考

[object Object]