torch_npu.npu_advance_step_flashattn(Tensor(a!) input_tokens, Tensor sampled_token_ids, Tensor(b!) input_positions, Tensor(c!) seq_lens, Tensor(d!) slot_mapping, Tensor block_tables, int num_seqs, int num_queries, int block_size) -> ()
在NPU上实现vLLM库中advance_step_flashattn的功能,在每个生成步骤中原地更新input_tokens,input_positions,seq_lens和slot_mapping。
此接口将原地更新input_tokens,input_positions,seq_lens和slot_mapping的值,无返回值。
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 | import numpy as np import torch import torch_npu num_seqs = 16 num_queries = 8 block_size = 8 input_token = np.random.randint(10, size=(num_seqs,)) sampled_token_id = np.random.randint(10, size=(num_queries,1)) input_position = np.random.randint(10, size=(num_seqs,)) seq_len = np.random.randint(10, size=(num_seqs,)) slot_mapping = np.random.randint(10, size=(num_seqs,)) input_tokens = torch.tensor(input_token, dtype=torch.int64, device="npu") sampled_token_ids = torch.tensor(sampled_token_id, dtype=torch.int64, device="npu") input_positions = torch.tensor(input_position, dtype=torch.int64, device="npu") seq_lens = torch.tensor(seq_len, dtype=torch.int64, device="npu") slot_mappings = torch.tensor(slot_mapping, dtype=torch.int64, device="npu") block_table = np.random.randint(10, size=(num_seqs, torch.max(seq_lens.cpu()) // block_size + 1)) block_tables = torch.tensor(block_table, dtype=torch.int64, device="npu") torch_npu.npu_advance_step_flashattn(input_tokens, sampled_token_ids, input_positions, seq_lens, slot_mappings, block_tables, num_seqs, num_queries, block_size) |