diff --git a/llmc/compression/token_reduction/divprune.py b/llmc/compression/token_reduction/divprune.py index 9ca45e86..7a30b770 100644 --- a/llmc/compression/token_reduction/divprune.py +++ b/llmc/compression/token_reduction/divprune.py @@ -1,3 +1,4 @@ +import functools from functools import wraps from types import MethodType @@ -6,6 +7,7 @@ from llmc.utils.registry_factory import TOKEN_REDUCTION_REGISTRY from .token_reduction_module import TokenReductionModule +from .utils import prefill_wrapper def pairwise_cosine_similarity(matrix): @@ -84,6 +86,41 @@ def divprune_post_hook(*args, pruning_paras=None): return tuple(args) +def prune_qwenv25vl_hook(module, args, kwargs, pruning_paras): + if kwargs['position_ids'].shape[-1] == 1: + return args, kwargs + inputs_embeds = kwargs['inputs_embeds'] + attention_mask = kwargs['attention_mask'] + rate = pruning_paras['reduction_ratio'] + SYS_TOKEN_LEN = pruning_paras['vision_token_start_index'] + img_feature_len = pruning_paras['vision_token_length'] + device = inputs_embeds.device + + visual_tokens = inputs_embeds[0][SYS_TOKEN_LEN: SYS_TOKEN_LEN + img_feature_len] + selected_visual_tokens, cosine_matrix = divprune( + visual_tokens, img_feature_len, None, threshold_ratio=1 - rate + ) + selected_visual_tokens += SYS_TOKEN_LEN + keep_indexs = torch.cat( + ( + torch.arange(SYS_TOKEN_LEN, device=device), + selected_visual_tokens, + torch.arange( + SYS_TOKEN_LEN + img_feature_len, inputs_embeds.shape[1], device=device + ), + ) + ) + keep_indexs = keep_indexs.sort().values + + kwargs['inputs_embeds'] = inputs_embeds[:, keep_indexs, :] + kwargs['position_ids'] = kwargs['position_ids'][:, :, keep_indexs] + if attention_mask is not None: + kwargs['attention_mask'] = attention_mask[:, keep_indexs] + kwargs['cache_position'] = keep_indexs + + return args, kwargs + + @TOKEN_REDUCTION_REGISTRY.register('DivPrune') class DivPrune(TokenReductionModule): def __init__(self, config, model, blocks): @@ -114,6 +151,14 @@ def wrapper(self, *args, **kwargs): return divprune_post_hook(*outs, pruning_paras=pruning_paras) return wrapper + @prefill_wrapper + def vtoken_length_hook(module, args, pruning_paras): + input_ids = args[0] + token_indices = torch.where( + input_ids[0] == pruning_paras['vision_token_index'] + )[0] + pruning_paras['vision_token_length'] = token_indices.shape[0] + if self.model.__class__.__name__ == 'Llava': self.model.vlm_model.prepare_inputs_labels_for_multimodal = MethodType( @@ -123,3 +168,15 @@ def wrapper(self, *args, **kwargs): llava_next=self.special_config['vision_token_length'] is None ), self.model.vlm_model ) + elif self.model.__class__.__name__ == 'Qwen2_5VL': + + self.model.embed_tokens.register_forward_pre_hook( + functools.partial(vtoken_length_hook, pruning_paras=self.pruning_paras) + ) + self.model.language_model.register_forward_pre_hook( + functools.partial( + prune_qwenv25vl_hook, + pruning_paras=self.pruning_paras, + ), + with_kwargs=True + ) diff --git a/llmc/compression/token_reduction/random.py b/llmc/compression/token_reduction/random.py index e889df78..6de1dfbc 100644 --- a/llmc/compression/token_reduction/random.py +++ b/llmc/compression/token_reduction/random.py @@ -228,6 +228,14 @@ def store_attention_hook(m, x, layer_outputs, pruning_paras): layer_attention = layer_outputs[1] pruning_paras['attn_scores'] = layer_attention + @prefill_wrapper + def vtoken_length_hook(module, args, pruning_paras): + input_ids = args[0] + token_indices = torch.where( + input_ids[0] == pruning_paras['vision_token_index'] + )[0] + pruning_paras['vision_token_length'] = token_indices.shape[0] + if self.special_config['vision_token_length'] is None: if self.model.__class__.__name__ == 'Llava': self.model.vlm_model.prepare_inputs_labels_for_multimodal = MethodType( @@ -236,6 +244,10 @@ def store_attention_hook(m, x, layer_outputs, pruning_paras): self.pruning_paras ), self.model.vlm_model ) + elif self.model.__class__.__name__ == 'Qwen2_5VL': + self.model.embed_tokens.register_forward_pre_hook( + functools.partial(vtoken_length_hook, pruning_paras=self.pruning_paras) + ) if self.special_config['metric'] == 'random': self.blocks[self.pruning_loc].register_forward_pre_hook( diff --git a/llmc/compression/token_reduction/sparsevlm.py b/llmc/compression/token_reduction/sparsevlm.py index aae8f722..04a0287e 100755 --- a/llmc/compression/token_reduction/sparsevlm.py +++ b/llmc/compression/token_reduction/sparsevlm.py @@ -17,6 +17,7 @@ sparse_token_list_192 = [] sparse_token_list_128 = [] sparse_token_list_64 = [] +sparse_token_list_960 = [] sparse_token_list_640 = [] sparse_token_list_320 = [] sparse_token_list_160 = [] @@ -329,8 +330,9 @@ def decoder_attn_hook(module, inputs, kwargs, layer_outputs, pruning_paras, laye if attention_mask is not None: attention_mask = attention_mask[:, :, keep_indexs, keep_indexs] - new_pe0 = position_embeddings[0][:, keep_indexs, :].clone() - new_pe1 = position_embeddings[1][:, keep_indexs, :].clone() + index_dim = 1 if position_embeddings[0].dim() == 3 else 2 + new_pe0 = position_embeddings[0].index_select(index_dim, keep_indexs).clone() + new_pe1 = position_embeddings[1].index_select(index_dim, keep_indexs).clone() position_embeddings = (new_pe0, new_pe1) pruning_paras['v_token_num'] = v_token_num @@ -352,6 +354,75 @@ def read_parameter_hook(module, args, kwargs, pruning_paras): return args, kwargs + @prefill_wrapper + def vtoken_length_hook(module, args, pruning_paras): + input_ids = args[0] + token_indices = torch.where( + input_ids[0] == pruning_paras['vision_token_index'] + )[0] + pruning_paras['vision_token_length'] = token_indices.shape[0] + pruning_paras['pre_prompt_length_list'] = [token_indices[0].item()] + + def get_attn_logits_for_qwen25vl( + module, + args, kwargs, layer_outs, + pruning_paras, layer_idx + ): + if kwargs['position_ids'].shape[-1] == 1: + return layer_outs + + from transformers.models.qwen2_5_vl.modeling_qwen2_5_vl import ( + apply_multimodal_rotary_pos_emb, repeat_kv) + + hidden_states = kwargs['hidden_states'] + position_embeddings = kwargs['position_embeddings'] + past_key_value = layer_outs[2] + attention_mask = kwargs['attention_mask'] + + t_token_idx = pruning_paras['t_token_idx'] + v_token_start = pruning_paras['v_token_start'] + v_token_num = pruning_paras['v_token_num'] + + bsz, q_len, _ = hidden_states.size() + + query_states = module.q_proj(hidden_states) + key_states = module.k_proj(hidden_states) + value_states = module.v_proj(hidden_states) + + query_states = query_states.view(bsz, q_len, -1, module.head_dim).transpose(1, 2) + key_states = key_states.view(bsz, q_len, -1, module.head_dim).transpose(1, 2) + value_states = value_states.view(bsz, q_len, -1, module.head_dim).transpose(1, 2) + + cos, sin = position_embeddings + query_states, key_states = apply_multimodal_rotary_pos_emb( + query_states, key_states, cos, sin, module.rope_scaling['mrope_section'] + ) + + if past_key_value is not None: + key_states = past_key_value.key_cache[layer_idx] + value_states = past_key_value.value_cache[layer_idx] + + key_states = repeat_kv(key_states, module.num_key_value_groups) + value_states = repeat_kv(value_states, module.num_key_value_groups) + + t_token_idx = t_token_idx[1] + v_token_start + v_token_num + L, S = query_states.size(-2), key_states.size(-2) + scale_factor = 1 / math.sqrt(query_states.size(-1)) + attn_bias = torch.zeros(L, S, dtype=query_states.dtype) + if module.is_causal: + assert attention_mask is None + temp_mask = torch.ones(L, S, dtype=torch.bool).tril(diagonal=0) + attn_bias.masked_fill_(temp_mask.logical_not(), float('-inf')) + attn_bias.to(query_states.dtype) + + attn_logits = query_states @ key_states.transpose(2, 3) * scale_factor + attn_logits += attn_bias.to(query_states.device) + attn_logits = torch.softmax(attn_logits, dim=-1) + + pruning_paras['attn_logits'] = attn_logits + + return layer_outs + if self.model.__class__.__name__ == 'LlavaHf': self.model.embed_tokens.register_forward_pre_hook( functools.partial(input_hook, pruning_paras=self.pruning_paras) @@ -364,11 +435,17 @@ def read_parameter_hook(module, args, kwargs, pruning_paras): llava_next=self.special_config['vision_token_length'] is None ), self.model.vlm_model ) + elif self.model.__class__.__name__ == 'Qwen2_5VL': + self.model.embed_tokens.register_forward_pre_hook( + functools.partial(vtoken_length_hook, pruning_paras=self.pruning_paras) + ) if self.model.__class__.__name__ == 'LlavaHf': llama_model = self.model.model elif self.model.__class__.__name__ == 'Llava': llama_model = self.model.model.model + elif self.model.__class__.__name__ == 'Qwen2_5VL': + llama_model = self.model.language_model llama_model.register_forward_pre_hook( functools.partial(register_module_paras, pruning_paras=self.pruning_paras), with_kwargs=True @@ -405,6 +482,23 @@ def read_parameter_hook(module, args, kwargs, pruning_paras): ), with_kwargs=True ) + elif self.model.__class__.__name__ == 'Qwen2_5VL': + self.blocks[block_idx].register_forward_pre_hook( + functools.partial( + update_kwargs_hook, + pruning_paras=self.pruning_paras, + layer_idx=block_idx, + ), + with_kwargs=True + ) + self.blocks[block_idx].self_attn.register_forward_hook( + functools.partial( + get_attn_logits_for_qwen25vl, + pruning_paras=self.pruning_paras, + layer_idx=block_idx, + ), + with_kwargs=True + ) self.blocks[block_idx].register_forward_hook( functools.partial( decoder_attn_hook, @@ -425,7 +519,7 @@ def read_parameter_hook(module, args, kwargs, pruning_paras): def update_list(): global sparse_token_list_192, sparse_token_list_128, sparse_token_list_64 - global sparse_token_list_640, sparse_token_list_320, sparse_token_list_160 + global sparse_token_list_960, sparse_token_list_640, sparse_token_list_320, sparse_token_list_160 # noqa global prune_flag, merge_flag, sparse_token_dict if layer_dict == {2: 0, 6: 1, 15: 2}: # 2*576 4*300 10*200 16*110 @@ -437,6 +531,7 @@ def update_list(): sparse_token_list_192 = [180] sparse_token_list_128 = [114] sparse_token_list_64 = [48] + sparse_token_list_960 = [0.3125] sparse_token_list_640 = [0.1979] sparse_token_list_320 = [0.0833] sparse_token_list_160 = [0.0261] @@ -444,6 +539,7 @@ def update_list(): sparse_token_list_192 = [192] sparse_token_list_128 = [128] sparse_token_list_64 = [64] + sparse_token_list_960 = [0.3333] sparse_token_list_640 = [0.2222] sparse_token_list_320 = [0.1111] sparse_token_list_160 = [0.0555] @@ -460,6 +556,7 @@ def update_list(): 192: sparse_token_list_192, 128: sparse_token_list_128, 64: sparse_token_list_64, + 960: sparse_token_list_960, 640: sparse_token_list_640, 320: sparse_token_list_320, 160: sparse_token_list_160 diff --git a/llmc/compression/token_reduction/vispruner.py b/llmc/compression/token_reduction/vispruner.py index afe63fe1..f052a485 100644 --- a/llmc/compression/token_reduction/vispruner.py +++ b/llmc/compression/token_reduction/vispruner.py @@ -1,13 +1,16 @@ import functools +import math from functools import wraps from types import MethodType +from typing import Optional, Tuple import torch +import torch.nn as nn from llmc.utils.registry_factory import TOKEN_REDUCTION_REGISTRY from .token_reduction_module import TokenReductionModule -from .utils import get_anyres_image_grid_shape, unpad_image +from .utils import get_anyres_image_grid_shape, prefill_wrapper, unpad_image @TOKEN_REDUCTION_REGISTRY.register('VisPruner') @@ -238,31 +241,187 @@ def prune_hook(module, inputs, outputs, pruning_paras, model_config): ) return image_features - self.model.vlm_model.prepare_inputs_labels_for_multimodal = MethodType( - change_images_hook( - self.model.vlm_model.prepare_inputs_labels_for_multimodal, - self.pruning_paras - ), - self.model.vlm_model - ) + if self.model.__class__.__name__ != 'Qwen2_5VL': + self.model.vlm_model.prepare_inputs_labels_for_multimodal = MethodType( + change_images_hook( + self.model.vlm_model.prepare_inputs_labels_for_multimodal, + self.pruning_paras + ), + self.model.vlm_model + ) + + self.model.vision_model.vision_tower.register_forward_pre_hook( + update_output_attentions_hook, + with_kwargs=True + ) + + self.model.vision_model.vision_tower.register_forward_hook( + functools.partial(store_attention_hook, pruning_paras=self.pruning_paras), + ) + + self.model.vision_projector.register_forward_pre_hook( + functools.partial(get_index_masks_hook, pruning_paras=self.pruning_paras), + ) + + self.model.vision_projector.register_forward_hook( + functools.partial( + prune_hook, + pruning_paras=self.pruning_paras, + model_config=self.model.vlm_model_config + ), + ) + + def get_metric(fn, pruning_paras): + @wraps(fn) + def wrapper(self, *args, **kwargs): + return fn(self, *args, pruning_paras=pruning_paras, **kwargs) + return wrapper - self.model.vision_model.vision_tower.register_forward_pre_hook( - update_output_attentions_hook, - with_kwargs=True - ) + def merger_hook(module, inputs, kwargs, layer_outs, pruning_paras): + with torch.no_grad(): + attn_mean = pruning_paras['attn_logits'].mean(dim=0) # 16 1120, 1120 -> 1120, 1120 + window_index, _ = module.get_window_index(kwargs['grid_thw']) + reverse_indices = torch.argsort(window_index) # [280] + attn_mean = attn_mean.sum(dim=0) # 1120, 1120 -> 1120 + attn_mean = attn_mean.view(attn_mean.shape[0] // 4, -1).mean(dim=-1) # 1120 -> 280 + attn_mean = attn_mean[reverse_indices] + pruning_paras['attn_mean'] = attn_mean + return layer_outs + + @prefill_wrapper + def get_input_ids_hook(module, input_args, pruning_paras): + pruning_paras['input_ids'] = input_args[0] + return input_args + + def prune_qwenv25vl_hook(module, args, kwargs, pruning_paras): + # only support bs=1 + if kwargs['position_ids'].shape[-1] == 1: + return args, kwargs + inputs_embeds = kwargs['inputs_embeds'] + + img_mask = (pruning_paras['input_ids'] == pruning_paras['vision_token_index'])[0] + img_idx = torch.nonzero(img_mask, as_tuple=True)[0] # [280] + image_features = inputs_embeds[:, img_idx, :] - self.model.vision_model.vision_tower.register_forward_hook( - functools.partial(store_attention_hook, pruning_paras=self.pruning_paras), - ) + B, N, C = image_features.shape + device = image_features.device + visual_token_num = round(N * (1 - self.special_config['prune_ratio'])) # T + important_ratio = self.pruning_paras['important_ratio'] # r + important_token_num = int(visual_token_num * important_ratio) # T_imp = T * r + if (N - important_token_num) % 2 != 0: + important_token_num += 1 + diverse_token_num = visual_token_num - important_token_num # T_div = T * (1 - r) - self.model.vision_projector.register_forward_pre_hook( - functools.partial(get_index_masks_hook, pruning_paras=self.pruning_paras), - ) + # [VisPruner] Select important tokens using attention scores + image_attentions = pruning_paras['attn_mean'].unsqueeze(0) # (B, N) + token_indices = image_attentions.argsort(dim=-1, descending=True) # (B, N) + important_indices = token_indices[:, :important_token_num] # (B, T_imp) + residual_indices = token_indices[:, important_token_num:] # (B, N - T_imp) - self.model.vision_projector.register_forward_hook( - functools.partial( - prune_hook, - pruning_paras=self.pruning_paras, - model_config=self.model.vlm_model_config - ), - ) + # [VisPruner] Remove duplicate tokens by iterative matching and pruning + image_normalized = image_features / image_features.norm(dim=-1, keepdim=True) + while diverse_token_num > 0: + R = residual_indices.shape[1] + r = min(8, R - diverse_token_num) + if r <= 0: + break + + residual_tokens = image_normalized[ + torch.arange(B).unsqueeze(-1).expand(-1, R), + residual_indices + ] # (B, R, C) + a, b = residual_tokens[..., ::2, :], residual_tokens[..., 1::2, :] # (B, R // 2, C) + scores = a @ b.transpose(-1, -2) # (B, R // 2, R // 2) + scores = scores.max(dim=-1).values # (B, R // 2) + + distinct_indices = scores.argsort(dim=-1, descending=True)[:, r:] # (B, R // 2 - r) + residual_indices = torch.cat([ + residual_indices[..., ::2][ + torch.arange(B).unsqueeze(-1).expand(-1, R // 2 - r), + distinct_indices + ], + residual_indices[..., 1::2] + ], dim=-1) # (B, R - r) + + if diverse_token_num > 0: + selected_indices = torch.cat([important_indices, residual_indices], dim=-1) + else: + selected_indices = important_indices # (B, T) + index_masks = torch.zeros(B, N, dtype=torch.bool, device=device) + index_masks.scatter_(1, selected_indices, True) + if img_idx.numel() > 0: + first, last = img_idx[0].item(), img_idx[-1].item() + img_mask[first: last + 1] = ~index_masks[0] + img_mask = ~img_mask + + kwargs['position_ids'] = kwargs['position_ids'][:, :, img_mask] + kwargs['attention_mask'] = kwargs['attention_mask'][:, img_mask] + kwargs['inputs_embeds'] = inputs_embeds[:, img_mask] + + return args, kwargs + + if self.model.__class__.__name__ == 'Qwen2_5VL': + self.blocks[-1].attn.forward = MethodType( + get_metric(Qwen2_5_VLVisionAttention_forward, self.pruning_paras), + self.blocks[-1].attn + ) + self.model.vision_model.register_forward_hook( + functools.partial( + merger_hook, + pruning_paras=self.pruning_paras, + ), + with_kwargs=True + ) + self.model.embed_tokens.register_forward_pre_hook( + functools.partial(get_input_ids_hook, pruning_paras=self.pruning_paras) + ) + self.model.language_model.register_forward_pre_hook( + functools.partial( + prune_qwenv25vl_hook, + pruning_paras=self.pruning_paras, + ), + with_kwargs=True + ) + + +def Qwen2_5_VLVisionAttention_forward( + self, + hidden_states: torch.Tensor, + pruning_paras, + cu_seqlens: torch.Tensor, + rotary_pos_emb: Optional[torch.Tensor] = None, + position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, +) -> torch.Tensor: + from transformers.models.qwen2_5_vl.modeling_qwen2_5_vl import \ + apply_rotary_pos_emb_vision + head_dim = self.qkv.in_features // self.num_heads + seq_length = hidden_states.shape[0] + q, k, v = self.qkv(hidden_states).reshape( + seq_length, 3, self.num_heads, -1 + ).permute(1, 0, 2, 3).unbind(0) + if position_embeddings is None: + emb = torch.cat((rotary_pos_emb, rotary_pos_emb), dim=-1) + cos = emb.cos() + sin = emb.sin() + else: + cos, sin = position_embeddings + q, k = apply_rotary_pos_emb_vision(q, k, cos, sin) + + attention_mask = torch.full( + [1, seq_length, seq_length], torch.finfo(q.dtype).min, device=q.device, dtype=q.dtype + ) + for i in range(1, len(cu_seqlens)): + attention_mask[..., cu_seqlens[i - 1]: cu_seqlens[i], cu_seqlens[i - 1]: cu_seqlens[i]] = 0 + + q = q.transpose(0, 1) + k = k.transpose(0, 1) + v = v.transpose(0, 1) + attn_weights = torch.matmul(q, k.transpose(1, 2)) / math.sqrt(head_dim) + attn_weights = attn_weights + attention_mask + attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(q.dtype) + attn_output = torch.matmul(attn_weights, v) + attn_output = attn_output.transpose(0, 1) + attn_output = attn_output.reshape(seq_length, -1) + attn_output = self.proj(attn_output) + pruning_paras['attn_logits'] = attn_weights + return attn_output