昇腾社区首页
中文
注册

torch_npu.npu_kv_rmsnorm_rope_cache

功能描述

  • 算子功能:融合了MLA(Multi-head Latent Attention)结构中RMSNorm归一化计算与RoPE(Rotary Position Embedding)位置编码以及更新KVCache的ScatterUpdate操作。
  • 计算公式
    • 输入张量kv拆分:拆分为两部分,其中B为批次大小,T为序列长度。

    • RMS归一化:对rms_in,应用RMS归一化。

      • γ∈R^512是可学习的缩放参数。
      • Ed[·]表示沿最后一个维度(维度d=512)的均值。
      • ε为小常数(如0.00001),防止除以零。
      • ⊙表示逐元素相乘。
    • 旋转位置编码(RoPE)
      1. 重塑与转置:将rope_in塑并转置以准备旋转

      2. 旋转操作:应用旋转位置编码

        • cos和sin为预计算的旋转角度参数。
        • RotateHalf(k)将k的后半部分元素移至前半部分并取反,后半部分用前半部分的值。具体来说,对于维度d=64:

接口原型

torch_npu.npu_kv_rmsnorm_rope_cache(Tensor kv, Tensor gamma, Tensor cos, Tensor sin, Tensor index, Tensor k_cache, Tensor ckv_cache, *, Tensor? k_rope_scale=None, Tensor? c_kv_scale=None, Tensor? k_rope_offset=None, Tensor? c_kv_offset=None, float epsilon=1e-5, str cache_mode='Norm', bool is_output_kv=False) -> (Tensor, Tensor, Tensor, Tensor)

参数说明

Tensor中shape使用的变量说明:

  • batch_size:batch的大小。
  • seq_len:sequence的长度。
  • hidden_size:表示MLA输入的向量长度,取值仅支持576。
  • rms_size:表示RMSNorm分支的向量长度,取值仅支持512。
  • rope_size:表示RoPE分支的向量长度,取值仅支持64。
  • cache_length:Norm模式下有效,表示KVCache支持的最大长度。
  • block_num:PagedAttention模式下有效,表示Block的个数。
  • block_size:PagedAttention模式下有效,表示Block的大小。
  • kv:Tensor类型,表示输入的特征张量。数据类型支持bfloat16、float16,数据格式为BNSD,要求为4D的Tensor,形状为[batch_size, 1, seq_len, hidden_size] ,其中hidden_size=rms_size(RMS)+rope_size(RoPE)。
  • gamma:Tensor类型,表示RMS归一化的缩放参数。数据类型支持bfloat16、float16,数据格式为ND,要求为1D的Tensor,形状为[rms_size]。
  • cos:Tensor类型,表示RoPE旋转位置编码的余弦分量。数据类型支持bfloat16、float16,数据格式为ND,要求为4D的Tensor,形状为[batch_size, 1, seq_len, rope_size]。
  • sin:Tensor类型,表示RoPE旋转位置编码的正弦分量。数据类型支持bfloat16、float16,数据格式为ND,要求为4D的Tensor,形状为[batch_size, 1, seq_len, rope_size]。
  • index:Tensor类型,表示缓存索引张量,用于定位k_cache和ckv_cache的写入位置。数据类型支持int64,数据格式为ND。shape取决于cache_mode。
  • k_cache:Tensor类型,用于存储量化/非量化的键向量。数据类型支持bfloat16、float16、int8, 数据格式为ND。shape取决于cache_mode。
  • ckv_cache:Tensor类型,用于存储量化/非量化的压缩后的kv向。数据类型支持bfloat16、float16、int8, 数据格式为ND。shape取决于cache_mode。
  • k_rope_scale:Tensor类型,可选,默认值None,表示k旋转位置编码的量化缩放因子。数据类型支持float32,数据格式为ND,要求为1D的Tensor,形状为[rope_size]。量化模式下必填。
  • c_kv_scale:Tensor类型,可选, 默认值None,表示压缩后kv的量化缩放因子。数据类型支持float32,数据格式为ND,要求为1D的Tensor,形状为[rms_size]。量化模式下必填。
  • k_rope_offset:Tensor类型,可选,默认值None, 表示k旋转位置编码量化偏移量。数据类型支持float32,数据格式为ND,要求为1D的Tensor,形状为[rope_size]。量化模式下必填。
  • c_kv_offset:Tensor类型,可选,默认值None,表示压缩后kv的量化偏移量。数据类型支持float32,数据格式为ND,要求为1D的Tensor,形状为[rms_size]。量化模式下必填。
  • epsilon:float类型,可选,默认值1e-5,表示RMS归一化中的极小值,防止除以零。
  • cache_mode:string类型,可选,默认值'Norm',表示缓存模式,支持的模式如下:

    枚举值

    模式名

    说明

    Norm

    KV-Cache更新模式

    k_cache形状为[batch_size, 1, cache_length, rope_size],ckv_cache形状为[batch_size, 1, cache_length, rms_size]。index形状为[batch_size, seq_len],index里的值表示每个Batch下的偏移。

    PA/PA_BNSD

    PagedAttention模式

    k_cache形状为[block_num, block_size, 1, rope_size],ckv_cache形状为[block_num, block_size, 1, rms_size]。index形状为[batch_size*seq_len],index里的值表示每个token的偏移。

    PA_NZ

    Cache数据格式为FRACTAL_NZ的PagedAttention模式

    k_cache形状为 [block_num, block_size, 1, rope_size],ckv_cache形状为[block_num, block_size, 1, rms_size]。index形状为[batch_size * seq_len],index里的值表示每个token的偏移。

    不同量化模式下数据排布不同:

    • 非量化模式下:k_cache数据排布为[block_num, rope_size//16, block_size, 1, 16],ckv_cache数据排布为[block_num, rms_size//16, block_size, 1, 16]
    • 量化模式下:k_cache数据排布为[block_num, rope_size//32, block_size, 1, 32],ckv_cache数据排布为[block_num, rms_size//32, block_size, 1, 32]

    PA_BLK_BNSD

    特殊的PagedAttention模式

    k_cache形状为[block_num, block_size, 1, rope_size],ckv_cache形状为[block_num, block_size, 1, rms_size]。index形状为[batch_size*Ceil(seq_len/block_size)],index里的值表示每个block的起始偏移,不再和token一一对应。

    PA_BLK_NZ

    Cache数据格式为FRACTAL_NZ的特殊的PagedAttention模式

    k_cache形状为 [block_num, block_size, 1, rope_size],ckv_cache形状为[block_num, block_size, 1, rms_size]。index形状为[batch_size * Ceil(seq_len / block_size)],index里的值表示每个block的起始偏移,不再和token一一对应。

    不同量化模式下数据排布不同:

    • 非量化模式下:k_cache数据排布为[block_num, rope_size//16, block_size, 1, 16],ckv_cache数据排布为[block_num, rms_size//16, block_size, 1, 16]
    • 量化模式下:k_cache数据排布为[block_num, rope_size//32, block_size, 1, 32],ckv_cache数据排布为[block_num, rms_size//32, block_size, 1, 32]
  • is_output_kv:bool类型,可选,表示是否输出处理后的k_embed_out和y_out(未量化的原始值),默认值False不输出,仅cache_mode在(PA/PA_BNSD/PA_NZ/PA_BLK_BNSD/PA_BLK_NZ)模式下有效。

输出说明

  • k_cache:Tensor类型,和输入k_cache的数据类型、维度、数据格式完全一致(本质in-place更新)。
  • ckv_cache:Tensor类型,和输入ckv_cache的数据类型、维度、数据格式完全一致(本质in-place更新)。
  • k_embed_out:Tensor类型,仅当is_output_kv=True时输出,表示RoPE处理后的值。要求为4D的Tensor,形状为[batch_size, 1, seq_len, 64],数据类型和格式同输入kv一致。
  • y_out:Tensor类型,仅当is_output_kv=True时输出,表示RMSNorm处理后的值。要求为4D的Tensor,形状为[batch_size, 1, seq_len, 512],数据类型和格式同输入kv一致。

约束说明

  • 该接口支持推理场景下使用。
  • 该接口支持图模式(PyTorch 2.1版本)。
  • 量化模式:当k_rope_scale和c_kv_scale非空时,k_cache和ckv_cache的dtype为int8,缓存形状的最后一个维度需要为32(Cache数据格式为FRACTAL_NZ模式)k_rope_scale和c_kv_scale必须同时非空k_rope_offset和c_kv_offset必须同时为None为非空
  • 非量化模式:当k_rope_scale和c_kv_scale为空时,k_cache和ckv_cache的dtype为bfloat16或float16。
  • 索引映射:所有cache_mode缓存模式下,index的值不可以重复,如果传入的index值存在重复,算子的行为是未定义的且不可预知的。
    • Norm:index的值表示每个Batch下的偏移
    • PA/PA_BNSD/PA_NZ:index的值表示全局的偏移。
    • PA_BLK_BNSD/PA_BLK_NZ:index的值表示每个页的全局偏移;这个场景假设cache更新是连续的,不支持非连续更新的cache。
  • Shape关联规则:不同的cache_mode缓存模式有不同的Shape规则。
    • Norm:k_cache形状为[batch_size, 1, cache_length, rope_size],ckv_cache形状为[batch_size, 1, cache_length, rms_size],index形状为[batch_size, seq_len], cache_length>=seq_len。
    • 非Norm模式(PagedAttention相关模式):要求block_num>=Ceil(seq_len/block_size)*batch_size。

支持的型号

  • Atlas A2 训练系列产品/Atlas 800I A2 推理产品/A200I A2 Box 异构组件
  • Atlas A3 训练系列产品/Atlas A3 推理系列产品

调用示例

  • 单算子模式调用
     1
     2
     3
     4
     5
     6
     7
     8
     9
    10
    11
    12
    13
    14
    15
    16
    17
    18
    19
    20
    21
    22
    23
    import torch
    import torch_npu
    
    class Model(torch.nn.Module):
        def __init__(self):
            super().__init__()
    
        def forward(self, kv, gamma, cos, sin, index, k_cache, ckv_cache,
                    k_rope_scale=None, c_kv_scale=None, k_rope_offset=None, c_kv_offset=None,
                    epsilon=1e-05, cache_mode="Norm", is_output_kv=False):
            k_cache, v_cache, k_rope, c_kv = torch_npu.npu_kv_rmsnorm_rope_cache(kv, gamma, cos, sin, index,
                                                                                 k_cache, ckv_cache,
                                                                                 k_rope_scale=k_rope_scale,
                                                                                 c_kv_scale=c_kv_scale,
                                                                                 k_rope_offset=k_rope_offset,
                                                                                 c_kv_offset=c_kv_offset,
                                                                                 epsilon=epsilon,
                                                                                 cache_mode=cache_mode,
                                                                                 is_output_kv=is_output_kv)
            return k_cache, v_cache, k_rope, c_kv
    
    model = Model().npu()
    _, _, k_rope, c_kv = model(kv, gamma, cos, sin, index, k_cache, ckv_cache, k_rope_scale, c_kv_scale, None, None, 1e-5, cache_mode, is_output_kv)
    
  • 图模式调用
     1
     2
     3
     4
     5
     6
     7
     8
     9
    10
    11
    12
    13
    14
    15
    16
    17
    18
    19
    20
    21
    22
    23
    24
    25
    26
    27
    28
    29
    import torch
    import torch_npu
    import torchair as tng
    from torchair.configs.compiler_config import CompilerConfig
    config = CompilerConfig()
    config.experimental_config.keep_inference_input_mutations = True
    npu_backend = tng.get_npu_backend(compiler_config=config)
    
    class Model(torch.nn.Module):
        def __init__(self):
            super().__init__()
    
        def forward(self, kv, gamma, cos, sin, index, k_cache, ckv_cache,
                    k_rope_scale=None, c_kv_scale=None, k_rope_offset=None, c_kv_offset=None,
                    epsilon=1e-05, cache_mode="Norm", is_output_kv=False):
            k_cache, v_cache, k_rope, c_kv = torch_npu.npu_kv_rmsnorm_rope_cache(kv, gamma, cos, sin, index,
                                                                                 k_cache, ckv_cache,
                                                                                 k_rope_scale=k_rope_scale,
                                                                                 c_kv_scale=c_kv_scale,
                                                                                 k_rope_offset=k_rope_offset,
                                                                                 c_kv_offset=c_kv_offset,
                                                                                 epsilon=epsilon,
                                                                                 cache_mode=cache_mode,
                                                                                 is_output_kv=is_output_kv)
            return k_cache, v_cache, k_rope, c_kv
    
    model = Model().npu()
    model = torch.compile(model, backend=npu_backend, dynamic=False)
    _, _, k_rope, c_kv = model(kv, gamma, cos, sin, index, k_cache, ckv_cache, k_rope_scale, c_kv_scale, None, None, 1e-5, cache_mode, is_output_kv)