diff --git a/configs/quantization/methods/GPTQ/gptq_w_only_vlm.yml b/configs/quantization/methods/GPTQ/gptq_w_only_vlm.yml index e6d9df736..017587648 100644 --- a/configs/quantization/methods/GPTQ/gptq_w_only_vlm.yml +++ b/configs/quantization/methods/GPTQ/gptq_w_only_vlm.yml @@ -25,22 +25,38 @@ eval: bs: 1 inference_per_block: False quant: - method: GPTQ - quant_objects: [vision, language] # default is [language] - weight: - bit: 4 - symmetric: False - granularity: per_group - group_size: 128 - # calib_algo: mse - # mse_b_num: 2 - special: - actorder: True - static_groups: False - percdamp: 0.01 - blocksize: 128 - true_sequential: True - quant_out: True + vision: + method: GPTQ + weight: + bit: 4 + symmetric: False + granularity: per_group + group_size: 128 + # calib_algo: mse + # mse_b_num: 2 + special: + actorder: True + static_groups: False + percdamp: 0.01 + blocksize: 128 + true_sequential: True + quant_out: True + language: + method: GPTQ + weight: + bit: 4 + symmetric: False + granularity: per_group + group_size: 128 + # calib_algo: mse + # mse_b_num: 2 + special: + actorder: True + static_groups: False + percdamp: 0.01 + blocksize: 128 + true_sequential: True + quant_out: True save: save_fake: False save_path: /path/to/save/ diff --git a/configs/quantization/methods/RTN/rtn_w_a_vlm.yml b/configs/quantization/methods/RTN/rtn_w_a_vlm.yml index 728e21344..7bc9fa480 100644 --- a/configs/quantization/methods/RTN/rtn_w_a_vlm.yml +++ b/configs/quantization/methods/RTN/rtn_w_a_vlm.yml @@ -13,17 +13,28 @@ eval: bs: 1 inference_per_block: False quant: - method: RTN - quant_objects: [vision, language] # default is [language] - weight: - bit: 8 - symmetric: True - granularity: per_channel - group_size: -1 - act: - bit: 8 - symmetric: True - granularity: per_token + vision: + method: RTN + weight: + bit: 8 + symmetric: True + granularity: per_channel + group_size: -1 + act: + bit: 8 + symmetric: True + granularity: per_token + language: + method: RTN + weight: + bit: 8 + symmetric: True + granularity: per_channel + group_size: -1 + act: + bit: 8 + symmetric: True + granularity: per_token save: save_fake: False save_path: /path/to/save/ diff --git a/configs/quantization/methods/SmoothQuant/smoothquant_w_a_vlm.yml b/configs/quantization/methods/SmoothQuant/smoothquant_w_a_vlm.yml index caf962478..5b2b7b217 100644 --- a/configs/quantization/methods/SmoothQuant/smoothquant_w_a_vlm.yml +++ b/configs/quantization/methods/SmoothQuant/smoothquant_w_a_vlm.yml @@ -25,18 +25,30 @@ eval: bs: 1 inference_per_block: False quant: - method: SmoothQuant - quant_objects: [vision, language] - weight: - bit: 8 - symmetric: True - granularity: per_channel - act: - bit: 8 - symmetric: True - granularity: per_token - special: - alpha: 0.8 + vision: + method: SmoothQuant + weight: + bit: 8 + symmetric: True + granularity: per_channel + act: + bit: 8 + symmetric: True + granularity: per_token + special: + alpha: 0.8 + language: + method: SmoothQuant + weight: + bit: 8 + symmetric: True + granularity: per_channel + act: + bit: 8 + symmetric: True + granularity: per_token + special: + alpha: 0.8 save: save_trans: False save_fake: False diff --git a/llmc/compression/quantization/base_blockwise_quantization.py b/llmc/compression/quantization/base_blockwise_quantization.py index 33927f35d..5a2232699 100755 --- a/llmc/compression/quantization/base_blockwise_quantization.py +++ b/llmc/compression/quantization/base_blockwise_quantization.py @@ -949,20 +949,11 @@ def deploy(self, quant_format, keep_device=False): self.set_no_quant_layer() module = module_mapping[quant_format] - if self.modality == 'vision': - self.model.replace_vision_module_all( - module, - self.get_replacement_params(mode=quant_format, w_only=self.w_only), - keep_device=keep_device, - ) - if self.modality == 'language': - self.model.replace_language_module_all( - module, - self.get_replacement_params(mode=quant_format, w_only=self.w_only), - keep_device=keep_device, - ) - if self.modality == 'video_gen': - self.model.replace_video_gen_module_all( + + self.model.set_modality(self.modality) + logger.info(f'set modality: {self.modality}') + if self.modality in ('vision', 'language', 'video_gen'): + self.model.replace_module_all( module, self.get_replacement_params(mode=quant_format, w_only=self.w_only), keep_device=keep_device, diff --git a/llmc/compression/quantization/llmint8.py b/llmc/compression/quantization/llmint8.py index 116b8843e..29209f63a 100644 --- a/llmc/compression/quantization/llmint8.py +++ b/llmc/compression/quantization/llmint8.py @@ -66,7 +66,7 @@ def deploy(self, quant_format): logger.info(f'-- deploy_{quant_format}_model start --') logger.info(f'quant_config : {self.quant_config}') - self.model.replace_language_module_all( + self.model.replace_module_all( FakeQuantLinear, self.get_replacement_params( mode='fake_quant', w_only=self.w_only, name=None diff --git a/llmc/compression/quantization/tesseraq.py b/llmc/compression/quantization/tesseraq.py index f81f0d762..8c87f77ac 100644 --- a/llmc/compression/quantization/tesseraq.py +++ b/llmc/compression/quantization/tesseraq.py @@ -3,7 +3,6 @@ import gc import math import os -import pdb import random from contextlib import nullcontext from math import inf @@ -268,7 +267,6 @@ def tesseraq_train(self, block): if not math.isfinite(loss.item()): logger.info('Loss is NAN, stopping training') - pdb.set_trace() optimizer.zero_grad() diff --git a/llmc/compression/token_reduction/holitom.py b/llmc/compression/token_reduction/holitom.py index 27f601c50..79b8ca28e 100644 --- a/llmc/compression/token_reduction/holitom.py +++ b/llmc/compression/token_reduction/holitom.py @@ -594,7 +594,6 @@ def prepare_inputs_labels_for_multimodal( if isinstance(modalities, str): modalities = [modalities] - # import pdb; pdb.set_trace() if type(images) is list or images.ndim == 5: mm_patch_merge_type = getattr(self.config, 'mm_patch_merge_type', 'flat') image_aspect_ratio = getattr(self.config, 'image_aspect_ratio', 'square') @@ -733,7 +732,7 @@ def prepare_inputs_labels_for_multimodal( # currently image_feature is a tensor of shape (4, num_patches, hidden_size) # we want to first unflatten it to (2, 2, h, w, hidden_size) # rank0_print("At least we are reaching here") - # import pdb; pdb.set_trace() + if image_idx in video_idx_in_batch: # video operations # rank0_print("Video") if mm_newline_position == 'grid': @@ -1032,7 +1031,6 @@ def prepare_inputs_labels_for_multimodal( cur_new_input_embeds = [x.to(self.device) for x in cur_new_input_embeds] - # import pdb; pdb.set_trace() cur_new_input_embeds = torch.cat(cur_new_input_embeds) cur_new_labels = torch.cat(cur_new_labels) @@ -1157,7 +1155,7 @@ def prepare_inputs_labels_for_multimodal( right_add = random.randint(left_add, self.config.pos_skipping_range) position_ids[:, :split_position] += left_add position_ids[:, split_position:] += right_add - # import pdb; pdb.set_trace() + # rank0_print("Finish preparing") return ( None, diff --git a/llmc/models/base_model.py b/llmc/models/base_model.py index 6f6a563b2..4d7dda2ae 100755 --- a/llmc/models/base_model.py +++ b/llmc/models/base_model.py @@ -378,40 +378,7 @@ def get_extra_modules(self, block): def get_moe_gate(self, block): return None - def replace_vision_module_all(self, module, params_dict, keep_device=False): - vision_model_linears = self.get_block_linears(self.vision_model) - for name, m in vision_model_linears.items(): - M = module.new(m, **params_dict) - - name_tmp = name.rsplit('.', 1) - if len(name_tmp) == 2: - parent_name = name_tmp[0] - parent = self.vision_model.get_submodule(parent_name) - child_name = name_tmp[1] - elif len(name_tmp) == 1: - parent = self.vision_model - child_name = name_tmp[0] - - setattr(parent, child_name, M) - - gc.collect() - torch.cuda.empty_cache() - logger.info(f'The Replaced vision_model: {self.vision_model}') - - def replace_language_module_all(self, module, params_dict, keep_device=False): - for block_idx in range(len(self.blocks)): - logger.info(f'Replace block index: {block_idx}/{len(self.blocks)}') - if keep_device: - self.replace_module_block(module, self.blocks[block_idx], block_idx, params_dict) - else: - self.blocks[block_idx].cuda() - self.replace_module_block(module, self.blocks[block_idx], block_idx, params_dict) - self.blocks[block_idx].cpu() - gc.collect() - torch.cuda.empty_cache() - logger.info(f'The Replaced model: {self.model}') - - def replace_video_gen_module_all(self, module, params_dict, keep_device=False): + def replace_module_all(self, module, params_dict, keep_device=False): for block_idx in range(len(self.blocks)): logger.info(f'Replace block index: {block_idx}/{len(self.blocks)}') if keep_device: @@ -422,7 +389,6 @@ def replace_video_gen_module_all(self, module, params_dict, keep_device=False): self.blocks[block_idx].cpu() gc.collect() torch.cuda.empty_cache() - logger.info(f'The Replaced model: {self.model}') def replace_module_block(self, module, block, block_idx, params_dict): if module in _LLMC_LN_TYPES_ + _TRANSFORMERS_LN_TYPES_: diff --git a/llmc/models/internvl3_5.py b/llmc/models/internvl3_5.py index 2b4a28af7..90f5c6256 100644 --- a/llmc/models/internvl3_5.py +++ b/llmc/models/internvl3_5.py @@ -157,7 +157,7 @@ def __new__(cls, config, device_map=None, use_cache=False): if language_part == 'qwen3': from .qwen3 import Qwen3 - class NewClass(InternVL2SharedBehavior, Qwen3): + class NewClass(InternVL3_5SharedBehavior, Qwen3): def __init__(self, config, device_map=None, use_cache=False): super().__init__(config, device_map, use_cache) setattr( @@ -170,7 +170,7 @@ def __init__(self, config, device_map=None, use_cache=False): return NewClass(config, device_map, use_cache) -class InternVL2SharedBehavior(): +class InternVL3_5SharedBehavior(): def build_model(self): self.eval_name = 'InternVL3_5Eval' self.vlm_model_config = AutoConfig.from_pretrained( diff --git a/llmc/models/qwen2.py b/llmc/models/qwen2.py index 4d9d9551f..1f4e4b5ec 100644 --- a/llmc/models/qwen2.py +++ b/llmc/models/qwen2.py @@ -22,7 +22,6 @@ def find_embed_layers(self): def find_block_name(self): self.block_name_prefix = 'model.layers' - self.pairs = {'q_proj': 'qkv', 'o_proj': 'out', 'up_proj': 'fc1'} def get_embed_layers(self): return [self.embed_tokens] diff --git a/llmc/models/qwen2_5vl.py b/llmc/models/qwen2_5vl.py index e41d51a5c..eb19b6dea 100755 --- a/llmc/models/qwen2_5vl.py +++ b/llmc/models/qwen2_5vl.py @@ -2,7 +2,6 @@ from typing import Optional, Union import torch -import torch.nn as nn from accelerate import Accelerator, DistributedType from loguru import logger from transformers import AutoConfig, AutoProcessor, AutoTokenizer @@ -15,18 +14,16 @@ 'If you need it, please upgrade transformers.' ) -try: - from qwen_vl_utils import process_vision_info -except Exception: - logger.warning( - 'Can not import qwen_vl_utils. ' - 'If you need it, please pip install qwen-vl-utils' - ) - from llmc.utils.registry_factory import MODEL_REGISTRY from .qwen2vl import Qwen2VL +# settings for qwen2_5_vl: +# pip install transformers==4.51.3 +# pip install git+https://github.com/EvolvingLMMs-Lab/lmms-eval.git@v0.3.5 +# And you should add the following at line 92 of llmc/eval/eval_vqa.py: +# import argparse; cli_args = argparse.Namespace(process_with_media=True) + @MODEL_REGISTRY class Qwen2_5VL(Qwen2VL): @@ -76,7 +73,6 @@ def build_model(self): } self.first_turn_question = True - # todo: check def get_subsets_in_block(self, block): if self.get_modality() == 'language': return super().get_subsets_in_block(block) @@ -152,14 +148,6 @@ def __init__( # Do not use kwargs for now assert kwargs == {}, f'Unexpected kwargs: {kwargs}' - # Validate attention implementation - valid_attn_implementations = [None, 'flash_attention_2', 'sdpa', 'eager'] - if attn_implementation not in valid_attn_implementations: - raise ValueError( - f'attn_implementation must be one of {valid_attn_implementations}, \ - got {attn_implementation}' - ) - self.use_custom_video_loader = use_custom_video_loader self.fps = fps # if self.fps and not self.use_custom_video_loader: @@ -178,16 +166,6 @@ def __init__( self._device = torch.device(device) self.device_map = device_map if device_map else device - # Prepare model loading arguments - model_kwargs = { - 'torch_dtype': 'auto', - 'device_map': self.device_map, - } - - # Add attention implementation if specified - if attn_implementation is not None: - model_kwargs['attn_implementation'] = attn_implementation - self._model = llmc_model.eval().cuda() self.max_pixels = max_pixels self.min_pixels = min_pixels diff --git a/llmc/models/qwen2vl.py b/llmc/models/qwen2vl.py index 40cbf38df..32d130a2e 100755 --- a/llmc/models/qwen2vl.py +++ b/llmc/models/qwen2vl.py @@ -187,27 +187,6 @@ def get_subsets_in_block(self, block): else: raise Exception(f'Qwen2VL do not support {self.get_modality()} modality.') - def get_catcher(self, first_block_input): - class Catcher(nn.Module): - def __init__(self, module): - super().__init__() - self.module = module - self.mlp = self.module.mlp - self.signature = inspect.signature(module.forward) - - def forward(self, *args, **kwargs): - params = list(self.signature.parameters.keys()) - for i, arg in enumerate(args): - if i > 0: - kwargs[params[i]] = arg - first_block_input['data'].append(args[0]) - if 'output_router_logits' in kwargs: - assert kwargs['output_router_logits'] is False - kwargs.pop('output_router_logits') - first_block_input['kwargs'].append(kwargs) - raise ValueError - return Catcher - try: from lmms_eval.api.model import lmms diff --git a/tools/quant_analysis.py b/tools/quant_analysis.py index 1accb3594..b5daae629 100644 --- a/tools/quant_analysis.py +++ b/tools/quant_analysis.py @@ -441,7 +441,7 @@ def a_qdq(act, module=None): params_dict = {} params_dict['w_qdq'] = wquanter.fake_quant_weight_dynamic params_dict['a_qdq'] = None if args.w_only else a_qdq - t_model.replace_language_module_all(FakeQuantLinear, params_dict) + t_model.replace_module_all(FakeQuantLinear, params_dict) with torch.no_grad(): for i in tqdm(range(len(model.blocks))):