@@ -74,6 +74,7 @@ def update_verify_buffers_to_fill_after_draft(
7474
7575 def __init__ (self , model_runner : ModelRunner ):
7676 super ().__init__ ()
77+ self .enable_torch_compile = False
7778 self .forward_metadata = None
7879 self .device = model_runner .device
7980 self .page_size = model_runner .page_size
@@ -576,112 +577,151 @@ def forward_decode_graph(
576577 layer , forward_batch .out_cache_loc , k , v
577578 )
578579
579- if not self .use_mla :
580- k_cache = forward_batch .token_to_kv_pool .get_key_buffer (
581- layer .layer_id
582- ).view (- 1 , self .page_size , layer .tp_k_head_num * layer .qk_head_dim )
583- v_cache = forward_batch .token_to_kv_pool .get_value_buffer (
584- layer .layer_id
585- ).view (- 1 , self .page_size , layer .tp_v_head_num * layer .v_head_dim )
586- query = q .reshape (- 1 , 1 , layer .tp_q_head_num * layer .qk_head_dim )
587- if self .forward_metadata .seq_lens_cpu_int is None :
588- actual_seq_len_kv = self .forward_metadata .seq_lens_cpu_list
589- else :
590- actual_seq_len_kv = (
591- self .forward_metadata .seq_lens_cpu_int .cpu ().int ().tolist ()
592- )
580+ if not self .use_mla and self .enable_torch_compile :
581+ k_cache = forward_batch .token_to_kv_pool .get_key_buffer (layer .layer_id )
582+ v_cache = forward_batch .token_to_kv_pool .get_value_buffer (layer .layer_id )
583+ query = q .reshape (- 1 , layer .tp_q_head_num , layer .qk_head_dim )
593584 num_tokens = query .shape [0 ]
594- workspace = torch_npu ._npu_fused_infer_attention_score_get_max_workspace (
595- query ,
596- k_cache ,
597- v_cache ,
598- block_table = self .forward_metadata .block_tables ,
599- block_size = self .page_size ,
600- num_heads = layer .tp_q_head_num ,
601- num_key_value_heads = layer .tp_k_head_num ,
602- input_layout = "BSH" ,
603- scale = layer .scaling ,
604- actual_seq_lengths_kv = actual_seq_len_kv ,
605- )
606- output = torch .empty (
607- (num_tokens , 1 , layer .tp_q_head_num * layer .v_head_dim ),
608- dtype = q .dtype ,
609- device = q .device ,
610- )
611- softmax_lse = torch .empty (1 , dtype = q .dtype , device = q .device )
612- torch_npu .npu_fused_infer_attention_score .out (
613- query ,
614- k_cache ,
615- v_cache ,
616- block_table = self .forward_metadata .block_tables ,
617- block_size = self .page_size ,
618- num_heads = layer .tp_q_head_num ,
619- num_key_value_heads = layer .tp_k_head_num ,
620- input_layout = "BSH" ,
621- scale = layer .scaling ,
622- actual_seq_lengths_kv = actual_seq_len_kv ,
623- workspace = workspace ,
624- out = [output , softmax_lse ],
625- )
626- return output .view (num_tokens , layer .tp_q_head_num * layer .v_head_dim )
627- else :
628- c_kv , k_rope = forward_batch .token_to_kv_pool .get_kv_buffer (layer .layer_id )
629- k_rope_cache = k_rope .view (
630- - 1 , layer .tp_k_head_num , self .page_size , self .qk_rope_head_dim
631- )
632- c_kv_cache = c_kv .view (
633- - 1 , layer .tp_v_head_num , self .page_size , self .kv_lora_rank
585+ attn_output = torch .empty (
586+ (num_tokens , layer .tp_q_head_num , layer .v_head_dim ),
587+ dtype = query .dtype ,
588+ device = query .device ,
634589 )
635590
636- q_nope = q .view (- 1 , layer .tp_q_head_num , 1 , self .kv_lora_rank ).contiguous ()
637- q_rope = q_rope .view (- 1 , layer .tp_q_head_num , 1 , self .qk_rope_head_dim )
638591 if self .forward_metadata .seq_lens_cpu_int is None :
639- actual_seq_len_kv = self .forward_metadata .seq_lens_cpu_list
640- else :
641- actual_seq_len_kv = (
642- self .forward_metadata .seq_lens_cpu_int .cpu ().int ().tolist ()
592+ actual_seq_len_kv = torch .from_numpy (
593+ np .array (self .forward_metadata .seq_lens_cpu_list ).astype (np .int32 )
643594 )
595+ else :
596+ actual_seq_len_kv = self .forward_metadata .seq_lens_cpu_int
644597
645- workspace = torch_npu ._npu_fused_infer_attention_score_get_max_workspace (
646- q_nope ,
647- c_kv_cache ,
648- c_kv_cache ,
649- query_rope = q_rope ,
650- key_rope = k_rope_cache ,
598+ torch_npu ._npu_paged_attention (
599+ query = query ,
600+ key_cache = k_cache ,
601+ value_cache = v_cache ,
651602 num_heads = layer .tp_q_head_num ,
652- num_key_value_heads = layer .tp_k_head_num ,
603+ num_kv_heads = layer .tp_k_head_num ,
604+ scale_value = layer .scaling ,
653605 block_table = self .forward_metadata .block_tables ,
654- block_size = self .page_size ,
655- input_layout = "BNSD" ,
656- scale = layer .scaling ,
657- actual_seq_lengths_kv = actual_seq_len_kv ,
658- antiquant_mode = 0 ,
659- antiquant_scale = None ,
660- sparse_mode = 0 ,
606+ context_lens = actual_seq_len_kv ,
607+ out = attn_output ,
661608 )
662- output = torch .empty_like (q_nope , dtype = q .dtype , device = q .device )
663- softmax_lse = torch .empty (1 , dtype = q .dtype , device = q .device )
609+ return attn_output .view (num_tokens , layer .tp_q_head_num * layer .v_head_dim )
610+ else :
611+ if not self .use_mla :
612+ k_cache = forward_batch .token_to_kv_pool .get_key_buffer (
613+ layer .layer_id
614+ ).view (- 1 , self .page_size , layer .tp_k_head_num * layer .qk_head_dim )
615+ v_cache = forward_batch .token_to_kv_pool .get_value_buffer (
616+ layer .layer_id
617+ ).view (- 1 , self .page_size , layer .tp_v_head_num * layer .v_head_dim )
618+ query = q .reshape (- 1 , 1 , layer .tp_q_head_num * layer .qk_head_dim )
619+ if self .forward_metadata .seq_lens_cpu_int is None :
620+ actual_seq_len_kv = self .forward_metadata .seq_lens_cpu_list
621+ else :
622+ actual_seq_len_kv = (
623+ self .forward_metadata .seq_lens_cpu_int .cpu ().int ().tolist ()
624+ )
625+ num_tokens = query .shape [0 ]
626+ workspace = (
627+ torch_npu ._npu_fused_infer_attention_score_get_max_workspace (
628+ query ,
629+ k_cache ,
630+ v_cache ,
631+ block_table = self .forward_metadata .block_tables ,
632+ block_size = self .page_size ,
633+ num_heads = layer .tp_q_head_num ,
634+ num_key_value_heads = layer .tp_k_head_num ,
635+ input_layout = "BSH" ,
636+ scale = layer .scaling ,
637+ actual_seq_lengths_kv = actual_seq_len_kv ,
638+ )
639+ )
640+ output = torch .empty (
641+ (num_tokens , 1 , layer .tp_q_head_num * layer .v_head_dim ),
642+ dtype = q .dtype ,
643+ device = q .device ,
644+ )
645+ softmax_lse = torch .empty (1 , dtype = q .dtype , device = q .device )
646+ torch_npu .npu_fused_infer_attention_score .out (
647+ query ,
648+ k_cache ,
649+ v_cache ,
650+ block_table = self .forward_metadata .block_tables ,
651+ block_size = self .page_size ,
652+ num_heads = layer .tp_q_head_num ,
653+ num_key_value_heads = layer .tp_k_head_num ,
654+ input_layout = "BSH" ,
655+ scale = layer .scaling ,
656+ actual_seq_lengths_kv = actual_seq_len_kv ,
657+ workspace = workspace ,
658+ out = [output , softmax_lse ],
659+ )
660+ return output .view (num_tokens , layer .tp_q_head_num * layer .v_head_dim )
661+ else :
662+ c_kv , k_rope = forward_batch .token_to_kv_pool .get_kv_buffer (
663+ layer .layer_id
664+ )
665+ k_rope_cache = k_rope .view (
666+ - 1 , layer .tp_k_head_num , self .page_size , self .qk_rope_head_dim
667+ )
668+ c_kv_cache = c_kv .view (
669+ - 1 , layer .tp_v_head_num , self .page_size , self .kv_lora_rank
670+ )
664671
665- torch_npu .npu_fused_infer_attention_score .out (
666- q_nope ,
667- c_kv_cache ,
668- c_kv_cache ,
669- query_rope = q_rope ,
670- key_rope = k_rope_cache ,
671- num_heads = layer .tp_q_head_num ,
672- num_key_value_heads = layer .tp_k_head_num ,
673- block_table = self .forward_metadata .block_tables ,
674- block_size = self .page_size ,
675- input_layout = "BNSD" ,
676- scale = layer .scaling ,
677- actual_seq_lengths_kv = actual_seq_len_kv ,
678- antiquant_mode = 0 ,
679- antiquant_scale = None ,
680- sparse_mode = 0 ,
681- workspace = workspace ,
682- out = [output , softmax_lse ],
683- )
684- return output .view (- 1 , layer .tp_q_head_num * self .kv_lora_rank )
672+ q_nope = q .view (
673+ - 1 , layer .tp_q_head_num , 1 , self .kv_lora_rank
674+ ).contiguous ()
675+ q_rope = q_rope .view (- 1 , layer .tp_q_head_num , 1 , self .qk_rope_head_dim )
676+ if self .forward_metadata .seq_lens_cpu_int is None :
677+ actual_seq_len_kv = self .forward_metadata .seq_lens_cpu_list
678+ else :
679+ actual_seq_len_kv = (
680+ self .forward_metadata .seq_lens_cpu_int .cpu ().int ().tolist ()
681+ )
682+
683+ workspace = (
684+ torch_npu ._npu_fused_infer_attention_score_get_max_workspace (
685+ q_nope ,
686+ c_kv_cache ,
687+ c_kv_cache ,
688+ query_rope = q_rope ,
689+ key_rope = k_rope_cache ,
690+ num_heads = layer .tp_q_head_num ,
691+ num_key_value_heads = layer .tp_k_head_num ,
692+ block_table = self .forward_metadata .block_tables ,
693+ block_size = self .page_size ,
694+ input_layout = "BNSD" ,
695+ scale = layer .scaling ,
696+ actual_seq_lengths_kv = actual_seq_len_kv ,
697+ antiquant_mode = 0 ,
698+ antiquant_scale = None ,
699+ sparse_mode = 0 ,
700+ )
701+ )
702+ output = torch .empty_like (q_nope , dtype = q .dtype , device = q .device )
703+ softmax_lse = torch .empty (1 , dtype = q .dtype , device = q .device )
704+
705+ torch_npu .npu_fused_infer_attention_score .out (
706+ q_nope ,
707+ c_kv_cache ,
708+ c_kv_cache ,
709+ query_rope = q_rope ,
710+ key_rope = k_rope_cache ,
711+ num_heads = layer .tp_q_head_num ,
712+ num_key_value_heads = layer .tp_k_head_num ,
713+ block_table = self .forward_metadata .block_tables ,
714+ block_size = self .page_size ,
715+ input_layout = "BNSD" ,
716+ scale = layer .scaling ,
717+ actual_seq_lengths_kv = actual_seq_len_kv ,
718+ antiquant_mode = 0 ,
719+ antiquant_scale = None ,
720+ sparse_mode = 0 ,
721+ workspace = workspace ,
722+ out = [output , softmax_lse ],
723+ )
724+ return output .view (- 1 , layer .tp_q_head_num * self .kv_lora_rank )
685725
686726 def forward_decode (
687727 self ,
0 commit comments