昇腾社区首页
EN
注册

torch_npu.npu_mla_prolog

功能描述

  • 算子功能:

    推理场景下,Multi-Head Latent Attention(MLA)前处理的计算。主要计算过程分为四路,首先对输入x乘以WDQ进行下采样和RmsNorm后分成两路,第一路乘以WUQ和WUK经过两次上采样后得到qN;第二路乘以WQR后经过旋转位置编码(ROPE)得到qR;第三路是输入x乘以WDKV进行下采样和RmsNorm后传入Cache中得到kC;第四路是输入x乘以WKR后经过旋转位置编码后传入另一个Cache中得到kR

  • 计算公式:
    • RmsNorm公式:

    • Query计算公式:

    • Query ROPE旋转位置编码:

    • Key计算公式:

    • Key ROPE旋转位置编码:

接口原型

torch_npu.npu_mla_prolog(Tensor token_x, Tensor weight_dq, Tensor weight_uq_qr, Tensor weight_uk, Tensor weight_dkv_kr, Tensor rmsnorm_gamma_cq, Tensor rmsnorm_gamma_ckv, Tensor rope_sin, Tensor rope_cos, Tensor cache_index, Tensor kv_cache, Tensor kr_cache, *, Tensor? dequant_scale_x=None, Tensor? dequant_scale_w_dq=None, Tensor? dequant_scale_w_uq_qr=None, Tensor? dequant_scale_w_dkv_kr=None, Tensor? quant_scale_ckv=None, Tensor? quant_scale_ckr=None, Tensor? smooth_scales_cq=None, float rmsnorm_epsilon_cq=1e-05, float rmsnorm_epsilon_ckv=1e-05, str cache_mode="PA_BSND") -> (Tensor, Tensor, Tensor, Tensor)

参数说明

shape格式字段含义:

  • B:Batch表示输入样本批量大小,取值范围为1~65536。
  • S:Seq-Length 表示输入样本序列长度,取值范围为1~16。
  • He:Head-Size 表示隐藏层的大小,取值为7168。
  • Hcq:q低秩矩阵维度,取值为1536。
  • N:Head-Num 表示多头数,取值范围为32、64、128。
  • Hckv:kv低秩矩阵维度,取值为512。
  • D:qk不含位置编码维度,取值为128。
  • Dr:qk位置编码维度,取值为64。
  • Nkv:kv的head数,取值为1。
  • BlockNum:PagedAttention场景下的块数,取值为计算B*Skv/BlockSize的值后再向上取整,其中Skv表示kv的序列长度。
  • BlockSize:PagedAttention场景下的块大小,取值范围为16、128。
  • token_x:Tensor类型,对应公式中x。shape支持3维,格式为(B,S,He),dtype支持bfloat16,数据格式支持ND。
  • weight_dq:Tensor类型,表示计算Query的下采样权重矩阵,即公式中WDQ。shape支持2维,格式为(He,Hcq),dtype支持bfloat16,数据格式支持FRACTAL_NZ(可通过torch_npu.npu_format_cast将ND格式转为FRACTAL_NZ格式)。
  • weight_uq_qr:Tensor类型,表示计算Query的上采样权重矩阵和Query的位置编码权重矩阵,即公式中WUQ和WQR。shape支持2维,格式为(Hcq,N*(D+Dr)),dtype支持bfloat16,数据格式支持FRACTAL_NZ。
  • weight_uk:Tensor类型,表示计算Key的上采样权重,即公式中WUK。shape支持3维,格式为(N,D,Hckv),dtype支持bfloat16,数据格式支持ND。
  • weight_dkv_kr:Tensor类型,表示计算Key的下采样权重矩阵和Key的位置编码权重矩阵,即公式中WDKV和WKR。shape支持2维,格式为(He,Hckv+Dr),dtype支持bfloat16,数据格式支持FRACTAL_NZ。
  • rmsnorm_gamma_cq:Tensor类型,表示计算cQ的RmsNorm公式中的γ参数。shape支持1维,格式为(Hcq),dtype支持bfloat16,数据格式支持ND。
  • rmsnorm_gamma_ckv:Tensor类型,表示计算cKV的RmsNorm公式中的γ参数。shape支持1维,格式为(Hckv),dtype支持bfloat16,数据格式支持ND。
  • rope_sin:Tensor类型,表示用于计算旋转位置编码的正弦参数矩阵。shape支持3维,格式为(B,S,Dr),dtype支持bfloat16,数据格式支持ND。
  • rope_cos:Tensor类型,表示用于计算旋转位置编码的余弦参数矩阵。shape支持3维,格式为(B,S,Dr),dtype支持bfloat16,数据格式支持ND。
  • cache_index:Tensor类型,表示用于存储kv_cache和kr_cache的索引。shape支持2维,格式为(B,S),dtype支持int64,数据格式支持ND。
  • kv_cache:Tensor类型,表示用于cache索引的aclTensor。shape支持4维,格式为(BlockNum, BlockSize, Nkv, Hckv),dtype支持bfloat16,数据格式支持ND。
  • kr_cache:Tensor类型,表示用于key位置编码的cache。shape支持4维,格式为(BlockNum, BlockSize, Nkv, Dr),dtype支持bfloat16,数据格式支持ND。
  • dequant_scale_x:Tensor类型,预留参数,暂未使用,使用默认值即可。
  • dequant_scale_w_dq:Tensor类型,预留参数,暂未使用,使用默认值即可。
  • dequant_scale_w_uq_qr:Tensor类型,预留参数,暂未使用,使用默认值即可。
  • dequant_scale_w_dkv_kr:Tensor类型,预留参数,暂未使用,使用默认值即可。
  • quant_scale_ckv:Tensor类型,预留参数,暂未使用,使用默认值即可。
  • quant_scale_ckr:Tensor类型,预留参数,暂未使用,使用默认值即可。
  • smooth_scales_cq:Tensor类型,预留参数,暂未使用,使用默认值即可。
  • rmsnorm_epsilon_cq:Double类型,表示计算cQ的RmsNorm公式中的ε参数,用户不特意指定时可传入默认值1e-05。
  • rmsnorm_epsilon_ckv:Double类型,表示计算cKV的RmsNorm公式中的ε参数,用户不特意指定时可传入默认值1e-05。
  • cache_mode:String类型,表示kvCache的模式,支持"PA_BSND"、"PA_NZ",其用户不特意指定时可传入默认值“PA_BSND”。

输出说明

  • query:Tensor类型,表示Query的输出Tensor,即公式中qN。shape支持4维,格式为(B, S, N, Hckv),dtype支持bfloat16,数据格式支持ND。
  • query_rope:Tensor类型,表示Query位置编码的输出Tensor,即公式中qR。shape支持4维,格式为(B, S, N, Dr),dtype支持bfloat16,数据格式支持ND。
  • kv_cache_out:Tensor类型,表示Key输出到kvCache中的Tensor,即公式中kC。shape支持4维,格式为(BlockNum, BlockSize, Nkv, Hckv),dtype支持bfloat16,数据格式支持ND。
  • kr_cache_out:Tensor类型,表示Key的位置编码输出到kvCache中的Tensor,即公式中kR。shape支持4维,格式为(BlockNum, BlockSize, Nkv, Dr),dtype支持bfloat16,数据格式支持ND。

约束说明

  • 该接口仅在推理场景下使用。
  • 该接口支持图模式(目前仅支持PyTorch 2.1版本)。

支持的型号

Atlas A2 训练系列产品/Atlas 800I A2 推理产品/A200I A2 Box 异构组件

调用示例

  • 单算子模式调用
     1
     2
     3
     4
     5
     6
     7
     8
     9
    10
    11
    12
    13
    14
    15
    16
    17
    18
    19
    20
    21
    22
    23
    24
    25
    26
    27
    28
    29
    30
    31
    32
    33
    34
    35
    36
    37
    38
    39
    40
    41
    42
    43
    44
    45
    46
    import torch
    import torch_npu
    import math
    # 生成随机数据, 并发送到npu
    B = 8
    He = 7168
    Hcq = 1536
    Hckv = 512
    N = 32
    D = 128
    Dr = 64
    Skv = 1024
    S = 2
    Nkv = 1
    BlockSize = 128
    BlockNum = 64
    token_x = torch.rand(B, S, He, dtype=torch.bfloat16).npu()
    w_dq = torch.rand(He, Hcq, dtype=torch.bfloat16).npu()
    w_dq_cast = torch_npu.npu_format_cast(w_dq.contiguous(), 29)
    w_uq_qr = torch.rand(Hcq, N * (D + Dr), dtype=torch.bfloat16).npu()
    w_uq_qr_cast = torch_npu.npu_format_cast(w_uq_qr.contiguous(), 29)
    w_uk = torch.rand(N, D, Hckv, dtype=torch.bfloat16).npu()
    w_dkv_kr = torch.rand(He, Hckv + Dr, dtype=torch.bfloat16).npu()
    w_dkv_kr_cast = torch_npu.npu_format_cast(w_dkv_kr.contiguous(), 29)
    rmsnorm_gamma_cq = torch.rand(Hcq, dtype=torch.bfloat16).npu()
    rmsnorm_gamma_ckv = torch.rand(Hckv, dtype=torch.bfloat16).npu()
    rope_sin = torch.rand(B, S, Dr, dtype=torch.bfloat16).npu()
    rope_cos = torch.rand(B, S, Dr, dtype=torch.bfloat16).npu()
    cache_index = torch.rand(B, S).to(torch.int64).npu()
    kv_cache = torch.rand(BlockNum, BlockSize, Nkv, Hckv, dtype=torch.bfloat16).npu()
    kr_cache = torch.rand(BlockNum, BlockSize, Nkv, Dr, dtype=torch.bfloat16).npu()
    rmsnorm_epsilon_cq = 1.0e-5
    rmsnorm_epsilon_ckv = 1.0e-5
    cache_mode = "PA_BSND"
    
    # 调用MlaProlog算子
    query_mla, query_rope_mla, kv_cache_out_mla, kr_cache_out_mla = torch_npu.npu_mla_prolog(token_x, w_dq_cast, w_uq_qr_cast, w_uk, w_dkv_kr_cast, rmsnorm_gamma_cq, rmsnorm_gamma_ckv, rope_sin, rope_cos, cache_index, kv_cache, kr_cache, rmsnorm_epsilon_cq=rmsnorm_epsilon_cq, rmsnorm_epsilon_ckv=rmsnorm_epsilon_ckv, cache_mode=cache_mode)
    # 执行上述代码的输出out类似如下
    tensor([[[[ 0.0219,  0.0201,  0.0049,  ...,  0.0118, -0.0011, -0.0140],
            [ 0.0294,  0.0256, -0.0081,  ...,  0.0267,  0.0067, -0.0117],
            [ 0.0285,  0.0296,  0.0011,  ...,  0.0150,  0.0056, -0.0062],
            ..
            [ 0.0177,  0.0194, -0.0060,  ...,  0.0226,  0.0029, -0.0039],
            [ 0.0180,  0.0186, -0.0067,  ...,  0.0204, -0.0045, -0.0164],
            [ 0.0176,  0.0288, -0.0091,  ...,  0.0304,  0.0033, -0.0173]]]],
            device='npu:0', dtype=torch.bfloat16)
    
  • 图模式调用
     1
     2
     3
     4
     5
     6
     7
     8
     9
    10
    11
    12
    13
    14
    15
    16
    17
    18
    19
    20
    21
    22
    23
    24
    25
    26
    27
    28
    29
    30
    31
    32
    33
    34
    35
    36
    37
    38
    39
    40
    41
    42
    43
    44
    45
    46
    47
    48
    49
    50
    51
    52
    53
    54
    55
    56
    57
    58
    59
    60
    61
    62
    63
    64
    65
    66
    67
    68
    69
    70
    71
    72
    73
    74
    75
    76
    77
    78
    79
    80
    81
    82
    83
    84
    85
    86
    # 入图方式
    import torch
    import torch_npu
    import math
    import torchair as tng
    
    from torchair.configs.compiler_config import CompilerConfig
    import torch._dynamo
    TORCHDYNAMO_VERBOSE=1
    TORCH_LOGS="+dynamo"
    
    # 支持入图的打印宏
    import logging
    from torchair.core.utils import logger
    logger.setLevel(logging.DEBUG)
    config = CompilerConfig()
    config.debug.graph_dump.type = "pbtxt"
    npu_backend = tng.get_npu_backend(compiler_config=config)
    from torch.library import Library, impl
    
    # 数据生成
    B = 8
    He = 7168
    Hcq = 1536
    Hckv = 512
    N = 32
    D = 128
    Dr = 64
    Skv = 1024
    S = 2
    Nkv = 1
    BlockSize = 128
    BlockNum = 64
    token_x = torch.rand(B, S, He, dtype=torch.bfloat16).npu()
    w_dq = torch.rand(He, Hcq, dtype=torch.bfloat16).npu()
    w_dq_cast = torch_npu.npu_format_cast(w_dq.contiguous(), 29)
    w_uq_qr = torch.rand(Hcq, N * (D + Dr), dtype=torch.bfloat16).npu()
    w_uq_qr_cast = torch_npu.npu_format_cast(w_uq_qr.contiguous(), 29)
    w_uk = torch.rand(N, D, Hckv, dtype=torch.bfloat16).npu()
    w_dkv_kr = torch.rand(He, Hckv + Dr, dtype=torch.bfloat16).npu()
    w_dkv_kr_cast = torch_npu.npu_format_cast(w_dkv_kr.contiguous(), 29)
    rmsnorm_gamma_cq = torch.rand(Hcq, dtype=torch.bfloat16).npu()
    rmsnorm_gamma_ckv = torch.rand(Hckv, dtype=torch.bfloat16).npu()
    rope_sin = torch.rand(B, S, Dr, dtype=torch.bfloat16).npu()
    rope_cos = torch.rand(B, S, Dr, dtype=torch.bfloat16).npu()
    cache_index = torch.rand(B, S).to(torch.int64).npu()
    kv_cache = torch.rand(BlockNum, BlockSize, Nkv, Hckv, dtype=torch.bfloat16).npu()
    kr_cache = torch.rand(BlockNum, BlockSize, Nkv, Dr, dtype=torch.bfloat16).npu()
    rmsnorm_epsilon_cq = 1.0e-5
    rmsnorm_epsilon_ckv = 1.0e-5
    cache_mode = "PA_BSND"
    class Model(torch.nn.Module):
        def __init__(self):
            super().__init__()
        def forward(self):
            return torch_npu.npu_mla_prolog(token_x, w_dq_cast, w_uq_qr_cast, w_uk, w_dkv_kr_cast, rmsnorm_gamma_cq,rmsnorm_gamma_ckv, rope_sin, rope_cos, cache_index, kv_cache, kr_cache, rmsnorm_epsilon_cq=rmsnorm_epsilon_cq, rmsnorm_epsilon_ckv=rmsnorm_epsilon_ckv, cache_mode=cache_mode)
    def MetaInfershape():
        with torch.no_grad():
            model = Model()
            model = torch.compile(model, backend=npu_backend, dynamic=False, fullgraph=True)
            graph_output = model()
        query_mla, query_rope_mla, kv_cache_out_mla, kr_cache_out_mla = torch_npu.npu_mla_prolog(token_x, w_dq_cast, w_uq_qr_cast, w_uk, w_dkv_kr_cast, rmsnorm_gamma_cq,rmsnorm_gamma_ckv, rope_sin, rope_cos, cache_index, kv_cache, kr_cache, rmsnorm_epsilon_cq=rmsnorm_epsilon_cq, rmsnorm_epsilon_ckv=rmsnorm_epsilon_ckv, cache_mode=cache_mode)
        print("single op output:", query_mla)
        print("graph output:", graph_output)
        
    if __name__ == "__main__":
        MetaInfershape()
    
    # 执行上述代码的输出类似如下
    single op output with mask: tensor([[[[ 0.0219,  0.0201,  0.0049,  ...,  0.0118, -0.0011, -0.0140],
            [ 0.0294,  0.0256, -0.0081,  ...,  0.0267,  0.0067, -0.0117],
            [ 0.0285,  0.0296,  0.0011,  ...,  0.0150,  0.0056, -0.0062],
            ...,
            [ 0.0177,  0.0194, -0.0060,  ...,  0.0226,  0.0029, -0.0039],
            [ 0.0180,  0.0186, -0.0067,  ...,  0.0204, -0.0045, -0.0164],
            [ 0.0176,  0.0288, -0.0091,  ...,  0.0304,  0.0033, -0.0173]]]],
            device='npu:0', dtype=torch.bfloat16)
    
    graph output with mask: tensor([[[[ 0.0219,  0.0201,  0.0049,  ...,  0.0118, -0.0011, -0.0140],
            [ 0.0294,  0.0256, -0.0081,  ...,  0.0267,  0.0067, -0.0117],
            [ 0.0285,  0.0296,  0.0011,  ...,  0.0150,  0.0056, -0.0062],
            ...,
            [ 0.0177,  0.0194, -0.0060,  ...,  0.0226,  0.0029, -0.0039],
            [ 0.0180,  0.0186, -0.0067,  ...,  0.0204, -0.0045, -0.0164],
            [ 0.0176,  0.0288, -0.0091,  ...,  0.0304,  0.0033, -0.0173]]]],
            device='npu:0', dtype=torch.bfloat16)