diff --git a/CHANGELOG.rst b/CHANGELOG.rst index 3993f4670..c50d37330 100755 --- a/CHANGELOG.rst +++ b/CHANGELOG.rst @@ -46,6 +46,7 @@ NVIDIA Model Optimizer Changelog - Add support for Nemotron-3 (NemotronHForCausalLM) model quantization and support for NemotronH MoE expert support in ``auto_quantize`` grouping and scoring rules. - Add support for block-granular RHT for non-power-of-2 dimensions. - Replace modelopt FP8 QDQ nodes with native ONNX QDQ nodes. +- Add support for vLLM fakequant reload using ModelOpt state for HF models. See `examples/vllm_serve/README.md `_ for more details. **Deprecations** diff --git a/examples/llm_ptq/README.md b/examples/llm_ptq/README.md index 5db36a972..4d2239076 100755 --- a/examples/llm_ptq/README.md +++ b/examples/llm_ptq/README.md @@ -346,6 +346,8 @@ with torch.inference_mode(): python hf_ptq.py --pyt_ckpt_path --qformat fp8 --export_path --trust_remote_code ``` +> *For exporting fake-quantized models for vLLM serving (e.g., for research or kernels not yet supported in real-quant), use the `--vllm_fakequant_export` flag. See [vllm_serve/README.md](../vllm_serve/README.md) for details.* + ### Hugging Face framework [Script](./scripts/huggingface_example.sh) Alternatively, the framework script `huggingface_example.sh` also supports quantize and export: diff --git a/examples/llm_ptq/hf_ptq.py b/examples/llm_ptq/hf_ptq.py index 5620ddf6a..b5f84b736 100755 --- a/examples/llm_ptq/hf_ptq.py +++ b/examples/llm_ptq/hf_ptq.py @@ -53,6 +53,7 @@ from modelopt.recipe import ModelOptPTQRecipe, load_recipe from modelopt.torch.export import ( export_hf_checkpoint, + export_hf_vllm_fq_checkpoint, export_speculative_decoding, export_tensorrt_llm_checkpoint, get_model_type, @@ -681,16 +682,21 @@ def export_quantized( # Load any missing weights from non-standard safetensors (handled in get_model for non-low-memory mode) # Store the MTP layer prefixes on the model for later exclusion from quantization - mtp_layer_prefixes, mtp_state_dict = load_mtp_weights(full_model, args.pyt_ckpt_path) + if args.vllm_fakequant_export: + export_hf_vllm_fq_checkpoint(full_model, export_dir=export_path) + else: + mtp_layer_prefixes, mtp_state_dict = load_mtp_weights( + full_model, args.pyt_ckpt_path + ) - if mtp_layer_prefixes: - full_model._mtp_layer_prefixes = mtp_layer_prefixes + if mtp_layer_prefixes: + full_model._mtp_layer_prefixes = mtp_layer_prefixes - export_hf_checkpoint( - full_model, - export_dir=export_path, - extra_state_dict=mtp_state_dict, - ) + export_hf_checkpoint( + full_model, + export_dir=export_path, + extra_state_dict=mtp_state_dict, + ) # Restore default padding and export the tokenizer as well. if tokenizer is not None: @@ -1218,6 +1224,13 @@ def parse_args() -> argparse.Namespace: "Does not impact non-MOE models." ), ) + parser.add_argument( + "--vllm_fakequant_export", + default=False, + action="store_true", + help="Export as vLLM fake-quant checkpoint (produces vllm_fq_modelopt_state.pth " + "for use with vllm_serve_fakequant.py).", + ) args = parser.parse_args() if args.moe_calib_experts_ratio is not None and not (0.0 < args.moe_calib_experts_ratio <= 1.0): diff --git a/examples/vllm_serve/README.md b/examples/vllm_serve/README.md index ff0c4eea3..01ee3c44b 100644 --- a/examples/vllm_serve/README.md +++ b/examples/vllm_serve/README.md @@ -23,9 +23,11 @@ You can either edit the `quant_config` dictionary in `vllm_serve_fakequant.py`, |-----------------|--------------------------------------------------|---------------------| | QUANT_DATASET | Dataset name for calibration | cnn_dailymail | | QUANT_CALIB_SIZE| Number of samples used for calibration | 512 | -| QUANT_CFG | Quantization format | NVFP4_DEFAULT_CFG | -| KV_QUANT_CFG | Quantization format for KV Cache | None | -| AMAX_FILE_PATH | Optional path to amax file (for loading amax) | None | +| QUANT_CFG | Quantization config | None | +| KV_QUANT_CFG | KV-cache quantization config | None | +| QUANT_FILE_PATH | Optional path to exported quantizer state dict `quantizer_state.pth` | None | +| MODELOPT_STATE_PATH | Optional path to exported `vllm_fq_modelopt_state.pth` (restores quantizer state and parameters) | None | +| CALIB_BATCH_SIZE | Calibration batch size | 1 | Set these variables in your shell or Docker environment as needed to customize calibration. @@ -56,21 +58,45 @@ lm_eval --model local-completions --tasks gsm8k --model_args model=, ## Load QAT/PTQ model and serve in vLLM (WIP) -Overwrite the calibrated amax value with prepared values from either QAT/PTQ. +Step 1: export the model with bf16 weights and quantizer state. To export the model: -Step 1: export the model with bf16 weights and amax values. To export the model: +- For **HF** models, use `examples/llm_ptq/hf_ptq.py` with `--vllm_fakequant_export`: -- For HF model use `modelopt.torch.export.export_hf_vllm_fq_checkpoint` function. -- For MCore model use `modelopt.torch.export.export_mcore_gpt_to_hf_vllm_fq` function. +```bash +python ../llm_ptq/hf_ptq.py \ + --pyt_ckpt_path \ + --qformat nvfp4 \ + --calib_size 512 \ + --export_path \ + --vllm_fakequant_export \ + --trust_remote_code +``` + + This creates `/vllm_fq_modelopt_state.pth` (ModelOpt quantizer state for vLLM fake-quant reload) and saves the HF-exported model under `` (config/tokenizer/weights). + + Note: `--pyt_ckpt_path` can point to either an HF checkpoint or a ModelOpt-saved checkpoint (e.g., a QAT/QAD checkpoint produced by `examples/llm_qat/main.py`). If the input checkpoint is already quantized, the script will **skip re-quantization** and only export artifacts for vLLM fakequant reload. + +- For **MCore** models, export the model with flag `--export-vllm-fq` as described in [Megatron-LM README](https://github.com/NVIDIA/Megatron-LM/tree/main/examples/post_training/modelopt#-nvfp4-quantization-qauntization-aware-training-and-model-export). This generates `quantizer_state.pth`, which contains quantizer tensors for vLLM reload via `QUANT_FILE_PATH`. + +Step 2: use the exported artifacts when serving: + +- **HF export**: pass the exported `vllm_fq_modelopt_state.pth` via `MODELOPT_STATE_PATH` + +```bash +# HF +MODELOPT_STATE_PATH= python vllm_serve_fakequant.py -tp 8 --host 0.0.0.0 --port 8000 +``` -Step 2: configure from exported model using AMAX_FILE_PATH environment variable in step 1. For example: +- **MCore export**: pass the exported `quantizer_state.pth` via `QUANT_FILE_PATH` and set `QUANT_CFG` to match the MCore quantization recipe ```bash -AMAX_FILE_PATH= QUANT_CFG= python vllm_serve_fakequant.py -tp 8 --host 0.0.0.0 --port 8000 +# MCore +QUANT_CFG= QUANT_FILE_PATH= python vllm_serve_fakequant.py -tp 8 --host 0.0.0.0 --port 8000 ``` ## Known Problems -1. AWQ is not yet supported in vLLM. -2. QAT checkpoint export doesn't have KV Cache quantization enabled. KV Cache fake quantization works for PTQ. -3. Mixed precision checkpoint doesn't work currently. +1. **MCore reload does not use `MODELOPT_STATE_PATH`**; use `QUANT_FILE_PATH` and make sure `QUANT_CFG` matches the quantization recipe used for the original MCore model (otherwise quantizer keys/config won’t align). +2. AWQ reload is not supported yet +3. KV cache quantization export and reload is not supported in MCore yet. +4. **`NVFP4_KV_CFG` and `NVFP4_AFFINE_KV_CFG` require `--enforce-eager`**; these configs use a dynamic-block Triton kernel for KV-cache quantization that is incompatible with CUDA graph capture (the kernel grid is computed from Python-level tensor shapes, which get baked in at capture time). Without `--enforce-eager`, the captured grid will be wrong for different batch sizes, producing incorrect outputs. diff --git a/examples/vllm_serve/fakequant_worker.py b/examples/vllm_serve/fakequant_worker.py index 772c6fe66..ec2b1f403 100644 --- a/examples/vllm_serve/fakequant_worker.py +++ b/examples/vllm_serve/fakequant_worker.py @@ -13,318 +13,108 @@ # See the License for the specific language governing permissions and # limitations under the License. -import dataclasses + import os -import re -import warnings -from collections import defaultdict -from contextlib import contextmanager from typing import Any import torch -from tqdm import tqdm from transformers import AutoTokenizer -from vllm.sampling_params import SamplingParams -from vllm.v1.core.sched.output import CachedRequestData, NewRequestData, SchedulerOutput from vllm.v1.worker.gpu_worker import Worker as BaseWorker +from vllm_ptq_utils import calibrate_fun, get_quant_config +from vllm_reload_utils import ( + convert_dict_to_vllm, + convert_modelopt_state_to_vllm, + load_state_dict_from_path, + restore_from_modelopt_state_vllm, +) import modelopt.torch.quantization as mtq +from modelopt.torch.quantization.plugins.vllm import ( + disable_compilation, + post_restore_vllm_parallel_linears, +) from modelopt.torch.utils.dataset_utils import get_dataset_dataloader - -def convert_amax_hf2vllm( - hf_state_dict: dict[str, torch.Tensor], fuse_experts: bool = False -) -> dict[str, torch.Tensor]: - """ - Convert amax values from HuggingFace format to vLLM format. - - This function merges: - - q_proj, k_proj, v_proj amax values into qkv_proj (taking max) - - gate_proj, up_proj amax values into gate_up_proj (taking max) - - Args: - hf_state_dict: HuggingFace state dict containing amax values - - Returns: - vLLM format state dict with merged amax values - """ - vllm_state_dict = {} - - # Group keys by their base pattern (without the specific projection name) - merge_groups = defaultdict(list) - - for key, value in hf_state_dict.items(): - if "_amax" not in key: - # Copy non-amax keys as-is - vllm_state_dict[key] = value - continue - - # Check if this is a q/k/v projection that needs merging - qkv_match = re.search(r"(.*\.)([qkv])_proj(\..+_amax)$", key) - if qkv_match: - base_pattern = qkv_match.group(1) + "qkv_proj" + qkv_match.group(3) - merge_groups[base_pattern].append((key, value)) - continue - - # Check if this is an expert gate/up projection - # Pattern: model.layers.0.mlp.experts.*.gate_proj.input_quantizer._amax and - # model.layers.0.mlp.experts.*.up_proj.input_quantizer._amax - # Maps to: model.layers.0.mlp.experts.w13_input_quantizer._amax - expert_gate_up_match = ( - "mixer" not in key - and fuse_experts - and re.search(r"(.*\.experts)\.\d+\.(gate|up)_proj\.([^.]+_quantizer\._amax)$", key) - ) - if expert_gate_up_match: - base_pattern = expert_gate_up_match.group(1) + ".w13_" + expert_gate_up_match.group(3) - merge_groups[base_pattern].append((key, value)) - continue - - # Check if this is a non-expert gate/up projection that needs merging - gate_up_match = ( - "mixer" not in key - and "experts" not in key - and re.search(r"(.*\.)(gate|up)_proj(\..+_amax)$", key) - ) - if gate_up_match: - base_pattern = gate_up_match.group(1) + "gate_up_proj" + gate_up_match.group(3) - merge_groups[base_pattern].append((key, value)) - continue - - # Check if this is an expert down_proj - # Pattern: model.layers.0.mlp.experts.*.down_proj.input_quantizer._amax - # Maps to: model.layers.0.mlp.experts.w2_input_quantizer._amax - expert_down_match = ( - "mixer" not in key - and fuse_experts - and re.search(r"(.*\.experts)\.\d+\.down_proj\.([^.]+_quantizer\._amax)$", key) - ) - if expert_down_match: - base_pattern = expert_down_match.group(1) + ".w2_" + expert_down_match.group(2) - merge_groups[base_pattern].append((key, value)) - continue - - # Copy other amax keys as-is (like o_proj, down_proj) - vllm_state_dict[key] = value - - # Merge grouped amax values by taking the maximum - for merged_key, key_value_pairs in merge_groups.items(): - if len(key_value_pairs) > 1: - # Take the maximum across all values for this merged key - values = [value for _, value in key_value_pairs] - merged_value = torch.stack(values).max(dim=0)[0] - vllm_state_dict[merged_key] = merged_value - print(f"Merged {len(key_value_pairs)} keys into {merged_key}") - for orig_key, _ in key_value_pairs: - print(f" - {orig_key}") - else: - # Single key, just rename it - _, value = key_value_pairs[0] - vllm_state_dict[merged_key] = value - - return vllm_state_dict - - -@contextmanager -def disable_compilation(model): - do_not_compile = True - if hasattr(model, "model"): - do_not_compile = model.model.do_not_compile - model.model.do_not_compile = True - elif hasattr(model, "language_model"): - do_not_compile = model.language_model.model.do_not_compile - model.language_model.model.do_not_compile = True - else: - raise ValueError("Model does not have a model or language_model attribute") - - try: - yield - finally: - if hasattr(model, "model"): - model.model.do_not_compile = do_not_compile - elif hasattr(model, "language_model"): - model.language_model.model.do_not_compile = do_not_compile - - quant_config: dict[str, Any] = { "dataset": os.environ.get("QUANT_DATASET", "cnn_dailymail"), "calib_size": int(os.environ.get("QUANT_CALIB_SIZE", 512)), "quant_cfg": os.environ.get("QUANT_CFG", None), "kv_quant_cfg": os.environ.get("KV_QUANT_CFG", None), - "amax_file_path": os.environ.get("AMAX_FILE_PATH", None), + "quant_file_path": os.environ.get("QUANT_FILE_PATH", None), + "modelopt_state_path": os.environ.get("MODELOPT_STATE_PATH", None), + "calib_batch_size": int(os.environ.get("CALIB_BATCH_SIZE", 1)), } -def update_kv_cfg_for_mla(model: torch.nn.Module, kv_quant_cfg: dict[str, Any]) -> dict[str, Any]: - """Update KV cache quantization config for MLA models. - - MLA uses `kv_c_bmm_quantizer` (compressed KV) instead of separate - `k_bmm_quantizer` and `v_bmm_quantizer`. This function copies the - config from `*[kv]_bmm_quantizer` to also cover `*kv_c_bmm_quantizer`. - """ - try: - from vllm.attention.layer import MLAAttention - except ImportError: - return kv_quant_cfg - - if not any(isinstance(m, MLAAttention) for m in model.modules()): - return kv_quant_cfg - - if kv_config := kv_quant_cfg.get("*[kv]_bmm_quantizer"): - kv_quant_cfg["*kv_c_bmm_quantizer"] = kv_config - kv_quant_cfg["*k_pe_bmm_quantizer"] = kv_config - print("MLA detected: added *kv_c_bmm_quantizer and k_pe_bmm_quantizer config") - - return kv_quant_cfg - - -def _create_new_data_cls(data_cls, **kwargs): - """vLLM's low-level API changes frequently. This function creates a class with parameters - compatible with the different vLLM versions.""" - valid_params = {field.name for field in dataclasses.fields(data_cls)} - filtered_kwargs = {k: v for k, v in kwargs.items() if k in valid_params} - return data_cls(**filtered_kwargs) - - def _fakequant_run_prolog_worker(self) -> None: + trust_remote_code = os.environ.get("TRUST_REMOTE_CODE", "false").lower() == "true" tokenizer = AutoTokenizer.from_pretrained( self.model_runner.model_config.tokenizer, - trust_remote_code=True, + trust_remote_code=trust_remote_code, ) if tokenizer.pad_token != "" or tokenizer.pad_token is None: tokenizer.pad_token = tokenizer.eos_token - if quant_config["amax_file_path"]: - print("Will load amax, so only do a single sample calibration") - quant_config["calib_size"] = 1 - - calib_dataloader = get_dataset_dataloader( - dataset_name=quant_config["dataset"], - tokenizer=tokenizer, - batch_size=1, - num_samples=quant_config["calib_size"], - device=self.device, - ) - - def calibrate_loop(model: Any = None) -> None: - for batch_idx, batch in tqdm(enumerate(calib_dataloader)): - input_ids = batch["input_ids"][0] - - # Convert tensor to list of integers for vLLM compatibility - if torch.is_tensor(input_ids): - input_ids_list = input_ids.cpu().tolist() - else: - input_ids_list = list(input_ids) - - num_groups = len(self.model_runner.kv_cache_config.kv_cache_groups) - empty_block_ids = tuple([] for _ in range(num_groups)) - - req_id = f"req-{batch_idx}" - # Pass all possible parameters - the helper will filter based on vLLM version - new_req = _create_new_data_cls( - NewRequestData, - req_id=req_id, - prompt_token_ids=input_ids_list, - # Old API parameters - mm_kwargs=[], # TODO: remove this when vllm <= 0.11 is outdated - mm_hashes=[], # TODO: remove this when vllm <= 0.11 is outdated - mm_positions=[], # TODO: remove this when vllm <= 0.11 is outdated - # New API parameter - mm_features=[], - sampling_params=SamplingParams(max_tokens=1), - pooling_params=None, - block_ids=empty_block_ids, - num_computed_tokens=0, - lora_request=None, - ) - - scheduler_output = _create_new_data_cls( - SchedulerOutput, - scheduled_new_reqs=[new_req], - scheduled_cached_reqs=CachedRequestData.make_empty(), - num_scheduled_tokens={req_id: len(input_ids_list)}, - total_num_scheduled_tokens=len(input_ids_list), - scheduled_spec_decode_tokens={}, - scheduled_encoder_inputs={}, - num_common_prefix_blocks=[0] * num_groups, - finished_req_ids=set(), - free_encoder_mm_hashes=[], - kv_connector_metadata=None, - # Old API parameters - structured_output_request_ids={}, # TODO: remove this when vllm <= 0.11 is outdated - grammar_bitmask=None, # TODO: remove this when vllm <= 0.11 is outdated - ) - output = self.execute_model(scheduler_output) - if hasattr(self, "sample_tokens"): - if output is None: # TODO: make this default when vllm <= 0.11 is outdated - self.sample_tokens(None) - - quant_cfg = {} if quant_config["quant_cfg"] is None else getattr(mtq, quant_config["quant_cfg"]) - quant_kv_cfg = ( - {} if quant_config["kv_quant_cfg"] is None else getattr(mtq, quant_config["kv_quant_cfg"]) - ) - model = self.model_runner.model if hasattr(model, "unwrap"): model = model.unwrap() - - # Check if model has MLA and update KV config accordingly - if quant_kv_cfg: - quant_kv_cfg["quant_cfg"] = update_kv_cfg_for_mla(model, quant_kv_cfg["quant_cfg"]) - - if quant_kv_cfg: - quant_cfg = mtq.utils.update_quant_cfg_with_kv_cache_quant( - quant_cfg, quant_kv_cfg["quant_cfg"] + if quant_config["modelopt_state_path"]: + print(f"Loading modelopt state from {quant_config['modelopt_state_path']}") + # Load on CPU to avoid failures when the checkpoint was saved from a different + # GPU mapping + modelopt_state = torch.load( + quant_config["modelopt_state_path"], weights_only=True, map_location="cpu" ) + modelopt_weights = modelopt_state.pop("modelopt_state_weights", None) + map_fun = ( + self.model_runner.model.hf_to_vllm_mapper.apply_dict + if hasattr(self.model_runner.model, "hf_to_vllm_mapper") + else None + ) + # convert modelopt state to vllm format + modelopt_state = convert_modelopt_state_to_vllm(modelopt_state, map_fun=map_fun) + # restore model from modelopt state + restore_from_modelopt_state_vllm(model, modelopt_state) + + if modelopt_weights is not None: + # convert quantizer state values to vllm format + modelopt_weights = convert_dict_to_vllm(modelopt_weights, map_fun=map_fun) + mtq.utils.set_quantizer_state_dict(model, modelopt_weights) + # set_quantizer_state_dict does not invoke modelopt_post_restore (unlike restore_quantizer_state). + post_restore_vllm_parallel_linears(model) - with disable_compilation(model): - print("quantizing model...") - mtq.quantize(model, quant_cfg, forward_loop=calibrate_loop) - - amax_file_path = quant_config["amax_file_path"] - if amax_file_path: - print(f"Loading amax values from {amax_file_path}") - saved_amax_dict = torch.load(amax_file_path) - # convert amax keys to vLLM format - if hasattr(self.model_runner.model, "hf_to_vllm_mapper"): - saved_amax_dict = self.model_runner.model.hf_to_vllm_mapper.apply_dict(saved_amax_dict) - saved_amax_dict = { - key.replace("quantizer_amax", "quantizer._amax"): value - for key, value in saved_amax_dict.items() - if key.endswith("quantizer_amax") - } - saved_amax_dict = convert_amax_hf2vllm(saved_amax_dict, fuse_experts=True) + else: + if quant_config["quant_file_path"]: + print("Will load quant, so only do a single sample calibration") + quant_config["calib_size"] = 1 + + calib_dataloader = get_dataset_dataloader( + dataset_name=quant_config["dataset"], + tokenizer=tokenizer, + batch_size=quant_config["calib_batch_size"], + num_samples=quant_config["calib_size"], + device=self.device, + ) - current_state_dict = model.state_dict() - # Count amax keys in checkpoint and model - checkpoint_amax_keys = [key for key in saved_amax_dict if key.endswith("_amax")] - model_amax_keys = [key for key in current_state_dict if key.endswith("_amax")] - for key in checkpoint_amax_keys: - if key not in model_amax_keys: - print(f"Key {key} not found in model state dict, but exists in checkpoint") - for key in model_amax_keys: - if key not in checkpoint_amax_keys: - raise ValueError( - f"Key {key} not found in checkpoint state dict, but exists in model" - ) + calibrate_loop = calibrate_fun(calib_dataloader, self) - checkpoint_amax_count = len(checkpoint_amax_keys) - model_amax_count = len(model_amax_keys) + quant_cfg = get_quant_config(quant_config, model) - # Ensure counts match - if checkpoint_amax_count != model_amax_count: - warnings.warn( - f"Mismatch in amax key counts: checkpoint has {checkpoint_amax_count} " - f"amax keys but model has {model_amax_count} amax keys. This can happen if the model is using PP." - ) + # quantize model + with disable_compilation(model): + print("Quantizing model...") + mtq.quantize(model, quant_cfg, forward_loop=calibrate_loop) - # Update amax values - for key, value in saved_amax_dict.items(): - if key in current_state_dict: - current_state_dict[key] = value.to(current_state_dict[key].device) + quantizer_file_path = quant_config["quant_file_path"] + if quantizer_file_path: + # Get amax and other quantizer state from the quantizer file + # this can be used with Megatron-LM exported model using export_mcore_gpt_to_hf_vllm_fq + current_state_dict = load_state_dict_from_path(self, quantizer_file_path, model) + model.load_state_dict(current_state_dict) - model.load_state_dict(current_state_dict) - torch.distributed.barrier() + # Only barrier if distributed is actually initialized (avoids deadlocks). + if torch.distributed.is_initialized() and torch.distributed.get_world_size() > 1: + torch.distributed.barrier() if not torch.distributed.is_initialized() or torch.distributed.get_rank() == 0: mtq.print_quant_summary(model) @@ -345,6 +135,10 @@ def determine_available_memory(self) -> int: return super().determine_available_memory() def compile_or_warm_up_model(self) -> None: - if quant_config["quant_cfg"] or quant_config["kv_quant_cfg"]: + if ( + quant_config["quant_cfg"] + or quant_config["kv_quant_cfg"] + or quant_config["modelopt_state_path"] + ): _fakequant_run_prolog_worker(self) super().compile_or_warm_up_model() diff --git a/examples/vllm_serve/vllm_ptq_utils.py b/examples/vllm_serve/vllm_ptq_utils.py new file mode 100644 index 000000000..d6c055709 --- /dev/null +++ b/examples/vllm_serve/vllm_ptq_utils.py @@ -0,0 +1,143 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import dataclasses +from collections.abc import Callable +from typing import Any + +import torch +from torch.utils.data import DataLoader +from tqdm import tqdm +from vllm.sampling_params import SamplingParams +from vllm.v1.core.sched.output import CachedRequestData, NewRequestData, SchedulerOutput + +import modelopt.torch.quantization as mtq + + +def _create_new_data_cls(data_cls, **kwargs): + """vLLM's low-level API changes frequently. This function creates a class with parameters + compatible with the different vLLM versions.""" + valid_params = {field.name for field in dataclasses.fields(data_cls)} + filtered_kwargs = {k: v for k, v in kwargs.items() if k in valid_params} + return data_cls(**filtered_kwargs) + + +def calibrate_fun(calib_dataloader: DataLoader, self: Any) -> Callable[[Any], None]: + def calibrate_loop(model: Any) -> None: + for batch_idx, batch in tqdm(enumerate(calib_dataloader)): + input_ids_batch = batch["input_ids"] + + # Convert to list of flat token id lists (one per sequence in batch) + if torch.is_tensor(input_ids_batch): + input_ids_batch = input_ids_batch.cpu() + # Handle both [batch_size, seq_len] and [seq_len] + if input_ids_batch.dim() == 1: + input_ids_batch = input_ids_batch.unsqueeze(0) + input_ids_list_batch = [seq.tolist() for seq in input_ids_batch] + else: + input_ids_list_batch = [ + list(seq) if not isinstance(seq, list) else seq for seq in input_ids_batch + ] + if input_ids_list_batch and isinstance(input_ids_list_batch[0], int): + input_ids_list_batch = [input_ids_list_batch] + + num_groups = len(self.model_runner.kv_cache_config.kv_cache_groups) + empty_block_ids = tuple([] for _ in range(num_groups)) + + scheduled_new_reqs = [] + num_scheduled_tokens = {} + total_tokens = 0 + for seq_idx, input_ids_list in enumerate(input_ids_list_batch): + req_id = f"req-{batch_idx}-{seq_idx}" + new_req = _create_new_data_cls( + NewRequestData, + req_id=req_id, + prompt_token_ids=input_ids_list, + mm_kwargs=[], + mm_hashes=[], + mm_positions=[], + mm_features=[], + sampling_params=SamplingParams(max_tokens=1), + pooling_params=None, + block_ids=empty_block_ids, + num_computed_tokens=0, + lora_request=None, + ) + scheduled_new_reqs.append(new_req) + num_scheduled_tokens[req_id] = len(input_ids_list) + total_tokens += len(input_ids_list) + + scheduler_output = _create_new_data_cls( + SchedulerOutput, + scheduled_new_reqs=scheduled_new_reqs, + scheduled_cached_reqs=CachedRequestData.make_empty(), + num_scheduled_tokens=num_scheduled_tokens, + total_num_scheduled_tokens=total_tokens, + scheduled_spec_decode_tokens={}, + scheduled_encoder_inputs={}, + num_common_prefix_blocks=[0] * num_groups, + finished_req_ids=set(), + free_encoder_mm_hashes=[], + kv_connector_metadata=None, + structured_output_request_ids={}, + grammar_bitmask=None, + ) + output = self.execute_model(scheduler_output) + if hasattr(self, "sample_tokens"): + if output is None: # TODO: make this default when vllm <= 0.11 is outdated + self.sample_tokens(None) + + return calibrate_loop + + +def update_kv_cfg_for_mla(model: torch.nn.Module, kv_quant_cfg: dict[str, Any]) -> dict[str, Any]: + """Update KV cache quantization config for MLA models. + + MLA uses `kv_c_bmm_quantizer` (compressed KV) instead of separate + `k_bmm_quantizer` and `v_bmm_quantizer`. This function copies the + config from `*[kv]_bmm_quantizer` to also cover `*kv_c_bmm_quantizer`. + """ + try: + from vllm.attention.layer import MLAAttention + except ImportError: + return kv_quant_cfg + + if not any(isinstance(m, MLAAttention) for m in model.modules()): + return kv_quant_cfg + + if kv_config := kv_quant_cfg.get("*[kv]_bmm_quantizer"): + kv_quant_cfg["*kv_c_bmm_quantizer"] = kv_config + kv_quant_cfg["*k_pe_bmm_quantizer"] = kv_config + print("MLA detected: added *kv_c_bmm_quantizer and k_pe_bmm_quantizer config") + + return kv_quant_cfg + + +def get_quant_config(quant_config: dict[str, Any], model: Any) -> dict[str, Any]: + quant_cfg = getattr(mtq, quant_config["quant_cfg"]) if quant_config["quant_cfg"] else {} + quant_kv_cfg = ( + getattr(mtq, quant_config["kv_quant_cfg"]) if quant_config["kv_quant_cfg"] else {} + ) + + # Check if model has MLA and update KV config accordingly + if quant_kv_cfg: + quant_kv_cfg["quant_cfg"] = update_kv_cfg_for_mla(model, quant_kv_cfg["quant_cfg"]) + + if quant_kv_cfg: + quant_cfg = mtq.utils.update_quant_cfg_with_kv_cache_quant( + quant_cfg, quant_kv_cfg["quant_cfg"] + ) + + return quant_cfg diff --git a/examples/vllm_serve/vllm_reload_utils.py b/examples/vllm_serve/vllm_reload_utils.py new file mode 100644 index 000000000..b67c92ae6 --- /dev/null +++ b/examples/vllm_serve/vllm_reload_utils.py @@ -0,0 +1,461 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import re +import warnings +from collections import defaultdict +from collections.abc import Callable +from typing import Any + +import torch +from vllm.distributed.parallel_state import get_tp_group + +from modelopt.torch.opt.conversion import ( + ModelLikeModule, + ModeloptStateManager, + _check_init_modellike, +) +from modelopt.torch.quantization.conversion import ( + convert_to_quantized_model, + restore_quantizer_state, +) +from modelopt.torch.quantization.utils import is_quantized + + +def _values_equal(v1: Any, v2: Any) -> bool: + """Compare values, handling dicts with tensors.""" + if isinstance(v1, dict) and isinstance(v2, dict): + if v1.keys() != v2.keys(): + return False + return all( + torch.equal(v1[k], v2[k]) if isinstance(v1[k], torch.Tensor) else v1[k] == v2[k] + for k in v1 + ) + elif isinstance(v1, torch.Tensor) and isinstance(v2, torch.Tensor): + return torch.equal(v1, v2) + return v1 == v2 + + +def _convert_key_for_vllm(key: str, value: Any) -> tuple[str, str | None, Any]: + """ + Transform a single key from HuggingFace format to vLLM format. + + Returns: + Tuple of (action, new_key_or_group, value) where action is one of: + - "copy": Copy value to new_key directly + - "group": Add to merge group identified by new_key + - "skip": Skip this key entirely + """ + if "quantizer" not in key: + return ("copy", key, value) + + # Skip softmax_quantizer and lm_head quantizers(not needed in vLLM) + if "softmax_quantizer" in key or (key.startswith("lm_head.") and "quantizer" in key): + return ("skip", None, None) + + # Check if this is a q/k/v projection that needs merging + qkv_match = re.search(r"(.*\.)([qkv])_proj\.([^.]+_quantizer)(\..+)?$", key) + if qkv_match: + suffix = qkv_match.group(4) or "" + group_key = qkv_match.group(1) + "qkv_proj." + qkv_match.group(3) + suffix + return ("group", group_key, value) + + # Check if this is an expert gate/up projection + # if "mixer" not in key: + expert_gate_up_match = re.search( + r"(.*\.experts)\.\d+\.(gate|up)_proj\.([^.]+_quantizer)(\..+)?$", key + ) + if expert_gate_up_match: + suffix = expert_gate_up_match.group(4) or "" + group_key = expert_gate_up_match.group(1) + ".w13_" + expert_gate_up_match.group(3) + suffix + return ("group", group_key, value) + + # Check if this is a non-expert gate/up projection that needs merging + if "mixer" not in key and "experts" not in key: + gate_up_match = re.search(r"(.*\.)(gate|up)_proj\.([^.]+_quantizer)(\..+)?$", key) + if gate_up_match: + suffix = gate_up_match.group(4) or "" + group_key = gate_up_match.group(1) + "gate_up_proj." + gate_up_match.group(3) + suffix + return ("group", group_key, value) + + # Check if this is an expert down_proj + # if "mixer" not in key: + expert_down_match = re.search(r"(.*\.experts)\.\d+\.down_proj\.([^.]+_quantizer)(\..+)?$", key) + if expert_down_match: + suffix = expert_down_match.group(3) or "" + group_key = expert_down_match.group(1) + ".w2_" + expert_down_match.group(2) + suffix + return ("group", group_key, value) + + # Transform bmm_quantizer keys: self_attn.q/k/v_bmm_quantizer -> self_attn.attn.q/k/v_bmm_quantizer + bmm_match = re.search(r"(.*\.self_attn)\.([qkv]_bmm_quantizer.*)$", key) or re.search( + r"(.*\.mixer)\.([qkv]_bmm_quantizer.*)$", key + ) + if bmm_match: + new_key = bmm_match.group(1) + ".attn." + bmm_match.group(2) + return ("copy", new_key, value) + + # Copy other quantizer keys as-is (like o_proj, down_proj) + return ("copy", key, value) + + +def _group_keys_for_vllm( + state_dict: dict[str, Any], +) -> tuple[dict[str, Any], defaultdict[str, list[tuple[str, Any]]]]: + """ + Process state dict and group keys that need merging. + + Returns: + Tuple of (direct_copy_dict, merge_groups) + """ + vllm_state_dict = {} + merge_groups = defaultdict(list) + + for key, value in state_dict.items(): + action, new_key, new_value = _convert_key_for_vllm(key, value) + if new_key is None or new_value is None: + assert action == "skip", ( + f"Expected action to be 'skip' for key {key}, value {value}, got {action}" + ) + continue + if action == "copy": + vllm_state_dict[new_key] = new_value + elif action == "group": + merge_groups[new_key].append((key, new_value)) + # action == "skip" does nothing + + return vllm_state_dict, merge_groups + + +def _merge_values_by_max_or_concat(merged_key: str, key_value_pairs: list[tuple[str, Any]]) -> Any: + """ + Merge values by taking max for amax, concatenating for others. + Used for quantizer state weights (tensor values). + """ + if not key_value_pairs: + raise ValueError(f"Cannot merge '{merged_key}': key_value_pairs is empty") + values = [value for _, value in key_value_pairs] + + # Check if values are dicts (OrderedDict) containing tensors + if isinstance(values[0], dict): + merged_value = {} + for dict_key in values[0]: + tensors = [v[dict_key] for v in values] + if "_amax" in dict_key: + merged_value[dict_key] = torch.stack(tensors).max(dim=0)[0] + elif "_pre_quant_scale" in dict_key: + # _pre_quant_scale is per-input-channel: identical across q/k/v projections + # since they share the same input. Do not concatenate; take the first value. + merged_value[dict_key] = tensors[0] + else: + merged_value[dict_key] = torch.cat(tensors, dim=0) + return merged_value + else: + # Values are tensors directly + if "_amax" in merged_key: + merged_value = torch.stack(values).max(dim=0)[0] + else: + merged_value = torch.cat(values, dim=0) + return merged_value + + +def _merge_values_require_identical(merged_key: str, key_value_pairs: list[tuple[str, Any]]) -> Any: + """ + Merge values by requiring all values to be identical. + Used for quantizer state (config/metadata). + """ + keys = [k for k, _ in key_value_pairs] + values = [v for _, v in key_value_pairs] + first_value = values[0] + + # If all quantizers are disabled, their shape-specific fields (e.g. _amax_shape_for_export) + # will differ across q/k/v projections even though the config is logically the same. + # Since disabled quantizers are not used, skip the equality check. + if all(isinstance(v, dict) and v.get("_disabled") for v in values): + return first_value + + for i, val in enumerate(values[1:], start=1): + if not _values_equal(val, first_value): + raise ValueError( + f"Cannot merge keys into '{merged_key}': values differ.\n" + f" '{keys[0]}' has value: {first_value}\n" + f" '{keys[i]}' has value: {val}" + ) + return first_value + + +def convert_dict_to_vllm( + state_dict: dict[str, Any], + max_or_concat: bool = True, + map_fun: Callable[[dict[str, Any]], dict[str, Any]] | None = None, +) -> dict[str, Any]: + """ + Common implementation for converting quantizer state from HF to vLLM format. + + Args: + state_dict: Input state dict + max_or_concat: Whether to merge grouped values by taking max/concatenate or require identical + map_fun: Function to map the state dict to vLLM format + """ + vllm_state_dict, merge_groups = _group_keys_for_vllm(state_dict) + + merge_fn = _merge_values_by_max_or_concat if max_or_concat else _merge_values_require_identical + + # Merge grouped values + for merged_key, key_value_pairs in merge_groups.items(): + if len(key_value_pairs) > 1: + merged_value = merge_fn(merged_key, key_value_pairs) + vllm_state_dict[merged_key] = merged_value + else: + # Single key, just rename it + _, value = key_value_pairs[0] + vllm_state_dict[merged_key] = value + if map_fun is None: + return vllm_state_dict + # Quantizer module-path keys (e.g. "layers.0.mlp.gate_proj.input_quantizer") must NOT + # go through map_fun (hf_to_vllm_mapper.apply_dict), which maps weight tensor paths and + # drops any key it doesn't recognise — including all quantizer keys. Split them out, + # apply map_fun only to non-quantizer keys, then merge back. + quantizer_keys = {k: v for k, v in vllm_state_dict.items() if "_quantizer" in k} + non_quantizer_keys = {k: v for k, v in vllm_state_dict.items() if "_quantizer" not in k} + mapped = map_fun(non_quantizer_keys) if non_quantizer_keys else {} + return {**mapped, **quantizer_keys} + + +def convert_modelopt_state_to_vllm( + modelopt_state: dict[str, Any], + map_fun: Callable[[dict[str, Any]], dict[str, Any]] | None = None, +) -> dict[str, Any]: + """ + Convert modelopt state from HuggingFace format to vLLM compatible format. + + This function converts the quantizer state from HuggingFace format to vLLM compatible format. + + Note: modifies modelopt_state in place (pops keys). Callers that need the + original dict should pass a copy. + + Args: + modelopt_state: HuggingFace modelopt state dict (modified in place) + map_fun: Optional function to remap non-quantizer keys to vLLM names + + Returns: + vLLM compatible modelopt state dict + """ + modelopt_state_dict = modelopt_state.pop("modelopt_state_dict", []) + for idx, current_mode in enumerate(modelopt_state_dict): + current_mode_metadata = current_mode[1].pop("metadata", {}) + current_mode_quant_state = current_mode_metadata.pop("quantizer_state", {}) + if current_mode_quant_state: + current_mode_metadata["quantizer_state"] = convert_dict_to_vllm( + current_mode_quant_state, max_or_concat=False, map_fun=map_fun + ) + else: + current_mode_metadata.pop("quantizer_state", None) + current_mode[1]["metadata"] = current_mode_metadata + modelopt_state_dict[idx] = (current_mode[0], current_mode[1]) + modelopt_state["modelopt_state_dict"] = modelopt_state_dict + return modelopt_state + + +def filter_modelopt_state_quantizer_state_for_model( + modelopt_state: dict[str, Any], model: torch.nn.Module +) -> None: + """ + Align quantizer_state in modelopt_state metadata with the model. + + - Removes keys not in the model (handles TP sharding - each rank has a subset). + - Removes keys only when the quantizer is disabled (in the model). + - Adds keys for quantizers in the model but not in metadata (e.g. disabled/excluded). + Modifies modelopt_state in place. Call after convert_to_quantized_model so the model has + quantizers. + + Args: + modelopt_state: Modelopt state dict (modified in place) + model: Model with quantizers (must already be converted) + """ + from modelopt.torch.quantization.conversion import quantizer_state + from modelopt.torch.quantization.nn import SequentialQuantizer, TensorQuantizer + from modelopt.torch.utils import get_unwrapped_name + + model_qstate = quantizer_state(model) + model_keys = set(model_qstate.keys()) + # Build name -> is_enabled for quantizers in the model + disabled_keys = set() + for name, module in model.named_modules(): + if isinstance(module, (TensorQuantizer, SequentialQuantizer)): + unwrapped_name = get_unwrapped_name(name, model) + if not getattr(module, "is_enabled", True): + disabled_keys.add(unwrapped_name) + + for mode_entry in modelopt_state.get("modelopt_state_dict", []): + metadata = mode_entry[1].get("metadata", {}) + if "quantizer_state" in metadata: + saved = metadata["quantizer_state"] + + # Keep keys that exist in the model. Remove disabled quantizers UNLESS they + # have registered buffers (e.g. _pre_quant_scale from AWQ/smoothquant on a + # disabled input_quantizer). Those buffers must reach _reset_pytorch_state_from_metadata + # so they get registered before set_quantizer_state_dict loads the values. + def _has_buffers(state: dict) -> bool: + return bool(state.get("_pytorch_state_metadata", {}).get("buffers")) + + filtered = { + k: v + for k, v in saved.items() + if k in model_keys and (k not in disabled_keys or _has_buffers(v)) + } + # Add state for quantizers in model but not in metadata (e.g. disabled/excluded) + for k in model_keys - filtered.keys(): + filtered[k] = model_qstate[k] + metadata["quantizer_state"] = filtered + + +def restore_from_modelopt_state_vllm( + model: torch.nn.Module, modelopt_state: dict[str, Any] +) -> torch.nn.Module: + """ + vLLM-specific restore that filters quantizer_state to match the model before restore. + + Handles TP sharding (each rank has a subset of quantizers) and excluded disabled quantizers + by running convert first, filtering metadata to model keys, then restoring. Uses the same + restore logic as restore_from_modelopt_state but with filtering for quantize modes. + """ + model = model if isinstance(model, torch.nn.Module) else ModelLikeModule(model) + manager = ModeloptStateManager(model=model, init_state=True) + manager.load_state_dict( + modelopt_state["modelopt_state_dict"], modelopt_state["modelopt_version"] + ) + + for i, (m, config, metadata) in enumerate(manager.modes_with_states()): + if i == 0: + model = _check_init_modellike(model, m) + # For quantize modes: convert first (if not already), filter metadata to model keys, then restore state. + # This handles TP (model has subset of quantizers) and excluded disabled quantizers. + if "quantizer_state" in metadata: + if not is_quantized(model): + convert_to_quantized_model(model, config) + filter_modelopt_state_quantizer_state_for_model( + {"modelopt_state_dict": manager._state}, model + ) + # Re-fetch metadata after filtering (manager._state was modified in place) + metadata = manager._state[i][1]["metadata"] + model = restore_quantizer_state(model, config, metadata) + else: + model = m.restore(model, config, metadata) + + if not manager.has_state and isinstance(model, ModelLikeModule): + model = model.init_modellike() + assert not isinstance(model, ModelLikeModule), "Model must be a regular Module now!" + return model + + +def process_state_dict_for_tp(saved_qstate_dict, current_state_dict): + """Shard quantizer tensors for tensor parallelism by matching expected shapes.""" + tp_group = get_tp_group() + tp_rank = tp_group.rank_in_group + tp_world_size = tp_group.world_size + + result = {} + for key, value in saved_qstate_dict.items(): + if key in current_state_dict: + expected = current_state_dict[key] + if not hasattr(value, "shape") or not hasattr(expected, "shape"): + result[key] = value + continue + expected_shape = expected.shape + value_shape = value.shape + if value_shape != expected_shape: + # Verify compatible rank before indexing + if len(value_shape) != len(expected_shape): + raise ValueError( + f"Cannot infer TP shard dim for {key}: rank mismatch " + f"(checkpoint rank={len(value_shape)}, expected rank={len(expected_shape)})" + ) + # Find the dimension that was tensor-parallel sharded. + # We expect exactly one dimension to satisfy: + # checkpoint_dim == expected_dim * tp_world_size + shard_dims = [ + d + for d in range(len(expected_shape)) + if value_shape[d] == expected_shape[d] * tp_world_size + ] + if len(shard_dims) != 1: + raise ValueError( + f"Cannot infer TP shard dim for {key}: " + f"expected_shape={tuple(expected_shape)}, checkpoint_shape={tuple(value_shape)}" + ) + + shard_dim = shard_dims[0] + shard_size = expected_shape[shard_dim] + start = tp_rank * shard_size + end = start + shard_size + if end > value_shape[shard_dim]: + raise ValueError( + f"TP shard out of bounds for {key}: " + f"expected_shape={tuple(expected_shape)}, checkpoint_shape={tuple(value_shape)}" + ) + value = value.narrow(shard_dim, start, shard_size).contiguous() + result[key] = value + + return result + + +def load_state_dict_from_path( + fakequant_runner: Any, quantizer_file_path: str, model: Any +) -> dict[str, Any]: + fakequant_runner.model_runner._dummy_run(1) + print(f"Loading quantizer values from {quantizer_file_path}") + # Load on CPU to avoid failures when the checkpoint was saved from a different + # GPU mapping + saved_quant_dict = torch.load(quantizer_file_path, weights_only=True, map_location="cpu") + # convert quant keys to vLLM format + if hasattr(fakequant_runner.model_runner.model, "hf_to_vllm_mapper"): + saved_quant_dict = fakequant_runner.model_runner.model.hf_to_vllm_mapper.apply_dict( + saved_quant_dict + ) + saved_quant_dict = { + key.replace("quantizer_", "quantizer._"): value + for key, value in saved_quant_dict.items() + if "quantizer_" in key + } + saved_quant_dict = convert_dict_to_vllm(saved_quant_dict) + + current_state_dict = model.state_dict() + # Count quant keys in checkpoint and model + checkpoint_quant_keys = [key for key in saved_quant_dict if "quantizer" in key] + model_quant_keys = [key for key in current_state_dict if "quantizer" in key] + for key in checkpoint_quant_keys: + if key not in model_quant_keys: + print(f"Key {key} not found in model state dict, but exists in checkpoint") + for key in model_quant_keys: + if key not in checkpoint_quant_keys: + raise ValueError(f"Key {key} not found in checkpoint state dict, but exists in model") + + checkpoint_quant_count = len(checkpoint_quant_keys) + model_quant_count = len(model_quant_keys) + + # Ensure counts match + if checkpoint_quant_count != model_quant_count: + warnings.warn( + f"Mismatch in quantizer state key counts: checkpoint has {checkpoint_quant_count} " + f"quant keys but model has {model_quant_count} quantizer state keys. " + f"This can happen if the model is using PP." + ) + + # Update quant values + saved_quant_dict = process_state_dict_for_tp(saved_quant_dict, current_state_dict) + for key, value in saved_quant_dict.items(): + if key in current_state_dict: + current_state_dict[key] = value.to(current_state_dict[key].device) + return current_state_dict diff --git a/examples/vllm_serve/vllm_serve_fakequant.py b/examples/vllm_serve/vllm_serve_fakequant.py index 25483f2be..e71707dca 100644 --- a/examples/vllm_serve/vllm_serve_fakequant.py +++ b/examples/vllm_serve/vllm_serve_fakequant.py @@ -74,8 +74,11 @@ "QUANT_DATASET", "QUANT_CALIB_SIZE", "QUANT_CFG", - "AMAX_FILE_PATH", + "QUANT_FILE_PATH", "KV_QUANT_CFG", + "MODELOPT_STATE_PATH", + "CALIB_BATCH_SIZE", + "TRUST_REMOTE_CODE", } RayDistributedExecutor.ADDITIONAL_ENV_VARS.update(additional_env_vars) diff --git a/modelopt/torch/export/plugins/vllm_fakequant_hf.py b/modelopt/torch/export/plugins/vllm_fakequant_hf.py index 54987b40c..5d6655d1f 100644 --- a/modelopt/torch/export/plugins/vllm_fakequant_hf.py +++ b/modelopt/torch/export/plugins/vllm_fakequant_hf.py @@ -19,8 +19,11 @@ import torch import torch.nn as nn -from modelopt.torch.export.layer_utils import is_quantlinear +import modelopt.torch.opt as mto +from modelopt.torch.quantization.conversion import quantizer_state +from modelopt.torch.quantization.nn import QuantModule, TensorQuantizer from modelopt.torch.quantization.utils import get_quantizer_state_dict +from modelopt.torch.utils import get_unwrapped_name __all__ = ["export_hf_vllm_fq_checkpoint"] @@ -29,34 +32,122 @@ def export_hf_vllm_fq_checkpoint( model: nn.Module, export_dir: Path | str, ): - """Exports the torch model weights and amax values separately. + """Export quantized HF weights + ``vllm_fq_modelopt_state.pth`` for vLLM fake-quant reload. - This function: - 1. Extracts amax values for calibration - 2. Deletes all quantizer parameters from state dict to store only weights in original dtype - 3. Saves the model weights + Folds fake-quant weights into a ``state_dict()`` copy (optional + ``pre_quant_scale`` into weight when input fake-quant is off), drops quantizer + keys from the HF save, briefly disables weight quantizers to snapshot + ModelOpt/quantizer state, then re-enables them. Writes ``export_dir`` via + ``save_pretrained(..., save_modelopt_state=False)``. Args: - model: The quantized model to export - export_dir: Directory to save the amax values - + model: In-memory quantized model. + export_dir: Output dir for HF files and ``vllm_fq_modelopt_state.pth``. """ export_dir = Path(export_dir) export_dir.mkdir(parents=True, exist_ok=True) - amax_dict = { - name + "._amax": param["_amax"].detach().clone().cpu() - for name, param in get_quantizer_state_dict(model).items() - if "_amax" in param - } + # Step 1: Build the folded HF state dict. + # model.state_dict() returns detached copies of all tensors, so model + # parameters are never modified. Apply each weight quantizer's fake-quant + # to the corresponding weight tensor in the copy. + state_dict = model.state_dict() + fakequant_weights = set() + input_quantizers_folded_pqs = ( + set() + ) # keys for input_quantizers where pre_quant_scale was folded + with torch.inference_mode(): + for module_name, module in model.named_modules(): + if not isinstance(module, QuantModule): + continue + for attr_name, quantizer in module.named_children(): + if not ( + attr_name.endswith("weight_quantizer") + and isinstance(quantizer, TensorQuantizer) + and quantizer.fake_quant + and quantizer.is_enabled + ): + continue + weight_name = attr_name.removesuffix("_quantizer") + prefix = f"{module_name}." if module_name else "" + sd_key = f"{prefix}{weight_name}" + assert sd_key not in fakequant_weights, ( + f"Weight {sd_key} has already been fakequantized" + ) + if sd_key in state_dict: + w = state_dict[sd_key] + w_quant = quantizer(w.float()).to(w.dtype).cpu() + # Fold pre_quant_scale: (x*s)@fake_quant(W) = x@(fake_quant(W)*s) + # Only valid when input_quantizer does NOT fake-quant activations. If it does + # fake_quant(x*s), the non-linearity prevents folding s into W. + inp_attr = attr_name.replace("weight_quantizer", "input_quantizer") + if hasattr(module, inp_attr): + inp_q = getattr(module, inp_attr) + if ( + hasattr(inp_q, "_pre_quant_scale") + and inp_q._pre_quant_scale is not None + and inp_q._disabled + ): + scale = inp_q._pre_quant_scale.squeeze().to(device=w_quant.device) + w_quant = (w_quant * scale[None, :]).to(w_quant.dtype) + inp_q_key = get_unwrapped_name( + f"{module_name}.{inp_attr}" if module_name else inp_attr, model + ) + input_quantizers_folded_pqs.add(inp_q_key) + state_dict[sd_key] = w_quant + fakequant_weights.add(sd_key) + + # Filter quantizer tensors out for a clean HF checkpoint. + clean_sd = {k: v for k, v in state_dict.items() if "quantizer" not in k} - # remove quantizer from model + # Step 2: Disable weight quantizers, save modelopt state + quantizer state + # dict, then re-enable. The _disabled=True flag is captured in modelopt_state + # so that on vLLM reload weight quantizers stay off while input/output/ + # attention quantizers remain active. + wqs_to_restore = [] for _, module in model.named_modules(): - if is_quantlinear(module): - for attr in ["weight_quantizer", "input_quantizer", "output_quantizer"]: - if hasattr(module, attr): - delattr(module, attr) - module.export() - torch.save(amax_dict, f"{export_dir}/quant_amax.pth") - # Save model - model.save_pretrained(export_dir, state_dict=model.state_dict(), save_modelopt_state=False) + if isinstance(module, QuantModule): + for attr_name, quantizer in module.named_children(): + if ( + attr_name.endswith("weight_quantizer") + and isinstance(quantizer, TensorQuantizer) + and quantizer.is_enabled + ): + quantizer.disable() + wqs_to_restore.append(quantizer) + + quantizer_state_dict = get_quantizer_state_dict(model) + for key in list(quantizer_state_dict): + if key.endswith("weight_quantizer"): + # Fakequant amax is folded into HF weights; do not reload weight quantizer tensors. + quantizer_state_dict.pop(key) + elif key in input_quantizers_folded_pqs: + # pre_quant_scale was folded into the weight; keep the buffer for strict load but + # save identity so activations are not scaled twice. + qstate_val = quantizer_state_dict[key] + if isinstance(qstate_val, dict) and "_pre_quant_scale" in qstate_val: + quantizer_state_dict[key]["_pre_quant_scale"] = torch.ones_like( + qstate_val["_pre_quant_scale"] + ) + modelopt_state = mto.modelopt_state(model) + # ``modelopt_state`` may be stale if another mode (e.g. calibrate) ran last. Rebuild + # ``quantizer_state`` and drop disabled weight quantizer entries (weights already folded). + qstate = quantizer_state(model) + for key in list(qstate): + if key.endswith("weight_quantizer") and qstate[key].get("_disabled"): + qstate.pop(key) + + for mode_str, m_state in modelopt_state.get("modelopt_state_dict", []): + if mode_str == "quantize" and "metadata" in m_state: + m_state["metadata"]["quantizer_state"] = qstate + break + + # Per-quantizer tensor dict loaded alongside metadata on reload. + modelopt_state["modelopt_state_weights"] = quantizer_state_dict + torch.save(modelopt_state, export_dir / "vllm_fq_modelopt_state.pth") + + # Step 3: Save HF weights using the pre-built folded state dict. + model.save_pretrained(export_dir, state_dict=clean_sd, save_modelopt_state=False) + + for wq in wqs_to_restore: + wq.enable() diff --git a/modelopt/torch/export/plugins/vllm_fakequant_megatron.py b/modelopt/torch/export/plugins/vllm_fakequant_megatron.py index 3f69271b0..9a41ae2ba 100644 --- a/modelopt/torch/export/plugins/vllm_fakequant_megatron.py +++ b/modelopt/torch/export/plugins/vllm_fakequant_megatron.py @@ -16,66 +16,80 @@ import os import tempfile +from collections.abc import Mapping from pathlib import Path +from typing import Any import torch from modelopt.torch.export.model_config import QUANTIZATION_NONE from modelopt.torch.export.unified_export_megatron import GPTModelExporter +from modelopt.torch.quantization.utils import get_quantizer_state_dict +from modelopt.torch.utils.distributed import DistributedProcessGroup, is_master __all__ = ["export_mcore_gpt_to_hf_vllm_fq"] def gather_mcore_vllm_fq_quantized_state_dict( - model, state_dict: dict[str, torch.Tensor], save_directory: str | os.PathLike -): - """Gather all quantized state dict from all ranks and save them to a file. + _model, + layer_state_dicts: Mapping[Any, dict[str, torch.Tensor]], + save_directory: str | os.PathLike, +) -> None: + """Gather quantizer tensors from every per-layer export shard, sync across ranks, and save. - Args: - state_dict: The state dictionary of the module. - save_directory: The directory to save the quantized state dict. + Megatron export stores one ``OrderedDict`` per decoder layer in ``layer_state_dicts``; the + ``GPTModelExporter.state_dict`` property only references the last shard after build, so + quantizer sidecars must be collected from all shards. - Returns: - The state dictionary of the module without quantized state. + Args: + _model: Unused; kept for a stable call signature with export entry points. + layer_state_dicts: Mapping from layer index to that shard's flat export state dict. + save_directory: Directory for ``quantizer_state.pth``. """ - amax_state_dict = { - k: v.detach().clone().cpu() for k, v in state_dict.items() if k.endswith("_amax") - } - - # Gather all amax dicts to rank 0 - world_size = torch.distributed.get_world_size() - rank = torch.distributed.get_rank() - - if rank == 0: - # Rank 0 will collect all amax values - all_amax_dicts = [None] * world_size - torch.distributed.gather_object(amax_state_dict, all_amax_dicts, dst=0) - - # Merge all amax dicts into one - merged_amax_dict = {} - for amax_dict in all_amax_dicts: - if amax_dict is not None: - merged_amax_dict.update(amax_dict) - - print(f"Total amax entries from all ranks: {len(merged_amax_dict.keys())}") - torch.save(merged_amax_dict, save_directory + "/quant_amax.pth") - else: - # Other ranks just send their amax values - torch.distributed.gather_object(amax_state_dict, None, dst=0) - - torch.distributed.barrier() + quantizer_state_dict: dict[str, torch.Tensor] = {} + for sd in layer_state_dicts.values(): + for k, v in sd.items(): + if "quantizer" in k: + quantizer_state_dict[k] = v.detach().clone().cpu() + + def _merge_quantizer_states(objs: list) -> dict: + merged: dict = {} + for d in objs: + if d is not None: + merged.update(d) + return merged + + merged_quantizer_state_dict = DistributedProcessGroup.get_dist_syncd_obj( + quantizer_state_dict, + DistributedProcessGroup(None), + _merge_quantizer_states, + ) + if is_master(): + torch.save(merged_quantizer_state_dict, Path(save_directory) / "quantizer_state.pth") class VllmFqGPTModelExporter(GPTModelExporter): """VLLM fakequant GPTModel exporter.""" + @staticmethod + def _pop_quantizer_keys(state_dict: dict) -> None: + """Remove quantizer tensors from an export shard (OrderedDict-safe).""" + for k in [k for k in state_dict if "quantizer" in k]: + state_dict.pop(k, None) + def save_pretrained( self, save_directory: str | os.PathLike, pretrained_model_name_or_path: str | os.PathLike, ): - os.makedirs(save_directory, exist_ok=True) - gather_mcore_vllm_fq_quantized_state_dict(self.model, self.state_dict, save_directory) + """Save HF shards + sidecar ``quantizer_state.pth``; then delegate to base export. + + Pipeline-parallel placement of ``config.json``, tokenizer, and multimodal tensors + remains handled by ``GPTModelExporter.save_pretrained`` (via ``super()``). + """ + save_dir = os.fspath(save_directory) + os.makedirs(save_dir, exist_ok=True) + assert not (self.is_multimodal and pretrained_model_name_or_path is not None), ( "Exporting weights in bf16 and amax values is not supported for multimodal models " "when pretrained_model_name_or_path is not None" @@ -83,11 +97,54 @@ def save_pretrained( assert not self.export_extra_modules, ( "Exporting extra modules is not supported for vLLM fakequant" ) + + gather_mcore_vllm_fq_quantized_state_dict(self.model, self.layer_state_dicts, save_dir) + + self._pop_quantizer_keys(self.state_dict) + for _layer_sd in self.layer_state_dicts.values(): + self._pop_quantizer_keys(_layer_sd) + super().save_pretrained(save_directory, pretrained_model_name_or_path) def _get_quantization_format(self, module: torch.nn.Module): return QUANTIZATION_NONE + def _get_quantized_state( + self, + module: torch.nn.Module, + dtype: torch.dtype = torch.float16, + prefix: str = "", + ) -> tuple[dict[str, torch.Tensor], str, int]: + """Return a state_dict, quantization format, and block_size of the module. + + Args: + module: The target module to perform real quantization. + dtype: The default data type. + + Returns: + Tuple: state_dict, quantization format, and block_size of the module. + """ + name_to_value = {} + qformat: str = self._get_quantization_format(module) + if qformat is None and "norm" not in prefix: + # Add exclude layers for vllm fakequant config. Note that if the prefix is not an empty + # string then it usually ends with "." which needs to be removed. + self.exclude_modules.append(prefix.removesuffix(".")) + block_size = 0 + + if hasattr(module, "weight") and module.weight is not None: + weight = module.weight.to(dtype).cpu() + name_to_value["weight"] = weight + else: + return name_to_value, qformat, block_size + + if hasattr(module, "bias") and module.bias is not None: + name_to_value["bias"] = module.bias.to(dtype).cpu() + for name, param in get_quantizer_state_dict(module).items(): + for key, value in param.items(): + name_to_value[name + "." + key] = value.to(dtype).cpu() + return name_to_value, qformat, block_size + def export_mcore_gpt_to_hf_vllm_fq( model: torch.nn.Module, @@ -96,6 +153,7 @@ def export_mcore_gpt_to_hf_vllm_fq( dtype: torch.dtype = torch.bfloat16, export_dir: Path | str = tempfile.gettempdir(), moe_router_dtype: torch.dtype | None = None, + trust_remote_code: bool = False, ): """Export Megatron Core GPTModel to unified checkpoint and save to export_dir. @@ -116,5 +174,6 @@ def export_mcore_gpt_to_hf_vllm_fq( export_extra_modules=export_extra_modules, dtype=dtype, moe_router_dtype=moe_router_dtype, + trust_remote_code=trust_remote_code, ) exporter.save_pretrained(export_dir, pretrained_model_name_or_path) diff --git a/modelopt/torch/quantization/plugins/vllm.py b/modelopt/torch/quantization/plugins/vllm.py index e1209607a..fef38093b 100644 --- a/modelopt/torch/quantization/plugins/vllm.py +++ b/modelopt/torch/quantization/plugins/vllm.py @@ -16,17 +16,24 @@ """Support quantization for VLLM layers.""" import importlib +from contextlib import contextmanager +from itertools import chain import torch -import vllm.attention as vllm_attention + +# Try multiple import paths for vLLM compatibility across versions +if importlib.util.find_spec("vllm.attention"): + import vllm.attention as vllm_attention # vllm < 0.16.0 +else: + import vllm.model_executor.layers.attention as vllm_attention # vllm >= 0.16.0 + import vllm.model_executor.layers.fused_moe.layer as vllm_fused_moe_layer import vllm.model_executor.layers.linear as vllm_linear -from vllm.attention.layers.cross_attention import CrossAttention -from vllm.attention.layers.encoder_only_attention import EncoderOnlyAttention from vllm.distributed.parallel_state import get_dp_group, get_ep_group, get_tp_group from ...utils.distributed import ParallelState from ..nn import QuantLinearConvBase, QuantModule, QuantModuleRegistry, TensorQuantizer +from .custom import CUSTOM_MODEL_PLUGINS # Try multiple import paths for vLLM compatibility across versions vllm_shared_fused_moe_layer = None @@ -40,14 +47,126 @@ except ImportError: continue +if importlib.util.find_spec("vllm.attention.layers"): # vllm < 0.15.0 + from vllm.attention.layers.cross_attention import CrossAttention + from vllm.attention.layers.encoder_only_attention import EncoderOnlyAttention +else: + try: + from vllm.model_executor.layers.attention.cross_attention import CrossAttention + except ImportError: + CrossAttention = None + try: + from vllm.model_executor.layers.attention.encoder_only_attention import EncoderOnlyAttention + except ImportError: + EncoderOnlyAttention = None + +if importlib.util.find_spec("vllm.attention.layer"): + import vllm.attention.layer as vllm_attention + try: - from vllm.attention.layer import MLAAttention as VllmMLAAttention + VllmMLAAttention = vllm_attention.MLAAttention except ImportError: VllmMLAAttention = None +_ATTENTION_TYPES = tuple( + t + for t in [vllm_attention.Attention, CrossAttention, EncoderOnlyAttention, VllmMLAAttention] + if t is not None +) + vllm_fused_moe_package = importlib.import_module("vllm.model_executor.layers.fused_moe.fused_moe") +@contextmanager +def disable_compilation(model): + """Disable compilation for a model. + + Args: + model: The model to disable compilation for. + """ + do_not_compile = True + if hasattr(model, "model"): + do_not_compile = model.model.do_not_compile + model.model.do_not_compile = True + elif hasattr(model, "language_model"): + do_not_compile = model.language_model.model.do_not_compile + model.language_model.model.do_not_compile = True + else: + raise ValueError("Model does not have a model or language_model attribute") + + try: + yield + finally: + if hasattr(model, "model"): + model.model.do_not_compile = do_not_compile + elif hasattr(model, "language_model"): + model.language_model.model.do_not_compile = do_not_compile + + +# vLLM Attention stores ``device``/``dtype`` as plain attrs; ``dtype`` may be a string +# (e.g. ``"float16"``, ``"auto"``). We resolve and stamp concrete torch types before +# QuantModule replacement. Priority: explicit attrs → KV-cache → shallow tensor scan. +# No model-wide fallback: a tensor from a different shard gives the wrong device under TP. + + +def _vllm_attr_dtype_to_torch(dtype) -> torch.dtype | None: + """Resolve vLLM dtype attr to ``torch.dtype``; ``None`` for ``"auto"`` (caller falls through).""" + if isinstance(dtype, torch.dtype): + return dtype + if isinstance(dtype, str) and dtype != "auto": + resolved = getattr(torch, dtype, None) + if resolved is None: + raise ValueError(f"Unrecognized vLLM dtype string: {dtype!r}") + return resolved + return None + + +def _get_device_dtype(module: torch.nn.Module) -> tuple: + """Return ``(device, dtype)`` for a vLLM Attention module, or ``(None, None)`` if unresolvable.""" + # Explicit attrs set by vLLM at construction — primary path. + dev, dt = getattr(module, "device", None), getattr(module, "dtype", None) + if dev is not None and dt is not None: + dt_resolved = _vllm_attr_dtype_to_torch(dt) + if dt_resolved is not None: + return dev, dt_resolved + + # KV-cache tensors are available after allocation; respect kv_cache_dtype when set. + kv = getattr(module, "kv_cache", None) + if kv and kv[0] is not None: + t0 = kv[0] + spec = getattr(module, "kv_cache_dtype", t0.dtype) + out_dtype = t0.dtype if spec == "auto" else (_vllm_attr_dtype_to_torch(spec) or t0.dtype) + return t0.device, out_dtype + + # Shallow scan: weights often live on child modules rather than the attention module itself. + for mod in (module, *module.children()): + for t in chain(mod.parameters(recurse=False), mod.buffers(recurse=False)): + return t.device, t.dtype + + return None, None + + +def vllm_replace_quant_module_hook(model: torch.nn.Module) -> None: + """Stamp resolved (device, dtype) onto Attention modules before QuantModule replacement.""" + for _n, m in model.named_modules(): + if isinstance(m, _ATTENTION_TYPES): + m.device, m.dtype = _get_device_dtype(m) + + +CUSTOM_MODEL_PLUGINS.add(vllm_replace_quant_module_hook) + + +def _vllm_attention_modelopt_post_restore(self) -> None: + """Move Attention module to its correct device after ModelOpt state restore.""" + device, dtype = _get_device_dtype(self) + if device is None or dtype is None: + raise RuntimeError( + "Could not determine device/dtype for vLLM Attention. " + "Ensure vllm_replace_quant_module_hook runs before replace_quant_module." + ) + self.to(device=device) + + class FakeQuantMethod: """A class that implements fake quantization methods for vLLM models. @@ -79,7 +198,8 @@ def apply( Returns: torch.Tensor: The quantized output tensor. """ - x = layer.input_quantizer(x) + if layer.input_quantizer.is_enabled: + x = layer.input_quantizer(x) if layer.weight_quantizer.is_enabled: original_weight = layer.weight quantized_tensor = layer.weight_quantizer(layer.weight) @@ -119,6 +239,21 @@ def _setup(self): self.fake_quant_method = FakeQuantMethod(self.quant_method) self.parallel_state = create_parallel_state() + def _sync_input_pre_quant_scale_to_weight(self) -> None: + """Align pre_quant_scale to weight (vLLM CUTLASS expects matching device/dtype).""" + pqs = getattr(self.input_quantizer, "_pre_quant_scale", None) + if pqs is None: + return + w = getattr(self, "weight", None) + if w is None or not isinstance(w, torch.Tensor) or w.is_meta: + return + if pqs.device != w.device or pqs.dtype != w.dtype: + self.input_quantizer._pre_quant_scale.data = pqs.data.to(device=w.device, dtype=w.dtype) + + def modelopt_post_restore(self, prefix: str = "") -> None: + super().modelopt_post_restore(prefix=prefix) + self._sync_input_pre_quant_scale_to_weight() + def forward(self, input_): # This context manager will conflict with torch.compile # with replace_function(self, "quant_method", self.fake_quant_method): @@ -130,6 +265,17 @@ def forward(self, input_): return output +def post_restore_vllm_parallel_linears(model: torch.nn.Module) -> None: + """Re-run modelopt_post_restore on vLLM parallel linears after set_quantizer_state_dict. + + restore_quantizer_state already calls modelopt_post_restore on all QuantModules, but vLLM + reload paths that load modelopt_state_weights via set_quantizer_state_dict do not. + """ + for module in model.modules(): + if isinstance(module, _VLLMParallelLinear): + module.modelopt_post_restore("") + + @QuantModuleRegistry.register({vllm_linear.RowParallelLinear: "vllm_RowParallelLinear"}) class _QuantVLLMRowParallelLinear(_VLLMParallelLinear): pass @@ -274,9 +420,11 @@ def forward(self, query, key, value, *args, **kwargs): query = self.q_bmm_quantizer(query) key = self.k_bmm_quantizer(key) value = self.v_bmm_quantizer(value) - return super().forward(query, key, value, *args, **kwargs) + def modelopt_post_restore(self, prefix: str = "") -> None: + _vllm_attention_modelopt_post_restore(self) + @QuantModuleRegistry.register({CrossAttention: "vllm_CrossAttention"}) class _QuantVLLMCrossAttention(_QuantVLLMAttention): @@ -303,3 +451,6 @@ def forward(self, query, kv_c, k_pe, *args, **kwargs): kv_c = self.kv_c_bmm_quantizer(kv_c) k_pe = self.k_pe_bmm_quantizer(k_pe) return super().forward(query, kv_c, k_pe, *args, **kwargs) + + def modelopt_post_restore(self, prefix: str = "") -> None: + _vllm_attention_modelopt_post_restore(self) diff --git a/tests/gpu/torch/export/test_vllm_fakequant_hf_export.py b/tests/gpu/torch/export/test_vllm_fakequant_hf_export.py index a156ad126..8f6071796 100644 --- a/tests/gpu/torch/export/test_vllm_fakequant_hf_export.py +++ b/tests/gpu/torch/export/test_vllm_fakequant_hf_export.py @@ -21,6 +21,7 @@ import modelopt.torch.quantization as mtq from modelopt.torch.export import export_hf_vllm_fq_checkpoint +from modelopt.torch.quantization.model_quant import fold_weight @pytest.mark.parametrize("quant_cfg", [mtq.FP8_DEFAULT_CFG]) @@ -28,13 +29,14 @@ def test_hf_vllm_export(tmp_path, quant_cfg): """Test HuggingFace model export for vLLM with fake quantization. This test verifies: - 1. Model weights match before and after export - 2. quant_amax.pth file is created, huggingface config file does not exist - 3. Amax values are correctly extracted and saved in quant_amax.pth file + 1. Input model is NOT mutated by export (weights and quantizer state unchanged) + 2. Exported weights match folded (fake-quantized) weights + 3. vllm_fq_modelopt_state.pth is created; hf_quant_config.json is not + 4. Weight quantizer states are empty in saved state dict; input quantizer amaxes preserved """ # Create a tiny LLaMA model for testing - tiny_model_dir = create_tiny_llama_dir(tmp_path, with_tokenizer=True, num_hidden_layers=2) + tiny_model_dir = create_tiny_llama_dir(tmp_path, num_hidden_layers=2) # Load the model model = AutoModelForCausalLM.from_pretrained(tiny_model_dir) @@ -48,19 +50,35 @@ def forward_loop(model): model(input_ids) model = mtq.quantize(model, quant_cfg, forward_loop) + quantizer_state_dict_before = mtq.utils.get_quantizer_state_dict(model) - model_state_dict = deepcopy(model.state_dict()) + # Compute expected exported weights: deepcopy → fold (export writes folded weights) + folded_model = deepcopy(model) + fold_weight(folded_model) + expected_weights = {k: v for k, v in folded_model.state_dict().items() if "quantizer" not in k} + del folded_model + + # Snapshot model state before export to verify it is not mutated + state_dict_before_export = {k: v.clone() for k, v in model.state_dict().items()} # Export directory export_dir = tmp_path / "vllm_export" export_dir.mkdir(exist_ok=True) - # Export for vLLM export_hf_vllm_fq_checkpoint(model, export_dir=export_dir) - # check if quant_amax.pth file exists - quant_amax_file = export_dir / "quant_amax.pth" - assert quant_amax_file.exists(), f"quant_amax.pth file should be created in {export_dir}" + # Verify the input model is not mutated: all state dict values unchanged + state_dict_after_export = model.state_dict() + for key, param_before in state_dict_before_export.items(): + assert torch.allclose(param_before, state_dict_after_export[key], atol=0), ( + f"Model was mutated by export: {key} changed" + ) + + # check if vllm_fq_modelopt_state.pth file exists + modelopt_state_file = export_dir / "vllm_fq_modelopt_state.pth" + assert modelopt_state_file.exists(), ( + f"vllm_fq_modelopt_state.pth file should be created in {export_dir}" + ) # make sure hf_quant_config.json file does not exist hf_quant_config_file = export_dir / "hf_quant_config.json" @@ -68,26 +86,28 @@ def forward_loop(model): f"hf_quant_config.json file should not be created in {export_dir}" ) - # check weights match before and after export + # check folded weights match exported model weights model_after = AutoModelForCausalLM.from_pretrained(export_dir) model_after = model_after.cuda() model_after.eval() model_after_state_dict = model_after.state_dict() - amax_state_dict = {} - for key, param in model_state_dict.items(): - if key.endswith("_amax"): - amax_state_dict[key] = param - continue - + for key, param in expected_weights.items(): assert torch.allclose(param, model_after_state_dict[key], atol=1e-6), ( f"Weight mismatch for {key}: " f"before shape={param.shape}, after shape={model_after_state_dict[key].shape}, " f"max diff={torch.abs(param - model_after_state_dict[key]).max()}" ) - # Verify amax values are correct - amax_dict = torch.load(quant_amax_file) - assert len(amax_dict) > 0, "amax_dict should not be empty" - assert amax_dict.keys() == amax_state_dict.keys(), ( - "amax keys mismatch between before and after export" + # Verify quantizer state dict: same keys, weight quantizer amaxes cleared, input amaxes kept + # weights_only=False required: modelopt_state contains Python objects (dicts, strings, etc.) + quantizer_state_dict = torch.load(modelopt_state_file)["modelopt_state_weights"] + assert len(quantizer_state_dict) > 0, ( + f"modelopt_state_weights should not be empty in {modelopt_state_file}" ) + for name, state in quantizer_state_dict.items(): + if "weight_quantizer" in name: + assert state == {}, f"weight quantizer {name} should have empty state after fold" + elif "input_quantizer" in name and any( + "_amax" in k for k in quantizer_state_dict_before[name] + ): + assert any("_amax" in k for k in state), f"input quantizer {name} should preserve _amax" diff --git a/tests/gpu_megatron/torch/export/test_vllm_fakequant_megatron_export.py b/tests/gpu_megatron/torch/export/test_vllm_fakequant_megatron_export.py index 7462f89e5..61fdabed8 100644 --- a/tests/gpu_megatron/torch/export/test_vllm_fakequant_megatron_export.py +++ b/tests/gpu_megatron/torch/export/test_vllm_fakequant_megatron_export.py @@ -95,8 +95,8 @@ def forward_loop(model): ) # check if quant_amax.pth file exists - quant_amax_file = export_dir / "quant_amax.pth" - assert quant_amax_file.exists(), f"quant_amax.pth file should be created in {export_dir}" + quant_amax_file = export_dir / "quantizer_state.pth" + assert quant_amax_file.exists(), f"quantizer_state.pth file should be created in {export_dir}" # make sure hf_quant_config.json file does not exist hf_quant_config_file = export_dir / "hf_quant_config.json"