diff --git a/docs/source/guides/1_quantization.rst b/docs/source/guides/1_quantization.rst index a838bfb10..38ce0956b 100644 --- a/docs/source/guides/1_quantization.rst +++ b/docs/source/guides/1_quantization.rst @@ -19,6 +19,7 @@ Below, you can find the documentation for the quantization toolkit in ModelOpt: ./_basic_quantization.rst ./_choosing_quant_methods.rst ./_pytorch_quantization.rst + ./_quant_cfg.rst ./_customized_model_quantization.rst ./_compress_quantized_models.rst ./_onnx_quantization.rst diff --git a/docs/source/guides/_pytorch_quantization.rst b/docs/source/guides/_pytorch_quantization.rst index 15a7da9f1..1b454e70e 100644 --- a/docs/source/guides/_pytorch_quantization.rst +++ b/docs/source/guides/_pytorch_quantization.rst @@ -237,14 +237,16 @@ For debugging purposes or simple customizations, you can modify an existing conf .. code-block:: python - # Create a copy of the default INT8 configuration - config = mtq.INT8_DEFAULT_CFG.copy() + import copy - # Disable input quantizers for all layers - config["quant_cfg"]["*input_quantizer"]["enable"] = False + # Create a deep copy of the default INT8 configuration + config = copy.deepcopy(mtq.INT8_DEFAULT_CFG) + + # Disable input quantizers for all layers (appended last, so it takes precedence) + config["quant_cfg"].append({"quantizer_path": "*input_quantizer", "enable": False}) # Disable all quantizers for layers matching the pattern "layer1.*" - config["quant_cfg"]["*layer1.*"] = {"enable": False} + config["quant_cfg"].append({"quantizer_path": "*layer1.*", "enable": False}) Advanced Configuration Creation ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ @@ -255,16 +257,19 @@ For exploring new quantization recipes, you can compose a completely new configu # Custom configuration for INT4 block-wise weights and INT8 dynamic activations MY_CUSTOM_CONFIG = { - "quant_cfg": { + "quant_cfg": [ + # Disable all quantizers by default, then enable selectively + {"quantizer_path": "*", "enable": False}, + # Configure weight quantizers with 4-bit precision and 128-element blocks - "*weight_quantizer": {"num_bits": 4, "block_sizes": {-1: 128}, "enable": True}, + {"quantizer_path": "*weight_quantizer", "cfg": {"num_bits": 4, "block_sizes": {-1: 128}}, "enable": True}, # Configure input quantizers with 8-bit dynamic quantization - "*input_quantizer": {"num_bits": 8, "type": "dynamic", "block_sizes": {-1: None}}, + {"quantizer_path": "*input_quantizer", "cfg": {"num_bits": 8, "type": "dynamic", "block_sizes": {-1: None}}}, # Include default disabled quantizer configurations - **_default_disabled_quantizer_cfg, - }, + *_default_disabled_quantizer_cfg, + ], "algorithm": "max", } @@ -394,8 +399,10 @@ You can specify ``custom_calib`` as ``algorithm`` in ``quant_cfg`` to use it. He # create quantization configuration with "custom_calib" method quant_cfg = { - 'quant_cfg': {'*weight_quantizer': ..}, - 'algorithm': {"method": 'custom_calib'}, + 'quant_cfg': [ + {"quantizer_path": "*weight_quantizer", "cfg": {...}}, + ], + 'algorithm': {"method": 'custom_calib'}, } diff --git a/docs/source/guides/_quant_cfg.rst b/docs/source/guides/_quant_cfg.rst new file mode 100644 index 000000000..b3d37cdb3 --- /dev/null +++ b/docs/source/guides/_quant_cfg.rst @@ -0,0 +1,307 @@ +.. _quant-cfg: + +====================================== +Quantization Configuration (quant_cfg) +====================================== + +The ``quant_cfg`` field is the primary mechanism for controlling which quantizers are active in a +model and how they are configured. This guide explains the format, ordering semantics, and common +patterns for composing quantization configurations. + +.. tip:: + + For the list of built-in configs and supported formats, see :any:`quantization-formats`. + For how to apply a config to a model, see :any:`_pytorch_quantization`. + +---------- + +Overview +======== + +A quantization config is a Python dictionary with two top-level keys: + +.. code-block:: python + + config = { + "quant_cfg": [...], # ordered list of QuantizerCfgEntry dicts + "algorithm": "max", # calibration algorithm + } + +The ``quant_cfg`` value is an **ordered list** of :class:`QuantizerCfgEntry +` dicts. Each entry targets a set of +quantizer modules in the model and specifies their configuration. + +---------- + +Entry Format +============ + +Each entry in the list is a dictionary with the following fields: + +.. list-table:: + :header-rows: 1 + :widths: 20 15 65 + + * - Field + - Required + - Description + * - ``quantizer_path`` + - Yes + - Wildcard string matched against quantizer module names (e.g. ``"*weight_quantizer"``). + Uses :func:`fnmatch` rules. + * - ``parent_class`` + - No + - Restricts matching to quantizers whose immediate parent module is of this PyTorch class + (e.g. ``"nn.Linear"``). If omitted, all modules are targeted regardless of class. + * - ``cfg`` + - No + - A dict of quantizer attributes as defined by :class:`QuantizerAttributeConfig + `, or a list of such dicts + for sequential quantization (see :ref:`sequential-quantizers`). + * - ``enable`` + - No + - ``True`` or ``False``. Toggles matched quantizers on or off, independently of ``cfg``. + When ``cfg`` is absent, **only** the enabled/disabled state is changed — all other + attributes remain untouched. When ``cfg`` is present, ``enable`` sets the enabled state + of the newly-configured quantizer. When ``cfg`` is present and ``enable`` is omitted, + the quantizer is implicitly enabled (``True``). + +.. note:: + + Every entry must specify at least one of ``cfg`` or ``enable`` in addition to + ``quantizer_path``. An entry with only ``quantizer_path`` and no other keys is **invalid** + and will raise a ``ValueError`` at config-processing time. This prevents subtle bugs where + a bare ``{"quantizer_path": "*"}`` would silently behave as ``enable=True`` for all + quantizers. + +---------- + +Default Quantizer Configuration +================================ + +When a quantizer is enabled but has never been touched by a ``cfg`` entry — either because no +entry in the list matched it, or because it was only reached by enable-only entries — it operates +with the default attributes of +:class:`QuantizerAttributeConfig `: + +.. code-block:: python + + { + "num_bits": 8, # 8-bit integer quantization + "axis": None, # per-tensor scale (no per-channel axis) + "fake_quant": True, # simulate quantization in forward pass (PTQ / QAT) + "unsigned": False, # signed integer range, e.g. [-128, 127] for INT8 + "narrow_range": False, # full range; True would restrict to [-127, 127] for INT8 + "type": "static", # static calibration (not dynamic per-inference) + "block_sizes": None, # no block quantization; set for NF4 / MXFP formats + "bias": None, # no affine bias correction + "calibrator": "max", # use max-abs calibration to determine amax + "rotate": False, # no Hadamard rotation (QuaRot / SpinQuant) + "pass_through_bwd": True, # straight-through estimator for QAT gradients + "trt_high_precision_dtype": "Float", # cast QDQ nodes to fp32 for TRT StronglyType export + "backend": None, # use the built-in quantization backend + "backend_extra_args": None, # no extra args for custom backends + "use_constant_amax": False, # calibrate amax; True hard-codes FP8 E4M3 max (448.0) + } + +In practice this means an un-configured but enabled quantizer performs **INT8 per-tensor static +fake-quantization** with a max-calibrated scale. This is rarely the intended behavior — every +quantizer you want active should be explicitly configured with a ``cfg`` entry. + +---------- + +Ordering and Precedence +======================= + +Entries are applied **in list order**. Later entries override earlier ones for any quantizer they +match. This gives a clear, composable precedence model: + +- Put broad rules (e.g. deny-all) **first**. +- Put format-specific enable rules **after**. +- Put fine-grained exclusions (specific layers, classes) **last**. + +The recommended pattern used by all built-in configs is: + +.. code-block:: python + + "quant_cfg": [ + # 1. Deny all quantizers by default + {"quantizer_path": "*", "enable": False}, + + # 2. Enable and configure the target quantizers + {"quantizer_path": "*weight_quantizer", "cfg": {"num_bits": 8, "axis": 0}}, + {"quantizer_path": "*input_quantizer", "cfg": {"num_bits": 8, "axis": None}}, + + # 3. Apply standard exclusions last (BatchNorm, LM head, MoE routers, etc.) + *mtq.config._default_disabled_quantizer_cfg, + ] + +.. note:: + + The deny-all entry ``{"quantizer_path": "*", "enable": False}`` is available as + :data:`modelopt.torch.quantization.config._base_disable_all` and is prepended to every + built-in config. This ensures quantizers not explicitly targeted remain disabled. + +---------- + +Entry Atomicity +=============== + +Each ``cfg``-bearing entry in ``quant_cfg`` is a **complete, self-contained configuration unit**. +When an entry with ``cfg`` matches a quantizer, it **completely replaces** that quantizer's +configuration — it does not merge with or incrementally update settings left by earlier entries. + +Concretely, if an entry specifies only a subset of quantizer attributes (e.g. only ``num_bits``), +all unspecified attributes are filled in with their default values from +:class:`QuantizerAttributeConfig `. +The resulting *complete* config is then written to the quantizer, discarding whatever any prior +matching entry had set. + +This means: + +- **Last cfg-entry wins, fully.** If two entries both match ``*weight_quantizer`` and both carry + a ``cfg``, the second entry does not inherit the first entry's settings — it replaces them entirely. +- **No hidden state accumulation.** The final configuration of a quantizer depends only on the + *last* ``cfg``-bearing entry in the list that matched it, making behavior easy to reason about. +- **Changing one field requires a full spec.** Because each ``cfg`` entry is a complete replacement, + to change only one attribute of a quantizer that was already configured, you must reproduce the + full desired config in the new entry. Any attribute omitted from the entry will revert to its + default, not to the value set by an earlier entry. + +**Enable-only entries are the exception.** An entry with no ``cfg`` (only ``enable``) is *not* a +full replacement — it solely flips the on/off state of matched quantizers, leaving all other +attributes unchanged: + +- ``{"quantizer_path": "*", "enable": False}`` disables all quantizers without touching their + configured attributes. Use this as the first step in a deny-all-then-configure pattern. +- ``{"quantizer_path": "*weight_quantizer", "enable": True}`` (no ``cfg``) re-enables weight + quantizers using whatever attributes they currently carry (or their defaults if they were never + configured by a ``cfg`` entry). + +For example, given the following two entries both matching ``*weight_quantizer``: + +.. code-block:: python + + # Entry 1 — sets FP8 per-channel + {"quantizer_path": "*weight_quantizer", "cfg": {"num_bits": (4, 3), "axis": 0}}, + + # Entry 2 — sets INT4 blockwise (axis is NOT inherited from Entry 1) + {"quantizer_path": "*weight_quantizer", "cfg": {"num_bits": 4, "block_sizes": {-1: 128}}}, + +After Entry 2 is applied, the quantizer has ``num_bits=4``, ``block_sizes={-1: 128}``, and +``axis=None`` (the default). The ``axis=0`` set by Entry 1 is gone. + +.. note:: + + The deny-all-then-configure pattern is safe and predictable precisely because + ``{"quantizer_path": "*", "enable": False}`` **only** disables quantizers without resetting + their attributes. Subsequent ``cfg`` entries then configure targets from a known default state. + +---------- + +Common Patterns +=============== + +Skipping Specific Layers +------------------------ + +Append a disable entry after the existing config to exclude layers matched by a path pattern. +Because it is appended last, it takes precedence over all earlier entries: + +.. code-block:: python + + import copy + import modelopt.torch.quantization as mtq + + config = copy.deepcopy(mtq.FP8_DEFAULT_CFG) + + # Skip the final projection layer + config["quant_cfg"].append({"quantizer_path": "*lm_head*", "enable": False}) + + model = mtq.quantize(model, config, forward_loop) + +Skipping Layers by Module Class +-------------------------------- + +Use ``parent_class`` to target quantizers only within a specific type of layer, leaving the +same quantizer path in other layer types unaffected: + +.. code-block:: python + + config["quant_cfg"].append({ + "quantizer_path": "*input_quantizer", + "parent_class": "nn.LayerNorm", + "enable": False, + }) + +Overriding Quantizer Precision for Specific Layers +--------------------------------------------------- + +A later entry with a matching ``quantizer_path`` replaces the configuration set by an earlier +entry. This allows per-layer precision overrides without restructuring the entire config: + +.. code-block:: python + + config = copy.deepcopy(mtq.FP8_DEFAULT_CFG) + + # Quantize attention output projections in higher-precision INT8 instead of FP8 + config["quant_cfg"].append({ + "quantizer_path": "*o_proj*weight_quantizer", + "cfg": {"num_bits": 8, "axis": 0}, + }) + +Building a Config from Scratch +------------------------------- + +For entirely custom recipes, compose the list directly: + +.. code-block:: python + + from modelopt.torch.quantization.config import _base_disable_all, _default_disabled_quantizer_cfg + + MY_CUSTOM_CFG = { + "quant_cfg": [ + *_base_disable_all, + {"quantizer_path": "*weight_quantizer", "cfg": {"num_bits": 4, "block_sizes": {-1: 128}}}, + {"quantizer_path": "*input_quantizer", "cfg": {"num_bits": 8, "axis": None}}, + *_default_disabled_quantizer_cfg, + ], + "algorithm": "max", + } + + model = mtq.quantize(model, MY_CUSTOM_CFG, forward_loop) + +---------- + +.. _sequential-quantizers: + +Sequential Quantization +======================= + +When ``cfg`` is a **list** of attribute dicts, the matched +:class:`TensorQuantizer ` +is replaced with a +:class:`SequentialQuantizer ` +that applies each format in sequence. This is used, for example, in W4A8 quantization where weights +are quantized first in INT4 and then in FP8: + +.. code-block:: python + + { + "quantizer_path": "*weight_quantizer", + "cfg": [ + {"num_bits": 4, "block_sizes": {-1: 128, "type": "static"}}, + {"num_bits": (4, 3)}, # FP8 + ], + "enable": True, + } + +---------- + +Reference +========= + +- :class:`QuantizerCfgEntry ` +- :class:`QuantizerAttributeConfig ` +- :class:`QuantizeConfig ` +- :func:`set_quantizer_by_cfg ` diff --git a/examples/deepseek/ptq.py b/examples/deepseek/ptq.py index bcfd9de40..faad47eca 100644 --- a/examples/deepseek/ptq.py +++ b/examples/deepseek/ptq.py @@ -309,38 +309,70 @@ def calibrate_loop(model): mtq_cfg = getattr(mtq, quant_cfg) # disable head that corresponds to lm_head (for the huggingface checkpoint) - mtq_cfg["quant_cfg"]["*head*"] = {"enable": False} + mtq_cfg["quant_cfg"].append({"quantizer_path": "*head*", "enable": False}) allowed_mla_quant = [None, "per_tensor_fp8", "nvfp4"] assert mla_quant in allowed_mla_quant, f"mla_quant must be {allowed_mla_quant}" if not mla_quant: - mtq_cfg["quant_cfg"]["*attn*"] = {"enable": False} + mtq_cfg["quant_cfg"].append({"quantizer_path": "*attn*", "enable": False}) elif mla_quant == "per_tensor_fp8": - mtq_cfg["quant_cfg"]["*attn*weight_quantizer"] = {"num_bits": (4, 3), "axis": None} - mtq_cfg["quant_cfg"]["*attn*input_quantizer"] = {"num_bits": (4, 3), "axis": None} + mtq_cfg["quant_cfg"].extend( + [ + { + "quantizer_path": "*attn*weight_quantizer", + "cfg": {"num_bits": (4, 3), "axis": None}, + }, + { + "quantizer_path": "*attn*input_quantizer", + "cfg": {"num_bits": (4, 3), "axis": None}, + }, + ] + ) elif mla_quant == "nvfp4": # for DeepSeek-R1-0528-NVFP4-Turbo mla_linear_layers = ["*wq_a*", "*wq_b*", "*wkv_a*", "*wkv_b*", "*wo*"] mla_nvfp4_linear_layers = ["*wq_a*", "*wkv_a*", "*wq_b*", "*wo*"] for layer in mla_linear_layers: if layer in mla_nvfp4_linear_layers: # wq_a, wkv_a, wq_b, wo use NVFP4 quantization - mtq_cfg["quant_cfg"][layer + "_quantizer"] = { - "num_bits": (2, 1), - "block_sizes": {-1: 16, "type": "dynamic", "scale_bits": (4, 3)}, - "axis": None, - "enable": True, - } + mtq_cfg["quant_cfg"].append( + { + "quantizer_path": layer + "_quantizer", + "cfg": { + "num_bits": (2, 1), + "block_sizes": {-1: 16, "type": "dynamic", "scale_bits": (4, 3)}, + "axis": None, + }, + "enable": True, + } + ) else: - mtq_cfg["quant_cfg"][layer + "_quantizer"] = {"enable": False} + mtq_cfg["quant_cfg"].append( + {"quantizer_path": layer + "_quantizer", "enable": False} + ) # Disable BMM quantizers - mtq_cfg["quant_cfg"]["*attn.kv_bmm_quantizer*"] = {"enable": False} - mtq_cfg["quant_cfg"]["*attn.pe_bmm_quantizer*"] = {"enable": False} + mtq_cfg["quant_cfg"].extend( + [ + {"quantizer_path": "*attn.kv_bmm_quantizer*", "enable": False}, + {"quantizer_path": "*attn.pe_bmm_quantizer*", "enable": False}, + ] + ) if not args.disable_wo_quant and "FP4" in quant_cfg: - mtq_cfg["quant_cfg"]["*wo*weight_quantizer"] = mtq_cfg["quant_cfg"]["*input_quantizer"] - mtq_cfg["quant_cfg"]["*wo*input_quantizer"] = mtq_cfg["quant_cfg"]["*weight_quantizer"] + # Find the default input/weight quantizer cfgs to swap for wo layers + input_cfg = next( + e["cfg"] for e in mtq_cfg["quant_cfg"] if e.get("quantizer_path") == "*input_quantizer" + ) + weight_cfg = next( + e["cfg"] for e in mtq_cfg["quant_cfg"] if e.get("quantizer_path") == "*weight_quantizer" + ) + mtq_cfg["quant_cfg"].extend( + [ + {"quantizer_path": "*wo*weight_quantizer", "cfg": input_cfg}, + {"quantizer_path": "*wo*input_quantizer", "cfg": weight_cfg}, + ] + ) ## ptq transformer = mtq.quantize(transformer, mtq_cfg, calibrate_loop) diff --git a/examples/diffusers/quantization/config.py b/examples/diffusers/quantization/config.py index 94063ffd9..9f24ec15f 100644 --- a/examples/diffusers/quantization/config.py +++ b/examples/diffusers/quantization/config.py @@ -17,82 +17,79 @@ from calib.plugin_calib import PercentileCalibrator FP8_DEFAULT_CONFIG = { - "quant_cfg": { - "*weight_quantizer": {"num_bits": (4, 3), "axis": None}, - "*input_quantizer": {"num_bits": (4, 3), "axis": None}, - "*output_quantizer": {"enable": False}, - "*softmax_quantizer": { - "num_bits": (4, 3), - "axis": None, - }, - "default": {"enable": False}, - }, + "quant_cfg": [ + {"quantizer_path": "*", "enable": False}, + {"quantizer_path": "*weight_quantizer", "cfg": {"num_bits": (4, 3), "axis": None}}, + {"quantizer_path": "*input_quantizer", "cfg": {"num_bits": (4, 3), "axis": None}}, + {"quantizer_path": "*output_quantizer", "enable": False}, + {"quantizer_path": "*softmax_quantizer", "cfg": {"num_bits": (4, 3), "axis": None}}, + ], "algorithm": "max", } INT8_DEFAULT_CONFIG = { - "quant_cfg": { - "*weight_quantizer": {"num_bits": 8, "axis": 0}, - "*input_quantizer": {"num_bits": 8, "axis": None}, - "*output_quantizer": {"enable": False}, - "default": {"enable": False}, - }, + "quant_cfg": [ + {"quantizer_path": "*", "enable": False}, + {"quantizer_path": "*weight_quantizer", "cfg": {"num_bits": 8, "axis": 0}}, + {"quantizer_path": "*input_quantizer", "cfg": {"num_bits": 8, "axis": None}}, + {"quantizer_path": "*output_quantizer", "enable": False}, + ], "algorithm": "max", } NVFP4_DEFAULT_CONFIG = { - "quant_cfg": { - "*weight_quantizer": { - "num_bits": (2, 1), - "block_sizes": {-1: 16, "type": "dynamic", "scale_bits": (4, 3)}, - "axis": None, + "quant_cfg": [ + {"quantizer_path": "*", "enable": False}, + { + "quantizer_path": "*weight_quantizer", + "cfg": { + "num_bits": (2, 1), + "block_sizes": {-1: 16, "type": "dynamic", "scale_bits": (4, 3)}, + "axis": None, + }, "enable": True, }, - "*input_quantizer": { - "num_bits": (2, 1), - "block_sizes": {-1: 16, "type": "dynamic", "scale_bits": (4, 3)}, - "axis": None, + { + "quantizer_path": "*input_quantizer", + "cfg": { + "num_bits": (2, 1), + "block_sizes": {-1: 16, "type": "dynamic", "scale_bits": (4, 3)}, + "axis": None, + }, "enable": True, }, - "*output_quantizer": {"enable": False}, - "*softmax_quantizer": { - "num_bits": (4, 3), - "axis": None, - }, - "default": {"enable": False}, - }, + {"quantizer_path": "*output_quantizer", "enable": False}, + {"quantizer_path": "*softmax_quantizer", "cfg": {"num_bits": (4, 3), "axis": None}}, + ], "algorithm": "max", } NVFP4_FP8_MHA_CONFIG = { - "quant_cfg": { - "**weight_quantizer": { - "num_bits": (2, 1), - "block_sizes": {-1: 16, "type": "dynamic", "scale_bits": (4, 3)}, - "axis": None, + "quant_cfg": [ + {"quantizer_path": "*", "enable": False}, + { + "quantizer_path": "**weight_quantizer", + "cfg": { + "num_bits": (2, 1), + "block_sizes": {-1: 16, "type": "dynamic", "scale_bits": (4, 3)}, + "axis": None, + }, "enable": True, }, - "**input_quantizer": { - "num_bits": (2, 1), - "block_sizes": {-1: 16, "type": "dynamic", "scale_bits": (4, 3)}, - "axis": None, + { + "quantizer_path": "**input_quantizer", + "cfg": { + "num_bits": (2, 1), + "block_sizes": {-1: 16, "type": "dynamic", "scale_bits": (4, 3)}, + "axis": None, + }, "enable": True, }, - "*output_quantizer": {"enable": False}, - "*[qkv]_bmm_quantizer": { - "num_bits": (4, 3), - "axis": None, - }, - "*softmax_quantizer": { - "num_bits": (4, 3), - "axis": None, - }, - "*bmm2_output_quantizer": { - "num_bits": (4, 3), - "axis": None, - }, - "default": {"enable": False}, - }, + {"quantizer_path": "*output_quantizer", "enable": False}, + {"quantizer_path": "*[qkv]_bmm_quantizer", "cfg": {"num_bits": (4, 3), "axis": None}}, + {"quantizer_path": "*softmax_quantizer", "cfg": {"num_bits": (4, 3), "axis": None}}, + {"quantizer_path": "*bmm2_output_quantizer", "cfg": {"num_bits": (4, 3), "axis": None}}, + ], "algorithm": {"method": "svdquant", "lowrank": 32}, } @@ -106,8 +103,9 @@ def set_quant_config_attr(quant_config, trt_high_precision_dtype, quant_algo, ** algo_cfg["lowrank"] = kwargs["lowrank"] quant_config["algorithm"] = algo_cfg - for p in quant_config["quant_cfg"].values(): - if "num_bits" in p and "trt_high_precision_dtype" not in p: + for entry in quant_config["quant_cfg"]: + p = entry.get("cfg", {}) + if isinstance(p, dict) and "num_bits" in p and "trt_high_precision_dtype" not in p: p["trt_high_precision_dtype"] = trt_high_precision_dtype @@ -127,18 +125,23 @@ def reset_set_int8_config(quant_config, percentile, n_steps, collect_method, bac for name, module in backbone.named_modules(): if isinstance(module, nn.Conv2d): aq_name = f"*{name}*input_quantizer*" - quant_config["quant_cfg"][aq_name] = { - "num_bits": 8, - "axis": None, - "calibrator": ( - PercentileCalibrator, - (), - { + quant_config["quant_cfg"].append( + { + "quantizer_path": aq_name, + "cfg": { "num_bits": 8, "axis": None, - "percentile": percentile, - "total_step": n_steps, - "collect_method": collect_method, + "calibrator": ( + PercentileCalibrator, + (), + { + "num_bits": 8, + "axis": None, + "percentile": percentile, + "total_step": n_steps, + "collect_method": collect_method, + }, + ), }, - ), - } + } + ) diff --git a/examples/diffusers/quantization/quantize.py b/examples/diffusers/quantization/quantize.py index 612357f6e..cb4b1e003 100644 --- a/examples/diffusers/quantization/quantize.py +++ b/examples/diffusers/quantization/quantize.py @@ -137,7 +137,12 @@ def get_quant_config(self, n_steps: int, backbone: torch.nn.Module) -> Any: else: raise NotImplementedError(f"Unknown format {self.config.format}") if self.config.quantize_mha: - quant_config["quant_cfg"]["*[qkv]_bmm_quantizer"] = {"num_bits": (4, 3), "axis": None} # type: ignore[index] + quant_config["quant_cfg"].append( + { + "quantizer_path": "*[qkv]_bmm_quantizer", + "cfg": {"num_bits": (4, 3), "axis": None}, + } + ) set_quant_config_attr( quant_config, self.model_config.trt_high_precision_dtype.value, diff --git a/examples/llm_autodeploy/run_auto_quantize.py b/examples/llm_autodeploy/run_auto_quantize.py index e9ecb0731..73308ed7f 100644 --- a/examples/llm_autodeploy/run_auto_quantize.py +++ b/examples/llm_autodeploy/run_auto_quantize.py @@ -100,11 +100,21 @@ def loss_func(output, data): if enable_kv_cache_quantization: mtq.set_quantizer_by_cfg( model, - quant_cfg={"*output_quantizer": {"num_bits": (4, 3), "axis": None, "enable": True}}, + quant_cfg=[ + { + "quantizer_path": "*output_quantizer", + "cfg": {"num_bits": (4, 3), "axis": None}, + "enable": True, + } + ], ) # Lets calibrate only the output quantizer this time. Let's disable all other quantizers. with mtq.set_quantizer_by_cfg_context( - model, {"*": {"enable": False}, "*output_quantizer": {"enable": True}} + model, + [ + {"quantizer_path": "*", "enable": False}, + {"quantizer_path": "*output_quantizer", "enable": True}, + ], ): mtq.calibrate(model, algorithm="max", forward_loop=calibrate_loop) return model diff --git a/examples/llm_eval/quantization_utils.py b/examples/llm_eval/quantization_utils.py index 3df44115a..466f65ced 100644 --- a/examples/llm_eval/quantization_utils.py +++ b/examples/llm_eval/quantization_utils.py @@ -33,12 +33,20 @@ # Modify your custom config for debugging or research purposes. CUSTOM_CONFIG = { "MY_QUANT_CONFIG": { - "quant_cfg": { - "*weight_quantizer": {"num_bits": 4, "block_sizes": {-1: 128}, "enable": True}, - "*input_quantizer": {"num_bits": 8, "type": "dynamic", "block_sizes": {-1: None}}, + "quant_cfg": [ + *mtq.config._base_disable_all, + { + "quantizer_path": "*weight_quantizer", + "cfg": {"num_bits": 4, "block_sizes": {-1: 128}}, + "enable": True, + }, + { + "quantizer_path": "*input_quantizer", + "cfg": {"num_bits": 8, "type": "dynamic", "block_sizes": {-1: None}}, + }, # Disable sensitive layers such as `lm_head`, gate layers in MoE etc. - **mtq.config._default_disabled_quantizer_cfg, - }, + *mtq.config._default_disabled_quantizer_cfg, + ], "algorithm": "max", }, } diff --git a/examples/llm_ptq/example_utils.py b/examples/llm_ptq/example_utils.py index 58eb67611..f73936a81 100755 --- a/examples/llm_ptq/example_utils.py +++ b/examples/llm_ptq/example_utils.py @@ -205,7 +205,12 @@ def build_quant_cfg( ) -> dict[str, Any]: quant_cfg = copy.deepcopy(quant_cfg) if "awq" in str(quant_cfg.get("algorithm")): - weight_quantizer = quant_cfg["quant_cfg"]["*weight_quantizer"] + weight_quantizer_entry = next( + e + for e in quant_cfg["quant_cfg"] + if isinstance(e, dict) and e.get("quantizer_path") == "*weight_quantizer" + ) + weight_quantizer = weight_quantizer_entry.get("cfg", {}) if isinstance(weight_quantizer, list): weight_quantizer = weight_quantizer[0] # If awq_block_size argument is provided, update weight_quantizer @@ -236,10 +241,10 @@ def build_quant_cfg( if model_type == "phi4mm": # Only quantize the language model - quant_cfg["quant_cfg"]["*speech*"] = {"enable": False} - quant_cfg["quant_cfg"]["*audio*"] = {"enable": False} - quant_cfg["quant_cfg"]["*image*"] = {"enable": False} - quant_cfg["quant_cfg"]["*vision*"] = {"enable": False} + quant_cfg["quant_cfg"].append({"quantizer_path": "*speech*", "enable": False}) + quant_cfg["quant_cfg"].append({"quantizer_path": "*audio*", "enable": False}) + quant_cfg["quant_cfg"].append({"quantizer_path": "*image*", "enable": False}) + quant_cfg["quant_cfg"].append({"quantizer_path": "*vision*", "enable": False}) return quant_cfg diff --git a/examples/llm_ptq/hf_ptq.py b/examples/llm_ptq/hf_ptq.py index 5620ddf6a..9c6335b9d 100755 --- a/examples/llm_ptq/hf_ptq.py +++ b/examples/llm_ptq/hf_ptq.py @@ -77,16 +77,27 @@ RAND_SEED = 1234 -def _set_kv_cache_constant_amax(quant_cfg: dict) -> None: +def _set_kv_cache_constant_amax(quant_cfg: list) -> None: """Set use_constant_amax on KV cache quantizers. Creates a new dict for the KV bmm quantizer config to avoid mutating shared references. """ - if "*[kv]_bmm_quantizer" in quant_cfg: - quant_cfg["*[kv]_bmm_quantizer"] = { - **quant_cfg["*[kv]_bmm_quantizer"], - "use_constant_amax": True, - } + for i, entry in enumerate(quant_cfg): + pattern = ( + entry["quantizer_path"] + if isinstance(entry, dict) and "quantizer_path" in entry + else entry[0] + ) + if pattern == "*[kv]_bmm_quantizer": + assert isinstance(entry, dict) and isinstance(entry.get("cfg", {}), dict) + new_entry = { + "quantizer_path": "*[kv]_bmm_quantizer", + "cfg": {**entry.get("cfg", {}), "use_constant_amax": True}, + } + if entry.get("enable") is not None: + new_entry["enable"] = entry["enable"] + quant_cfg[i] = new_entry + break QUANT_CFG_CHOICES: dict[str, dict[str, Any]] = { @@ -144,7 +155,7 @@ def extract_and_prepare_language_model_from_vl(full_model): # Apply disabled quant to all modules that are not part of language_model # This excludes them during HF export disabled_quant_cfg = { - "quant_cfg": {"default": {"enable": False}}, + "quant_cfg": [{"quantizer_path": "*", "enable": False}], "algorithm": "max", } @@ -318,7 +329,7 @@ def forward_step(model, batch): ), verbose=True, # Disable all default disabled layers such as lm_head, mlp.gate, router etc. - disabled_layers=list(_default_disabled_quantizer_cfg.keys()), + disabled_layers=[entry["quantizer_path"] for entry in _default_disabled_quantizer_cfg], method=auto_quantize_method, checkpoint=auto_quantize_checkpoint, ) @@ -331,7 +342,9 @@ def forward_step(model, batch): kv_cache_quant_cfg = copy.deepcopy( getattr(mtq, KV_QUANT_CFG_CHOICES[args.kv_cache_qformat])["quant_cfg"] ) - kv_cache_quant_cfg.pop("default", None) # keep other quantizers from auto_quantize + kv_cache_quant_cfg = [ + e for e in kv_cache_quant_cfg if e["quantizer_path"] != "*" + ] # keep other quantizers from auto_quantize if args.kv_cache_qformat in _KV_CAST_FORMATS: _set_kv_cache_constant_amax(kv_cache_quant_cfg) @@ -340,7 +353,8 @@ def forward_step(model, batch): if args.kv_cache_qformat not in _KV_CAST_FORMATS: # Calibrate only the KV cache quantizers; disable all others. with mtq.set_quantizer_by_cfg_context( - language_model, {"*": {"enable": False}, **kv_cache_quant_cfg} + language_model, + [{"quantizer_path": "*", "enable": False}, *kv_cache_quant_cfg], ): mtq.calibrate(language_model, algorithm="max", forward_loop=calibrate_loop) return language_model @@ -543,13 +557,17 @@ def mono_quantize( # For Nemotron VL models, disable quantization of vision components if is_nemotron_vl_model: print("Disabling quantization for vision components in Nemotron VL model") - quant_cfg["quant_cfg"]["*vision*"] = {"enable": False} - quant_cfg["quant_cfg"]["*image*"] = {"enable": False} + quant_cfg["quant_cfg"].append({"quantizer_path": "*vision*", "enable": False}) + quant_cfg["quant_cfg"].append({"quantizer_path": "*image*", "enable": False}) # Also disable radio model components specifically (for Nemotron-Parse) - quant_cfg["quant_cfg"]["*radio*"] = {"enable": False} - quant_cfg["quant_cfg"]["*visual*"] = {"enable": False} - quant_cfg["quant_cfg"]["*encoder*"] = {"enable": False} # Disable encoder - quant_cfg["quant_cfg"]["*model_encoder*"] = {"enable": False} # Nemotron-Parse specific + quant_cfg["quant_cfg"].append({"quantizer_path": "*radio*", "enable": False}) + quant_cfg["quant_cfg"].append({"quantizer_path": "*visual*", "enable": False}) + quant_cfg["quant_cfg"].append( + {"quantizer_path": "*encoder*", "enable": False} + ) # Disable encoder + quant_cfg["quant_cfg"].append( + {"quantizer_path": "*model_encoder*", "enable": False} + ) # Nemotron-Parse specific print("Quantization will only be applied to the decoder (text generation) component") if not model_is_already_quantized or calibration_only: @@ -968,7 +986,7 @@ def quantize_main( for prefix in mtp_layer_prefixes: # Add exclusion pattern for this MTP layer (e.g., "*layers.92*") pattern = f"*{prefix.split('.')[-2]}.{prefix.split('.')[-1]}*" - quant_cfg["quant_cfg"][pattern] = {"enable": False} + quant_cfg["quant_cfg"].append({"quantizer_path": pattern, "enable": False}) print(f"Excluding MTP layer from quantization: {pattern}") # Use constant amax for KV quantizers when a cast format is selected. diff --git a/examples/llm_ptq/notebooks/2_PTQ_AWQ_Calibration.ipynb b/examples/llm_ptq/notebooks/2_PTQ_AWQ_Calibration.ipynb index fc055cf84..0892cec63 100644 --- a/examples/llm_ptq/notebooks/2_PTQ_AWQ_Calibration.ipynb +++ b/examples/llm_ptq/notebooks/2_PTQ_AWQ_Calibration.ipynb @@ -192,7 +192,7 @@ "source": [ "# Get default AWQ config and optionally adjust block size\n", "quant_cfg = mtq.INT4_AWQ_CFG\n", - "weight_quantizer = quant_cfg[\"quant_cfg\"][\"*weight_quantizer\"]\n", + "weight_quantizer = next(cfg for pat, cfg in quant_cfg[\"quant_cfg\"] if pat == \"*weight_quantizer\")\n", "if isinstance(weight_quantizer, list):\n", " weight_quantizer = weight_quantizer[0]\n", "weight_quantizer[\"block_sizes\"][-1] = 128 # Optional: override block size\n", diff --git a/examples/llm_ptq/notebooks/3_PTQ_AutoQuantization.ipynb b/examples/llm_ptq/notebooks/3_PTQ_AutoQuantization.ipynb index 122569489..9634c615d 100644 --- a/examples/llm_ptq/notebooks/3_PTQ_AutoQuantization.ipynb +++ b/examples/llm_ptq/notebooks/3_PTQ_AutoQuantization.ipynb @@ -288,7 +288,9 @@ " mtq.set_quantizer_by_cfg(model, quant_cfg=kv_cfg)\n", "\n", " # Calibrate **only** those quantizers\n", - " with mtq.set_quantizer_by_cfg_context(model, {\"*\": {\"enable\": False}, **kv_cfg}):\n", + " with mtq.set_quantizer_by_cfg_context(\n", + " model, [{\"quantizer_path\": \"*\", \"enable\": False}, *kv_cfg]\n", + " ):\n", " mtq.calibrate(model, algorithm=\"max\", forward_loop=forward_loop)\n", "else:\n", " print(\"KV cache left unquantized.\")" @@ -427,4 +429,4 @@ }, "nbformat": 4, "nbformat_minor": 5 -} +} \ No newline at end of file diff --git a/examples/llm_qat/main.py b/examples/llm_qat/main.py index 943515725..14d5a5c82 100644 --- a/examples/llm_qat/main.py +++ b/examples/llm_qat/main.py @@ -54,12 +54,20 @@ CUSTOM_QUANT_CFG = { "INT4_WEIGHT_INT8_ACTIVATIONS": { - "quant_cfg": { - "*weight_quantizer": {"num_bits": 4, "block_sizes": {-1: 128}, "enable": True}, - "*input_quantizer": {"num_bits": 8, "axis": None, "enable": True}, - "*lm_head*": {"enable": False}, - "default": {"enable": False}, - }, + "quant_cfg": [ + {"quantizer_path": "*", "enable": False}, + { + "quantizer_path": "*weight_quantizer", + "cfg": {"num_bits": 4, "block_sizes": {-1: 128}}, + "enable": True, + }, + { + "quantizer_path": "*input_quantizer", + "cfg": {"num_bits": 8, "axis": None}, + "enable": True, + }, + {"quantizer_path": "*lm_head*", "enable": False}, + ], "algorithm": "max", } } diff --git a/examples/vllm_serve/fakequant_worker.py b/examples/vllm_serve/fakequant_worker.py index 772c6fe66..284aba8f7 100644 --- a/examples/vllm_serve/fakequant_worker.py +++ b/examples/vllm_serve/fakequant_worker.py @@ -155,7 +155,7 @@ def disable_compilation(model): } -def update_kv_cfg_for_mla(model: torch.nn.Module, kv_quant_cfg: dict[str, Any]) -> dict[str, Any]: +def update_kv_cfg_for_mla(model: torch.nn.Module, kv_quant_cfg: list) -> list: """Update KV cache quantization config for MLA models. MLA uses `kv_c_bmm_quantizer` (compressed KV) instead of separate @@ -170,9 +170,18 @@ def update_kv_cfg_for_mla(model: torch.nn.Module, kv_quant_cfg: dict[str, Any]) if not any(isinstance(m, MLAAttention) for m in model.modules()): return kv_quant_cfg - if kv_config := kv_quant_cfg.get("*[kv]_bmm_quantizer"): - kv_quant_cfg["*kv_c_bmm_quantizer"] = kv_config - kv_quant_cfg["*k_pe_bmm_quantizer"] = kv_config + kv_entry = next( + ( + e + for e in kv_quant_cfg + if isinstance(e, dict) and e.get("quantizer_path") == "*[kv]_bmm_quantizer" + ), + None, + ) + if kv_entry is not None: + kv_config = kv_entry.get("cfg", {}) + kv_quant_cfg.append({"quantizer_path": "*kv_c_bmm_quantizer", "cfg": kv_config}) + kv_quant_cfg.append({"quantizer_path": "*k_pe_bmm_quantizer", "cfg": kv_config}) print("MLA detected: added *kv_c_bmm_quantizer and k_pe_bmm_quantizer config") return kv_quant_cfg diff --git a/examples/windows/torch_onnx/diffusers/qad_example/sample_example_qad_diffusers.py b/examples/windows/torch_onnx/diffusers/qad_example/sample_example_qad_diffusers.py index a861493b3..4c66de1d4 100644 --- a/examples/windows/torch_onnx/diffusers/qad_example/sample_example_qad_diffusers.py +++ b/examples/windows/torch_onnx/diffusers/qad_example/sample_example_qad_diffusers.py @@ -257,26 +257,18 @@ def build_quant_config( if exclude_blocks is None: exclude_blocks = [0, 1, 46, 47] - quant_cfg = { - "*weight_quantizer": { - "num_bits": (2, 1), - "block_sizes": {-1: 16, "type": "dynamic", "scale_bits": (4, 3)}, - "axis": None, - "enable": True, - }, - "*input_quantizer": { - "num_bits": (2, 1), - "block_sizes": {-1: 16, "type": "dynamic", "scale_bits": (4, 3)}, - "axis": None, - "enable": True, - }, + _nvfp4_cfg = { + "num_bits": (2, 1), + "block_sizes": {-1: 16, "type": "dynamic", "scale_bits": (4, 3)}, + "axis": None, + "enable": True, } - - for pattern in SENSITIVE_LAYER_PATTERNS: - quant_cfg[pattern] = {"enable": False} - - for block_idx in exclude_blocks: - quant_cfg[f"*transformer_blocks.{block_idx}.*"] = {"enable": False} + quant_cfg = [ + ("*weight_quantizer", _nvfp4_cfg), + ("*input_quantizer", _nvfp4_cfg), + *[(pattern, {"enable": False}) for pattern in SENSITIVE_LAYER_PATTERNS], + *[(f"*transformer_blocks.{i}.*", {"enable": False}) for i in exclude_blocks], + ] return { "quant_cfg": quant_cfg, diff --git a/modelopt/onnx/llm_export_utils/quantization_utils.py b/modelopt/onnx/llm_export_utils/quantization_utils.py index 61f551b63..a8fdcb98c 100644 --- a/modelopt/onnx/llm_export_utils/quantization_utils.py +++ b/modelopt/onnx/llm_export_utils/quantization_utils.py @@ -68,24 +68,47 @@ def get_quant_config(precision, lm_head_precision="fp16"): else: raise ValueError(f"Unsupported precision: {precision}") - config_dict = quant_cfg["quant_cfg"] # type: dict + quant_cfg_list: list = [ + e for e in quant_cfg["quant_cfg"] if isinstance(e, dict) and "quantizer_path" in e + ] if lm_head_precision == "fp8": - config_dict["*lm_head.input_quantizer"] = {"num_bits": (4, 3), "axis": None} - config_dict["*lm_head.weight_quantizer"] = {"num_bits": (4, 3), "axis": None} + quant_cfg_list.append( + { + "quantizer_path": "*lm_head.input_quantizer", + "cfg": {"num_bits": (4, 3), "axis": None}, + } + ) + quant_cfg_list.append( + { + "quantizer_path": "*lm_head.weight_quantizer", + "cfg": {"num_bits": (4, 3), "axis": None}, + } + ) elif lm_head_precision == "nvfp4": - config_dict["*lm_head.input_quantizer"] = { - "num_bits": (2, 1), - "block_sizes": {-1: 16, "type": "dynamic", "scale_bits": (4, 3)}, - "axis": None, - "enable": True, - } - config_dict["*lm_head.weight_quantizer"] = { - "num_bits": (2, 1), - "block_sizes": {-1: 16, "type": "dynamic", "scale_bits": (4, 3)}, - "axis": None, - "enable": True, - } + quant_cfg_list.append( + { + "quantizer_path": "*lm_head.input_quantizer", + "cfg": { + "num_bits": (2, 1), + "block_sizes": {-1: 16, "type": "dynamic", "scale_bits": (4, 3)}, + "axis": None, + }, + "enable": True, + } + ) + quant_cfg_list.append( + { + "quantizer_path": "*lm_head.weight_quantizer", + "cfg": { + "num_bits": (2, 1), + "block_sizes": {-1: 16, "type": "dynamic", "scale_bits": (4, 3)}, + "axis": None, + }, + "enable": True, + } + ) + quant_cfg["quant_cfg"] = quant_cfg_list return quant_cfg diff --git a/modelopt/torch/export/unified_export_hf.py b/modelopt/torch/export/unified_export_hf.py index 14a12bcdf..3433fe5f7 100644 --- a/modelopt/torch/export/unified_export_hf.py +++ b/modelopt/torch/export/unified_export_hf.py @@ -218,7 +218,10 @@ def _output_hook(module, input, output): # Run dummy forward pass to collect modules sharing same input try: - with torch.no_grad(), set_quantizer_by_cfg_context(model, {"*": {"enable": False}}): + with ( + torch.no_grad(), + set_quantizer_by_cfg_context(model, [{"quantizer_path": "*", "enable": False}]), + ): dummy_forward_fn() finally: # Always remove hooks diff --git a/modelopt/torch/quantization/algorithms.py b/modelopt/torch/quantization/algorithms.py index 339e9d0bb..c00b39f6a 100644 --- a/modelopt/torch/quantization/algorithms.py +++ b/modelopt/torch/quantization/algorithms.py @@ -62,9 +62,22 @@ def estimate_quant_compression(quant_cfg: QuantizeConfig) -> float: def estimate_quant_compression_for_quantizer(quantizer_attr_cfg): if isinstance(quantizer_attr_cfg, list): + if not quantizer_attr_cfg: + return 1.0 return min(estimate_quant_compression_for_quantizer(q) for q in quantizer_attr_cfg) if isinstance(quantizer_attr_cfg, dict): - return estimate_quant_compression_for_quantizer(list(quantizer_attr_cfg.values())) + # Handle raw quantizer cfg dicts (e.g. {"num_bits": (4, 3), "axis": None}) + if not quantizer_attr_cfg.get("enable", True): + return 1.0 + num_bits = quantizer_attr_cfg.get("num_bits") + if num_bits is None: + return 1.0 + if isinstance(num_bits, tuple): + return (sum(num_bits) + 1) / 16 + elif isinstance(num_bits, int): + return num_bits / 16 + else: + raise ValueError(f"Unknown quantization config {num_bits}") if isinstance(quantizer_attr_cfg, QuantizerAttributeConfig): if not quantizer_attr_cfg.enable: @@ -80,7 +93,9 @@ def estimate_quant_compression_for_quantizer(quantizer_attr_cfg): raise ValueError(f"Unknown type {type(quantizer_attr_cfg)}, {quantizer_attr_cfg}") - return estimate_quant_compression_for_quantizer(list(quant_cfg.quant_cfg.values())) + cfgs = [e.get("cfg", {}) for e in quant_cfg.quant_cfg] + cfgs = [c for c in cfgs if c is not None] + return estimate_quant_compression_for_quantizer(cfgs) if cfgs else 1.0 class QuantRecipe(CustomHPType): @@ -97,7 +112,7 @@ def __init__(self, quant_cfg: str | dict[str, Any] | None = None, name: str | No name = self.get_auto_name_for_config(quant_cfg) or name if quant_cfg is None: - quant_cfg = {"quant_cfg": {"*": {"enable": False}}} + quant_cfg = {"quant_cfg": [{"quantizer_path": "*", "enable": False}]} elif isinstance(quant_cfg, str): assert hasattr(mtq_config, quant_cfg), f"Unknown quantization format {quant_cfg}" quant_cfg = getattr(mtq_config, quant_cfg) @@ -109,9 +124,7 @@ def __init__(self, quant_cfg: str | dict[str, Any] | None = None, name: str | No # Disable KV Cache quantization # Currently KV Cache quantization is enabled for some quantization formats and disabled for others # This breaks the monotonicity of the quantization formats in terms of weight compression Vs accuracy - self.config.quant_cfg["*output_quantizer"] = mtq_config.QuantizerAttributeConfig( - enable=False - ) + self.config.quant_cfg.append({"quantizer_path": "*output_quantizer", "enable": False}) self.compression = estimate_quant_compression(self.config) @@ -1299,29 +1312,36 @@ def get_auto_quantize_config(search_state, constraints=None, verbose=False): else: best_recipe = search_state["best"]["recipe"] - quant_cfg: dict[str, Any] = {"*": {"enable": False}} + def _cfg_to_dict(v): + if isinstance(v, mtq_config.QuantizerAttributeConfig): + return ( + { + "num_bits": v.num_bits, + **v.model_dump(exclude_defaults=True), + }, + ) + if isinstance(v, list): + return [_cfg_to_dict(c) for c in v] + return v + + quant_cfg: list[dict] = [{"quantizer_path": "*", "enable": False}] for hparam_name, recipe in best_recipe.items(): if recipe == QuantRecipe(quant_cfg=None): continue module_names = search_state["candidate_stats"][hparam_name]["module_names"] for module_name in module_names: for quantizer_attr in ("input_quantizer", "weight_quantizer"): - matched_cfg = _match_quantizer_cfg(recipe.config.quant_cfg, quantizer_attr) + matched_cfg, matched_enable = _match_quantizer_cfg( + recipe.config.quant_cfg, quantizer_attr + ) if matched_cfg is not None: - quant_cfg[f"{module_name}.{quantizer_attr}"] = matched_cfg - - def _cfg_to_dict(v): - if isinstance(v, mtq_config.QuantizerAttributeConfig): - return { - "enable": v.enable, - "num_bits": v.num_bits, - **v.model_dump(exclude_defaults=True), - } - if isinstance(v, list): - return [_cfg_to_dict(c) for c in v] - return v - - quant_cfg = {k: _cfg_to_dict(v) for k, v in quant_cfg.items()} + quant_cfg.append( + { + "quantizer_path": f"{module_name}.{quantizer_attr}", + "cfg": _cfg_to_dict(matched_cfg), + "enable": matched_enable, + } + ) warnings.warn( "get_auto_quantize_config: returned config uses algorithm='max'. " "Per-recipe calibration algorithms (e.g. smoothquant, awq) are not preserved. " @@ -1363,7 +1383,13 @@ def _resolve_best_recipe(search_state, constraints, verbose=False): def _match_quantizer_cfg(quant_cfg, quantizer_attr): # Last-match-wins to mirror set_quantizer_by_cfg behavior matched = None - for pattern, cfg in quant_cfg.items(): + matched_enable = False + for entry in quant_cfg: + pattern = entry["quantizer_path"] + cfg = entry.get("cfg", {}) + enable = entry.get("enable", True) if fnmatch.fnmatch(quantizer_attr, pattern): matched = cfg - return matched + matched_enable = enable + + return matched, matched_enable diff --git a/modelopt/torch/quantization/backends/fp8_per_tensor_gemm.py b/modelopt/torch/quantization/backends/fp8_per_tensor_gemm.py index cc5be9d56..a668b33b8 100644 --- a/modelopt/torch/quantization/backends/fp8_per_tensor_gemm.py +++ b/modelopt/torch/quantization/backends/fp8_per_tensor_gemm.py @@ -15,8 +15,6 @@ """This module provides a GEMM function for fp8 per tensor quantization.""" -from typing import Any - import torch from torch.autograd import Function @@ -99,9 +97,23 @@ def fp8_per_tensor_gemm(quant_module, input, bias=None): def _fp8_availability_check(module, input, args, kwargs): """Comprehensive check for FP8 GEMM availability.""" # Quantizer configs - quant_cfg: dict[str, Any] = FP8_DEFAULT_CFG["quant_cfg"] - input_cfg = quant_cfg["*input_quantizer"] - weight_cfg = quant_cfg["*weight_quantizer"] + quant_cfg_list = FP8_DEFAULT_CFG["quant_cfg"] + input_cfg = next( + e.get("cfg", {}) + for e in quant_cfg_list + if isinstance(e, dict) + and "quantizer_path" in e + and e["quantizer_path"] == "*input_quantizer" + ) + weight_cfg = next( + e.get("cfg", {}) + for e in quant_cfg_list + if isinstance(e, dict) + and "quantizer_path" in e + and e["quantizer_path"] == "*weight_quantizer" + ) + assert isinstance(input_cfg, dict) + assert isinstance(weight_cfg, dict) # Check hardware support if not torch.cuda.is_available() or not fp8_compatible(): diff --git a/modelopt/torch/quantization/backends/nvfp4_gemm.py b/modelopt/torch/quantization/backends/nvfp4_gemm.py index ffc18fea3..e70d51ea1 100644 --- a/modelopt/torch/quantization/backends/nvfp4_gemm.py +++ b/modelopt/torch/quantization/backends/nvfp4_gemm.py @@ -15,8 +15,6 @@ """This module provides a GEMM function for nvfp4 quantization.""" -from typing import Any - import torch from torch.autograd import Function @@ -213,10 +211,24 @@ def _nvfp4_availability_check(module, input, args, kwargs): if not hasattr(module, "input_quantizer") or not hasattr(module, "weight_quantizer"): return False - quant_cfg: dict[str, Any] = mtq.NVFP4_DEFAULT_CFG["quant_cfg"] + quant_cfg_list = mtq.NVFP4_DEFAULT_CFG["quant_cfg"] # Quantizer configs - input_cfg = quant_cfg["*input_quantizer"] - weight_cfg = quant_cfg["*weight_quantizer"] + input_cfg = next( + e.get("cfg", {}) + for e in quant_cfg_list + if isinstance(e, dict) + and "quantizer_path" in e + and e["quantizer_path"] == "*input_quantizer" + ) + weight_cfg = next( + e.get("cfg", {}) + for e in quant_cfg_list + if isinstance(e, dict) + and "quantizer_path" in e + and e["quantizer_path"] == "*weight_quantizer" + ) + assert isinstance(input_cfg, dict) + assert isinstance(weight_cfg, dict) # Check input quantizer config for key, value in input_cfg.items(): diff --git a/modelopt/torch/quantization/compress.py b/modelopt/torch/quantization/compress.py index 5477d0b61..2a5cbbee9 100644 --- a/modelopt/torch/quantization/compress.py +++ b/modelopt/torch/quantization/compress.py @@ -30,7 +30,7 @@ from .backends.gemm_registry import disable_real_quant_gemm, enable_real_quant_gemm from .config import CompressCfgType, CompressConfig -from .conversion import _replace_quant_module, set_quantizer_attribute +from .conversion import _replace_quant_module, set_quantizer_attributes_partial from .nn.modules.quant_linear import RealQuantLinear from .qtensor import QTensorWrapper, pack_real_quantize_weight from .utils import is_quantized_linear @@ -87,7 +87,7 @@ def compress_convert( compress_cfg = config.compress if "default" in compress_cfg and isinstance(compress_cfg["default"], bool): - set_quantizer_attribute( + set_quantizer_attributes_partial( model, "*weight_quantizer*", {"fake_quant": not compress_cfg["default"]} ) @@ -99,7 +99,7 @@ def compress_convert( def filter_func(name): return fnmatch.fnmatch(name, pattern) and "weight_quantizer" in name - set_quantizer_attribute(model, filter_func, {"fake_quant": not to_compress}) + set_quantizer_attributes_partial(model, filter_func, {"fake_quant": not to_compress}) else: raise ValueError( f"Invalid compression configuration: {to_compress}, expected a boolean as value." diff --git a/modelopt/torch/quantization/config.py b/modelopt/torch/quantization/config.py index cf2336bf4..7968c56ba 100644 --- a/modelopt/torch/quantization/config.py +++ b/modelopt/torch/quantization/config.py @@ -50,40 +50,52 @@ Quantization Configs ================================ -Quantization config is dictionary specifying the values for keys ``"quant_cfg"`` and -``"algorithm"``. The ``"quant_cfg"`` key specifies the quantization configurations. The -``"algorithm"`` key specifies the ``algorithm`` argument to -:meth:`calibrate `. Please see :class:`QuantizeConfig` -for the quantization config definition. - -'Quantization configurations' is a dictionary mapping wildcards or filter functions -to its 'quantizer attributes'. The wildcards or filter functions are matched -against the quantizer module names. The quantizer modules have names ending with -``weight_quantizer`` and ``input_quantizer`` and they perform weight quantization and -input quantization (or activation quantization) respectively. The quantizer modules are generally -instances of -:class:`TensorQuantizer `. -The quantizer attributes are defined by :class:`QuantizerAttributeConfig`. See :class:`QuantizerAttributeConfig` -for details on the quantizer attributes and their values. - -The key `"default"` from the quantization configuration dictionary is applied if no other wildcard or filter functions -match the quantizer module name. - -The quantizer attributes are applied in the order they are specified. For the missing attributes, the default attributes -as defined by :class:`QuantizerAttributeConfig` are used. - -Quantizer attributes can also be a list of dictionaries. In this case, the matched quantizer module -is replaced with a -:class:`SequentialQuantizer ` -module which is used to quantize a tensor in multiple formats sequentially. Each quantizer attribute -dictionary in the list specifies the quantization formats for each quantization step of the -sequential quantizer. For example, `SequentialQuantizer` is used in 'INT4 Weights, FP8 Activations' -quantization in which the weights are quantized in INT4 followed by FP8. - -In addition, the dictionary entries could also be pytorch module class names mapping the class specific -quantization configurations. The pytorch modules should have a quantized equivalent. - -To get the string representation of a module class, do: +Quantization config is a dictionary with two top-level keys: + +- ``"quant_cfg"``: an ordered list of :class:`QuantizerCfgEntry` dicts that specify which + quantizers to configure and how. +- ``"algorithm"``: the calibration algorithm passed to + :meth:`calibrate `. + +Please see :class:`QuantizeConfig` for the full config schema. + +``quant_cfg`` — Entry Format +----------------------------- + +Each entry in the ``quant_cfg`` list is a :class:`QuantizerCfgEntry` with the following fields: + +- ``quantizer_path`` *(required)*: a wildcard string matched against quantizer module names. + Quantizer modules are instances of + :class:`TensorQuantizer ` + and have names ending with ``weight_quantizer``, ``input_quantizer``, etc. +- ``parent_class`` *(optional)*: restricts matching to quantizers whose immediate parent module is + of this PyTorch class (e.g. ``"nn.Linear"``). If omitted, all matching quantizers are targeted + regardless of their parent class. +- ``cfg`` *(optional)*: a dict of quantizer attributes as defined by + :class:`QuantizerAttributeConfig`, or a list of such dicts. When a list is given, the matched + :class:`TensorQuantizer ` + is replaced with a + :class:`SequentialQuantizer ` + that applies each format in sequence. This is used for example in W4A8 quantization where weights + are quantized first in INT4 and then in FP8. +- ``enable`` *(optional)*: toggles matched quantizers on (``True``) or off (``False``), + independently of ``cfg``. When ``cfg`` is present and ``enable`` is absent, the quantizer is + implicitly enabled. When ``enable`` is the only field (no ``cfg``), it only flips the on/off + state — all other attributes remain unchanged. + +``quant_cfg`` — Ordering and Precedence +----------------------------------------- + +Entries are applied **in list order**; later entries override earlier ones for any quantizer they +match. The recommended pattern is: + +1. Start with a deny-all entry ``{"quantizer_path": "*", "enable": False}`` (provided as + :data:`_base_disable_all`) to disable every quantizer by default. +2. Follow with format-specific entries that selectively enable and configure the desired quantizers. +3. Append :data:`_default_disabled_quantizer_cfg` to enforce standard exclusions (e.g. BatchNorm + layers, LM head, MoE routers). + +To get the string representation of a module class for use in ``parent_class``, do: .. code-block:: @@ -97,15 +109,17 @@ .. code-block:: MY_QUANT_CFG = { - "quant_cfg": { - # Quantizer wildcard strings mapping to quantizer attributes - "*weight_quantizer": {"num_bits": 8, "axis": 0}, - "*input_quantizer": {"num_bits": 8, "axis": None}, + "quant_cfg": [ + # Deny all quantizers by default + {"quantizer_path": "*", "enable": False}, - # Module class names mapping to quantizer configurations - "nn.LeakyReLU": {"*input_quantizer": {"enable": False}}, + # Enable and configure weight and input quantizers + {"quantizer_path": "*weight_quantizer", "cfg": {"num_bits": 8, "axis": 0}}, + {"quantizer_path": "*input_quantizer", "cfg": {"num_bits": 8, "axis": None}}, - } + # Disable input quantizers specifically for LeakyReLU layers + {"quantizer_path": "*input_quantizer", "parent_class": "nn.LeakyReLU", "enable": False}, + ] } .. _example-quantization-configs: @@ -129,157 +143,219 @@ # Create custom config CUSTOM_INT4_AWQ_CFG = copy.deepcopy(mtq.INT4_AWQ_CFG) - CUSTOM_INT4_AWQ_CFG["quant_cfg"]["*lm_head*"] = {"enable": False} + CUSTOM_INT4_AWQ_CFG["quant_cfg"].append({"quantizer_path": "*lm_head*", "enable": False}) # quantize model model = mtq.quantize(model, CUSTOM_INT4_AWQ_CFG, forward_loop) """ -from collections.abc import Callable -from typing import Literal +from typing import Any, Literal, cast from pydantic import ValidationInfo, field_validator, model_validator +from typing_extensions import TypedDict from modelopt.torch.opt.config import ModeloptBaseConfig, ModeloptField from modelopt.torch.utils.network import ConstructorLike -_default_disabled_quantizer_cfg = { - "nn.BatchNorm1d": {"*": {"enable": False}}, - "nn.BatchNorm2d": {"*": {"enable": False}}, - "nn.BatchNorm3d": {"*": {"enable": False}}, - "nn.LeakyReLU": {"*": {"enable": False}}, - "*lm_head*": {"enable": False}, - "*proj_out.*": {"enable": False}, # In Whisper model, lm_head has key name proj_out - "*block_sparse_moe.gate*": {"enable": False}, # Skip the MOE router - "*router*": {"enable": False}, # Skip the MOE router - "*mlp.gate.*": {"enable": False}, # Skip the MOE router - "*mlp.shared_expert_gate.*": {"enable": False}, # Skip the MOE router - "*linear_attn.conv1d*": {"enable": False}, - "*mixer.conv1d*": {"enable": False}, # Skip mamba conv1d - "*output_layer*": {"enable": False}, - "output.*": {"enable": False}, - "default": {"enable": False}, -} -_mamba_moe_disabled_quantizer_cfg = { - "*fc1_latent_proj*": {"enable": False}, # Skip Latent MOE - "*fc2_latent_proj*": {"enable": False}, # Skip Latent MOE - "*q_proj*": {"enable": False}, # Skip QKV Linear - "*k_proj*": {"enable": False}, # Skip QKV Linear - "*v_proj*": {"enable": False}, # Skip QKV Linear - "*o_proj*": {"enable": False}, # Skip QKV Output Projection -} +class QuantizerCfgEntry(TypedDict, total=False): + """A single entry in a ``quant_cfg`` list.""" + + quantizer_path: str # required; matched against quantizer module names + parent_class: str | None # optional; filters by pytorch module class name (e.g. "nn.Linear") + cfg: dict[str, Any] | list[dict[str, Any]] | None # quantizer attribute config(s) + enable: bool | None # toggles matched quantizers on/off; independent of cfg + + +_base_disable_all: list[QuantizerCfgEntry] = [ + {"quantizer_path": "*", "enable": False}, +] + +_default_disabled_quantizer_cfg: list[QuantizerCfgEntry] = [ + {"parent_class": "nn.BatchNorm1d", "quantizer_path": "*", "enable": False}, + {"parent_class": "nn.BatchNorm2d", "quantizer_path": "*", "enable": False}, + {"parent_class": "nn.BatchNorm3d", "quantizer_path": "*", "enable": False}, + {"parent_class": "nn.LeakyReLU", "quantizer_path": "*", "enable": False}, + {"quantizer_path": "*lm_head*", "enable": False}, + { + "quantizer_path": "*proj_out.*", + "enable": False, + }, # In Whisper model, lm_head has key name proj_out + { + "quantizer_path": "*block_sparse_moe.gate*", + "enable": False, + }, # Skip the MOE router + {"quantizer_path": "*router*", "enable": False}, # Skip the MOE router + {"quantizer_path": "*mlp.gate.*", "enable": False}, # Skip the MOE router + { + "quantizer_path": "*mlp.shared_expert_gate.*", + "enable": False, + }, # Skip the MOE router + {"quantizer_path": "*linear_attn.conv1d*", "enable": False}, + {"quantizer_path": "*mixer.conv1d*", "enable": False}, # Skip mamba conv1d + {"quantizer_path": "*output_layer*", "enable": False}, + {"quantizer_path": "output.*", "enable": False}, +] + +_mamba_moe_disabled_quantizer_cfg: list[QuantizerCfgEntry] = [ + {"quantizer_path": "*fc1_latent_proj*", "enable": False}, # Skip Latent MOE + {"quantizer_path": "*fc2_latent_proj*", "enable": False}, # Skip Latent MOE + {"quantizer_path": "*q_proj*", "enable": False}, # Skip QKV Linear + {"quantizer_path": "*k_proj*", "enable": False}, # Skip QKV Linear + {"quantizer_path": "*v_proj*", "enable": False}, # Skip QKV Linear + {"quantizer_path": "*o_proj*", "enable": False}, # Skip QKV Output Projection +] INT8_DEFAULT_CFG = { - "quant_cfg": { - "*weight_quantizer": {"num_bits": 8, "axis": 0}, - "*input_quantizer": {"num_bits": 8, "axis": None}, - **_default_disabled_quantizer_cfg, - }, + "quant_cfg": [ + *_base_disable_all, + {"quantizer_path": "*weight_quantizer", "cfg": {"num_bits": 8, "axis": 0}}, + {"quantizer_path": "*input_quantizer", "cfg": {"num_bits": 8, "axis": None}}, + *_default_disabled_quantizer_cfg, + ], "algorithm": "max", } INT8_SMOOTHQUANT_CFG = { - "quant_cfg": { - "*weight_quantizer": {"num_bits": 8, "axis": 0}, - "*input_quantizer": {"num_bits": 8, "axis": None}, - **_default_disabled_quantizer_cfg, - }, + "quant_cfg": [ + *_base_disable_all, + {"quantizer_path": "*weight_quantizer", "cfg": {"num_bits": 8, "axis": 0}}, + {"quantizer_path": "*input_quantizer", "cfg": {"num_bits": 8, "axis": None}}, + *_default_disabled_quantizer_cfg, + ], "algorithm": "smoothquant", } INT8_WEIGHT_ONLY_CFG = { - "quant_cfg": { - "*weight_quantizer": {"num_bits": 8, "axis": 0}, - "*input_quantizer": {"enable": False}, - **_default_disabled_quantizer_cfg, - }, + "quant_cfg": [ + *_base_disable_all, + {"quantizer_path": "*weight_quantizer", "cfg": {"num_bits": 8, "axis": 0}}, + {"quantizer_path": "*input_quantizer", "enable": False}, + *_default_disabled_quantizer_cfg, + ], "algorithm": "max", } FP8_DEFAULT_CFG = { - "quant_cfg": { - "*weight_quantizer": {"num_bits": (4, 3), "axis": None}, - "*input_quantizer": {"num_bits": (4, 3), "axis": None}, - **_default_disabled_quantizer_cfg, - }, + "quant_cfg": [ + *_base_disable_all, + { + "quantizer_path": "*weight_quantizer", + "cfg": {"num_bits": (4, 3), "axis": None}, + }, + { + "quantizer_path": "*input_quantizer", + "cfg": {"num_bits": (4, 3), "axis": None}, + }, + *_default_disabled_quantizer_cfg, + ], "algorithm": "max", } MAMBA_MOE_FP8_AGGRESSIVE_CFG = { - "quant_cfg": { - "*weight_quantizer": {"num_bits": (4, 3), "axis": None}, - "*input_quantizer": {"num_bits": (4, 3), "axis": None}, - **_default_disabled_quantizer_cfg, - **_mamba_moe_disabled_quantizer_cfg, - }, + "quant_cfg": [ + *_base_disable_all, + { + "quantizer_path": "*weight_quantizer", + "cfg": {"num_bits": (4, 3), "axis": None}, + }, + { + "quantizer_path": "*input_quantizer", + "cfg": {"num_bits": (4, 3), "axis": None}, + }, + *_default_disabled_quantizer_cfg, + *_mamba_moe_disabled_quantizer_cfg, + ], "algorithm": "max", } MAMBA_MOE_FP8_CONSERVATIVE_CFG = { - "quant_cfg": { - "*weight_quantizer": {"num_bits": (4, 3), "axis": None}, - "*input_quantizer": {"num_bits": (4, 3), "axis": None}, - **_default_disabled_quantizer_cfg, - **_mamba_moe_disabled_quantizer_cfg, - "*mixer.in_proj*": {"enable": False}, # Skip mamba linear - "*mixer.out_proj*": {"enable": False}, # Skip mamba linear - }, + "quant_cfg": [ + *_base_disable_all, + { + "quantizer_path": "*weight_quantizer", + "cfg": {"num_bits": (4, 3), "axis": None}, + }, + { + "quantizer_path": "*input_quantizer", + "cfg": {"num_bits": (4, 3), "axis": None}, + }, + *_default_disabled_quantizer_cfg, + *_mamba_moe_disabled_quantizer_cfg, + {"quantizer_path": "*mixer.in_proj*", "enable": False}, # Skip mamba linear + {"quantizer_path": "*mixer.out_proj*", "enable": False}, # Skip mamba linear + ], "algorithm": "max", } FP8_PER_CHANNEL_PER_TOKEN_CFG = { - "quant_cfg": { - "*weight_quantizer": {"num_bits": (4, 3), "axis": 0}, - "*input_quantizer": { - "num_bits": (4, 3), - "type": "dynamic", - "block_sizes": {-1: None}, + "quant_cfg": [ + *_base_disable_all, + {"quantizer_path": "*weight_quantizer", "cfg": {"num_bits": (4, 3), "axis": 0}}, + { + "quantizer_path": "*input_quantizer", + "cfg": { + "num_bits": (4, 3), + "type": "dynamic", + "block_sizes": {-1: None}, + }, }, - **_default_disabled_quantizer_cfg, - }, + *_default_disabled_quantizer_cfg, + ], "algorithm": "max", } # FP8 2D blockwise fake quantization config for deepseek models FP8_2D_BLOCKWISE_WEIGHT_ONLY_CFG = { - "quant_cfg": { - "*weight_quantizer": { - "num_bits": (4, 3), - "block_sizes": {-1: 128, -2: 128}, + "quant_cfg": [ + *_base_disable_all, + { + "quantizer_path": "*weight_quantizer", + "cfg": { + "num_bits": (4, 3), + "block_sizes": {-1: 128, -2: 128}, + }, "enable": True, }, - "*input_quantizer": {"enable": False}, - **_default_disabled_quantizer_cfg, - }, + {"quantizer_path": "*input_quantizer", "enable": False}, + *_default_disabled_quantizer_cfg, + ], "algorithm": "max", } INT4_BLOCKWISE_WEIGHT_ONLY_CFG = { - "quant_cfg": { - "*weight_quantizer": { - "num_bits": 4, - "block_sizes": {-1: 128}, + "quant_cfg": [ + *_base_disable_all, + { + "quantizer_path": "*weight_quantizer", + "cfg": { + "num_bits": 4, + "block_sizes": {-1: 128}, + }, "enable": True, }, - "*input_quantizer": {"enable": False}, - **_default_disabled_quantizer_cfg, - }, + {"quantizer_path": "*input_quantizer", "enable": False}, + *_default_disabled_quantizer_cfg, + ], "algorithm": "max", } INT4_AWQ_CFG = { - "quant_cfg": { - "*weight_quantizer": { - "num_bits": 4, - "block_sizes": {-1: 128, "type": "static"}, + "quant_cfg": [ + *_base_disable_all, + { + "quantizer_path": "*weight_quantizer", + "cfg": { + "num_bits": 4, + "block_sizes": {-1: 128, "type": "static"}, + }, "enable": True, }, - "*input_quantizer": {"enable": False}, - **_default_disabled_quantizer_cfg, - }, + {"quantizer_path": "*input_quantizer", "enable": False}, + *_default_disabled_quantizer_cfg, + ], "algorithm": {"method": "awq_lite", "alpha_step": 0.1}, # "algorithm": {"method": "awq_full", "alpha_step": 0.1, "max_co_batch_size": 1024}, # "algorithm": {"method": "awq_clip", "max_co_batch_size": 2048}, @@ -288,137 +364,179 @@ # W4A8 currently uses INT4 blockwise quantization (block size = 128) followed by FP8 quantization # for weights. This could change in the future W4A8_AWQ_BETA_CFG = { - "quant_cfg": { - "*weight_quantizer": [ - { - "num_bits": 4, - "block_sizes": {-1: 128, "type": "static"}, - "enable": True, - }, - { + "quant_cfg": [ + *_base_disable_all, + { + "quantizer_path": "*weight_quantizer", + "cfg": [ + { + "num_bits": 4, + "block_sizes": {-1: 128, "type": "static"}, + }, + { + "num_bits": (4, 3), + }, + ], + "enable": True, + }, + { + "quantizer_path": "*input_quantizer", + "cfg": { "num_bits": (4, 3), - "enable": True, }, - ], - "*input_quantizer": { - "num_bits": (4, 3), "enable": True, }, - **_default_disabled_quantizer_cfg, - }, + *_default_disabled_quantizer_cfg, + ], "algorithm": "awq_lite", } MXFP8_DEFAULT_CFG = { - "quant_cfg": { - "*weight_quantizer": { - "num_bits": (4, 3), - "block_sizes": {-1: 32, "type": "dynamic", "scale_bits": (8, 0)}, + "quant_cfg": [ + *_base_disable_all, + { + "quantizer_path": "*weight_quantizer", + "cfg": { + "num_bits": (4, 3), + "block_sizes": {-1: 32, "type": "dynamic", "scale_bits": (8, 0)}, + }, "enable": True, }, - "*input_quantizer": { - "num_bits": (4, 3), - "block_sizes": {-1: 32, "type": "dynamic", "scale_bits": (8, 0)}, + { + "quantizer_path": "*input_quantizer", + "cfg": { + "num_bits": (4, 3), + "block_sizes": {-1: 32, "type": "dynamic", "scale_bits": (8, 0)}, + }, "enable": True, }, - **_default_disabled_quantizer_cfg, - }, + *_default_disabled_quantizer_cfg, + ], "algorithm": None, } MXFP6_DEFAULT_CFG = { - "quant_cfg": { - "*weight_quantizer": { - "num_bits": (3, 2), - "block_sizes": {-1: 32, "type": "dynamic", "scale_bits": (8, 0)}, + "quant_cfg": [ + *_base_disable_all, + { + "quantizer_path": "*weight_quantizer", + "cfg": { + "num_bits": (3, 2), + "block_sizes": {-1: 32, "type": "dynamic", "scale_bits": (8, 0)}, + }, "enable": True, }, - "*input_quantizer": { - "num_bits": (3, 2), - "block_sizes": {-1: 32, "type": "dynamic", "scale_bits": (8, 0)}, + { + "quantizer_path": "*input_quantizer", + "cfg": { + "num_bits": (3, 2), + "block_sizes": {-1: 32, "type": "dynamic", "scale_bits": (8, 0)}, + }, "enable": True, }, - **_default_disabled_quantizer_cfg, - }, + *_default_disabled_quantizer_cfg, + ], "algorithm": None, } MXFP4_DEFAULT_CFG = { - "quant_cfg": { - "*weight_quantizer": { - "num_bits": (2, 1), - "block_sizes": {-1: 32, "type": "dynamic", "scale_bits": (8, 0)}, + "quant_cfg": [ + *_base_disable_all, + { + "quantizer_path": "*weight_quantizer", + "cfg": { + "num_bits": (2, 1), + "block_sizes": {-1: 32, "type": "dynamic", "scale_bits": (8, 0)}, + }, "enable": True, }, - "*input_quantizer": { - "num_bits": (2, 1), - "block_sizes": {-1: 32, "type": "dynamic", "scale_bits": (8, 0)}, + { + "quantizer_path": "*input_quantizer", + "cfg": { + "num_bits": (2, 1), + "block_sizes": {-1: 32, "type": "dynamic", "scale_bits": (8, 0)}, + }, "enable": True, }, - **_default_disabled_quantizer_cfg, - }, + *_default_disabled_quantizer_cfg, + ], "algorithm": None, } W4A8_MXFP4_FP8_CFG = { - "quant_cfg": { - "*weight_quantizer": { - "num_bits": (2, 1), - "block_sizes": {-1: 32, "type": "dynamic", "scale_bits": (8, 0)}, + "quant_cfg": [ + *_base_disable_all, + { + "quantizer_path": "*weight_quantizer", + "cfg": { + "num_bits": (2, 1), + "block_sizes": {-1: 32, "type": "dynamic", "scale_bits": (8, 0)}, + }, "enable": True, }, - "*input_quantizer": {"num_bits": (4, 3), "axis": None}, - **_default_disabled_quantizer_cfg, - }, + { + "quantizer_path": "*input_quantizer", + "cfg": {"num_bits": (4, 3), "axis": None}, + }, + *_default_disabled_quantizer_cfg, + ], "algorithm": None, } MXINT8_DEFAULT_CFG = { - "quant_cfg": { - "*weight_quantizer": { - "num_bits": 8, - "block_sizes": {-1: 32, "type": "dynamic", "scale_bits": (8, 0)}, + "quant_cfg": [ + *_base_disable_all, + { + "quantizer_path": "*weight_quantizer", + "cfg": { + "num_bits": 8, + "block_sizes": {-1: 32, "type": "dynamic", "scale_bits": (8, 0)}, + }, "enable": True, }, - "*input_quantizer": { - "num_bits": 8, - "block_sizes": {-1: 32, "type": "dynamic", "scale_bits": (8, 0)}, + { + "quantizer_path": "*input_quantizer", + "cfg": { + "num_bits": 8, + "block_sizes": {-1: 32, "type": "dynamic", "scale_bits": (8, 0)}, + }, "enable": True, }, - **_default_disabled_quantizer_cfg, - }, + *_default_disabled_quantizer_cfg, + ], "algorithm": None, } FP8_KV_CFG = { - "quant_cfg": { - "*[kv]_bmm_quantizer": { - "num_bits": (4, 3), + "quant_cfg": [ + { + "quantizer_path": "*[kv]_bmm_quantizer", + "cfg": { + "num_bits": (4, 3), + }, "enable": True, }, - "default": {"enable": False}, - }, - "algorithm": "max", + ] } FP8_AFFINE_KV_CFG = { - "quant_cfg": { - "*[kv]_bmm_quantizer": { - "num_bits": (4, 3), - "bias": {-2: None, -4: None, "type": "static"}, + "quant_cfg": [ + { + "quantizer_path": "*[kv]_bmm_quantizer", + "cfg": { + "num_bits": (4, 3), + "bias": {-2: None, -4: None, "type": "static"}, + }, }, - "default": {"enable": False}, - }, - "algorithm": "max", + ] } -_nvfp4_quantizer = { +_nvfp4_cfg = { "num_bits": (2, 1), "block_sizes": {-1: 16, "type": "dynamic", "scale_bits": (4, 3)}, - "enable": True, } -_nvfp4_quantizer_bs32 = { +_nvfp4_cfg_bs32 = { "num_bits": (2, 1), "block_sizes": {-1: 32, "type": "dynamic", "scale_bits": (4, 3)}, "enable": True, @@ -428,32 +546,37 @@ def _nvfp4_selective_quant_cfg( layer_patterns: list[str], *, - quantizer: dict = _nvfp4_quantizer, + quantizer: dict = _nvfp4_cfg, weight_only: bool = False, algorithm: str | dict = "max", ) -> dict: """Build an NVFP4 config that quantizes only the specified layer patterns.""" - quant_cfg: dict[str, object] = {} + quant_cfg: list[QuantizerCfgEntry] = [] + quant_cfg.extend(_base_disable_all) for pattern in layer_patterns: - quant_cfg[f"{pattern}weight_quantizer"] = quantizer + quant_cfg.append({"quantizer_path": f"{pattern}weight_quantizer", "cfg": quantizer}) if not weight_only: - quant_cfg[f"{pattern}input_quantizer"] = quantizer - quant_cfg.update(_default_disabled_quantizer_cfg) + quant_cfg.append({"quantizer_path": f"{pattern}input_quantizer", "cfg": quantizer}) + quant_cfg.extend(_default_disabled_quantizer_cfg) return {"quant_cfg": quant_cfg, "algorithm": algorithm} NVFP4_DEFAULT_CFG = _nvfp4_selective_quant_cfg(["*"]) NVFP4_W4A4_WEIGHT_MSE_FP8_SWEEP_CFG = { - "quant_cfg": { - "*weight_quantizer": { - "num_bits": (2, 1), - "block_sizes": {-1: 16, "type": "static", "scale_bits": (4, 3)}, + "quant_cfg": [ + *_base_disable_all, + { + "quantizer_path": "*weight_quantizer", + "cfg": { + "num_bits": (2, 1), + "block_sizes": {-1: 16, "type": "static", "scale_bits": (4, 3)}, + }, "enable": True, }, - "*input_quantizer": _nvfp4_quantizer, - **_default_disabled_quantizer_cfg, - }, + {"quantizer_path": "*input_quantizer", "cfg": _nvfp4_cfg, "enable": True}, + *_default_disabled_quantizer_cfg, + ], "algorithm": { "method": "mse", "fp8_scale_sweep": True, @@ -461,15 +584,19 @@ def _nvfp4_selective_quant_cfg( } NVFP4_W4A4_WEIGHT_LOCAL_HESSIAN_CFG = { - "quant_cfg": { - "*weight_quantizer": { - "num_bits": (2, 1), - "block_sizes": {-1: 16, "type": "static", "scale_bits": (4, 3)}, + "quant_cfg": [ + *_base_disable_all, + { + "quantizer_path": "*weight_quantizer", + "cfg": { + "num_bits": (2, 1), + "block_sizes": {-1: 16, "type": "static", "scale_bits": (4, 3)}, + }, "enable": True, }, - "*input_quantizer": _nvfp4_quantizer, - **_default_disabled_quantizer_cfg, - }, + {"quantizer_path": "*input_quantizer", "cfg": _nvfp4_cfg, "enable": True}, + *_default_disabled_quantizer_cfg, + ], "algorithm": { "method": "local_hessian", "fp8_scale_sweep": True, @@ -477,27 +604,28 @@ def _nvfp4_selective_quant_cfg( } MAMBA_MOE_NVFP4_AGGRESSIVE_CFG = { - "quant_cfg": { - "*weight_quantizer": _nvfp4_quantizer, - "*input_quantizer": _nvfp4_quantizer, - **_default_disabled_quantizer_cfg, - **_mamba_moe_disabled_quantizer_cfg, - }, + "quant_cfg": [ + *_base_disable_all, + {"quantizer_path": "*weight_quantizer", "cfg": _nvfp4_cfg, "enable": True}, + {"quantizer_path": "*input_quantizer", "cfg": _nvfp4_cfg, "enable": True}, + *_default_disabled_quantizer_cfg, + *_mamba_moe_disabled_quantizer_cfg, + ], "algorithm": "max", } MAMBA_MOE_NVFP4_CONSERVATIVE_CFG = { - "quant_cfg": { - "*weight_quantizer": _nvfp4_quantizer, - "*input_quantizer": _nvfp4_quantizer, - **_default_disabled_quantizer_cfg, - **_mamba_moe_disabled_quantizer_cfg, - "*mixer.in_proj*": {"enable": False}, # Skip mamba linear - "*mixer.out_proj*": {"enable": False}, # Skip mamba linear - }, + "quant_cfg": [ + *_base_disable_all, + {"quantizer_path": "*weight_quantizer", "cfg": _nvfp4_cfg, "enable": True}, + {"quantizer_path": "*input_quantizer", "cfg": _nvfp4_cfg, "enable": True}, + *_default_disabled_quantizer_cfg, + *_mamba_moe_disabled_quantizer_cfg, + {"quantizer_path": "*mixer.in_proj*", "enable": False}, # Skip mamba linear + {"quantizer_path": "*mixer.out_proj*", "enable": False}, # Skip mamba linear + ], "algorithm": "max", } - NVFP4_AWQ_LITE_CFG = _nvfp4_selective_quant_cfg(["*"], algorithm="awq_lite") NVFP4_AWQ_CLIP_CFG = _nvfp4_selective_quant_cfg(["*"], algorithm={"method": "awq_clip"}) @@ -506,65 +634,85 @@ def _nvfp4_selective_quant_cfg( ["*"], algorithm={"method": "awq_full", "alpha_step": 0.1} ) - NVFP4_AFFINE_KV_CFG = { - "quant_cfg": { - "*[kv]_bmm_quantizer": { - **_nvfp4_quantizer, - "bias": {-2: None, -4: None, "type": "static"}, + "quant_cfg": [ + { + "quantizer_path": "*[kv]_bmm_quantizer", + "cfg": { + **_nvfp4_cfg, + "bias": {-2: None, -4: None, "type": "static"}, + }, + "enable": True, }, - "default": {"enable": False}, - }, - "algorithm": "max", + ] } NVFP4_KV_CFG = { - "quant_cfg": { - "*[kv]_bmm_quantizer": _nvfp4_quantizer, - "default": {"enable": False}, - }, - "algorithm": "max", + "quant_cfg": [ + {"quantizer_path": "*[kv]_bmm_quantizer", "cfg": _nvfp4_cfg, "enable": True}, + ] } # Moved from examples/diffusers/quantization/config.py to here NVFP4_FP8_MHA_CONFIG = { - "quant_cfg": { - "*weight_quantizer": _nvfp4_quantizer, - "*input_quantizer": _nvfp4_quantizer, - "*output_quantizer": {"enable": False}, - "*q_bmm_quantizer": { - "num_bits": (4, 3), + "quant_cfg": [ + *_base_disable_all, + {"quantizer_path": "*weight_quantizer", "cfg": _nvfp4_cfg, "enable": True}, + {"quantizer_path": "*input_quantizer", "cfg": _nvfp4_cfg, "enable": True}, + {"quantizer_path": "*output_quantizer", "enable": False}, + { + "quantizer_path": "*q_bmm_quantizer", + "cfg": { + "num_bits": (4, 3), + }, }, - "*k_bmm_quantizer": { - "num_bits": (4, 3), + { + "quantizer_path": "*k_bmm_quantizer", + "cfg": { + "num_bits": (4, 3), + }, }, - "*v_bmm_quantizer": { - "num_bits": (4, 3), + { + "quantizer_path": "*v_bmm_quantizer", + "cfg": { + "num_bits": (4, 3), + }, }, - "*softmax_quantizer": { - "num_bits": (4, 3), + { + "quantizer_path": "*softmax_quantizer", + "cfg": { + "num_bits": (4, 3), + }, }, - "transformer_blocks*bmm2_output_quantizer": { - "num_bits": (4, 3), + { + "quantizer_path": "transformer_blocks*bmm2_output_quantizer", + "cfg": { + "num_bits": (4, 3), + }, }, - "default": {"enable": False}, - }, + ], "algorithm": "max", } NVFP4_KV_ROTATE_CFG = { - "quant_cfg": { - "*q_bmm_quantizer": { + "quant_cfg": [ + { + "quantizer_path": "*q_bmm_quantizer", + "cfg": { + "rotate": True, + }, "enable": False, - "rotate": True, }, - "*k_bmm_quantizer": { - **_nvfp4_quantizer, - "rotate": True, + { + "quantizer_path": "*k_bmm_quantizer", + "cfg": { + **_nvfp4_cfg, + "rotate": True, + }, + "enable": True, }, - "*v_bmm_quantizer": _nvfp4_quantizer, - }, - "algorithm": "max", + {"quantizer_path": "*v_bmm_quantizer", "cfg": _nvfp4_cfg, "enable": True}, + ] } NVFP4_SVDQUANT_DEFAULT_CFG = _nvfp4_selective_quant_cfg( @@ -572,40 +720,54 @@ def _nvfp4_selective_quant_cfg( ) W4A8_NVFP4_FP8_CFG = { - "quant_cfg": { - "*weight_quantizer": { - "num_bits": (2, 1), - "block_sizes": {-1: 32, "type": "dynamic", "scale_bits": (4, 3)}, + "quant_cfg": [ + *_base_disable_all, + { + "quantizer_path": "*weight_quantizer", + "cfg": { + "num_bits": (2, 1), + "block_sizes": {-1: 32, "type": "dynamic", "scale_bits": (4, 3)}, + }, "enable": True, }, - "*input_quantizer": { - "num_bits": (4, 3), + { + "quantizer_path": "*input_quantizer", + "cfg": { + "num_bits": (4, 3), + }, "enable": True, }, - **_default_disabled_quantizer_cfg, - }, + *_default_disabled_quantizer_cfg, + ], "algorithm": "max", } MXFP4_MLP_WEIGHT_ONLY_CFG = { - "quant_cfg": { - "*mlp*weight_quantizer": { - "num_bits": (2, 1), - "block_sizes": {-1: 32, "type": "dynamic", "scale_bits": (8, 0)}, + "quant_cfg": [ + *_base_disable_all, + { + "quantizer_path": "*mlp*weight_quantizer", + "cfg": { + "num_bits": (2, 1), + "block_sizes": {-1: 32, "type": "dynamic", "scale_bits": (8, 0)}, + }, "enable": True, }, - "*block_sparse_moe*weight_quantizer": { - "num_bits": (2, 1), - "block_sizes": {-1: 32, "type": "dynamic", "scale_bits": (8, 0)}, + { + "quantizer_path": "*block_sparse_moe*weight_quantizer", + "cfg": { + "num_bits": (2, 1), + "block_sizes": {-1: 32, "type": "dynamic", "scale_bits": (8, 0)}, + }, "enable": True, }, - **_default_disabled_quantizer_cfg, - }, + *_default_disabled_quantizer_cfg, + ], "algorithm": None, } NVFP4_MLP_WEIGHT_ONLY_CFG = _nvfp4_selective_quant_cfg( - ["*mlp*", "*block_sparse_moe*"], quantizer=_nvfp4_quantizer_bs32, weight_only=True + ["*mlp*", "*block_sparse_moe*"], quantizer=_nvfp4_cfg_bs32, weight_only=True ) NVFP4_EXPERTS_ONLY_CFG = _nvfp4_selective_quant_cfg(["*mlp.experts*", "*block_sparse_moe*"]) NVFP4_MLP_ONLY_CFG = _nvfp4_selective_quant_cfg(["*mlp*", "*block_sparse_moe*"]) @@ -1346,23 +1508,106 @@ class GPTQLiteConfig(QuantizeAlgorithmConfig): ) -QuantizeQuantCfgType = dict[ - str | Callable, - QuantizerAttributeConfig - | list[QuantizerAttributeConfig] - | dict[str | Callable, QuantizerAttributeConfig | list[QuantizerAttributeConfig]], -] +QuantizeQuantCfgType = list[QuantizerCfgEntry] _QuantizeAlgoCfgType = str | dict | QuantizeAlgorithmConfig | None QuantizeAlgoCfgType = _QuantizeAlgoCfgType | list[_QuantizeAlgoCfgType] | None +def normalize_quant_cfg_list(v: list) -> list[QuantizerCfgEntry]: + """Normalize a raw quant_cfg list into a list of :class:`QuantizerCfgEntry` dicts. + + Supports the following input forms per entry: + + - New format: ``{"quantizer_path": ..., "enable": ..., "cfg": ...}`` — passed through. + - Legacy single-key format: ``{"": }`` — converted to new format. + - Legacy ``nn.*``-scoped format: ``{"nn.": {"": }}`` — converted + to a new-format entry with ``parent_class`` set. + + **Validation** — an entry is rejected if it carries no instruction, i.e. it specifies neither + ``cfg`` nor ``enable``. Concretely, the following are invalid: + + - An empty entry ``{}``. + - An entry with only ``quantizer_path`` and no other keys — the only effect would be an + implicit ``enable=True``, which must be stated explicitly. + + **Normalization** — after conversion and validation every entry is put into canonical form: + + - ``enable`` is set to ``True`` if not explicitly specified. + - ``cfg`` is set to ``None`` if not present in the entry. + + Every returned entry is therefore guaranteed to have the keys ``quantizer_path``, ``enable``, + and ``cfg`` (plus optionally ``parent_class``). + + Args: + v: A list of raw quant_cfg entries in any supported format. + + Returns: + A list of :class:`QuantizerCfgEntry` dicts in canonical normalized form. + + Raises: + ValueError: If any entry has only ``quantizer_path`` with neither ``cfg`` nor ``enable``, + or if the entry format is not recognized. + """ + + def _dict_to_entry(key: str, value) -> QuantizerCfgEntry: + if isinstance(key, str) and key.startswith("nn."): + assert isinstance(value, dict) and len(value) == 1 + q_path, sub_cfg = next(iter(value.items())) + sub_cfg = dict(sub_cfg) + enable = sub_cfg.pop("enable", None) + entry: QuantizerCfgEntry = { + "parent_class": key, + "quantizer_path": q_path, + "cfg": sub_cfg, + } + if enable is not None: + entry["enable"] = enable + return entry + else: + if isinstance(value, dict): + cfg = {k: val for k, val in value.items() if k != "enable"} + enable = value.get("enable") + else: + cfg = value + enable = None + entry = {"quantizer_path": key, "cfg": cfg} + if enable is not None: + entry["enable"] = enable + return entry + + result: list[QuantizerCfgEntry] = [] + for raw in v: + if isinstance(raw, dict) and "quantizer_path" in raw: + entry: dict = dict(raw) # copy to avoid mutating caller's data + elif isinstance(raw, dict) and len(raw) == 1: + key, val = next(iter(raw.items())) + entry = dict(_dict_to_entry(key, val)) + else: + raise ValueError(f"Invalid quant_cfg entry: {raw!r}.") + + # Validate: must carry at least one instruction beyond the path selector. + if "cfg" not in entry and "enable" not in entry: + raise ValueError( + f"Invalid quant_cfg entry: {raw!r} — each entry must specify 'cfg', 'enable', " + "or both. An entry with only 'quantizer_path' has no effect (implicit " + "enable=True is not allowed; set it explicitly)." + ) + + # Normalize: make enable and cfg always explicit. + entry.setdefault("enable", True) + entry.setdefault("cfg", None) + + result.append(cast("QuantizerCfgEntry", entry)) + return result + + class QuantizeConfig(ModeloptBaseConfig): """Default configuration for ``quantize`` mode.""" quant_cfg: QuantizeQuantCfgType = ModeloptField( - default={"default": {"num_bits": 8, "axis": None}}, + default=[{"quantizer_path": "*", "cfg": {"num_bits": 8, "axis": None}}], title="Quantization configuration", validate_default=True, ) @@ -1374,6 +1619,25 @@ class QuantizeConfig(ModeloptBaseConfig): validate_default=True, ) + @field_validator("quant_cfg", mode="before") + @classmethod + def normalize_quant_cfg(cls, v): + """Normalize quant_cfg entries: convert dict and tuple forms to QuantizerCfgEntry dicts.""" + if not isinstance(v, list): + return v + return normalize_quant_cfg_list(v) + + @field_validator("quant_cfg", mode="after") + @classmethod + def validate_quant_cfg_entries(cls, v): + """Validate quantizer attribute configs to surface errors (e.g. invalid axis/block_sizes).""" + qac_fields = set(QuantizerAttributeConfig.model_fields.keys()) + for entry in v: + cfg = entry.get("cfg", {}) + if isinstance(cfg, dict) and qac_fields & set(cfg.keys()): + QuantizerAttributeConfig.model_validate(cfg) + return v + class CompressConfig(ModeloptBaseConfig): """Default configuration for ``compress`` mode.""" @@ -1410,7 +1674,19 @@ def _not_dynamic(cfg): and cfg.get("*", {}).get("enable", True) ) - for name, cfg in config.get("quant_cfg", {}).items(): + quant_cfg: list = config.get("quant_cfg") or [] + for entry in quant_cfg: + name = ( + entry["quantizer_path"] + if isinstance(entry, dict) and "quantizer_path" in entry + else entry[0] + ) + if isinstance(entry, dict) and "quantizer_path" in entry: + cfg = dict(entry.get("cfg") or {}) + if "enable" in entry: + cfg["enable"] = entry["enable"] + else: + cfg = entry[1] if "weight_quantizer" in name: # We don't calibrate weight quantizer continue @@ -1418,10 +1694,8 @@ def _not_dynamic(cfg): if isinstance(cfg, list): for _config in cfg: if _not_dynamic(_config): - print(f"{cfg}: True") return True - elif _not_dynamic(cfg): - print(f"{cfg}: True") + elif isinstance(cfg, dict) and _not_dynamic(cfg): return True return False diff --git a/modelopt/torch/quantization/conversion.py b/modelopt/torch/quantization/conversion.py index 472252e1c..47552c663 100644 --- a/modelopt/torch/quantization/conversion.py +++ b/modelopt/torch/quantization/conversion.py @@ -19,7 +19,7 @@ import warnings from collections.abc import Callable from contextlib import contextmanager -from typing import Any +from typing import Any, cast import torch.nn as nn @@ -33,6 +33,7 @@ QuantizeQuantCfgType, QuantizerAttributeConfig, _QuantizeExportConfig, + normalize_quant_cfg_list, ) from .nn import ( NVFP4StaticQuantizer, @@ -47,7 +48,8 @@ __all__ = [ "register", "replace_quant_module", - "set_quantizer_attribute", + "set_quantizer_attributes_full", + "set_quantizer_attributes_partial", "set_quantizer_by_cfg", "set_quantizer_by_cfg_context", "unregister", @@ -60,7 +62,7 @@ def convert_to_quantized_model(model: ModelLikeModule, config: QuantizeConfig) - model = model.init_modellike() if isinstance(model, ModelLikeModule) else model replace_quant_module(model, version=ModeloptStateManager(model).state_version) - set_quantizer_by_cfg(model, config.get("quant_cfg", {})) + set_quantizer_by_cfg(model, config.get("quant_cfg", [])) metadata = {} update_quantize_metadata(model, config, metadata) @@ -76,7 +78,7 @@ def convert_to_quantized_model_svdquant( model = model.init_modellike() if isinstance(model, ModelLikeModule) else model create_and_replace_svdquant_linear_on_the_fly(model) - set_quantizer_by_cfg(model, config.get("quant_cfg", {})) + set_quantizer_by_cfg(model, config.get("quant_cfg", [])) metadata = {} update_quantize_metadata(model, config, metadata) @@ -211,116 +213,234 @@ def _replace_quant_module(model: nn.Module, version=None, registry=QuantModuleRe _replace_quant_module(getattr(model, name), version=version, registry=registry) -def set_quantizer_by_cfg(quant_model: nn.Module, quant_cfg: QuantizeQuantCfgType | dict): - """Update the quantizer attributes based on the specified `quant_cfg`. +def set_quantizer_by_cfg(quant_model: nn.Module, quant_cfg: QuantizeQuantCfgType): + """Apply a quantization config list to the quantizers in ``quant_model``. - `quant_cfg` is a dictionary mapping wildcards or filter functions - to its quantizer attributes which are defined in - :class:`QuantizerAttributeConfig <.config.QuantizerAttributeConfig>`. - The wildcards or filter functions are matched against the quantizer module names. - The specified quantizer attributes of the matched quantizer modules are set accordingly. - The key ``"default"`` is a special key that sets the quantizer attributes of all the quantizers for which - no other wildcard or filter functions match the quantizer module name. + ``quant_cfg`` is an **ordered list** of :class:`QuantizerCfgEntry <.config.QuantizerCfgEntry>` + dicts. Each entry has the following fields: - In addition, the dictionary entries could also be pytorch module class names mapping the class specific - quantization configuration. The pytorch modules should have a quantized equivalent. + - ``quantizer_path`` *(required)*: wildcard matched against quantizer module names via + :func:`fnmatch`. + - ``cfg`` *(optional)*: a dict of :class:`QuantizerAttributeConfig <.config.QuantizerAttributeConfig>` + fields, or a list of such dicts for sequential quantization. + - ``enable`` *(optional)*: ``True`` or ``False`` to toggle matched quantizers on or off. + When omitted but ``cfg`` is present, defaults to ``True``. Every entry must specify at + least one of ``cfg`` or ``enable`` — an entry with only ``quantizer_path`` is invalid. + - ``parent_class`` *(optional)*: restricts matching to quantizers whose immediate parent + module is of this PyTorch class name. - See :meth:`set_quantizer_attribute ` - for more details. + **Ordering and atomicity:** entries are applied in list order; later entries override earlier + ones for any quantizer they match. Each entry with a ``cfg`` is a **complete replacement** — + unspecified attributes revert to their defaults rather than inheriting from a prior entry. + The typical pattern is to deny all first (``{"quantizer_path": "*", "enable": False}``), then + selectively enable and configure target quantizers in subsequent entries. + + **``enable`` and ``cfg`` are independent:** + + - An entry with ``cfg`` (and optionally ``enable``) fully replaces the matched quantizer's + attributes. If ``enable`` is omitted, the quantizer is implicitly enabled. + - ``{"enable": False}`` without ``cfg`` **only** toggles the matched quantizers off, leaving + all other attributes unchanged. + - ``{"enable": True}`` without ``cfg`` **only** toggles the matched quantizers on, using + whatever attributes they currently have (or their defaults if never configured). + + See :ref:`quant-cfg` for the full format reference and common patterns. """ - quant_cfg = quant_cfg.copy() - if "default" in quant_cfg: - set_quantizer_attribute(quant_model, "*", quant_cfg["default"]) - quant_cfg.pop("default") - - for pattern, cfg in quant_cfg.items(): - if str(pattern) in QuantModuleRegistry: - parent_class = QuantModuleRegistry[str(pattern)] - assert isinstance(cfg, dict), ( - f"Expected a dictionary for quantizer configuration for child tensor quantizers of {parent_class}." + quant_cfg = normalize_quant_cfg_list(quant_cfg) + + for entry in quant_cfg: + quantizer_path: str = entry["quantizer_path"] + cfg = entry["cfg"] # None, dict, or list — always explicit after normalization + enable: bool = entry["enable"] # always explicit after normalization + parent_class_name = entry.get("parent_class") + parent_class = QuantModuleRegistry[parent_class_name] if parent_class_name else None + + if not cfg: + # No cfg: only toggle the enable state, leave all other attributes unchanged. + set_quantizer_attributes_partial( + quant_model, quantizer_path, {"enable": enable}, parent_class ) - for sub_pattern, sub_cfg in cfg.items(): - set_quantizer_attribute(quant_model, sub_pattern, sub_cfg, parent_class) - continue - set_quantizer_attribute(quant_model, pattern, cfg) + else: + # Has cfg: apply full replacement with the explicit enable value. + if isinstance(cfg, dict): + attributes = QuantizerAttributeConfig(**cfg, enable=enable) + else: + attributes = [QuantizerAttributeConfig(**c, enable=enable) for c in cfg] + set_quantizer_attributes_full(quant_model, quantizer_path, attributes, parent_class) -def set_quantizer_attribute( +def _match_quantizer( + wildcard_or_filter_func: str | Callable, + name: str, + module: nn.Module, + parent_class: type[nn.Module] | None, + full_model: nn.Module, +): + if not isinstance(module, (TensorQuantizer, SequentialQuantizer)): + return False + if isinstance(wildcard_or_filter_func, str): + if not fnmatch.fnmatch(name, wildcard_or_filter_func): + return False + elif callable(wildcard_or_filter_func): + if not wildcard_or_filter_func(name): + return False + else: + raise NotImplementedError(f"Unsupported type {type(wildcard_or_filter_func)}") + + return parent_class is None or isinstance( + full_model.get_submodule(".".join(name.split(".")[:-1])), parent_class + ) + + +def set_quantizer_attributes_full( quant_model: nn.Module, wildcard_or_filter_func: str | Callable, - attribute: QuantizerAttributeConfig - | list[QuantizerAttributeConfig] - | dict[ - str | Callable, - QuantizerAttributeConfig | list[QuantizerAttributeConfig], - ] - | dict - | list[dict], - parent_class: type | None = None, + attributes: QuantizerAttributeConfig | list[QuantizerAttributeConfig], + parent_class: type[nn.Module] | None = None, ): - """Finegrained adjustment of quantizer attribute by wildcard or filter function. + """Set quantizer attributes by wildcard or filter function, fully overwriting existing attributes. + + Unlike :func:`set_quantizer_attributes_partial`, this function requires a complete + :class:`QuantizerAttributeConfig <.config.QuantizerAttributeConfig>` and **replaces** the + matched quantizer's attributes entirely rather than merging with existing ones. Args: - quant_model: A pytorch model - wildcard_or_filter_func: a wildcard string or a filter function. The wildcard string is matched - against the quantizer module names. The quantizer modules are - instances of + quant_model: A pytorch model. + wildcard_or_filter_func: A wildcard string or a filter function. The wildcard string is + matched against the quantizer module names. The quantizer modules are instances of :class:`TensorQuantizer `. - The filter function takes a quantized module name as input and returns ``True`` if the + The filter function takes a quantizer module name as input and returns ``True`` if the quantizer should be adjusted and ``False`` otherwise. - attribute: An instance of :class:`QuantizerAttributeConfig <.config.QuantizerAttributeConfig>` or an equivalent - dictionary or a list of these two types. - If ``attribute`` is a list, the matched + attributes: A :class:`QuantizerAttributeConfig <.config.QuantizerAttributeConfig>` (or a + list of them) that **fully replaces** the matched quantizer's current attributes. All + fields of the config are applied — unspecified fields revert to their defaults. + If ``attributes`` is a list, the matched :class:`TensorQuantizer ` - modules will be replaced with :class:`SequentialQuantizer ` - modules having one quantizer for each attribute instance from the list. + modules will be replaced with + :class:`SequentialQuantizer ` + modules having one quantizer per attribute instance in the list. See :meth:`set_from_attribute_config() ` - for more details on the supported attributes and their types. - parent_class: (Optional) The parent class of the quantizer modules matching ``wildcard_or_filter_func`` which - should be adjusted. If ``None``, all the matching quantizer modules will be adjusted. + for details on supported attributes and their types. + parent_class: (Optional) Restrict matching to quantizers whose immediate parent module is + an instance of this class. If ``None``, all quantizers matching + ``wildcard_or_filter_func`` are adjusted. """ + if not isinstance(attributes, (QuantizerAttributeConfig, list)): + raise ValueError( + f"Invalid type for attributes: {type(attributes)}, " + "expected QuantizerAttributeConfig or list of QuantizerAttributeConfig." + ) + if isinstance(attributes, list) and not all( + isinstance(attr, QuantizerAttributeConfig) for attr in attributes + ): + raise ValueError( + "All elements in attributes list must be of type QuantizerAttributeConfig." + ) for name, module in quant_model.named_modules(): - if isinstance(module, (TensorQuantizer, SequentialQuantizer)): - if isinstance(wildcard_or_filter_func, str): - if not fnmatch.fnmatch(name, wildcard_or_filter_func): - continue - elif callable(wildcard_or_filter_func): - if not wildcard_or_filter_func(name): - continue + if _match_quantizer(wildcard_or_filter_func, name, module, parent_class, quant_model): + if isinstance(attributes, list): + if not isinstance(module, SequentialQuantizer): + parent_module = quant_model.get_submodule(name.rpartition(".")[0]) + module = SequentialQuantizer( + *(TensorQuantizer() for _ in range(len(attributes))) + ) + setattr(parent_module, name.split(".")[-1], module) + elif len(attributes) != len(module): + warnings.warn( + f"The number of attributes ({len(attributes)}) does not match the number of " + f"quantizers of {module} leading to partial assignment.", + ) + module.set_from_attribute_config(attributes) else: - raise NotImplementedError(f"Unsupported type {type(wildcard_or_filter_func)}") + cast("TensorQuantizer", module).set_from_attribute_config(attributes) - if parent_class is not None and not isinstance( - quant_model.get_submodule(".".join(name.split(".")[:-1])), parent_class - ): - continue - if isinstance(attribute, list) and not isinstance(module, SequentialQuantizer): - parent_module = quant_model.get_submodule(name.rpartition(".")[0]) - module = SequentialQuantizer(*(TensorQuantizer() for _ in range(len(attribute)))) - setattr(parent_module, name.split(".")[-1], module) - elif isinstance(attribute, list) and len(attribute) != len(module): - warnings.warn( - f"The number of attributes ({len(attribute)}) does not match the number of " - f"quantizers of {module} leading to partial assignment.", - ) - module.set_from_attribute_config(attribute) +def set_quantizer_attributes_partial( + quant_model: nn.Module, + wildcard_or_filter_func: str | Callable, + partial_attributes: dict[str, Any] | list[dict[str, Any]], + parent_class: type[nn.Module] | None = None, +): + """Update a subset of quantizer attributes by wildcard or filter function, merging with existing attributes. + + Unlike :func:`set_quantizer_attributes_full`, this function accepts an arbitrary subset of + quantizer attributes as a plain ``dict`` and **merges** them into the matched quantizer's + current attributes, leaving unspecified attributes unchanged. + + Args: + quant_model: A pytorch model. + wildcard_or_filter_func: A wildcard string or a filter function. The wildcard string is + matched against the quantizer module names. The quantizer modules are instances of + :class:`TensorQuantizer `. + The filter function takes a quantizer module name as input and returns ``True`` if the + quantizer should be adjusted and ``False`` otherwise. + partial_attributes: A ``dict`` (or a list of ``dict``) containing only the attributes to + update. Keys must be valid fields of + :class:`QuantizerAttributeConfig <.config.QuantizerAttributeConfig>`. Only the + specified keys are written; all other attributes on the quantizer remain unchanged. + When a ``dict`` is passed and the matched module is a + :class:`SequentialQuantizer `, + the dict is broadcast to every sub-quantizer. + When a ``list`` is passed, the matched module must already be a + :class:`SequentialQuantizer ` — + unlike :func:`set_quantizer_attributes_full`, this function will **not** replace a + :class:`TensorQuantizer ` with a + ``SequentialQuantizer``. + See + :meth:`set_from_attribute_config() ` + for details on supported attributes and their types. + parent_class: (Optional) Restrict matching to quantizers whose immediate parent module is + an instance of this class. If ``None``, all quantizers matching + ``wildcard_or_filter_func`` are adjusted. + """ + if not isinstance(partial_attributes, (dict, list)): + raise ValueError( + f"Invalid type for attributes: {type(partial_attributes)}, expected dictionary or list of dict." + ) + if isinstance(partial_attributes, list) and not all( + isinstance(attr, dict) for attr in partial_attributes + ): + raise ValueError("All elements in attributes list must be of type dict.") + + for name, module in quant_model.named_modules(): + if _match_quantizer(wildcard_or_filter_func, name, module, parent_class, quant_model): + module = cast("TensorQuantizer | SequentialQuantizer", module) # for type checker + if isinstance(partial_attributes, list): + if not isinstance(module, SequentialQuantizer): + raise ValueError( + f"Attributes is a list but {module} is not a SequentialQuantizer." + ) + module.set_from_attribute_config(partial_attributes) + elif isinstance(module, SequentialQuantizer): + # Broadcast the dict to all sub-quantizers. + module.set_from_attribute_config([partial_attributes] * len(module)) + else: + module.set_from_attribute_config(partial_attributes) @contextmanager -def set_quantizer_by_cfg_context(quant_model: nn.Module, quant_cfg: QuantizeQuantCfgType | dict): - """Context manager for setting quantizer attributes using `quant_cfg`. +def set_quantizer_by_cfg_context(quant_model: nn.Module, quant_cfg: QuantizeQuantCfgType): + """Context manager that temporarily applies a quantization config and restores the original state on exit. + + Calls :func:`set_quantizer_by_cfg` on entry and reverts every + :class:`TensorQuantizer ` in + ``quant_model`` to its original attributes on exit. - The set attributes will be reset to the original attributes after exiting the context manager. - See :meth:`set_quantizer_by_cfg` for more details. + .. caution:: + Changing stateful attributes such as ``calibrator`` inside this context may produce + unexpected behavior because those objects are not deep-copied during save/restore. - Use this context manager with caution. Changing certain attributes of the quantizer such as - `calibrator` can lead to unexpected behavior. + Args: + quant_model: A quantized PyTorch model whose quantizers will be temporarily reconfigured. + quant_cfg: A quantization config (or list of + :class:`QuantizerCfgEntry <.config.QuantizerCfgEntry>` dicts) passed directly to + :func:`set_quantizer_by_cfg`. Sequential ``cfg`` lists are not allowed. + + Yields: + None — the context body runs with the new quantizer attributes active. """ - assert not any(cfg for cfg in quant_cfg.values() if isinstance(cfg, (list, tuple))), ( - "list of config not support." - ) + quant_cfg = normalize_quant_cfg_list(quant_cfg) original_attributes = {} for name, module in quant_model.named_modules(): diff --git a/modelopt/torch/quantization/model_calib.py b/modelopt/torch/quantization/model_calib.py index ed57ea3fc..4616c82fc 100644 --- a/modelopt/torch/quantization/model_calib.py +++ b/modelopt/torch/quantization/model_calib.py @@ -1101,7 +1101,9 @@ def forward(self, input, *args, **kwargs): self.awq_lite.num_cache_steps += 1 self.awq_lite.num_tokens += input.numel() / input.shape[-1] if self.awq_lite.is_input_quantized: - with set_quantizer_by_cfg_context(self.input_quantizer, {"*": {"enable": True}}): + with set_quantizer_by_cfg_context( + self.input_quantizer, [{"quantizer_path": "*", "enable": True}] + ): max_calibrate(self.input_quantizer, lambda quantizer: quantizer(input), False) return out_actual diff --git a/modelopt/torch/quantization/model_quant.py b/modelopt/torch/quantization/model_quant.py index 4aa1ff46b..1d0314185 100644 --- a/modelopt/torch/quantization/model_quant.py +++ b/modelopt/torch/quantization/model_quant.py @@ -30,13 +30,15 @@ from modelopt.torch.opt.searcher import ForwardLoop from modelopt.torch.opt.utils import forward_with_reshard from modelopt.torch.quantization.config import QuantizeConfig -from modelopt.torch.quantization.conversion import set_quantizer_by_cfg +from modelopt.torch.quantization.conversion import ( + set_quantizer_attributes_partial, + set_quantizer_by_cfg, +) from modelopt.torch.utils import atomic_print from .algorithms import AutoQuantizeGradientSearcher, AutoQuantizeKLDivSearcher, QuantRecipe from .algorithms import get_auto_quantize_config as _get_auto_quantize_config from .config import QuantizeAlgoCfgType -from .conversion import set_quantizer_attribute from .mode import QuantizeModeRegistry, get_modelike_from_algo_cfg from .nn import QuantModule, TensorQuantizer from .utils import is_quantized @@ -178,17 +180,15 @@ def quantize( .. code-block::python config = { - - "quant_cfg": { + "quant_cfg": [ + # Disable all quantizers by default + {"quantizer_path": "*", "enable": False}, # "num_bits" specifies the number of bits for quantization # "axis" specifies the axis for quantization - "*weight_quantizer": {"num_bits": 8, "axis": 0}, - "*input_quantizer": {"num_bits": 8, "axis": -1}, - - # Default quantization settings - "default": {"num_bits": 8, "axis": None}, - } - "algorithm": "max" + {"quantizer_path": "*weight_quantizer", "cfg": {"num_bits": 8, "axis": 0}}, + {"quantizer_path": "*input_quantizer", "cfg": {"num_bits": 8, "axis": -1}}, + ], + "algorithm": "max", } See :ref:`Quantization Formats ` to learn more about the supported @@ -323,10 +323,13 @@ def auto_quantize( .. code-block:: python INT8_CUSTOM_QUANT_CFG = { - "quant_cfg": { - "*weight_quantizer": {"num_bits": 8, "axis": 0}, - "*input_quantizer": {"num_bits": 8, "axis": None}, - }, + "quant_cfg": [ + {"quantizer_path": "*weight_quantizer", "cfg": {"num_bits": 8, "axis": 0}}, + { + "quantizer_path": "*input_quantizer", + "cfg": {"num_bits": 8, "axis": None}, + }, + ], "algorithm": "smoothquant", } @@ -527,7 +530,7 @@ def forward_backward_step(model, batch) -> None: "checkpoint": checkpoint, } # Disable all quantizers; AutoQuantize will enable the needed ones - set_quantizer_by_cfg(model, {"*": {"enable": False}}) + set_quantizer_by_cfg(model, [{"quantizer_path": "*", "enable": False}]) searcher.search(model, constraints, config=search_config) # type: ignore[arg-type] return model, searcher.state_dict() @@ -574,12 +577,12 @@ def get_auto_quantize_config(search_state, constraints=None, verbose=False): def disable_quantizer(model: nn.Module, wildcard_or_filter_func: str | Callable): """Disable quantizer by wildcard or filter function.""" - set_quantizer_attribute(model, wildcard_or_filter_func, {"enable": False}) + set_quantizer_attributes_partial(model, wildcard_or_filter_func, {"enable": False}) def enable_quantizer(model: nn.Module, wildcard_or_filter_func: str | Callable): """Enable quantizer by wildcard or filter function.""" - set_quantizer_attribute(model, wildcard_or_filter_func, {"enable": True}) + set_quantizer_attributes_partial(model, wildcard_or_filter_func, {"enable": True}) @atomic_print diff --git a/modelopt/torch/quantization/nn/modules/tensor_quantizer.py b/modelopt/torch/quantization/nn/modules/tensor_quantizer.py index ec2c3cfc5..3ff7401ec 100644 --- a/modelopt/torch/quantization/nn/modules/tensor_quantizer.py +++ b/modelopt/torch/quantization/nn/modules/tensor_quantizer.py @@ -203,8 +203,8 @@ def __init__( # Optional quantizer cache for caching quantizer related encoding or tensors. self._quantizer_cache = None - def set_from_attribute_config(self, attribute_cfg: QuantizerAttributeConfig | dict): - """Set quantizer attributes from attribute_dict. + def set_from_attribute_config(self, attribute_cfg: QuantizerAttributeConfig | dict[str, Any]): + """Set quantizer attributes from attribute_cfg. The attributes are defined in :class:`QuantizerAttributeConfig `. @@ -218,12 +218,27 @@ def _calibrator_setter(val): calib_cls, args, kwargs = standardize_constructor_args(val) return calib_cls(*args, **kwargs) + def _axis_setter(val): + if getattr(self, "_calibrator", None) is not None: + self._calibrator._axis = val + return val + + def _block_sizes_setter(val): + if val is not None: + # block_sizes and axis are mutually exclusive; clear axis when block_sizes is set + setattr(self, "_axis", None) + if getattr(self, "_calibrator", None) is not None: + self._calibrator._axis = None + return val + # Some attributes need custom handling. # By default, attributes from config are mapped to a name ``f"_{attribute}"`` _custom_setters: dict[str, tuple[str, Callable]] = { "enable": ("_disabled", lambda val: val is False), "type": ("_dynamic", lambda val: val == "dynamic"), "calibrator": ("_calibrator", _calibrator_setter), + "axis": ("_axis", _axis_setter), + "block_sizes": ("_block_sizes", _block_sizes_setter), "backend": ("backend", lambda val: val), "backend_extra_args": ("backend_extra_args", lambda val: val or {}), "use_constant_amax": ("_use_constant_amax", lambda val: val), @@ -1408,10 +1423,7 @@ def get_modelopt_state(self) -> dict[str, Any]: return {"num_quantizers": len(self), "is_sequential_quantizer": True} def set_from_attribute_config( - self, - attributes: list[dict[str, Any] | QuantizerAttributeConfig] - | dict[str, Any] - | QuantizerAttributeConfig, + self, attributes: list[QuantizerAttributeConfig] | list[dict[str, Any]] ): """Set the attributes of contained quantizers from a list of attribute_dicts.""" if not isinstance(attributes, (list, tuple)): diff --git a/modelopt/torch/quantization/utils/core_utils.py b/modelopt/torch/quantization/utils/core_utils.py index 4340b8dc1..b9008a702 100644 --- a/modelopt/torch/quantization/utils/core_utils.py +++ b/modelopt/torch/quantization/utils/core_utils.py @@ -27,6 +27,7 @@ from torch.distributed.fsdp._fully_shard._fsdp_param import FSDPParam from torch.distributed.tensor import Replicate +from modelopt.torch.quantization.config import QuantizerCfgEntry from modelopt.torch.utils import get_unwrapped_name, print_rank_0 if TYPE_CHECKING: @@ -310,11 +311,15 @@ def calibrate_with_adapters(model, args): def disable_lora_quantizers_in_config(config, layers): """Turns off input, weight, and output quantizers for LoRA weights and LoRALinear layers in config.""" - config["quant_cfg"]["*lora*"] = {"enable": False} + config["quant_cfg"].append({"quantizer_path": "*lora*", "enable": False}) for layer in layers: - config["quant_cfg"][f"*{layer}.input_quantizer"] = {"enable": False} - config["quant_cfg"][f"*{layer}.weight_quantizer"] = {"enable": False} - config["quant_cfg"][f"*{layer}.output_quantizer"] = {"enable": False} + config["quant_cfg"].append({"quantizer_path": f"*{layer}.input_quantizer", "enable": False}) + config["quant_cfg"].append( + {"quantizer_path": f"*{layer}.weight_quantizer", "enable": False} + ) + config["quant_cfg"].append( + {"quantizer_path": f"*{layer}.output_quantizer", "enable": False} + ) return config @@ -823,13 +828,25 @@ def fsdp2_aware_weight_update(root_model, modules_to_update, reshard=True): def update_quant_cfg_with_kv_cache_quant( - quant_cfg: dict[str, Any], kv_cache_quant_cfg: dict[str, Any] + quant_cfg: dict[str, Any], kv_cache_quant_cfg: list[QuantizerCfgEntry] ) -> dict[str, Any]: - """Update the quant_cfg with the kv cache quant_cfg.""" + """Update the quant_cfg with the kv cache quant_cfg. + + Args: + quant_cfg: The outer quantization config dict (with ``"quant_cfg"`` and ``"algorithm"`` keys). + kv_cache_quant_cfg: A list of :class:`QuantizerCfgEntry + ` dicts for KV cache quantization, + typically ``some_kv_cfg["quant_cfg"]``. + + Returns: + A deep copy of ``quant_cfg`` with the KV cache entries appended to ``quant_cfg["quant_cfg"]``. + """ # If quant_cfg["quant_cfg"] is None, it corresponds to only kv cache quantization case quant_cfg = copy.deepcopy(quant_cfg) - quant_cfg["quant_cfg"] = quant_cfg.get("quant_cfg") or {"default": {"enable": False}} - quant_cfg["quant_cfg"].update(kv_cache_quant_cfg) + inner: list[QuantizerCfgEntry] = quant_cfg.get("quant_cfg") or [ + {"quantizer_path": "*", "enable": False} + ] + quant_cfg["quant_cfg"] = inner + list(kv_cache_quant_cfg) # Set default algorithm for kv cache quantization if not provided. if not quant_cfg.get("algorithm"): diff --git a/modelopt/torch/sparsity/attention_sparsity/conversion.py b/modelopt/torch/sparsity/attention_sparsity/conversion.py index cdc2aed94..0255caf4e 100644 --- a/modelopt/torch/sparsity/attention_sparsity/conversion.py +++ b/modelopt/torch/sparsity/attention_sparsity/conversion.py @@ -194,7 +194,7 @@ def set_sparse_attention_attribute( ): """Set sparse attention attributes for modules matching pattern. - Similar to quantization's set_quantizer_attribute. + Similar to quantization's set_quantizer_attributes_partial. Args: model: Model to configure diff --git a/modelopt_recipes/general/ptq/fp8_default-fp8_kv.yml b/modelopt_recipes/general/ptq/fp8_default-fp8_kv.yml index 72630965b..1024a60c1 100644 --- a/modelopt_recipes/general/ptq/fp8_default-fp8_kv.yml +++ b/modelopt_recipes/general/ptq/fp8_default-fp8_kv.yml @@ -19,46 +19,49 @@ metadata: ptq_cfg: algorithm: max quant_cfg: - '*input_quantizer': - num_bits: e4m3 - axis: - '*weight_quantizer': - num_bits: e4m3 - axis: - default: + - quantizer_path: '*' enable: false - '*block_sparse_moe.gate*': + - quantizer_path: '*input_quantizer' + cfg: + num_bits: e4m3 + axis: + - quantizer_path: '*weight_quantizer' + cfg: + num_bits: e4m3 + axis: + - quantizer_path: '*[kv]_bmm_quantizer' + enable: true + cfg: + num_bits: e4m3 + - quantizer_path: '*block_sparse_moe.gate*' enable: false - '*linear_attn.conv1d*': + - quantizer_path: '*linear_attn.conv1d*' enable: false - '*lm_head*': + - quantizer_path: '*lm_head*' enable: false - '*mixer.conv1d*': + - quantizer_path: '*mixer.conv1d*' enable: false - '*mlp.gate.*': + - quantizer_path: '*mlp.gate.*' enable: false - '*mlp.shared_expert_gate.*': + - quantizer_path: '*mlp.shared_expert_gate.*' enable: false - '*output_layer*': + - quantizer_path: '*output_layer*' enable: false - '*proj_out.*': + - quantizer_path: '*proj_out.*' enable: false - '*router*': + - quantizer_path: '*router*' enable: false - output.*: + - quantizer_path: 'output.*' + enable: false + - parent_class: 'nn.BatchNorm1d' + quantizer_path: '*' + enable: false + - parent_class: 'nn.BatchNorm2d' + quantizer_path: '*' + enable: false + - parent_class: 'nn.BatchNorm3d' + quantizer_path: '*' + enable: false + - parent_class: 'nn.LeakyReLU' + quantizer_path: '*' enable: false - nn.BatchNorm1d: - '*': - enable: false - nn.BatchNorm2d: - '*': - enable: false - nn.BatchNorm3d: - '*': - enable: false - nn.LeakyReLU: - '*': - enable: false - '*[kv]_bmm_quantizer': - num_bits: e4m3 - enable: true diff --git a/modelopt_recipes/general/ptq/nvfp4_default-fp8_kv.yml b/modelopt_recipes/general/ptq/nvfp4_default-fp8_kv.yml index 73e84b1bc..524fb6d97 100644 --- a/modelopt_recipes/general/ptq/nvfp4_default-fp8_kv.yml +++ b/modelopt_recipes/general/ptq/nvfp4_default-fp8_kv.yml @@ -19,54 +19,57 @@ metadata: ptq_cfg: algorithm: max quant_cfg: - '*weight_quantizer': - block_sizes: - -1: 16 - type: dynamic - scale_bits: e4m3 - num_bits: e2m1 + - quantizer_path: '*' + enable: false + - quantizer_path: '*weight_quantizer' + enable: true + cfg: + block_sizes: + -1: 16 + type: dynamic + scale_bits: e4m3 + num_bits: e2m1 + - quantizer_path: '*input_quantizer' enable: true - '*input_quantizer': - block_sizes: - -1: 16 - type: dynamic - scale_bits: e4m3 - num_bits: e2m1 + cfg: + block_sizes: + -1: 16 + type: dynamic + scale_bits: e4m3 + num_bits: e2m1 + - quantizer_path: '*[kv]_bmm_quantizer' enable: true - default: + cfg: + num_bits: e4m3 + - quantizer_path: '*block_sparse_moe.gate*' enable: false - '*block_sparse_moe.gate*': + - quantizer_path: '*linear_attn.conv1d*' enable: false - '*linear_attn.conv1d*': + - quantizer_path: '*lm_head*' enable: false - '*lm_head*': + - quantizer_path: '*mixer.conv1d*' enable: false - '*mixer.conv1d*': + - quantizer_path: '*mlp.gate.*' enable: false - '*mlp.gate.*': + - quantizer_path: '*mlp.shared_expert_gate.*' enable: false - '*mlp.shared_expert_gate.*': + - quantizer_path: '*output_layer*' enable: false - '*output_layer*': + - quantizer_path: '*proj_out.*' enable: false - '*proj_out.*': + - quantizer_path: '*router*' enable: false - '*router*': + - quantizer_path: 'output.*' enable: false - output.*: + - parent_class: 'nn.BatchNorm1d' + quantizer_path: '*' + enable: false + - parent_class: 'nn.BatchNorm2d' + quantizer_path: '*' + enable: false + - parent_class: 'nn.BatchNorm3d' + quantizer_path: '*' + enable: false + - parent_class: 'nn.LeakyReLU' + quantizer_path: '*' enable: false - nn.BatchNorm1d: - '*': - enable: false - nn.BatchNorm2d: - '*': - enable: false - nn.BatchNorm3d: - '*': - enable: false - nn.LeakyReLU: - '*': - enable: false - '*[kv]_bmm_quantizer': - num_bits: e4m3 - enable: true diff --git a/modelopt_recipes/general/ptq/nvfp4_mlp_only-fp8_kv.yml b/modelopt_recipes/general/ptq/nvfp4_mlp_only-fp8_kv.yml index fd502e2c3..33fee0e3e 100644 --- a/modelopt_recipes/general/ptq/nvfp4_mlp_only-fp8_kv.yml +++ b/modelopt_recipes/general/ptq/nvfp4_mlp_only-fp8_kv.yml @@ -19,68 +19,73 @@ metadata: ptq_cfg: algorithm: max quant_cfg: - '*mlp*weight_quantizer': - block_sizes: - -1: 16 - type: dynamic - scale_bits: e4m3 - num_bits: e2m1 + - quantizer_path: '*' + enable: false + - quantizer_path: '*mlp*weight_quantizer' + enable: true + cfg: + block_sizes: + -1: 16 + type: dynamic + scale_bits: e4m3 + num_bits: e2m1 + - quantizer_path: '*mlp*input_quantizer' enable: true - '*mlp*input_quantizer': - block_sizes: - -1: 16 - type: dynamic - scale_bits: e4m3 - num_bits: e2m1 + cfg: + block_sizes: + -1: 16 + type: dynamic + scale_bits: e4m3 + num_bits: e2m1 + - quantizer_path: '*block_sparse_moe*weight_quantizer' enable: true - '*block_sparse_moe*weight_quantizer': - block_sizes: - -1: 16 - type: dynamic - scale_bits: e4m3 - num_bits: e2m1 + cfg: + block_sizes: + -1: 16 + type: dynamic + scale_bits: e4m3 + num_bits: e2m1 + - quantizer_path: '*block_sparse_moe*input_quantizer' enable: true - '*block_sparse_moe*input_quantizer': - block_sizes: - -1: 16 - type: dynamic - scale_bits: e4m3 - num_bits: e2m1 + cfg: + block_sizes: + -1: 16 + type: dynamic + scale_bits: e4m3 + num_bits: e2m1 + - quantizer_path: '*[kv]_bmm_quantizer' enable: true - default: + cfg: + num_bits: e4m3 + - quantizer_path: '*block_sparse_moe.gate*' enable: false - '*block_sparse_moe.gate*': + - quantizer_path: '*linear_attn.conv1d*' enable: false - '*linear_attn.conv1d*': + - quantizer_path: '*lm_head*' enable: false - '*lm_head*': + - quantizer_path: '*mixer.conv1d*' enable: false - '*mixer.conv1d*': + - quantizer_path: '*mlp.gate.*' enable: false - '*mlp.gate.*': + - quantizer_path: '*mlp.shared_expert_gate.*' enable: false - '*mlp.shared_expert_gate.*': + - quantizer_path: '*output_layer*' enable: false - '*output_layer*': + - quantizer_path: '*proj_out.*' enable: false - '*proj_out.*': + - quantizer_path: '*router*' enable: false - '*router*': + - quantizer_path: 'output.*' enable: false - output.*: + - parent_class: 'nn.BatchNorm1d' + quantizer_path: '*' + enable: false + - parent_class: 'nn.BatchNorm2d' + quantizer_path: '*' + enable: false + - parent_class: 'nn.BatchNorm3d' + quantizer_path: '*' + enable: false + - parent_class: 'nn.LeakyReLU' + quantizer_path: '*' enable: false - nn.BatchNorm1d: - '*': - enable: false - nn.BatchNorm2d: - '*': - enable: false - nn.BatchNorm3d: - '*': - enable: false - nn.LeakyReLU: - '*': - enable: false - '*[kv]_bmm_quantizer': - num_bits: e4m3 - enable: true diff --git a/modelopt_recipes/general/ptq/nvfp4_omlp_only-fp8_kv.yml b/modelopt_recipes/general/ptq/nvfp4_omlp_only-fp8_kv.yml index 4a19f874a..29cb76bb5 100644 --- a/modelopt_recipes/general/ptq/nvfp4_omlp_only-fp8_kv.yml +++ b/modelopt_recipes/general/ptq/nvfp4_omlp_only-fp8_kv.yml @@ -19,82 +19,89 @@ metadata: ptq_cfg: algorithm: max quant_cfg: - '*mlp*weight_quantizer': - block_sizes: - -1: 16 - type: dynamic - scale_bits: e4m3 - num_bits: e2m1 + - quantizer_path: '*' + enable: false + - quantizer_path: '*mlp*weight_quantizer' + enable: true + cfg: + block_sizes: + -1: 16 + type: dynamic + scale_bits: e4m3 + num_bits: e2m1 + - quantizer_path: '*mlp*input_quantizer' enable: true - '*mlp*input_quantizer': - block_sizes: - -1: 16 - type: dynamic - scale_bits: e4m3 - num_bits: e2m1 + cfg: + block_sizes: + -1: 16 + type: dynamic + scale_bits: e4m3 + num_bits: e2m1 + - quantizer_path: '*block_sparse_moe*weight_quantizer' enable: true - '*block_sparse_moe*weight_quantizer': - block_sizes: - -1: 16 - type: dynamic - scale_bits: e4m3 - num_bits: e2m1 + cfg: + block_sizes: + -1: 16 + type: dynamic + scale_bits: e4m3 + num_bits: e2m1 + - quantizer_path: '*block_sparse_moe*input_quantizer' enable: true - '*block_sparse_moe*input_quantizer': - block_sizes: - -1: 16 - type: dynamic - scale_bits: e4m3 - num_bits: e2m1 + cfg: + block_sizes: + -1: 16 + type: dynamic + scale_bits: e4m3 + num_bits: e2m1 + - quantizer_path: '*o_proj*weight_quantizer' enable: true - '*o_proj*weight_quantizer': - block_sizes: - -1: 16 - type: dynamic - scale_bits: e4m3 - num_bits: e2m1 + cfg: + block_sizes: + -1: 16 + type: dynamic + scale_bits: e4m3 + num_bits: e2m1 + - quantizer_path: '*o_proj*input_quantizer' enable: true - '*o_proj*input_quantizer': - block_sizes: - -1: 16 - type: dynamic - scale_bits: e4m3 - num_bits: e2m1 + cfg: + block_sizes: + -1: 16 + type: dynamic + scale_bits: e4m3 + num_bits: e2m1 + - quantizer_path: '*[kv]_bmm_quantizer' enable: true - default: + cfg: + num_bits: e4m3 + - quantizer_path: '*block_sparse_moe.gate*' enable: false - '*block_sparse_moe.gate*': + - quantizer_path: '*linear_attn.conv1d*' enable: false - '*linear_attn.conv1d*': + - quantizer_path: '*lm_head*' enable: false - '*lm_head*': + - quantizer_path: '*mixer.conv1d*' enable: false - '*mixer.conv1d*': + - quantizer_path: '*mlp.gate.*' enable: false - '*mlp.gate.*': + - quantizer_path: '*mlp.shared_expert_gate.*' enable: false - '*mlp.shared_expert_gate.*': + - quantizer_path: '*output_layer*' enable: false - '*output_layer*': + - quantizer_path: '*proj_out.*' enable: false - '*proj_out.*': + - quantizer_path: '*router*' enable: false - '*router*': + - quantizer_path: 'output.*' enable: false - output.*: + - parent_class: 'nn.BatchNorm1d' + quantizer_path: '*' + enable: false + - parent_class: 'nn.BatchNorm2d' + quantizer_path: '*' + enable: false + - parent_class: 'nn.BatchNorm3d' + quantizer_path: '*' + enable: false + - parent_class: 'nn.LeakyReLU' + quantizer_path: '*' enable: false - nn.BatchNorm1d: - '*': - enable: false - nn.BatchNorm2d: - '*': - enable: false - nn.BatchNorm3d: - '*': - enable: false - nn.LeakyReLU: - '*': - enable: false - '*[kv]_bmm_quantizer': - num_bits: e4m3 - enable: true diff --git a/tests/_test_utils/torch/export/utils.py b/tests/_test_utils/torch/export/utils.py index 8011eb72e..e0867bad7 100644 --- a/tests/_test_utils/torch/export/utils.py +++ b/tests/_test_utils/torch/export/utils.py @@ -85,162 +85,241 @@ def forward(self, x): # Quantization configs partial_fp8_config = { - "quant_cfg": { - "*.1.weight_quantizer": {"num_bits": (4, 3), "axis": None}, - "*.1.input_quantizer": {"num_bits": (4, 3), "axis": None}, - "default": {"num_bits": 8, "enable": False}, - }, + "quant_cfg": [ + {"quantizer_path": "*", "enable": False}, + {"quantizer_path": "*.1.weight_quantizer", "cfg": {"num_bits": (4, 3), "axis": None}}, + {"quantizer_path": "*.1.input_quantizer", "cfg": {"num_bits": (4, 3), "axis": None}}, + ], "algorithm": "max", } partial_w4a8_config = { - "quant_cfg": { - "*.2.weight_quantizer": [ - {"num_bits": 4, "block_sizes": {-1: 128, "type": "static"}, "enable": True}, - {"num_bits": (4, 3), "axis": None, "enable": True}, - ], - "*.2.input_quantizer": {"num_bits": (4, 3), "axis": None, "enable": True}, - "default": {"num_bits": 8, "enable": False}, - }, + "quant_cfg": [ + {"quantizer_path": "*", "enable": False}, + { + "quantizer_path": "*.2.weight_quantizer", + "cfg": [ + {"num_bits": 4, "block_sizes": {-1: 128, "type": "static"}}, + {"num_bits": (4, 3), "axis": None}, + ], + "enable": True, + }, + { + "quantizer_path": "*.2.input_quantizer", + "cfg": {"num_bits": (4, 3), "axis": None}, + "enable": True, + }, + ], "algorithm": "awq_lite", } partial_nvfp4_config = { - "quant_cfg": { - "*.1.weight_quantizer": { - "num_bits": (2, 1), - "block_sizes": {-1: 16, "type": "dynamic", "scale_bits": (4, 3)}, - "axis": None, + "quant_cfg": [ + {"quantizer_path": "*", "enable": False}, + { + "quantizer_path": "*.1.weight_quantizer", + "cfg": { + "num_bits": (2, 1), + "block_sizes": {-1: 16, "type": "dynamic", "scale_bits": (4, 3)}, + "axis": None, + }, "enable": True, }, - "*.1.input_quantizer": { - "num_bits": (2, 1), - "block_sizes": {-1: 16, "type": "dynamic", "scale_bits": (4, 3)}, - "axis": None, + { + "quantizer_path": "*.1.input_quantizer", + "cfg": { + "num_bits": (2, 1), + "block_sizes": {-1: 16, "type": "dynamic", "scale_bits": (4, 3)}, + "axis": None, + }, "enable": True, }, - "*.2.weight_quantizer": { - "num_bits": (2, 1), - "block_sizes": {-1: 16, "type": "dynamic", "scale_bits": (4, 3)}, - "axis": None, + { + "quantizer_path": "*.2.weight_quantizer", + "cfg": { + "num_bits": (2, 1), + "block_sizes": {-1: 16, "type": "dynamic", "scale_bits": (4, 3)}, + "axis": None, + }, "enable": True, }, - "*.2.input_quantizer": { - "num_bits": (2, 1), - "block_sizes": {-1: 16, "type": "dynamic", "scale_bits": (4, 3)}, - "axis": None, + { + "quantizer_path": "*.2.input_quantizer", + "cfg": { + "num_bits": (2, 1), + "block_sizes": {-1: 16, "type": "dynamic", "scale_bits": (4, 3)}, + "axis": None, + }, "enable": True, }, - "default": {"enable": False}, - }, + ], "algorithm": "max", } partial_nvfp4_awq_config = { - "quant_cfg": { - "*.2.weight_quantizer": { - "num_bits": (2, 1), - "block_sizes": {-1: 16, "type": "dynamic", "scale_bits": (4, 3)}, - "axis": None, + "quant_cfg": [ + {"quantizer_path": "*", "enable": False}, + { + "quantizer_path": "*.2.weight_quantizer", + "cfg": { + "num_bits": (2, 1), + "block_sizes": {-1: 16, "type": "dynamic", "scale_bits": (4, 3)}, + "axis": None, + }, "enable": True, }, - "*.2.input_quantizer": { - "num_bits": (2, 1), - "block_sizes": {-1: 16, "type": "dynamic", "scale_bits": (4, 3)}, - "axis": None, + { + "quantizer_path": "*.2.input_quantizer", + "cfg": { + "num_bits": (2, 1), + "block_sizes": {-1: 16, "type": "dynamic", "scale_bits": (4, 3)}, + "axis": None, + }, "enable": True, }, - "*.1.weight_quantizer": { - "num_bits": (2, 1), - "block_sizes": {-1: 16, "type": "dynamic", "scale_bits": (4, 3)}, - "axis": None, + { + "quantizer_path": "*.1.weight_quantizer", + "cfg": { + "num_bits": (2, 1), + "block_sizes": {-1: 16, "type": "dynamic", "scale_bits": (4, 3)}, + "axis": None, + }, "enable": False, }, - "*.1.input_quantizer": { - "num_bits": (2, 1), - "block_sizes": {-1: 16, "type": "dynamic", "scale_bits": (4, 3)}, - "axis": None, + { + "quantizer_path": "*.1.input_quantizer", + "cfg": { + "num_bits": (2, 1), + "block_sizes": {-1: 16, "type": "dynamic", "scale_bits": (4, 3)}, + "axis": None, + }, "enable": False, }, - "default": {"enable": False}, - }, + ], "algorithm": "awq_lite", } partial_int4_awq_config = { - "quant_cfg": { - "*.2.weight_quantizer": { - "num_bits": 4, - "block_sizes": {-1: 128, "type": "static"}, + "quant_cfg": [ + {"quantizer_path": "*", "enable": False}, + { + "quantizer_path": "*.2.weight_quantizer", + "cfg": {"num_bits": 4, "block_sizes": {-1: 128, "type": "static"}}, "enable": True, }, - "*.2.input_quantizer": {"enable": False}, - "default": {"enable": False}, - }, + {"quantizer_path": "*.2.input_quantizer", "enable": False}, + ], "algorithm": {"method": "awq_lite", "alpha_step": 0.1}, # "algorithm": {"method": "awq_full", "alpha_step": 0.1, "max_co_batch_size": 1024}, # "algorithm": {"method": "awq_clip", "max_co_batch_size": 2048}, } partial_fp8_kv_cache_config = { - "quant_cfg": { - "*.1.weight_quantizer": {"num_bits": (4, 3), "axis": None}, - "*.1.input_quantizer": {"num_bits": (4, 3), "axis": None}, - "*output_quantizer": {"num_bits": (4, 3), "axis": None, "enable": True}, - "default": {"enable": False}, - }, + "quant_cfg": [ + {"quantizer_path": "*", "enable": False}, + {"quantizer_path": "*.1.weight_quantizer", "cfg": {"num_bits": (4, 3), "axis": None}}, + {"quantizer_path": "*.1.input_quantizer", "cfg": {"num_bits": (4, 3), "axis": None}}, + { + "quantizer_path": "*output_quantizer", + "cfg": {"num_bits": (4, 3), "axis": None}, + "enable": True, + }, + ], "algorithm": "max", } partial_int8_kv_cache_config = { - "quant_cfg": { - "*.1.weight_quantizer": {"num_bits": (4, 3), "axis": None}, - "*.1.input_quantizer": {"num_bits": (4, 3), "axis": None}, - "*output_quantizer": {"num_bits": 8, "axis": None, "enable": True}, - "default": {"enable": False}, - }, + "quant_cfg": [ + {"quantizer_path": "*", "enable": False}, + {"quantizer_path": "*.1.weight_quantizer", "cfg": {"num_bits": (4, 3), "axis": None}}, + {"quantizer_path": "*.1.input_quantizer", "cfg": {"num_bits": (4, 3), "axis": None}}, + { + "quantizer_path": "*output_quantizer", + "cfg": {"num_bits": 8, "axis": None}, + "enable": True, + }, + ], "algorithm": "max", } partial_nvfp4_kv_cache_config = { - "quant_cfg": { - "*.1.weight_quantizer": {"num_bits": (4, 3), "axis": None}, - "*.1.input_quantizer": {"num_bits": (4, 3), "axis": None}, - "*[kv]_bmm_quantizer": { - "num_bits": (2, 1), - "block_sizes": {-1: 16, "type": "dynamic", "scale_bits": (4, 3)}, - "axis": None, + "quant_cfg": [ + {"quantizer_path": "*", "enable": False}, + {"quantizer_path": "*.1.weight_quantizer", "cfg": {"num_bits": (4, 3), "axis": None}}, + {"quantizer_path": "*.1.input_quantizer", "cfg": {"num_bits": (4, 3), "axis": None}}, + { + "quantizer_path": "*[kv]_bmm_quantizer", + "cfg": { + "num_bits": (2, 1), + "block_sizes": {-1: 16, "type": "dynamic", "scale_bits": (4, 3)}, + "axis": None, + }, "enable": True, }, - "default": {"enable": False}, - }, + ], "algorithm": "max", } only_weight_quantizer_fp8_config = { - "quant_cfg": { - "*weight_quantizer": {"num_bits": (4, 3), "axis": None, "enable": True}, - "*input_quantizer": {"num_bits": (4, 3), "axis": None, "enable": False}, - "*output_quantizer": {"num_bits": (4, 3), "axis": None, "enable": False}, - "default": {"enable": False}, - }, + "quant_cfg": [ + {"quantizer_path": "*", "enable": False}, + { + "quantizer_path": "*weight_quantizer", + "cfg": {"num_bits": (4, 3), "axis": None}, + "enable": True, + }, + { + "quantizer_path": "*input_quantizer", + "cfg": {"num_bits": (4, 3), "axis": None}, + "enable": False, + }, + { + "quantizer_path": "*output_quantizer", + "cfg": {"num_bits": (4, 3), "axis": None}, + "enable": False, + }, + ], "algorithm": "max", } only_input_quantizer_fp8_config = { - "quant_cfg": { - "*weight_quantizer": {"num_bits": (4, 3), "axis": None, "enable": False}, - "*input_quantizer": {"num_bits": (4, 3), "axis": None, "enable": True}, - "*output_quantizer": {"num_bits": (4, 3), "axis": None, "enable": False}, - "default": {"enable": False}, - }, + "quant_cfg": [ + {"quantizer_path": "*", "enable": False}, + { + "quantizer_path": "*weight_quantizer", + "cfg": {"num_bits": (4, 3), "axis": None}, + "enable": False, + }, + { + "quantizer_path": "*input_quantizer", + "cfg": {"num_bits": (4, 3), "axis": None}, + "enable": True, + }, + { + "quantizer_path": "*output_quantizer", + "cfg": {"num_bits": (4, 3), "axis": None}, + "enable": False, + }, + ], "algorithm": "max", } only_output_quantizer_fp8_config = { - "quant_cfg": { - "*weight_quantizer": {"num_bits": (4, 3), "axis": None, "enable": False}, - "*input_quantizer": {"num_bits": (4, 3), "axis": None, "enable": False}, - "*output_quantizer": {"num_bits": (4, 3), "axis": None, "enable": True}, - "default": {"enable": False}, - }, + "quant_cfg": [ + {"quantizer_path": "*", "enable": False}, + { + "quantizer_path": "*weight_quantizer", + "cfg": {"num_bits": (4, 3), "axis": None}, + "enable": False, + }, + { + "quantizer_path": "*input_quantizer", + "cfg": {"num_bits": (4, 3), "axis": None}, + "enable": False, + }, + { + "quantizer_path": "*output_quantizer", + "cfg": {"num_bits": (4, 3), "axis": None}, + "enable": True, + }, + ], "algorithm": "max", } diff --git a/tests/_test_utils/torch/quantization/onnx_export.py b/tests/_test_utils/torch/quantization/onnx_export.py index 5c74e656c..57ee92ad0 100644 --- a/tests/_test_utils/torch/quantization/onnx_export.py +++ b/tests/_test_utils/torch/quantization/onnx_export.py @@ -29,11 +29,11 @@ def onnx_export_tester(model, device, num_bits, per_channel_quantization, constant_folding, dtype): axis = 0 if per_channel_quantization else None config = { - "quant_cfg": { - "*weight_quantizer": {"num_bits": num_bits, "axis": axis}, - "*input_quantizer": {"num_bits": num_bits}, - "default": {"enable": False}, - }, + "quant_cfg": [ + {"quantizer_path": "*", "enable": False}, + {"quantizer_path": "*weight_quantizer", "cfg": {"num_bits": num_bits, "axis": axis}}, + {"quantizer_path": "*input_quantizer", "cfg": {"num_bits": num_bits}}, + ], "algorithm": "max", } diff --git a/tests/_test_utils/torch/quantization/quantize_common.py b/tests/_test_utils/torch/quantization/quantize_common.py index ae56dd299..03290dfab 100644 --- a/tests/_test_utils/torch/quantization/quantize_common.py +++ b/tests/_test_utils/torch/quantization/quantize_common.py @@ -47,7 +47,18 @@ def get_awq_config(algorithm="awq_lite", block_size=8): config = copy.deepcopy(mtq.INT4_AWQ_CFG) - config["quant_cfg"]["*weight_quantizer"]["block_sizes"] = {-1: block_size} + for entry in config["quant_cfg"]: + pat = ( + entry["quantizer_path"] + if isinstance(entry, dict) and "quantizer_path" in entry + else entry[0] + ) + if pat == "*weight_quantizer": + if isinstance(entry, dict) and "quantizer_path" in entry: + entry.setdefault("cfg", {})["block_sizes"] = {-1: block_size} + else: + entry[1]["block_sizes"] = {-1: block_size} + break if "algorithm" not in config or not isinstance(config["algorithm"], dict): config["algorithm"] = {} diff --git a/tests/gpu/torch/quantization/test_hadamard.py b/tests/gpu/torch/quantization/test_hadamard.py index 93d3e8ccb..430d7ddf6 100644 --- a/tests/gpu/torch/quantization/test_hadamard.py +++ b/tests/gpu/torch/quantization/test_hadamard.py @@ -77,7 +77,7 @@ def test_kv_rotate(rotate_fp32): model = nn.Sequential(SDPAAttention()) mtq.replace_quant_module(model) - set_quantizer_by_cfg(model, {"*": {"enable": False}}) + set_quantizer_by_cfg(model, [{"quantizer_path": "*", "enable": False}]) dummy_input = SDPAAttention.get_input(device="cuda") output_ref = model(dummy_input) if rotate_fp32: @@ -86,11 +86,9 @@ def test_kv_rotate(rotate_fp32): rotate = True with set_quantizer_by_cfg_context( model, - { - "*[qk]_bmm_quantizer": { - "rotate": rotate, - }, - }, + [ + {"quantizer_path": "*[qk]_bmm_quantizer", "cfg": {"rotate": rotate}}, + ], ): output_test = model(dummy_input) assert torch.allclose(output_ref, output_test, atol=0.05) @@ -98,11 +96,9 @@ def test_kv_rotate(rotate_fp32): # Test the rotation is actually applied by turning on only one of the query, key quantizers with set_quantizer_by_cfg_context( model, - { - "*k_bmm_quantizer": { - "rotate": rotate, - }, - }, + [ + {"quantizer_path": "*k_bmm_quantizer", "cfg": {"rotate": rotate}}, + ], ): output_test1 = model(dummy_input) assert not torch.allclose(output_ref, output_test1, atol=0.05) diff --git a/tests/gpu/torch/quantization/test_quant_rnn_cuda.py b/tests/gpu/torch/quantization/test_quant_rnn_cuda.py index be40de8e5..8a245336f 100644 --- a/tests/gpu/torch/quantization/test_quant_rnn_cuda.py +++ b/tests/gpu/torch/quantization/test_quant_rnn_cuda.py @@ -21,7 +21,7 @@ import torch import torch.nn as nn -from modelopt.torch.quantization import set_quantizer_attribute +from modelopt.torch.quantization.conversion import set_quantizer_attributes_partial from modelopt.torch.quantization.nn import QuantModuleRegistry @@ -44,7 +44,7 @@ def test_no_quant_proj(original_cls, bidirectional, bias): rnn_object_original = copy.deepcopy(rnn_object) quant_rnn_object = QuantModuleRegistry.convert(rnn_object) - set_quantizer_attribute(quant_rnn_object, lambda name: True, {"enable": False}) + set_quantizer_attributes_partial(quant_rnn_object, lambda name: True, {"enable": False}) test_input = torch.randn((3, 2, 8), device="cuda") diff --git a/tests/gpu/torch/quantization/test_quantize_cuda.py b/tests/gpu/torch/quantization/test_quantize_cuda.py index 3e9ff4256..984aa5b2b 100644 --- a/tests/gpu/torch/quantization/test_quantize_cuda.py +++ b/tests/gpu/torch/quantization/test_quantize_cuda.py @@ -29,20 +29,26 @@ from modelopt.torch.quantization.extensions import get_cuda_ext_mx NVFP4_WEIGHT_ACT_MSE_CFG = { - "quant_cfg": { - "*weight_quantizer": { - "num_bits": (2, 1), - "block_sizes": {-1: 16, "type": "static", "scale_bits": (4, 3)}, - "axis": None, + "quant_cfg": [ + { + "quantizer_path": "*weight_quantizer", + "cfg": { + "num_bits": (2, 1), + "block_sizes": {-1: 16, "type": "static", "scale_bits": (4, 3)}, + "axis": None, + }, "enable": True, }, - "*input_quantizer": { - "num_bits": (2, 1), - "block_sizes": {-1: 16, "type": "dynamic", "scale_bits": (4, 3)}, - "axis": None, + { + "quantizer_path": "*input_quantizer", + "cfg": { + "num_bits": (2, 1), + "block_sizes": {-1: 16, "type": "dynamic", "scale_bits": (4, 3)}, + "axis": None, + }, "enable": True, }, - }, + ], "algorithm": { "method": "mse", "step_size": 0.25, @@ -52,17 +58,18 @@ } NVFP4_WEIGHT_MSE_FP8_SWEEP_CFG = { - "quant_cfg": { - "*weight_quantizer": { - "num_bits": (2, 1), - "block_sizes": {-1: 16, "type": "static", "scale_bits": (4, 3)}, - "axis": None, + "quant_cfg": [ + { + "quantizer_path": "*weight_quantizer", + "cfg": { + "num_bits": (2, 1), + "block_sizes": {-1: 16, "type": "static", "scale_bits": (4, 3)}, + "axis": None, + }, "enable": True, }, - "*input_quantizer": { - "enable": False, - }, - }, + {"quantizer_path": "*input_quantizer", "enable": False}, + ], "algorithm": { "method": "mse", "fp8_scale_sweep": True, @@ -123,7 +130,9 @@ def test_quantize(model_cls, config): if config == mtq.FP8_2D_BLOCKWISE_WEIGHT_ONLY_CFG: # reduce block sizes for simple testing models - config["quant_cfg"]["*weight_quantizer"]["block_sizes"] = {-1: 8, -2: 8} + for entry in config["quant_cfg"]: + if entry.get("quantizer_path") == "*weight_quantizer": + entry.setdefault("cfg", {})["block_sizes"] = {-1: 8, -2: 8} model = model_cls().cuda() calib_data = [model.get_input().cuda() for _ in range(8)] quantize_model_and_forward(model, config, calib_data) diff --git a/tests/gpu/torch/quantization/test_real_quantize_cuda.py b/tests/gpu/torch/quantization/test_real_quantize_cuda.py index 2c6512896..e94210ff7 100644 --- a/tests/gpu/torch/quantization/test_real_quantize_cuda.py +++ b/tests/gpu/torch/quantization/test_real_quantize_cuda.py @@ -47,10 +47,13 @@ def test_real_quantize(model_cls, config): # update config to fit test cases if config == mtq.INT4_AWQ_CFG: # reduce block sizes for simple testing models - config["quant_cfg"]["*weight_quantizer"]["block_sizes"] = { - -1: 16, - "scale_bits": 8, - } + for entry in config["quant_cfg"]: + if entry.get("quantizer_path") == "*weight_quantizer": + entry.setdefault("cfg", {})["block_sizes"] = { + -1: 16, + "scale_bits": 8, + } + break if model_cls is SimpleConv or model_cls is SimpleConvLinear: pytest.skip( "INT4_AWQ_CFG requires even number of elements on last dimension for weights." @@ -101,10 +104,13 @@ def test_save_restore(model_cls, config): # update config to fit test cases if config == mtq.INT4_AWQ_CFG: # reduce block sizes for simple testing models - config["quant_cfg"]["*weight_quantizer"]["block_sizes"] = { - -1: 16, - "scale_bits": 8, - } + for entry in config["quant_cfg"]: + if entry.get("quantizer_path") == "*weight_quantizer": + entry.setdefault("cfg", {})["block_sizes"] = { + -1: 16, + "scale_bits": 8, + } + break if model_cls is SimpleConv or model_cls is SimpleConvLinear: pytest.skip( "INT4_AWQ_CFG requires even number of elements on last dimension for weights." diff --git a/tests/gpu_megatron/torch/peft/plugins/test_megatron_peft.py b/tests/gpu_megatron/torch/peft/plugins/test_megatron_peft.py index b71eaeb21..cfa678b1a 100644 --- a/tests/gpu_megatron/torch/peft/plugins/test_megatron_peft.py +++ b/tests/gpu_megatron/torch/peft/plugins/test_megatron_peft.py @@ -33,23 +33,32 @@ from modelopt.torch.utils.plugins import megatron_prefill NVFP4_DEFAULT_CONFIG = { - "quant_cfg": { - "*weight_quantizer": { - "num_bits": (2, 1), - "block_sizes": {-1: 16, "type": "dynamic", "scale_bits": (4, 3)}, - "axis": None, + "quant_cfg": [ + {"quantizer_path": "*", "enable": False}, + { + "quantizer_path": "*weight_quantizer", + "cfg": { + "num_bits": (2, 1), + "block_sizes": {-1: 16, "type": "dynamic", "scale_bits": (4, 3)}, + "axis": None, + }, "enable": True, }, - "*input_quantizer": { - "num_bits": (2, 1), - "block_sizes": {-1: 16, "type": "dynamic", "scale_bits": (4, 3)}, - "axis": None, + { + "quantizer_path": "*input_quantizer", + "cfg": { + "num_bits": (2, 1), + "block_sizes": {-1: 16, "type": "dynamic", "scale_bits": (4, 3)}, + "axis": None, + }, "enable": True, }, - "*output_quantizer": {"enable": False}, - "*output_layer*": {"enable": False}, # Note: only output_layer is disabled. - "default": {"enable": False}, - }, + {"quantizer_path": "*output_quantizer", "enable": False}, + { + "quantizer_path": "*output_layer*", + "enable": False, + }, # Note: only output_layer is disabled. + ], "algorithm": "max", } diff --git a/tests/gpu_megatron/torch/quantization/plugins/test_apex.py b/tests/gpu_megatron/torch/quantization/plugins/test_apex.py index 1c9bf1ec6..144c05f6d 100644 --- a/tests/gpu_megatron/torch/quantization/plugins/test_apex.py +++ b/tests/gpu_megatron/torch/quantization/plugins/test_apex.py @@ -84,15 +84,15 @@ def test_convert_apex_parallel_linear(distributed_setup_size_1): assert hasattr(module, "weight_quantizer") assert hasattr(module, "output_quantizer") - mtq.set_quantizer_attribute(model_test, "*", {"enable": False}) + mtq.set_quantizer_attributes_partial(model_test, "*", {"enable": False}) x = model_ref.get_dummy_input().cuda() out_1 = model_ref(x) out_2 = model_test(x) assert torch.allclose(out_1, out_2) - mtq.set_quantizer_attribute(model_test, "*input_quantizer", {"enable": True}) - mtq.set_quantizer_attribute(model_test, "*weight_quantizer", {"enable": True}) + mtq.set_quantizer_attributes_partial(model_test, "*input_quantizer", {"enable": True}) + mtq.set_quantizer_attributes_partial(model_test, "*weight_quantizer", {"enable": True}) model_ref = RegularQuantModelForTP().cuda() model_ref.load_state_dict(model_test.state_dict()) diff --git a/tests/gpu_megatron/torch/quantization/plugins/test_megatron.py b/tests/gpu_megatron/torch/quantization/plugins/test_megatron.py index d8ba6fbed..8075ddc13 100644 --- a/tests/gpu_megatron/torch/quantization/plugins/test_megatron.py +++ b/tests/gpu_megatron/torch/quantization/plugins/test_megatron.py @@ -82,15 +82,15 @@ def test_convert_megatron_parallel_linear(distributed_setup_size_1): assert hasattr(module, "weight_quantizer") assert hasattr(module, "output_quantizer") - mtq.set_quantizer_attribute(model_test, "*", {"enable": False}) + mtq.set_quantizer_attributes_partial(model_test, "*", {"enable": False}) x = model_ref.get_dummy_input().cuda() out_1 = model_ref(x) out_2 = model_test(x) assert torch.allclose(out_1, out_2) - mtq.set_quantizer_attribute(model_test, "*input_quantizer", {"enable": True}) - mtq.set_quantizer_attribute(model_test, "*weight_quantizer", {"enable": True}) + mtq.set_quantizer_attributes_partial(model_test, "*input_quantizer", {"enable": True}) + mtq.set_quantizer_attributes_partial(model_test, "*weight_quantizer", {"enable": True}) model_ref = RegularQuantModelForTP().cuda() model_ref.load_state_dict(model_test.state_dict(), strict=False) @@ -304,7 +304,7 @@ def _test_sharded_state_dict( ): # Must disable output_layer quantization since output_layer amax cannot be restore via # sharded_state_dict. All output_layer quantizers state are removed. - config["quant_cfg"]["*output_layer*"] = {"enable": False} + config["quant_cfg"].append({"quantizer_path": "*output_layer*", "enable": False}) if modelopt_version is not None: mto.conversion.__version__ = modelopt_version @@ -383,36 +383,44 @@ def _test_sharded_state_dict( mixed_precision_config = copy.deepcopy(mtq.W4A8_AWQ_BETA_CFG) -mixed_precision_config["quant_cfg"].update( - { - "*.1.*": {"enable": False}, - "*.2.*weight_quantizer": {"num_bits": (4, 3), "axis": None}, - "*.2.*input_quantizer": {"num_bits": (4, 3), "axis": None}, - "*.3.*weight_quantizer.0": {"num_bits": 8, "axis": 0}, - "*.3.*weight_quantizer.1": {"enable": False}, - "*.3.*input_quantizer": {"num_bits": 8, "axis": None}, - } +mixed_precision_config["quant_cfg"].extend( + [ + {"quantizer_path": "*.1.*", "enable": False}, + {"quantizer_path": "*.2.*weight_quantizer", "cfg": {"num_bits": (4, 3), "axis": None}}, + {"quantizer_path": "*.2.*input_quantizer", "cfg": {"num_bits": (4, 3), "axis": None}}, + {"quantizer_path": "*.3.*weight_quantizer.0", "cfg": {"num_bits": 8, "axis": 0}}, + {"quantizer_path": "*.3.*weight_quantizer.1", "enable": False}, + {"quantizer_path": "*.3.*input_quantizer", "cfg": {"num_bits": 8, "axis": None}}, + ] ) mixed_block_size_config = copy.deepcopy(mtq.INT4_BLOCKWISE_WEIGHT_ONLY_CFG) -mixed_block_size_config["quant_cfg"].update( - { - "*.1.*": {"enable": False}, - "*.2.*weight_quantizer": {"num_bits": 4, "block_sizes": {-1: 64}, "enable": True}, - "*.2.*input_quantizer": {"num_bits": (4, 3), "axis": None}, - "*.3.*weight_quantizer": {"num_bits": 4, "block_sizes": {-1: 128, -2: 64}, "enable": True}, - "*.3.*input_quantizer": {"num_bits": 8, "axis": None}, - } +mixed_block_size_config["quant_cfg"].extend( + [ + {"quantizer_path": "*.1.*", "enable": False}, + { + "quantizer_path": "*.2.*weight_quantizer", + "cfg": {"num_bits": 4, "block_sizes": {-1: 64}}, + "enable": True, + }, + {"quantizer_path": "*.2.*input_quantizer", "cfg": {"num_bits": (4, 3), "axis": None}}, + { + "quantizer_path": "*.3.*weight_quantizer", + "cfg": {"num_bits": 4, "block_sizes": {-1: 128, -2: 64}}, + "enable": True, + }, + {"quantizer_path": "*.3.*input_quantizer", "cfg": {"num_bits": 8, "axis": None}}, + ] ) # Combined NVFP4 GEMM + KV cache quantization config NVFP4_GEMM_KV_CFG = copy.deepcopy(mtq.NVFP4_DEFAULT_CFG) -NVFP4_GEMM_KV_CFG["quant_cfg"].update(mtq.NVFP4_KV_CFG["quant_cfg"]) +NVFP4_GEMM_KV_CFG["quant_cfg"].extend(mtq.NVFP4_KV_CFG["quant_cfg"]) # Combined FP8 GEMM + KV cache quantization config FP8_GEMM_KV_CFG = copy.deepcopy(mtq.FP8_DEFAULT_CFG) -FP8_GEMM_KV_CFG["quant_cfg"].update(mtq.FP8_KV_CFG["quant_cfg"]) +FP8_GEMM_KV_CFG["quant_cfg"].extend(mtq.FP8_KV_CFG["quant_cfg"]) @pytest.mark.parametrize( diff --git a/tests/gpu_megatron/torch/quantization/plugins/test_transformer_engine.py b/tests/gpu_megatron/torch/quantization/plugins/test_transformer_engine.py index 288cc7519..348d89af2 100644 --- a/tests/gpu_megatron/torch/quantization/plugins/test_transformer_engine.py +++ b/tests/gpu_megatron/torch/quantization/plugins/test_transformer_engine.py @@ -73,7 +73,10 @@ def test_quantize(model_cls, config): if config == mtq.FP8_2D_BLOCKWISE_WEIGHT_ONLY_CFG: # reduce block sizes for simple testing models - config["quant_cfg"]["*weight_quantizer"]["block_sizes"] = {-1: 8, -2: 8} + for entry in config["quant_cfg"]: + if entry.get("quantizer_path") == "*weight_quantizer": + entry["cfg"]["block_sizes"] = {-1: 8, -2: 8} + break model = model_cls().cuda() calib_data = [model.get_input().cuda() for _ in range(1)] quantize_model_and_forward(model, config, calib_data) diff --git a/tests/unit/recipe/test_loader.py b/tests/unit/recipe/test_loader.py index e52617861..251fc7fdc 100644 --- a/tests/unit/recipe/test_loader.py +++ b/tests/unit/recipe/test_loader.py @@ -15,6 +15,8 @@ """Unit tests for modelopt.recipe.loader and modelopt.recipe.loader.load_config.""" +import re + import pytest from modelopt.recipe.config import ModelOptPTQRecipe, RecipeType @@ -164,11 +166,11 @@ def test_load_recipe_dir(tmp_path): (tmp_path / "recipe.yml").write_text( "metadata:\n recipe_type: ptq\n description: Dir test.\n" ) - (tmp_path / "ptq_cfg.yml").write_text("algorithm: max\nquant_cfg: {}\n") + (tmp_path / "ptq_cfg.yml").write_text("algorithm: max\nquant_cfg: []\n") recipe = load_recipe(tmp_path) assert recipe.recipe_type == RecipeType.PTQ assert recipe.description == "Dir test." - assert recipe.ptq_cfg == {"algorithm": "max", "quant_cfg": {}} + assert recipe.ptq_cfg == {"algorithm": "max", "quant_cfg": []} def test_load_recipe_dir_missing_recipe_raises(tmp_path): @@ -200,13 +202,49 @@ def test_load_recipe_dir_missing_ptq_cfg_raises(tmp_path): ], ) def test_general_ptq_yaml_matches_config_dicts(yaml_path, model_cfg_name, kv_cfg_name): - """Each general/ptq YAML's merged quant_cfg matches the corresponding config.py dicts.""" + """Each general/ptq YAML's quant_cfg list matches the merged Python config dicts.""" + import json + import modelopt.torch.quantization.config as qcfg + from modelopt.torch.quantization.config import normalize_quant_cfg_list model_cfg = getattr(qcfg, model_cfg_name) kv_cfg = getattr(qcfg, kv_cfg_name) yaml_data = load_config(yaml_path) - ptq = yaml_data["ptq_cfg"] - assert {**model_cfg["quant_cfg"], **kv_cfg["quant_cfg"]} == ptq["quant_cfg"] - assert model_cfg["algorithm"] == ptq["algorithm"] + def _normalize_fpx(val): + """Normalize FPx representations to a canonical ``[E, M]`` list. + + Python configs may use tuple form ``(E, M)`` or string alias ``"eEmM"``; + YAML always uses the string form. Both are converted to ``[E, M]`` so the + comparison is representation-agnostic. + """ + if isinstance(val, str): + m = re.fullmatch(r"e(\d+)m(\d+)", val) + if m: + return [int(m.group(1)), int(m.group(2))] + if isinstance(val, tuple) and len(val) == 2 and all(isinstance(x, int) for x in val): + return list(val) + if isinstance(val, dict): + return {str(k): _normalize_fpx(v) for k, v in val.items()} + return val + + def _normalize_entries(raw_entries): + """Normalize a raw quant_cfg list to a canonical, JSON-serialisable form.""" + entries = normalize_quant_cfg_list(list(raw_entries)) + result = [] + for entry in entries: + e = {k: v for k, v in entry.items() if v is not None} + if "cfg" in e and e["cfg"] is not None: + e["cfg"] = _normalize_fpx(e["cfg"]) + result.append(e) + return result + + def _sort_key(entry): + return json.dumps(entry, sort_keys=True, default=str) + + python_entries = _normalize_entries(model_cfg["quant_cfg"] + kv_cfg["quant_cfg"]) + yaml_entries = _normalize_entries(yaml_data["ptq_cfg"]["quant_cfg"]) + + assert sorted(python_entries, key=_sort_key) == sorted(yaml_entries, key=_sort_key) + assert model_cfg["algorithm"] == yaml_data["ptq_cfg"]["algorithm"] diff --git a/tests/unit/torch/quantization/plugins/test_attention_quant.py b/tests/unit/torch/quantization/plugins/test_attention_quant.py index 9526f80ac..302e39496 100644 --- a/tests/unit/torch/quantization/plugins/test_attention_quant.py +++ b/tests/unit/torch/quantization/plugins/test_attention_quant.py @@ -61,10 +61,10 @@ def forward(self, hidden_states, **kwargs): kv_cache_config = { - "quant_cfg": { - "*[kv]_bmm_quantizer": {"num_bits": 4, "enable": True}, - "*softmax_quantizer": {"enable": False}, - }, + "quant_cfg": [ + {"quantizer_path": "*[kv]_bmm_quantizer", "cfg": {"num_bits": 4}, "enable": True}, + {"quantizer_path": "*softmax_quantizer", "enable": False}, + ], "algorithm": "max", } diff --git a/tests/unit/torch/quantization/plugins/test_huggingface.py b/tests/unit/torch/quantization/plugins/test_huggingface.py index 33730409a..771feb31a 100644 --- a/tests/unit/torch/quantization/plugins/test_huggingface.py +++ b/tests/unit/torch/quantization/plugins/test_huggingface.py @@ -87,7 +87,7 @@ def test_convert_conv1d(): assert hasattr(module, "weight_quantizer") assert hasattr(module, "output_quantizer") - mtq.set_quantizer_attribute(model_test, "*", {"enable": False}) + mtq.set_quantizer_attributes_partial(model_test, "*", {"enable": False}) x = torch.randn(2, 3) out_1 = model_ref(x) @@ -95,8 +95,8 @@ def test_convert_conv1d(): assert torch.allclose(out_1, out_2) - mtq.set_quantizer_attribute(model_test, "*input_quantizer", {"enable": True}) - mtq.set_quantizer_attribute(model_test, "*weight_quantizer", {"enable": True}) + mtq.set_quantizer_attributes_partial(model_test, "*input_quantizer", {"enable": True}) + mtq.set_quantizer_attributes_partial(model_test, "*weight_quantizer", {"enable": True}) model_ref = PytorchModel() model_ref.load_state_dict(model_test.state_dict()) @@ -136,7 +136,7 @@ def test_dbrx(): expertglu_ref.w1, ) - mtq.set_quantizer_attribute(model_test, "*", {"enable": False}) + mtq.set_quantizer_attributes_partial(model_test, "*", {"enable": False}) x = torch.randn(1, 4, 32) out_1 = model_ref(x) @@ -193,7 +193,21 @@ def test_quantized_transformers_save_restore(tmp_path, model_cls, quant_config): tiny_llama_dir = create_tiny_llama_dir(tmp_path) # update config to fit test cases if quant_config == mtq.INT4_AWQ_CFG: - quant_config["quant_cfg"]["*weight_quantizer"]["block_sizes"] = {-1: 16} + import copy + + quant_config = copy.deepcopy(quant_config) + for entry in quant_config["quant_cfg"]: + pat = ( + entry["quantizer_path"] + if isinstance(entry, dict) and "quantizer_path" in entry + else entry[0] + ) + if pat == "*weight_quantizer": + if isinstance(entry, dict) and "quantizer_path" in entry: + entry.setdefault("cfg", {})["block_sizes"] = {-1: 16} + else: + entry[1]["block_sizes"] = {-1: 16} + break else: raise ValueError(f"Unsupported quant_config: {quant_config}") diff --git a/tests/unit/torch/quantization/plugins/test_peft.py b/tests/unit/torch/quantization/plugins/test_peft.py index c794c67bc..fda0e3bec 100644 --- a/tests/unit/torch/quantization/plugins/test_peft.py +++ b/tests/unit/torch/quantization/plugins/test_peft.py @@ -48,7 +48,7 @@ def test_convert_loralinear(): assert hasattr(module, "weight_quantizer") assert hasattr(module, "output_quantizer") - mtq.set_quantizer_attribute(model_test, "*", {"enable": False}) + mtq.set_quantizer_attributes_partial(model_test, "*", {"enable": False}) tf_output_tester(model_ref, model_test) diff --git a/tests/unit/torch/quantization/test_autoquant.py b/tests/unit/torch/quantization/test_autoquant.py index c0f049174..e619c7e7b 100644 --- a/tests/unit/torch/quantization/test_autoquant.py +++ b/tests/unit/torch/quantization/test_autoquant.py @@ -28,7 +28,7 @@ QuantRecipeHparam, estimate_quant_compression, ) -from modelopt.torch.quantization.config import _default_disabled_quantizer_cfg +from modelopt.torch.quantization.config import _base_disable_all, _default_disabled_quantizer_cfg from modelopt.torch.utils.distributed import DistributedProcessGroup @@ -110,11 +110,12 @@ def test_quant_recipe_hparam(): # use this config to test custom quantization config INT8_CUSTOM_QUANT_TEST_CFG = { - "quant_cfg": { - "*weight_quantizer": {"num_bits": 8, "axis": 0}, - "*input_quantizer": {"num_bits": 8, "axis": None}, - **_default_disabled_quantizer_cfg, - }, + "quant_cfg": [ + *_base_disable_all, + {"quantizer_path": "*weight_quantizer", "cfg": {"num_bits": 8, "axis": 0}}, + {"quantizer_path": "*input_quantizer", "cfg": {"num_bits": 8, "axis": None}}, + *_default_disabled_quantizer_cfg, + ], "algorithm": "smoothquant", } @@ -230,14 +231,22 @@ def test_auto_quantize_disabled_layers_no_poison(): INT4INT8_AWQ_CFG = { - "quant_cfg": { - "*weight_quantizer": [ - {"num_bits": 4, "block_sizes": {-1: 128, "type": "static"}, "enable": True}, - {"num_bits": 8, "axis": None, "enable": True}, - ], - "*input_quantizer": {"num_bits": 8, "axis": None, "enable": True}, - "default": {"enable": False}, - }, + "quant_cfg": [ + {"quantizer_path": "*", "enable": False}, + { + "quantizer_path": "*weight_quantizer", + "cfg": [ + {"num_bits": 4, "block_sizes": {-1: 128, "type": "static"}}, + {"num_bits": 8, "axis": None}, + ], + "enable": True, + }, + { + "quantizer_path": "*input_quantizer", + "cfg": {"num_bits": 8, "axis": None}, + "enable": True, + }, + ], "algorithm": "awq_lite", } @@ -480,7 +489,22 @@ def test_get_auto_quantize_config(method): # Use stored best recipe config = mtq.get_auto_quantize_config(search_state) assert "quant_cfg" in config - assert config["quant_cfg"]["*"] == {"enable": False} + assert isinstance(config["quant_cfg"], list) + assert any( + ( + entry["quantizer_path"] + if isinstance(entry, dict) and "quantizer_path" in entry + else entry[0] + ) + == "*" + and ( + entry.get("enable") + if isinstance(entry, dict) and "quantizer_path" in entry + else entry[1].get("enable") + ) + is False + for entry in config["quant_cfg"] + ) assert config["algorithm"] == "max" # Re-solve with different constraints diff --git a/tests/unit/torch/quantization/test_compute_quantization_mse.py b/tests/unit/torch/quantization/test_compute_quantization_mse.py index 9a9a81a61..26aa7144a 100644 --- a/tests/unit/torch/quantization/test_compute_quantization_mse.py +++ b/tests/unit/torch/quantization/test_compute_quantization_mse.py @@ -22,10 +22,10 @@ from modelopt.torch.quantization.nn import TensorQuantizer INT8_CFG = { - "quant_cfg": { - "*weight_quantizer": {"num_bits": 8, "axis": 0}, - "*input_quantizer": {"num_bits": 8, "axis": None}, - }, + "quant_cfg": [ + {"quantizer_path": "*weight_quantizer", "cfg": {"num_bits": 8, "axis": 0}}, + {"quantizer_path": "*input_quantizer", "cfg": {"num_bits": 8, "axis": None}}, + ], "algorithm": "max", } diff --git a/tests/unit/torch/quantization/test_config_validation.py b/tests/unit/torch/quantization/test_config_validation.py index 6ed0c918a..cc8077ef2 100644 --- a/tests/unit/torch/quantization/test_config_validation.py +++ b/tests/unit/torch/quantization/test_config_validation.py @@ -15,6 +15,8 @@ """Test of quantization config validations.""" +import pytest + from modelopt.torch.quantization.config import ( FP8_2D_BLOCKWISE_WEIGHT_ONLY_CFG, FP8_DEFAULT_CFG, @@ -23,6 +25,7 @@ NVFP4_DEFAULT_CFG, W4A8_AWQ_BETA_CFG, need_calibration, + normalize_quant_cfg_list, ) @@ -33,3 +36,94 @@ def test_need_calibration(): assert need_calibration(INT4_AWQ_CFG) assert need_calibration(W4A8_AWQ_BETA_CFG) assert need_calibration(NVFP4_DEFAULT_CFG) + + +class TestNormalizeQuantCfgList: + def test_new_format_passthrough(self): + """New-format entries are returned unchanged (only canonical defaults added).""" + raw = [{"quantizer_path": "*weight_quantizer", "cfg": {"num_bits": 8, "axis": 0}}] + result = normalize_quant_cfg_list(raw) + assert len(result) == 1 + assert result[0]["quantizer_path"] == "*weight_quantizer" + assert result[0]["cfg"] == {"num_bits": 8, "axis": 0} + assert result[0]["enable"] is True # defaulted + + def test_new_format_enable_false(self): + """Explicit enable=False is preserved.""" + raw = [{"quantizer_path": "*", "enable": False}] + result = normalize_quant_cfg_list(raw) + assert result[0]["enable"] is False + assert result[0]["cfg"] is None # defaulted + + def test_new_format_explicit_enable_true_no_cfg(self): + """Explicit enable=True with no cfg is valid and cfg defaults to None.""" + raw = [{"quantizer_path": "*", "enable": True}] + result = normalize_quant_cfg_list(raw) + assert result[0]["enable"] is True + assert result[0]["cfg"] is None + + def test_legacy_single_key_dict(self): + """Legacy {'*path': {attrs}} is converted to new format.""" + raw = [{"*weight_quantizer": {"num_bits": 8, "axis": 0}}] + result = normalize_quant_cfg_list(raw) + assert result[0]["quantizer_path"] == "*weight_quantizer" + assert result[0]["cfg"] == {"num_bits": 8, "axis": 0} + assert result[0]["enable"] is True # defaulted + + def test_legacy_single_key_dict_with_enable(self): + """Legacy {'*path': {'enable': False}} splits enable out from cfg.""" + raw = [{"*input_quantizer": {"enable": False}}] + result = normalize_quant_cfg_list(raw) + assert result[0]["quantizer_path"] == "*input_quantizer" + assert result[0]["enable"] is False + assert result[0]["cfg"] == {} + + def test_legacy_nn_class_scoped(self): + """Legacy {'nn.Linear': {'*': {attrs}}} is converted with parent_class.""" + raw = [{"nn.Linear": {"*": {"enable": False}}}] + result = normalize_quant_cfg_list(raw) + assert result[0]["parent_class"] == "nn.Linear" + assert result[0]["quantizer_path"] == "*" + assert result[0]["enable"] is False + + def test_normalization_cfg_defaults_to_none(self): + """Entries without cfg get cfg=None after normalization.""" + raw = [{"quantizer_path": "*lm_head*", "enable": False}] + result = normalize_quant_cfg_list(raw) + assert "cfg" in result[0] + assert result[0]["cfg"] is None + + def test_normalization_enable_defaults_to_true(self): + """Entries with cfg but no enable get enable=True after normalization.""" + raw = [{"quantizer_path": "*", "cfg": {"num_bits": 4}}] + result = normalize_quant_cfg_list(raw) + assert result[0]["enable"] is True + + def test_empty_list(self): + """Empty list is returned unchanged.""" + assert normalize_quant_cfg_list([]) == [] + + def test_multiple_entries_order_preserved(self): + """The order of entries is preserved.""" + raw = [ + {"quantizer_path": "*", "enable": False}, + {"quantizer_path": "*weight_quantizer", "cfg": {"num_bits": 8}}, + ] + result = normalize_quant_cfg_list(raw) + assert result[0]["quantizer_path"] == "*" + assert result[1]["quantizer_path"] == "*weight_quantizer" + + def test_error_on_quantizer_path_only(self): + """Entry with only quantizer_path and no cfg or enable is rejected.""" + with pytest.raises(ValueError, match="must specify 'cfg', 'enable'"): + normalize_quant_cfg_list([{"quantizer_path": "*"}]) + + def test_error_on_empty_dict(self): + """An empty dict entry is rejected.""" + with pytest.raises(ValueError): + normalize_quant_cfg_list([{}]) + + def test_error_on_multi_key_legacy_dict(self): + """A multi-key legacy dict (no quantizer_path) is rejected.""" + with pytest.raises(ValueError): + normalize_quant_cfg_list([{"*weight_quantizer": {}, "*input_quantizer": {}}]) diff --git a/tests/unit/torch/quantization/test_custom_backend.py b/tests/unit/torch/quantization/test_custom_backend.py index f42d6a5f9..1b9308559 100644 --- a/tests/unit/torch/quantization/test_custom_backend.py +++ b/tests/unit/torch/quantization/test_custom_backend.py @@ -42,16 +42,19 @@ def dummy_backend(inputs: torch.Tensor, tq) -> torch.Tensor: model = torch.nn.Linear(16, 16, bias=False) cfg = { - "quant_cfg": { - "*weight_quantizer": { + "quant_cfg": [ + {"quantizer_path": "*", "enable": False}, + { + "quantizer_path": "*weight_quantizer", + "cfg": { + "num_bits": 8, + "axis": None, + "backend": "dummy_backend", + "backend_extra_args": {"offset": 2.5}, + }, "enable": True, - "num_bits": 8, - "axis": None, - "backend": "dummy_backend", - "backend_extra_args": {"offset": 2.5}, }, - "default": {"enable": False}, - }, + ], "algorithm": "max", } @@ -88,10 +91,14 @@ def cached_backend(inputs: torch.Tensor, tq: TensorQuantizer) -> torch.Tensor: model = torch.nn.Linear(16, 16, bias=False) cfg = { - "quant_cfg": { - "*weight_quantizer": {"enable": True, "backend": "cached_backend"}, - "default": {"enable": False}, - }, + "quant_cfg": [ + {"quantizer_path": "*", "enable": False}, + { + "quantizer_path": "*weight_quantizer", + "cfg": {"backend": "cached_backend"}, + "enable": True, + }, + ], "algorithm": "max", } inputs = torch.randn(1, 16) diff --git a/tests/unit/torch/quantization/test_quant_activations.py b/tests/unit/torch/quantization/test_quant_activations.py index afc8decce..e27b85bb6 100644 --- a/tests/unit/torch/quantization/test_quant_activations.py +++ b/tests/unit/torch/quantization/test_quant_activations.py @@ -19,7 +19,7 @@ import torch.nn as nn import torch.nn.functional as F -from modelopt.torch.quantization import set_quantizer_attribute, tensor_quant +from modelopt.torch.quantization import set_quantizer_attributes_partial, tensor_quant from modelopt.torch.quantization.nn import QuantModuleRegistry @@ -42,7 +42,7 @@ def test_fake_quant_per_channel(self): negative_slope = 0.01 leaky_relu_object = nn.LeakyReLU(negative_slope=negative_slope) quant_leaky_relu_object = QuantModuleRegistry.convert(leaky_relu_object) - set_quantizer_attribute(quant_leaky_relu_object, lambda name: True, {"axis": (1)}) + set_quantizer_attributes_partial(quant_leaky_relu_object, lambda name: True, {"axis": (1)}) test_input = torch.randn(input_shape) quant_input = tensor_quant.fake_tensor_quant( diff --git a/tests/unit/torch/quantization/test_quant_batchnorm.py b/tests/unit/torch/quantization/test_quant_batchnorm.py index ee035dab1..c55b4b0b0 100644 --- a/tests/unit/torch/quantization/test_quant_batchnorm.py +++ b/tests/unit/torch/quantization/test_quant_batchnorm.py @@ -20,7 +20,8 @@ import torch.nn as nn import torch.nn.functional as F -from modelopt.torch.quantization import set_quantizer_attribute, tensor_quant +from modelopt.torch.quantization import tensor_quant +from modelopt.torch.quantization.conversion import set_quantizer_attributes_partial from modelopt.torch.quantization.nn import QuantModuleRegistry NUM_CHANNELS = 3 @@ -90,7 +91,7 @@ def test_fake_quant_per_tensor(self, original_cls, input_shape): def test_fake_quant_per_channel(self, original_cls, input_shape): batchnorm_object = original_cls(NUM_CHANNELS, affine=True) quant_batchnorm_object = QuantModuleRegistry.convert(batchnorm_object) - set_quantizer_attribute(quant_batchnorm_object, lambda name: True, {"axis": (1)}) + set_quantizer_attributes_partial(quant_batchnorm_object, lambda name: True, {"axis": (1)}) test_input = torch.randn(input_shape) reduce_dims = list(range(len(test_input.shape))) diff --git a/tests/unit/torch/quantization/test_quant_rnn.py b/tests/unit/torch/quantization/test_quant_rnn.py index 6f3d054c4..0ea6d755a 100644 --- a/tests/unit/torch/quantization/test_quant_rnn.py +++ b/tests/unit/torch/quantization/test_quant_rnn.py @@ -21,7 +21,8 @@ import torch import torch.nn as nn -from modelopt.torch.quantization import set_quantizer_attribute, tensor_quant +from modelopt.torch.quantization import tensor_quant +from modelopt.torch.quantization.conversion import set_quantizer_attributes_partial from modelopt.torch.quantization.nn import QuantModuleRegistry from modelopt.torch.quantization.nn.modules.quant_rnn import VFRNNForward @@ -52,7 +53,7 @@ def test_no_quant(self, original_cls, bidirectional, bias): quant_rnn_object = QuantModuleRegistry.convert(rnn_object) rnn_object.eval() rnn_object_original.eval() - set_quantizer_attribute(quant_rnn_object, lambda name: True, {"enable": False}) + set_quantizer_attributes_partial(quant_rnn_object, lambda name: True, {"enable": False}) assert torch.allclose( quant_rnn_object.weight_ih_l0, rnn_object_original.weight_ih_l0, atol=1e-6 @@ -86,7 +87,7 @@ def test_no_quant_packed_sequence(self, original_cls, bidirectional, bias): quant_rnn_object = QuantModuleRegistry.convert(rnn_object) rnn_object.eval() rnn_object_original.eval() - set_quantizer_attribute(quant_rnn_object, lambda name: True, {"enable": False}) + set_quantizer_attributes_partial(quant_rnn_object, lambda name: True, {"enable": False}) assert torch.allclose( quant_rnn_object.weight_ih_l0, rnn_object_original.weight_ih_l0, atol=1e-6 @@ -124,7 +125,7 @@ def test_no_quant_proj(self, original_cls, bidirectional, bias): rnn_object_original = copy.deepcopy(rnn_object) quant_rnn_object = QuantModuleRegistry.convert(rnn_object) - set_quantizer_attribute(quant_rnn_object, lambda name: True, {"enable": False}) + set_quantizer_attributes_partial(quant_rnn_object, lambda name: True, {"enable": False}) test_input = torch.randn(INPUT_SHAPE) @@ -150,7 +151,7 @@ def test_no_quant_batch_first(self, original_cls, bidirectional): rnn_object_original = copy.deepcopy(rnn_object) quant_rnn_object = QuantModuleRegistry.convert(rnn_object) - set_quantizer_attribute(quant_rnn_object, lambda name: True, {"enable": False}) + set_quantizer_attributes_partial(quant_rnn_object, lambda name: True, {"enable": False}) test_input = torch.randn([INPUT_SHAPE[1], INPUT_SHAPE[0], INPUT_SHAPE[2]]) @@ -176,7 +177,7 @@ def test_fake_quant_per_tensor(self, original_cls, bidirectional): ) rnn_object_original = copy.deepcopy(rnn_object) quant_rnn_object = QuantModuleRegistry.convert(rnn_object) - set_quantizer_attribute(quant_rnn_object, lambda name: True, {"axis": None}) + set_quantizer_attributes_partial(quant_rnn_object, lambda name: True, {"axis": None}) quant_rnn_object._disable_input_quantizers() for name, weight in rnn_object_original.named_parameters(): @@ -205,7 +206,7 @@ def test_fake_quant_per_channel(self, original_cls, bidirectional): rnn_object = original_cls(HIDDEN_SIZE, HIDDEN_SIZE, NUM_LAYERS, bidirectional=bidirectional) rnn_object_original = copy.deepcopy(rnn_object) quant_rnn_object = QuantModuleRegistry.convert(rnn_object) - set_quantizer_attribute(quant_rnn_object, lambda name: True, {"axis": (0)}) + set_quantizer_attributes_partial(quant_rnn_object, lambda name: True, {"axis": (0)}) quant_rnn_object._disable_input_quantizers() for name, weight in rnn_object_original.named_parameters(): @@ -234,7 +235,7 @@ def test_input_quant_per_tensor(self, original_cls, bidirectional): HIDDEN_SIZE, HIDDEN_SIZE, NUM_LAYERS, bidirectional=bidirectional, bias=True ) quant_rnn_object = QuantModuleRegistry.convert(rnn_object) - set_quantizer_attribute(quant_rnn_object, lambda name: True, {"axis": None}) + set_quantizer_attributes_partial(quant_rnn_object, lambda name: True, {"axis": None}) quant_rnn_object._disable_weight_quantizers() num_directions = 2 if bidirectional else 1 diff --git a/tests/unit/torch/quantization/test_quantize_cpu.py b/tests/unit/torch/quantization/test_quantize_cpu.py index 641eafd2f..46f974a0c 100644 --- a/tests/unit/torch/quantization/test_quantize_cpu.py +++ b/tests/unit/torch/quantization/test_quantize_cpu.py @@ -32,41 +32,54 @@ import modelopt.torch.opt as mto import modelopt.torch.quantization as mtq from modelopt.torch.quantization.calib import MaxCalibrator +from modelopt.torch.quantization.config import QuantizerAttributeConfig +from modelopt.torch.quantization.conversion import set_quantizer_attributes_full +from modelopt.torch.quantization.nn.modules.tensor_quantizer import ( + SequentialQuantizer, + TensorQuantizer, +) # A test config with double-quant (using `SequentialQuantizers`) WINT4INT8_CFG = { - "quant_cfg": { - "*weight_quantizer": [ - {"num_bits": 4, "block_sizes": {-1: 128, "type": "static"}, "enable": True}, - {"num_bits": 8, "axis": 0, "enable": True}, - ], - "*input_quantizer": {"num_bits": 8, "axis": None, "enable": True}, - }, + "quant_cfg": [ + { + "quantizer_path": "*weight_quantizer", + "cfg": [ + {"num_bits": 4, "block_sizes": {-1: 128, "type": "static"}}, + {"num_bits": 8, "axis": 0}, + ], + "enable": True, + }, + { + "quantizer_path": "*input_quantizer", + "cfg": {"num_bits": 8, "axis": None}, + "enable": True, + }, + ], "algorithm": "awq_lite", } # Test configs for per channel MSE calibration INT8_MSE_CFG = { - "quant_cfg": { - "*weight_quantizer": {"num_bits": 8, "axis": 0}, - "*input_quantizer": {"num_bits": 8, "axis": None}, - }, + "quant_cfg": [ + {"quantizer_path": "*weight_quantizer", "cfg": {"num_bits": 8, "axis": 0}}, + {"quantizer_path": "*input_quantizer", "cfg": {"num_bits": 8, "axis": None}}, + ], "algorithm": "mse", } STATIC_WEIGHT_DYNAMIC_ACTIVATION_CFG = { - "quant_cfg": { - "*weight_quantizer": { - "num_bits": 8, - "axis": 0, + "quant_cfg": [ + {"quantizer_path": "*", "enable": False}, + { + "quantizer_path": "*weight_quantizer", + "cfg": {"num_bits": 8, "axis": 0}, }, # Per-channel quantization - "*input_quantizer": { - "num_bits": 8, - "axis": (0, 1), - "type": "dynamic", + { + "quantizer_path": "*input_quantizer", + "cfg": {"num_bits": 8, "axis": (0, 1), "type": "dynamic"}, }, # Dynamic per-token quantization - "default": {"enable": False}, - }, + ], "algorithm": "max", } @@ -77,14 +90,17 @@ def compute_amax(self): quant_cfg_custom_calib = { - "quant_cfg": { - "*": { - "num_bits": 4, - "axis": None, + "quant_cfg": [ + { + "quantizer_path": "*", + "cfg": { + "num_bits": 4, + "axis": None, + "calibrator": (NewMaxCalibrator, (4, None, False)), + }, "enable": True, - "calibrator": (NewMaxCalibrator, (4, None, False)), } - }, + ], "algorithm": "max", } @@ -131,7 +147,9 @@ def test_save_restore(model_cls, quant_config): def test_quantize_invalid_cfg(): model = SimpleLinear() config_invalid = { - "quant_cfg": {"*": {"num_bits": 4, "axis": 0, "block_sizes": {-1: 128}}}, + "quant_cfg": [ + {"quantizer_path": "*", "cfg": {"num_bits": 4, "axis": 0, "block_sizes": {-1: 128}}} + ], "algorithm": "max", } with pytest.raises(ValidationError, match="axis must be None when block_sizes is not None."): @@ -170,12 +188,22 @@ def test_custom_calib_config(): def test_class_wise_config(): model = SimpleConvLinear() config = { - "quant_cfg": { - "nn.Linear": {"*": {"num_bits": 4, "axis": -1, "enable": True}}, - "nn.Conv2d": {"*": {"num_bits": 8, "enable": True}}, - "nn.BatchNorm2d": {"*": {"enable": False}}, - "*output_quantizer": {"num_bits": 8, "enable": True}, - }, + "quant_cfg": [ + { + "parent_class": "nn.Linear", + "quantizer_path": "*", + "cfg": {"num_bits": 4, "axis": -1}, + "enable": True, + }, + { + "parent_class": "nn.Conv2d", + "quantizer_path": "*", + "cfg": {"num_bits": 8}, + "enable": True, + }, + {"parent_class": "nn.BatchNorm2d", "quantizer_path": "*", "enable": False}, + {"quantizer_path": "*output_quantizer", "cfg": {"num_bits": 8}, "enable": True}, + ], "algorithm": "max", } @@ -222,33 +250,28 @@ def test_static_weight_dynamic_activations(): def test_block_sizes_axis_model(): REF_QUANT_CFG = { # noqa: N806 - "quant_cfg": { - "*weight_quantizer": { - "num_bits": 8, - "axis": 0, + "quant_cfg": [ + {"quantizer_path": "*", "enable": False}, + {"quantizer_path": "*weight_quantizer", "cfg": {"num_bits": 8, "axis": 0}}, + { + "quantizer_path": "*input_quantizer", + "cfg": {"num_bits": 8, "axis": None, "type": "dynamic"}, }, - "*input_quantizer": { - "num_bits": 8, - "axis": None, - "type": "dynamic", - }, - "default": {"enable": False}, - }, + ], "algorithm": "max", } QUANT_CFG = { # noqa: N806 - "quant_cfg": { - "*weight_quantizer": { - "num_bits": 8, - "block_sizes": {1: None}, + "quant_cfg": [ + {"quantizer_path": "*", "enable": False}, + { + "quantizer_path": "*weight_quantizer", + "cfg": {"num_bits": 8, "block_sizes": {1: None}}, }, - "*input_quantizer": { - "num_bits": 8, - "block_sizes": {0: None, 1: None}, - "type": "dynamic", + { + "quantizer_path": "*input_quantizer", + "cfg": {"num_bits": 8, "block_sizes": {0: None, 1: None}, "type": "dynamic"}, }, - "default": {"enable": False}, - }, + ], "algorithm": "max", } model_ref = SimpleLinear() @@ -283,3 +306,97 @@ def forward_loop(model): out2 = model(inputs) assert torch.allclose(out1, out2), "Re-quantization with same config should be idempotent" + + +class TestSetQuantizerAttributesFull: + """Tests for set_quantizer_attributes_full and its atomicity semantics.""" + + def _quantize(self, model): + return mtq.quantize(model, mtq.INT8_DEFAULT_CFG, lambda m: m(m.get_input())) + + def test_basic_full_replacement(self): + """set_quantizer_attributes_full replaces all attributes on matched quantizers.""" + model = self._quantize(SimpleLinear()) + attrs = QuantizerAttributeConfig(num_bits=4, axis=0) + set_quantizer_attributes_full(model, "*weight_quantizer", attrs) + for name, module in model.named_modules(): + if name.endswith("weight_quantizer"): + assert isinstance(module, TensorQuantizer) + assert module.num_bits == 4 + assert module.axis == 0 + + def test_atomicity_unset_fields_revert_to_defaults(self): + """A full replacement reverts unspecified fields to QuantizerAttributeConfig defaults.""" + model = self._quantize(SimpleLinear()) + # First configure with axis=0 (non-default) + set_quantizer_attributes_full( + model, "*weight_quantizer", QuantizerAttributeConfig(num_bits=8, axis=0) + ) + for name, module in model.named_modules(): + if name.endswith("weight_quantizer"): + assert module.axis == 0 + + # Now replace with only num_bits=4; axis should revert to default (None) + set_quantizer_attributes_full( + model, "*weight_quantizer", QuantizerAttributeConfig(num_bits=4) + ) + default_axis = QuantizerAttributeConfig().axis + for name, module in model.named_modules(): + if name.endswith("weight_quantizer"): + assert module.num_bits == 4 + assert module.axis == default_axis + + def test_parent_class_filter(self): + """parent_class restricts which quantizers are affected.""" + model = self._quantize(SimpleConvLinear()) + # Only set num_bits=4 for quantizers inside nn.Linear modules + set_quantizer_attributes_full( + model, + "*weight_quantizer", + QuantizerAttributeConfig(num_bits=4), + parent_class=torch.nn.Linear, + ) + for name, module in model.named_modules(): + if not name.endswith("weight_quantizer"): + continue + parent_name = name.rpartition(".")[0] + parent = model.get_submodule(parent_name) + if isinstance(parent, torch.nn.Linear): + assert module.num_bits == 4 + else: + # Conv2d weight_quantizers should be unchanged (still 8-bit from INT8_DEFAULT_CFG) + assert module.num_bits == 8 + + def test_wildcard_no_match_is_noop(self): + """A wildcard that matches nothing silently does nothing.""" + model = self._quantize(SimpleLinear()) + # Record state before + bits_before = { + n: m.num_bits for n, m in model.named_modules() if isinstance(m, TensorQuantizer) + } + set_quantizer_attributes_full( + model, "*nonexistent_quantizer*", QuantizerAttributeConfig(num_bits=4) + ) + bits_after = { + n: m.num_bits for n, m in model.named_modules() if isinstance(m, TensorQuantizer) + } + assert bits_before == bits_after + + def test_invalid_attributes_type_raises(self): + """Passing a plain dict instead of QuantizerAttributeConfig raises ValueError.""" + model = self._quantize(SimpleLinear()) + with pytest.raises((ValueError, AttributeError)): + set_quantizer_attributes_full(model, "*weight_quantizer", {"num_bits": 4}) # type: ignore[arg-type] + + def test_list_attributes_creates_sequential_quantizer(self): + """A list of QuantizerAttributeConfig replaces TensorQuantizer with SequentialQuantizer.""" + model = self._quantize(SimpleLinear()) + attrs = [ + QuantizerAttributeConfig(num_bits=4, block_sizes={-1: 128}), + QuantizerAttributeConfig(num_bits=8, axis=0), + ] + set_quantizer_attributes_full(model, "*weight_quantizer", attrs) + for name, module in model.named_modules(): + if name.endswith("weight_quantizer"): + assert isinstance(module, SequentialQuantizer) + assert len(module) == 2 diff --git a/tests/unit/torch/quantization/test_quantize_replace.py b/tests/unit/torch/quantization/test_quantize_replace.py index 140da2b64..4b0f4edd2 100644 --- a/tests/unit/torch/quantization/test_quantize_replace.py +++ b/tests/unit/torch/quantization/test_quantize_replace.py @@ -47,7 +47,7 @@ def test_quantize_replace(model_cls): assert not isinstance(module, nn.Conv2d) or _is_quantized_linear_conv(module) assert not isinstance(module, nn.Linear) or _is_quantized_linear_conv(module) - mtq.set_quantizer_attribute(model_atq, "*", {"enable": False}) + mtq.set_quantizer_attributes_partial(model_atq, "*", {"enable": False}) out_ref = model_ref(dummy_input) out_atq = model_atq(dummy_input) diff --git a/tests/unit/torch/quantization/test_tensor_quant_cpu.py b/tests/unit/torch/quantization/test_tensor_quant_cpu.py index d5c6479cd..78a79bbcb 100644 --- a/tests/unit/torch/quantization/test_tensor_quant_cpu.py +++ b/tests/unit/torch/quantization/test_tensor_quant_cpu.py @@ -89,14 +89,18 @@ def test_num_bits(self): WINT4INT8_CFG = { - "quant_cfg": { - "*weight_quantizer": [ - {"num_bits": 4, "block_sizes": {-1: 128, "type": "static"}, "enable": True}, - {"num_bits": 8, "axis": 0, "enable": True}, - ], - "*input_quantizer": {"num_bits": 8, "enable": True}, - "default": {"enable": False}, - }, + "quant_cfg": [ + {"quantizer_path": "*", "enable": False}, + { + "quantizer_path": "*weight_quantizer", + "cfg": [ + {"num_bits": 4, "block_sizes": {-1: 128, "type": "static"}}, + {"num_bits": 8, "axis": 0}, + ], + "enable": True, + }, + {"quantizer_path": "*input_quantizer", "cfg": {"num_bits": 8}, "enable": True}, + ], "algorithm": "awq_full", } @@ -109,10 +113,14 @@ def test_set_quantizer_cxt(): state_dict = model.state_dict() output_ref = model(inputs) - mtq.set_quantizer_by_cfg(model, {"*output_quantizer": {"enable": True}}) + mtq.set_quantizer_by_cfg(model, [{"quantizer_path": "*output_quantizer", "enable": True}]) with mtq.set_quantizer_by_cfg_context( - model, {"*": {"enable": False}, "*output_quantizer": {"enable": True}} + model, + [ + {"quantizer_path": "*", "enable": False}, + {"quantizer_path": "*output_quantizer", "enable": True}, + ], ): for name, module in model.named_modules(): if not isinstance(module, TensorQuantizer): @@ -123,7 +131,7 @@ def test_set_quantizer_cxt(): assert not module.is_enabled mtq.calibrate(model, "max", lambda model: model(inputs * 10)) - mtq.set_quantizer_by_cfg(model, {"*output_quantizer": {"enable": False}}) + mtq.set_quantizer_by_cfg(model, [{"quantizer_path": "*output_quantizer", "enable": False}]) output_test = model(inputs) assert torch.allclose(output_ref, output_test)