Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
57 changes: 57 additions & 0 deletions llmc/compression/token_reduction/divprune.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import functools
from functools import wraps
from types import MethodType

Expand All @@ -6,6 +7,7 @@
from llmc.utils.registry_factory import TOKEN_REDUCTION_REGISTRY

from .token_reduction_module import TokenReductionModule
from .utils import prefill_wrapper


def pairwise_cosine_similarity(matrix):
Expand Down Expand Up @@ -84,6 +86,41 @@ def divprune_post_hook(*args, pruning_paras=None):
return tuple(args)


def prune_qwenv25vl_hook(module, args, kwargs, pruning_paras):
if kwargs['position_ids'].shape[-1] == 1:
return args, kwargs
inputs_embeds = kwargs['inputs_embeds']
attention_mask = kwargs['attention_mask']
rate = pruning_paras['reduction_ratio']
SYS_TOKEN_LEN = pruning_paras['vision_token_start_index']
img_feature_len = pruning_paras['vision_token_length']
device = inputs_embeds.device

visual_tokens = inputs_embeds[0][SYS_TOKEN_LEN: SYS_TOKEN_LEN + img_feature_len]
selected_visual_tokens, cosine_matrix = divprune(
visual_tokens, img_feature_len, None, threshold_ratio=1 - rate
)
selected_visual_tokens += SYS_TOKEN_LEN
keep_indexs = torch.cat(
(
torch.arange(SYS_TOKEN_LEN, device=device),
selected_visual_tokens,
torch.arange(
SYS_TOKEN_LEN + img_feature_len, inputs_embeds.shape[1], device=device
),
)
)
keep_indexs = keep_indexs.sort().values

kwargs['inputs_embeds'] = inputs_embeds[:, keep_indexs, :]
kwargs['position_ids'] = kwargs['position_ids'][:, :, keep_indexs]
if attention_mask is not None:
kwargs['attention_mask'] = attention_mask[:, keep_indexs]
kwargs['cache_position'] = keep_indexs

return args, kwargs


@TOKEN_REDUCTION_REGISTRY.register('DivPrune')
class DivPrune(TokenReductionModule):
def __init__(self, config, model, blocks):
Expand Down Expand Up @@ -114,6 +151,14 @@ def wrapper(self, *args, **kwargs):
return divprune_post_hook(*outs, pruning_paras=pruning_paras)
return wrapper

@prefill_wrapper
def vtoken_length_hook(module, args, pruning_paras):
input_ids = args[0]
token_indices = torch.where(
input_ids[0] == pruning_paras['vision_token_index']
)[0]
pruning_paras['vision_token_length'] = token_indices.shape[0]

if self.model.__class__.__name__ == 'Llava':

self.model.vlm_model.prepare_inputs_labels_for_multimodal = MethodType(
Expand All @@ -123,3 +168,15 @@ def wrapper(self, *args, **kwargs):
llava_next=self.special_config['vision_token_length'] is None
), self.model.vlm_model
)
elif self.model.__class__.__name__ == 'Qwen2_5VL':

self.model.embed_tokens.register_forward_pre_hook(
functools.partial(vtoken_length_hook, pruning_paras=self.pruning_paras)
)
self.model.language_model.register_forward_pre_hook(
functools.partial(
prune_qwenv25vl_hook,
pruning_paras=self.pruning_paras,
),
with_kwargs=True
)
12 changes: 12 additions & 0 deletions llmc/compression/token_reduction/random.py
Original file line number Diff line number Diff line change
Expand Up @@ -228,6 +228,14 @@ def store_attention_hook(m, x, layer_outputs, pruning_paras):
layer_attention = layer_outputs[1]
pruning_paras['attn_scores'] = layer_attention

@prefill_wrapper
def vtoken_length_hook(module, args, pruning_paras):
input_ids = args[0]
token_indices = torch.where(
input_ids[0] == pruning_paras['vision_token_index']
)[0]
pruning_paras['vision_token_length'] = token_indices.shape[0]

if self.special_config['vision_token_length'] is None:
if self.model.__class__.__name__ == 'Llava':
self.model.vlm_model.prepare_inputs_labels_for_multimodal = MethodType(
Expand All @@ -236,6 +244,10 @@ def store_attention_hook(m, x, layer_outputs, pruning_paras):
self.pruning_paras
), self.model.vlm_model
)
elif self.model.__class__.__name__ == 'Qwen2_5VL':
self.model.embed_tokens.register_forward_pre_hook(
functools.partial(vtoken_length_hook, pruning_paras=self.pruning_paras)
)

if self.special_config['metric'] == 'random':
self.blocks[self.pruning_loc].register_forward_pre_hook(
Expand Down
103 changes: 100 additions & 3 deletions llmc/compression/token_reduction/sparsevlm.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
sparse_token_list_192 = []
sparse_token_list_128 = []
sparse_token_list_64 = []
sparse_token_list_960 = []
sparse_token_list_640 = []
sparse_token_list_320 = []
sparse_token_list_160 = []
Expand Down Expand Up @@ -329,8 +330,9 @@ def decoder_attn_hook(module, inputs, kwargs, layer_outputs, pruning_paras, laye

if attention_mask is not None:
attention_mask = attention_mask[:, :, keep_indexs, keep_indexs]
new_pe0 = position_embeddings[0][:, keep_indexs, :].clone()
new_pe1 = position_embeddings[1][:, keep_indexs, :].clone()
index_dim = 1 if position_embeddings[0].dim() == 3 else 2
new_pe0 = position_embeddings[0].index_select(index_dim, keep_indexs).clone()
new_pe1 = position_embeddings[1].index_select(index_dim, keep_indexs).clone()
position_embeddings = (new_pe0, new_pe1)

pruning_paras['v_token_num'] = v_token_num
Expand All @@ -352,6 +354,75 @@ def read_parameter_hook(module, args, kwargs, pruning_paras):

return args, kwargs

@prefill_wrapper
def vtoken_length_hook(module, args, pruning_paras):
input_ids = args[0]
token_indices = torch.where(
input_ids[0] == pruning_paras['vision_token_index']
)[0]
pruning_paras['vision_token_length'] = token_indices.shape[0]
pruning_paras['pre_prompt_length_list'] = [token_indices[0].item()]

def get_attn_logits_for_qwen25vl(
module,
args, kwargs, layer_outs,
pruning_paras, layer_idx
):
if kwargs['position_ids'].shape[-1] == 1:
return layer_outs

from transformers.models.qwen2_5_vl.modeling_qwen2_5_vl import (
apply_multimodal_rotary_pos_emb, repeat_kv)

hidden_states = kwargs['hidden_states']
position_embeddings = kwargs['position_embeddings']
past_key_value = layer_outs[2]
attention_mask = kwargs['attention_mask']

t_token_idx = pruning_paras['t_token_idx']
v_token_start = pruning_paras['v_token_start']
v_token_num = pruning_paras['v_token_num']

bsz, q_len, _ = hidden_states.size()

query_states = module.q_proj(hidden_states)
key_states = module.k_proj(hidden_states)
value_states = module.v_proj(hidden_states)

query_states = query_states.view(bsz, q_len, -1, module.head_dim).transpose(1, 2)
key_states = key_states.view(bsz, q_len, -1, module.head_dim).transpose(1, 2)
value_states = value_states.view(bsz, q_len, -1, module.head_dim).transpose(1, 2)

cos, sin = position_embeddings
query_states, key_states = apply_multimodal_rotary_pos_emb(
query_states, key_states, cos, sin, module.rope_scaling['mrope_section']
)

if past_key_value is not None:
key_states = past_key_value.key_cache[layer_idx]
value_states = past_key_value.value_cache[layer_idx]

key_states = repeat_kv(key_states, module.num_key_value_groups)
value_states = repeat_kv(value_states, module.num_key_value_groups)

t_token_idx = t_token_idx[1] + v_token_start + v_token_num
L, S = query_states.size(-2), key_states.size(-2)
scale_factor = 1 / math.sqrt(query_states.size(-1))
attn_bias = torch.zeros(L, S, dtype=query_states.dtype)
if module.is_causal:
assert attention_mask is None
temp_mask = torch.ones(L, S, dtype=torch.bool).tril(diagonal=0)
attn_bias.masked_fill_(temp_mask.logical_not(), float('-inf'))
attn_bias.to(query_states.dtype)

attn_logits = query_states @ key_states.transpose(2, 3) * scale_factor
attn_logits += attn_bias.to(query_states.device)
attn_logits = torch.softmax(attn_logits, dim=-1)

pruning_paras['attn_logits'] = attn_logits

return layer_outs

if self.model.__class__.__name__ == 'LlavaHf':
self.model.embed_tokens.register_forward_pre_hook(
functools.partial(input_hook, pruning_paras=self.pruning_paras)
Expand All @@ -364,11 +435,17 @@ def read_parameter_hook(module, args, kwargs, pruning_paras):
llava_next=self.special_config['vision_token_length'] is None
), self.model.vlm_model
)
elif self.model.__class__.__name__ == 'Qwen2_5VL':
self.model.embed_tokens.register_forward_pre_hook(
functools.partial(vtoken_length_hook, pruning_paras=self.pruning_paras)
)

if self.model.__class__.__name__ == 'LlavaHf':
llama_model = self.model.model
elif self.model.__class__.__name__ == 'Llava':
llama_model = self.model.model.model
elif self.model.__class__.__name__ == 'Qwen2_5VL':
llama_model = self.model.language_model
llama_model.register_forward_pre_hook(
functools.partial(register_module_paras, pruning_paras=self.pruning_paras),
with_kwargs=True
Expand Down Expand Up @@ -405,6 +482,23 @@ def read_parameter_hook(module, args, kwargs, pruning_paras):
),
with_kwargs=True
)
elif self.model.__class__.__name__ == 'Qwen2_5VL':
self.blocks[block_idx].register_forward_pre_hook(
functools.partial(
update_kwargs_hook,
pruning_paras=self.pruning_paras,
layer_idx=block_idx,
),
with_kwargs=True
)
self.blocks[block_idx].self_attn.register_forward_hook(
functools.partial(
get_attn_logits_for_qwen25vl,
pruning_paras=self.pruning_paras,
layer_idx=block_idx,
),
with_kwargs=True
)
self.blocks[block_idx].register_forward_hook(
functools.partial(
decoder_attn_hook,
Expand All @@ -425,7 +519,7 @@ def read_parameter_hook(module, args, kwargs, pruning_paras):

def update_list():
global sparse_token_list_192, sparse_token_list_128, sparse_token_list_64
global sparse_token_list_640, sparse_token_list_320, sparse_token_list_160
global sparse_token_list_960, sparse_token_list_640, sparse_token_list_320, sparse_token_list_160 # noqa
global prune_flag, merge_flag, sparse_token_dict

if layer_dict == {2: 0, 6: 1, 15: 2}: # 2*576 4*300 10*200 16*110
Expand All @@ -437,13 +531,15 @@ def update_list():
sparse_token_list_192 = [180]
sparse_token_list_128 = [114]
sparse_token_list_64 = [48]
sparse_token_list_960 = [0.3125]
sparse_token_list_640 = [0.1979]
sparse_token_list_320 = [0.0833]
sparse_token_list_160 = [0.0261]
elif prune_flag:
sparse_token_list_192 = [192]
sparse_token_list_128 = [128]
sparse_token_list_64 = [64]
sparse_token_list_960 = [0.3333]
sparse_token_list_640 = [0.2222]
sparse_token_list_320 = [0.1111]
sparse_token_list_160 = [0.0555]
Expand All @@ -460,6 +556,7 @@ def update_list():
192: sparse_token_list_192,
128: sparse_token_list_128,
64: sparse_token_list_64,
960: sparse_token_list_960,
640: sparse_token_list_640,
320: sparse_token_list_320,
160: sparse_token_list_160
Expand Down
Loading