Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
48 changes: 32 additions & 16 deletions configs/quantization/methods/GPTQ/gptq_w_only_vlm.yml
Original file line number Diff line number Diff line change
Expand Up @@ -25,22 +25,38 @@ eval:
bs: 1
inference_per_block: False
quant:
method: GPTQ
quant_objects: [vision, language] # default is [language]
weight:
bit: 4
symmetric: False
granularity: per_group
group_size: 128
# calib_algo: mse
# mse_b_num: 2
special:
actorder: True
static_groups: False
percdamp: 0.01
blocksize: 128
true_sequential: True
quant_out: True
vision:
method: GPTQ
weight:
bit: 4
symmetric: False
granularity: per_group
group_size: 128
# calib_algo: mse
# mse_b_num: 2
special:
actorder: True
static_groups: False
percdamp: 0.01
blocksize: 128
true_sequential: True
quant_out: True
language:
method: GPTQ
weight:
bit: 4
symmetric: False
granularity: per_group
group_size: 128
# calib_algo: mse
# mse_b_num: 2
special:
actorder: True
static_groups: False
percdamp: 0.01
blocksize: 128
true_sequential: True
quant_out: True
save:
save_fake: False
save_path: /path/to/save/
33 changes: 22 additions & 11 deletions configs/quantization/methods/RTN/rtn_w_a_vlm.yml
Original file line number Diff line number Diff line change
Expand Up @@ -13,17 +13,28 @@ eval:
bs: 1
inference_per_block: False
quant:
method: RTN
quant_objects: [vision, language] # default is [language]
weight:
bit: 8
symmetric: True
granularity: per_channel
group_size: -1
act:
bit: 8
symmetric: True
granularity: per_token
vision:
method: RTN
weight:
bit: 8
symmetric: True
granularity: per_channel
group_size: -1
act:
bit: 8
symmetric: True
granularity: per_token
language:
method: RTN
weight:
bit: 8
symmetric: True
granularity: per_channel
group_size: -1
act:
bit: 8
symmetric: True
granularity: per_token
save:
save_fake: False
save_path: /path/to/save/
36 changes: 24 additions & 12 deletions configs/quantization/methods/SmoothQuant/smoothquant_w_a_vlm.yml
Original file line number Diff line number Diff line change
Expand Up @@ -25,18 +25,30 @@ eval:
bs: 1
inference_per_block: False
quant:
method: SmoothQuant
quant_objects: [vision, language]
weight:
bit: 8
symmetric: True
granularity: per_channel
act:
bit: 8
symmetric: True
granularity: per_token
special:
alpha: 0.8
vision:
method: SmoothQuant
weight:
bit: 8
symmetric: True
granularity: per_channel
act:
bit: 8
symmetric: True
granularity: per_token
special:
alpha: 0.8
language:
method: SmoothQuant
weight:
bit: 8
symmetric: True
granularity: per_channel
act:
bit: 8
symmetric: True
granularity: per_token
special:
alpha: 0.8
save:
save_trans: False
save_fake: False
Expand Down
19 changes: 5 additions & 14 deletions llmc/compression/quantization/base_blockwise_quantization.py
Original file line number Diff line number Diff line change
Expand Up @@ -949,20 +949,11 @@ def deploy(self, quant_format, keep_device=False):
self.set_no_quant_layer()

module = module_mapping[quant_format]
if self.modality == 'vision':
self.model.replace_vision_module_all(
module,
self.get_replacement_params(mode=quant_format, w_only=self.w_only),
keep_device=keep_device,
)
if self.modality == 'language':
self.model.replace_language_module_all(
module,
self.get_replacement_params(mode=quant_format, w_only=self.w_only),
keep_device=keep_device,
)
if self.modality == 'video_gen':
self.model.replace_video_gen_module_all(

self.model.set_modality(self.modality)
logger.info(f'set modality: {self.modality}')
if self.modality in ('vision', 'language', 'video_gen'):
self.model.replace_module_all(
module,
self.get_replacement_params(mode=quant_format, w_only=self.w_only),
keep_device=keep_device,
Expand Down
2 changes: 1 addition & 1 deletion llmc/compression/quantization/llmint8.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ def deploy(self, quant_format):
logger.info(f'-- deploy_{quant_format}_model start --')
logger.info(f'quant_config : {self.quant_config}')

self.model.replace_language_module_all(
self.model.replace_module_all(
FakeQuantLinear,
self.get_replacement_params(
mode='fake_quant', w_only=self.w_only, name=None
Expand Down
2 changes: 0 additions & 2 deletions llmc/compression/quantization/tesseraq.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
import gc
import math
import os
import pdb
import random
from contextlib import nullcontext
from math import inf
Expand Down Expand Up @@ -268,7 +267,6 @@ def tesseraq_train(self, block):

if not math.isfinite(loss.item()):
logger.info('Loss is NAN, stopping training')
pdb.set_trace()

optimizer.zero_grad()

Expand Down
6 changes: 2 additions & 4 deletions llmc/compression/token_reduction/holitom.py
Original file line number Diff line number Diff line change
Expand Up @@ -594,7 +594,6 @@ def prepare_inputs_labels_for_multimodal(
if isinstance(modalities, str):
modalities = [modalities]

# import pdb; pdb.set_trace()
if type(images) is list or images.ndim == 5:
mm_patch_merge_type = getattr(self.config, 'mm_patch_merge_type', 'flat')
image_aspect_ratio = getattr(self.config, 'image_aspect_ratio', 'square')
Expand Down Expand Up @@ -733,7 +732,7 @@ def prepare_inputs_labels_for_multimodal(
# currently image_feature is a tensor of shape (4, num_patches, hidden_size)
# we want to first unflatten it to (2, 2, h, w, hidden_size)
# rank0_print("At least we are reaching here")
# import pdb; pdb.set_trace()

if image_idx in video_idx_in_batch: # video operations
# rank0_print("Video")
if mm_newline_position == 'grid':
Expand Down Expand Up @@ -1032,7 +1031,6 @@ def prepare_inputs_labels_for_multimodal(

cur_new_input_embeds = [x.to(self.device) for x in cur_new_input_embeds]

# import pdb; pdb.set_trace()
cur_new_input_embeds = torch.cat(cur_new_input_embeds)
cur_new_labels = torch.cat(cur_new_labels)

Expand Down Expand Up @@ -1157,7 +1155,7 @@ def prepare_inputs_labels_for_multimodal(
right_add = random.randint(left_add, self.config.pos_skipping_range)
position_ids[:, :split_position] += left_add
position_ids[:, split_position:] += right_add
# import pdb; pdb.set_trace()

# rank0_print("Finish preparing")
return (
None,
Expand Down
36 changes: 1 addition & 35 deletions llmc/models/base_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -378,40 +378,7 @@ def get_extra_modules(self, block):
def get_moe_gate(self, block):
return None

def replace_vision_module_all(self, module, params_dict, keep_device=False):
vision_model_linears = self.get_block_linears(self.vision_model)
for name, m in vision_model_linears.items():
M = module.new(m, **params_dict)

name_tmp = name.rsplit('.', 1)
if len(name_tmp) == 2:
parent_name = name_tmp[0]
parent = self.vision_model.get_submodule(parent_name)
child_name = name_tmp[1]
elif len(name_tmp) == 1:
parent = self.vision_model
child_name = name_tmp[0]

setattr(parent, child_name, M)

gc.collect()
torch.cuda.empty_cache()
logger.info(f'The Replaced vision_model: {self.vision_model}')

def replace_language_module_all(self, module, params_dict, keep_device=False):
for block_idx in range(len(self.blocks)):
logger.info(f'Replace block index: {block_idx}/{len(self.blocks)}')
if keep_device:
self.replace_module_block(module, self.blocks[block_idx], block_idx, params_dict)
else:
self.blocks[block_idx].cuda()
self.replace_module_block(module, self.blocks[block_idx], block_idx, params_dict)
self.blocks[block_idx].cpu()
gc.collect()
torch.cuda.empty_cache()
logger.info(f'The Replaced model: {self.model}')

def replace_video_gen_module_all(self, module, params_dict, keep_device=False):
def replace_module_all(self, module, params_dict, keep_device=False):
for block_idx in range(len(self.blocks)):
logger.info(f'Replace block index: {block_idx}/{len(self.blocks)}')
if keep_device:
Expand All @@ -422,7 +389,6 @@ def replace_video_gen_module_all(self, module, params_dict, keep_device=False):
self.blocks[block_idx].cpu()
gc.collect()
torch.cuda.empty_cache()
logger.info(f'The Replaced model: {self.model}')

def replace_module_block(self, module, block, block_idx, params_dict):
if module in _LLMC_LN_TYPES_ + _TRANSFORMERS_LN_TYPES_:
Expand Down
4 changes: 2 additions & 2 deletions llmc/models/internvl3_5.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,7 +157,7 @@ def __new__(cls, config, device_map=None, use_cache=False):
if language_part == 'qwen3':
from .qwen3 import Qwen3

class NewClass(InternVL2SharedBehavior, Qwen3):
class NewClass(InternVL3_5SharedBehavior, Qwen3):
def __init__(self, config, device_map=None, use_cache=False):
super().__init__(config, device_map, use_cache)
setattr(
Expand All @@ -170,7 +170,7 @@ def __init__(self, config, device_map=None, use_cache=False):
return NewClass(config, device_map, use_cache)


class InternVL2SharedBehavior():
class InternVL3_5SharedBehavior():
def build_model(self):
self.eval_name = 'InternVL3_5Eval'
self.vlm_model_config = AutoConfig.from_pretrained(
Expand Down
1 change: 0 additions & 1 deletion llmc/models/qwen2.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@ def find_embed_layers(self):

def find_block_name(self):
self.block_name_prefix = 'model.layers'
self.pairs = {'q_proj': 'qkv', 'o_proj': 'out', 'up_proj': 'fc1'}

def get_embed_layers(self):
return [self.embed_tokens]
Expand Down
34 changes: 6 additions & 28 deletions llmc/models/qwen2_5vl.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
from typing import Optional, Union

import torch
import torch.nn as nn
from accelerate import Accelerator, DistributedType
from loguru import logger
from transformers import AutoConfig, AutoProcessor, AutoTokenizer
Expand All @@ -15,18 +14,16 @@
'If you need it, please upgrade transformers.'
)

try:
from qwen_vl_utils import process_vision_info
except Exception:
logger.warning(
'Can not import qwen_vl_utils. '
'If you need it, please pip install qwen-vl-utils'
)

from llmc.utils.registry_factory import MODEL_REGISTRY

from .qwen2vl import Qwen2VL

# settings for qwen2_5_vl:
# pip install transformers==4.51.3
# pip install git+https://github.com/EvolvingLMMs-Lab/[email protected]
# And you should add the following at line 92 of llmc/eval/eval_vqa.py:
# import argparse; cli_args = argparse.Namespace(process_with_media=True)


@MODEL_REGISTRY
class Qwen2_5VL(Qwen2VL):
Expand Down Expand Up @@ -76,7 +73,6 @@ def build_model(self):
}
self.first_turn_question = True

# todo: check
def get_subsets_in_block(self, block):
if self.get_modality() == 'language':
return super().get_subsets_in_block(block)
Expand Down Expand Up @@ -152,14 +148,6 @@ def __init__(
# Do not use kwargs for now
assert kwargs == {}, f'Unexpected kwargs: {kwargs}'

# Validate attention implementation
valid_attn_implementations = [None, 'flash_attention_2', 'sdpa', 'eager']
if attn_implementation not in valid_attn_implementations:
raise ValueError(
f'attn_implementation must be one of {valid_attn_implementations}, \
got {attn_implementation}'
)

self.use_custom_video_loader = use_custom_video_loader
self.fps = fps
# if self.fps and not self.use_custom_video_loader:
Expand All @@ -178,16 +166,6 @@ def __init__(
self._device = torch.device(device)
self.device_map = device_map if device_map else device

# Prepare model loading arguments
model_kwargs = {
'torch_dtype': 'auto',
'device_map': self.device_map,
}

# Add attention implementation if specified
if attn_implementation is not None:
model_kwargs['attn_implementation'] = attn_implementation

self._model = llmc_model.eval().cuda()
self.max_pixels = max_pixels
self.min_pixels = min_pixels
Expand Down
21 changes: 0 additions & 21 deletions llmc/models/qwen2vl.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,27 +187,6 @@ def get_subsets_in_block(self, block):
else:
raise Exception(f'Qwen2VL do not support {self.get_modality()} modality.')

def get_catcher(self, first_block_input):
class Catcher(nn.Module):
def __init__(self, module):
super().__init__()
self.module = module
self.mlp = self.module.mlp
self.signature = inspect.signature(module.forward)

def forward(self, *args, **kwargs):
params = list(self.signature.parameters.keys())
for i, arg in enumerate(args):
if i > 0:
kwargs[params[i]] = arg
first_block_input['data'].append(args[0])
if 'output_router_logits' in kwargs:
assert kwargs['output_router_logits'] is False
kwargs.pop('output_router_logits')
first_block_input['kwargs'].append(kwargs)
raise ValueError
return Catcher


try:
from lmms_eval.api.model import lmms
Expand Down
Loading