aclnnSparseLightningIndexerGradKLLoss
产品支持情况
功能说明
接口功能:SparselightningIndexerGradKlLoss算子是LightningIndexer的反向算子,再额外融合了Loss计算功能。LightningIndexer算子将QueryToken和KeyToken之间的最高内在联系的TopK个筛选出来,存放在SparseIndices中,从而减少长序列场景下Attention的计算量,加速长序列的网络的推理和训练的性能。
计算公式: 用于取Top-k的value的计算公式可以表示为:
其中,是第个token对应的weights,是第个token对应的个query头合轴后的矩阵,为行矩阵。
LightningIndexer会单独训练,对应的loss function为:
其中,是target distribution,通过对main attention score 进行所有的head的求和,然后把求和结果沿着上下文方向进行L1正则化得到。为KL散度,其表达式为:
通过求导可得Loss的梯度表达式:
利用链式法则可以进行weights,query和key矩阵的梯度计算:
其中,S为QK矩阵softmax的结果。
函数原型
每个算子分为,必须先调用“aclnnSparseLightningIndexerGradKLLossGetWorkspaceSize”接口获取入参并根据计算流程计算所需workspace大小,再调用“aclnnSparseLightningIndexerGradKLLoss”接口执行计算。
[object Object]
[object Object]
aclnnSparseLightningIndexerGradKLLossGetWorkspaceSize
aclnnSparseLightningIndexerGradKLLoss
约束说明
- 确定性计算:
- aclnnSparseLightningIndexerKLLoss默认非确定性实现,不支持通过aclrtCtxSetSysParamOpt开启确定性。
- 公共约束
- 入参为空的场景处理:
- query为空Tensor:直接返回。
- 公共约束里入参为空的场景和FAG保持一致。
- 入参为空的场景处理:
- 规格约束[object Object]
- 典型值[object Object]
调用示例
[object Object]