API功能:该接口实现了npu_lightning_indexer的反向功能,并融合了Loss的计算。npu_lightning_indexer用于筛选Attention的
[object Object]与[object Object]间最高内在联系的Top-k项,存放在[object Object]中,以减少长序列场景下的Attention计算量,提升训练性能。计算公式:
Top-k value的计算公式:
- 是第个token对应的;
- 是矩阵第个token对应的个头合轴后的结果;
- 为矩阵行。
正向的Softmax对应公式:
- 是第个token对应的Softmax结果;
- 是矩阵第个token对应的个query头合轴后的结果;
- 为矩阵行。
npu_lightning_indexer会单独训练,对应的loss function为:
其中,是target distribution,通过对main attention score 进行所有的head的求和,然后把求和结果沿着上下文方向进行L1正则化得到。为KL散度,其表达式为:
通过求导可得Loss的梯度表达式:
利用链式法则可以进行weights,query和key矩阵的梯度计算:
其中,为和矩阵乘的结果。
query([object Object]):必选参数,表示Attention中的query,对应公式中的。数据格式支持,数据类型支持[object Object]、[object Object]。shape支持、。
key([object Object]):必选参数,表示Attention中的key,对应公式中的。数据格式支持,数据类型支持[object Object]、[object Object]。shape支持、。
query_index([object Object]):必选参数,表示lightning_indexer正向的输入[object Object],对应公式中的。数据格式支持,数据类型支持[object Object]、[object Object]。shape支持、。
key_index([object Object]):必选参数,表示lightning_indexer正向的输入[object Object],对应公式中的。数据格式支持,数据类型支持[object Object]、[object Object]。shape支持、。
weights([object Object]):必选参数,表示lightning_indexer的权重系数,对应公式中的。数据格式支持,数据类型支持[object Object]、[object Object]、[object Object]。shape支持、。
sparse_indices([object Object]):必选参数,表示排序后[object Object]和[object Object]的token序号。数据格式支持,数据类型支持[object Object]、[object Object]。shape支持、。
softmax_max([object Object]):必选参数,表示Attention softmax结果中的最大值。数据格式支持,数据类型支持[object Object]、[object Object]。shape支持、。
softmax_sum([object Object]):必选参数,表示Attention softmax结果的求和。数据格式支持,数据类型支持[object Object]、[object Object]。shape支持、。
scale_value([object Object]):必选参数,表示缩放系数,数据类型支持[object Object]。
query_rope([object Object]):可选参数,表示MLA结构中的query的rope信息。数据格式支持,数据类型支持[object Object]、[object Object]。shape支持、。
key_rope([object Object]):可选参数,表示MLA结构中的key的rope信息。数据格式支持,数据类型支持[object Object]、[object Object]。shape支持、。
actual_seq_qlen([object Object]):可选参数,TND场景时需传入此参数。表示query每个S的累加和长度,数据类型支持[object Object],数据格式支持,默认值为[object Object]。
actual_seq_klen([object Object]):可选参数,TND场景时需传入此参数。表示key每个S的累加和长度,数据类型支持[object Object],数据格式支持,默认值为[object Object]。
layout([object Object]):可选参数,用于标识输入[object Object]的数据排布格式。当前支持、,默认值为"BSND"。
sparse_mode([object Object]):可选参数,表示sparse的模式,数据类型支持[object Object],默认值为[object Object]。
pre_tokens([object Object]):可选参数,用于稀疏计算,表示Attention需要和前几个token计算关联。数据类型支持[object Object],默认值2^63-1。
next_tokens([object Object]):可选参数,用于稀疏计算,表示Attention需要和后几个token计算关联。数据类型支持[object Object],默认值2^63-1。
- d_query_index(
[object Object]):对应公式中的,表示[object Object]的梯度,数据类型支持[object Object]、[object Object]。 - d_key_index(
[object Object]):对应公式中的,表示[object Object]的梯度,数据类型支持[object Object]、[object Object]。 - d_weights(
[object Object]):对应公式中的,表示[object Object]的梯度,数据类型支持[object Object]、[object Object]、[object Object]。 - loss(
[object Object]):对应公式中的,表示网络正向输出和golden值的差异,数据类型支持[object Object]。
- 参数query、key、query_index、key_index的数据类型应保持一致。
- 参数weights不为
[object Object]时,参数query、key、query_index、key_index、weights的数据类型应保持一致。 - 规格约束:
单算子模式调用
[object Object]