# 入图方式
import torch
import torch_npu
import math
import torchair as tng
from torchair.configs.compiler_config import CompilerConfig
import torch._dynamo
TORCHDYNAMO_VERBOSE=1
TORCH_LOGS="+dynamo"
# 支持入图的打印宏
import logging
from torchair.core.utils import logger
logger.setLevel(logging.DEBUG)
config = CompilerConfig()
config.debug.graph_dump.type = "pbtxt"
npu_backend = tng.get_npu_backend(compiler_config=config)
from torch.library import Library, impl
# 数据生成
B = 8
He = 7168
Hcq = 1536
Hckv = 512
N = 32
D = 128
Dr = 64
Skv = 1024
S = 2
Nkv = 1
BlockSize = 128
BlockNum = 64
token_x = torch.rand(B, S, He, dtype=torch.bfloat16).npu()
w_dq = torch.rand(He, Hcq, dtype=torch.bfloat16).npu()
w_dq_cast = torch_npu.npu_format_cast(w_dq.contiguous(), 29)
w_uq_qr = torch.rand(Hcq, N * (D + Dr), dtype=torch.bfloat16).npu()
w_uq_qr_cast = torch_npu.npu_format_cast(w_uq_qr.contiguous(), 29)
w_uk = torch.rand(N, D, Hckv, dtype=torch.bfloat16).npu()
w_dkv_kr = torch.rand(He, Hckv + Dr, dtype=torch.bfloat16).npu()
w_dkv_kr_cast = torch_npu.npu_format_cast(w_dkv_kr.contiguous(), 29)
rmsnorm_gamma_cq = torch.rand(Hcq, dtype=torch.bfloat16).npu()
rmsnorm_gamma_ckv = torch.rand(Hckv, dtype=torch.bfloat16).npu()
rope_sin = torch.rand(B, S, Dr, dtype=torch.bfloat16).npu()
rope_cos = torch.rand(B, S, Dr, dtype=torch.bfloat16).npu()
cache_index = torch.rand(B, S).to(torch.int64).npu()
kv_cache = torch.rand(BlockNum, BlockSize, Nkv, Hckv, dtype=torch.bfloat16).npu()
kr_cache = torch.rand(BlockNum, BlockSize, Nkv, Dr, dtype=torch.bfloat16).npu()
rmsnorm_epsilon_cq = 1.0e-5
rmsnorm_epsilon_ckv = 1.0e-5
cache_mode = "PA_BSND"
class Model(torch.nn.Module):
def __init__(self):
super().__init__()
def forward(self):
return torch_npu.npu_mla_prolog(token_x, w_dq_cast, w_uq_qr_cast, w_uk, w_dkv_kr_cast, rmsnorm_gamma_cq,rmsnorm_gamma_ckv, rope_sin, rope_cos, cache_index, kv_cache, kr_cache, rmsnorm_epsilon_cq=rmsnorm_epsilon_cq, rmsnorm_epsilon_ckv=rmsnorm_epsilon_ckv, cache_mode=cache_mode)
def MetaInfershape():
with torch.no_grad():
model = Model()
model = torch.compile(model, backend=npu_backend, dynamic=False, fullgraph=True)
graph_output = model()
query_mla, query_rope_mla, kv_cache_out_mla, kr_cache_out_mla = torch_npu.npu_mla_prolog(token_x, w_dq_cast, w_uq_qr_cast, w_uk, w_dkv_kr_cast, rmsnorm_gamma_cq,rmsnorm_gamma_ckv, rope_sin, rope_cos, cache_index, kv_cache, kr_cache, rmsnorm_epsilon_cq=rmsnorm_epsilon_cq, rmsnorm_epsilon_ckv=rmsnorm_epsilon_ckv, cache_mode=cache_mode)
print("single op output:", query_mla)
print("graph output:", graph_output)
if __name__ == "__main__":
MetaInfershape()
# 执行上述代码的输出类似如下
single op output with mask: tensor([[[[ 0.0219, 0.0201, 0.0049, ..., 0.0118, -0.0011, -0.0140],
[ 0.0294, 0.0256, -0.0081, ..., 0.0267, 0.0067, -0.0117],
[ 0.0285, 0.0296, 0.0011, ..., 0.0150, 0.0056, -0.0062],
...,
[ 0.0177, 0.0194, -0.0060, ..., 0.0226, 0.0029, -0.0039],
[ 0.0180, 0.0186, -0.0067, ..., 0.0204, -0.0045, -0.0164],
[ 0.0176, 0.0288, -0.0091, ..., 0.0304, 0.0033, -0.0173]]]],
device='npu:0', dtype=torch.bfloat16)
graph output with mask: tensor([[[[ 0.0219, 0.0201, 0.0049, ..., 0.0118, -0.0011, -0.0140],
[ 0.0294, 0.0256, -0.0081, ..., 0.0267, 0.0067, -0.0117],
[ 0.0285, 0.0296, 0.0011, ..., 0.0150, 0.0056, -0.0062],
...,
[ 0.0177, 0.0194, -0.0060, ..., 0.0226, 0.0029, -0.0039],
[ 0.0180, 0.0186, -0.0067, ..., 0.0204, -0.0045, -0.0164],
[ 0.0176, 0.0288, -0.0091, ..., 0.0304, 0.0033, -0.0173]]]],
device='npu:0', dtype=torch.bfloat16)