torch_npu.npu_kv_rmsnorm_rope_cache(Tensor kv, Tensor gamma, Tensor cos, Tensor sin, Tensor index, Tensor k_cache, Tensor ckv_cache, *, Tensor? k_rope_scale=None, Tensor? c_kv_scale=None, Tensor? k_rope_offset=None, Tensor? c_kv_offset=None, float epsilon=1e-5, str cache_mode='Norm', bool is_output_kv=False) -> (Tensor, Tensor, Tensor, Tensor)
Tensor中shape使用的变量说明:
枚举值 |
模式名 |
说明 |
---|---|---|
Norm |
KV-Cache更新模式 |
k_cache形状为[batch_size, 1, cache_length, rope_size],ckv_cache形状为[batch_size, 1, cache_length, rms_size]。index形状为[batch_size, seq_len],index里的值表示每个Batch下的偏移。 |
PA/PA_BNSD |
PagedAttention模式 |
k_cache形状为[block_num, block_size, 1, rope_size],ckv_cache形状为[block_num, block_size, 1, rms_size]。index形状为[batch_size*seq_len],index里的值表示每个token的偏移。 |
PA_NZ |
Cache数据格式为FRACTAL_NZ的PagedAttention模式 |
k_cache形状为 [block_num, block_size, 1, rope_size],ckv_cache形状为[block_num, block_size, 1, rms_size]。index形状为[batch_size * seq_len],index里的值表示每个token的偏移。 不同量化模式下数据排布不同:
|
PA_BLK_BNSD |
特殊的PagedAttention模式 |
k_cache形状为[block_num, block_size, 1, rope_size],ckv_cache形状为[block_num, block_size, 1, rms_size]。index形状为[batch_size*Ceil(seq_len/block_size)],index里的值表示每个block的起始偏移,不再和token一一对应。 |
PA_BLK_NZ |
Cache数据格式为FRACTAL_NZ的特殊的PagedAttention模式 |
k_cache形状为 [block_num, block_size, 1, rope_size],ckv_cache形状为[block_num, block_size, 1, rms_size]。index形状为[batch_size * Ceil(seq_len / block_size)],index里的值表示每个block的起始偏移,不再和token一一对应。 不同量化模式下数据排布不同:
|
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 | import torch import torch_npu class Model(torch.nn.Module): def __init__(self): super().__init__() def forward(self, kv, gamma, cos, sin, index, k_cache, ckv_cache, k_rope_scale=None, c_kv_scale=None, k_rope_offset=None, c_kv_offset=None, epsilon=1e-05, cache_mode="Norm", is_output_kv=False): k_cache, v_cache, k_rope, c_kv = torch_npu.npu_kv_rmsnorm_rope_cache(kv, gamma, cos, sin, index, k_cache, ckv_cache, k_rope_scale=k_rope_scale, c_kv_scale=c_kv_scale, k_rope_offset=k_rope_offset, c_kv_offset=c_kv_offset, epsilon=epsilon, cache_mode=cache_mode, is_output_kv=is_output_kv) return k_cache, v_cache, k_rope, c_kv model = Model().npu() _, _, k_rope, c_kv = model(kv, gamma, cos, sin, index, k_cache, ckv_cache, k_rope_scale, c_kv_scale, None, None, 1e-5, cache_mode, is_output_kv) |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 | import torch import torch_npu import torchair as tng from torchair.configs.compiler_config import CompilerConfig config = CompilerConfig() config.experimental_config.keep_inference_input_mutations = True npu_backend = tng.get_npu_backend(compiler_config=config) class Model(torch.nn.Module): def __init__(self): super().__init__() def forward(self, kv, gamma, cos, sin, index, k_cache, ckv_cache, k_rope_scale=None, c_kv_scale=None, k_rope_offset=None, c_kv_offset=None, epsilon=1e-05, cache_mode="Norm", is_output_kv=False): k_cache, v_cache, k_rope, c_kv = torch_npu.npu_kv_rmsnorm_rope_cache(kv, gamma, cos, sin, index, k_cache, ckv_cache, k_rope_scale=k_rope_scale, c_kv_scale=c_kv_scale, k_rope_offset=k_rope_offset, c_kv_offset=c_kv_offset, epsilon=epsilon, cache_mode=cache_mode, is_output_kv=is_output_kv) return k_cache, v_cache, k_rope, c_kv model = Model().npu() model = torch.compile(model, backend=npu_backend, dynamic=False) _, _, k_rope, c_kv = model(kv, gamma, cos, sin, index, k_cache, ckv_cache, k_rope_scale, c_kv_scale, None, None, 1e-5, cache_mode, is_output_kv) |