开发者
资源
[object Object][object Object][object Object]undefined
[object Object]
  • API功能:该接口实现了npu_lightning_indexer的反向功能,并融合了Loss的计算。npu_lightning_indexer用于筛选Attention的[object Object][object Object]间最高内在联系的Top-k项,存放在[object Object]中,以减少长序列场景下的Attention计算量,提升训练性能。

  • 计算公式:

    1. Top-k value的计算公式:

      It,:=Wt,:@ReLU(q~t,:@K~:t,:)I_{t,:}=W_{t,:}@ReLU(\tilde{q}_{t,:}@\tilde{K}_{:t,:}^\top)
      • Wt,:W_{t,:}是第tt个token对应的weightsweights
      • q~t,:\tilde{q}_{t,:}QindexQ_{index}矩阵第tt个token对应的GG个头合轴后的结果;
      • K~:t,:\tilde{K}_{:t,:}KindexK_{index}矩阵tt行。
    2. 正向的Softmax对应公式:

      pt,:=Softmax(qt,:@K:t,:/d)p_{t,:} = \text{Softmax}(q_{t,:} @ K_{:t,:}^\top/\sqrt{d})
      • pt,:p_{t,:}是第tt个token对应的Softmax结果;
      • qt,:q_{t,:}QQ矩阵第tt个token对应的GG个query头合轴后的结果;
      • KKKK矩阵tt行。
    3. npu_lightning_indexer会单独训练,对应的loss function为:

      Loss=tDKL(pt,:Softmax(It,:))Loss{=}\sum_tD_{KL}(p_{t,:}||Softmax(I_{t,:}))

      其中,pt,:p_{t,:}是target distribution,通过对main attention score 进行所有的head的求和,然后把求和结果沿着上下文方向进行L1正则化得到。DKLD_{KL}为KL散度,其表达式为:

      DKL(ab)=iailog(aibi)D_{KL}(a||b){=}\sum_ia_i\mathrm{log}{\left(\frac{a_i}{b_i}\right)}
    4. 通过求导可得Loss的梯度表达式:

      dIt,:=Softmax(It,:)pt,:dI\mathop{{}}\nolimits_{{t,:}}=Softmax \left( I\mathop{{}}\nolimits_{{t,:}} \left) -p\mathop{{}}\nolimits_{{t,:}}\right. \right.

      利用链式法则可以进行weights,query和key矩阵的梯度计算:

      dWt,:=dIt,:@(ReLU(St,:))dW\mathop{{}}\nolimits_{{t,:}}=dI\mathop{{}}\nolimits_{{t,:}}\text{@} \left( ReLU \left( S\mathop{{}}\nolimits_{{t,:}} \left) \left) \mathop{{}}\nolimits^{\top}\right. \right. \right. \right. dq~t,:=dSt,:@K~:t,:d\mathop{{\tilde{q}}}\nolimits_{{t,:}}=dS\mathop{{}}\nolimits_{{t,:}}@\tilde{K}\mathop{{}}\nolimits_{{:t,:}} dK~:t,:=(dSt,:)@q~:t,:d\tilde{K}\mathop{{}}\nolimits_{{:t,:}}=\left(dS\mathop{{}}\nolimits_{{t,:}} \left) \mathop{{}}\nolimits^{\top}@\tilde{q}\mathop{{}}\nolimits_{{:t, :}}\right. \right.

      其中,SSQindexQ_{index}KindexK_{index}矩阵乘的结果。

[object Object]
[object Object]
[object Object]

query([object Object]):必选参数,表示Attention中的query,对应公式中的qtq_{t}。数据格式支持NDND,数据类型支持[object Object][object Object]。shape支持(B,S1,N1,D)(B, S1, N1, D)(T1,N1,D)(T1, N1, D)

key([object Object]):必选参数,表示Attention中的key,对应公式中的KtK_{t}。数据格式支持NDND,数据类型支持[object Object][object Object]。shape支持(B,S2,N2,D)(B, S2, N2, D)(T2,N2,D)(T2, N2, D)

query_index([object Object]):必选参数,表示lightning_indexer正向的输入[object Object],对应公式中的q~t\tilde{q}_{t}。数据格式支持NDND,数据类型支持[object Object][object Object]。shape支持(B,S1,N1index,D)(B, S1, N1index, D)(T1,N1index,D)(T1, N1index, D)

key_index([object Object]):必选参数,表示lightning_indexer正向的输入[object Object],对应公式中的K~t\tilde{K}_{t}。数据格式支持NDND,数据类型支持[object Object][object Object]。shape支持(B,S2,N2index,D)(B, S2, N2index, D)(T2,N2index,D)(T2, N2index, D)

weights([object Object]):必选参数,表示lightning_indexer的权重系数,对应公式中的WtW_{t}。数据格式支持NDND,数据类型支持[object Object][object Object][object Object]。shape支持(B,S1,N1index)(B, S1, N1index)(T1,N1index)(T1, N1index)

sparse_indices([object Object]):必选参数,表示排序后[object Object][object Object]的token序号。数据格式支持NDND,数据类型支持[object Object][object Object]。shape支持(B,S1,N2index,topK)(B, S1, N2index, topK)(T1,N2index,topK)(T1, N2index, topK)

softmax_max([object Object]):必选参数,表示Attention softmax结果中的最大值。数据格式支持NDND,数据类型支持[object Object][object Object]。shape支持(B,N2,S1,G)(B, N2, S1, G)(N2,T1,G)(N2, T1, G)

softmax_sum([object Object]):必选参数,表示Attention softmax结果的求和。数据格式支持NDND,数据类型支持[object Object][object Object]。shape支持(B,N2,S1,G)(B, N2, S1, G)(N2,T1,G)(N2, T1, G)

scale_value([object Object]):必选参数,表示缩放系数,数据类型支持[object Object]

query_rope([object Object]):可选参数,表示MLA结构中的query的rope信息。数据格式支持NDND,数据类型支持[object Object][object Object]。shape支持(B,S1,N1,Dr)(B, S1, N1, Dr)(T1,N1,Dr)(T1, N1, Dr)

key_rope([object Object]):可选参数,表示MLA结构中的key的rope信息。数据格式支持NDND,数据类型支持[object Object][object Object]。shape支持(B,S2,N2,Dr)(B, S2, N2, Dr)(T2,N2,Dr)(T2, N2, Dr)

actual_seq_qlen([object Object]):可选参数,TND场景时需传入此参数。表示query每个S的累加和长度,数据类型支持[object Object],数据格式支持NDND,默认值为[object Object]

actual_seq_klen([object Object]):可选参数,TND场景时需传入此参数。表示key每个S的累加和长度,数据类型支持[object Object],数据格式支持NDND,默认值为[object Object]

layout([object Object]):可选参数,用于标识输入[object Object]的数据排布格式。当前支持BSNDBSNDTNDTND,默认值为"BSND"。

sparse_mode([object Object]):可选参数,表示sparse的模式,数据类型支持[object Object],默认值为[object Object]

pre_tokens([object Object]):可选参数,用于稀疏计算,表示Attention需要和前几个token计算关联。数据类型支持[object Object],默认值2^63-1。

next_tokens([object Object]):可选参数,用于稀疏计算,表示Attention需要和后几个token计算关联。数据类型支持[object Object],默认值2^63-1。

[object Object]
  • d_query_index([object Object]):对应公式中的dq~td\tilde{q}_{t},表示[object Object]的梯度,数据类型支持[object Object][object Object]
  • d_key_index([object Object]):对应公式中的dK~td\tilde{K}_{t},表示[object Object]的梯度,数据类型支持[object Object][object Object]
  • d_weights([object Object]):对应公式中的dWtdW_{t},表示[object Object]的梯度,数据类型支持[object Object][object Object][object Object]
  • loss([object Object]):对应公式中的dItdI_{t},表示网络正向输出和golden值的差异,数据类型支持[object Object]
[object Object]
  • 参数query、key、query_index、key_index的数据类型应保持一致。
  • 参数weights不为[object Object]时,参数query、key、query_index、key_index、weights的数据类型应保持一致。
  • 规格约束:
[object Object]undefined
[object Object]
  • 单算子模式调用

    [object Object]