Skip to content

Commit 250318a

Browse files
authored
[fix]qwen2vl support fa3 (#993)
1 parent a4bc0d6 commit 250318a

File tree

3 files changed

+42
-2
lines changed

3 files changed

+42
-2
lines changed

lightllm/models/llama/flashattention_infer_struct.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -24,8 +24,7 @@ def get_page_table_buffer(cls, graph_max_batch_size: int, max_seq_len: int):
2424
]
2525
return cls._shared_page_table_buffer
2626

27-
def init_some_extra_state(self, model, input_ids: torch.Tensor):
28-
super().init_some_extra_state(model, input_ids)
27+
def _init_flash_attention_state(self, model, input_ids: torch.Tensor):
2928
if self.is_prefill:
3029
self.cu_seqlens_q = self.b1_cu_q_seq_len.int()
3130
self.cu_seqlens_k = self.b1_cu_kv_seq_len.int()
@@ -93,3 +92,8 @@ def init_some_extra_state(self, model, input_ids: torch.Tensor):
9392
)
9493
)
9594
return
95+
96+
def init_some_extra_state(self, model, input_ids: torch.Tensor):
97+
super().init_some_extra_state(model, input_ids)
98+
self._init_flash_attention_state(model, input_ids)
99+
return
Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
import os
2+
import torch
3+
import numpy as np
4+
import torch.distributed as dist
5+
from lightllm.common.basemodel.infer_struct import InferStateInfo
6+
from lightllm.models.llama.flashattention_infer_struct import FlashAttentionStateInfo
7+
from lightllm.utils.envs_utils import get_env_start_args
8+
from lightllm.utils.dist_utils import get_current_device_id
9+
from lightllm.models.deepseek2.triton_kernel.repack_kv_index import repack_kv_index
10+
from lightllm.common.basemodel.batch_objs import ModelInput
11+
12+
13+
class Qwen2VLFlashAttentionStateInfo(FlashAttentionStateInfo):
14+
def init_some_extra_state(self, model, input_ids: torch.Tensor):
15+
InferStateInfo.init_some_extra_state(self, model, input_ids)
16+
if self.is_prefill:
17+
self.max_seq_len = self.max_kv_seq_len
18+
self.q_max_seq_len = self.max_q_seq_len
19+
position_ids = self.position_ids
20+
self.position_sin = model._sin_cached[:, position_ids, :].unsqueeze(1)
21+
self.position_cos = model._cos_cached[:, position_ids, :].unsqueeze(1)
22+
position_ids = None
23+
else:
24+
position_ids = self.position_ids
25+
self.position_sin = model._sin_cached[:, position_ids, :].unsqueeze(1)
26+
self.position_cos = model._cos_cached[:, position_ids, :].unsqueeze(1)
27+
28+
# init flash attention state
29+
self._init_flash_attention_state(model, input_ids)
30+
return

lightllm/models/qwen2_vl/model.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
from transformers.tokenization_utils_base import PaddingStrategy, PreTokenizedInput, TextInput, TruncationStrategy
1313
from typing import List, Optional, Union
1414
from transformers.utils import TensorType, logging
15+
from lightllm.models.qwen2_vl.flashattention_infer_struct import Qwen2VLFlashAttentionStateInfo
1516
from lightllm.common.build_utils import repair_config
1617
from lightllm.models.registry import ModelRegistry
1718
from lightllm.models.qwen2_vl.infer_struct import Qwen2VLInferStateInfo
@@ -20,6 +21,7 @@
2021
import torch
2122
from PIL import Image
2223
from .vision_process import smart_resize
24+
from lightllm.utils.envs_utils import enable_env_vars, get_env_start_args
2325
from lightllm.models.qwen2.layer_weights import transformer_layer_weight, pre_and_post_layer_weight
2426
from lightllm.models.qwen2.model import Qwen2TpPartModel
2527
import os
@@ -103,6 +105,10 @@ def __init__(self, kvargs):
103105
super().__init__(kvargs)
104106
return
105107

108+
def _init_inferstate_cls(self):
109+
if get_env_start_args().enable_fa3:
110+
self.infer_state_class = Qwen2VLFlashAttentionStateInfo
111+
106112
def _init_config(self):
107113
with open(os.path.join(self.weight_dir_, "config.json"), "r") as json_file:
108114
self.config = json.load(json_file)

0 commit comments

Comments
 (0)