torch_npu.npu_mla_prolog

功能描述

接口原型

torch_npu.npu_mla_prolog(Tensor token_x, Tensor weight_dq, Tensor weight_uq_qr, Tensor weight_uk, Tensor weight_dkv_kr, Tensor rmsnorm_gamma_cq, Tensor rmsnorm_gamma_ckv, Tensor rope_sin, Tensor rope_cos, Tensor cache_index, Tensor kv_cache, Tensor kr_cache, *, Tensor? dequant_scale_x=None, Tensor? dequant_scale_w_dq=None, Tensor? dequant_scale_w_uq_qr=None, Tensor? dequant_scale_w_dkv_kr=None, Tensor? quant_scale_ckv=None, Tensor? quant_scale_ckr=None, Tensor? smooth_scales_cq=None, float rmsnorm_epsilon_cq=1e-05, float rmsnorm_epsilon_ckv=1e-05, str cache_mode="PA_BSND") -> (Tensor, Tensor, Tensor, Tensor)

参数说明

shape格式字段含义:

  • B:Batch表示输入样本批量大小,取值范围为1~65536。
  • S:Seq-Length 表示输入样本序列长度,取值范围为1~16。
  • He:Head-Size 表示隐藏层的大小,取值为7168。
  • Hcq:q低秩矩阵维度,取值为1536。
  • N:Head-Num 表示多头数,取值范围为32、64、128。
  • Hckv:kv低秩矩阵维度,取值为512。
  • D:qk不含位置编码维度,取值为128。
  • Dr:qk位置编码维度,取值为64。
  • Nkv:kv的head数,取值为1。
  • BlockNum:PagedAttention场景下的块数,取值为计算B*Skv/BlockSize的值后再向上取整,其中Skv表示kv的序列长度。
  • BlockSize:PagedAttention场景下的块大小,取值范围为16、128。

输出说明

约束说明

支持的型号

Atlas A2 训练系列产品/Atlas 800I A2 推理产品/A200I A2 Box 异构组件

调用示例