torch_npu.npu_kv_rmsnorm_rope_cache

功能描述

接口原型

torch_npu.npu_kv_rmsnorm_rope_cache(Tensor kv, Tensor gamma, Tensor cos, Tensor sin, Tensor index, Tensor k_cache, Tensor ckv_cache, *, Tensor? k_rope_scale=None, Tensor? c_kv_scale=None, Tensor? k_rope_offset=None, Tensor? c_kv_offset=None, float epsilon=1e-5, str cache_mode='Norm', bool is_output_kv=False) -> (Tensor, Tensor, Tensor, Tensor)

参数说明

Tensor中shape使用的变量说明:

  • batch_size:batch的大小。
  • seq_len:sequence的长度。
  • hidden_size:表示MLA输入的向量长度,取值仅支持576。
  • rms_size:表示RMSNorm分支的向量长度,取值仅支持512。
  • rope_size:表示RoPE分支的向量长度,取值仅支持64。
  • cache_length:Norm模式下有效,表示KVCache支持的最大长度。
  • block_num:PagedAttention模式下有效,表示Block的个数。
  • block_size:PagedAttention模式下有效,表示Block的大小。

输出说明

约束说明

支持的型号

调用示例