diff --git a/.github/CODEOWNERS b/.github/CODEOWNERS
index d1e02acfba..f9f98b73d8 100644
--- a/.github/CODEOWNERS
+++ b/.github/CODEOWNERS
@@ -24,6 +24,7 @@ modelopt/torch/nas @NVIDIA/modelopt-torch-nas-prune-codeowners
modelopt/torch/opt @NVIDIA/modelopt-torch-opt-codeowners
modelopt/torch/peft @NVIDIA/modelopt-torch-peft-codeowners
modelopt/torch/prune @NVIDIA/modelopt-torch-nas-prune-codeowners
+modelopt/torch/puzzletron @NVIDIA/modelopt-torch-puzzletron-codeowners
modelopt/torch/quantization @NVIDIA/modelopt-torch-quantization-codeowners
modelopt/torch/sparsity @NVIDIA/modelopt-torch-sparsity-codeowners
modelopt/torch/speculative @NVIDIA/modelopt-torch-speculative-codeowners
diff --git a/.github/workflows/example_tests.yml b/.github/workflows/example_tests.yml
index f3f3908043..848b3d326d 100644
--- a/.github/workflows/example_tests.yml
+++ b/.github/workflows/example_tests.yml
@@ -125,14 +125,14 @@ jobs:
strategy: &nemo_strategy
fail-fast: false
matrix:
- example: [megatron_bridge]
+ example: [megatron_bridge, puzzletron]
uses: ./.github/workflows/_example_tests_runner.yml
secrets: inherit
with:
docker_image: "nvcr.io/nvidia/nemo:26.02"
example: ${{ matrix.example }}
timeout_minutes: 30
- pip_install_extras: "[hf,dev-test]"
+ pip_install_extras: "[hf,puzzletron,dev-test]"
runner: linux-amd64-gpu-rtxpro6000-latest-1
nemo-non-pr:
@@ -144,7 +144,7 @@ jobs:
docker_image: "nvcr.io/nvidia/nemo:26.02"
example: ${{ matrix.example }}
timeout_minutes: 30
- pip_install_extras: "[hf,dev-test]"
+ pip_install_extras: "[hf,puzzletron,dev-test]"
runner: linux-amd64-gpu-rtxpro6000-latest-2
##### ONNX/TensorRT Example Tests #####
diff --git a/.github/workflows/gpu_tests.yml b/.github/workflows/gpu_tests.yml
index 538e05e75f..542e948909 100644
--- a/.github/workflows/gpu_tests.yml
+++ b/.github/workflows/gpu_tests.yml
@@ -85,6 +85,8 @@ jobs:
- name: Setup environment variables
run: |
echo "LD_LIBRARY_PATH=${LD_LIBRARY_PATH}:/usr/include:/usr/lib/x86_64-linux-gnu" >> $GITHUB_ENV
+ - name: Install dependencies for mip
+ run: apt-get update && apt-get install -y libffi-dev
- name: Run gpu tests
run: pip install tox-current-env && tox -e cuda13-${{ matrix.example }} --current-env
gpu-tests-non-pr:
diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml
index 0a7821d42f..7810db7886 100644
--- a/.pre-commit-config.yaml
+++ b/.pre-commit-config.yaml
@@ -79,6 +79,7 @@ repos:
modelopt/onnx/quantization/ort_patching.py|
modelopt/torch/_deploy/utils/onnx_utils.py|
modelopt/torch/export/transformer_engine.py|
+ modelopt/torch/puzzletron/anymodel/models/gpt_oss/gpt_oss_pruned_to_mxfp4.py|
modelopt/torch/quantization/export_onnx.py|
modelopt/torch/quantization/plugins/attention.py|
modelopt/torch/speculative/eagle/utils.py|
@@ -95,6 +96,7 @@ repos:
examples/llm_eval/modeling.py|
examples/llm_qat/main.py|
examples/llm_sparsity/weight_sparsity/finetune.py|
+ examples/puzzletron/evaluation/lm_eval_anymodel.py|
examples/specdec_bench/specdec_bench/models/specbench_medusa.py|
examples/speculative_decoding/main.py|
examples/speculative_decoding/medusa_utils.py|
diff --git a/examples/pruning/README.md b/examples/pruning/README.md
index 656d6315db..930e9c6d25 100644
--- a/examples/pruning/README.md
+++ b/examples/pruning/README.md
@@ -7,6 +7,7 @@ Pruning can involve removal (prune) of Linear and Conv layers; and Transformer a
This section focuses on applying Model Optimizer's state-of-the-art complementary pruning modes to enable you to search for the best subnet architecture from your provided base model:
1. [Minitron](https://arxiv.org/pdf/2408.11796): A pruning method developed by NVIDIA Research for pruning GPT (and later extended to Mamba, MoE, and Hybrid Transformer Mamba) models in NVIDIA Megatron-LM (M-LM) or Megatron-Bridge (M-Bridge) framework. It uses the activation magnitudes to prune the embedding hidden size; mlp ffn hidden size; transformer attention heads; mamba heads and head dimension; MoE number of experts, ffn hidden size, and shared expert intermediate size; and number of layers of the model.
+1. [Puzzletron](../puzzletron/README.md): An advanced pruning method by NVIDIA using Mixed Integer Programming (MIP) based NAS search algorithm.
1. FastNAS: A pruning method recommended for Computer Vision models. Given a pretrained model, FastNAS finds the subnet which maximizes the score function while meeting the given constraints.
1. GradNAS: A light-weight pruning method recommended for language models like Hugging Face BERT, GPT-J. It uses the gradient information to prune the model's linear layers and attention heads to meet the given constraints.
diff --git a/examples/puzzletron/GPTOSS.md b/examples/puzzletron/GPTOSS.md
new file mode 100644
index 0000000000..7c160c8997
--- /dev/null
+++ b/examples/puzzletron/GPTOSS.md
@@ -0,0 +1,14 @@
+
+## GptOss
+
+With this release Puzzle algorithm supports only experts removal for `Gpt-Oss`.
+
+This model comes as a quantized checkpoint i.e. MoE experts matrices are quantized with _MXFP4_ format.
+In the pruning steps puzzle utilizes decompressed model (back to BF16) for statistics and scores computation.
+This means, during the conversion to puzzle format we decompress the model and store it as a BF16.
+Once the pruning is done i.e. experts to be removed are identified and the process is finished, user may want to get back the _MXFP4_ format of the checkpoint.
+To do so, there is an additional script, that takes the original and the pruned checkpoint and outputs pruned checkpoint in _MXFP4_ format.
+
+```bash
+python -m modelopt.torch.puzzletron.anymodel.models.gpt_oss.gpt_oss_pruned_to_mxfp4 --student-path /workspaces/any_model_gpt_oss/mip/puzzle_solutions/stats_num_params_18014757184/solutions--checkpoints/solution_0/ --original-path /workspaces/source_model_checkpoints/openai_gpt-oss-20b/ --output-path /workspaces/any_model_gpt_oss/mip/puzzle_solutions/stats_num_params_18014757184/solutions--checkpoints/mxfp4-ckpt/ --num-layers 24
+```
diff --git a/examples/puzzletron/README.md b/examples/puzzletron/README.md
new file mode 100644
index 0000000000..a7e3aedfc1
--- /dev/null
+++ b/examples/puzzletron/README.md
@@ -0,0 +1,284 @@
+# Puzzletron Algorithm Tutorial
+
+This tutorial demonstrates how to compress large language models using the puzzletron algorithm based on the [Puzzle paper](https://arxiv.org/abs/2411.19146).
+The goal of the algorithm it to find the most optimal modifications to MLP and attention layers of the model, resulting in a heterogeneous model architecture.
+The supported modifications are:
+
+- `ffn_intermediate_size`: different FFN intermediate sizes
+- `attention op/noop`: complete removal of attention layers
+
+To use the Puzzle algorithm effectively, we need to specify the target number of parameters and/or the memory. The final stage is based on Mixed-Integer Programming (MIP) algorithm to find the most optimal combination of layer modifications that satisfy the target requirements.
+
+In this example, we compress the [Llama-3.1-8B-Instruct](https://huggingface.co/meta-llama/Llama-3.1-8B-Instruct) model reducing GPU memory usage from 113 GiB to 96 GiB (15% reduction) with less than 1% regression in the token_accuracy_top_10 metric. Other supported models should be compressed in a similar way. For GptOss there is one [additional step to be performed](GPTOSS.md).
+
+> **Note:** Other models are also supported. See the [configs](./configs/) directory for additional model configurations (e.g., Llama-3.2-3B-Instruct on 1x H100, Qwen2.5-7B-Instruct on 1x H100, Qwen3-8B on 1x H100, Nemotron-Nano-12B-v2 on 1x H100, Mistral-Small-24B-Instruct-2501 on 4x H100). For information on adding support for new models, see the [AnyModel Guide](../../modelopt/torch/puzzletron/anymodel/README.md).
+
+## Environment
+
+- Install Model-Optimizer in editable mode with the corresponding dependencies (run from the repo root):
+
+```bash
+pip install -e .[hf,puzzletron]
+pip install -r examples/puzzletron/requirements.txt
+```
+
+> **Note:** NeMo containers may ship `nvidia-lm-eval` which may conflict with `lm-eval` that is used for evaluation.
+> If so, run `pip uninstall nvidia-lm-eval -y` before installing requirements.
+
+- For this example we are using 2x NVIDIA H100 80GB HBM3 to show multi-GPU steps. You can use also use a single GPU.
+
+- To make use of [Llama-3.1-8B-Instruct](https://huggingface.co/meta-llama/Llama-3.1-8B-Instruct) and [Nemotron-Post-Training-Dataset-v2](https://huggingface.co/datasets/nvidia/Nemotron-Post-Training-Dataset-v2), you need to accept the terms and conditions for the corresponding model and the dataset in the Huggingface Hub. Log in to the Huggingface Hub and enter your HF token.
+
+```bash
+hf auth login
+```
+
+## Compress the Model
+
+1. Download and prepare the [Nemotron-Post-Training-Dataset-v2](https://huggingface.co/datasets/nvidia/Nemotron-Post-Training-Dataset-v2).
+
+ dataset split: "code", "math", "stem", "chat", excluding reasoning samples (2.62GB)
+
+ ```bash
+ python -m modelopt.torch.puzzletron.dataset.prepare_dataset --dataset_name nvidia/Nemotron-Post-Training-Dataset-v2 --output_dir path/to/Nemotron-Post-Training-Dataset-v2
+ ```
+
+2. Specify the `puzzle_dir`, `input_hf_model_path`, `dataset_path`, `intermediate_size_list`, and `target_memory` arguments in the [llama-3_1-8B_pruneffn_memory.yaml](./configs/llama-3_1-8B_pruneffn_memory/llama-3_1-8B_pruneffn_memory.yaml) configuration file.
+
+ - `puzzle_dir` indicates a new directory for saving the resulting model.
+ - `input_hf_model_path` indicates the local directory with the input model checkpoint.
+ - `dataset_path` indicates the directory with the dataset downloaded earlier.
+
+ **_NOTE:_**
+ How to choose `intermediate_size_list`?
+ The list specifies the candidate FFN sizes that we wish to search over. It is recommended to choose several pruning sizes (e.g. 15%, 20%, 30% etc of the original). Note that the values must be hardware-friendly (divisible by a 256) to avoid issues with tensor operations in subsequent steps.
+
+ Let's first shoot for 32% GPU memory reduction setting `target_memory = 78_000` MiB. This means that the algorithm will choose the candidates with highest accuracy that also meet the specified requirements.
+
+ We can also set the target size of the resulting model using `num_params = 7_000_000_000`. This will be used as an upper bound for the number of parameters of the model.
+
+3. Run the puzzletron pipeline.
+
+ ```bash
+ torchrun --nproc_per_node 2 examples/puzzletron/main.py --config examples/puzzletron/configs/llama-3_1-8B_pruneffn_memory/llama-3_1-8B_pruneffn_memory.yaml 2>&1 | tee ./log.txt | grep "Puzzletron Progress"
+ ```
+
+ This will save the full output to `log.txt` and display the following progress on screen:
+
+ ```bash
+ [2025-11-02 12:06:34][rank-0][main.py:71] Puzzletron Progress 1/8: starting puzzletron pipeline
+ [2025-11-02 12:06:45][rank-0][puzzletron_nas_plugin.py:123] Puzzletron Progress 2/8: converting model from HF to DeciLM (single-gpu)
+ [2025-11-02 12:07:07][rank-0][puzzletron_nas_plugin.py:132] Puzzletron Progress 3/8: scoring pruning activations (multi-gpu)
+ [2025-11-02 12:11:36][rank-0][puzzletron_nas_plugin.py:137] Puzzletron Progress 4/8: pruning the model and saving pruned checkpoints (single-gpu)
+ [2025-11-02 12:12:20][rank-0][puzzletron_nas_plugin.py:217] Puzzletron Progress 5/8: building replacement library and subblock statistics (single-gpu)
+ [2025-11-02 12:12:21][rank-0][puzzletron_nas_plugin.py:222] Puzzletron Progress 6/8: calculating one block scores (multi-gpu)
+ [2025-11-02 12:50:41][rank-0][puzzletron_nas_plugin.py:226] Puzzletron Progress 7/8: running MIP and realizing models (multi-gpu)
+ [2025-11-02 12:52:34][rank-0][main.py:115] Puzzletron Progress 8/8: puzzletron pipeline completed (multi-gpu)
+ ```
+
+ Once the process is complete, the resulting network architecture will be recorded in `log.txt` for your review:
+
+ ```bash
+ ...
+ block_0: attention gqa_4 ffn intermediate_14336
+ block_1: attention gqa_4 ffn intermediate_14336
+ block_2: attention gqa_4 ffn intermediate_14336
+ block_3: attention gqa_4 ffn intermediate_14336
+ block_4: attention gqa_4 ffn intermediate_14336
+ block_5: attention gqa_4 ffn intermediate_14336
+ block_6: attention gqa_4 ffn intermediate_14336
+ block_7: attention gqa_4 ffn intermediate_14336
+ block_8: attention gqa_4 ffn intermediate_14336
+ block_9: attention gqa_4 ffn intermediate_14336
+ block_10: attention gqa_4 ffn intermediate_14336
+ block_11: attention gqa_4 ffn intermediate_14336
+ block_12: attention gqa_4 ffn intermediate_14336
+ block_13: attention gqa_4 ffn intermediate_14336
+ block_14: attention gqa_4 ffn intermediate_14336
+ block_15: attention gqa_4 ffn intermediate_14336
+ block_16: attention gqa_4 ffn intermediate_14336
+ block_17: attention no_op ffn intermediate_14336
+ block_18: attention no_op ffn intermediate_14336
+ block_19: attention no_op ffn intermediate_14336
+ block_20: attention no_op ffn intermediate_14336
+ block_21: attention no_op ffn intermediate_14336
+ block_22: attention no_op ffn intermediate_14336
+ block_23: attention no_op ffn intermediate_14336
+ block_24: attention no_op ffn intermediate_14336
+ block_25: attention no_op ffn intermediate_14336
+ block_26: attention no_op ffn intermediate_14336
+ block_27: attention no_op ffn intermediate_14336
+ block_28: attention no_op ffn intermediate_14336
+ block_29: attention gqa_4 ffn intermediate_14336
+ block_30: attention gqa_4 ffn intermediate_14336
+ block_31: attention gqa_4 ffn intermediate_14336
+
+ [2025-11-02 04:53:11,332]^[[92m[rank-0]^[[0m[run_puzzle.py:295] Total costs: {'stats.memory_mib': 75796.4140625, 'stats.ffn_num_params': 5637275648, 'stats.num_kv_heads': 160, 'stats.kv_cache_memory_mib': 61440.0, 'stats.ffn_memory_mib': 10752.25, 'stats.attention_memory_mib': 63040.15625, 'stats.attention_num_params': 838942720, 'stats.num_params': 7526895616, 'stats.has_attention': 20, 'stats.has_ffn': 32}
+ ...
+ ################################################################
+ validate_model_and_extract_token_probs(model_name='teacher')
+ ################################################################
+ ...
+ Average losses = {'lm_loss': 1.118250765837729, 'token_accuracy_top_1': 0.7331905364990234, 'token_accuracy_top_5': 0.9094219207763672, 'token_accuracy_top_10': 0.9423646926879883}
+ ...
+ ################################################################
+ validate_model_with_kl_div(model_name='solution_0', is_calc_kl_div=True)
+ ################################################################
+ ....
+ Average losses = {'lm_loss': 1.7577573340386152, 'token_accuracy_top_1': 0.6225490570068359, 'token_accuracy_top_5': 0.846257209777832, 'token_accuracy_top_10': 0.8987817764282227}
+ ```
+
+ 30% GPU memory reduction leads to nearly 5% regression in token_accuracy_top_10 metric (0.898 / 0.942).
+
+## Re-run MIP Search with different constraints
+
+If you want to try different constraints without re-running the expensive pruning and scoring steps, use the `--mip-only` flag.
+This assumes pruning, replacement library building, NAS scoring, and subblock stats calculation have already been completed.
+
+For example, let's set `target_memory: 96_000` in `llama-3_1-8B_pruneffn_memory.yaml`.
+
+```bash
+torchrun --nproc_per_node 2 examples/puzzletron/main.py --config examples/puzzletron/configs/llama-3_1-8B_pruneffn_memory/llama-3_1-8B_pruneffn_memory.yaml --mip-only 2>&1 | tee ./log.txt | grep "Puzzletron Progress"
+```
+
+This will generate the following network architecture (see `log.txt`):
+
+```bash
+block_0: attention gqa_4 ffn intermediate_14336
+block_1: attention gqa_4 ffn intermediate_14336
+block_2: attention gqa_4 ffn intermediate_14336
+block_3: attention gqa_4 ffn intermediate_14336
+block_4: attention gqa_4 ffn intermediate_14336
+block_5: attention gqa_4 ffn intermediate_14336
+block_6: attention gqa_4 ffn intermediate_14336
+block_7: attention gqa_4 ffn intermediate_14336
+block_8: attention gqa_4 ffn intermediate_14336
+block_9: attention gqa_4 ffn intermediate_14336
+block_10: attention gqa_4 ffn intermediate_14336
+block_11: attention gqa_4 ffn intermediate_14336
+block_12: attention gqa_4 ffn intermediate_14336
+block_13: attention gqa_4 ffn intermediate_14336
+block_14: attention gqa_4 ffn intermediate_14336
+block_15: attention gqa_4 ffn intermediate_14336
+block_16: attention gqa_4 ffn intermediate_14336
+block_17: attention gqa_4 ffn intermediate_14336
+block_18: attention no_op ffn intermediate_14336
+block_19: attention no_op ffn intermediate_14336
+block_20: attention no_op ffn intermediate_14336
+block_21: attention gqa_4 ffn intermediate_14336
+block_22: attention no_op ffn intermediate_14336
+block_23: attention no_op ffn intermediate_14336
+block_24: attention no_op ffn intermediate_14336
+block_25: attention gqa_4 ffn intermediate_14336
+block_26: attention gqa_4 ffn intermediate_14336
+block_27: attention gqa_4 ffn intermediate_14336
+block_28: attention gqa_4 ffn intermediate_14336
+block_29: attention gqa_4 ffn intermediate_14336
+block_30: attention gqa_4 ffn intermediate_14336
+block_31: attention gqa_4 ffn intermediate_14336
+
+[2025-11-02 12:50:42,024]^[[92m[rank-0]^[[0m[run_puzzle.py:295] Total costs: {'stats.memory_mib': 94708.4609375, 'stats.has_ffn': 32, 'stats.ffn_memory_mib': 10752.25, 'stats.kv_cache_memory_mib': 79872.0, 'stats.attention_num_params': 1090625536, 'stats.ffn_num_params': 5637275648, 'stats.has_attention': 26, 'stats.num_params': 7778578432, 'stats.attention_memory_mib': 81952.203125, 'stats.num_kv_heads': 208}
+...
+################################################################
+validate_model_with_kl_div(model_name='solution_0', is_calc_kl_div=True)
+################################################################
+Average losses = {'lm_loss': 1.2425934937782586, 'token_accuracy_top_1': 0.703862190246582, 'token_accuracy_top_5': 0.8954982757568359, 'token_accuracy_top_10': 0.9336576461791992
+```
+
+On the other hand, if you set `target_memory: 28_000`, you'll observe that the intermediate FFN sizes are significantly reduced in certain layers (see `log.txt` for details):
+
+```bash
+block_5: attention no_op ffn intermediate_11520
+block_6: attention no_op ffn intermediate_14336
+block_7: attention no_op ffn intermediate_8704
+block_8: attention no_op ffn intermediate_14336
+block_9: attention no_op ffn intermediate_3072
+block_10: attention no_op ffn intermediate_11520
+block_11: attention no_op ffn intermediate_11520
+block_12: attention no_op ffn intermediate_11520
+block_13: attention no_op ffn intermediate_11520
+block_14: attention no_op ffn intermediate_3072
+```
+
+### MIP Sweep Mode
+
+The **MIP sweep mode** lets you explore multiple memory compression rates in a single run and compare the accuracy-memory trade-offs.
+
+#### Quick Start
+
+1. Enable sweep in your config YAML (e.g., `llama-3_1-8B_pruneffn_memory.yaml`):
+
+ ```yaml
+ mip:
+ sweep:
+ enabled: true
+ memory_compression_rates: [0.5, 0.6, 0.7, 0.8, 0.9, 1.0]
+ output_csv: ${puzzle_dir}/mip_sweep_results.csv
+ ```
+
+2. Run the sweep:
+
+ ```bash
+ torchrun --nproc_per_node 2 examples/puzzletron/main.py --config examples/puzzletron/configs/llama-3_1-8B_pruneffn_memory/llama-3_1-8B_pruneffn_memory.yaml --mip-only 2>&1 | tee ./log.txt | grep "Puzzletron Progress"
+ ```
+
+3. View results: The CSV file contains compression rates, memory usage, and accuracy metrics for each configuration.
+
+#### Example Results
+
+
+
+The plot shows how token accuracy changes with different compression rates. Higher compression (0.5 = 50% of original memory) reduces accuracy, while lower compression maintains accuracy closer to the teacher model.
+
+## Evaluation
+
+Evaluate AnyModel checkpoints using [lm-eval](https://github.com/EleutherAI/lm-evaluation-harness) directly.
+
+```bash
+python examples/puzzletron/evaluation/lm_eval_anymodel.py \
+ --model hf \
+ --model_args pretrained=path/to/checkpoint,dtype=bfloat16,parallelize=True \
+ --tasks mmlu \
+ --num_fewshot 5 \
+ --batch_size 4
+```
+
+For a quick smoke test, add `--limit 10`.
+
+> **Alternative:** For server-based evaluation via an OpenAI-compatible endpoint,
+> see [evaluation/nemo_evaluator_instructions.md](./evaluation/nemo_evaluator_instructions.md).
+
+## Inference Performance Benchmarking
+
+Now let's evaluate how much speedup we get with the compressed model in terms of throughput and latency.
+
+- Install [vLLM from source](https://docs.vllm.ai/en/latest/getting_started/installation/gpu/index.html#build-wheel-from-source).
+- Rearrange the model safetensors to be used for vLLM.
+
+```bash
+cd path/to/model
+mv subblocks_safetensors/* .
+sed -i 's+subblocks_safetensors/++g' model.safetensors.index.json
+```
+
+- Benchmark latency
+
+```bash
+vllm bench latency --model path/to/model --load-format safetensors --trust-remote-code
+```
+
+- Benchmark throughput
+
+```bash
+vllm bench throughput --model path/to/model --input-len 2000 --output-len 100 --load-format safetensors --trust-remote-code
+```
+
+## Knowledge Distillation
+
+To recover degradation in the quality of the compressed model, we can use knowledge distillation. This allows transferring the capabilities of the original model to the pruned one.
+
+See [mbridge_distillation/README.md](./mbridge_distillation/README.md) for instructions on using Megatron-Bridge for knowledge distillation.
+
+## Advanced Usage
+
+Modify `llama-3_1-8B_pruneffn_memory.yaml` file for advanced compression scenarios.
diff --git a/examples/puzzletron/configs/gptoss-20b_remove_experts_memory/gptoss-20b.yaml b/examples/puzzletron/configs/gptoss-20b_remove_experts_memory/gptoss-20b.yaml
new file mode 100644
index 0000000000..b48f1de78c
--- /dev/null
+++ b/examples/puzzletron/configs/gptoss-20b_remove_experts_memory/gptoss-20b.yaml
@@ -0,0 +1,110 @@
+defaults:
+ - pruning: ffn_pruning
+ - scoring: ../validate_solutions_defaults
+ - realize_model: ../validate_solutions_defaults
+ - bypass:
+ - override hydra/hydra_logging: disabled
+ - _self_
+
+puzzle_dir: ???
+descriptor: gpt_oss
+teacher_dir: ${puzzle_dir}/ckpts/teacher/
+replacement_library_path: ${puzzle_dir}/replacement_library.json
+dataset_path: ??? # path to Nemotron-Post-Training-Dataset-v2
+
+skip_realize_model: false
+
+build_replacement_library:
+ add_ffn_no_ops: true
+ add_attention_no_ops: true
+
+calc_subblock_stats:
+ batch_sizes: [64, 96, 128]
+ prefill_seq_len: 4096
+ generation_seq_len: 4096
+ num_active_tokens_override: # Optional override for sequence lengths
+ prefill_queue_size: 0
+ allocate_prefill_query: false
+ benchmark_iterations: # Set to a number (e.g., 1000) to enable runtime benchmarking
+ merge_with_existing_stats: false
+ subblock_stats_filename: "subblock_stats.json"
+ moe_stats_filename: "moe_stats.json"
+ runtime_stats:
+ backend: trt_torch
+
+scoring:
+ descriptor: ${descriptor}
+ solutions_to_validate:
+ skip_existing_solutions: true
+
+ replacement_library_path: ${replacement_library_path}
+ solutions_path: ${to_path:${puzzle_dir}/single_sequence_replacement_solutions.json}
+ teacher_dir: ${to_path:${teacher_dir}}
+ output_dir: ${puzzle_dir}/single_sequence_replacement_solutions--validation
+
+ eval_samples: 128
+ micro_batch_size: 1
+ seed: 42
+ shuffle_seed: 444
+ dataset_path: ${dataset_path}
+
+mip:
+ single_block_replacement_validation_dir: ${to_path:${scoring.output_dir}}
+ subblock_stats_path: ${to_path:${puzzle_dir}/${calc_subblock_stats.subblock_stats_filename}}
+ output_path: ${to_path:${puzzle_dir}/mip/puzzle_solutions}
+ gathered_metrics_path:
+ puzzle_profile:
+
+ # puzzle_profile:
+ objective: metrics.cosine_embedding_loss_hidden_states
+ bigger_is_better: false
+
+ subblock_stats_args:
+ - batch_size: 96
+ weights_dtype: torch.bfloat16
+ activations_dtype: torch.bfloat16
+ kv_cache_dtype: torch.bfloat16
+
+ report_additional_costs:
+ - stats.memory_mib
+ - stats.num_params
+ - stats.num_kv_heads
+ - stats.has_attention
+ - stats.has_ffn
+ - stats.kv_cache_memory_mib
+ - stats.attention_memory_mib
+ - stats.ffn_memory_mib
+ - stats.ffn_num_params
+ - stats.attention_num_params
+
+ human_constraints:
+ target_memory: 45_000
+ num_params: 3_000_000_000
+
+ mip_constraints:
+ metric_overrides:
+ max_seconds_per_solution: 60
+
+realize_model:
+ descriptor: ${descriptor}
+ teacher_dir: ${to_path:${teacher_dir}}
+ tokenizer_name: ${to_path:${teacher_dir}}
+ replacement_library_path: ${replacement_library_path}
+ save_models: true
+ solutions_path: # Filled dynamically
+
+ # Validate params
+ skip_validation: false # To enable validation of the model solution set `skip_validation` as False
+ eval_samples: 128
+ micro_batch_size: 1
+ seed: 42
+ shuffle_seed: 444
+ dataset_path: ${dataset_path}
+
+nccl_timeout_minutes: ${timedelta_minutes:10}
+
+# This section redirects Hydra outputs
+hydra:
+ run:
+ dir: ${puzzle_dir}/hydra_logs/${now:%Y-%m-%d}/${now:%H-%M-%S}
+
diff --git a/examples/puzzletron/configs/gptoss-20b_remove_experts_memory/gptoss-20b_remove_experts_memory.yaml b/examples/puzzletron/configs/gptoss-20b_remove_experts_memory/gptoss-20b_remove_experts_memory.yaml
new file mode 100644
index 0000000000..8ed06e9568
--- /dev/null
+++ b/examples/puzzletron/configs/gptoss-20b_remove_experts_memory/gptoss-20b_remove_experts_memory.yaml
@@ -0,0 +1,17 @@
+defaults:
+ - gptoss-20b
+ - _self_
+
+# Input Hugging Face model to compress
+input_hf_model_path: /workspace/hf_models/openai/gpt-oss-20b
+
+# Dataset path for pruning and NAS scoring
+dataset_path: /workspace/datasets/Nemotron-Post-Training-Dataset-v2
+
+# Working directory for compression outputs
+puzzle_dir: /workspace/puzzle_dir
+
+# MIP memory constraint (in MiB)
+mip:
+ human_constraints:
+ target_memory: 16_000 # 45 GiB
diff --git a/examples/puzzletron/configs/gptoss-20b_remove_experts_memory/pruning/ffn_pruning.yaml b/examples/puzzletron/configs/gptoss-20b_remove_experts_memory/pruning/ffn_pruning.yaml
new file mode 100644
index 0000000000..258e6c38a3
--- /dev/null
+++ b/examples/puzzletron/configs/gptoss-20b_remove_experts_memory/pruning/ffn_pruning.yaml
@@ -0,0 +1,21 @@
+defaults:
+ - pruning_defaults
+
+eval_samples: 2500 #10
+activations_log_dir: ${puzzle_dir}/pruning/pruning_scores/expert_removal/${pruning.experiment_id}
+
+pruning_mixin:
+ _target_: modelopt.torch.puzzletron.pruning.expert_removal_pruning_mixin.ExpertRemovalPruningMixIn
+ layer_descriptor:
+ _target_: modelopt.torch.puzzletron.anymodel.models.gpt_oss.gpt_oss_model_descriptor.GptOssExpertRemovalLayerDescriptor
+ target_name: "mlp.router"
+
+hook_class: ${get_object:modelopt.torch.prune.importance_hooks.expert_removal_hooks.RankedChoiceVotingHook}
+activation_hooks_kwargs: # Additional kwargs to pass to the hook init
+
+num_experts_to_keep_list: [24, 16, 8] # num_experts in teacher is 128
+mlp_init_mode: "ExpertRemoval"
+mlp_init_config_yaml:
+ expert_scores_key: "expert_ranks"
+ layer_prefix_template: "model.layers.{layer_idx}.mlp.router"
+
diff --git a/examples/puzzletron/configs/gptoss-20b_remove_experts_memory/pruning/pruning_defaults.yaml b/examples/puzzletron/configs/gptoss-20b_remove_experts_memory/pruning/pruning_defaults.yaml
new file mode 100644
index 0000000000..0eff799d7e
--- /dev/null
+++ b/examples/puzzletron/configs/gptoss-20b_remove_experts_memory/pruning/pruning_defaults.yaml
@@ -0,0 +1,34 @@
+defaults:
+ - /validate_model_defaults
+
+model_name_or_path: ${teacher_dir}
+experiment_id: ${pruning.eval_samples}samples_diverse_mini
+activations_log_dir: ???
+activation_hooks_kwargs: ???
+
+descriptor: ${descriptor}
+
+# Data:
+eval_samples: 10_000
+micro_batch_size: 1
+dataset_path: ${dataset_path}
+val_dataset_name: train
+
+# Prune ckpts
+pruned_ckpts_output_dir: ${puzzle_dir}/pruning/${pruning.experiment_id}
+
+## FFN pruning
+ffn_list:
+mlp_init_mode: "Truncate" # PruneByActivationsLog
+
+## KV-heads pruning
+n_heads_in_group_list:
+gqa_init_mode: "AverageKV"
+
+## Hidden dimension pruning
+hidden_size_list:
+hidden_size_init_mode: "PruneByChannelRanking"
+linear_init_mode: "FromTeacher"
+
+mlp_init_config_yaml:
+ activations_log_dir: ${pruning.activations_log_dir}
diff --git a/examples/puzzletron/configs/gptoss-20b_remove_experts_memory/validate_model_defaults.yaml b/examples/puzzletron/configs/gptoss-20b_remove_experts_memory/validate_model_defaults.yaml
new file mode 100644
index 0000000000..b80faea5f5
--- /dev/null
+++ b/examples/puzzletron/configs/gptoss-20b_remove_experts_memory/validate_model_defaults.yaml
@@ -0,0 +1,18 @@
+model_dtype: torch.bfloat16 # dtype to cast the model for validate_model
+autocast_dtype: torch.bfloat16 # dtype for torch.autocast for validate_model
+block_size: 8192
+bos_rate: 0.5
+data_column: messages
+val_dataset_name: valid
+shuffle_seed: 81436
+seed: 42
+fim_rate: 0
+fim_spm_rate: 0
+source_datasets_to_discard:
+varlen: false
+write_results: false
+calc_losses_on_cpu: false
+activations_log_dir:
+model_name_or_path:
+load_dataset_fn: ${get_object:modelopt.torch.puzzletron.utils.data.dataloaders.load_from_disk_fn}
+
diff --git a/examples/puzzletron/configs/gptoss-20b_remove_experts_memory/validate_solutions_defaults.yaml b/examples/puzzletron/configs/gptoss-20b_remove_experts_memory/validate_solutions_defaults.yaml
new file mode 100644
index 0000000000..ab8c892182
--- /dev/null
+++ b/examples/puzzletron/configs/gptoss-20b_remove_experts_memory/validate_solutions_defaults.yaml
@@ -0,0 +1,11 @@
+defaults:
+ - /validate_model_defaults
+ - _self_
+
+solutions_to_validate:
+skip_validation: false
+save_models: false
+bigger_is_better: false
+sort_solutions_by:
+calculate_full_score_ablations: false
+
diff --git a/examples/puzzletron/configs/llama-3_1-8B_pruneffn_memory/Llama-3_1-8B.yaml b/examples/puzzletron/configs/llama-3_1-8B_pruneffn_memory/Llama-3_1-8B.yaml
new file mode 100644
index 0000000000..21903db162
--- /dev/null
+++ b/examples/puzzletron/configs/llama-3_1-8B_pruneffn_memory/Llama-3_1-8B.yaml
@@ -0,0 +1,109 @@
+defaults:
+ - pruning: ffn_pruning
+ - scoring: ../validate_solutions_defaults
+ - realize_model: ../validate_solutions_defaults
+ - bypass:
+ - override hydra/hydra_logging: disabled
+ - _self_
+
+puzzle_dir: ???
+descriptor: llama
+teacher_dir: ${puzzle_dir}/ckpts/teacher/
+replacement_library_path: ${puzzle_dir}/replacement_library.json
+dataset_path: ??? # ppath to Nemotron-Post-Training-Dataset-v2
+
+skip_realize_model: false
+
+build_replacement_library:
+ add_ffn_no_ops: true
+ add_attention_no_ops: true
+
+calc_subblock_stats:
+ batch_sizes: [64, 96, 128]
+ prefill_seq_len: 4096
+ generation_seq_len: 4096
+ num_active_tokens_override: # Optional override for sequence lengths
+ prefill_queue_size: 0
+ allocate_prefill_query: false
+ benchmark_iterations: # Set to a number (e.g., 1000) to enable runtime benchmarking
+ merge_with_existing_stats: false
+ subblock_stats_filename: "subblock_stats.json"
+ moe_stats_filename: "moe_stats.json"
+ runtime_stats:
+ backend: trt_torch
+
+scoring:
+ descriptor: ${descriptor}
+ solutions_to_validate:
+ skip_existing_solutions: true
+
+ replacement_library_path: ${replacement_library_path}
+ solutions_path: ${to_path:${puzzle_dir}/single_sequence_replacement_solutions.json}
+ teacher_dir: ${to_path:${teacher_dir}}
+ output_dir: ${puzzle_dir}/single_sequence_replacement_solutions--validation
+
+ eval_samples: 128
+ micro_batch_size: 1
+ seed: 42
+ shuffle_seed: 444
+ dataset_path: ${dataset_path}
+
+mip:
+ single_block_replacement_validation_dir: ${to_path:${scoring.output_dir}}
+ subblock_stats_path: ${to_path:${puzzle_dir}/${calc_subblock_stats.subblock_stats_filename}}
+ output_path: ${to_path:${puzzle_dir}/mip/puzzle_solutions}
+ gathered_metrics_path:
+ puzzle_profile:
+
+ # puzzle_profile:
+ objective: metrics.cosine_embedding_loss_hidden_states
+ bigger_is_better: false
+
+ subblock_stats_args:
+ - batch_size: 96
+ weights_dtype: torch.bfloat16
+ activations_dtype: torch.bfloat16
+ kv_cache_dtype: torch.bfloat16
+
+ report_additional_costs:
+ - stats.memory_mib
+ - stats.num_params
+ - stats.num_kv_heads
+ - stats.has_attention
+ - stats.has_ffn
+ - stats.kv_cache_memory_mib
+ - stats.attention_memory_mib
+ - stats.ffn_memory_mib
+ - stats.ffn_num_params
+ - stats.attention_num_params
+
+ human_constraints:
+ target_memory: 78_000
+ num_params: 7_000_000_000
+
+ mip_constraints:
+ metric_overrides:
+ max_seconds_per_solution: 60
+
+realize_model:
+ descriptor: ${descriptor}
+ teacher_dir: ${to_path:${teacher_dir}}
+ tokenizer_name: ${to_path:${teacher_dir}}
+ replacement_library_path: ${replacement_library_path}
+ save_models: true
+ solutions_path: # Filled dynamically
+
+ # Validate params
+ skip_validation: false # To enable validation of the model solution set `skip_validation` as False
+ eval_samples: 128
+ micro_batch_size: 1
+ seed: 42
+ shuffle_seed: 444
+ dataset_path: ${dataset_path}
+
+nccl_timeout_minutes: ${timedelta_minutes:10}
+
+# This section redirects Hydra outputs
+hydra:
+ run:
+ dir: ${puzzle_dir}/hydra_logs/${now:%Y-%m-%d}/${now:%H-%M-%S}
diff --git a/examples/puzzletron/configs/llama-3_1-8B_pruneffn_memory/llama-3_1-8B_pruneffn_memory.yaml b/examples/puzzletron/configs/llama-3_1-8B_pruneffn_memory/llama-3_1-8B_pruneffn_memory.yaml
new file mode 100644
index 0000000000..ad16dbc5ea
--- /dev/null
+++ b/examples/puzzletron/configs/llama-3_1-8B_pruneffn_memory/llama-3_1-8B_pruneffn_memory.yaml
@@ -0,0 +1,26 @@
+defaults:
+ - Llama-3_1-8B
+ - _self_
+
+# Input Hugging Face model to compress
+input_hf_model_path: /workspace/hf_models/meta-llama/Llama-3.1-8B-Instruct
+
+# Dataset path for pruning and NAS scoring
+dataset_path: /workspace/datasets/Nemotron-Post-Training-Dataset-v2
+
+# Working directory for puzzletron outputs
+puzzle_dir: /workspace/puzzle_dir
+
+# MIP memory constraint (in MiB)
+mip:
+ human_constraints:
+ target_memory: 78_000 # 78 GiB
+ # Memory sweep configuration (optional)
+ sweep:
+ enabled: false
+ memory_compression_rates: [0.5, 0.6, 0.7, 0.8, 0.9]
+ output_csv: ${puzzle_dir}/mip_sweep_results.csv
+
+# FFN intermediate sizes to search over (heterogeneous architecture)
+pruning:
+ intermediate_size_list: [3072, 5888, 8704, 11520] # teacher_intermediate_size is 14336
diff --git a/examples/puzzletron/configs/llama-3_1-8B_pruneffn_memory/pruning/attn_pruning.yaml b/examples/puzzletron/configs/llama-3_1-8B_pruneffn_memory/pruning/attn_pruning.yaml
new file mode 100644
index 0000000000..01886607e4
--- /dev/null
+++ b/examples/puzzletron/configs/llama-3_1-8B_pruneffn_memory/pruning/attn_pruning.yaml
@@ -0,0 +1,16 @@
+defaults:
+ - pruning_defaults
+
+activations_log_dir: ${puzzle_dir}/pruning/pruning_scores/attn_${pruning.activation_hooks_kwargs.method}/${pruning.experiment_id}
+
+activation_hooks_kwargs:
+ method: independent_kv_head_contribution
+ optimize_for: memory # IndependentKvHeadContributionHook implementation that consumes less memory
+ target_layer: "self_attn.o_proj"
+ layer_input_descriptors_path:
+
+# n_heads_in_group: 4
+# num_attention_heads: 32 # num query heads
+# num_kv_heads: 32 / 4 = 8 # num_query_heads // n_heads_in_group
+n_heads_in_group_list: [8, 16, 32] # num_kv_heads = [4, 2, 1]
+gqa_init_mode: "PruneKVHeads"
diff --git a/examples/puzzletron/configs/llama-3_1-8B_pruneffn_memory/pruning/ffn_pruning.yaml b/examples/puzzletron/configs/llama-3_1-8B_pruneffn_memory/pruning/ffn_pruning.yaml
new file mode 100644
index 0000000000..da0b972070
--- /dev/null
+++ b/examples/puzzletron/configs/llama-3_1-8B_pruneffn_memory/pruning/ffn_pruning.yaml
@@ -0,0 +1,19 @@
+defaults:
+ - pruning_defaults
+
+pruning_mixin:
+ _target_: modelopt.torch.puzzletron.pruning.ffn_intermediate_pruning_mixin.FFNIntermediatePruningMixIn
+ layer_descriptor:
+ _target_: modelopt.torch.puzzletron.anymodel.models.llama.llama_model_descriptor.LlamaFFNIntermediateLayerDescriptor
+
+hook_class: ${get_object:modelopt.torch.prune.importance_hooks.base_hooks.IterativeChannelContributionHook}
+
+activations_log_dir: ${puzzle_dir}/pruning/pruning_scores/ffn_${pruning.activation_hooks_kwargs.method}/${pruning.experiment_id}
+
+activation_hooks_kwargs:
+ method: iterative
+ target_layer: "mlp.down_proj"
+ layer_input_descriptors_path:
+
+intermediate_size_list: [3072, 5888, 8704, 11520] # teacher_intermediate_size is 14336
+mlp_init_mode: "PruneByActivationsLog"
diff --git a/examples/puzzletron/configs/llama-3_1-8B_pruneffn_memory/pruning/hidden_dim_pruning.yaml b/examples/puzzletron/configs/llama-3_1-8B_pruneffn_memory/pruning/hidden_dim_pruning.yaml
new file mode 100644
index 0000000000..407c835d8c
--- /dev/null
+++ b/examples/puzzletron/configs/llama-3_1-8B_pruneffn_memory/pruning/hidden_dim_pruning.yaml
@@ -0,0 +1,15 @@
+defaults:
+ - pruning_defaults
+
+activations_log_dir: ${puzzle_dir}/pruning/pruning_scores/hidden_dim_${pruning.activation_hooks_kwargs.method}/${pruning.experiment_id}
+
+activation_hooks_kwargs:
+ method: layer_norm_contribution
+ target_layer: "layernorm"
+
+# Hidden dimension pruning specific settings
+hidden_size_list: [3072, 2048] # Target hidden sizes to prune to
+hidden_size_init_mode: "PruneByChannelRanking"
+mlp_init_mode: "Truncate" # TODO, make it work with CopyAsIs/FromTeacher
+gqa_init_mode: "AverageKV" # TODO, make it work with CopyAsIs/FromTeacher
+linear_init_mode: "FromTeacher"
diff --git a/examples/puzzletron/configs/llama-3_1-8B_pruneffn_memory/pruning/pruning_defaults.yaml b/examples/puzzletron/configs/llama-3_1-8B_pruneffn_memory/pruning/pruning_defaults.yaml
new file mode 100644
index 0000000000..e05e775bee
--- /dev/null
+++ b/examples/puzzletron/configs/llama-3_1-8B_pruneffn_memory/pruning/pruning_defaults.yaml
@@ -0,0 +1,33 @@
+defaults:
+ - /validate_model_defaults
+
+descriptor: ${descriptor}
+model_name_or_path: ${teacher_dir}
+experiment_id: ${pruning.eval_samples}samples_diverse_mini
+activations_log_dir: ???
+activation_hooks_kwargs: ???
+
+# Data:
+eval_samples: 1000 # default is 10000
+micro_batch_size: 4
+dataset_path: ${dataset_path}
+val_dataset_name: train
+
+# Prune ckpts
+pruned_ckpts_output_dir: ${puzzle_dir}/pruning/${pruning.experiment_id}
+
+## FFN pruning
+ffn_list:
+mlp_init_mode: "Truncate" # PruneByActivationsLog
+
+## KV-heads pruning
+n_heads_in_group_list:
+gqa_init_mode: "AverageKV"
+
+## Hidden dimension pruning
+hidden_size_list:
+hidden_size_init_mode: "PruneByChannelRanking"
+linear_init_mode: "FromTeacher"
+
+mlp_init_config_yaml:
+ activations_log_dir: ${pruning.activations_log_dir}
diff --git a/examples/puzzletron/configs/llama-3_1-8B_pruneffn_memory/validate_model_defaults.yaml b/examples/puzzletron/configs/llama-3_1-8B_pruneffn_memory/validate_model_defaults.yaml
new file mode 100644
index 0000000000..ce1749d969
--- /dev/null
+++ b/examples/puzzletron/configs/llama-3_1-8B_pruneffn_memory/validate_model_defaults.yaml
@@ -0,0 +1,17 @@
+model_dtype: torch.bfloat16 # dtype to cast the model for validate_model
+autocast_dtype: torch.bfloat16 # dtype for torch.autocast for validate_model
+block_size: 8192
+bos_rate: 0.5
+data_column: messages
+val_dataset_name: valid
+shuffle_seed: 81436
+seed: 42
+fim_rate: 0
+fim_spm_rate: 0
+source_datasets_to_discard:
+varlen: false
+write_results: false
+calc_losses_on_cpu: false
+activations_log_dir:
+model_name_or_path:
+load_dataset_fn: ${get_object:modelopt.torch.puzzletron.utils.data.dataloaders.load_from_disk_fn}
diff --git a/examples/puzzletron/configs/llama-3_1-8B_pruneffn_memory/validate_solutions_defaults.yaml b/examples/puzzletron/configs/llama-3_1-8B_pruneffn_memory/validate_solutions_defaults.yaml
new file mode 100644
index 0000000000..ec13902379
--- /dev/null
+++ b/examples/puzzletron/configs/llama-3_1-8B_pruneffn_memory/validate_solutions_defaults.yaml
@@ -0,0 +1,10 @@
+defaults:
+ - /validate_model_defaults
+ - _self_
+
+solutions_to_validate:
+skip_validation: false
+save_models: false
+bigger_is_better: false
+sort_solutions_by:
+calculate_full_score_ablations: false
diff --git a/examples/puzzletron/configs/llama-3_2-3B_pruneffn_memory/Llama-3_2-3B.yaml b/examples/puzzletron/configs/llama-3_2-3B_pruneffn_memory/Llama-3_2-3B.yaml
new file mode 100644
index 0000000000..7de281e788
--- /dev/null
+++ b/examples/puzzletron/configs/llama-3_2-3B_pruneffn_memory/Llama-3_2-3B.yaml
@@ -0,0 +1,110 @@
+defaults:
+ - pruning: ffn_pruning
+ - scoring: ../validate_solutions_defaults
+ - realize_model: ../validate_solutions_defaults
+ - bypass:
+ - override hydra/hydra_logging: disabled
+ - _self_
+
+puzzle_dir: ???
+descriptor: llama
+teacher_dir: ${puzzle_dir}/ckpts/teacher/
+replacement_library_path: ${puzzle_dir}/replacement_library.json
+dataset_path: ??? # path to Nemotron-Post-Training-Dataset-v2
+
+skip_realize_model: false
+
+build_replacement_library:
+ add_ffn_no_ops: true
+ add_attention_no_ops: true
+
+calc_subblock_stats:
+ batch_sizes: [64, 96, 128]
+ prefill_seq_len: 4096
+ generation_seq_len: 4096
+ num_active_tokens_override: # Optional override for sequence lengths
+ prefill_queue_size: 0
+ allocate_prefill_query: false
+ benchmark_iterations: # Set to a number (e.g., 1000) to enable runtime benchmarking
+ merge_with_existing_stats: false
+ subblock_stats_filename: "subblock_stats.json"
+ moe_stats_filename: "moe_stats.json"
+ runtime_stats:
+ backend: trt_torch
+
+scoring:
+ descriptor: ${descriptor}
+ solutions_to_validate:
+ skip_existing_solutions: true
+
+ replacement_library_path: ${replacement_library_path}
+ solutions_path: ${to_path:${puzzle_dir}/single_sequence_replacement_solutions.json}
+ teacher_dir: ${to_path:${teacher_dir}}
+ output_dir: ${puzzle_dir}/single_sequence_replacement_solutions--validation
+
+ eval_samples: 128
+ micro_batch_size: 1
+ seed: 42
+ shuffle_seed: 444
+ dataset_path: ${dataset_path}
+
+mip:
+ single_block_replacement_validation_dir: ${to_path:${scoring.output_dir}}
+ subblock_stats_path: ${to_path:${puzzle_dir}/${calc_subblock_stats.subblock_stats_filename}}
+ output_path: ${to_path:${puzzle_dir}/mip/puzzle_solutions}
+ gathered_metrics_path:
+ puzzle_profile:
+
+ # puzzle_profile:
+ objective: metrics.cosine_embedding_loss_hidden_states
+ bigger_is_better: false
+
+ subblock_stats_args:
+ - batch_size: 96
+ weights_dtype: torch.bfloat16
+ activations_dtype: torch.bfloat16
+ kv_cache_dtype: torch.bfloat16
+
+ report_additional_costs:
+ - stats.memory_mib
+ - stats.num_params
+ - stats.num_kv_heads
+ - stats.has_attention
+ - stats.has_ffn
+ - stats.kv_cache_memory_mib
+ - stats.attention_memory_mib
+ - stats.ffn_memory_mib
+ - stats.ffn_num_params
+ - stats.attention_num_params
+
+ human_constraints:
+ target_memory: 45_000
+ num_params: 3_000_000_000
+
+ mip_constraints:
+ metric_overrides:
+ max_seconds_per_solution: 60
+
+realize_model:
+ descriptor: ${descriptor}
+ teacher_dir: ${to_path:${teacher_dir}}
+ tokenizer_name: ${to_path:${teacher_dir}}
+ replacement_library_path: ${replacement_library_path}
+ save_models: true
+ solutions_path: # Filled dynamically
+
+ # Validate params
+ skip_validation: false # To enable validation of the model solution set `skip_validation` as False
+ eval_samples: 128
+ micro_batch_size: 1
+ seed: 42
+ shuffle_seed: 444
+ dataset_path: ${dataset_path}
+
+nccl_timeout_minutes: ${timedelta_minutes:10}
+
+# This section redirects Hydra outputs
+hydra:
+ run:
+ dir: ${puzzle_dir}/hydra_logs/${now:%Y-%m-%d}/${now:%H-%M-%S}
+
diff --git a/examples/puzzletron/configs/llama-3_2-3B_pruneffn_memory/llama-3_2-3B_pruneffn_memory.yaml b/examples/puzzletron/configs/llama-3_2-3B_pruneffn_memory/llama-3_2-3B_pruneffn_memory.yaml
new file mode 100644
index 0000000000..b5303d318a
--- /dev/null
+++ b/examples/puzzletron/configs/llama-3_2-3B_pruneffn_memory/llama-3_2-3B_pruneffn_memory.yaml
@@ -0,0 +1,22 @@
+defaults:
+ - Llama-3_2-3B
+ - _self_
+
+# Input Hugging Face model to compress
+input_hf_model_path: /workspace/hf_models/meta-llama/Llama-3.2-3B-Instruct
+
+# Dataset path for pruning and NAS scoring
+dataset_path: /workspace/datasets/Nemotron-Post-Training-Dataset-v2
+
+# Working directory for compression outputs
+puzzle_dir: /workspace/puzzle_dir
+
+# MIP memory constraint (in MiB)
+mip:
+ human_constraints:
+ target_memory: 45_000 # 45 GiB
+
+# FFN intermediate sizes to search over (heterogeneous architecture)
+# teacher_intermediate_size is 8192, so we use proportionally smaller values
+pruning:
+ intermediate_size_list: [2048, 4096, 6144]
diff --git a/examples/puzzletron/configs/llama-3_2-3B_pruneffn_memory/pruning/ffn_pruning.yaml b/examples/puzzletron/configs/llama-3_2-3B_pruneffn_memory/pruning/ffn_pruning.yaml
new file mode 100644
index 0000000000..05de8bfdcc
--- /dev/null
+++ b/examples/puzzletron/configs/llama-3_2-3B_pruneffn_memory/pruning/ffn_pruning.yaml
@@ -0,0 +1,21 @@
+defaults:
+ - pruning_defaults
+
+pruning_mixin:
+ _target_: modelopt.torch.puzzletron.pruning.ffn_intermediate_pruning_mixin.FFNIntermediatePruningMixIn
+ layer_descriptor:
+ _target_: modelopt.torch.puzzletron.anymodel.models.llama.llama_model_descriptor.LlamaFFNIntermediateLayerDescriptor
+
+hook_class: ${get_object:modelopt.torch.prune.importance_hooks.base_hooks.IterativeChannelContributionHook}
+
+activations_log_dir: ${puzzle_dir}/pruning/pruning_scores/ffn_${pruning.activation_hooks_kwargs.method}/${pruning.experiment_id}
+
+activation_hooks_kwargs:
+ method: iterative
+ target_layer: "mlp.down_proj"
+ layer_input_descriptors_path:
+
+# Llama-3.2-3B has intermediate_size=8192, so we use proportionally smaller pruning sizes
+intermediate_size_list: [2048, 4096, 6144]
+mlp_init_mode: "PruneByActivationsLog"
+
diff --git a/examples/puzzletron/configs/llama-3_2-3B_pruneffn_memory/pruning/pruning_defaults.yaml b/examples/puzzletron/configs/llama-3_2-3B_pruneffn_memory/pruning/pruning_defaults.yaml
new file mode 100644
index 0000000000..e05e775bee
--- /dev/null
+++ b/examples/puzzletron/configs/llama-3_2-3B_pruneffn_memory/pruning/pruning_defaults.yaml
@@ -0,0 +1,33 @@
+defaults:
+ - /validate_model_defaults
+
+descriptor: ${descriptor}
+model_name_or_path: ${teacher_dir}
+experiment_id: ${pruning.eval_samples}samples_diverse_mini
+activations_log_dir: ???
+activation_hooks_kwargs: ???
+
+# Data:
+eval_samples: 1000 # default is 10000
+micro_batch_size: 4
+dataset_path: ${dataset_path}
+val_dataset_name: train
+
+# Prune ckpts
+pruned_ckpts_output_dir: ${puzzle_dir}/pruning/${pruning.experiment_id}
+
+## FFN pruning
+ffn_list:
+mlp_init_mode: "Truncate" # PruneByActivationsLog
+
+## KV-heads pruning
+n_heads_in_group_list:
+gqa_init_mode: "AverageKV"
+
+## Hidden dimension pruning
+hidden_size_list:
+hidden_size_init_mode: "PruneByChannelRanking"
+linear_init_mode: "FromTeacher"
+
+mlp_init_config_yaml:
+ activations_log_dir: ${pruning.activations_log_dir}
diff --git a/examples/puzzletron/configs/llama-3_2-3B_pruneffn_memory/validate_model_defaults.yaml b/examples/puzzletron/configs/llama-3_2-3B_pruneffn_memory/validate_model_defaults.yaml
new file mode 100644
index 0000000000..b80faea5f5
--- /dev/null
+++ b/examples/puzzletron/configs/llama-3_2-3B_pruneffn_memory/validate_model_defaults.yaml
@@ -0,0 +1,18 @@
+model_dtype: torch.bfloat16 # dtype to cast the model for validate_model
+autocast_dtype: torch.bfloat16 # dtype for torch.autocast for validate_model
+block_size: 8192
+bos_rate: 0.5
+data_column: messages
+val_dataset_name: valid
+shuffle_seed: 81436
+seed: 42
+fim_rate: 0
+fim_spm_rate: 0
+source_datasets_to_discard:
+varlen: false
+write_results: false
+calc_losses_on_cpu: false
+activations_log_dir:
+model_name_or_path:
+load_dataset_fn: ${get_object:modelopt.torch.puzzletron.utils.data.dataloaders.load_from_disk_fn}
+
diff --git a/examples/puzzletron/configs/llama-3_2-3B_pruneffn_memory/validate_solutions_defaults.yaml b/examples/puzzletron/configs/llama-3_2-3B_pruneffn_memory/validate_solutions_defaults.yaml
new file mode 100644
index 0000000000..ab8c892182
--- /dev/null
+++ b/examples/puzzletron/configs/llama-3_2-3B_pruneffn_memory/validate_solutions_defaults.yaml
@@ -0,0 +1,11 @@
+defaults:
+ - /validate_model_defaults
+ - _self_
+
+solutions_to_validate:
+skip_validation: false
+save_models: false
+bigger_is_better: false
+sort_solutions_by:
+calculate_full_score_ablations: false
+
diff --git a/examples/puzzletron/configs/mistral-small-24b-instruct-2501_pruneffn_memory/Mistral-Small-24B.yaml b/examples/puzzletron/configs/mistral-small-24b-instruct-2501_pruneffn_memory/Mistral-Small-24B.yaml
new file mode 100644
index 0000000000..18213f9b7a
--- /dev/null
+++ b/examples/puzzletron/configs/mistral-small-24b-instruct-2501_pruneffn_memory/Mistral-Small-24B.yaml
@@ -0,0 +1,109 @@
+defaults:
+ - pruning: ffn_pruning
+ - scoring: ../validate_solutions_defaults
+ - realize_model: ../validate_solutions_defaults
+ - bypass:
+ - override hydra/hydra_logging: disabled
+ - _self_
+
+puzzle_dir: ???
+descriptor: mistral_small
+teacher_dir: ${puzzle_dir}/ckpts/teacher/
+replacement_library_path: ${puzzle_dir}/replacement_library.json
+dataset_path: ??? # path to Nemotron-Post-Training-Dataset-v2
+
+skip_realize_model: false
+
+build_replacement_library:
+ add_ffn_no_ops: true
+ add_attention_no_ops: true
+
+calc_subblock_stats:
+ batch_sizes: [64, 96, 128]
+ prefill_seq_len: 4096
+ generation_seq_len: 4096
+ num_active_tokens_override: # Optional override for sequence lengths
+ prefill_queue_size: 0
+ allocate_prefill_query: false
+ benchmark_iterations: # Set to a number (e.g., 1000) to enable runtime benchmarking
+ merge_with_existing_stats: false
+ subblock_stats_filename: "subblock_stats.json"
+ moe_stats_filename: "moe_stats.json"
+ runtime_stats:
+ backend: trt_torch
+
+scoring:
+ descriptor: ${descriptor}
+ solutions_to_validate:
+ skip_existing_solutions: true
+
+ replacement_library_path: ${replacement_library_path}
+ solutions_path: ${to_path:${puzzle_dir}/single_sequence_replacement_solutions.json}
+ teacher_dir: ${to_path:${teacher_dir}}
+ output_dir: ${puzzle_dir}/single_sequence_replacement_solutions--validation
+
+ eval_samples: 128
+ micro_batch_size: 1
+ seed: 42
+ shuffle_seed: 444
+ dataset_path: ${dataset_path}
+
+mip:
+ single_block_replacement_validation_dir: ${to_path:${scoring.output_dir}}
+ subblock_stats_path: ${to_path:${puzzle_dir}/${calc_subblock_stats.subblock_stats_filename}}
+ output_path: ${to_path:${puzzle_dir}/mip/puzzle_solutions}
+ gathered_metrics_path:
+ puzzle_profile:
+
+ # puzzle_profile:
+ objective: metrics.cosine_embedding_loss_hidden_states
+ bigger_is_better: false
+
+ subblock_stats_args:
+ - batch_size: 96
+ weights_dtype: torch.bfloat16
+ activations_dtype: torch.bfloat16
+ kv_cache_dtype: torch.bfloat16
+
+ report_additional_costs:
+ - stats.memory_mib
+ - stats.num_params
+ - stats.num_kv_heads
+ - stats.has_attention
+ - stats.has_ffn
+ - stats.kv_cache_memory_mib
+ - stats.attention_memory_mib
+ - stats.ffn_memory_mib
+ - stats.ffn_num_params
+ - stats.attention_num_params
+
+ human_constraints:
+ target_memory: 78_000
+ num_params: 24_000_000_000
+
+ mip_constraints:
+ metric_overrides:
+ max_seconds_per_solution: 60
+
+realize_model:
+ descriptor: ${descriptor}
+ teacher_dir: ${to_path:${teacher_dir}}
+ tokenizer_name: ${to_path:${teacher_dir}}
+ replacement_library_path: ${replacement_library_path}
+ save_models: true
+ solutions_path: # Filled dynamically
+
+ # Validate params
+ skip_validation: false # To enable validation of the model solution set `skip_validation` as False
+ eval_samples: 128
+ micro_batch_size: 1
+ seed: 42
+ shuffle_seed: 444
+ dataset_path: ${dataset_path}
+
+nccl_timeout_minutes: ${timedelta_minutes:10}
+
+# This section redirects Hydra outputs
+hydra:
+ run:
+ dir: ${puzzle_dir}/hydra_logs/${now:%Y-%m-%d}/${now:%H-%M-%S}
diff --git a/examples/puzzletron/configs/mistral-small-24b-instruct-2501_pruneffn_memory/mistral-small-24b-instruct-2501_pruneffn_memory.yaml b/examples/puzzletron/configs/mistral-small-24b-instruct-2501_pruneffn_memory/mistral-small-24b-instruct-2501_pruneffn_memory.yaml
new file mode 100644
index 0000000000..68a0652d6f
--- /dev/null
+++ b/examples/puzzletron/configs/mistral-small-24b-instruct-2501_pruneffn_memory/mistral-small-24b-instruct-2501_pruneffn_memory.yaml
@@ -0,0 +1,21 @@
+defaults:
+ - Mistral-Small-24B
+ - _self_
+
+# Input Hugging Face model to compress
+input_hf_model_path: /workspace/hf_models/mistralai/Mistral-Small-24B-Instruct-2501
+
+# Dataset path for pruning and NAS scoring
+dataset_path: /workspace/datasets/Nemotron-Post-Training-Dataset-v2
+
+# Working directory for compression outputs
+puzzle_dir: /workspace/puzzle_dir
+
+# MIP memory constraint (in MiB)
+mip:
+ human_constraints:
+ target_memory: 234_000 # 234 GiB
+
+# FFN intermediate sizes to search over (heterogeneous architecture)
+pruning:
+ intermediate_size_list: [8192, 16384, 24576] # teacher_intermediate_size is 32768
diff --git a/examples/puzzletron/configs/mistral-small-24b-instruct-2501_pruneffn_memory/pruning/attn_pruning.yaml b/examples/puzzletron/configs/mistral-small-24b-instruct-2501_pruneffn_memory/pruning/attn_pruning.yaml
new file mode 100644
index 0000000000..cb24e1bc24
--- /dev/null
+++ b/examples/puzzletron/configs/mistral-small-24b-instruct-2501_pruneffn_memory/pruning/attn_pruning.yaml
@@ -0,0 +1,17 @@
+defaults:
+ - pruning_defaults
+
+activations_log_dir: ${puzzle_dir}/pruning/pruning_scores/attn_${pruning.activation_hooks_kwargs.method}/${pruning.experiment_id}
+
+activation_hooks_kwargs:
+ method: independent_kv_head_contribution
+ optimize_for: memory # IndependentKvHeadContributionHook implementation that consumes less memory
+ target_layer: "self_attn.o_proj"
+ layer_input_descriptors_path:
+
+# Mistral Small 24B: 32 query heads, 8 KV heads
+# n_heads_in_group = num_query_heads / num_kv_heads
+# num_kv_heads = num_query_heads / n_heads_in_group
+# Base: n_heads_in_group = 4, num_kv_heads = 8
+n_heads_in_group_list: [8, 16, 32] # num_kv_heads = [4, 2, 1]
+gqa_init_mode: "PruneKVHeads"
diff --git a/examples/puzzletron/configs/mistral-small-24b-instruct-2501_pruneffn_memory/pruning/ffn_pruning.yaml b/examples/puzzletron/configs/mistral-small-24b-instruct-2501_pruneffn_memory/pruning/ffn_pruning.yaml
new file mode 100644
index 0000000000..5fb7fcbdd2
--- /dev/null
+++ b/examples/puzzletron/configs/mistral-small-24b-instruct-2501_pruneffn_memory/pruning/ffn_pruning.yaml
@@ -0,0 +1,20 @@
+defaults:
+ - pruning_defaults
+
+activations_log_dir: ${puzzle_dir}/pruning/pruning_scores/ffn_${pruning.activation_hooks_kwargs.method}/${pruning.experiment_id}
+
+pruning_mixin:
+ _target_: modelopt.torch.puzzletron.pruning.ffn_intermediate_pruning_mixin.FFNIntermediatePruningMixIn
+ layer_descriptor:
+ _target_: modelopt.torch.puzzletron.anymodel.models.mistral_small.mistral_small_model_descriptor.MistralFFNIntermediateLayerDescriptor
+
+hook_class: ${get_object:modelopt.torch.prune.importance_hooks.base_hooks.IterativeChannelContributionHook}
+activation_hooks_kwargs:
+ method: iterative
+ target_layer: "mlp.down_proj"
+ layer_input_descriptors_path:
+
+# FFN intermediate sizes to search over (heterogeneous architecture)
+# teacher_intermediate_size is 32768
+intermediate_size_list: [8192, 16384, 24576]
+mlp_init_mode: "PruneByActivationsLog"
diff --git a/examples/puzzletron/configs/mistral-small-24b-instruct-2501_pruneffn_memory/pruning/hidden_dim_pruning.yaml b/examples/puzzletron/configs/mistral-small-24b-instruct-2501_pruneffn_memory/pruning/hidden_dim_pruning.yaml
new file mode 100644
index 0000000000..7de32621e0
--- /dev/null
+++ b/examples/puzzletron/configs/mistral-small-24b-instruct-2501_pruneffn_memory/pruning/hidden_dim_pruning.yaml
@@ -0,0 +1,16 @@
+defaults:
+ - pruning_defaults
+
+activations_log_dir: ${puzzle_dir}/pruning/pruning_scores/hidden_dim_${pruning.activation_hooks_kwargs.method}/${pruning.experiment_id}
+
+activation_hooks_kwargs:
+ method: layer_norm_contribution
+ target_layer: "layernorm"
+
+# Hidden dimension pruning specific settings
+# Mistral Small 24B: hidden_size is 5120
+hidden_size_list: [3072, 4096] # Target hidden sizes to prune to
+hidden_size_init_mode: "PruneByChannelRanking"
+mlp_init_mode: "Truncate" # TODO, make it work with CopyAsIs/FromTeacher
+gqa_init_mode: "AverageKV" # TODO, make it work with CopyAsIs/FromTeacher
+linear_init_mode: "FromTeacher"
diff --git a/examples/puzzletron/configs/mistral-small-24b-instruct-2501_pruneffn_memory/pruning/pruning_defaults.yaml b/examples/puzzletron/configs/mistral-small-24b-instruct-2501_pruneffn_memory/pruning/pruning_defaults.yaml
new file mode 100644
index 0000000000..e05e775bee
--- /dev/null
+++ b/examples/puzzletron/configs/mistral-small-24b-instruct-2501_pruneffn_memory/pruning/pruning_defaults.yaml
@@ -0,0 +1,33 @@
+defaults:
+ - /validate_model_defaults
+
+descriptor: ${descriptor}
+model_name_or_path: ${teacher_dir}
+experiment_id: ${pruning.eval_samples}samples_diverse_mini
+activations_log_dir: ???
+activation_hooks_kwargs: ???
+
+# Data:
+eval_samples: 1000 # default is 10000
+micro_batch_size: 4
+dataset_path: ${dataset_path}
+val_dataset_name: train
+
+# Prune ckpts
+pruned_ckpts_output_dir: ${puzzle_dir}/pruning/${pruning.experiment_id}
+
+## FFN pruning
+ffn_list:
+mlp_init_mode: "Truncate" # PruneByActivationsLog
+
+## KV-heads pruning
+n_heads_in_group_list:
+gqa_init_mode: "AverageKV"
+
+## Hidden dimension pruning
+hidden_size_list:
+hidden_size_init_mode: "PruneByChannelRanking"
+linear_init_mode: "FromTeacher"
+
+mlp_init_config_yaml:
+ activations_log_dir: ${pruning.activations_log_dir}
diff --git a/examples/puzzletron/configs/mistral-small-24b-instruct-2501_pruneffn_memory/validate_model_defaults.yaml b/examples/puzzletron/configs/mistral-small-24b-instruct-2501_pruneffn_memory/validate_model_defaults.yaml
new file mode 100644
index 0000000000..ce1749d969
--- /dev/null
+++ b/examples/puzzletron/configs/mistral-small-24b-instruct-2501_pruneffn_memory/validate_model_defaults.yaml
@@ -0,0 +1,17 @@
+model_dtype: torch.bfloat16 # dtype to cast the model for validate_model
+autocast_dtype: torch.bfloat16 # dtype for torch.autocast for validate_model
+block_size: 8192
+bos_rate: 0.5
+data_column: messages
+val_dataset_name: valid
+shuffle_seed: 81436
+seed: 42
+fim_rate: 0
+fim_spm_rate: 0
+source_datasets_to_discard:
+varlen: false
+write_results: false
+calc_losses_on_cpu: false
+activations_log_dir:
+model_name_or_path:
+load_dataset_fn: ${get_object:modelopt.torch.puzzletron.utils.data.dataloaders.load_from_disk_fn}
diff --git a/examples/puzzletron/configs/mistral-small-24b-instruct-2501_pruneffn_memory/validate_solutions_defaults.yaml b/examples/puzzletron/configs/mistral-small-24b-instruct-2501_pruneffn_memory/validate_solutions_defaults.yaml
new file mode 100644
index 0000000000..ec13902379
--- /dev/null
+++ b/examples/puzzletron/configs/mistral-small-24b-instruct-2501_pruneffn_memory/validate_solutions_defaults.yaml
@@ -0,0 +1,10 @@
+defaults:
+ - /validate_model_defaults
+ - _self_
+
+solutions_to_validate:
+skip_validation: false
+save_models: false
+bigger_is_better: false
+sort_solutions_by:
+calculate_full_score_ablations: false
diff --git a/examples/puzzletron/configs/nemotron-nano-12b-v2/nemotron_nano_12b_v2.yaml b/examples/puzzletron/configs/nemotron-nano-12b-v2/nemotron_nano_12b_v2.yaml
new file mode 100644
index 0000000000..62b6ecb4cb
--- /dev/null
+++ b/examples/puzzletron/configs/nemotron-nano-12b-v2/nemotron_nano_12b_v2.yaml
@@ -0,0 +1,109 @@
+defaults:
+ - pruning: ffn_pruning
+ - scoring: ../validate_solutions_defaults
+ - realize_model: ../validate_solutions_defaults
+ - bypass:
+ - override hydra/hydra_logging: disabled
+ - _self_
+
+puzzle_dir: ???
+descriptor: nemotron_h_v2
+teacher_dir: ${puzzle_dir}/ckpts/teacher/
+replacement_library_path: ${puzzle_dir}/replacement_library.json
+dataset_path: ??? # path to Nemotron-Post-Training-Dataset-v2
+
+skip_realize_model: false
+
+build_replacement_library:
+ add_ffn_no_ops: true
+ add_attention_no_ops: true
+
+calc_subblock_stats:
+ batch_sizes: [64, 96, 128]
+ prefill_seq_len: 4096
+ generation_seq_len: 4096
+ num_active_tokens_override: # Optional override for sequence lengths
+ prefill_queue_size: 0
+ allocate_prefill_query: false
+ runtime_stats:
+ backend: trt_torch
+ benchmark_iterations: # Set to a number (e.g., 1000) to enable runtime benchmarking
+ merge_with_existing_stats: false
+ subblock_stats_filename: "subblock_stats.json"
+ moe_stats_filename: "moe_stats.json"
+
+scoring:
+ descriptor: ${descriptor}
+ solutions_to_validate:
+ skip_existing_solutions: true
+
+ replacement_library_path: ${replacement_library_path}
+ solutions_path: ${to_path:${puzzle_dir}/single_sequence_replacement_solutions.json}
+ teacher_dir: ${to_path:${teacher_dir}}
+ output_dir: ${puzzle_dir}/single_sequence_replacement_solutions--validation
+
+ eval_samples: 128
+ micro_batch_size: 1
+ seed: 42
+ shuffle_seed: 444
+ dataset_path: ${dataset_path}
+
+mip:
+ single_block_replacement_validation_dir: ${to_path:${scoring.output_dir}}
+ subblock_stats_path: ${to_path:${puzzle_dir}/${calc_subblock_stats.subblock_stats_filename}}
+ output_path: ${to_path:${puzzle_dir}/mip/puzzle_solutions}
+ gathered_metrics_path:
+ puzzle_profile:
+
+ # puzzle_profile:
+ objective: metrics.cosine_embedding_loss_hidden_states
+ bigger_is_better: false
+
+ subblock_stats_args:
+ - batch_size: 96
+ weights_dtype: torch.bfloat16
+ activations_dtype: torch.bfloat16
+ kv_cache_dtype: torch.bfloat16
+
+ report_additional_costs:
+ - stats.memory_mib
+ - stats.num_params
+ - stats.num_kv_heads
+ - stats.has_attention
+ - stats.has_ffn
+ - stats.kv_cache_memory_mib
+ - stats.attention_memory_mib
+ - stats.ffn_memory_mib
+ - stats.ffn_num_params
+ - stats.attention_num_params
+
+ human_constraints:
+ target_memory: 90_000
+ num_params: 12_000_000_000
+
+ mip_constraints:
+ metric_overrides:
+ max_seconds_per_solution: 60
+
+realize_model:
+ descriptor: ${descriptor}
+ teacher_dir: ${to_path:${teacher_dir}}
+ tokenizer_name: ${to_path:${teacher_dir}}
+ replacement_library_path: ${replacement_library_path}
+ save_models: true
+ solutions_path: # Filled dynamically
+
+ # Validate params
+ skip_validation: false # To enable validation of the model solution set `skip_validation` as False
+ eval_samples: 128
+ micro_batch_size: 1
+ seed: 42
+ shuffle_seed: 444
+ dataset_path: ${dataset_path}
+
+nccl_timeout_minutes: ${timedelta_minutes:10}
+
+# This section redirects Hydra outputs
+hydra:
+ run:
+ dir: ${puzzle_dir}/hydra_logs/${now:%Y-%m-%d}/${now:%H-%M-%S}
diff --git a/examples/puzzletron/configs/nemotron-nano-12b-v2/nemotron_nano_12b_v2_pruneffn_memory.yaml b/examples/puzzletron/configs/nemotron-nano-12b-v2/nemotron_nano_12b_v2_pruneffn_memory.yaml
new file mode 100644
index 0000000000..3b880b2c7d
--- /dev/null
+++ b/examples/puzzletron/configs/nemotron-nano-12b-v2/nemotron_nano_12b_v2_pruneffn_memory.yaml
@@ -0,0 +1,22 @@
+defaults:
+ - nemotron_nano_12b_v2
+ - _self_
+
+# Input Hugging Face model to compress
+input_hf_model_path: /workspace/hf_models/nvidia/Nemotron-Nano-12B-v2
+
+# Dataset path for pruning and NAS scoring
+dataset_path: /workspace/datasets/Nemotron-Post-Training-Dataset-v2
+
+# Working directory for compression outputs
+puzzle_dir: /workspace/puzzle_dir
+
+# MIP memory constraint (in MiB)
+mip:
+ human_constraints:
+ target_memory: 90_000 # 90 GiB
+
+# FFN intermediate sizes to search over (heterogeneous architecture)
+# teacher_intermediate_size is 20480
+pruning:
+ intermediate_size_list: [4352, 8448, 12544, 16384]
diff --git a/examples/puzzletron/configs/nemotron-nano-12b-v2/pruning/attn_pruning.yaml b/examples/puzzletron/configs/nemotron-nano-12b-v2/pruning/attn_pruning.yaml
new file mode 100644
index 0000000000..01886607e4
--- /dev/null
+++ b/examples/puzzletron/configs/nemotron-nano-12b-v2/pruning/attn_pruning.yaml
@@ -0,0 +1,16 @@
+defaults:
+ - pruning_defaults
+
+activations_log_dir: ${puzzle_dir}/pruning/pruning_scores/attn_${pruning.activation_hooks_kwargs.method}/${pruning.experiment_id}
+
+activation_hooks_kwargs:
+ method: independent_kv_head_contribution
+ optimize_for: memory # IndependentKvHeadContributionHook implementation that consumes less memory
+ target_layer: "self_attn.o_proj"
+ layer_input_descriptors_path:
+
+# n_heads_in_group: 4
+# num_attention_heads: 32 # num query heads
+# num_kv_heads: 32 / 4 = 8 # num_query_heads // n_heads_in_group
+n_heads_in_group_list: [8, 16, 32] # num_kv_heads = [4, 2, 1]
+gqa_init_mode: "PruneKVHeads"
diff --git a/examples/puzzletron/configs/nemotron-nano-12b-v2/pruning/ffn_pruning.yaml b/examples/puzzletron/configs/nemotron-nano-12b-v2/pruning/ffn_pruning.yaml
new file mode 100644
index 0000000000..1e2ecf07a0
--- /dev/null
+++ b/examples/puzzletron/configs/nemotron-nano-12b-v2/pruning/ffn_pruning.yaml
@@ -0,0 +1,18 @@
+defaults:
+ - pruning_defaults
+
+activations_log_dir: ${puzzle_dir}/pruning/pruning_scores/ffn_${pruning.activation_hooks_kwargs.method}/${pruning.experiment_id}
+
+pruning_mixin:
+ _target_: modelopt.torch.puzzletron.pruning.ffn_intermediate_pruning_mixin.FFNIntermediatePruningMixIn
+ layer_descriptor:
+ _target_: modelopt.torch.puzzletron.anymodel.models.nemotron_h_v2.nemotron_h_v2_model_descriptor.NemotronHV2FFNIntermediateLayerDescriptor
+
+hook_class: ${get_object:modelopt.torch.prune.importance_hooks.base_hooks.IterativeChannelContributionHook}
+activation_hooks_kwargs:
+ method: iterative
+ target_layer: "mixer.down_proj"
+ layer_input_descriptors_path:
+
+intermediate_size_list: [256] # teacher_intermediate_size is 14336
+mlp_init_mode: "PruneByActivationsLog"
diff --git a/examples/puzzletron/configs/nemotron-nano-12b-v2/pruning/hidden_dim_pruning.yaml b/examples/puzzletron/configs/nemotron-nano-12b-v2/pruning/hidden_dim_pruning.yaml
new file mode 100644
index 0000000000..407c835d8c
--- /dev/null
+++ b/examples/puzzletron/configs/nemotron-nano-12b-v2/pruning/hidden_dim_pruning.yaml
@@ -0,0 +1,15 @@
+defaults:
+ - pruning_defaults
+
+activations_log_dir: ${puzzle_dir}/pruning/pruning_scores/hidden_dim_${pruning.activation_hooks_kwargs.method}/${pruning.experiment_id}
+
+activation_hooks_kwargs:
+ method: layer_norm_contribution
+ target_layer: "layernorm"
+
+# Hidden dimension pruning specific settings
+hidden_size_list: [3072, 2048] # Target hidden sizes to prune to
+hidden_size_init_mode: "PruneByChannelRanking"
+mlp_init_mode: "Truncate" # TODO, make it work with CopyAsIs/FromTeacher
+gqa_init_mode: "AverageKV" # TODO, make it work with CopyAsIs/FromTeacher
+linear_init_mode: "FromTeacher"
diff --git a/examples/puzzletron/configs/nemotron-nano-12b-v2/pruning/pruning_defaults.yaml b/examples/puzzletron/configs/nemotron-nano-12b-v2/pruning/pruning_defaults.yaml
new file mode 100644
index 0000000000..8816eecc4a
--- /dev/null
+++ b/examples/puzzletron/configs/nemotron-nano-12b-v2/pruning/pruning_defaults.yaml
@@ -0,0 +1,34 @@
+defaults:
+ - /validate_model_defaults
+
+model_name_or_path: ${teacher_dir}
+experiment_id: ${pruning.eval_samples}samples_diverse_mini
+activations_log_dir: ???
+activation_hooks_kwargs: ???
+
+descriptor: ${descriptor}
+
+# Data:
+eval_samples: 1000 # default is 10000
+micro_batch_size: 4
+dataset_path: ${dataset_path}
+val_dataset_name: train
+
+# Prune ckpts
+pruned_ckpts_output_dir: ${puzzle_dir}/pruning/${pruning.experiment_id}
+
+## FFN pruning
+ffn_list:
+mlp_init_mode: "Truncate"
+
+## KV-heads pruning
+n_heads_in_group_list:
+gqa_init_mode: "AverageKV"
+
+## Hidden dimension pruning
+hidden_size_list:
+hidden_size_init_mode: "PruneByChannelRanking"
+linear_init_mode: "FromTeacher"
+
+mlp_init_config_yaml:
+ activations_log_dir: ${pruning.activations_log_dir}
diff --git a/examples/puzzletron/configs/nemotron-nano-12b-v2/validate_model_defaults.yaml b/examples/puzzletron/configs/nemotron-nano-12b-v2/validate_model_defaults.yaml
new file mode 100644
index 0000000000..ce1749d969
--- /dev/null
+++ b/examples/puzzletron/configs/nemotron-nano-12b-v2/validate_model_defaults.yaml
@@ -0,0 +1,17 @@
+model_dtype: torch.bfloat16 # dtype to cast the model for validate_model
+autocast_dtype: torch.bfloat16 # dtype for torch.autocast for validate_model
+block_size: 8192
+bos_rate: 0.5
+data_column: messages
+val_dataset_name: valid
+shuffle_seed: 81436
+seed: 42
+fim_rate: 0
+fim_spm_rate: 0
+source_datasets_to_discard:
+varlen: false
+write_results: false
+calc_losses_on_cpu: false
+activations_log_dir:
+model_name_or_path:
+load_dataset_fn: ${get_object:modelopt.torch.puzzletron.utils.data.dataloaders.load_from_disk_fn}
diff --git a/examples/puzzletron/configs/nemotron-nano-12b-v2/validate_solutions_defaults.yaml b/examples/puzzletron/configs/nemotron-nano-12b-v2/validate_solutions_defaults.yaml
new file mode 100644
index 0000000000..ec13902379
--- /dev/null
+++ b/examples/puzzletron/configs/nemotron-nano-12b-v2/validate_solutions_defaults.yaml
@@ -0,0 +1,10 @@
+defaults:
+ - /validate_model_defaults
+ - _self_
+
+solutions_to_validate:
+skip_validation: false
+save_models: false
+bigger_is_better: false
+sort_solutions_by:
+calculate_full_score_ablations: false
diff --git a/examples/puzzletron/configs/qwen2_5_7b_instruct_pruneffn_memory/pruning/attn_pruning.yaml b/examples/puzzletron/configs/qwen2_5_7b_instruct_pruneffn_memory/pruning/attn_pruning.yaml
new file mode 100644
index 0000000000..3f7a248ee7
--- /dev/null
+++ b/examples/puzzletron/configs/qwen2_5_7b_instruct_pruneffn_memory/pruning/attn_pruning.yaml
@@ -0,0 +1,16 @@
+defaults:
+ - pruning_defaults
+
+activations_log_dir: ${puzzle_dir}/pruning/pruning_scores/attn_${modelopt.torch.puzzletron.pruning.activation_hooks_kwargs.method}/${modelopt.torch.puzzletron.pruning.experiment_id}
+
+activation_hooks_kwargs:
+ method: independent_kv_head_contribution
+ optimize_for: memory # IndependentKvHeadContributionHook implementation that consumes less memory
+ target_layer: "self_attn.o_proj"
+ layer_input_descriptors_path:
+
+# n_heads_in_group: 4
+# num_attention_heads: 32 # num query heads
+# num_kv_heads: 32 / 4 = 8 # num_query_heads // n_heads_in_group
+n_heads_in_group_list: [8, 16, 32] # num_kv_heads = [4, 2, 1]
+gqa_init_mode: "PruneKVHeads"
diff --git a/examples/puzzletron/configs/qwen2_5_7b_instruct_pruneffn_memory/pruning/ffn_pruning.yaml b/examples/puzzletron/configs/qwen2_5_7b_instruct_pruneffn_memory/pruning/ffn_pruning.yaml
new file mode 100644
index 0000000000..18d7e234ac
--- /dev/null
+++ b/examples/puzzletron/configs/qwen2_5_7b_instruct_pruneffn_memory/pruning/ffn_pruning.yaml
@@ -0,0 +1,18 @@
+defaults:
+ - pruning_defaults
+
+activations_log_dir: ${puzzle_dir}/pruning/pruning_scores/ffn_${pruning.activation_hooks_kwargs.method}/${pruning.experiment_id}
+
+pruning_mixin:
+ _target_: modelopt.torch.puzzletron.pruning.ffn_intermediate_pruning_mixin.FFNIntermediatePruningMixIn
+ layer_descriptor:
+ _target_: modelopt.torch.puzzletron.anymodel.models.qwen2.qwen2_model_descriptor.Qwen2FFNIntermediateLayerDescriptor
+
+hook_class: ${get_object:modelopt.torch.prune.importance_hooks.base_hooks.IterativeChannelContributionHook}
+activation_hooks_kwargs:
+ method: iterative
+ target_layer: "mlp.down_proj"
+ layer_input_descriptors_path:
+
+intermediate_size_list: [256] # teacher_intermediate_size is 14336
+mlp_init_mode: "PruneByActivationsLog"
diff --git a/examples/puzzletron/configs/qwen2_5_7b_instruct_pruneffn_memory/pruning/hidden_dim_pruning.yaml b/examples/puzzletron/configs/qwen2_5_7b_instruct_pruneffn_memory/pruning/hidden_dim_pruning.yaml
new file mode 100644
index 0000000000..af8af990b7
--- /dev/null
+++ b/examples/puzzletron/configs/qwen2_5_7b_instruct_pruneffn_memory/pruning/hidden_dim_pruning.yaml
@@ -0,0 +1,15 @@
+defaults:
+ - pruning_defaults
+
+activations_log_dir: ${puzzle_dir}/pruning/pruning_scores/hidden_dim_${modelopt.torch.puzzletron.pruning.activation_hooks_kwargs.method}/${modelopt.torch.puzzletron.pruning.experiment_id}
+
+activation_hooks_kwargs:
+ method: layer_norm_contribution
+ target_layer: "layernorm"
+
+# Hidden dimension pruning specific settings
+hidden_size_list: [3072, 2048] # Target hidden sizes to prune to
+hidden_size_init_mode: "PruneByChannelRanking"
+mlp_init_mode: "Truncate" # TODO, make it work with CopyAsIs/FromTeacher
+gqa_init_mode: "AverageKV" # TODO, make it work with CopyAsIs/FromTeacher
+linear_init_mode: "FromTeacher"
diff --git a/examples/puzzletron/configs/qwen2_5_7b_instruct_pruneffn_memory/pruning/pruning_defaults.yaml b/examples/puzzletron/configs/qwen2_5_7b_instruct_pruneffn_memory/pruning/pruning_defaults.yaml
new file mode 100644
index 0000000000..8816eecc4a
--- /dev/null
+++ b/examples/puzzletron/configs/qwen2_5_7b_instruct_pruneffn_memory/pruning/pruning_defaults.yaml
@@ -0,0 +1,34 @@
+defaults:
+ - /validate_model_defaults
+
+model_name_or_path: ${teacher_dir}
+experiment_id: ${pruning.eval_samples}samples_diverse_mini
+activations_log_dir: ???
+activation_hooks_kwargs: ???
+
+descriptor: ${descriptor}
+
+# Data:
+eval_samples: 1000 # default is 10000
+micro_batch_size: 4
+dataset_path: ${dataset_path}
+val_dataset_name: train
+
+# Prune ckpts
+pruned_ckpts_output_dir: ${puzzle_dir}/pruning/${pruning.experiment_id}
+
+## FFN pruning
+ffn_list:
+mlp_init_mode: "Truncate"
+
+## KV-heads pruning
+n_heads_in_group_list:
+gqa_init_mode: "AverageKV"
+
+## Hidden dimension pruning
+hidden_size_list:
+hidden_size_init_mode: "PruneByChannelRanking"
+linear_init_mode: "FromTeacher"
+
+mlp_init_config_yaml:
+ activations_log_dir: ${pruning.activations_log_dir}
diff --git a/examples/puzzletron/configs/qwen2_5_7b_instruct_pruneffn_memory/qwen2_5_7b_instruct.yaml b/examples/puzzletron/configs/qwen2_5_7b_instruct_pruneffn_memory/qwen2_5_7b_instruct.yaml
new file mode 100644
index 0000000000..aa11499a3c
--- /dev/null
+++ b/examples/puzzletron/configs/qwen2_5_7b_instruct_pruneffn_memory/qwen2_5_7b_instruct.yaml
@@ -0,0 +1,109 @@
+defaults:
+ - pruning: ffn_pruning
+ - scoring: ../validate_solutions_defaults
+ - realize_model: ../validate_solutions_defaults
+ - bypass:
+ - override hydra/hydra_logging: disabled
+ - _self_
+
+puzzle_dir: ???
+descriptor: qwen2
+teacher_dir: ${puzzle_dir}/ckpts/teacher/
+replacement_library_path: ${puzzle_dir}/replacement_library.json
+dataset_path: ??? # path to Nemotron-Post-Training-Dataset-v2
+
+skip_realize_model: false
+
+build_replacement_library:
+ add_ffn_no_ops: true
+ add_attention_no_ops: true
+
+calc_subblock_stats:
+ batch_sizes: [64, 96, 128]
+ prefill_seq_len: 4096
+ generation_seq_len: 4096
+ num_active_tokens_override: # Optional override for sequence lengths
+ prefill_queue_size: 0
+ allocate_prefill_query: false
+ runtime_stats:
+ backend: trt_torch
+ benchmark_iterations: # Set to a number (e.g., 1000) to enable runtime benchmarking
+ merge_with_existing_stats: false
+ subblock_stats_filename: "subblock_stats.json"
+ moe_stats_filename: "moe_stats.json"
+
+scoring:
+ descriptor: ${descriptor}
+ solutions_to_validate:
+ skip_existing_solutions: true
+
+ replacement_library_path: ${replacement_library_path}
+ solutions_path: ${to_path:${puzzle_dir}/single_sequence_replacement_solutions.json}
+ teacher_dir: ${to_path:${teacher_dir}}
+ output_dir: ${puzzle_dir}/single_sequence_replacement_solutions--validation
+
+ eval_samples: 128
+ micro_batch_size: 1
+ seed: 42
+ shuffle_seed: 444
+ dataset_path: ${dataset_path}
+
+mip:
+ single_block_replacement_validation_dir: ${to_path:${scoring.output_dir}}
+ subblock_stats_path: ${to_path:${puzzle_dir}/${calc_subblock_stats.subblock_stats_filename}}
+ output_path: ${to_path:${puzzle_dir}/mip/puzzle_solutions}
+ gathered_metrics_path:
+ puzzle_profile:
+
+ # puzzle_profile:
+ objective: metrics.cosine_embedding_loss_hidden_states
+ bigger_is_better: false
+
+ subblock_stats_args:
+ - batch_size: 96
+ weights_dtype: torch.bfloat16
+ activations_dtype: torch.bfloat16
+ kv_cache_dtype: torch.bfloat16
+
+ report_additional_costs:
+ - stats.memory_mib
+ - stats.num_params
+ - stats.num_kv_heads
+ - stats.has_attention
+ - stats.has_ffn
+ - stats.kv_cache_memory_mib
+ - stats.attention_memory_mib
+ - stats.ffn_memory_mib
+ - stats.ffn_num_params
+ - stats.attention_num_params
+
+ human_constraints:
+ target_memory: 78_000
+ num_params: 7_000_000_000
+
+ mip_constraints:
+ metric_overrides:
+ max_seconds_per_solution: 60
+
+realize_model:
+ descriptor: ${descriptor}
+ teacher_dir: ${to_path:${teacher_dir}}
+ tokenizer_name: ${to_path:${teacher_dir}}
+ replacement_library_path: ${replacement_library_path}
+ save_models: true
+ solutions_path: # Filled dynamically
+
+ # Validate params
+ skip_validation: false # To enable validation of the model solution set `skip_validation` as False
+ eval_samples: 128
+ micro_batch_size: 1
+ seed: 42
+ shuffle_seed: 444
+ dataset_path: ${dataset_path}
+
+nccl_timeout_minutes: ${timedelta_minutes:10}
+
+# This section redirects Hydra outputs
+hydra:
+ run:
+ dir: ${puzzle_dir}/hydra_logs/${now:%Y-%m-%d}/${now:%H-%M-%S}
diff --git a/examples/puzzletron/configs/qwen2_5_7b_instruct_pruneffn_memory/qwen2_5_7b_instruct_pruneffn_memory.yaml b/examples/puzzletron/configs/qwen2_5_7b_instruct_pruneffn_memory/qwen2_5_7b_instruct_pruneffn_memory.yaml
new file mode 100644
index 0000000000..fb961033bc
--- /dev/null
+++ b/examples/puzzletron/configs/qwen2_5_7b_instruct_pruneffn_memory/qwen2_5_7b_instruct_pruneffn_memory.yaml
@@ -0,0 +1,22 @@
+defaults:
+ - qwen2_5_7b_instruct
+ - _self_
+
+# Input Hugging Face model to compress
+input_hf_model_path: /workspace/hf_models/Qwen/Qwen2.5-7B-Instruct
+
+# Dataset path for pruning and NAS scoring
+dataset_path: /workspace/datasets/Nemotron-Post-Training-Dataset-v2
+
+# Working directory for compression outputs
+puzzle_dir: /workspace/puzzle_dir
+
+# MIP memory constraint (in MiB)
+mip:
+ human_constraints:
+ target_memory: 78_000 # 78 GiB
+
+# FFN intermediate sizes to search over (heterogeneous architecture)
+# teacher_intermediate_size is 18944
+pruning:
+ intermediate_size_list: [4096, 7808, 11520, 15104]
diff --git a/examples/puzzletron/configs/qwen2_5_7b_instruct_pruneffn_memory/validate_model_defaults.yaml b/examples/puzzletron/configs/qwen2_5_7b_instruct_pruneffn_memory/validate_model_defaults.yaml
new file mode 100644
index 0000000000..ce1749d969
--- /dev/null
+++ b/examples/puzzletron/configs/qwen2_5_7b_instruct_pruneffn_memory/validate_model_defaults.yaml
@@ -0,0 +1,17 @@
+model_dtype: torch.bfloat16 # dtype to cast the model for validate_model
+autocast_dtype: torch.bfloat16 # dtype for torch.autocast for validate_model
+block_size: 8192
+bos_rate: 0.5
+data_column: messages
+val_dataset_name: valid
+shuffle_seed: 81436
+seed: 42
+fim_rate: 0
+fim_spm_rate: 0
+source_datasets_to_discard:
+varlen: false
+write_results: false
+calc_losses_on_cpu: false
+activations_log_dir:
+model_name_or_path:
+load_dataset_fn: ${get_object:modelopt.torch.puzzletron.utils.data.dataloaders.load_from_disk_fn}
diff --git a/examples/puzzletron/configs/qwen2_5_7b_instruct_pruneffn_memory/validate_solutions_defaults.yaml b/examples/puzzletron/configs/qwen2_5_7b_instruct_pruneffn_memory/validate_solutions_defaults.yaml
new file mode 100644
index 0000000000..ec13902379
--- /dev/null
+++ b/examples/puzzletron/configs/qwen2_5_7b_instruct_pruneffn_memory/validate_solutions_defaults.yaml
@@ -0,0 +1,10 @@
+defaults:
+ - /validate_model_defaults
+ - _self_
+
+solutions_to_validate:
+skip_validation: false
+save_models: false
+bigger_is_better: false
+sort_solutions_by:
+calculate_full_score_ablations: false
diff --git a/examples/puzzletron/configs/qwen3-8b_pruneffn_memory/pruning/attn_pruning.yaml b/examples/puzzletron/configs/qwen3-8b_pruneffn_memory/pruning/attn_pruning.yaml
new file mode 100644
index 0000000000..01886607e4
--- /dev/null
+++ b/examples/puzzletron/configs/qwen3-8b_pruneffn_memory/pruning/attn_pruning.yaml
@@ -0,0 +1,16 @@
+defaults:
+ - pruning_defaults
+
+activations_log_dir: ${puzzle_dir}/pruning/pruning_scores/attn_${pruning.activation_hooks_kwargs.method}/${pruning.experiment_id}
+
+activation_hooks_kwargs:
+ method: independent_kv_head_contribution
+ optimize_for: memory # IndependentKvHeadContributionHook implementation that consumes less memory
+ target_layer: "self_attn.o_proj"
+ layer_input_descriptors_path:
+
+# n_heads_in_group: 4
+# num_attention_heads: 32 # num query heads
+# num_kv_heads: 32 / 4 = 8 # num_query_heads // n_heads_in_group
+n_heads_in_group_list: [8, 16, 32] # num_kv_heads = [4, 2, 1]
+gqa_init_mode: "PruneKVHeads"
diff --git a/examples/puzzletron/configs/qwen3-8b_pruneffn_memory/pruning/ffn_pruning.yaml b/examples/puzzletron/configs/qwen3-8b_pruneffn_memory/pruning/ffn_pruning.yaml
new file mode 100644
index 0000000000..93590d13e5
--- /dev/null
+++ b/examples/puzzletron/configs/qwen3-8b_pruneffn_memory/pruning/ffn_pruning.yaml
@@ -0,0 +1,18 @@
+defaults:
+ - pruning_defaults
+
+activations_log_dir: ${puzzle_dir}/pruning/pruning_scores/ffn_${pruning.activation_hooks_kwargs.method}/${pruning.experiment_id}
+
+pruning_mixin:
+ _target_: modelopt.torch.puzzletron.pruning.ffn_intermediate_pruning_mixin.FFNIntermediatePruningMixIn
+ layer_descriptor:
+ _target_: modelopt.torch.puzzletron.anymodel.models.qwen3.qwen3_model_descriptor.Qwen3FFNIntermediateLayerDescriptor
+
+hook_class: ${get_object:modelopt.torch.prune.importance_hooks.base_hooks.IterativeChannelContributionHook}
+activation_hooks_kwargs:
+ method: iterative
+ target_layer: "mlp.down_proj"
+ layer_input_descriptors_path:
+
+intermediate_size_list: [256] # teacher_intermediate_size is 14336
+mlp_init_mode: "PruneByActivationsLog"
diff --git a/examples/puzzletron/configs/qwen3-8b_pruneffn_memory/pruning/hidden_dim_pruning.yaml b/examples/puzzletron/configs/qwen3-8b_pruneffn_memory/pruning/hidden_dim_pruning.yaml
new file mode 100644
index 0000000000..407c835d8c
--- /dev/null
+++ b/examples/puzzletron/configs/qwen3-8b_pruneffn_memory/pruning/hidden_dim_pruning.yaml
@@ -0,0 +1,15 @@
+defaults:
+ - pruning_defaults
+
+activations_log_dir: ${puzzle_dir}/pruning/pruning_scores/hidden_dim_${pruning.activation_hooks_kwargs.method}/${pruning.experiment_id}
+
+activation_hooks_kwargs:
+ method: layer_norm_contribution
+ target_layer: "layernorm"
+
+# Hidden dimension pruning specific settings
+hidden_size_list: [3072, 2048] # Target hidden sizes to prune to
+hidden_size_init_mode: "PruneByChannelRanking"
+mlp_init_mode: "Truncate" # TODO, make it work with CopyAsIs/FromTeacher
+gqa_init_mode: "AverageKV" # TODO, make it work with CopyAsIs/FromTeacher
+linear_init_mode: "FromTeacher"
diff --git a/examples/puzzletron/configs/qwen3-8b_pruneffn_memory/pruning/pruning_defaults.yaml b/examples/puzzletron/configs/qwen3-8b_pruneffn_memory/pruning/pruning_defaults.yaml
new file mode 100644
index 0000000000..8816eecc4a
--- /dev/null
+++ b/examples/puzzletron/configs/qwen3-8b_pruneffn_memory/pruning/pruning_defaults.yaml
@@ -0,0 +1,34 @@
+defaults:
+ - /validate_model_defaults
+
+model_name_or_path: ${teacher_dir}
+experiment_id: ${pruning.eval_samples}samples_diverse_mini
+activations_log_dir: ???
+activation_hooks_kwargs: ???
+
+descriptor: ${descriptor}
+
+# Data:
+eval_samples: 1000 # default is 10000
+micro_batch_size: 4
+dataset_path: ${dataset_path}
+val_dataset_name: train
+
+# Prune ckpts
+pruned_ckpts_output_dir: ${puzzle_dir}/pruning/${pruning.experiment_id}
+
+## FFN pruning
+ffn_list:
+mlp_init_mode: "Truncate"
+
+## KV-heads pruning
+n_heads_in_group_list:
+gqa_init_mode: "AverageKV"
+
+## Hidden dimension pruning
+hidden_size_list:
+hidden_size_init_mode: "PruneByChannelRanking"
+linear_init_mode: "FromTeacher"
+
+mlp_init_config_yaml:
+ activations_log_dir: ${pruning.activations_log_dir}
diff --git a/examples/puzzletron/configs/qwen3-8b_pruneffn_memory/qwen3_8b.yaml b/examples/puzzletron/configs/qwen3-8b_pruneffn_memory/qwen3_8b.yaml
new file mode 100644
index 0000000000..eec82a7d63
--- /dev/null
+++ b/examples/puzzletron/configs/qwen3-8b_pruneffn_memory/qwen3_8b.yaml
@@ -0,0 +1,109 @@
+defaults:
+ - pruning: ffn_pruning
+ - scoring: ../validate_solutions_defaults
+ - realize_model: ../validate_solutions_defaults
+ - bypass:
+ - override hydra/hydra_logging: disabled
+ - _self_
+
+puzzle_dir: ???
+descriptor: qwen3
+teacher_dir: ${puzzle_dir}/ckpts/teacher/
+replacement_library_path: ${puzzle_dir}/replacement_library.json
+dataset_path: ??? # path to Nemotron-Post-Training-Dataset-v2
+
+skip_realize_model: false
+
+build_replacement_library:
+ add_ffn_no_ops: true
+ add_attention_no_ops: true
+
+calc_subblock_stats:
+ batch_sizes: [64, 96, 128]
+ prefill_seq_len: 4096
+ generation_seq_len: 4096
+ num_active_tokens_override: # Optional override for sequence lengths
+ prefill_queue_size: 0
+ allocate_prefill_query: false
+ runtime_stats:
+ backend: trt_torch
+ benchmark_iterations: # Set to a number (e.g., 1000) to enable runtime benchmarking
+ merge_with_existing_stats: false
+ subblock_stats_filename: "subblock_stats.json"
+ moe_stats_filename: "moe_stats.json"
+
+scoring:
+ descriptor: ${descriptor}
+ solutions_to_validate:
+ skip_existing_solutions: true
+
+ replacement_library_path: ${replacement_library_path}
+ solutions_path: ${to_path:${puzzle_dir}/single_sequence_replacement_solutions.json}
+ teacher_dir: ${to_path:${teacher_dir}}
+ output_dir: ${puzzle_dir}/single_sequence_replacement_solutions--validation
+
+ eval_samples: 128
+ micro_batch_size: 1
+ seed: 42
+ shuffle_seed: 444
+ dataset_path: ${dataset_path}
+
+mip:
+ single_block_replacement_validation_dir: ${to_path:${scoring.output_dir}}
+ subblock_stats_path: ${to_path:${puzzle_dir}/${calc_subblock_stats.subblock_stats_filename}}
+ output_path: ${to_path:${puzzle_dir}/mip/puzzle_solutions}
+ gathered_metrics_path:
+ puzzle_profile:
+
+ # puzzle_profile:
+ objective: metrics.cosine_embedding_loss_hidden_states
+ bigger_is_better: false
+
+ subblock_stats_args:
+ - batch_size: 96
+ weights_dtype: torch.bfloat16
+ activations_dtype: torch.bfloat16
+ kv_cache_dtype: torch.bfloat16
+
+ report_additional_costs:
+ - stats.memory_mib
+ - stats.num_params
+ - stats.num_kv_heads
+ - stats.has_attention
+ - stats.has_ffn
+ - stats.kv_cache_memory_mib
+ - stats.attention_memory_mib
+ - stats.ffn_memory_mib
+ - stats.ffn_num_params
+ - stats.attention_num_params
+
+ human_constraints:
+ target_memory: 78_000
+ num_params: 8_000_000_000
+
+ mip_constraints:
+ metric_overrides:
+ max_seconds_per_solution: 60
+
+realize_model:
+ descriptor: ${descriptor}
+ teacher_dir: ${to_path:${teacher_dir}}
+ tokenizer_name: ${to_path:${teacher_dir}}
+ replacement_library_path: ${replacement_library_path}
+ save_models: true
+ solutions_path: # Filled dynamically
+
+ # Validate params
+ skip_validation: false # To enable validation of the model solution set `skip_validation` as False
+ eval_samples: 128
+ micro_batch_size: 1
+ seed: 42
+ shuffle_seed: 444
+ dataset_path: ${dataset_path}
+
+nccl_timeout_minutes: ${timedelta_minutes:10}
+
+# This section redirects Hydra outputs
+hydra:
+ run:
+ dir: ${puzzle_dir}/hydra_logs/${now:%Y-%m-%d}/${now:%H-%M-%S}
diff --git a/examples/puzzletron/configs/qwen3-8b_pruneffn_memory/qwen3_8b_pruneffn_memory.yaml b/examples/puzzletron/configs/qwen3-8b_pruneffn_memory/qwen3_8b_pruneffn_memory.yaml
new file mode 100644
index 0000000000..4ee81286dd
--- /dev/null
+++ b/examples/puzzletron/configs/qwen3-8b_pruneffn_memory/qwen3_8b_pruneffn_memory.yaml
@@ -0,0 +1,22 @@
+defaults:
+ - qwen3_8b
+ - _self_
+
+# Input Hugging Face model to compress
+input_hf_model_path: /workspace/hf_models/Qwen/Qwen3-8B
+
+# Dataset path for pruning and NAS scoring
+dataset_path: /workspace/datasets/Nemotron-Post-Training-Dataset-v2
+
+# Working directory for compression outputs
+puzzle_dir: /workspace/puzzle_dir
+
+# MIP memory constraint (in MiB)
+mip:
+ human_constraints:
+ target_memory: 78_000 # 78 GiB
+
+# FFN intermediate sizes to search over (heterogeneous architecture)
+# teacher_intermediate_size is 12288
+pruning:
+ intermediate_size_list: [2560, 5120, 7424, 9984]
diff --git a/examples/puzzletron/configs/qwen3-8b_pruneffn_memory/validate_model_defaults.yaml b/examples/puzzletron/configs/qwen3-8b_pruneffn_memory/validate_model_defaults.yaml
new file mode 100644
index 0000000000..ce1749d969
--- /dev/null
+++ b/examples/puzzletron/configs/qwen3-8b_pruneffn_memory/validate_model_defaults.yaml
@@ -0,0 +1,17 @@
+model_dtype: torch.bfloat16 # dtype to cast the model for validate_model
+autocast_dtype: torch.bfloat16 # dtype for torch.autocast for validate_model
+block_size: 8192
+bos_rate: 0.5
+data_column: messages
+val_dataset_name: valid
+shuffle_seed: 81436
+seed: 42
+fim_rate: 0
+fim_spm_rate: 0
+source_datasets_to_discard:
+varlen: false
+write_results: false
+calc_losses_on_cpu: false
+activations_log_dir:
+model_name_or_path:
+load_dataset_fn: ${get_object:modelopt.torch.puzzletron.utils.data.dataloaders.load_from_disk_fn}
diff --git a/examples/puzzletron/configs/qwen3-8b_pruneffn_memory/validate_solutions_defaults.yaml b/examples/puzzletron/configs/qwen3-8b_pruneffn_memory/validate_solutions_defaults.yaml
new file mode 100644
index 0000000000..ec13902379
--- /dev/null
+++ b/examples/puzzletron/configs/qwen3-8b_pruneffn_memory/validate_solutions_defaults.yaml
@@ -0,0 +1,10 @@
+defaults:
+ - /validate_model_defaults
+ - _self_
+
+solutions_to_validate:
+skip_validation: false
+save_models: false
+bigger_is_better: false
+sort_solutions_by:
+calculate_full_score_ablations: false
diff --git a/examples/puzzletron/evaluation/hf_deployable_anymodel.py b/examples/puzzletron/evaluation/hf_deployable_anymodel.py
new file mode 100644
index 0000000000..3ca8dd7581
--- /dev/null
+++ b/examples/puzzletron/evaluation/hf_deployable_anymodel.py
@@ -0,0 +1,724 @@
+# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+# SPDX-License-Identifier: Apache-2.0
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+# mypy: ignore-errors
+
+import json
+import logging
+from typing import Any
+
+import numpy as np
+import torch
+from nemo_deploy import ITritonDeployable
+from nemo_deploy.utils import broadcast_list, cast_output, str_ndarray2list
+from nemo_export_deploy_common.import_utils import (
+ MISSING_TRITON_MSG,
+ UnavailableError,
+ null_decorator,
+)
+from peft import PeftModel
+from transformers import AutoModel, AutoModelForCausalLM, AutoTokenizer
+
+from modelopt.torch.puzzletron.anymodel.model_descriptor.model_descriptor_factory import (
+ resolve_descriptor_from_pretrained,
+)
+from modelopt.torch.puzzletron.anymodel.puzzformer import deci_x_patcher
+
+try:
+ from pytriton.decorators import batch
+ from pytriton.model_config import Tensor
+
+ HAVE_TRITON = True
+except (ImportError, ModuleNotFoundError):
+ from unittest.mock import MagicMock
+
+ HAVE_TRITON = False
+ batch = MagicMock()
+ Tensor = MagicMock()
+ batch = null_decorator
+
+
+LOGGER = logging.getLogger("NeMo")
+
+SUPPORTED_TASKS = ["text-generation"]
+
+
+class HuggingFaceLLMDeploy(ITritonDeployable):
+ """A Triton inference server compatible wrapper for HuggingFace models.
+
+ This class provides a standardized interface for deploying HuggingFace models
+ in Triton inference server. It supports various NLP tasks and handles model
+ loading, inference, and deployment configurations.
+
+ Args:
+ hf_model_id_path (Optional[str]): Path to the HuggingFace model or model identifier.
+ Can be a local path or a model ID from HuggingFace Hub.
+ hf_peft_model_id_path (Optional[str]): Path to the PEFT model or model identifier.
+ Can be a local path or a model ID from HuggingFace Hub.
+ tokenizer_id_path (Optional[str]): Path to the tokenizer or tokenizer identifier.
+ If None, will use the same path as hf_model_id_path.
+ model (Optional[AutoModel]): Pre-loaded HuggingFace model.
+ tokenizer (Optional[AutoTokenizer]): Pre-loaded HuggingFace tokenizer.
+ tokenizer_padding (bool): Whether to enable padding in tokenizer. Defaults to True.
+ tokenizer_truncation (bool): Whether to enable truncation in tokenizer. Defaults to True.
+ tokenizer_padding_side (str): Which side to pad on ('left' or 'right'). Defaults to 'left'.
+ task (str): HuggingFace task type (e.g., "text-generation"). Defaults to "text-generation".
+ **hf_kwargs: Additional keyword arguments to pass to HuggingFace model loading.
+ """
+
+ def __init__(
+ self,
+ hf_model_id_path: str | None = None,
+ hf_peft_model_id_path: str | None = None,
+ tokenizer_id_path: str | None = None,
+ model: AutoModel | None = None,
+ tokenizer: AutoTokenizer | None = None,
+ tokenizer_padding=True,
+ tokenizer_truncation=True,
+ tokenizer_padding_side="left",
+ task: str | None = "text-generation",
+ torch_dtype: torch.dtype | None = "auto",
+ device_map: str | None = "auto",
+ **hf_kwargs,
+ ):
+ if not HAVE_TRITON:
+ raise UnavailableError(MISSING_TRITON_MSG)
+
+ if hf_model_id_path is None and model is None:
+ raise ValueError("hf_model_id_path or model parameters has to be passed.")
+ elif hf_model_id_path is not None and model is not None:
+ LOGGER.warning(
+ "hf_model_id_path will be ignored and the HuggingFace model set with model parameter will be used."
+ )
+
+ assert task in SUPPORTED_TASKS, "Task {} is not a support task.".format(task)
+
+ self.hf_model_id_path = hf_model_id_path
+ self.hf_peft_model_id_path = hf_peft_model_id_path
+ self.task = task
+ self.model = model
+ self.tokenizer = tokenizer
+ self.tokenizer_padding = tokenizer_padding
+ self.tokenizer_truncation = tokenizer_truncation
+ self.tokenizer_padding_side = tokenizer_padding_side
+
+ if tokenizer_id_path is None:
+ self.tokenizer_id_path = hf_model_id_path
+ else:
+ self.tokenizer_id_path = tokenizer_id_path
+
+ if model is None:
+ self._load(torch_dtype=torch_dtype, device_map=device_map, **hf_kwargs)
+
+ def _load(
+ self, torch_dtype: torch.dtype | None = "auto", device_map: str | None = "auto", **hf_kwargs
+ ) -> None:
+ """Load the HuggingFace pipeline with the specified model and task.
+
+ This method initializes the HuggingFace AutoModel classes using the provided model
+ configuration and task type. It handles the model and tokenizer loading
+ process.
+
+ Args:
+ torch_dtype (torch.dtype): Data type for the model. Defaults to "auto".
+ device_map (str): Device map for the model. Defaults to "auto".
+ **hf_kwargs: Additional keyword arguments to pass to the HuggingFace model loading.
+
+ Raises:
+ AssertionError: If task is not specified.
+ """
+ assert self.task is not None, "A task has to be given for the generation task."
+
+ if self.task == "text-generation":
+ # =========================================================================
+ # BEGIN ANYMODEL PATCH
+ # Wraps model loading with deci_x_patcher for heterogeneous layer configs.
+ # See: modelopt/torch/puzzletron/anymodel/puzzformer/utils.py
+ # =========================================================================
+
+ descriptor = resolve_descriptor_from_pretrained(
+ self.hf_model_id_path, trust_remote_code=hf_kwargs.get("trust_remote_code", False)
+ )
+
+ with deci_x_patcher(model_descriptor=descriptor):
+ self.model = AutoModelForCausalLM.from_pretrained(
+ self.hf_model_id_path,
+ torch_dtype=torch_dtype,
+ device_map=device_map,
+ **hf_kwargs,
+ )
+ # =========================================================================
+ # END ANYMODEL PATCH
+ # =========================================================================
+
+ if self.hf_peft_model_id_path is not None:
+ self.model = PeftModel.from_pretrained(self.model, self.hf_peft_model_id_path)
+ else:
+ raise ValueError("Task {} is not supported.".format(self.task))
+ num_gpus = torch.cuda.device_count()
+ # If there is only one GPU, move the model to GPU. If you are using device_map as "auto" or "balanced",
+ # the model will be moved to GPU automatically.
+ if device_map is None and num_gpus >= 1 and self.model.device.type != "cuda":
+ self.model.cuda()
+ self.tokenizer = AutoTokenizer.from_pretrained(
+ self.tokenizer_id_path,
+ trust_remote_code=hf_kwargs.pop("trust_remote_code", False),
+ padding=self.tokenizer_padding,
+ truncation=self.tokenizer_truncation,
+ padding_side=self.tokenizer_padding_side,
+ )
+
+ if self.tokenizer.pad_token is None:
+ self.tokenizer.pad_token = self.tokenizer.eos_token
+
+ def generate(
+ self,
+ **kwargs: Any,
+ ) -> list[str]:
+ """Generate text based on the provided input prompts.
+
+ This method processes input prompts through the loaded pipeline and
+ generates text according to the specified parameters.
+
+ Args:
+ **kwargs: Generation parameters including:
+ - text_inputs: List of input prompts
+ - max_length: Maximum number of tokens to generate
+ - num_return_sequences: Number of sequences to generate per prompt
+ - temperature: Sampling temperature
+ - top_k: Number of highest probability tokens to consider
+ - top_p: Cumulative probability threshold for token sampling
+ - do_sample: Whether to use sampling, default is False for greedy decoding
+ - echo: Whether to return prompt + generated text (True) or just generated text (False)
+ - return_full_text: Whether to return full text or only generated part
+
+ Returns:
+ If output logits and output scores are False:
+ List[str]: A list of generated texts, one for each input prompt.
+ If output logits and output scores are True:
+ Dict: A dictionary containing:
+ - sentences: List of generated texts
+ - logits: List of logits
+ - scores: List of scores
+ - input_lengths: List of input token lengths (for echo processing)
+
+ Raises:
+ RuntimeError: If the pipeline is not initialized.
+ """
+ if not self.model:
+ raise RuntimeError("Model is not initialized")
+
+ inputs = self.tokenizer(
+ kwargs["text_inputs"],
+ return_tensors="pt",
+ padding=self.tokenizer_padding,
+ truncation=self.tokenizer_truncation,
+ )
+
+ # Store input lengths to extract only generated tokens later
+ input_lengths = [len(input_ids) for input_ids in inputs["input_ids"]]
+
+ # Get echo parameter (default False - only return generated text)
+ echo = kwargs.pop("echo", False)
+ kwargs.pop("text_inputs") # Remove text_inputs as it's already been tokenized
+
+ kwargs = {**inputs, **kwargs}
+ for key, val in kwargs.items():
+ if torch.is_tensor(val):
+ kwargs[key] = val.cuda()
+
+ with torch.no_grad():
+ generated_ids = self.model.generate(**kwargs)
+ return_dict_in_generate = kwargs.get("return_dict_in_generate", False)
+ if return_dict_in_generate:
+ # Handle dict output (when logits/scores are requested)
+ sequences = generated_ids["sequences"]
+ output = {"sentences": [], "input_lengths": input_lengths, "sequences": sequences}
+
+ if echo:
+ # Return full text (prompt + generated).
+ # HF model's generate returns the input/prompt tokens as well by default.
+ for i, seq in enumerate(sequences):
+ full_text = self.tokenizer.decode(seq, skip_special_tokens=True)
+ output["sentences"].append(full_text)
+ else:
+ # Extract only the generated tokens (skip input tokens).
+ # This is required as HF model's generate returns the input/prompt tokens
+ # as well by default. (return_full_text is specific to some models)
+ for i, seq in enumerate(sequences):
+ input_len = input_lengths[i] if i < len(input_lengths) else 0
+ generated_tokens = seq[input_len:] # Skip input tokens
+ generated_text = self.tokenizer.decode(
+ generated_tokens, skip_special_tokens=True
+ )
+ output["sentences"].append(generated_text)
+
+ if kwargs.get("output_logits", False):
+ output["logits"] = generated_ids["logits"]
+ if kwargs.get("output_scores", False):
+ output["scores"] = generated_ids["scores"]
+ else:
+ # Handle list output (normal case)
+ output = []
+ if echo:
+ # Return full text (prompt + generated), which is the default in case of HF model generate.
+ for i, seq in enumerate(generated_ids):
+ full_text = self.tokenizer.decode(seq, skip_special_tokens=True)
+ output.append(full_text)
+ else:
+ # Extract only the generated tokens (skip input tokens) as the default
+ # behavior returns the input/prompt tokens as well.
+ for i, seq in enumerate(generated_ids):
+ input_len = input_lengths[i] if i < len(input_lengths) else 0
+ generated_tokens = seq[input_len:] # Skip input tokens
+ generated_text = self.tokenizer.decode(
+ generated_tokens, skip_special_tokens=True
+ )
+ output.append(generated_text)
+
+ return output
+
+ def generate_other_ranks(self):
+ """Generate function for ranks other than the rank 0."""
+ while True:
+ message = torch.empty(1, dtype=torch.long, device="cuda")
+ torch.distributed.broadcast(message, src=0)
+ if message == 0:
+ prompts = broadcast_list(data=[None], src=0)
+ (
+ temperature,
+ top_k,
+ top_p,
+ num_tokens_to_generate,
+ output_logits,
+ output_scores,
+ ) = broadcast_list(data=[None], src=0)
+
+ return_dict_in_generate = False
+ if output_logits or output_scores:
+ return_dict_in_generate = True
+
+ self.generate(
+ text_inputs=prompts,
+ do_sample=False, # do_sample=False for greedy decoding
+ top_k=top_k,
+ top_p=top_p,
+ temperature=temperature,
+ max_new_tokens=num_tokens_to_generate,
+ output_logits=output_logits,
+ output_scores=output_scores,
+ return_dict_in_generate=return_dict_in_generate,
+ )
+ else:
+ return
+
+ @property
+ def get_triton_input(self):
+ inputs = (
+ Tensor(name="prompts", shape=(-1,), dtype=bytes),
+ Tensor(name="max_length", shape=(-1,), dtype=np.int_, optional=True),
+ Tensor(name="max_batch_size", shape=(-1,), dtype=np.int_, optional=True),
+ Tensor(name="top_k", shape=(-1,), dtype=np.int_, optional=True),
+ Tensor(name="top_p", shape=(-1,), dtype=np.single, optional=True),
+ Tensor(name="temperature", shape=(-1,), dtype=np.single, optional=True),
+ Tensor(name="random_seed", shape=(-1,), dtype=np.int_, optional=True),
+ Tensor(name="max_length", shape=(-1,), dtype=np.int_, optional=True),
+ Tensor(name="output_logits", shape=(-1,), dtype=np.bool_, optional=True),
+ Tensor(name="output_scores", shape=(-1,), dtype=np.bool_, optional=True),
+ )
+ return inputs
+
+ @property
+ def get_triton_output(self):
+ return (
+ Tensor(name="sentences", shape=(-1,), dtype=bytes),
+ Tensor(name="logits", shape=(-1,), dtype=np.single),
+ Tensor(name="scores", shape=(-1,), dtype=np.single),
+ )
+
+ @batch
+ def triton_infer_fn(self, **inputs: np.ndarray):
+ output_infer = {}
+
+ try:
+ prompts = str_ndarray2list(inputs.pop("prompts"))
+ temperature = inputs.pop("temperature")[0][0] if "temperature" in inputs else 1.0
+ top_k = int(inputs.pop("top_k")[0][0] if "top_k" in inputs else 1)
+ top_p = inputs.pop("top_p")[0][0] if "top_p" in inputs else 0
+ num_tokens_to_generate = (
+ inputs.pop("max_length")[0][0] if "max_length" in inputs else 256
+ )
+ output_logits = (
+ inputs.pop("output_logits")[0][0] if "output_logits" in inputs else False
+ )
+ output_scores = (
+ inputs.pop("output_scores")[0][0] if "output_scores" in inputs else False
+ )
+ return_dict_in_generate = False
+ if output_logits or output_scores:
+ return_dict_in_generate = True
+
+ if torch.distributed.is_initialized():
+ if torch.distributed.get_world_size() > 1:
+ torch.distributed.broadcast(
+ torch.tensor([0], dtype=torch.long, device="cuda"), src=0
+ )
+ broadcast_list(prompts, src=0)
+ broadcast_list(
+ data=[
+ temperature,
+ top_k,
+ top_p,
+ num_tokens_to_generate,
+ output_logits,
+ output_scores,
+ ],
+ src=0,
+ )
+
+ output = self.generate(
+ text_inputs=prompts,
+ do_sample=False, # do_sample=False for greedy decoding
+ top_k=top_k,
+ top_p=top_p,
+ temperature=temperature,
+ max_new_tokens=num_tokens_to_generate,
+ output_logits=output_logits,
+ output_scores=output_scores,
+ return_dict_in_generate=return_dict_in_generate,
+ echo=False,
+ )
+
+ if isinstance(output, dict):
+ output_infer = {"sentences": cast_output(output["sentences"], np.bytes_)}
+
+ if "scores" in output:
+ output_scores = []
+ for r in output["scores"]:
+ lp = torch.tensor(r).cpu().detach().numpy()
+ if len(lp) == 0:
+ output_scores.append([0])
+ else:
+ output_scores.append(lp)
+ output_infer["scores"] = np.array(output_scores).transpose(1, 0, 2)
+
+ if "logits" in output:
+ output_logits = []
+ for r in output["logits"]:
+ lp = torch.tensor(r).cpu().detach().numpy()
+ if len(lp) == 0:
+ output_logits.append([0])
+ else:
+ output_logits.append(lp)
+ output_infer["logits"] = np.array(output_logits).transpose(1, 0, 2)
+ else:
+ output_infer = {"sentences": cast_output(output, np.bytes_)}
+
+ except Exception as error:
+ err_msg = "An error occurred: {}".format(str(error))
+ output_infer["sentences"] = cast_output([err_msg], np.bytes_)
+
+ return output_infer
+
+ def _compute_logprobs(
+ self,
+ prompts: list[str],
+ output_infer: dict[str, Any],
+ compute_logprob: bool,
+ n_top_logprobs: int,
+ echo: bool,
+ ):
+ """Compute log probabilities and top log probabilities from model scores.
+ Used by ray_infer_fn to provide OAI API compatible output for evaluations.
+
+ This method processes the raw scores from model generation to compute:
+ - Log probabilities for chosen tokens
+ - Top-k log probabilities for each position (if requested)
+ - Handles both prompt tokens (when echo=True) and generated tokens
+
+ Args:
+ prompts: List of input prompts
+ output_infer: Dictionary containing model outputs including scores, sequences, and input_lengths
+ compute_logprob: Whether to compute log probabilities
+ n_top_logprobs: Number of top log probabilities to return (0 to disable)
+ echo: Whether to include prompt token log probabilities
+
+ Returns:
+ Tuple[Optional[List], Optional[List]]:
+ - log_probs_list: List of log probabilities for each sample (None if not computed)
+ - top_logprobs_list: List of top-k log probabilities for each sample (None if not computed)
+ """
+ # Tokenize the prompts to get prompt token IDs (needed for echo)
+ prompt_token_ids = None
+ prompt_inputs = None
+ if echo:
+ prompt_inputs = self.tokenizer(
+ prompts,
+ return_tensors="pt",
+ padding=self.tokenizer_padding,
+ truncation=self.tokenizer_truncation,
+ )
+ prompt_token_ids = prompt_inputs["input_ids"]
+ # Move to same device as model
+ for key, val in prompt_inputs.items():
+ if torch.is_tensor(val):
+ prompt_inputs[key] = val.cuda()
+
+ # Process each sample
+ log_probs_list = []
+ top_logprobs_list = []
+
+ for sample_idx in range(len(prompts)):
+ sample_log_probs = []
+ sample_top_logprobs = []
+
+ # Get the generated sequence for this sample
+ sequences = output_infer["sequences"][sample_idx]
+
+ # For echo, compute prompt token logprobs by running forward pass
+ if echo and prompt_token_ids is not None:
+ prompt_len = len(prompt_token_ids[sample_idx])
+
+ # Run forward pass on prompt to get logits for prompt tokens as scores in output_infer contains
+ # logits only for generated tokens.
+ with torch.no_grad():
+ # Create input for this specific sample
+ sample_prompt_input = {
+ key: val[sample_idx : sample_idx + 1] for key, val in prompt_inputs.items()
+ }
+ prompt_outputs = self.model(**sample_prompt_input)
+ prompt_logits = prompt_outputs.logits[0] # Shape: [seq_len, vocab_size]
+
+ # Calculate log probs for each prompt token (except the first BOS token)
+ for token_pos in range(1, prompt_len): # Start from 1 to skip BOS
+ # The logit at position i-1 predicts token at position i
+ logit_for_current_token = prompt_logits[token_pos - 1]
+ current_token_id = prompt_token_ids[sample_idx][token_pos].item()
+
+ # Calculate log probabilities
+ log_probs = torch.nn.functional.log_softmax(logit_for_current_token, dim=-1)
+ chosen_log_prob = log_probs[current_token_id].item()
+ sample_log_probs.append(chosen_log_prob)
+
+ # Calculate top log probabilities if requested
+ if n_top_logprobs > 0:
+ top_log_probs_dict = {}
+ top_k_values, top_k_indices = torch.topk(
+ log_probs, min(n_top_logprobs, len(log_probs))
+ )
+ for k_idx in range(len(top_k_indices)):
+ token_id = top_k_indices[k_idx].item()
+ token_str = self.tokenizer.decode([token_id])
+ top_log_probs_dict[token_str] = top_k_values[k_idx].item()
+ sample_top_logprobs.append(top_log_probs_dict)
+
+ # Process the scores for generated tokens
+ for token_idx, score_tensor in enumerate(output_infer["scores"]):
+ # Get the chosen token ID from the sequence
+ # Scores start after the prompt, so we need to offset
+ input_len = (
+ output_infer.get("input_lengths", [0])[sample_idx]
+ if "input_lengths" in output_infer
+ else 0
+ )
+ seq_idx = input_len + token_idx
+
+ if seq_idx < len(sequences):
+ chosen_token_id = (
+ sequences[seq_idx].item()
+ if hasattr(sequences[seq_idx], "item")
+ else sequences[seq_idx]
+ )
+
+ # Calculate log probabilities
+ log_probs = torch.nn.functional.log_softmax(score_tensor[sample_idx], dim=-1)
+ chosen_log_prob = log_probs[chosen_token_id].item()
+ sample_log_probs.append(chosen_log_prob)
+
+ # Calculate top log probabilities if requested
+ if n_top_logprobs > 0:
+ top_log_probs_dict = {}
+ top_k_values, top_k_indices = torch.topk(
+ log_probs, min(n_top_logprobs, len(log_probs))
+ )
+ for k_idx in range(len(top_k_indices)):
+ token_id = top_k_indices[k_idx].item()
+ token_str = self.tokenizer.decode([token_id])
+ top_log_probs_dict[token_str] = top_k_values[k_idx].item()
+ sample_top_logprobs.append(top_log_probs_dict)
+
+ log_probs_list.append(sample_log_probs)
+ if n_top_logprobs > 0:
+ top_logprobs_list.append(sample_top_logprobs)
+
+ # Return log probs and top logprobs
+ return_log_probs = log_probs_list if compute_logprob else None
+ return_top_logprobs = top_logprobs_list if n_top_logprobs > 0 else None
+
+ return return_log_probs, return_top_logprobs
+
+ def ray_infer_fn(self, inputs: dict[Any, Any]):
+ """Perform inference using Ray with dictionary inputs and outputs.
+
+ Args:
+ inputs (Dict[Any, Any]): Dictionary containing input parameters:
+ - prompts: List of input prompts
+ - temperature: Sampling temperature (optional)
+ - top_k: Number of highest probability tokens to consider (optional)
+ - top_p: Cumulative probability threshold for token sampling (optional)
+ - max_tokens: Maximum number of tokens to generate (optional)
+ - compute_logprob: Whether to compute log probabilities (optional)
+ - n_top_logprobs: Number of top log probabilities to return (optional)
+ - echo: Whether to echo the prompt in output (optional)
+
+ Returns:
+ Dict[str, Any]: Dictionary containing:
+ - sentences: List of generated texts
+ - log_probs: Optional list of log probabilities if compute_logprob is True
+ - top_logprobs: Optional list of top log probabilities if n_top_logprobs > 0
+ """
+ try:
+ prompts = inputs.pop("prompts")
+ temperature = inputs.pop("temperature", 1.0)
+ top_k = int(inputs.pop("top_k", 1))
+ top_p = inputs.pop("top_p", 0.0)
+ num_tokens_to_generate = inputs.pop("max_tokens", 256)
+ output_logits = inputs.pop("output_logits", False)
+ output_scores = inputs.pop("output_scores", False)
+ compute_logprob = inputs.pop("compute_logprob", False)
+ n_top_logprobs = inputs.pop("n_top_logprobs", 0)
+ echo = inputs.pop("echo", False)
+
+ output_infer = self._infer_fn_ray(
+ prompts=prompts,
+ temperature=temperature,
+ top_k=top_k,
+ top_p=top_p,
+ num_tokens_to_generate=num_tokens_to_generate,
+ output_logits=output_logits,
+ output_scores=output_scores,
+ compute_logprob=compute_logprob,
+ n_top_logprobs=n_top_logprobs,
+ echo=echo,
+ )
+ # Code to get logprobs (required in OAI API format for eval) from the scores in output_infer.
+ if (
+ (compute_logprob or n_top_logprobs > 0)
+ and "scores" in output_infer
+ and output_infer["scores"]
+ ):
+ log_probs_list, top_logprobs_list = self._compute_logprobs(
+ prompts=prompts,
+ output_infer=output_infer,
+ compute_logprob=compute_logprob,
+ n_top_logprobs=n_top_logprobs,
+ echo=echo,
+ )
+
+ # Add to output
+ if log_probs_list is not None:
+ output_infer["log_probs"] = log_probs_list
+ if top_logprobs_list is not None:
+ # Convert to JSON strings for compatibility
+ output_infer["top_logprobs"] = [
+ json.dumps(top_logprobs) for top_logprobs in top_logprobs_list
+ ]
+
+ # Remove raw outputs that are not needed in the final response
+ output_infer.pop("scores", None)
+ output_infer.pop("sequences", None)
+ output_infer.pop("input_lengths", None)
+ return output_infer
+ except Exception as error:
+ err_msg = "An error occurred: {}".format(str(error))
+ return {"sentences": [err_msg]}
+
+ def _infer_fn_ray(
+ self,
+ prompts,
+ temperature=1.0,
+ top_k=1,
+ top_p=0.0,
+ num_tokens_to_generate=256,
+ output_logits=False,
+ output_scores=False,
+ compute_logprob=False,
+ n_top_logprobs=0,
+ echo=False,
+ cast_output_func=None,
+ ):
+ """Common internal function for inference operations.
+
+ Args:
+ prompts: List of input prompts
+ temperature: Sampling temperature
+ top_k: Number of highest probability tokens to consider
+ top_p: Cumulative probability threshold for token sampling
+ num_tokens_to_generate: Maximum number of tokens to generate
+ output_logits: Whether to output logits
+ output_scores: Whether to output scores
+ compute_logprob: Whether to compute log probabilities
+ n_top_logprobs: Number of top log probabilities to return
+ echo: Whether to echo the prompt in output
+ cast_output_func: Optional function to cast output values
+
+ Returns:
+ Dict containing inference results with raw outputs
+ """
+ # Enable return_dict if we need scores for logprobs or if output_logits/scores are requested
+ return_dict_in_generate = (
+ output_logits or output_scores or compute_logprob or n_top_logprobs > 0
+ )
+ # Enable output_scores if we need to compute logprobs. scores and logits from generate are both identical in
+ # case of greedy decoding. Hence setting output_scores to True when compute_logprob or n_top_logprobs > 0.
+ if compute_logprob or n_top_logprobs > 0:
+ output_scores = True
+
+ if torch.distributed.is_initialized():
+ if torch.distributed.get_world_size() > 1:
+ torch.distributed.broadcast(
+ torch.tensor([0], dtype=torch.long, device="cuda"), src=0
+ )
+ broadcast_list(prompts, src=0)
+ broadcast_list(
+ data=[
+ temperature,
+ top_k,
+ top_p,
+ num_tokens_to_generate,
+ output_logits,
+ output_scores,
+ ],
+ src=0,
+ )
+
+ output = self.generate(
+ text_inputs=prompts,
+ do_sample=False, # do_sample=False for greedy decoding
+ top_k=top_k,
+ top_p=top_p,
+ temperature=temperature,
+ max_new_tokens=num_tokens_to_generate,
+ output_logits=output_logits,
+ output_scores=output_scores,
+ return_dict_in_generate=return_dict_in_generate,
+ echo=echo,
+ )
+
+ if isinstance(output, dict):
+ return output
+
+ else:
+ return {"sentences": output}
diff --git a/examples/puzzletron/evaluation/lm_eval_anymodel.py b/examples/puzzletron/evaluation/lm_eval_anymodel.py
new file mode 100644
index 0000000000..7f9e07dd2b
--- /dev/null
+++ b/examples/puzzletron/evaluation/lm_eval_anymodel.py
@@ -0,0 +1,115 @@
+# Adapted from https://github.com/EleutherAI/lm-evaluation-harness/tree/aa457edc3d64d81530159cd3a182932320c78f8c
+
+# MIT License
+#
+# Copyright (c) 2020 EleutherAI
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in all
+# copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+# SOFTWARE.
+
+# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+# SPDX-License-Identifier: Apache-2.0
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+# mypy: ignore-errors
+
+"""Run lm-eval directly on AnyModel (Puzzletron) checkpoints without a deployment server.
+
+Patches lm-eval's HFLM to wrap model loading with deci_x_patcher so AnyModel
+Puzzletron checkpoints load correctly. Model descriptor is auto-detected from the
+checkpoint's config.json model_type.
+"""
+
+from lm_eval import utils
+from lm_eval.__main__ import cli_evaluate
+from lm_eval.api.model import T
+from lm_eval.models.huggingface import HFLM
+
+# Trigger factory registration for all model descriptors
+import modelopt.torch.puzzletron.anymodel.models # noqa: F401
+from modelopt.torch.puzzletron.anymodel.model_descriptor.model_descriptor_factory import (
+ resolve_descriptor_from_pretrained,
+)
+from modelopt.torch.puzzletron.anymodel.puzzformer import deci_x_patcher
+
+
+def create_from_arg_obj(cls: type[T], arg_dict: dict, additional_config: dict | None = None) -> T:
+ """Override HFLM.create_from_arg_obj to wrap model loading with deci_x_patcher."""
+
+ additional_config = {} if additional_config is None else additional_config
+ additional_config = {k: v for k, v in additional_config.items() if v is not None}
+
+ pretrained = arg_dict.get("pretrained")
+ descriptor = resolve_descriptor_from_pretrained(
+ pretrained, trust_remote_code=arg_dict.get("trust_remote_code", False)
+ )
+ # The patcher must be active during HFLM.__init__ because that's where
+ # AutoModelForCausalLM.from_pretrained() is called internally.
+ with deci_x_patcher(model_descriptor=descriptor):
+ model_obj = cls(**arg_dict, **additional_config)
+
+ return model_obj
+
+
+def create_from_arg_string(
+ cls: type[T], arg_string: str, additional_config: dict | None = None
+) -> T:
+ """Create an LM instance from a comma-separated argument string.
+
+ Args:
+ arg_string: Arguments as ``"key1=value1,key2=value2"``.
+ additional_config: Extra configuration merged into the parsed args.
+
+ Returns:
+ An instance of this LM subclass.
+ """
+ args = utils.simple_parse_args_string(arg_string)
+ additional_config = {} if additional_config is None else additional_config
+ args2 = {k: v for k, v in additional_config.items() if v is not None}
+
+ pretrained = args.get("pretrained")
+ descriptor = resolve_descriptor_from_pretrained(
+ pretrained, trust_remote_code=args.get("trust_remote_code", False)
+ )
+
+ # The patcher must be active during HFLM.__init__ because that's where
+ # AutoModelForCausalLM.from_pretrained() is called internally.
+ with deci_x_patcher(model_descriptor=descriptor):
+ model_obj = cls(**args, **args2)
+
+ return model_obj
+
+
+# Monkey-patch HFLM so lm-eval uses our patched model loading
+HFLM.create_from_arg_obj = classmethod(create_from_arg_obj)
+HFLM.create_from_arg_string = classmethod(create_from_arg_string)
+
+
+if __name__ == "__main__":
+ cli_evaluate()
diff --git a/examples/puzzletron/evaluation/nemo_evaluator_instructions.md b/examples/puzzletron/evaluation/nemo_evaluator_instructions.md
new file mode 100644
index 0000000000..f8b53889c6
--- /dev/null
+++ b/examples/puzzletron/evaluation/nemo_evaluator_instructions.md
@@ -0,0 +1,70 @@
+# Evaluation with NeMo Evaluator (Alternative)
+
+> **Recommended approach:** Use lm-eval for direct evaluation without a
+> deployment server. See the main [README](../README.md#evaluation) for details.
+
+Evaluate AnyModel checkpoints by deploying a local OpenAI-compatible completions endpoint and running benchmarks against it.
+
+This flow requires Ray for serving the model and NeMo Export-Deploy (included in NeMo containers):
+
+```bash
+pip install -r examples/puzzletron/requirements.txt
+```
+
+**1. Deploy the model (2 GPUs example):**
+
+We need to patch the `hf_deployable.py` script from Export-Deploy. Best way is to do it as a mount in docker run:
+
+```bash
+export MODELOPT_DIR=${PWD}/Model-Optimizer # or set to your local Model-Optimizer repository path if you have cloned it
+if [ ! -d "${MODELOPT_DIR}" ]; then
+ git clone https://github.com/NVIDIA/Model-Optimizer.git ${MODELOPT_DIR}
+fi
+
+export DOCKER_IMAGE=nvcr.io/nvidia/nemo:26.02
+docker run \
+ --gpus all \
+ --shm-size=16GB \
+ --net=host \
+ --ulimit memlock=-1 \
+ --rm -it \
+ -v ${MODELOPT_DIR}:/opt/Model-Optimizer \
+ -v ${MODELOPT_DIR}/modelopt:/opt/venv/lib/python3.12/site-packages/modelopt \
+ -v ${MODELOPT_DIR}/examples/puzzletron/evaluation/hf_deployable_anymodel.py:/opt/Export-Deploy/nemo_deploy/llm/hf_deployable.py \
+ -w /opt/Model-Optimizer/examples/megatron_bridge \
+ ${DOCKER_IMAGE} bash
+```
+
+Alternatively you can manually update the file
+
+```bash
+# Install the AnyModel-patched deployable (first time only: backs up the original)
+# /opt/Export-Deploy is the default path in NeMo containers — adjust if needed
+cp /opt/Export-Deploy/nemo_deploy/llm/hf_deployable.py /opt/Export-Deploy/nemo_deploy/llm/hf_deployable.py.bak
+cp examples/puzzletron/evaluation/hf_deployable_anymodel.py /opt/Export-Deploy/nemo_deploy/llm/hf_deployable.py
+```
+
+Now start ray server and deploy the model
+
+```bash
+# Start the server (blocks while running — use a separate terminal)
+ray start --head --num-gpus 2 --port 6379 --disable-usage-stats
+python /opt/Export-Deploy/scripts/deploy/nlp/deploy_ray_hf.py \
+ --model_path path/to/checkpoint \
+ --model_id anymodel-hf \
+ --num_gpus 2 --num_gpus_per_replica 2 --num_cpus_per_replica 16 \
+ --trust_remote_code --port 8083 --device_map "auto" --cuda_visible_devices "0,1"
+```
+
+**2. Run MMLU:**
+
+```bash
+eval-factory run_eval \
+ --eval_type mmlu \
+ --model_id anymodel-hf \
+ --model_type completions \
+ --model_url http://0.0.0.0:8083/v1/completions/ \
+ --output_dir examples/puzzletron/evals/mmlu_anymodel
+```
+
+For a quick debug run, add `--overrides "config.params.limit_samples=5"`.
diff --git a/examples/puzzletron/main.py b/examples/puzzletron/main.py
new file mode 100644
index 0000000000..5bb04818e5
--- /dev/null
+++ b/examples/puzzletron/main.py
@@ -0,0 +1,177 @@
+# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+# SPDX-License-Identifier: Apache-2.0
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""
+Main script for running the puzzletron algorithm on large language models (based on Puzzle paper https://arxiv.org/abs/2411.19146).
+
+This script provides three modes:
+1. Default mode: Runs the full puzzletron pipeline
+2. MIP-only mode: Runs only the MIP search and realize models phase
+3. MIP sweep mode: Runs MIP for multiple memory compression rates (enabled via config)
+
+Usage:
+ # Full puzzletron pipeline
+ torchrun main.py --config ./configs/llama_3.2_1B_pruneffn_memory.yaml
+
+ # Only MIP search and realize models phase
+ torchrun main.py --config ./configs/llama_3.2_1B_pruneffn_memory.yaml --mip-only
+
+ # MIP sweep mode (set mip.sweep.enabled: true in config)
+ torchrun main.py --config ./configs/llama_3.2_1B_pruneffn_memory.yaml --mip-only
+"""
+
+import argparse
+from datetime import timedelta
+from pathlib import Path
+
+import modelopt.torch.nas as mtn
+import modelopt.torch.puzzletron.mip.mip_and_realize_models as mip_and_realize_models
+import modelopt.torch.puzzletron.mip.sweep as sweep
+import modelopt.torch.utils.distributed as dist
+from modelopt.torch.puzzletron.nas.plugins.puzzletron_nas_plugin import PuzzletronModel
+from modelopt.torch.puzzletron.tools.hydra_utils import (
+ initialize_hydra_config_for_dir,
+ register_hydra_resolvers,
+)
+from modelopt.torch.puzzletron.tools.logger import mprint
+
+
+def parse_args():
+ """Parse command line arguments."""
+ parser = argparse.ArgumentParser(
+ description="Compress large language models using the Puzzletron algorithm (based on Puzzle paper https://arxiv.org/abs/2411.19146)"
+ )
+ parser.add_argument(
+ "--config",
+ type=str,
+ required=True,
+ help="Path to the main config YAML file (e.g., ./configs/llama_3.2_1B_pruneffn_memory.yaml)",
+ )
+ parser.add_argument(
+ "--mip-only",
+ action="store_true",
+ help="Run only the MIP search and realize models phase (skip pruning and NAS scoring)",
+ )
+
+ return parser.parse_args()
+
+
+def run_full_puzzletron(hydra_config_path: str):
+ """Run the full puzzletron pipeline.
+
+ Args:
+ config_path: Path to the YAML configuration file
+ """
+ mprint("Puzzletron Progress 1/8: starting puzzletron pipeline")
+ dist.setup(timeout=timedelta(10))
+
+ # Register Hydra custom resolvers (needed for config resolution)
+ register_hydra_resolvers()
+
+ hydra_config_path = Path(hydra_config_path).resolve()
+ hydra_config_dir = str(hydra_config_path.parent)
+ hydra_config_name = hydra_config_path.stem
+
+ # Load hydra config
+ hydra_cfg = initialize_hydra_config_for_dir(
+ config_dir=hydra_config_dir,
+ config_name=hydra_config_name,
+ overrides=[],
+ )
+
+ # Convert model (convert from HF to DeciLM, score pruning activations,
+ # prune the model and save pruned checkpoints)
+ input_model = PuzzletronModel()
+ converted_model = mtn.convert(
+ input_model,
+ mode=[
+ (
+ "puzzletron",
+ {
+ "puzzle_dir": str(hydra_cfg.puzzle_dir),
+ "input_model_path": hydra_cfg.input_hf_model_path,
+ "hydra_config_dir": hydra_config_dir,
+ "hydra_config_name": hydra_config_name,
+ "dataset_path": str(hydra_cfg.dataset_path),
+ },
+ )
+ ],
+ )
+
+ # Run NAS search (build replacement library and compute stats,
+ # compute one block scores, run MIP and realize models)
+ mtn.search(
+ converted_model,
+ constraints={}, # this is not used as the search space is defined in the hydra config
+ dummy_input=None, # Not used
+ config={}, # this is not used as the search space is defined in the hydra config
+ )
+
+ dist.cleanup()
+ mprint("Puzzletron Progress 8/8: puzzletron pipeline completed (multi-gpu)")
+
+
+def run_mip_only(hydra_config_path: str):
+ """Run only the MIP search and realize models phase.
+
+ This assumes that pruning, replacement library building, NAS scoring, and subblock stats calculation
+ have already been completed.
+
+ Args:
+ hydra_config_path: Path to the YAML configuration file
+ """
+ dist.setup(timeout=timedelta(10))
+
+ # Register Hydra custom resolvers (needed for config resolution)
+ register_hydra_resolvers()
+
+ hydra_config_path = Path(hydra_config_path).resolve()
+ hydra_config_dir = str(hydra_config_path.parent)
+ hydra_config_name = hydra_config_path.stem
+
+ # Load hydra config
+ hydra_cfg = initialize_hydra_config_for_dir(
+ config_dir=hydra_config_dir,
+ config_name=hydra_config_name,
+ overrides=[],
+ )
+
+ # Check if sweep mode is enabled
+ if hasattr(hydra_cfg.mip, "sweep") and hydra_cfg.mip.sweep.get("enabled", False):
+ mprint(
+ "Puzzletron Progress 7/8: running MIP sweep for multiple compression rates (multi-gpu)"
+ )
+ sweep.run_mip_sweep(hydra_cfg)
+ else:
+ # mip_and_realize_models (distributed processing)
+ # TODO: How to make it part of mnt.search() api, similarly to run_full_puzzletron() API
+ mprint("Puzzletron Progress 7/8: running MIP and realizing models (multi-gpu)")
+ mip_and_realize_models.launch_mip_and_realize_model(hydra_cfg)
+
+ dist.cleanup()
+ mprint("Puzzletron Progress 8/8: puzzletron pipeline completed (multi-gpu)")
+
+
+def main():
+ args = parse_args()
+
+ if args.mip_only:
+ run_mip_only(hydra_config_path=args.config)
+ else:
+ run_full_puzzletron(hydra_config_path=args.config)
+
+
+if __name__ == "__main__":
+ main()
diff --git a/examples/puzzletron/mbridge_distillation/README.md b/examples/puzzletron/mbridge_distillation/README.md
new file mode 100644
index 0000000000..f7dda866e8
--- /dev/null
+++ b/examples/puzzletron/mbridge_distillation/README.md
@@ -0,0 +1,152 @@
+# Knowledge Distillation with Megatron-Bridge
+
+This guide shows how to perform knowledge distillation on Puzzletron-compressed AnyModel checkpoints using Megatron-Bridge.
+
+## Overview
+
+1. Set up the environment with Megatron-Bridge
+2. Prepare tokenized dataset
+3. Run knowledge distillation training directly from HuggingFace checkpoints
+4. Review MMLU evaluation results (before/after distillation)
+
+## Setup
+
+**Clone Model-Optimizer repo:**
+
+The NeMo container does not include Model-Optimizer examples, so you need to clone the Model-Optimizer repo:
+
+```bash
+export MODELOPT_DIR=${PWD}/Model-Optimizer
+git clone https://github.com/NVIDIA/Model-Optimizer.git ${MODELOPT_DIR}
+```
+
+**Start Docker container:**
+
+Use the [NeMo 26.02.01 container](https://catalog.ngc.nvidia.com/orgs/nvidia/containers/nemo?version=26.02.01):
+
+```bash
+# Recommended to mount a workspace directory for storing datasets and distilled models
+docker run --gpus all -it --rm \
+ -v /path/to/your/project:/workspace \
+ -v ${MODELOPT_DIR}:/opt/Model-Optimizer \
+ -v ${MODELOPT_DIR}/modelopt:/opt/venv/lib/python3.12/site-packages/modelopt \
+ -w /opt/Model-Optimizer \
+ nvcr.io/nvidia/nemo:26.02.01 \
+ /bin/bash
+```
+
+## Dataset Preparation
+
+This section describes how to prepare datasets for knowledge distillation. We provide examples using WikiText-103, which is a small dataset that can still produce decent results (see the Qwen3-8B example below showing +10.11 percentage point improvement). For production use, larger datasets like [Nemotron-Post-Training-Dataset-v2](https://huggingface.co/datasets/nvidia/Nemotron-Post-Training-Dataset-v2) are recommended.
+
+### Download and Tokenize Dataset
+
+Download and tokenize the dataset in a single step. This downloads the dataset from HuggingFace, tokenizes it, and saves it in the Megatron format (`.bin` and `.idx` files):
+
+```bash
+python -m modelopt.torch.utils.plugins.megatron_preprocess_data \
+ --hf_dataset Salesforce/wikitext \
+ --hf_name wikitext-103-v1 \
+ --hf_split train \
+ --output_dir path/to/hf_datasets/wikitext-103-v1 \
+ --tokenizer meta-llama/Llama-3.1-8B-Instruct \
+ --json_keys text \
+ --workers 32
+```
+
+This will create:
+
+- `Salesforce--wikitext_wikitext-103-v1_train_text_document.bin` - Binary tokenized data
+- `Salesforce--wikitext_wikitext-103-v1_train_text_document.idx` - Index file for the binary data
+- `Salesforce--wikitext_wikitext-103-v1_train_text_document/cache/` - Cache directory (created after running distillation)
+
+## Run Knowledge Distillation
+
+Run distillation directly from HuggingFace checkpoints (student and teacher) with tokenized dataset:
+
+```bash
+torchrun --nproc_per_node=8 examples/puzzletron/mbridge_distillation/distill_hf.py \
+ --student_hf_path /path/to/student/huggingface/checkpoint \
+ --teacher_hf_path /path/to/teacher/huggingface/checkpoint \
+ --data_paths 1.0 /path/to/hf_datasets/wikitext-103-v1/Salesforce--wikitext_wikitext-103-v1_train_text_document \
+ --output_dir /path/to/distilled/checkpoint \
+ --hf-export-path /path/to/exported/hf/model \
+ --hf-model meta-llama/Llama-3.1-8B-Instruct \
+ --seq_length 4096 \
+ --tp_size 8 \
+ --pp_size 1 \
+ --mbs 1 \
+ --gbs 4 \
+ --train_iters 100 \
+ --lr 0.0001 \
+ --min_lr 1e-05 \
+ --lr_warmup_iters 10 \
+ --eval_interval 10 \
+ --eval_iters 10 \
+ --log_interval 1
+```
+
+**Notes:**
+
+- Add `--trust_remote_code` if student or teacher checkpoints need HuggingFace custom modeling code.
+- The distilled Megatron-Bridge checkpoint will be saved to `--output_dir/checkpoints/iter_`.
+- Add `--hf-export-path` (or `--hf_export_path`) to automatically export the final checkpoint to HuggingFace format after distillation. When exporting, you must also provide `--hf-model` / `--hf_model` as the HuggingFace model ID for the export template (e.g., `meta-llama/Llama-3.1-8B-Instruct`). It should match the base architecture of the student model. The exported model can be evaluated for accuracy using the evaluation tools described in the main [README.md](../README.md#evaluation).
+- For production use, use larger datasets like [Nemotron-Pretraining-SFT-v1](https://huggingface.co/datasets/nvidia/Nemotron-Pretraining-SFT-v1) and train for more iterations. See the [Megatron-Bridge distillation tutorial](https://github.com/NVIDIA/Model-Optimizer/tree/main/examples/megatron_bridge#distillation) for best practices.
+
+## MMLU Evaluation Results
+
+This section presents MMLU evaluation results for knowledge distillation experiments compressing Qwen3-8B and Llama-3.1-8B-Instruct.
+
+### Successful Case: Qwen3-8B (80% of original)
+
+Distillation results for a memory-compressed Qwen3-8B checkpoint (80% of original size):
+
+| Model | MMLU | Humanities | Other | Social Sci | STEM |
+|-------|------|------------|-------|------------|------|
+| 80% pre-distillation | 0.5910 | 0.5046 | 0.6363 | 0.6831 | 0.5855 |
+| 80% post-distillation | 0.6921 | 0.5906 | 0.7316 | 0.7975 | 0.7016 |
+| Original Qwen3-8B | 0.7493 | 0.6648 | 0.7856 | 0.8385 | 0.7526 |
+
+**Key observations:**
+
+- MMLU accuracy improved from 59.10% to 69.21% (+10.11 percentage points) after distillation
+- Achieved with just 100 iterations on WikiText-103, demonstrating efficient knowledge transfer
+- Recovery of 64% of the gap to the teacher model (from 59.10% to 69.21%, closing 64% of the gap from 59.10% to 74.93%)
+- All individual category scores (Humanities, Other, Social Sciences, STEM) improved significantly
+
+### Successful Case: Llama-3.1-8B-Instruct (50% of original, 56,810 MiB)
+
+Distillation results for a pruned Llama-3.1-8B-Instruct checkpoint (50% of original size, 56,810 MiB memory constraint):
+
+| Model | MMLU | Humanities | Other | Social Sciences | STEM |
+|-------|------|------------|-------|-----------------|------|
+| Before distillation | 0.2316 | 0.2462 | 0.2292 | 0.2250 | 0.2274 |
+| After distillation | 0.2960 | 0.3146 | 0.3085 | 0.2925 | 0.2768 |
+| Original Llama-3.1-8B-Instruct | 0.6839 | 0.7231 | 0.7038 | 0.7667 | 0.5911 |
+
+**Key observations:**
+
+- MMLU accuracy (average across all categories) improved from 23.16% to 29.60% (+6.44 percentage points)
+- All individual category scores (Humanities, Other, Social Sciences, STEM) improved, demonstrating effective knowledge transfer from teacher to student
+
+### Regression Case: Llama-3.1-8B-Instruct (69% of original, 78,000 MiB)
+
+Distillation results for a pruned Llama-3.1-8B-Instruct checkpoint (approximately 69% of original size, 78,000 MiB memory constraint) showing regression due to overfitting on the small WikiText-103 dataset (evaluated with limit 100):
+
+| Model | MMLU | Humanities | Other | Social Sciences | STEM |
+|-------|------|------------|-------|-----------------|------|
+| Before distillation | 0.6626 | 0.7069 | 0.6892 | 0.7525 | 0.5574 |
+| After distillation | 0.6496 | 0.6862 | 0.6677 | 0.7433 | 0.5532 |
+| Original Llama-3.1-8B-Instruct | 0.6839 | 0.7231 | 0.7038 | 0.7667 | 0.5911 |
+
+**Key observations:**
+
+- MMLU accuracy (average across all categories) decreased from 66.26% to 64.96% (-1.30 percentage points) after distillation
+- The model overfitted to the small WikiText-103 dataset, causing performance regression
+- This demonstrates the critical importance of using larger, more diverse datasets for knowledge distillation
+
+### Recommendations
+
+- **For production distillation:** Use larger production datasets like [nvidia/Nemotron-Pretraining-SFT-v1](https://huggingface.co/datasets/nvidia/Nemotron-Pretraining-SFT-v1) for better results and to avoid overfitting (see regression case above)
+- **Training duration:** Train for more iterations to ensure proper convergence
+- **See the [Megatron-Bridge distillation tutorial](https://github.com/NVIDIA/Model-Optimizer/tree/main/examples/megatron_bridge#distillation) for best practices**
diff --git a/examples/puzzletron/mbridge_distillation/distill_hf.py b/examples/puzzletron/mbridge_distillation/distill_hf.py
new file mode 100644
index 0000000000..ac703909c2
--- /dev/null
+++ b/examples/puzzletron/mbridge_distillation/distill_hf.py
@@ -0,0 +1,329 @@
+# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+# SPDX-License-Identifier: Apache-2.0
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""Distillation script for Megatron-Bridge.
+
+Loads student and teacher models directly from HuggingFace checkpoints (local or remote) and saves the distilled model
+to `/checkpoints` in megatron distributed checkpoint format.
+
+See `README.md` in this directory for example usage and data preparation instructions.
+"""
+
+import argparse
+import os
+import traceback
+
+import megatron.bridge.models.distillation_provider
+import torch
+from megatron.bridge import AutoBridge
+from megatron.bridge.recipes.utils.optimizer_utils import (
+ distributed_fused_adam_with_cosine_annealing,
+)
+from megatron.bridge.training.config import (
+ CheckpointConfig,
+ ConfigContainer,
+ GPTDatasetConfig,
+ LoggerConfig,
+ MockGPTDatasetConfig,
+ RNGConfig,
+ TokenizerConfig,
+ TrainingConfig,
+)
+from megatron.bridge.training.post_training.distillation import ModelOptDistillConfig
+from megatron.core.datasets.utils import get_blend_from_list
+from megatron.core.distributed import DistributedDataParallelConfig
+
+# Import heterogeneous bridges BEFORE AutoBridge.from_hf_pretrained() is called to ensure
+# registration takes precedence. The @MegatronModelBridge.register_bridge decorator registers
+# bridges when the module is imported. If both LlamaBridge and PuzzletronLlamaAnyModelBridge
+# register for the same source (LlamaForCausalLM), the dispatch system uses the last registration.
+#
+# Note: Currently, bridges are also registered when distillation_provider is imported
+# below (via mbridge/__init__.py), but this import will be needed once DistillationProvider
+# is upstreamed to Megatron-Bridge and we no longer import from modelopt.torch.puzzletron.
+import modelopt.torch.puzzletron.export.mbridge # noqa: F401
+import modelopt.torch.utils.distributed as dist
+
+# Use local copy of distillation_provider with fix for heterogeneous models
+# TODO: Remove this local copy once fix is upstreamed to Megatron-Bridge
+from modelopt.torch.puzzletron.export.mbridge.distillation_provider import (
+ DistillationProvider,
+ convert_to_distillation_provider,
+)
+from modelopt.torch.puzzletron.export.mbridge.export_mbridge_to_hf import (
+ export_to_hf_and_copy_config,
+)
+from modelopt.torch.utils import print_rank_0
+
+# Patch upstream module BEFORE importing distill() so isinstance checks work with our local DistillationProvider
+# This must happen before distill() is imported because distill.py imports DistillationProvider at module load time
+megatron.bridge.models.distillation_provider.DistillationProvider = DistillationProvider
+
+# Import distill() AFTER patching so it uses the patched DistillationProvider
+from megatron.bridge.training.distill import distill # noqa: E402
+
+SEED = 1234
+
+
+def get_args():
+ """Parse command-line arguments."""
+ parser = argparse.ArgumentParser(description="Distillation for Megatron-Bridge.")
+ # Model arguments (accepts HuggingFace input only at the moment)
+ parser.add_argument(
+ "--student_hf_path",
+ type=str,
+ required=True,
+ help="HuggingFace model name or path for the student (e.g. Qwen/Qwen3-0.6B)",
+ )
+ parser.add_argument(
+ "--teacher_hf_path",
+ type=str,
+ required=True,
+ help="HuggingFace model name or path for the teacher (e.g. Qwen/Qwen3-8B)",
+ )
+ parser.add_argument("--trust_remote_code", action="store_true", help="Trust remote code")
+ # Parallelism arguments
+ parser.add_argument("--tp_size", type=int, default=1, help="Tensor parallel size")
+ parser.add_argument("--pp_size", type=int, default=1, help="Pipeline parallel size")
+ # Dataset arguments
+ parser.add_argument(
+ "--data_paths",
+ nargs="+",
+ help="List of tokenized data paths to load from (weight1 path1 weight2 path2 ...)",
+ )
+ parser.add_argument(
+ "--split", type=str, default="99,1,0", help="Train,Val,Test ratios to split data"
+ )
+ parser.add_argument(
+ "--data_path_to_cache", type=str, default=None, help="Path to cache the dataset indices"
+ )
+ parser.add_argument(
+ "--use_mock_data", action="store_true", help="Use mock data instead of --data_paths"
+ )
+ # Training & Eval arguments
+ parser.add_argument(
+ "--output_dir", type=str, required=True, help="Folder for logging and checkpoint saving"
+ )
+ parser.add_argument(
+ "--seq_length",
+ type=int,
+ default=4096,
+ help="Number of tokens per input sample. Use 8192 if your dataset has longer sequences.",
+ )
+ parser.add_argument("--mbs", type=int, default=1, help="Micro-batch Size")
+ parser.add_argument("--gbs", type=int, default=768, help="Global Batch Size")
+ parser.add_argument(
+ "--train_iters", type=int, required=True, help="Number of training iterations"
+ )
+ parser.add_argument("--lr", type=float, default=1e-4, help="Peak learning rate")
+ parser.add_argument("--min_lr", type=float, default=1e-5, help="Minimum learning rate")
+ parser.add_argument("--lr_warmup_iters", type=int, default=50, help="Number of LR warmup steps")
+ parser.add_argument(
+ "--eval_interval", type=int, default=100, help="Validate + checkpoint every steps"
+ )
+ parser.add_argument(
+ "--eval_iters", type=int, default=32, help="Number of batches per validation stage"
+ )
+ # Logging arguments
+ parser.add_argument("--log_interval", type=int, default=10, help="Write to log every steps")
+ parser.add_argument(
+ "--wandb_project", type=str, help="Wandb project name (required to enable Wandb logging)"
+ )
+ parser.add_argument("--wandb_entity", type=str, help="Wandb entity name (optional)")
+ parser.add_argument("--wandb_exp_name", type=str, help="Wandb experiment name (optional)")
+ # Export arguments
+ parser.add_argument(
+ "--hf_export_path",
+ "--hf-export-path",
+ type=str,
+ default=None,
+ help=(
+ "Path where to save the HuggingFace export. "
+ "If provided, exports checkpoint to HF format after distillation."
+ ),
+ )
+ parser.add_argument(
+ "--hf_model",
+ "--hf-model",
+ type=str,
+ required=True,
+ help="HuggingFace model ID to use as template for export (e.g., meta-llama/Llama-3.1-8B-Instruct). "
+ "Should match the base architecture of the student model.",
+ )
+ args = parser.parse_args()
+
+ # Sanity checks
+ if not args.use_mock_data and not args.data_paths:
+ raise ValueError("Must provide either --data_paths or set --use_mock_data.")
+
+ print_rank_0("\n==================== Arguments ====================")
+ for k, v in args.__dict__.items():
+ print_rank_0(f"{k:<35} {v}")
+ print_rank_0("===================================================\n")
+
+ return args
+
+
+def main(args: argparse.Namespace):
+ checkpoint_dir = os.path.join(args.output_dir, "checkpoints")
+ tensorboard_dir = os.path.join(args.output_dir, "tb_logs")
+
+ # Build student and teacher model providers
+ def _build_model_provider(hf_path):
+ bridge = AutoBridge.from_hf_pretrained(hf_path, trust_remote_code=args.trust_remote_code)
+ provider = bridge.to_megatron_provider(load_weights=True)
+
+ # Override parallelism / training settings
+ provider.tensor_model_parallel_size = args.tp_size
+ provider.pipeline_model_parallel_size = args.pp_size
+ provider.context_parallel_size = 1
+ provider.sequence_parallel = args.tp_size > 1
+ provider.seq_length = args.seq_length
+ provider.pipeline_dtype = torch.bfloat16
+ return provider
+
+ # TODO: Support megatron-ckpt as an alternative to HF checkpoints (e.g. /path/to/ckpt/iter_0000000)
+ # Still requires an HF model name or path to build provider correctly
+ student_provider = _build_model_provider(args.student_hf_path)
+ teacher_provider = _build_model_provider(args.teacher_hf_path)
+
+ # Wrap into DistillationProvider
+ kd_config = ModelOptDistillConfig()
+ distill_provider = convert_to_distillation_provider(
+ student_provider, teacher_provider, kd_config
+ )
+
+ # Build optimizer and scheduler
+ optimizer_config, scheduler_config = distributed_fused_adam_with_cosine_annealing(
+ lr_warmup_iters=args.lr_warmup_iters,
+ max_lr=args.lr,
+ min_lr=args.min_lr,
+ adam_beta2=0.98,
+ )
+
+ # Build dataset config
+ dataset_kwargs = {
+ "seq_length": args.seq_length,
+ "path_to_cache": args.data_path_to_cache,
+ "random_seed": SEED,
+ "reset_attention_mask": False,
+ "reset_position_ids": False,
+ "eod_mask_loss": False,
+ "num_dataset_builder_threads": 1,
+ "data_sharding": True,
+ "dataloader_type": "single",
+ "skip_getting_attention_mask_from_dataset": True,
+ }
+ if args.use_mock_data:
+ dataset_config = MockGPTDatasetConfig(**dataset_kwargs)
+ else:
+ # Convert flat CLI list (e.g. ["1.0", "/path/data"]) to Megatron blend format
+ blend = get_blend_from_list(args.data_paths)
+ dataset_config = GPTDatasetConfig(blend=blend, split=args.split, **dataset_kwargs)
+
+ # Assemble ConfigContainer and run distillation
+ config = ConfigContainer(
+ model=distill_provider,
+ train=TrainingConfig(
+ train_iters=args.train_iters,
+ eval_interval=args.eval_interval,
+ eval_iters=args.eval_iters,
+ global_batch_size=args.gbs,
+ micro_batch_size=args.mbs,
+ manual_gc=True,
+ manual_gc_interval=100,
+ ),
+ # TODO: Replace validation args in train with validation config in nemo:26.04
+ # validation=ValidationConfig(eval_interval=args.eval_interval, eval_iters=args.eval_iters),
+ optimizer=optimizer_config,
+ scheduler=scheduler_config,
+ ddp=DistributedDataParallelConfig(
+ check_for_nan_in_grad=True,
+ grad_reduce_in_fp32=True,
+ overlap_grad_reduce=True,
+ overlap_param_gather=True,
+ average_in_collective=True,
+ use_distributed_optimizer=True,
+ ),
+ dataset=dataset_config,
+ logger=LoggerConfig(
+ log_interval=args.log_interval,
+ tensorboard_dir=tensorboard_dir,
+ log_timers_to_tensorboard=True,
+ # Weights & Biases logging
+ wandb_project=args.wandb_project,
+ wandb_entity=args.wandb_entity, # optional
+ wandb_exp_name=args.wandb_exp_name,
+ ),
+ tokenizer=TokenizerConfig(
+ tokenizer_type="NullTokenizer", vocab_size=distill_provider.vocab_size
+ ),
+ checkpoint=CheckpointConfig(
+ save_interval=args.eval_interval,
+ save=checkpoint_dir,
+ load=checkpoint_dir, # Resume from this directory (if exists)
+ most_recent_k=3, # Keeps 3 most recent checkpoints (not metric-based)
+ ckpt_format="torch_dist",
+ async_save=True,
+ fully_parallel_save=True,
+ ),
+ rng=RNGConfig(seed=SEED),
+ mixed_precision="bf16_mixed",
+ )
+
+ print_rank_0("\nStarting distillation...")
+ distill(config)
+ print_rank_0(f"\nDistillation done! Saved checkpoint to {checkpoint_dir}\n")
+
+ # Export to HuggingFace format if hf_export_path is provided
+ if args.hf_export_path:
+ # Wait for all ranks to finish distillation before export
+ if torch.distributed.is_initialized():
+ torch.distributed.barrier()
+
+ # Save rank before destroying process group (dist.rank() won't work after destruction)
+ is_rank_0 = dist.rank() == 0
+
+ # Destroy process group on all ranks - export_ckpt will create its own temporary one
+ # This prevents cleanup from hanging (cleanup tries to barrier, but rank 0 would be gone)
+ if torch.distributed.is_initialized():
+ torch.distributed.destroy_process_group()
+
+ # Only rank 0 exports
+ if is_rank_0:
+ try:
+ export_to_hf_and_copy_config(
+ student_hf_path=args.student_hf_path,
+ checkpoint_dir=checkpoint_dir,
+ train_iters=args.train_iters,
+ hf_export_path=args.hf_export_path,
+ hf_model=args.hf_model,
+ trust_remote_code=args.trust_remote_code,
+ )
+ except Exception as e:
+ print(f"⚠️ Export failed: {e}")
+ traceback.print_exc()
+
+
+if __name__ == "__main__":
+ dist.setup()
+ args = get_args()
+ try:
+ main(args)
+ except Exception as e:
+ print_rank_0(f"✗ MAIN FAILED: {type(e).__name__}: {e}")
+ print_rank_0(f"Traceback:\n{traceback.format_exc()}")
+ raise
+ finally:
+ dist.cleanup()
diff --git a/examples/puzzletron/mip_sweep_example.png b/examples/puzzletron/mip_sweep_example.png
new file mode 100644
index 0000000000..4eb1089fe0
Binary files /dev/null and b/examples/puzzletron/mip_sweep_example.png differ
diff --git a/examples/puzzletron/requirements.txt b/examples/puzzletron/requirements.txt
new file mode 100644
index 0000000000..0511fb473b
--- /dev/null
+++ b/examples/puzzletron/requirements.txt
@@ -0,0 +1,3 @@
+lm-eval==0.4.10
+math-verify
+ray
diff --git a/modelopt/torch/prune/importance_hooks/__init__.py b/modelopt/torch/prune/importance_hooks/__init__.py
index 3bf30c2a46..1e86ddcf65 100644
--- a/modelopt/torch/prune/importance_hooks/__init__.py
+++ b/modelopt/torch/prune/importance_hooks/__init__.py
@@ -18,6 +18,7 @@
from .base_hooks import *
from .base_hooks_analysis import *
+from .expert_removal_hooks import *
with import_plugin("megatron_hooks"):
from .plugins.megatron_hooks import *
diff --git a/modelopt/torch/prune/importance_hooks/base_hooks.py b/modelopt/torch/prune/importance_hooks/base_hooks.py
index 248e6ec108..44eea3bdbe 100644
--- a/modelopt/torch/prune/importance_hooks/base_hooks.py
+++ b/modelopt/torch/prune/importance_hooks/base_hooks.py
@@ -149,7 +149,8 @@ def dump_activations_logs(
torch.save(activations_log, activations_log_path)
if rank == 0:
- args.activation_hooks_kwargs.pop("model")
+ if args.activation_hooks_kwargs is not None:
+ args.activation_hooks_kwargs.pop("model", None)
json_dump(OmegaConf.to_container(args, resolve=True), activations_log_dir / "args.json")
dist.barrier()
@@ -565,9 +566,9 @@ def __init__(self, linear_layer: nn.Linear, activation_hooks_kwargs: dict):
assert self.optimize_for in ["latency", "memory"]
self.hidden_size = model_config.hidden_size
- self.n_heads_in_group = block_config.attention.n_heads_in_group
self.num_q_heads = model_config.num_attention_heads
- self.num_kv_heads = self.num_q_heads // self.n_heads_in_group
+ self.num_kv_heads = block_config.attention.num_key_value_heads
+ self.n_heads_in_group = self.num_q_heads // self.num_kv_heads
self.head_dim = getattr(model_config, "head_dim", self.hidden_size // self.num_q_heads)
self.agg_kv_head_contributions = torch.zeros(
@@ -734,7 +735,7 @@ def _save_channel_importance_results(
all_scores = []
for activation_file in activation_files:
print(f"Loading activations from {activation_file}")
- # SECURITY: weights_only=False is required because files contain dictionaries with tensors.
+ # Security NOTE: weights_only=False is required because files contain dictionaries with tensors.
# These files are generated by dump_activations_logs() in this module and contain
# hook state dictionaries. The activations_log_dir should only contain trusted files
# generated by the same codebase, not from untrusted sources.
diff --git a/modelopt/torch/prune/importance_hooks/compare_module_outputs.py b/modelopt/torch/prune/importance_hooks/compare_module_outputs.py
index e692a518ae..dbb4f564d7 100644
--- a/modelopt/torch/prune/importance_hooks/compare_module_outputs.py
+++ b/modelopt/torch/prune/importance_hooks/compare_module_outputs.py
@@ -52,7 +52,8 @@
python compare_module_outputs.py \
--reference output_unpruned.pt \
--compare output_l2norm.pt \
- --output-json comparison_stats.json
+ --output-json comparison_stats.json \
+ --trust-inputs
The saved file format\:
@@ -180,21 +181,26 @@ def main():
default=None,
help="Path to save comparison statistics as JSON",
)
+ parser.add_argument(
+ "--trust-inputs",
+ action="store_true",
+ help="Trust input files for loading with weights_only=False in torch.load()",
+ )
args = parser.parse_args()
# Load reference data
print(f"\nLoading reference: {args.reference}")
- # SECURITY: weights_only=False is required because files contain dictionaries with tensors.
+ # Security NOTE: weights_only=False is required because files contain dictionaries with tensors.
# These files are expected to be generated by save_multi_layer_outputs() in this module,
# not from untrusted sources. Users should only load files they generated themselves.
- ref_data = torch.load(args.reference, map_location="cpu", weights_only=False)
+ ref_data = torch.load(args.reference, map_location="cpu", weights_only=not args.trust_inputs)
# Load comparison data
print(f"Loading compare: {args.compare}")
- # SECURITY: weights_only=False is required because files contain dictionaries with tensors.
+ # Security NOTE: weights_only=False is required because files contain dictionaries with tensors.
# These files are expected to be generated by save_multi_layer_outputs() in this module,
# not from untrusted sources. Users should only load files they generated themselves.
- comp_data = torch.load(args.compare, map_location="cpu", weights_only=False)
+ comp_data = torch.load(args.compare, map_location="cpu", weights_only=not args.trust_inputs)
# Compare multi-layer outputs
compare_multi_layer(ref_data, comp_data, args.output_json)
diff --git a/modelopt/torch/prune/importance_hooks/expert_removal_hooks.py b/modelopt/torch/prune/importance_hooks/expert_removal_hooks.py
new file mode 100644
index 0000000000..68eaf2e711
--- /dev/null
+++ b/modelopt/torch/prune/importance_hooks/expert_removal_hooks.py
@@ -0,0 +1,387 @@
+# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+# SPDX-License-Identifier: Apache-2.0
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# mypy: ignore-errors
+
+"""MoE expert-removal and ranked-choice importance hooks (uses Puzzletron BlockConfig)."""
+
+from abc import ABC, abstractmethod
+
+import torch
+from torch import nn
+
+from modelopt.torch.puzzletron.decilm.deci_lm_hf_code.block_config import BlockConfig # noqa: TC001
+
+from .base_hooks import ForwardHook
+
+__all__ = [
+ "NemotronHRemoveExpertsIndependentHook",
+ "Qwen3VLRemoveExpertsIndependentHook",
+ "RankedChoiceVotingHook",
+ "RankedChoiceVotingHookNemotronH",
+ "RemoveExpertsIndependentHook",
+]
+
+
+class RemoveExpertsIndependentHook(ForwardHook, ABC):
+ """Base hook for measuring expert importance in Mixture-of-Experts models.
+
+ This hook measures how much removing each expert affects the model output
+ by comparing outputs with and without each expert.
+ """
+
+ def __init__(self, moe: nn.Module, activation_hooks_kwargs: dict):
+ """Initialize the hook.
+
+ Args:
+ moe: The MoE module to analyze
+ activation_hooks_kwargs: Configuration dict containing block_config
+ """
+ self.moe = moe
+ block_config: BlockConfig = activation_hooks_kwargs["block_config"]
+ self.num_local_experts = block_config.ffn.moe.num_local_experts
+ self.num_experts_per_tok = block_config.ffn.moe.num_experts_per_tok
+ # tensor of zeros of size num experts
+ self.diffs = ["mse", "cosine"]
+ some_param = next(self.moe.parameters())
+ self.diffs = {
+ k: torch.zeros(
+ size=(self.num_local_experts,), dtype=torch.float32, device=some_param.device
+ )
+ for k in self.diffs
+ }
+ self.call_count = 0
+
+ @abstractmethod
+ def get_router_logits_and_routed_experts(
+ self, hidden_states: torch.Tensor, router_logits: torch.Tensor | None = None
+ ) -> tuple[torch.Tensor, torch.Tensor]:
+ """Extract router logits and expert outputs for measuring expert importance.
+
+ This method is called twice per forward pass:
+ 1. First call (router_logits=None): Compute original routing and expert outputs
+ 2. Second call (router_logits provided): Re-run with modified logits (expert disabled)
+
+ Args:
+ hidden_states: Input tensor of shape (batch, seq_len, hidden_dim)
+ router_logits: Optional pre-computed router logits. If None, compute from hidden_states.
+
+ Returns:
+ tuple of (router_logits, routed_experts):
+ - router_logits: Shape (num_tokens, num_local_experts)
+ - routed_experts: Shape (num_tokens, hidden_dim)
+ """
+ raise NotImplementedError
+
+ def __call__(
+ self, module: nn.Module, args: tuple[torch.Tensor, ...], output: torch.Tensor
+ ) -> None:
+ """Forward hook that measures expert importance."""
+ hidden_states = args[0]
+ router_logits, original_routed_out = self.get_router_logits_and_routed_experts(
+ hidden_states
+ )
+
+ hidden_states = hidden_states.view(-1, hidden_states.shape[-1])
+ original_routed_out = original_routed_out.view(-1, original_routed_out.shape[-1])
+
+ _, router_indices = torch.topk(router_logits, self.num_experts_per_tok, dim=-1)
+ self.call_count += 1
+
+ for i_expert in range(self.num_local_experts):
+ expert_mask = router_indices == i_expert
+ is_token_routed_to_this_expert = expert_mask.any(dim=-1)
+
+ num_tokens_displaced = is_token_routed_to_this_expert.sum()
+ if num_tokens_displaced == 0:
+ continue
+ num_total_tokens = is_token_routed_to_this_expert.numel()
+
+ relevant_hidden_states = hidden_states[is_token_routed_to_this_expert, :]
+
+ router_logits_without_i = router_logits.clone()
+ router_logits_without_i[..., i_expert] = -float("inf") # disable expert i
+ router_logits_without_i = router_logits_without_i[is_token_routed_to_this_expert, :]
+ _, routed_out_without_i = self.get_router_logits_and_routed_experts(
+ relevant_hidden_states, router_logits_without_i
+ )
+
+ relevant_tokens_original_out = original_routed_out[is_token_routed_to_this_expert, :]
+ self.diffs["mse"][i_expert] += (
+ nn.functional.mse_loss(
+ relevant_tokens_original_out, routed_out_without_i, reduction="mean"
+ )
+ * num_tokens_displaced
+ / num_total_tokens
+ )
+ self.diffs["cosine"][i_expert] += (
+ -nn.functional.cosine_similarity(
+ relevant_tokens_original_out, routed_out_without_i, dim=-1
+ ).mean()
+ * num_tokens_displaced
+ / num_total_tokens
+ )
+
+ def to_dict(self) -> dict[str, torch.Tensor]:
+ """Convert accumulated statistics to dict format."""
+ expert_ranks_mse = torch.argsort(self.diffs["mse"])
+ expert_ranks_cosine = torch.argsort(self.diffs["cosine"])
+ return {
+ "expert_ranks_mse": expert_ranks_mse.cpu(),
+ "expert_ranks_cosine": expert_ranks_cosine.cpu(),
+ "cosine_diffs": (self.diffs["cosine"] / self.call_count).cpu(),
+ "mse_diffs": (self.diffs["mse"] / self.call_count).cpu(),
+ }
+
+ def accumulate(self) -> torch.Tensor:
+ """Return accumulated expert importance scores."""
+ return self.diffs["mse"]
+
+ def state_dict(self) -> dict:
+ """Return the internal state for checkpointing."""
+ return {
+ "diffs_mse": self.diffs["mse"].cpu(),
+ "diffs_cosine": self.diffs["cosine"].cpu(),
+ "call_count": self.call_count,
+ }
+
+ def load_state_dict(self, state_dict: dict) -> None:
+ """Load the internal state from a checkpoint."""
+ self.diffs["mse"] = state_dict["diffs_mse"].to(self.diffs["mse"].device)
+ self.diffs["cosine"] = state_dict["diffs_cosine"].to(self.diffs["cosine"].device)
+ self.call_count = state_dict["call_count"]
+
+
+class NemotronHRemoveExpertsIndependentHook(RemoveExpertsIndependentHook):
+ """Expert removal importance hook for NemotronH models."""
+
+ def get_router_logits_and_routed_experts(
+ self, hidden_states: torch.Tensor, router_logits: torch.Tensor | None = None
+ ) -> tuple[torch.Tensor, torch.Tensor]:
+ """Extract router logits and expert outputs for NemotronH MoE.
+
+ Based on NemotronHMOE forward, uses minimum ops to get router_logits and routed_experts.
+ """
+ orig_shape = hidden_states.shape
+ # NemotronHMOE.gate forward, copied to extract router_logits
+ hidden_states = hidden_states.view(-1, hidden_states.shape[-1])
+ if router_logits is None:
+ router_logits = nn.functional.linear(
+ hidden_states.type(torch.float32), self.moe.gate.weight.type(torch.float32)
+ )
+ router_logits = router_logits.sigmoid()
+ router_logits = router_logits + self.moe.gate.e_score_correction_bias.unsqueeze(0)
+
+ topk_indices = self._get_topk_indices_without_correction_bias(router_logits)
+ topk_weights = router_logits.gather(1, topk_indices)
+ if self.moe.gate.norm_topk_prob:
+ denominator = topk_weights.sum(dim=-1, keepdim=True) + 1e-20
+ topk_weights /= denominator
+ topk_weights = topk_weights * self.moe.gate.routed_scaling_factor
+ # Routed experts forward
+ hidden_states = self.moe.moe(hidden_states, topk_indices, topk_weights).view(*orig_shape)
+ return router_logits, hidden_states
+
+ @torch.no_grad()
+ def _get_topk_indices_without_correction_bias(self, scores: torch.Tensor) -> torch.Tensor:
+ """Get topk indices without correction bias.
+
+ Same as NemotronHMOE.gate.get_topk_indices but without adding e_score_correction_bias.
+ """
+ group_scores = (
+ scores.view(
+ -1, self.moe.gate.n_group, self.moe.gate.n_routed_experts // self.moe.gate.n_group
+ )
+ .topk(2, dim=-1)[0]
+ .sum(dim=-1)
+ )
+ group_idx = torch.topk(group_scores, k=self.moe.gate.topk_group, dim=-1, sorted=False)[1]
+ group_mask = torch.zeros_like(group_scores)
+ group_mask.scatter_(1, group_idx, 1)
+ score_mask = (
+ group_mask.unsqueeze(-1)
+ .expand(
+ -1, self.moe.gate.n_group, self.moe.gate.n_routed_experts // self.moe.gate.n_group
+ )
+ .reshape(-1, self.moe.gate.n_routed_experts)
+ )
+ scores_for_choice = scores.masked_fill(~score_mask.bool(), 0.0)
+ topk_indices = torch.topk(scores_for_choice, k=self.moe.gate.top_k, dim=-1, sorted=False)[1]
+ return topk_indices
+
+
+class RankedChoiceVotingHook(ForwardHook):
+ """Hook for ranking experts using ranked choice voting algorithm.
+
+ This hook tracks router decisions and uses ranked choice voting to determine
+ which experts are least important (can be pruned first).
+ """
+
+ def __init__(self, router: nn.Module, activation_hooks_kwargs: dict):
+ """Initialize the hook.
+
+ Args:
+ router: The router module (typically nn.Linear)
+ activation_hooks_kwargs: Configuration dict containing block_config
+ """
+ self.router_argsort: list[torch.Tensor] = []
+ block_config: BlockConfig = activation_hooks_kwargs["block_config"]
+ self.top_k = block_config.ffn.moe.num_experts_per_tok
+
+ def __call__(
+ self, module: nn.Module, args: tuple[torch.Tensor, ...], output: torch.Tensor
+ ) -> None:
+ """Forward hook that records router decisions.
+
+ Args:
+ module: The router module
+ args: Tuple with one tensor entry (B, T, I)
+ output: Router logits of shape (B, T, E)
+ """
+ router_logits = output[0] if isinstance(output, tuple) else output
+ num_experts = router_logits.shape[-1]
+ router_argsort = torch.argsort(router_logits, dim=-1, descending=True)
+ router_argsort = router_argsort.view(-1, num_experts).to(torch.int16).cpu()
+ self.router_argsort.append(router_argsort)
+
+ def to_dict(self) -> dict[str, torch.Tensor]:
+ """Convert accumulated statistics to dict format using ranked choice voting."""
+ router_argsort = torch.concat(self.router_argsort, dim=0)
+ num_tokens, num_experts = router_argsort.shape
+
+ expert_ranks = torch.full((num_experts,), -1)
+ expert_counts_at_pruning_time = {}
+
+ expert_kept_per_iteration: list[list[int]] = []
+ expert_counts_per_iteration: list[dict[int, int]] = []
+
+ for rank in range(num_experts):
+ ids, counts = router_argsort[:, : self.top_k].unique(return_counts=True)
+ ids = ids.tolist()
+ counts = counts.tolist()
+ expert_counts = dict(zip(ids, counts))
+
+ expert_kept_per_iteration.append(ids)
+ expert_counts_per_iteration.append(expert_counts)
+
+ least_popular_expert, min_count = min(expert_counts.items(), key=lambda tup: tup[1])
+
+ expert_ranks[least_popular_expert] = rank
+ expert_counts_at_pruning_time[least_popular_expert] = min_count
+ print(f"#{rank}: router_argsort shape = {router_argsort.shape}")
+ router_argsort = router_argsort[router_argsort != least_popular_expert].view(
+ num_tokens, -1
+ )
+
+ zero_shot_expert_counts = torch.zeros((num_experts,), dtype=torch.long)
+ for expert_id, expert_counts_val in expert_counts_per_iteration[0].items():
+ zero_shot_expert_counts[expert_id] = expert_counts_val
+
+ # Compute zero-shot expert ranks (double argsort converts counts to rank positions)
+ zero_shot_expert_ranks = torch.argsort(torch.argsort(zero_shot_expert_counts))
+
+ print("Done: Returning hook metadata.")
+ return {
+ "expert_ranks": expert_ranks,
+ "zero_shot_expert_ranks": zero_shot_expert_ranks,
+ "expert_counts_at_pruning_time": expert_counts_at_pruning_time,
+ "expert_counts_per_iteration": expert_counts_per_iteration,
+ "top_k": self.top_k,
+ }
+
+ def accumulate(self) -> torch.Tensor:
+ """Return accumulated expert ranks."""
+ if not self.router_argsort:
+ return torch.tensor([])
+ router_argsort = torch.concat(self.router_argsort, dim=0)
+ return router_argsort[:, 0].float()
+
+ def state_dict(self) -> dict:
+ """Return the internal state for checkpointing."""
+ return {
+ "router_argsort": [tensor.cpu().clone() for tensor in self.router_argsort],
+ "top_k": self.top_k,
+ }
+
+ def load_state_dict(self, state_dict: dict) -> None:
+ """Load the internal state from a checkpoint."""
+ self.router_argsort = [tensor.cpu() for tensor in state_dict["router_argsort"]]
+ self.top_k = state_dict["top_k"]
+
+ def get_progress_info(self) -> dict:
+ """Get progress information."""
+ return {
+ "num_batches_processed": len(self.router_argsort),
+ "total_tokens_processed": sum(tensor.shape[0] for tensor in self.router_argsort)
+ if self.router_argsort
+ else 0,
+ }
+
+
+class RankedChoiceVotingHookNemotronH(RankedChoiceVotingHook):
+ """Ranked choice voting hook for NemotronH models.
+
+ In NemotronH, router_logits is an internal temporary state that never leaves
+ the forward() function. We reconstruct router_logits from the input hidden_states.
+ """
+
+ def __call__(
+ self, module: nn.Module, args: tuple[torch.Tensor, ...], output: torch.Tensor
+ ) -> None:
+ """Forward hook that reconstructs router logits from hidden states."""
+ hidden_states = args[0]
+ hidden_states = hidden_states.view(-1, module.config.hidden_size)
+ router_logits = nn.functional.linear(
+ hidden_states.type(torch.float32), module.weight.type(torch.float32)
+ )
+ super().__call__(module, args, router_logits)
+
+
+class Qwen3VLRemoveExpertsIndependentHook(RemoveExpertsIndependentHook):
+ """Expert removal importance hook for Qwen3-VL models."""
+
+ def get_router_logits_and_routed_experts(
+ self, hidden_states: torch.Tensor, router_logits: torch.Tensor | None = None
+ ) -> tuple[torch.Tensor, torch.Tensor]:
+ """Extract router logits and expert outputs for Qwen3-VL MoE.
+
+ Based on Qwen3VLMoeSparseMoe forward pass.
+ """
+ orig_shape = hidden_states.shape
+
+ # Flatten to (num_tokens, hidden_size) for processing
+ hidden_states_flat = hidden_states.reshape(-1, self.moe.hidden_size)
+
+ if router_logits is None:
+ router_logits = self.moe.gate(hidden_states_flat)
+
+ routing_weights = torch.nn.functional.softmax(router_logits, dim=-1, dtype=torch.float)
+ routing_weights, router_indices = torch.topk(routing_weights, self.moe.top_k, dim=-1)
+ routing_weights = routing_weights / routing_weights.sum(dim=-1, keepdim=True)
+ routing_weights = routing_weights.to(hidden_states_flat.dtype)
+ router_weights = torch.zeros_like(router_logits).scatter_(
+ 1, router_indices, routing_weights
+ )
+
+ # Reshape hidden_states for moe.experts (expects 3D: batch, seq, hidden)
+ # router_weights and router_indices remain 2D (num_tokens, num_experts)
+ batch_size = orig_shape[0] if hidden_states.ndim == 3 else 1
+ hidden_states_3d = hidden_states_flat.reshape(batch_size, -1, self.moe.hidden_size)
+
+ routed_out = self.moe.experts(hidden_states_3d, router_weights, router_indices)
+
+ # Return in same shape as input
+ routed_out = routed_out.reshape(*orig_shape)
+
+ return router_logits, routed_out
diff --git a/modelopt/torch/puzzletron/README.md b/modelopt/torch/puzzletron/README.md
new file mode 100644
index 0000000000..4c6da80e54
--- /dev/null
+++ b/modelopt/torch/puzzletron/README.md
@@ -0,0 +1,3 @@
+Experimental model compression algorithm based on a Local Neural Architecture Search.
+Based on the Puzzle paper:
+PoC for Llama 3.1 model.
diff --git a/modelopt/torch/puzzletron/__init__.py b/modelopt/torch/puzzletron/__init__.py
new file mode 100644
index 0000000000..47f1c65a15
--- /dev/null
+++ b/modelopt/torch/puzzletron/__init__.py
@@ -0,0 +1,15 @@
+# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+# SPDX-License-Identifier: Apache-2.0
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
diff --git a/modelopt/torch/puzzletron/activation_scoring/activation_hooks/__init__.py b/modelopt/torch/puzzletron/activation_scoring/activation_hooks/__init__.py
new file mode 100644
index 0000000000..47f1c65a15
--- /dev/null
+++ b/modelopt/torch/puzzletron/activation_scoring/activation_hooks/__init__.py
@@ -0,0 +1,15 @@
+# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+# SPDX-License-Identifier: Apache-2.0
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
diff --git a/modelopt/torch/puzzletron/activation_scoring/activation_hooks/utils.py b/modelopt/torch/puzzletron/activation_scoring/activation_hooks/utils.py
new file mode 100644
index 0000000000..ccf73f7612
--- /dev/null
+++ b/modelopt/torch/puzzletron/activation_scoring/activation_hooks/utils.py
@@ -0,0 +1,96 @@
+# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+# SPDX-License-Identifier: Apache-2.0
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# mypy: ignore-errors
+
+"""Provides a function to register activation hooks for a model.
+Activation hooks are used to compute activation scores for pruning."""
+
+from typing import Type
+
+import torch
+
+from modelopt.torch.prune.importance_hooks.base_hooks import ForwardHook as ActivationsHook
+from modelopt.torch.puzzletron.tools.logger import aprint
+from modelopt.torch.puzzletron.utils.dummy_modules import DummyBlock, DummyModule
+
+
+def register_activation_hooks(
+ model,
+ activation_hooks_kwargs: dict,
+ pruning_mixin,
+ hook_class: Type[ActivationsHook],
+) -> dict[str, ActivationsHook]:
+ """Register activation hooks using the pruning mixin approach.
+
+ Args:
+ model: The model to register hooks on.
+ activation_hooks_kwargs: Keyword arguments passed to hook constructors.
+ pruning_mixin: The pruning mixin that defines which modules to hook.
+ hook_class: The hook class to instantiate for each module.
+
+ Returns:
+ Dictionary mapping module names to hook instances.
+ """
+ activation_hooks_kwargs["model"] = model
+
+ if hook_class not in pruning_mixin.supported_hooks():
+ raise ValueError(
+ f"Hook class not supported for {pruning_mixin.__class__.__name__}, "
+ f"must be in {pruning_mixin.supported_hooks()}"
+ )
+
+ module_names_to_hook = pruning_mixin.get_module_names_to_hook(model)
+ activation_hooks = dict()
+ for block_idx, module_name in module_names_to_hook:
+ try:
+ module = model.get_submodule(module_name)
+ except AttributeError:
+ # Module doesn't exist on this rank's shard (e.g., in distributed setup)
+ continue
+
+ # Skip dummy modules - they don't have real activations to hook
+ if isinstance(module, (DummyModule, DummyBlock)):
+ continue
+
+ block_config = None
+ if block_idx is not None:
+ block_config = model.config.block_configs[block_idx]
+ curr_activation_hooks_kwargs = {
+ **activation_hooks_kwargs,
+ "block_config": block_config,
+ }
+
+ hook = hook_class(module, curr_activation_hooks_kwargs)
+ module.register_forward_hook(hook)
+ activation_hooks[module_name] = hook
+
+ if len(activation_hooks) == 0:
+ # In distributed mode, it's okay for a rank to have 0 hooks if it doesn't own
+ # the target modules (e.g., with hybrid patterns like "*-" where different
+ # ranks own different layer types). However, we still want to catch real bugs
+ # where no hooks are found at all.
+ is_distributed = torch.distributed.is_available() and torch.distributed.is_initialized()
+ if is_distributed:
+ aprint(
+ "No hooks registered on this rank. This is expected if this rank "
+ "doesn't own any layers matching the hook pattern (e.g., in hybrid "
+ "patterns with distributed model sharding)."
+ )
+ else:
+ raise ValueError("couldn't find any hooks")
+
+ if len(activation_hooks) > 0:
+ aprint(f"Found the following hooks: {activation_hooks.keys()}")
+ return activation_hooks
diff --git a/modelopt/torch/puzzletron/activation_scoring/score_pruning_activations.py b/modelopt/torch/puzzletron/activation_scoring/score_pruning_activations.py
new file mode 100644
index 0000000000..c043c20d5f
--- /dev/null
+++ b/modelopt/torch/puzzletron/activation_scoring/score_pruning_activations.py
@@ -0,0 +1,141 @@
+# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+# SPDX-License-Identifier: Apache-2.0
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from pathlib import Path
+
+import torch
+from omegaconf import DictConfig
+
+import modelopt.torch.utils.distributed as dist
+from modelopt.torch.puzzletron.tools.logger import mprint
+from modelopt.torch.puzzletron.tools.validate_model import validate_model
+
+
+def has_checkpoint_support(activation_hooks_kwargs: dict) -> bool:
+ """Determine if the activation hook method has proper checkpoint support implemented.
+
+ Args:
+ activation_hooks_kwargs: Hook configuration
+
+ Returns:
+ bool: True if the hook method has save_state/load_state implemented
+ """
+ method = activation_hooks_kwargs.get("method", "")
+
+ # Methods with implemented checkpoint support
+ supported_methods = {
+ "iterative", # IterativeChannelContributionHook: save_state/load_state implemented
+ "independent", # IndependentChannelContributionHook: save_state/load_state implemented
+ "stats", # RouterStatsHook: save_state/load_state implemented
+ "ranked_choice_voting", # RankedChoiceVotingHook: save_state/load_state implemented
+ }
+
+ return method in supported_methods
+
+
+def check_scoring_completion(activations_log_dir: str, activation_hooks_kwargs=None) -> bool:
+ """Check if scoring is already completed by looking for the expected output files.
+ Also checks if the scoring method is safe for resume.
+
+ Args:
+ activations_log_dir: Directory where activation logs should be stored
+ activation_hooks_kwargs: Hook configuration to check if resume is safe
+
+ Returns:
+ bool: True if scoring is completed (has rank files and args.json)
+ """
+ # Only check completion on main process
+ if dist.is_master():
+ log_dir = Path(activations_log_dir)
+
+ # Check if directory exists
+ if not log_dir.exists():
+ return False
+
+ # Check for rank files (at least rank_0.pth should exist)
+ rank_files = list(log_dir.glob("rank_*.pth"))
+
+ if not rank_files:
+ return False
+
+ # Check for args.json (created by main process)
+ args_file = log_dir / "args.json"
+ has_args_json = args_file.exists()
+
+ # Check for completion: if we have rank files and args.json, scoring is complete
+ if rank_files and has_args_json:
+ # Add optional completion info for debugging
+ mprint(f"Found completed scoring in {activations_log_dir}")
+ mprint(f" - Found {len(rank_files)} rank files")
+ mprint(f" - Found args.json: {has_args_json}")
+
+ return True
+
+ return False
+
+
+def should_skip_scoring_completely(cfg: DictConfig) -> bool:
+ """Determine if we should skip scoring entirely (only if 100% complete).
+ Partial progress should proceed to validate_model for proper resume.
+
+ Args:
+ cfg: Configuration object
+
+ Returns:
+ bool: True if we should skip scoring (100% completed), False if we should run/resume it
+ """
+ # Check if activations_log_dir is specified
+ if not hasattr(cfg.pruning, "activations_log_dir") or cfg.pruning.activations_log_dir is None:
+ mprint("No activations_log_dir specified, running scoring")
+ return False
+
+ # Check for force restart flag
+ force_restart = getattr(cfg.pruning, "force_restart_scoring", False)
+ if force_restart:
+ mprint("Force restart flag set, will restart scoring regardless of existing artifacts")
+ return False
+
+ # Get hook configuration to check if resume is mathematically safe
+ activation_hooks_kwargs = getattr(cfg.pruning, "activation_hooks_kwargs", {})
+
+ # Check if scoring is already completed
+ is_completed = check_scoring_completion(
+ cfg.pruning.activations_log_dir, activation_hooks_kwargs
+ )
+
+ # Broadcast the result to all processes in distributed mode
+ if dist.size() > 1:
+ should_skip = [is_completed] # Use list for mutable object
+ torch.distributed.broadcast_object_list(should_skip, src=0)
+ is_completed = should_skip[0]
+
+ if is_completed:
+ mprint("Scoring 100% completed, skipping...")
+
+ return is_completed
+
+
+# Old progress tracking removed - checkpoint manager handles all progress tracking
+
+
+def launch_score_activations(cfg: DictConfig):
+ # Check if we should skip scoring entirely (only if 100% complete)
+ if should_skip_scoring_completely(cfg):
+ return
+
+ mprint("Starting pruning activation scoring...")
+
+ # The checkpoint manager inside validate_model handles all progress tracking
+ validate_model(args=cfg.pruning)
diff --git a/modelopt/torch/puzzletron/anymodel/README.md b/modelopt/torch/puzzletron/anymodel/README.md
new file mode 100644
index 0000000000..291966eb7b
--- /dev/null
+++ b/modelopt/torch/puzzletron/anymodel/README.md
@@ -0,0 +1,204 @@
+# AnyModel Guide
+
+This guide explains how to add support for new models in the Puzzletron pipeline.
+
+## Convert model
+
+Convert a HuggingFace model to Puzzletron format.
+
+Step 1: Create Model Descriptor
+
+Extend `ModelDescriptor` and implement `layer_name_predicates()` to define regex patterns for grouping weights into subblocks (embeddings, lm_head, block_N_ffn, block_N_attention).
+
+Key points:
+
+- Find weight names on the model's HuggingFace page → click "Files info" to see the safetensors structure with all tensor names (example: [Llama-3.1-8B-Instruct](https://huggingface.co/meta-llama/Llama-3.1-8B-Instruct?show_file_info=model.safetensors.index.json))
+
+See example: [llama_model_descriptor.py](models/llama/llama_model_descriptor.py)
+
+Step 2: Create Converter
+
+Extend `Converter` and implement `create_block_configs_from_main_config()` to create per-layer BlockConfigs from the HuggingFace config.
+
+Key points:
+
+- Import correct HuggingFace config class (e.g., `MistralConfig`, `LlamaConfig`, `Qwen2Config`). Find it in the transformers source: `github.com/huggingface/transformers/tree/main/src/transformers/models//configuration_.py`
+
+See example: [llama_converter.py](models/llama/llama_converter.py)
+
+Step 3: Create `models//__init__.py`
+
+Export descriptor and converter classes:
+
+```python
+from models.._model_descriptor import MyModelDescriptor
+from models.._converter import MyConverter
+```
+
+Step 4: Register in `models/__init__.py`
+
+Add import to trigger factory registration:
+
+```python
+from models. import *
+```
+
+## Usage
+
+```python
+from modelopt.torch.puzzletron.anymodel import convert_model
+
+convert_model(
+ input_dir="path/to/hf_checkpoint",
+ output_dir="path/to/puzzletron_checkpoint",
+ converter="model_name",
+)
+```
+
+## Compress model
+
+Run pruning and compression on a Puzzletron model.
+
+Step 1: Implement ModelDescriptor methods for compression
+
+Add to your `ModelDescriptor`:
+
+- `decoder_layer_cls()` - return the decoder layer class(es) to patch for heterogeneous config support
+- `block_config_to_layer_overrides()` - map BlockConfig to layer override dict (see [details](#implementing-block_config_to_layer_overrides))
+- `init_rotary_embedding()` - reinitialize rotary embeddings after model loading (see [details](#implementing-init_rotary_embedding))
+- `input_embedding_name()` - return the name of the input embedding layer (see [details](#implementing-path-based-methods))
+- `output_embedding_name()` - return the name of the output embedding layer (see [details](#implementing-path-based-methods))
+- `layer_block_name()` - return the name pattern for decoder layers (see [details](#implementing-path-based-methods))
+- `final_norm_name()` - return the name of the final normalization layer (see [details](#implementing-path-based-methods))
+- `attn_no_op_post_init()` - replace attention sublayers with no-op modules
+- `mlp_no_op_post_init()` - replace MLP sublayers with no-op modules
+
+Step 2: Create FFN Layer Descriptor
+
+Extend `FFNIntermediateLayerDescriptor` to define model-specific paths for FFN pruning hooks (`down_proj_name`, `ffn_prefix_name`, `linear_weight_names`). Derive values from your model's weight names in `layer_name_predicates()`.
+
+See example: [llama_model_descriptor.py](models/llama/llama_model_descriptor.py) → `LlamaFFNIntermediateLayerDescriptor`
+
+Step 3: Configure YAML files
+
+Update the main model config YAML:
+
+- Set `descriptor` to match the name used in `@ModelDescriptorFactory.register_decorator("your_model_name")`
+- See example: [llama_3_1_8b_instruct.yaml](../../../../tests/gpu/torch/puzzletron/resources/configs/llama_3_1_8b_instruct/llama_3_1_8b_instruct.yaml)
+
+Update pruning YAML files (`ffn_pruning.yaml`, `expert_pruning.yaml`, etc.):
+
+- Set `pruning_mixin._target_` to the appropriate mixin class
+- Set `layer_descriptor._target_` to your layer descriptor class
+- Set `hook_class` to the activation hook for scoring
+- Set `target_layer` in `activation_hooks_kwargs` to the layer name for hook attachment
+- See examples in [configs/llama_3_1_8b_instruct/pruning/](../../../../tests/gpu/torch/puzzletron/resources/configs/llama_3_1_8b_instruct/pruning/)
+
+## End-to-end example
+
+See [test_puzzletron.py](../../../../tests/gpu/torch/puzzletron/test_puzzletron.py) for a complete example that runs both convert and compression steps.
+
+---
+
+## Advanced Topics
+
+## Pruning Configuration
+
+### Pruning YAML Structure
+
+Each pruning type has a YAML config with these key fields:
+
+```yaml
+pruning_mixin:
+ _target_: pruning._pruning_mixin.
+ layer_descriptor:
+ _target_: models..
+
+hook_class: ${get_object:utils.activation_hooks.hooks.}
+activation_hooks_kwargs:
+ method:
+ target_layer: "" # e.g., "mlp.down_proj", "self_attn.o_proj"
+```
+
+| Field | Description |
+|-------|-------------|
+| `pruning_mixin._target_` | Mixin class that orchestrates this pruning type |
+| `layer_descriptor._target_` | Model-specific class defining layer paths for hooks |
+| `hook_class` | Activation hook class for importance scoring |
+| `target_layer` | Layer name (relative to decoder block) where hooks attach |
+
+### Adding a New Hook Class
+
+1. **Implement the hook** under `modelopt/torch/prune/importance_hooks/` (e.g. `base_hooks.py` for generic hooks, `expert_removal_hooks.py` for MoE expert removal):
+ - Extend an existing hook base class (e.g., `RemoveExpertsIndependentHook` in `expert_removal_hooks.py`)
+ - Implement required methods (e.g., `get_router_logits_and_routed_experts`)
+
+2. **Register the hook** in the appropriate pruning mixin's `supported_hooks()`:
+
+ For FFN pruning (`pruning/ffn_intermediate_pruning_mixin.py`):
+
+ ```python
+ def supported_hooks(self) -> List[Type[ActivationsHook]]:
+ return [IndependentChannelContributionHook, IterativeChannelContributionHook, YourNewHook]
+ ```
+
+ For expert removal (`pruning/expert_removal_pruning_mixin.py`):
+
+ ```python
+ def supported_hooks(self) -> List[Type[ActivationsHook]]:
+ return [RankedChoiceVotingHook, ..., YourNewHook]
+ ```
+
+3. **Reference in YAML**:
+
+ ```yaml
+ hook_class: ${get_object:utils.activation_hooks.hooks.YourNewHook}
+ ```
+
+### Pruning Types Reference
+
+| Type | Mixin | Example Hooks |
+|------|-------|---------------|
+| FFN intermediate | [`FFNIntermediatePruningMixIn`](../pruning/ffn_intermediate_pruning_mixin.py) | [`IterativeChannelContributionHook`](../../prune/importance_hooks/base_hooks.py), [`IndependentChannelContributionHook`](../../prune/importance_hooks/base_hooks.py) |
+| Expert removal | [`ExpertRemovalPruningMixIn`](../pruning/expert_removal_pruning_mixin.py) | [`NemotronHRemoveExpertsIndependentHook`](../../prune/importance_hooks/expert_removal_hooks.py), [`Qwen3VLRemoveExpertsIndependentHook`](../../prune/importance_hooks/expert_removal_hooks.py) |
+| KV heads | [`KVHeadsPruningMixIn`](../pruning/kv_heads_pruning_mixin.py) | [`IndependentKvHeadContributionHook`](../../prune/importance_hooks/base_hooks.py) |
+
+## Implementing `block_config_to_layer_overrides`
+
+Maps Puzzletron's [`BlockConfig`](../decilm/deci_lm_hf_code/block_config.py) fields to HuggingFace config attribute names. Only override attributes that change during pruning:
+
+| BlockConfig Field | HuggingFace Attribute (check `config.json`) |
+|-------------------|---------------------------------------------|
+| `attention.num_key_value_heads` | `num_key_value_heads` |
+| `ffn.intermediate_size` | `intermediate_size` |
+| `ffn.moe.num_local_experts` | `num_experts` or `n_routed_experts` (model-specific) |
+| `ffn.moe.expert_intermediate_dim` | `moe_intermediate_size` |
+
+**Tip**: Check the model's `config.json` for exact attribute names - they vary between models.
+
+See examples: [qwen3_vl](models/qwen3_vl/qwen3_vl_model_descriptor.py), [nemotron_h](models/nemotron_h/nemotron_h_model_descriptor.py)
+
+---
+
+## Implementing path-based methods
+
+These methods return paths derived from the model's weight names:
+
+- `input_embedding_name()`, `output_embedding_name()`, `layer_block_name()`, `final_norm_name()`
+
+Find them on the model's HuggingFace page → "Files info" → safetensors structure (example: [Llama-3.1-8B-Instruct](https://huggingface.co/meta-llama/Llama-3.1-8B-Instruct?show_file_info=model.safetensors.index.json)).
+
+See example: [llama_model_descriptor.py](models/llama/llama_model_descriptor.py)
+
+---
+
+## Implementing `init_rotary_embedding`
+
+Rotary embeddings are computed modules (not saved weights). After model sharding, they need re-initialization on the correct device/dtype.
+
+Look in `github.com/huggingface/transformers/tree/main/src/transformers/models//modeling_.py` for:
+
+- `class.*Rotary` — the rotary embedding class name and constructor arguments
+- `self.rotary_emb` — the attribute path
+
+See example: [llama_model_descriptor.py](models/llama/llama_model_descriptor.py)
diff --git a/modelopt/torch/puzzletron/anymodel/__init__.py b/modelopt/torch/puzzletron/anymodel/__init__.py
new file mode 100644
index 0000000000..e1755a16d8
--- /dev/null
+++ b/modelopt/torch/puzzletron/anymodel/__init__.py
@@ -0,0 +1,64 @@
+# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+# SPDX-License-Identifier: Apache-2.0
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# mypy: ignore-errors
+
+"""AnyModel: Architecture-agnostic model compression for HuggingFace models.
+
+This module provides a declarative approach to model compression that works with
+any HuggingFace model without requiring custom modeling code. Instead of duplicating
+HuggingFace modeling classes, AnyModel uses ModelDescriptors that define:
+
+1. Which decoder layer class(es) to patch for heterogeneous configs
+2. How to map BlockConfig to layer-specific overrides
+3. Weight name patterns for subblock checkpointing
+
+Example usage:
+ >>> from modelopt.torch.puzzletron.anymodel import convert_model
+ >>> convert_model(
+ ... input_dir="path/to/hf_checkpoint",
+ ... output_dir="path/to/anymodel_checkpoint",
+ ... converter="llama",
+ ... )
+
+Supported models:
+ - llama: Llama 2, Llama 3, Llama 3.1, Llama 3.2
+ - (more to come: qwen2, mistral_small, etc.)
+"""
+
+# Import models to trigger factory registration
+from modelopt.torch.puzzletron.anymodel import models # noqa: F401
+from modelopt.torch.puzzletron.anymodel.converter import Converter, ConverterFactory, convert_model
+from modelopt.torch.puzzletron.anymodel.model_descriptor import (
+ ModelDescriptor,
+ ModelDescriptorFactory,
+)
+from modelopt.torch.puzzletron.anymodel.puzzformer import (
+ MatchingZeros,
+ Same,
+ deci_x_patcher,
+ return_tuple_of_size,
+)
+
+__all__ = [
+ "Converter",
+ "ConverterFactory",
+ "ModelDescriptor",
+ "ModelDescriptorFactory",
+ "deci_x_patcher",
+ "MatchingZeros",
+ "Same",
+ "return_tuple_of_size",
+ "convert_model",
+]
diff --git a/modelopt/torch/puzzletron/anymodel/converter/__init__.py b/modelopt/torch/puzzletron/anymodel/converter/__init__.py
new file mode 100644
index 0000000000..02903b817d
--- /dev/null
+++ b/modelopt/torch/puzzletron/anymodel/converter/__init__.py
@@ -0,0 +1,19 @@
+# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+# SPDX-License-Identifier: Apache-2.0
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""Converters for transforming HuggingFace models to AnyModel format."""
+
+from .convert_any_model import *
+from .converter import *
+from .converter_factory import *
diff --git a/modelopt/torch/puzzletron/anymodel/converter/convert_any_model.py b/modelopt/torch/puzzletron/anymodel/converter/convert_any_model.py
new file mode 100644
index 0000000000..889685c001
--- /dev/null
+++ b/modelopt/torch/puzzletron/anymodel/converter/convert_any_model.py
@@ -0,0 +1,68 @@
+# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+# SPDX-License-Identifier: Apache-2.0
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# mypy: ignore-errors
+
+"""Convert a HuggingFace model to AnyModel format."""
+
+from pathlib import Path
+
+from modelopt.torch.puzzletron.anymodel.converter.converter import Converter
+from modelopt.torch.puzzletron.anymodel.converter.converter_factory import ConverterFactory
+from modelopt.torch.puzzletron.anymodel.model_descriptor import ModelDescriptorFactory
+
+__all__ = ["convert_model"]
+
+
+def convert_model(
+ input_dir: str,
+ output_dir: str,
+ converter: Converter | str,
+):
+ """Convert a HuggingFace model to AnyModel format.
+
+ This function converts a HuggingFace checkpoint to the AnyModel format used
+ for compression. The conversion process:
+
+ 1. Copies non-weight files (config, tokenizer, etc.)
+ 2. Creates block_configs for each layer
+ 3. Reorganizes weights into subblock checkpoints
+
+ Args:
+ input_dir: Path to the input HuggingFace checkpoint directory.
+ output_dir: Path to the output AnyModel checkpoint directory.
+ converter: Either a converter name (e.g., "llama") or a Converter class.
+
+ Example:
+ >>> convert_model(
+ ... input_dir="/path/to/Llama-3.1-8B-Instruct",
+ ... output_dir="/path/to/output/ckpts/teacher",
+ ... converter="llama",
+ ... )
+ """
+ input_dir = Path(input_dir)
+ output_dir = Path(output_dir)
+ output_dir.mkdir(parents=True, exist_ok=True)
+
+ # Get descriptor and converter from factories (they use the same name)
+ descriptor = ModelDescriptorFactory.get(converter)
+ converter = ConverterFactory.get(converter)
+
+ converter.convert(descriptor=descriptor, input_dir=input_dir, output_dir=output_dir)
+
+
+if __name__ == "__main__":
+ from fire import Fire
+
+ Fire(convert_model)
diff --git a/modelopt/torch/puzzletron/anymodel/converter/converter.py b/modelopt/torch/puzzletron/anymodel/converter/converter.py
new file mode 100644
index 0000000000..eb2330b515
--- /dev/null
+++ b/modelopt/torch/puzzletron/anymodel/converter/converter.py
@@ -0,0 +1,239 @@
+# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+# SPDX-License-Identifier: Apache-2.0
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# mypy: ignore-errors
+
+import copy
+import fnmatch
+import json
+import os
+import shutil
+from abc import ABC, abstractmethod
+from collections import defaultdict
+from pathlib import Path
+from typing import Dict, List
+
+from safetensors.torch import load_file, save_file
+from tqdm import tqdm
+from transformers import PretrainedConfig
+from transformers.integrations.mxfp4 import convert_moe_packed_tensors
+
+from modelopt.torch.puzzletron.anymodel.model_descriptor import ModelDescriptor
+from modelopt.torch.puzzletron.decilm.deci_lm_hf_code.block_config import BlockConfig
+from modelopt.torch.puzzletron.tools.checkpoint_utils_hf import load_model_config, save_model_config
+
+__all__ = ["Converter"]
+
+
+class Converter(ABC):
+ """Base class for converting HuggingFace models to Puzzletron/AnyModel format."""
+
+ @staticmethod
+ def _get_weight_map(input_dir: Path) -> Dict[str, str]:
+ """Load weight map from checkpoint directory (supports both sharded and single-file models).
+
+ Returns a dict mapping parameter names to their safetensors filenames.
+ """
+ index_path = input_dir / "model.safetensors.index.json"
+ single_file_path = input_dir / "model.safetensors"
+
+ if index_path.exists():
+ # Sharded model
+ with open(index_path, "r") as f:
+ index = json.load(f)
+ return index["weight_map"]
+ elif single_file_path.exists():
+ # Single file model - create a synthetic weight map
+ data = load_file(single_file_path)
+ return {name: "model.safetensors" for name in data.keys()}
+ else:
+ raise FileNotFoundError(
+ f"Neither {index_path} nor {single_file_path} found. Cannot determine model format."
+ )
+
+ @classmethod
+ def convert_model_weights(
+ cls, input_dir: Path, output_dir: Path, descriptor: ModelDescriptor, num_hidden_layers: int
+ ):
+ """Convert model weights to subblock format."""
+ param_to_file = Converter._get_weight_map(input_dir)
+ all_param_names = list(param_to_file.keys())
+
+ # Reverse map: file -> set of params
+ file_to_params = defaultdict(set)
+ for name, file in param_to_file.items():
+ file_to_params[file].add(name)
+
+ # Determine subblocks needed
+ subblocks = descriptor.get_weight_groups(
+ all_param_names, num_hidden_layers=num_hidden_layers
+ )
+
+ # Output directory
+ out_dir = output_dir / "subblocks_safetensors"
+ os.makedirs(out_dir, exist_ok=True)
+
+ # New weight index
+ new_index = {"metadata": {"format": "pt"}, "weight_map": {}}
+
+ for subblock, param_names in tqdm(subblocks.items(), desc="Processing subblocks"):
+ param_files = set(param_to_file[name] for name in param_names)
+ tensors = {}
+
+ # Load only needed files for this subblock
+ for file in param_files:
+ data = load_file(os.path.join(input_dir, file))
+ for name in param_names:
+ if param_to_file[name] == file and name in data:
+ converted_name = cls.convert_weight_name(name)
+ # Convert MoE packed tensors if quantized is mxfp4 //gpt-oss-20b
+ if getattr(cls, "quantized", None) == "mxfp4":
+ if name.endswith("_blocks"):
+ converted_name = converted_name.replace("_blocks", "")
+ tensors[converted_name] = convert_moe_packed_tensors(
+ data[converted_name + "_blocks"],
+ data[converted_name + "_scales"],
+ )
+ elif name.endswith("_scales"):
+ continue
+ else:
+ tensors[converted_name] = data[name]
+ else:
+ tensors[converted_name] = data[name]
+
+ # Save this subblock
+ print(f"\n✅ Group: {subblock} ({len(tensors)} layers)")
+ for layer in tensors.keys():
+ print(f" - {layer}")
+
+ subblock_file = f"{subblock}.safetensors"
+ save_file(tensors, os.path.join(out_dir, subblock_file))
+
+ # Update index
+ for new_name in tensors.keys():
+ new_index["weight_map"][new_name] = f"subblocks_safetensors/{subblock_file}"
+
+ # Save new index file
+ with (output_dir / "model.safetensors.index.json").open("w") as f:
+ json.dump(new_index, f, indent=2)
+
+ print(f"✅ Finished saving subblocks and index to {output_dir}")
+
+ @classmethod
+ def convert_configs_in_dirs(
+ cls,
+ input_dir: Path,
+ output_dir: Path,
+ trust_remote_code: bool = False,
+ ):
+ """Convert config and add block_configs."""
+ config = load_model_config(input_dir, trust_remote_code=trust_remote_code)
+
+ block_configs = cls.create_block_configs_from_main_config(config)
+ out_config = copy.deepcopy(config)
+ out_config.block_configs = block_configs
+
+ save_model_config(out_config, output_dir)
+ return out_config
+
+ @staticmethod
+ def copy_checkpoint_files(input_dir: Path, output_dir: Path):
+ """Copy checkpoint files except model weights (which will be converted)."""
+ ignore_patterns = [
+ "model-*.safetensors",
+ "model.safetensors",
+ "model.safetensors.index.json",
+ "subblocks_safetensors",
+ ]
+
+ def ignore_func(dir, files):
+ ignored = set()
+ for pattern in ignore_patterns:
+ ignored.update(fnmatch.filter(files, pattern))
+ return ignored
+
+ shutil.copytree(str(input_dir), str(output_dir), ignore=ignore_func, dirs_exist_ok=True)
+
+ @classmethod
+ def convert(
+ cls,
+ descriptor: ModelDescriptor,
+ input_dir: Path,
+ output_dir: Path,
+ ):
+ """Convert a HuggingFace model to AnyModel format.
+
+ Args:
+ descriptor: Model descriptor for the model type.
+ input_dir: Path to the input HuggingFace checkpoint.
+ output_dir: Path to the output AnyModel checkpoint.
+ """
+ cls.copy_checkpoint_files(input_dir, output_dir)
+ trust_remote_code = descriptor.requires_trust_remote_code()
+ config = cls.convert_configs_in_dirs(
+ input_dir, output_dir, trust_remote_code=trust_remote_code
+ )
+ cls.convert_model_weights(
+ input_dir, output_dir, descriptor=descriptor, num_hidden_layers=config.num_hidden_layers
+ )
+
+ @staticmethod
+ @abstractmethod
+ def create_block_configs_from_main_config(config: PretrainedConfig) -> List[BlockConfig]:
+ """Create per-layer BlockConfig list from a HuggingFace model config.
+
+ This method extracts layer-specific parameters (e.g., intermediate_size,
+ num_key_value_heads) from the main model config and creates a BlockConfig
+ for each layer. These BlockConfigs enable layer-specific pruning and
+ modifications during the compression pipeline.
+
+ Args:
+ config: HuggingFace PretrainedConfig (e.g., LlamaConfig, Qwen2Config)
+
+ Returns:
+ List of BlockConfig, one per hidden layer. Each BlockConfig contains:
+ - AttentionConfig: attention settings (no_op, num_key_value_heads)
+ - FFNConfig: FFN settings (no_op, intermediate_size)
+
+ Example:
+ For a model with uniform layers (e.g., Llama):
+ return [BlockConfig(...)] * config.num_hidden_layers
+
+ For a model with heterogeneous layers (e.g., NemotronH with Mamba/Attention):
+ return [BlockConfig(...) for layer_idx in range(num_layers)]
+ """
+ raise NotImplementedError
+
+ @staticmethod
+ def convert_weight_name(name: str) -> str:
+ """
+ Convert weight names during checkpoint conversion.
+
+ This method can be overridden by subclasses to apply model-specific weight name
+ transformations when converting checkpoints from HuggingFace format to Puzzletron format.
+
+ Default implementation returns the name unchanged (identity function).
+
+ Args:
+ name: Original weight name from HuggingFace checkpoint
+
+ Returns:
+ Converted weight name for Puzzletron format
+
+ Example:
+ For Qwen2.5-VL, this converts:
+ - visual.* → model.visual.*
+ - model.* → model.language_model.*
+ """
+ return name
diff --git a/modelopt/torch/puzzletron/anymodel/converter/converter_factory.py b/modelopt/torch/puzzletron/anymodel/converter/converter_factory.py
new file mode 100644
index 0000000000..88d490d653
--- /dev/null
+++ b/modelopt/torch/puzzletron/anymodel/converter/converter_factory.py
@@ -0,0 +1,75 @@
+# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+# SPDX-License-Identifier: Apache-2.0
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# mypy: ignore-errors
+
+import inspect
+from typing import Callable, Type
+
+from modelopt.torch.puzzletron.anymodel.model_descriptor import ModelDescriptor
+
+__all__ = ["ConverterFactory"]
+
+
+class ConverterFactory:
+ """Factory for registering and retrieving Converter classes."""
+
+ CLASS_MAPPING = {}
+
+ @classmethod
+ def register(cls, **entries: Type):
+ """Register converter classes.
+
+ Raises:
+ KeyError: if entry key is already in type_dict and points to a different class.
+ """
+ for cls_name, cls_type in entries.items():
+ if cls_name in cls.CLASS_MAPPING:
+ ref = cls.CLASS_MAPPING[cls_name]
+ # If ref and cls_name point to the same class ignore and don't raise an exception.
+ if cls_type == ref:
+ continue
+ raise KeyError(
+ f"Could not register `{cls_name}`: {cls_type}, "
+ f"`{cls_name}` is already registered and points to "
+ f"`{inspect.getmodule(ref).__name__}.{ref.__name__}`"
+ )
+ cls.CLASS_MAPPING[cls_name] = cls_type
+
+ @classmethod
+ def register_decorator(cls, name: str | None) -> Callable:
+ """Set up a register decorator.
+
+ Args:
+ name: If specified, the decorated object will be registered with this name.
+
+ Returns:
+ Decorator that registers the callable.
+ """
+
+ def decorator(cls_type: Type) -> Callable:
+ """Register the decorated callable."""
+ cls_name = name if name is not None else cls_type.__name__
+ cls.register(**{cls_name: cls_type})
+ return cls_type
+
+ return decorator
+
+ @classmethod
+ def get(cls, value: str | ModelDescriptor):
+ """Get a registered converter by name or return the converter if already resolved."""
+ if isinstance(value, str):
+ if value in cls.CLASS_MAPPING:
+ return cls.CLASS_MAPPING[value]
+ return value
diff --git a/modelopt/torch/puzzletron/anymodel/model_descriptor/__init__.py b/modelopt/torch/puzzletron/anymodel/model_descriptor/__init__.py
new file mode 100644
index 0000000000..cc8e89e34b
--- /dev/null
+++ b/modelopt/torch/puzzletron/anymodel/model_descriptor/__init__.py
@@ -0,0 +1,18 @@
+# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+# SPDX-License-Identifier: Apache-2.0
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""Model descriptors for defining model-specific properties and layer naming conventions."""
+
+from .model_descriptor import *
+from .model_descriptor_factory import *
diff --git a/modelopt/torch/puzzletron/anymodel/model_descriptor/model_descriptor.py b/modelopt/torch/puzzletron/anymodel/model_descriptor/model_descriptor.py
new file mode 100644
index 0000000000..4cc4356c8e
--- /dev/null
+++ b/modelopt/torch/puzzletron/anymodel/model_descriptor/model_descriptor.py
@@ -0,0 +1,228 @@
+# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+# SPDX-License-Identifier: Apache-2.0
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import re
+from abc import ABC, abstractmethod
+from collections import defaultdict
+from typing import Any, Dict, Iterable, List, Type
+
+import torch.nn as nn
+
+from modelopt.torch.puzzletron.decilm.deci_lm_hf_code.block_config import BlockConfig
+from modelopt.torch.puzzletron.utils.dummy_modules import DummyBlock
+
+__all__ = ["ModelDescriptor"]
+
+
+class ModelDescriptor(ABC):
+ @staticmethod
+ @abstractmethod
+ def decoder_layer_cls() -> Type[nn.Module] | List[Type[nn.Module]]:
+ """Decoder layer class types to patch for heterogeneous config support.
+
+ In most cases this class will hold as attributes both FFN & attention layers.
+
+ Returns:
+ nn.Module class type or a list if several class types should be patched.
+ """
+ raise NotImplementedError
+
+ @staticmethod
+ @abstractmethod
+ def block_config_to_layer_overrides(block_config: BlockConfig) -> Dict[str, Any]:
+ """Map between BlockConfig and layer config overrides.
+
+ These overrides are consumed by a specific decoder layer and by the whole model.
+ Usage can be seen in `deci_x_patcher` under the method `_patched_decoder_layer_init`.
+
+ Example implementation to override the FFN intermediate size of a block:
+ >>> def block_config_to_layer_overrides(block_config: BlockConfig) -> Dict[str, Any]:
+ >>> return {"intermediate_size": block_config.ffn.intermediate_size}
+ """
+ raise NotImplementedError
+
+ @staticmethod
+ def requires_trust_remote_code() -> bool:
+ """Whether this model descriptor requires trust_remote_code=True for loading.
+
+ Models that use custom code (e.g., via auto_map in config) should override
+ this to return True.
+
+ Returns:
+ True if trust_remote_code=True is required, False otherwise.
+ """
+ return False
+
+ @staticmethod
+ def mlp_no_op_post_init(decoder_layer: nn.Module):
+ """Post-init callback to alter a decoder layer so that FFN/mlp subblock performs as no-op.
+
+ It is recommended to use the utils modules from `no_op.py` to replace layers to dummy
+ counterparts.
+
+ Example for replacing a layernorm layer with identity:
+
+ >>> decoder_layer.post_attention_layernorm = Same()
+
+ Example for replacing an MLP layer with zeroes (zeroes since hidden_states are added to
+ the residuals hidden_states so a no-op implementation will leave residual the same):
+
+ >>> decoder_layer.mlp = MatchingZeros()
+
+ In case the MLP layer to replace returns multiple outputs i.e `hidden_states, _ = self.mlp()`,
+ use the util method `return_tuple_of_size` to return trailing None values:
+
+ >>> decoder_layer.mlp = return_tuple_of_size(MatchingZeros, size=2)()
+ """
+ raise NotImplementedError
+
+ @staticmethod
+ def attn_no_op_post_init(decoder_layer: nn.Module):
+ """Post-init callback to alter a decoder layer so that Attention subblock performs as no-op.
+
+ It is recommended to use the utils modules from `no_op.py` to replace layers to dummy
+ counterparts.
+
+ Example for replacing a layernorm layer with identity:
+
+ >>> decoder_layer.post_attention_layernorm = Same()
+
+ Example for replacing an attention layer with zeroes:
+
+ >>> decoder_layer.self_attn = MatchingZeros()
+
+ In case the attention layer returns multiple outputs i.e `hidden_states, _ = self.self_attn()`,
+ use the util method `return_tuple_of_size` to return trailing None values:
+
+ >>> decoder_layer.self_attn = return_tuple_of_size(MatchingZeros, size=2)()
+ """
+ raise NotImplementedError
+
+ @staticmethod
+ @abstractmethod
+ def init_rotary_embedding(model, runtime):
+ """Re-initiate the rotary embeddings based on an existing model.
+
+ In puzzletron we initiate a sharded model by first creating a meta model then replacing
+ to the actual device by loading the state_dict with the real weights.
+
+ Rotary embeddings frequencies are tensor buffers that are created dynamically during init
+ and are not part of the model state_dict, so cannot be restored after a meta device
+ initialization.
+ """
+ raise NotImplementedError
+
+ @staticmethod
+ @abstractmethod
+ def input_embedding_name():
+ """Return the name of the input embedding layer."""
+ raise NotImplementedError
+
+ @staticmethod
+ @abstractmethod
+ def output_embedding_name():
+ """Return the name of the output embedding layer."""
+ raise NotImplementedError
+
+ @staticmethod
+ @abstractmethod
+ def final_norm_name():
+ """Return the name of the final normalization layer."""
+ raise NotImplementedError
+
+ @staticmethod
+ @abstractmethod
+ def layer_block_name(index: int):
+ """Return the name of the decoder layer at the given index."""
+ raise NotImplementedError
+
+ @staticmethod
+ @abstractmethod
+ def layer_name_predicates(num_layers: int) -> Dict[str, re.Pattern]:
+ """Return predicates for grouping model weights to support subblock checkpointing.
+
+ For every group name return a regex predicate whether a layer name is part of the group.
+
+ Returns:
+ Dictionary of group name to regex pattern predicate.
+ """
+ raise NotImplementedError
+
+ @staticmethod
+ def uses_autocast() -> bool:
+ """Whether this model supports torch.autocast.
+
+ Some models (e.g., Qwen3-VL MoE) have dtype bugs under autocast.
+ Override and return False for models that do not support autocast.
+ """
+ return True
+
+ @staticmethod
+ def get_language_model_config(config):
+ """Get the language model config from a PretrainedConfig.
+
+ For regular LM models, returns the config itself.
+ For VL/multimodal models with nested configs, override to return the
+ language model portion (e.g., config.text_config for Qwen-VL).
+ """
+ return config
+
+ @classmethod
+ def create_dummy_block(cls, original_layer: nn.Module, block_index: int) -> nn.Module:
+ """Create a dummy block to replace a layer for sharded model initialization."""
+ return DummyBlock(block_index=block_index)
+
+ @classmethod
+ def mlp_no_op_supported(cls) -> bool:
+ """Check whether `mlp_no_op_post_init` is overridden for mlp no-op support."""
+ method_name = ModelDescriptor.mlp_no_op_post_init.__name__
+ return getattr(cls, method_name) is not getattr(ModelDescriptor, method_name)
+
+ @classmethod
+ def attn_no_op_supported(cls):
+ """Check whether `attn_no_op_post_init` is overridden for attention no-op support."""
+ method_name = ModelDescriptor.attn_no_op_post_init.__name__
+ return getattr(cls, method_name) is not getattr(ModelDescriptor, method_name)
+
+ @classmethod
+ def get_weight_groups(
+ cls, layer_names: Iterable[str], num_hidden_layers: int
+ ) -> Dict[str, List[str]]:
+ """Group model weights to support the puzzle subblock checkpointing format.
+
+ This method uses the abstract method `layer_name_predicates` by default.
+
+ Args:
+ layer_names: state_dict layer names of the model.
+ num_hidden_layers: number of decoder layers in the model.
+
+ Returns:
+ Dictionary of group names to list of layer names per group, e.g.:
+ >>> {
+ ... "embedding": ["model.embed_tokens.weight"],
+ ... "lm_head": ["lm_head.weight", "model.norm.weight"],
+ ... "block_0_ffn": ["model.layers.0.mlp.down_proj", ...],
+ ... "block_0_attention": ["model.layers.0.self_attn.q_proj", ...],
+ ... }
+ """
+ weight_groups = defaultdict(list)
+ for name in layer_names:
+ for group, pattern in cls.layer_name_predicates(num_hidden_layers).items():
+ if pattern.match(name):
+ weight_groups[group].append(name)
+ break
+ else:
+ raise ValueError(f"Couldn't find a match for {name}")
+ return weight_groups
diff --git a/modelopt/torch/puzzletron/anymodel/model_descriptor/model_descriptor_factory.py b/modelopt/torch/puzzletron/anymodel/model_descriptor/model_descriptor_factory.py
new file mode 100644
index 0000000000..badbe2b0e3
--- /dev/null
+++ b/modelopt/torch/puzzletron/anymodel/model_descriptor/model_descriptor_factory.py
@@ -0,0 +1,122 @@
+# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+# SPDX-License-Identifier: Apache-2.0
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# mypy: ignore-errors
+
+import inspect
+from typing import Callable, Type
+
+from transformers import AutoConfig
+
+from modelopt.torch.puzzletron.anymodel.model_descriptor.model_descriptor import ModelDescriptor
+
+__all__ = ["ModelDescriptorFactory"]
+
+# Map from HuggingFace config.model_type (in checkpoint config.json) to ModelDescriptorFactory name.
+# Local to this script; add entries when supporting new model types for auto-detection.
+_MODEL_TYPE_TO_DESCRIPTOR = {
+ "llama": "llama",
+ "mistral": "mistral_small",
+ "qwen2": "qwen2",
+ "qwen3": "qwen3",
+ "nemotron_h": "nemotron_h",
+ "nemotron_h_v2": "nemotron_h_v2",
+ "gpt_oss_20b": "gpt_oss_20b",
+}
+
+
+def resolve_descriptor_from_pretrained(pretrained: str, trust_remote_code: bool = False):
+ """Resolve the model descriptor by loading the checkpoint config and mapping model_type.
+
+ Args:
+ pretrained: Path to a pretrained model checkpoint or HuggingFace model identifier.
+ trust_remote_code: If True, allows execution of custom code from the model repository.
+ This is a security risk if the model source is untrusted. Only set to True if you
+ trust the source of the model. Defaults to False for security.
+
+ Returns:
+ The resolved ModelDescriptor class for the detected model type.
+
+ Raises:
+ ValueError: If pretrained is not provided or if the model type cannot be auto-detected.
+ """
+
+ config = AutoConfig.from_pretrained(pretrained, trust_remote_code=trust_remote_code)
+ model_type = getattr(config, "model_type", None)
+
+ if model_type and model_type in _MODEL_TYPE_TO_DESCRIPTOR:
+ detected = _MODEL_TYPE_TO_DESCRIPTOR[model_type]
+ print(
+ f"[resolve_descriptor_from_pretrained] Auto-detected model_type='{model_type}' → descriptor='{detected}'"
+ )
+ return ModelDescriptorFactory.get(detected)
+
+ known = sorted(_MODEL_TYPE_TO_DESCRIPTOR.keys())
+ raise ValueError(
+ f"Cannot auto-detect descriptor for model_type='{model_type}'. "
+ f"Known model types: {known}. Add this model_type to _MODEL_TYPE_TO_DESCRIPTOR if supported."
+ )
+
+
+class ModelDescriptorFactory:
+ """Factory for registering and retrieving ModelDescriptor classes."""
+
+ CLASS_MAPPING = {}
+
+ @classmethod
+ def register(cls, **entries: Type):
+ """Register model descriptor classes.
+
+ Raises:
+ KeyError: if entry key is already in type_dict and points to a different class.
+ """
+ for cls_name, cls_type in entries.items():
+ if cls_name in cls.CLASS_MAPPING:
+ ref = cls.CLASS_MAPPING[cls_name]
+ # If ref and cls_name point to the same class ignore and don't raise an exception.
+ if cls_type == ref:
+ continue
+ raise KeyError(
+ f"Could not register `{cls_name}`: {cls_type}, "
+ f"`{cls_name}` is already registered and points to "
+ f"`{inspect.getmodule(ref).__name__}.{ref.__name__}`"
+ )
+ cls.CLASS_MAPPING[cls_name] = cls_type
+
+ @classmethod
+ def register_decorator(cls, name: str | None) -> Callable:
+ """Set up a register decorator.
+
+ Args:
+ name: If specified, the decorated object will be registered with this name.
+
+ Returns:
+ Decorator that registers the callable.
+ """
+
+ def decorator(cls_type: Type) -> Callable:
+ """Register the decorated callable."""
+ cls_name = name if name is not None else cls_type.__name__
+ cls.register(**{cls_name: cls_type})
+ return cls_type
+
+ return decorator
+
+ @classmethod
+ def get(cls, value: str | ModelDescriptor):
+ """Get a registered model descriptor by name or return the descriptor if already resolved."""
+ if isinstance(value, str):
+ if value in cls.CLASS_MAPPING:
+ return cls.CLASS_MAPPING[value]
+ return value
diff --git a/modelopt/torch/puzzletron/anymodel/models/__init__.py b/modelopt/torch/puzzletron/anymodel/models/__init__.py
new file mode 100644
index 0000000000..4c68dbc823
--- /dev/null
+++ b/modelopt/torch/puzzletron/anymodel/models/__init__.py
@@ -0,0 +1,24 @@
+# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+# SPDX-License-Identifier: Apache-2.0
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+# Import models to trigger factory registration
+from modelopt.torch.puzzletron.anymodel.models.gpt_oss import *
+from modelopt.torch.puzzletron.anymodel.models.llama import *
+from modelopt.torch.puzzletron.anymodel.models.mistral_small import *
+from modelopt.torch.puzzletron.anymodel.models.nemotron_h import *
+from modelopt.torch.puzzletron.anymodel.models.nemotron_h_v2 import *
+from modelopt.torch.puzzletron.anymodel.models.qwen2 import *
+from modelopt.torch.puzzletron.anymodel.models.qwen3 import *
+from modelopt.torch.puzzletron.anymodel.models.qwen3_vl import *
diff --git a/modelopt/torch/puzzletron/anymodel/models/gpt_oss/__init__.py b/modelopt/torch/puzzletron/anymodel/models/gpt_oss/__init__.py
new file mode 100644
index 0000000000..9f72b8dd78
--- /dev/null
+++ b/modelopt/torch/puzzletron/anymodel/models/gpt_oss/__init__.py
@@ -0,0 +1,22 @@
+# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+# SPDX-License-Identifier: Apache-2.0
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+# SPDX-License-Identifier: Apache-2.0
+
+"""GPT-OSS model support for AnyModel."""
+
+from .gpt_oss_converter import GptOssConverter
+from .gpt_oss_model_descriptor import GptOssModelDescriptor
diff --git a/modelopt/torch/puzzletron/anymodel/models/gpt_oss/gpt_oss_converter.py b/modelopt/torch/puzzletron/anymodel/models/gpt_oss/gpt_oss_converter.py
new file mode 100644
index 0000000000..3e7371aaee
--- /dev/null
+++ b/modelopt/torch/puzzletron/anymodel/models/gpt_oss/gpt_oss_converter.py
@@ -0,0 +1,74 @@
+# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+# SPDX-License-Identifier: Apache-2.0
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# mypy: ignore-errors
+
+"""GPT-OSS-20B converter for AnyModel compression."""
+
+from typing import List
+
+from transformers import PretrainedConfig
+
+from modelopt.torch.puzzletron.anymodel.converter import Converter, ConverterFactory
+from modelopt.torch.puzzletron.decilm.deci_lm_hf_code.block_config import (
+ AttentionConfig,
+ BlockConfig,
+ FFNConfig,
+ MoEConfig,
+)
+
+
+@ConverterFactory.register_decorator("gpt_oss")
+class GptOssConverter(Converter):
+ """Converter for GPT-OSS models to AnyModel format.
+
+ GPT-OSS is a pure MoE model with 32/128 experts per layer and 4/16 active experts.
+ All layers use MoE FFN (no standard dense FFN layers).
+ """
+
+ quantized = "mxfp4"
+
+ @staticmethod
+ def create_block_configs_from_main_config(config: PretrainedConfig) -> List[BlockConfig]:
+ """Create block configs for GPT-OSS layers.
+
+ GPT-OSS uses MoE for all FFN layers with:
+ - 32/128 local experts (num_local_experts)
+ - 4/16 active experts per token (experts_per_token)
+ - No dense/standard FFN layers
+ """
+ num_hidden_layers = config.num_hidden_layers
+ num_local_experts = config.num_local_experts
+ experts_per_token = config.experts_per_token
+ intermediate_size = config.intermediate_size
+
+ block_configs = []
+ for layer_idx in range(num_hidden_layers):
+ block_config = BlockConfig(
+ attention=AttentionConfig(
+ no_op=False, num_key_value_heads=config.num_key_value_heads
+ ),
+ ffn=FFNConfig(
+ no_op=False,
+ intermediate_size=None, # MoE doesn't use this field
+ moe=MoEConfig(
+ num_local_experts=num_local_experts,
+ num_experts_per_tok=experts_per_token,
+ expert_intermediate_dim=intermediate_size,
+ ),
+ ),
+ ).to_dict()
+ block_configs.append(block_config)
+
+ return block_configs
diff --git a/modelopt/torch/puzzletron/anymodel/models/gpt_oss/gpt_oss_model_descriptor.py b/modelopt/torch/puzzletron/anymodel/models/gpt_oss/gpt_oss_model_descriptor.py
new file mode 100644
index 0000000000..c77a4547f0
--- /dev/null
+++ b/modelopt/torch/puzzletron/anymodel/models/gpt_oss/gpt_oss_model_descriptor.py
@@ -0,0 +1,236 @@
+# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+# SPDX-License-Identifier: Apache-2.0
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# mypy: ignore-errors
+
+"""GPT-OSS model descriptor for AnyModel compression."""
+
+import re
+from dataclasses import dataclass, field
+from typing import Dict, List, Tuple, Type
+
+import torch.nn as nn
+from transformers.models.gpt_oss.modeling_gpt_oss import GptOssDecoderLayer, GptOssRotaryEmbedding
+
+from modelopt.torch.puzzletron.anymodel.model_descriptor import (
+ ModelDescriptor,
+ ModelDescriptorFactory,
+)
+from modelopt.torch.puzzletron.anymodel.puzzformer.no_op import (
+ MatchingZeros,
+ Same,
+ return_tuple_of_size,
+)
+from modelopt.torch.puzzletron.decilm.deci_lm_hf_code.block_config import BlockConfig
+from modelopt.torch.puzzletron.pruning.expert_removal_pruning_mixin import (
+ ExpertRemovalLayerDescriptor,
+ ExpertRemovalPruningMixIn,
+)
+
+# Expert removal is supported for unquantized models (test models).
+# Production models use MXFP4 quantized MoE with combined tensors
+# (gate_up_proj_blocks, down_proj_blocks), which is not yet supported.
+from modelopt.torch.puzzletron.pruning.pruning_mixin import PruningMixIn
+from modelopt.torch.puzzletron.utils.dummy_modules import DummyBlock
+
+
+@ModelDescriptorFactory.register_decorator("gpt_oss")
+class GptOssModelDescriptor(ModelDescriptor):
+ """Model descriptor for GPT-OSS (pure MoE model)."""
+
+ _DECODER_LAYER_CLS: Type[nn.Module] = None
+
+ @classmethod
+ def create_dummy_block(cls, original_layer: GptOssDecoderLayer, block_index: int) -> nn.Module:
+ dummy_block = DummyBlock(block_index=block_index)
+ # Required by `GptOssModel.forward`.
+ dummy_block.attention_type = original_layer.attention_type
+ return dummy_block
+
+ @staticmethod
+ def decoder_layer_cls():
+ """Get the decoder layer class for GPT-OSS models.
+
+ GPT-OSS is a standard transformers model in recent versions.
+ Import directly from transformers.models.gpt_oss.modeling_gpt_oss.
+ """
+ return GptOssDecoderLayer
+
+ @staticmethod
+ def block_config_to_layer_overrides(block_config: BlockConfig):
+ """Map BlockConfig to layer constructor overrides."""
+ override_kwargs = {}
+
+ if block_config.attention.num_key_value_heads is not None:
+ override_kwargs["num_key_value_heads"] = block_config.attention.num_key_value_heads
+
+ if block_config.ffn.moe is not None:
+ override_kwargs["moe_intermediate_size"] = block_config.ffn.moe.expert_intermediate_dim
+ override_kwargs["num_local_experts"] = block_config.ffn.moe.num_local_experts
+ override_kwargs["num_experts_per_tok"] = block_config.ffn.moe.num_experts_per_tok
+
+ return override_kwargs
+
+ @staticmethod
+ def attn_no_op_post_init(decoder_layer):
+ """Replace attention sublayers with no-op modules."""
+ decoder_layer.input_layernorm = Same()
+ decoder_layer.self_attn = return_tuple_of_size(MatchingZeros, size=2)()
+
+ @staticmethod
+ def mlp_no_op_post_init(decoder_layer):
+ """Replace MLP sublayers with no-op modules.
+
+ Note: GPT-OSS MoE layers return (hidden_states, router_scores), so we need
+ to return a tuple of 2 values.
+ """
+ decoder_layer.post_attention_layernorm = Same()
+ decoder_layer.mlp = return_tuple_of_size(MatchingZeros, size=2)()
+
+ @staticmethod
+ def init_rotary_embedding(model, runtime):
+ """Initialize rotary embeddings on the correct device."""
+ # GPT-OSS uses RoPE with YARN scaling
+
+ model.model.rotary_emb = GptOssRotaryEmbedding(
+ config=model.config,
+ device=runtime.device,
+ )
+
+ @staticmethod
+ def input_embedding_name():
+ return "model.embed_tokens"
+
+ @staticmethod
+ def output_embedding_name():
+ return "lm_head"
+
+ @staticmethod
+ def final_norm_name():
+ return "model.norm"
+
+ @staticmethod
+ def layer_block_name(index: int):
+ return f"model.layers.{index}"
+
+ @staticmethod
+ def layer_name_predicates(num_layers: int) -> Dict[str, re.Pattern]:
+ """Define regex patterns for grouping weights into subblocks."""
+ layer_name_patterns = {
+ "embeddings": re.compile(r"^model\.embed_tokens\.weight$"),
+ "lm_head": re.compile(r"^(model\.norm\.weight|lm_head\.weight)$"),
+ }
+
+ def build_ffn_predicates() -> Dict[str, re.Pattern]:
+ """FFN is MoE in GPT-OSS with MXFP4 quantization."""
+ return {
+ f"block_{layer_idx}_ffn": re.compile(
+ rf"^model\.layers\.{layer_idx}\."
+ r"(post_attention_layernorm\.weight"
+ r"|mlp\.router\.weight"
+ r"|mlp\.router\.bias"
+ r"|mlp\.experts\.(gate_up_proj|down_proj)(_(bias|blocks|scales))?)$"
+ )
+ for layer_idx in range(num_layers)
+ }
+
+ def build_attention_predicates() -> Dict[str, re.Pattern]:
+ return {
+ f"block_{layer_idx}_attention": re.compile(
+ rf"^model\.layers\.{layer_idx}\."
+ r"(input_layernorm\.weight"
+ r"|self_attn\.q_proj\.weight"
+ r"|self_attn\.q_proj\.bias"
+ r"|self_attn\.k_proj\.weight"
+ r"|self_attn\.k_proj\.bias"
+ r"|self_attn\.v_proj\.weight"
+ r"|self_attn\.v_proj\.bias"
+ r"|self_attn\.o_proj\.weight"
+ r"|self_attn\.o_proj\.bias"
+ r"|self_attn\.sinks)$"
+ )
+ for layer_idx in range(num_layers)
+ }
+
+ layer_name_patterns.update(
+ **build_ffn_predicates(),
+ **build_attention_predicates(),
+ )
+
+ return layer_name_patterns
+
+ @staticmethod
+ def pruning_mixins() -> Dict[str, PruningMixIn]:
+ """Return available pruning mixins for GPT-OSS.
+
+ Note: Expert removal works for unquantized models (test models).
+ Production models use MXFP4 quantization which is not yet supported.
+ """
+ return {"expert_removal": ExpertRemovalPruningMixIn(GptOssExpertRemovalLayerDescriptor())}
+
+
+@dataclass
+class GptOssExpertRemovalLayerDescriptor(ExpertRemovalLayerDescriptor):
+ """
+ GPT-OSS MoE layer descriptor for expert removal.
+
+ Note: This only works for unquantized models (e.g., test models).
+ Production GPT-OSS models use MXFP4 quantization with fused experts
+ (_blocks, _scales, _bias), which requires a different approach.
+
+ Structure:
+ - Router: mlp.router with .weight and .bias
+ - Experts: mlp.experts.{idx}.{gate_up_proj,down_proj} with .weight and .bias
+ """
+
+ target_name: str = "mlp"
+ moe_prefix_name: str = "model.layers.{layer_idx}.mlp"
+ expert_prefix_name: str = "experts"
+
+ # Router has both weight and bias
+ router_weights: List[str] = field(default_factory=lambda: ["router.weight"])
+ router_biases: List[str] = field(default_factory=lambda: ["router.bias"])
+
+ # Fused format: experts stored as single tensors
+ is_fused_experts: bool = True
+
+ # Fused format: single tensors containing all experts (test models)
+ fused_expert_weights: List[str] = field(
+ default_factory=lambda: [
+ "experts.gate_up_proj",
+ "experts.gate_up_proj_bias",
+ "experts.down_proj",
+ "experts.down_proj_bias",
+ ]
+ )
+
+ # Not used for fused format, but kept for compatibility
+ expert_weights: List[str] = field(default_factory=lambda: ["gate_up_proj", "down_proj"])
+ expert_biases: List[str] = field(
+ default_factory=lambda: ["gate_up_proj_bias", "down_proj_bias"]
+ )
+
+ def get_modules_names_to_hook(self, model) -> List[Tuple[int, str]]:
+ target_class_name = "GptOssTopKRouter"
+
+ module_names_to_hook = []
+ for module_name, module in model.named_modules():
+ if (
+ module_name.endswith(self.target_name)
+ and module.__class__.__name__ == target_class_name
+ ):
+ module_names_to_hook.append(
+ (self.block_idx_from_module_name(module_name), module_name)
+ )
+ return module_names_to_hook
diff --git a/modelopt/torch/puzzletron/anymodel/models/gpt_oss/gpt_oss_pruned_to_mxfp4.py b/modelopt/torch/puzzletron/anymodel/models/gpt_oss/gpt_oss_pruned_to_mxfp4.py
new file mode 100644
index 0000000000..64d18921fd
--- /dev/null
+++ b/modelopt/torch/puzzletron/anymodel/models/gpt_oss/gpt_oss_pruned_to_mxfp4.py
@@ -0,0 +1,524 @@
+# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+# SPDX-License-Identifier: Apache-2.0
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""
+Create a HuggingFace checkpoint with MXFP4 MoE weights from the original gpt-oss-120b model.
+
+This script:
+1. Copies non-MoE weights from the student model (trained attention, embeddings, etc.)
+2. Extracts MoE expert weights from the original gpt-oss-120b in MXFP4 format
+3. Deduces expert mappings by comparing weights
+4. Outputs a new pruned (heterogeneous) checkpoint with PACKED MXFP4 expert weights
+"""
+
+import argparse
+import json
+import os
+import shutil
+from typing import Any, Dict, List, Optional, TextIO, Tuple
+
+import torch
+from safetensors import safe_open
+from safetensors.torch import save_file
+from tqdm import tqdm
+from transformers.integrations.mxfp4 import convert_moe_packed_tensors
+
+
+def deduce_experts_for_layer(
+ layer: int,
+ original_path: str,
+ original_index: Dict,
+ student_path: str,
+) -> Tuple[List[int], int, int]:
+ """
+ Deduce which original experts match the student experts by comparing weights.
+
+ Compares dequantized MXFP4 weights from the original model against the student
+ model's BF16 weights using L2 distance. Finds the best 1-to-1 matching.
+
+ Args:
+ layer: Layer index
+ original_path: Path to original model
+ original_index: Original model's safetensors index
+ student_path: Path to student model
+ num_student_experts: Number of experts in student model (if None, auto-detect)
+
+ Returns:
+ Tuple of (expert_indices, num_student_experts, num_original_experts)
+ """
+ # Load original tensors
+ orig_tensors = load_layer_tensors(original_path, layer, original_index)
+ mlp1_blocks = orig_tensors[f"model.layers.{layer}.mlp.experts.gate_up_proj_blocks"]
+ mlp1_scales = orig_tensors[f"model.layers.{layer}.mlp.experts.gate_up_proj_scales"]
+ mlp2_blocks = orig_tensors[f"model.layers.{layer}.mlp.experts.down_proj_blocks"]
+ mlp2_scales = orig_tensors[f"model.layers.{layer}.mlp.experts.down_proj_scales"]
+
+ num_original_experts = mlp1_blocks.shape[0]
+
+ # Load student tensors
+ student_subblocks = os.path.join(student_path, "subblocks_safetensors")
+ student_ffn = os.path.join(student_subblocks, f"block_{layer}_ffn.safetensors")
+ if not os.path.exists(student_ffn):
+ print(f"FFN file not found at {student_ffn} - fallback to no_op")
+ return [], 0, num_original_experts
+
+ student_experts = {}
+ with safe_open(student_ffn, framework="pt") as f:
+ for key in f.keys():
+ if "experts" in key or "router" in key:
+ student_experts[key] = f.get_tensor(key)
+
+ # Auto-detect number of student experts
+ num_student_experts = student_experts[f"model.layers.{layer}.mlp.experts.gate_up_proj"].size(0)
+ print(
+ f" Layer {layer}: Comparing {num_student_experts} student experts against {num_original_experts} original experts"
+ )
+
+ # Pre-dequantize all original experts once (optimization)
+ print(f" Pre-dequantizing {num_original_experts} original experts...")
+ deqexpert_mlp1 = convert_moe_packed_tensors(mlp1_blocks, mlp1_scales).cpu()
+ deqexpert_mlp2 = convert_moe_packed_tensors(mlp2_blocks, mlp2_scales).cpu()
+ original_experts_dequant = []
+ for orig_idx in range(num_original_experts):
+ original_experts_dequant.append(
+ {"up": deqexpert_mlp1[orig_idx], "down": deqexpert_mlp2[orig_idx]}
+ )
+
+ # For each student expert, find best matching original expert
+ experts_to_keep = []
+ used_original_indices = set()
+
+ # Number of values to use for quick comparison (tune this)
+ quick_compare_size = 8
+ # Number of candidates to keep for full comparison
+ top_k_candidates = min(10, num_original_experts)
+
+ for student_idx in range(num_student_experts):
+ # Get student expert weights
+ prefix = f"model.layers.{layer}.mlp"
+ student_up = student_experts.get(f"{prefix}.experts.gate_up_proj")[student_idx] # type: ignore[index]
+ student_down = student_experts.get(f"{prefix}.experts.down_proj")[student_idx] # type: ignore[index]
+
+ # if student_gate is None or student_up is None or student_down is None:
+ if student_up is None or student_down is None:
+ raise ValueError(
+ f"Missing student expert weights for layer {layer} expert {student_idx}"
+ )
+
+ # Step 1: Quick filtering using first N values
+ candidate_scores = []
+ for orig_idx in range(num_original_experts):
+ if orig_idx in used_original_indices:
+ continue
+
+ orig_expert = original_experts_dequant[orig_idx]
+
+ up_quick = (
+ (
+ orig_expert["up"].flatten()[:quick_compare_size]
+ - student_up.float().flatten()[:quick_compare_size]
+ )
+ .pow(2)
+ .mean()
+ .sqrt()
+ )
+ down_quick = (
+ (
+ orig_expert["down"].flatten()[:quick_compare_size]
+ - student_down.float().flatten()[:quick_compare_size]
+ )
+ .pow(2)
+ .mean()
+ .sqrt()
+ )
+
+ quick_score = (up_quick + down_quick) / 2.0
+ candidate_scores.append((orig_idx, quick_score.item()))
+
+ # Step 2: Get top-k candidates based on quick comparison
+ candidate_scores.sort(key=lambda x: x[1])
+ top_candidates = [idx for idx, _ in candidate_scores[:top_k_candidates]]
+
+ # Step 3: Full comparison only on top candidates
+ best_match_idx = None
+ best_match_score = float("inf")
+
+ for orig_idx in top_candidates:
+ orig_expert = original_experts_dequant[orig_idx]
+
+ # Full comparison across all values
+ up_diff = (orig_expert["up"] - student_up.float()).pow(2).mean().sqrt()
+ down_diff = (orig_expert["down"] - student_down.float()).pow(2).mean().sqrt()
+
+ score = (up_diff + down_diff) / 2.0
+
+ if score < best_match_score:
+ best_match_score = score
+ best_match_idx = orig_idx
+
+ if best_match_idx is None:
+ raise ValueError(
+ f"Could not find match for student expert {student_idx} in layer {layer}"
+ )
+
+ experts_to_keep.append(best_match_idx)
+ used_original_indices.add(best_match_idx)
+ print(
+ f" Student expert {student_idx} -> Original expert {best_match_idx} (RMSE: {best_match_score:.6f})"
+ )
+
+ return experts_to_keep, num_student_experts, num_original_experts
+
+
+def load_original_index(path: str) -> Dict[str, Any]:
+ """Load the original model's safetensors index."""
+ with open(path, "r") as f:
+ return json.load(f)
+
+
+def load_layer_tensors(original_path: str, layer: int, index: Dict) -> Dict[str, torch.Tensor]:
+ """Load all MoE-related tensors for a layer, potentially from multiple files."""
+ keys_to_load = [
+ f"model.layers.{layer}.mlp.experts.gate_up_proj_blocks",
+ f"model.layers.{layer}.mlp.experts.gate_up_proj_scales",
+ f"model.layers.{layer}.mlp.experts.gate_up_proj_bias",
+ f"model.layers.{layer}.mlp.experts.down_proj_blocks",
+ f"model.layers.{layer}.mlp.experts.down_proj_scales",
+ f"model.layers.{layer}.mlp.experts.down_proj_bias",
+ f"model.layers.{layer}.mlp.router.weight", # Router weight
+ f"model.layers.{layer}.mlp.router.bias", # Router bias
+ ]
+
+ # Group by file
+ file_to_keys = {}
+ for key in keys_to_load:
+ if key in index["weight_map"]:
+ filename = index["weight_map"][key]
+ if filename not in file_to_keys:
+ file_to_keys[filename] = []
+ file_to_keys[filename].append(key)
+
+ # Load from each file
+ tensors = {}
+ for filename, keys in file_to_keys.items():
+ filepath = os.path.join(original_path, filename)
+ with safe_open(filepath, framework="pt") as f:
+ for key in keys:
+ tensors[key] = f.get_tensor(key)
+
+ return tensors
+
+
+def copy_non_moe_weights(student_path: str, output_path: str, num_layers: int) -> Dict[str, str]:
+ """
+ Copy non-MoE weights from student model.
+ Returns weight_map for the new index.
+ """
+ weight_map = {}
+ subblocks_dir = os.path.join(output_path, "subblocks_safetensors")
+ os.makedirs(subblocks_dir, exist_ok=True)
+
+ student_subblocks = os.path.join(student_path, "subblocks_safetensors")
+
+ # Copy embeddings
+ src_emb = os.path.join(student_subblocks, "embeddings.safetensors")
+ dst_emb = os.path.join(subblocks_dir, "embeddings.safetensors")
+ shutil.copy2(src_emb, dst_emb)
+ with safe_open(src_emb, framework="pt") as f:
+ for key in f.keys():
+ weight_map[key] = "subblocks_safetensors/embeddings.safetensors"
+
+ # Copy lm_head
+ src_head = os.path.join(student_subblocks, "lm_head.safetensors")
+ dst_head = os.path.join(subblocks_dir, "lm_head.safetensors")
+ shutil.copy2(src_head, dst_head)
+ with safe_open(src_head, framework="pt") as f:
+ for key in f.keys():
+ weight_map[key] = "subblocks_safetensors/lm_head.safetensors"
+
+ # Copy attention blocks
+ for layer in range(num_layers):
+ src_attn = os.path.join(student_subblocks, f"block_{layer}_attention.safetensors")
+ dst_attn = os.path.join(subblocks_dir, f"block_{layer}_attention.safetensors")
+ shutil.copy2(src_attn, dst_attn)
+ with safe_open(src_attn, framework="pt") as f:
+ for key in f.keys():
+ weight_map[key] = f"subblocks_safetensors/block_{layer}_attention.safetensors"
+
+ return weight_map
+
+
+def process_single_layer(
+ layer: int,
+ original_path: str,
+ original_index: Dict,
+ student_path: str,
+ output_path: str,
+ experts_to_keep: List[int],
+) -> Tuple[Dict[str, str], List[str]]:
+ """
+ Process a single layer - loads tensors from potentially multiple files.
+ Returns (weight_map, verification_errors).
+ """
+ weight_map = {}
+ verification_errors = []
+ subblocks_dir = os.path.join(output_path, "subblocks_safetensors")
+ student_subblocks = os.path.join(student_path, "subblocks_safetensors")
+
+ # Load all tensors for this layer (may come from multiple files)
+ orig_tensors = load_layer_tensors(original_path, layer, original_index)
+
+ # Load student FFN file
+ student_ffn = os.path.join(student_subblocks, f"block_{layer}_ffn.safetensors")
+
+ tensors_to_save = {}
+ student_tensors = {}
+
+ with safe_open(student_ffn, framework="pt") as f:
+ for key in f.keys():
+ tensor = f.get_tensor(key)
+ if "experts" not in key and "router" not in key:
+ # Copy norm weights
+ tensors_to_save[key] = tensor
+
+ # Get router from original model, sliced to kept experts
+ orig_router_weight = orig_tensors[f"model.layers.{layer}.mlp.router.weight"]
+ orig_router_bias = orig_tensors[f"model.layers.{layer}.mlp.router.bias"]
+
+ kept_indices_tensor = torch.tensor(experts_to_keep, dtype=torch.long)
+ sliced_router_weight = orig_router_weight[kept_indices_tensor]
+ sliced_router_bias = orig_router_bias[kept_indices_tensor]
+
+ tensors_to_save[f"model.layers.{layer}.mlp.router.weight"] = sliced_router_weight
+ tensors_to_save[f"model.layers.{layer}.mlp.router.bias"] = sliced_router_bias
+
+ # Get MoE tensors
+ mlp1_blocks = orig_tensors[f"model.layers.{layer}.mlp.experts.gate_up_proj_blocks"]
+ mlp1_scales = orig_tensors[f"model.layers.{layer}.mlp.experts.gate_up_proj_scales"]
+ mlp2_blocks = orig_tensors[f"model.layers.{layer}.mlp.experts.down_proj_blocks"]
+ mlp2_scales = orig_tensors[f"model.layers.{layer}.mlp.experts.down_proj_scales"]
+ mlp1_bias = orig_tensors[f"model.layers.{layer}.mlp.experts.gate_up_proj_bias"]
+ mlp2_bias = orig_tensors[f"model.layers.{layer}.mlp.experts.down_proj_bias"]
+
+ tensors_to_save[f"model.layers.{layer}.mlp.experts.gate_up_proj_blocks"] = mlp1_blocks[
+ kept_indices_tensor
+ ]
+ tensors_to_save[f"model.layers.{layer}.mlp.experts.gate_up_proj_scales"] = mlp1_scales[
+ kept_indices_tensor
+ ]
+ tensors_to_save[f"model.layers.{layer}.mlp.experts.gate_up_proj_bias"] = mlp1_bias[
+ kept_indices_tensor
+ ]
+
+ tensors_to_save[f"model.layers.{layer}.mlp.experts.down_proj_blocks"] = mlp2_blocks[
+ kept_indices_tensor
+ ]
+ tensors_to_save[f"model.layers.{layer}.mlp.experts.down_proj_scales"] = mlp2_scales[
+ kept_indices_tensor
+ ]
+ tensors_to_save[f"model.layers.{layer}.mlp.experts.down_proj_bias"] = mlp2_bias[
+ kept_indices_tensor
+ ]
+
+ # Save the FFN file
+ output_file = os.path.join(subblocks_dir, f"block_{layer}_ffn.safetensors")
+ save_file(tensors_to_save, output_file)
+
+ # Build weight map
+ for key in tensors_to_save.keys():
+ weight_map[key] = f"subblocks_safetensors/block_{layer}_ffn.safetensors"
+
+ return weight_map, verification_errors
+
+
+def copy_config_files(student_path: str, output_path: str):
+ """Copy configuration files from student model and update config.json."""
+ files_to_copy = [
+ "tokenizer.json",
+ "tokenizer_config.json",
+ "special_tokens_map.json",
+ "chat_template.jinja",
+ ]
+
+ # Also copy transformers compatibility files
+ if os.path.exists(student_path):
+ for f in os.listdir(student_path):
+ if f.startswith("transformers_"):
+ files_to_copy.append(f)
+
+ for filename in files_to_copy:
+ src = os.path.join(student_path, filename)
+ dst = os.path.join(output_path, filename)
+
+ # Try student path first
+ if os.path.exists(src):
+ try:
+ shutil.copy2(src, dst)
+ continue
+ except PermissionError:
+ pass
+
+ # If we get here, file doesn't exist or permission denied
+ if not os.path.exists(dst):
+ print(f" Warning: Could not copy {filename}")
+
+ # Update config.json for DeciGptOssForCausalLM with MXFP4
+ src_config = os.path.join(student_path, "config.json")
+ if not os.path.exists(src_config):
+ raise FileNotFoundError(f"config.json not found at {src_config}")
+
+ with open(src_config, "r") as f:
+ config = json.load(f) # type: ignore[arg-type]
+
+ # Set architecture to DeciGptOssForCausalLM for MXFP4 support
+ config["architectures"] = ["DeciGptOssForCausalLM"]
+
+ # Add quantization_config so vllm calls _load_weights_mxfp4
+ config["quantization_config"] = {
+ "quant_method": "mxfp4",
+ "modules_to_not_convert": [
+ "model.layers.*.self_attn",
+ "model.layers.*.mlp.router",
+ "model.embed_tokens",
+ "lm_head",
+ ],
+ }
+
+ dst_config = os.path.join(output_path, "config.json")
+ with open(dst_config, "w") as f:
+ json.dump(config, f, indent=2) # type: ignore[arg-type]
+
+
+def main():
+ parser = argparse.ArgumentParser(description="Create MXFP4 checkpoint from student model")
+ parser.add_argument(
+ "--student-path", type=str, required=True, help="Path to student model checkpoint"
+ )
+ parser.add_argument(
+ "--original-path",
+ type=str,
+ required=True,
+ help="Path to original gpt-oss-120b model with MXFP4 weights",
+ )
+ parser.add_argument(
+ "--output-path", type=str, required=True, help="Output path for the new checkpoint"
+ )
+ parser.add_argument("--num-layers", type=int, default=36, help="Number of transformer layers")
+ args = parser.parse_args()
+
+ print(f"Creating MXFP4 checkpoint...")
+ print(f" Student model: {args.student_path}")
+ print(f" Original model: {args.original_path}")
+ print(f" Output: {args.output_path}")
+
+ # Load original model index
+ original_index = load_original_index(
+ os.path.join(args.original_path, "model.safetensors.index.json")
+ )
+
+ print("\nDeducing expert mappings by comparing weights...")
+ experts_to_keep = []
+ layer_statistics = [] # Store (num_student, num_original) for each layer
+
+ for layer in range(args.num_layers):
+ layer_experts, num_student, num_original = deduce_experts_for_layer(
+ layer,
+ args.original_path,
+ original_index,
+ args.student_path,
+ )
+ experts_to_keep.append(layer_experts)
+ layer_statistics.append((num_student, num_original))
+
+ # Print statistics
+ print(f"\n{'=' * 70}")
+ print("EXPERT DEDUCTION STATISTICS")
+ print(f"{'=' * 70}")
+ print(f"{'Layer':<8} {'Student Experts':<18} {'Original Experts':<18} {'Kept %':<10}")
+ print(f"{'-' * 70}")
+
+ total_student = 0
+ total_original = 0
+ for layer, (num_student, num_original) in enumerate(layer_statistics):
+ percentage = (num_student / num_original * 100) if num_original > 0 else 0
+ print(f"{layer:<8} {num_student:<18} {num_original:<18} {percentage:<10.2f}")
+ total_student += num_student
+ total_original += num_original
+
+ print(f"{'-' * 70}")
+ avg_percentage = (total_student / total_original * 100) if total_original > 0 else 0
+ print(f"{'TOTAL':<8} {total_student:<18} {total_original:<18} {avg_percentage:<10.2f}")
+ print(f"{'=' * 70}")
+ print(f"\n Deduced experts_to_keep mapping for {len(experts_to_keep)} layers")
+
+ # Create output directory
+ os.makedirs(args.output_path, exist_ok=True)
+ os.makedirs(os.path.join(args.output_path, "subblocks_safetensors"), exist_ok=True)
+
+ # Copy config files
+ print("Copying configuration files...")
+ copy_config_files(args.student_path, args.output_path)
+
+ # Save experts_to_keep.json
+ experts_to_keep_output = os.path.join(args.output_path, "experts_to_keep.json")
+ with open(experts_to_keep_output, "w") as f:
+ json.dump(experts_to_keep, f, indent=2)
+ print(f" Saved experts_to_keep mapping to {experts_to_keep_output}")
+
+ # Copy non-MoE weights (embeddings, attention, lm_head)
+ print("Copying non-MoE weights...")
+ weight_map = copy_non_moe_weights(args.student_path, args.output_path, args.num_layers)
+
+ # Load weights per layer (handles multi-file loading)
+ print(f"Processing {args.num_layers} layers...")
+
+ all_verification_errors = []
+
+ # Process each layer
+ for layer in tqdm(range(args.num_layers), desc="Processing layers"):
+ if len(experts_to_keep[layer]) == 0:
+ print(f"Layer {layer} has no experts to keep - ffn->no_op")
+ continue
+ layer_weight_map, layer_errors = process_single_layer(
+ layer,
+ args.original_path,
+ original_index,
+ args.student_path,
+ args.output_path,
+ experts_to_keep[layer],
+ )
+ weight_map.update(layer_weight_map)
+ all_verification_errors.extend(layer_errors)
+
+ # Calculate total size
+ total_size = 0
+ subblocks_dir = os.path.join(args.output_path, "subblocks_safetensors")
+ for filename in os.listdir(subblocks_dir):
+ filepath = os.path.join(subblocks_dir, filename)
+ total_size += os.path.getsize(filepath)
+
+ # Create model.safetensors.index.json
+ index = {"metadata": {"total_size": total_size}, "weight_map": weight_map}
+
+ index_path = os.path.join(args.output_path, "model.safetensors.index.json")
+ with open(index_path, "w") as f:
+ json.dump(index, f, indent=2)
+
+ print(f"\nCheckpoint created successfully at: {args.output_path}")
+ print(f"Total size: {total_size / 1e9:.2f} GB")
+
+
+if __name__ == "__main__":
+ main()
diff --git a/modelopt/torch/puzzletron/anymodel/models/llama/__init__.py b/modelopt/torch/puzzletron/anymodel/models/llama/__init__.py
new file mode 100644
index 0000000000..a0be9f919e
--- /dev/null
+++ b/modelopt/torch/puzzletron/anymodel/models/llama/__init__.py
@@ -0,0 +1,19 @@
+# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+# SPDX-License-Identifier: Apache-2.0
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from modelopt.torch.puzzletron.anymodel.models.llama.llama_converter import LlamaConverter
+from modelopt.torch.puzzletron.anymodel.models.llama.llama_model_descriptor import (
+ LlamaModelDescriptor,
+)
diff --git a/modelopt/torch/puzzletron/anymodel/models/llama/llama_converter.py b/modelopt/torch/puzzletron/anymodel/models/llama/llama_converter.py
new file mode 100644
index 0000000000..5a0686ecc8
--- /dev/null
+++ b/modelopt/torch/puzzletron/anymodel/models/llama/llama_converter.py
@@ -0,0 +1,53 @@
+# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+# SPDX-License-Identifier: Apache-2.0
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# mypy: ignore-errors
+
+"""Llama converter for AnyModel compression."""
+
+from typing import List
+
+from transformers import LlamaConfig
+
+from modelopt.torch.puzzletron.anymodel.converter import Converter, ConverterFactory
+from modelopt.torch.puzzletron.decilm.deci_lm_hf_code.block_config import (
+ AttentionConfig,
+ BlockConfig,
+ FFNConfig,
+)
+
+
+@ConverterFactory.register_decorator("llama")
+class LlamaConverter(Converter):
+ """Converter for Llama models to AnyModel format."""
+
+ @staticmethod
+ def create_block_configs_from_main_config(config: LlamaConfig) -> List[BlockConfig]:
+ """Create uniform block configs for all Llama layers.
+
+ Llama models have uniform architecture across all layers, so we create
+ the same BlockConfig for each layer.
+ """
+ num_hidden_layers = config.num_hidden_layers
+
+ block_configs = [
+ BlockConfig(
+ attention=AttentionConfig(
+ no_op=False, num_key_value_heads=config.num_key_value_heads
+ ),
+ ffn=FFNConfig(no_op=False, intermediate_size=config.intermediate_size),
+ ).to_dict()
+ for _ in range(num_hidden_layers)
+ ]
+ return block_configs
diff --git a/modelopt/torch/puzzletron/anymodel/models/llama/llama_model_descriptor.py b/modelopt/torch/puzzletron/anymodel/models/llama/llama_model_descriptor.py
new file mode 100644
index 0000000000..082e5da599
--- /dev/null
+++ b/modelopt/torch/puzzletron/anymodel/models/llama/llama_model_descriptor.py
@@ -0,0 +1,141 @@
+# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+# SPDX-License-Identifier: Apache-2.0
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# mypy: ignore-errors
+
+"""Llama model descriptor for AnyModel compression."""
+
+import re
+from dataclasses import dataclass, field
+from typing import Dict, List
+
+from transformers.models.llama.modeling_llama import (
+ LlamaDecoderLayer,
+ LlamaForCausalLM,
+ LlamaRotaryEmbedding,
+)
+
+from modelopt.torch.puzzletron.anymodel.model_descriptor import (
+ ModelDescriptor,
+ ModelDescriptorFactory,
+)
+from modelopt.torch.puzzletron.anymodel.puzzformer.no_op import (
+ MatchingZeros,
+ Same,
+ return_tuple_of_size,
+)
+from modelopt.torch.puzzletron.decilm.deci_lm_hf_code.block_config import BlockConfig
+from modelopt.torch.puzzletron.pruning.ffn_intermediate_pruning_mixin import (
+ FFNIntermediateLayerDescriptor,
+)
+from modelopt.torch.puzzletron.pruning.kv_heads_pruning_mixin import KVHeadsLayerDescriptor
+
+
+@ModelDescriptorFactory.register_decorator("llama")
+class LlamaModelDescriptor(ModelDescriptor):
+ """Model descriptor for Llama models (Llama 2, Llama 3, Llama 3.1, Llama 3.2)."""
+
+ @staticmethod
+ def decoder_layer_cls():
+ return LlamaDecoderLayer
+
+ @staticmethod
+ def block_config_to_layer_overrides(block_config: BlockConfig):
+ return {
+ "intermediate_size": block_config.ffn.intermediate_size,
+ "num_key_value_heads": block_config.attention.num_key_value_heads,
+ }
+
+ @staticmethod
+ def attn_no_op_post_init(decoder_layer: LlamaDecoderLayer):
+ decoder_layer.input_layernorm = Same()
+ decoder_layer.self_attn = return_tuple_of_size(MatchingZeros, size=2)()
+
+ @staticmethod
+ def mlp_no_op_post_init(decoder_layer: LlamaDecoderLayer):
+ decoder_layer.post_attention_layernorm = Same()
+ decoder_layer.mlp = MatchingZeros()
+
+ @staticmethod
+ def init_rotary_embedding(model: LlamaForCausalLM, runtime):
+ model.model.rotary_emb = LlamaRotaryEmbedding(model.config, runtime.device)
+
+ @staticmethod
+ def input_embedding_name():
+ return "model.embed_tokens"
+
+ @staticmethod
+ def output_embedding_name():
+ return "lm_head"
+
+ @staticmethod
+ def final_norm_name():
+ return "model.norm"
+
+ @staticmethod
+ def layer_block_name(index: int):
+ return f"model.layers.{index}"
+
+ @staticmethod
+ def layer_name_predicates(num_layers: int) -> Dict[str, re.Pattern]:
+ layer_name_patterns = {
+ "embeddings": re.compile(r"^model\.embed_tokens\.weight$"),
+ "lm_head": re.compile(r"^(model\.norm\.weight|lm_head\.weight)$"),
+ }
+
+ def build_ffn_predicates() -> Dict[str, re.Pattern]:
+ return {
+ f"block_{layer_idx}_ffn": re.compile(
+ rf"^model\.layers\.{layer_idx}\.(post_attention_layernorm\.weight"
+ r"|mlp\.up_proj\.weight"
+ r"|mlp\.gate_proj\.weight"
+ r"|mlp\.down_proj\.weight)$"
+ )
+ for layer_idx in range(num_layers)
+ }
+
+ def build_attention_predicates() -> Dict[str, re.Pattern]:
+ return {
+ f"block_{layer_idx}_attention": re.compile(
+ rf"^model\.layers\.{layer_idx}\.(input_layernorm\.weight"
+ r"|self_attn\.q_proj\.weight"
+ r"|self_attn\.k_proj\.weight"
+ r"|self_attn\.v_proj\.weight"
+ r"|self_attn\.o_proj\.weight)$"
+ )
+ for layer_idx in range(num_layers)
+ }
+
+ layer_name_patterns.update(**build_ffn_predicates(), **build_attention_predicates())
+ return layer_name_patterns
+
+
+@dataclass
+class LlamaFFNIntermediateLayerDescriptor(FFNIntermediateLayerDescriptor):
+ """Layer descriptor for Llama FFN intermediate pruning."""
+
+ down_proj_name: str = "mlp.down_proj"
+ ffn_prefix_name: str = "model.layers.{layer_idx}.mlp"
+ linear_weight_names: List[str] = field(
+ default_factory=lambda: ["down_proj", "gate_proj", "up_proj"]
+ )
+
+
+@dataclass
+class LlamaKVHeadsLayerDescriptor(KVHeadsLayerDescriptor):
+ o_proj_name: str = "self_attn.o_proj"
+ attn_prefix_name: str = "model.layers.{layer_idx}.self_attn"
+ qkvo_weight_names: List[str] = field(
+ default_factory=lambda: ["q_proj", "k_proj", "v_proj", "o_proj"]
+ )
diff --git a/modelopt/torch/puzzletron/anymodel/models/mistral_small/__init__.py b/modelopt/torch/puzzletron/anymodel/models/mistral_small/__init__.py
new file mode 100644
index 0000000000..821be47e9d
--- /dev/null
+++ b/modelopt/torch/puzzletron/anymodel/models/mistral_small/__init__.py
@@ -0,0 +1,21 @@
+# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+# SPDX-License-Identifier: Apache-2.0
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from modelopt.torch.puzzletron.anymodel.models.mistral_small.mistral_small_converter import (
+ MistralSmallConverter,
+)
+from modelopt.torch.puzzletron.anymodel.models.mistral_small.mistral_small_model_descriptor import (
+ MistralSmallModelDescriptor,
+)
diff --git a/modelopt/torch/puzzletron/anymodel/models/mistral_small/mistral_small_converter.py b/modelopt/torch/puzzletron/anymodel/models/mistral_small/mistral_small_converter.py
new file mode 100644
index 0000000000..ddc8151dc9
--- /dev/null
+++ b/modelopt/torch/puzzletron/anymodel/models/mistral_small/mistral_small_converter.py
@@ -0,0 +1,41 @@
+# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+# SPDX-License-Identifier: Apache-2.0
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# mypy: ignore-errors
+
+from typing import List
+
+from transformers import MistralConfig
+
+from modelopt.torch.puzzletron.anymodel.converter import Converter, ConverterFactory
+from modelopt.torch.puzzletron.decilm.deci_lm_hf_code.block_config import (
+ AttentionConfig,
+ BlockConfig,
+ FFNConfig,
+)
+
+
+@ConverterFactory.register_decorator("mistral_small")
+class MistralSmallConverter(Converter):
+ @staticmethod
+ def create_block_configs_from_main_config(config: MistralConfig) -> List[BlockConfig]:
+ num_hidden_layers = config.num_hidden_layers
+
+ block_config = BlockConfig(
+ attention=AttentionConfig(no_op=False, num_key_value_heads=config.num_key_value_heads),
+ ffn=FFNConfig(no_op=False, intermediate_size=config.intermediate_size),
+ ).to_dict()
+
+ block_configs = [block_config] * num_hidden_layers
+ return block_configs
diff --git a/modelopt/torch/puzzletron/anymodel/models/mistral_small/mistral_small_model_descriptor.py b/modelopt/torch/puzzletron/anymodel/models/mistral_small/mistral_small_model_descriptor.py
new file mode 100644
index 0000000000..1ac2bd7072
--- /dev/null
+++ b/modelopt/torch/puzzletron/anymodel/models/mistral_small/mistral_small_model_descriptor.py
@@ -0,0 +1,135 @@
+# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+# SPDX-License-Identifier: Apache-2.0
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# mypy: ignore-errors
+
+import re
+from dataclasses import dataclass, field
+from typing import Dict, List
+
+from transformers.models.mistral.modeling_mistral import (
+ MistralDecoderLayer,
+ MistralForCausalLM,
+ MistralRotaryEmbedding,
+)
+
+from modelopt.torch.puzzletron.anymodel.model_descriptor import (
+ ModelDescriptor,
+ ModelDescriptorFactory,
+)
+from modelopt.torch.puzzletron.anymodel.puzzformer.no_op import (
+ MatchingZeros,
+ Same,
+ return_tuple_of_size,
+)
+from modelopt.torch.puzzletron.decilm.deci_lm_hf_code.block_config import BlockConfig
+from modelopt.torch.puzzletron.pruning.ffn_intermediate_pruning_mixin import (
+ FFNIntermediateLayerDescriptor,
+)
+from modelopt.torch.puzzletron.pruning.kv_heads_pruning_mixin import KVHeadsLayerDescriptor
+
+
+@ModelDescriptorFactory.register_decorator("mistral_small")
+class MistralSmallModelDescriptor(ModelDescriptor):
+ @staticmethod
+ def decoder_layer_cls():
+ return MistralDecoderLayer
+
+ @staticmethod
+ def block_config_to_layer_overrides(block_config: BlockConfig):
+ return {
+ "intermediate_size": block_config.ffn.intermediate_size,
+ "num_key_value_heads": block_config.attention.num_key_value_heads,
+ }
+
+ @staticmethod
+ def attn_no_op_post_init(decoder_layer: MistralDecoderLayer):
+ decoder_layer.input_layernorm = Same()
+ decoder_layer.self_attn = return_tuple_of_size(MatchingZeros, size=2)()
+
+ @staticmethod
+ def mlp_no_op_post_init(decoder_layer: MistralDecoderLayer):
+ decoder_layer.post_attention_layernorm = Same()
+ decoder_layer.mlp = MatchingZeros()
+
+ @staticmethod
+ def init_rotary_embedding(model: MistralForCausalLM, runtime):
+ model.model.rotary_emb = MistralRotaryEmbedding(model.config, runtime.device)
+
+ @staticmethod
+ def input_embedding_name():
+ return "model.embed_tokens"
+
+ @staticmethod
+ def output_embedding_name():
+ return "lm_head"
+
+ @staticmethod
+ def final_norm_name():
+ return "model.norm"
+
+ @staticmethod
+ def layer_block_name(index: int):
+ return f"model.layers.{index}"
+
+ @staticmethod
+ def layer_name_predicates(num_layers: int) -> Dict[str, re.Pattern]:
+ layer_name_patterns = {
+ "embeddings": re.compile(r"^model\.embed_tokens\.weight$"),
+ "lm_head": re.compile(r"^(model\.norm\.weight|lm_head\.weight)$"),
+ }
+
+ def build_ffn_predicates() -> Dict[str, re.Pattern]:
+ return {
+ f"block_{layer_idx}_ffn": re.compile(
+ rf"^model\.layers\.{layer_idx}\.(post_attention_layernorm\.weight"
+ r"|mlp\.up_proj\.weight"
+ r"|mlp\.gate_proj\.weight"
+ r"|mlp\.down_proj\.weight)$"
+ )
+ for layer_idx in range(num_layers)
+ }
+
+ def build_attention_predicates() -> Dict[str, re.Pattern]:
+ return {
+ f"block_{layer_idx}_attention": re.compile(
+ rf"^model\.layers\.{layer_idx}\.(input_layernorm\.weight"
+ r"|self_attn\.q_proj\.weight"
+ r"|self_attn\.k_proj\.weight"
+ r"|self_attn\.v_proj\.weight"
+ r"|self_attn\.o_proj\.weight)$"
+ )
+ for layer_idx in range(num_layers)
+ }
+
+ layer_name_patterns.update(**build_ffn_predicates(), **build_attention_predicates())
+ return layer_name_patterns
+
+
+@dataclass
+class MistralFFNIntermediateLayerDescriptor(FFNIntermediateLayerDescriptor):
+ down_proj_name: str = "mlp.down_proj"
+ ffn_prefix_name: str = "model.layers.{layer_idx}.mlp"
+ linear_weight_names: List[str] = field(
+ default_factory=lambda: ["down_proj", "gate_proj", "up_proj"]
+ )
+
+
+@dataclass
+class MistralKVHeadsLayerDescriptor(KVHeadsLayerDescriptor):
+ o_proj_name: str = "self_attn.o_proj"
+ attn_prefix_name: str = "model.layers.{layer_idx}.self_attn"
+ qkvo_weight_names: List[str] = field(
+ default_factory=lambda: ["q_proj", "k_proj", "v_proj", "o_proj"]
+ )
diff --git a/modelopt/torch/puzzletron/anymodel/models/nemotron_h/__init__.py b/modelopt/torch/puzzletron/anymodel/models/nemotron_h/__init__.py
new file mode 100644
index 0000000000..a2140f118e
--- /dev/null
+++ b/modelopt/torch/puzzletron/anymodel/models/nemotron_h/__init__.py
@@ -0,0 +1,21 @@
+# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+# SPDX-License-Identifier: Apache-2.0
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from modelopt.torch.puzzletron.anymodel.models.nemotron_h.nemotron_h_converter import (
+ NemotronHConverter,
+)
+from modelopt.torch.puzzletron.anymodel.models.nemotron_h.nemotron_h_model_descriptor import (
+ NemotronHModelDescriptor,
+)
diff --git a/modelopt/torch/puzzletron/anymodel/models/nemotron_h/nemotron_h_converter.py b/modelopt/torch/puzzletron/anymodel/models/nemotron_h/nemotron_h_converter.py
new file mode 100644
index 0000000000..16d9e3c73d
--- /dev/null
+++ b/modelopt/torch/puzzletron/anymodel/models/nemotron_h/nemotron_h_converter.py
@@ -0,0 +1,84 @@
+# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+# SPDX-License-Identifier: Apache-2.0
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from typing import List
+
+from modelopt.torch.puzzletron.anymodel.converter import Converter, ConverterFactory
+from modelopt.torch.puzzletron.decilm.deci_lm_hf_code.block_config import (
+ AttentionConfig,
+ BlockConfig,
+ FFNConfig,
+ MambaConfig,
+ MoEConfig,
+)
+
+
+@ConverterFactory.register_decorator("nemotron_h")
+class NemotronHConverter(Converter):
+ @staticmethod
+ def create_block_configs_from_main_config(config) -> List[BlockConfig]:
+ # Create block configs for each layer based on the hybrid_override_pattern
+ block_configs = []
+
+ # Parse the hybrid_override_pattern: "M-M-M-M*-M-M-M-M-M*-M-M-M-M-M*-M-M-M-M-M*-M-M-M-M-M-"
+ pattern = config.hybrid_override_pattern
+ print(f"Parsing hybrid pattern: {pattern}")
+
+ for i, char in enumerate(pattern):
+ if char == "M":
+ _block_config = BlockConfig(
+ attention=AttentionConfig(
+ mamba=MambaConfig( # Those parameters are currently used only for calc_block_stats.
+ state_dim=config.ssm_state_size,
+ num_heads=config.mamba_num_heads,
+ head_dim=config.mamba_head_dim,
+ num_groups=config.n_groups,
+ )
+ ),
+ ffn=FFNConfig(no_op=True),
+ )
+
+ elif char == "-":
+ _block_config = BlockConfig(
+ attention=AttentionConfig(no_op=True),
+ ffn=FFNConfig(intermediate_size=config.intermediate_size),
+ )
+
+ elif char == "*":
+ _block_config = BlockConfig(
+ attention=AttentionConfig(num_key_value_heads=config.num_key_value_heads),
+ ffn=FFNConfig(no_op=True),
+ )
+
+ elif char == "E":
+ _block_config = BlockConfig(
+ attention=AttentionConfig(no_op=True),
+ ffn=FFNConfig(
+ moe=MoEConfig(
+ num_local_experts=config.n_routed_experts,
+ expert_intermediate_dim=config.moe_intermediate_size,
+ num_experts_per_tok=config.num_experts_per_tok,
+ )
+ ),
+ )
+ else:
+ raise ValueError(
+ f"Unknown character '{char}' in hybrid_override_pattern at position {i}"
+ )
+
+ block_configs.append(_block_config)
+
+ print(f"Created {len(block_configs)} block configs from pattern")
+ return block_configs
diff --git a/modelopt/torch/puzzletron/anymodel/models/nemotron_h/nemotron_h_model_descriptor.py b/modelopt/torch/puzzletron/anymodel/models/nemotron_h/nemotron_h_model_descriptor.py
new file mode 100644
index 0000000000..7687d57c83
--- /dev/null
+++ b/modelopt/torch/puzzletron/anymodel/models/nemotron_h/nemotron_h_model_descriptor.py
@@ -0,0 +1,255 @@
+# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+# SPDX-License-Identifier: Apache-2.0
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# mypy: ignore-errors
+
+import importlib
+import inspect
+import pkgutil
+import re
+from collections import defaultdict
+from dataclasses import dataclass, field
+from typing import Dict, Iterable, List, Tuple, Type
+
+import torch.nn as nn
+
+from modelopt.torch.puzzletron.anymodel.model_descriptor import (
+ ModelDescriptor,
+ ModelDescriptorFactory,
+)
+from modelopt.torch.puzzletron.anymodel.puzzformer.no_op import MatchingZeros, Same
+from modelopt.torch.puzzletron.decilm.deci_lm_hf_code.block_config import BlockConfig
+from modelopt.torch.puzzletron.pruning.expert_removal_pruning_mixin import (
+ ExpertRemovalLayerDescriptor,
+ ExpertRemovalPruningMixIn,
+)
+from modelopt.torch.puzzletron.pruning.pruning_mixin import PruningMixIn
+
+
+def get_dynamic_modules(module_cls_str: str) -> List[Type[nn.Module]]:
+ import transformers_modules
+
+ matches = []
+ for finder, modname, ispkg in pkgutil.walk_packages(
+ transformers_modules.__path__, transformers_modules.__name__ + "."
+ ):
+ module = importlib.import_module(modname)
+ for _, obj in inspect.getmembers(module, inspect.isclass):
+ if obj.__name__ == module_cls_str:
+ matches.append(obj)
+
+ return matches
+
+
+@dataclass
+class NemotronHExpertRemovalLayerDescriptor(ExpertRemovalLayerDescriptor):
+ target_name: str = "mixer.gate"
+ moe_prefix_name: str = "backbone.layers.{layer_idx}.mixer"
+ expert_prefix_name: str = "experts.{expert_idx}"
+ router_weights: List[str] = field(default_factory=lambda: ["gate.weight"])
+ router_biases: List[str] = field(default_factory=lambda: ["gate.e_score_correction_bias"])
+ expert_weights: List[str] = field(
+ default_factory=lambda: ["up_proj.weight", "down_proj.weight"]
+ )
+
+ def get_modules_names_to_hook(self, model) -> List[Tuple[int, str]]:
+ if self.target_name != "mixer":
+ return super().get_modules_names_to_hook(model)
+
+ # when target is `mixer` we'll target moe layers of class type: `NemotronHMOE`, as NemotronH models use auto-map we'll check for class name instead of class type.
+ target_class_name = "NemotronHMOE"
+
+ module_names_to_hook = []
+ for module_name, module in model.named_modules():
+ # restrict to attributes called "mixer" and with the desired class name
+ if (
+ module_name.endswith(self.target_name)
+ and module.__class__.__name__ == target_class_name
+ ):
+ module_names_to_hook.append(
+ (self.block_idx_from_module_name(module_name), module_name)
+ )
+ return module_names_to_hook
+
+
+@ModelDescriptorFactory.register_decorator("nemotron_h")
+class NemotronHModelDescriptor(ModelDescriptor):
+ _DECODER_LAYER_CLS: Type[nn.Module] = None
+
+ @staticmethod
+ def decoder_layer_cls():
+ decoder_cls_list = get_dynamic_modules("NemotronHBlock")
+ if not decoder_cls_list:
+ raise AssertionError(
+ "NemotronH contains dynamic modules that should be cached beforehand, make sure to load your config using `load_model_config` or manually call `force_cache_dynamic_modules(config, checkpoint_dir)`"
+ )
+ return decoder_cls_list
+
+ @staticmethod
+ def requires_trust_remote_code() -> bool:
+ return True
+
+ @staticmethod
+ def block_config_to_layer_overrides(block_config: BlockConfig):
+ override_kwargs = {}
+ if block_config.ffn.intermediate_size is not None:
+ override_kwargs["intermediate_size"] = block_config.ffn.intermediate_size
+
+ if block_config.attention.num_key_value_heads is not None:
+ override_kwargs["num_key_value_heads"] = block_config.attention.num_key_value_heads
+
+ if block_config.ffn.moe is not None:
+ override_kwargs["moe_intermediate_size"] = block_config.ffn.moe.expert_intermediate_dim
+ override_kwargs["n_routed_experts"] = block_config.ffn.moe.num_local_experts
+
+ return override_kwargs
+
+ @staticmethod
+ def _block_no_op_post_init(decoder_layer):
+ """
+ Due to the subblock structure of NemotronH always one of the subblock is set to no-op, for a real no-op both attention & ffn no-op should be set to True.
+ """
+ block_config = decoder_layer.config.block_configs[decoder_layer.layer_idx]
+ if block_config.ffn.no_op and block_config.attention.no_op:
+ decoder_layer.norm = Same()
+ decoder_layer.mixer = MatchingZeros()
+
+ @staticmethod
+ def attn_no_op_post_init(decoder_layer):
+ NemotronHModelDescriptor._block_no_op_post_init(decoder_layer)
+
+ @staticmethod
+ def mlp_no_op_post_init(decoder_layer):
+ NemotronHModelDescriptor._block_no_op_post_init(decoder_layer)
+
+ @classmethod
+ def create_dummy_block(cls, original_layer: nn.Module, block_index: int) -> nn.Module:
+ dummy_block = super().create_dummy_block(original_layer, block_index)
+ # Required by `NemotronHModel.forward`.
+ dummy_block.block_type = original_layer.block_type
+ # Preserve layer_idx if it exists (used by _block_no_op_post_init)
+ if hasattr(original_layer, "layer_idx"):
+ dummy_block.layer_idx = original_layer.layer_idx
+ # Preserve config if it exists (used by _block_no_op_post_init to access block_configs)
+ if hasattr(original_layer, "config"):
+ dummy_block.config = original_layer.config
+ return dummy_block
+
+ @staticmethod
+ def init_rotary_embedding(model, runtime):
+ """
+ NemotronH has no positional embeddings
+ """
+
+ @staticmethod
+ def input_embedding_name():
+ return "backbone.embeddings"
+
+ @staticmethod
+ def output_embedding_name():
+ return "lm_head"
+
+ @staticmethod
+ def final_norm_name():
+ return "backbone.norm_f"
+
+ @staticmethod
+ def layer_block_name(index: int):
+ return f"backbone.layers.{index}"
+
+ @classmethod
+ def get_weight_groups(
+ cls, layer_names: Iterable[str], num_hidden_layers: int
+ ) -> Dict[str, List[str]]:
+ """
+ Problem with NemotronH is that `norm.weight` can be in both block_{i}_ffn and block_{i}_attention. duplicate groups with `norm.weight` should be removed.
+ """
+ weight_groups = defaultdict(list)
+ for name in layer_names:
+ is_matched = False
+ for group, pattern in cls.layer_name_predicates(num_hidden_layers).items():
+ if pattern.match(name):
+ weight_groups[group].append(name)
+ is_matched = True
+ if not is_matched:
+ raise ValueError(f"Couldn't find a match for {name}")
+
+ valid_weight_groups = {}
+ for group, names in weight_groups.items():
+ if len(names) == 1:
+ only_name = names[0]
+ if only_name.endswith("norm.weight") and "layers" in only_name:
+ # Skip and don't append this group to valid_weight_groups
+ continue
+ valid_weight_groups[group] = names
+
+ return valid_weight_groups
+
+ @staticmethod
+ def layer_name_predicates(num_layers: int) -> Dict[str, re.Pattern]:
+ layer_name_patterns = {
+ "embeddings": re.compile(
+ r"^(model\.embed_tokens\.weight|backbone\.embeddings\.weight)$"
+ ),
+ "lm_head": re.compile(r"^(lm_head\.weight|backbone\.norm_f\.weight)$"),
+ }
+
+ def build_ffn_predicates() -> Dict[str, re.Pattern]:
+ return {
+ f"block_{layer_idx}_ffn": re.compile(
+ rf"^backbone\.layers\.{layer_idx}\."
+ r"(norm\.weight|" # ← INCLUDED IN FFN
+ r"mixer\.(gate\.e_score_correction_bias"
+ r"|gate\.weight"
+ r"|experts\.\d+\.up_proj\.weight"
+ r"|experts\.\d+\.down_proj\.weight"
+ r"|shared_experts\.up_proj\.weight"
+ r"|shared_experts\.down_proj\.weight))$"
+ )
+ for layer_idx in range(num_layers)
+ }
+
+ def build_attention_predicates() -> Dict[str, re.Pattern]:
+ return {
+ f"block_{layer_idx}_attention": re.compile(
+ rf"^backbone\.layers\.{layer_idx}\."
+ r"(norm\.weight|" # ← INCLUDED IN ATTENTION
+ r"mixer\.(norm\.weight"
+ r"|A_log"
+ r"|D"
+ r"|conv1d\.weight"
+ r"|conv1d\.bias"
+ r"|dt_bias"
+ r"|in_proj\.weight"
+ r"|out_proj\.weight"
+ r"|q_proj\.weight"
+ r"|k_proj\.weight"
+ r"|v_proj\.weight"
+ r"|o_proj\.weight))$"
+ )
+ for layer_idx in range(num_layers)
+ }
+
+ layer_name_patterns.update(
+ **build_ffn_predicates(),
+ **build_attention_predicates(),
+ )
+
+ return layer_name_patterns
+
+ @staticmethod
+ def pruning_mixins() -> Dict[str, PruningMixIn]:
+ return {
+ "experts_removal": ExpertRemovalPruningMixIn(NemotronHExpertRemovalLayerDescriptor()),
+ }
diff --git a/modelopt/torch/puzzletron/anymodel/models/nemotron_h_v2/__init__.py b/modelopt/torch/puzzletron/anymodel/models/nemotron_h_v2/__init__.py
new file mode 100644
index 0000000000..4b17785ace
--- /dev/null
+++ b/modelopt/torch/puzzletron/anymodel/models/nemotron_h_v2/__init__.py
@@ -0,0 +1,21 @@
+# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+# SPDX-License-Identifier: Apache-2.0
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from modelopt.torch.puzzletron.anymodel.models.nemotron_h_v2.nemotron_h_v2_converter import (
+ NemotronHV2Converter,
+)
+from modelopt.torch.puzzletron.anymodel.models.nemotron_h_v2.nemotron_h_v2_model_descriptor import (
+ NemotronHV2ModelDescriptor,
+)
diff --git a/modelopt/torch/puzzletron/anymodel/models/nemotron_h_v2/nemotron_h_v2_converter.py b/modelopt/torch/puzzletron/anymodel/models/nemotron_h_v2/nemotron_h_v2_converter.py
new file mode 100644
index 0000000000..2c54388325
--- /dev/null
+++ b/modelopt/torch/puzzletron/anymodel/models/nemotron_h_v2/nemotron_h_v2_converter.py
@@ -0,0 +1,84 @@
+# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+# SPDX-License-Identifier: Apache-2.0
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from typing import List
+
+from modelopt.torch.puzzletron.anymodel.converter import Converter, ConverterFactory
+from modelopt.torch.puzzletron.decilm.deci_lm_hf_code.block_config import (
+ AttentionConfig,
+ BlockConfig,
+ FFNConfig,
+ MambaConfig,
+ MoEConfig,
+)
+
+
+@ConverterFactory.register_decorator("nemotron_h_v2")
+class NemotronHV2Converter(Converter):
+ @staticmethod
+ def create_block_configs_from_main_config(config) -> List[BlockConfig]:
+ # Create block configs for each layer based on the hybrid_override_pattern
+ block_configs = []
+
+ # Parse the hybrid_override_pattern: "M-M-M-M*-M-M-M-M-M*-M-M-M-M-M*-M-M-M-M-M*-M-M-M-M-M-"
+ pattern = config.hybrid_override_pattern
+ print(f"Parsing hybrid pattern: {pattern}")
+
+ for i, char in enumerate(pattern):
+ if char == "M":
+ _block_config = BlockConfig(
+ attention=AttentionConfig(
+ mamba=MambaConfig( # Those parameters are currently used only for calc_block_stats.
+ state_dim=config.ssm_state_size,
+ num_heads=config.mamba_num_heads,
+ head_dim=config.mamba_head_dim,
+ num_groups=config.n_groups,
+ )
+ ),
+ ffn=FFNConfig(no_op=True),
+ )
+
+ elif char == "-":
+ _block_config = BlockConfig(
+ attention=AttentionConfig(no_op=True),
+ ffn=FFNConfig(intermediate_size=config.intermediate_size),
+ )
+
+ elif char == "*":
+ _block_config = BlockConfig(
+ attention=AttentionConfig(num_key_value_heads=config.num_key_value_heads),
+ ffn=FFNConfig(no_op=True),
+ )
+
+ elif char == "E":
+ _block_config = BlockConfig(
+ attention=AttentionConfig(no_op=True),
+ ffn=FFNConfig(
+ moe=MoEConfig(
+ num_local_experts=config.n_routed_experts,
+ expert_intermediate_dim=config.moe_intermediate_size,
+ num_experts_per_tok=config.num_experts_per_tok,
+ )
+ ),
+ )
+ else:
+ raise ValueError(
+ f"Unknown character '{char}' in hybrid_override_pattern at position {i}"
+ )
+
+ block_configs.append(_block_config)
+
+ print(f"Created {len(block_configs)} block configs from pattern")
+ return block_configs
diff --git a/modelopt/torch/puzzletron/anymodel/models/nemotron_h_v2/nemotron_h_v2_model_descriptor.py b/modelopt/torch/puzzletron/anymodel/models/nemotron_h_v2/nemotron_h_v2_model_descriptor.py
new file mode 100644
index 0000000000..c8c89658bf
--- /dev/null
+++ b/modelopt/torch/puzzletron/anymodel/models/nemotron_h_v2/nemotron_h_v2_model_descriptor.py
@@ -0,0 +1,240 @@
+# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+# SPDX-License-Identifier: Apache-2.0
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import importlib
+import inspect
+import pkgutil
+import re
+from collections import defaultdict
+from dataclasses import dataclass, field
+from typing import Dict, Iterable, List, Type
+
+import torch.nn as nn
+
+from modelopt.torch.puzzletron.anymodel.model_descriptor import (
+ ModelDescriptor,
+ ModelDescriptorFactory,
+)
+from modelopt.torch.puzzletron.anymodel.puzzformer.no_op import MatchingZeros, Same
+from modelopt.torch.puzzletron.decilm.deci_lm_hf_code.block_config import BlockConfig
+from modelopt.torch.puzzletron.pruning.ffn_intermediate_pruning_mixin import (
+ FFNIntermediateLayerDescriptor,
+ FFNIntermediatePruningMixIn,
+)
+from modelopt.torch.puzzletron.pruning.pruning_mixin import PruningMixIn
+
+
+def get_dynamic_modules(module_cls_str: str) -> List[Type[nn.Module]]:
+ import transformers_modules
+
+ matches = []
+ for finder, modname, ispkg in pkgutil.walk_packages(
+ transformers_modules.__path__, transformers_modules.__name__ + "."
+ ):
+ module = importlib.import_module(modname)
+ for _, obj in inspect.getmembers(module, inspect.isclass):
+ if obj.__name__ == module_cls_str:
+ matches.append(obj)
+
+ return matches
+
+
+@dataclass
+class NemotronHV2FFNIntermediateLayerDescriptor(FFNIntermediateLayerDescriptor):
+ down_proj_name: str = "mixer.down_proj"
+ ffn_prefix_name: str = "backbone.layers.{layer_idx}.mixer"
+ linear_weight_names: List[str] = field(default_factory=lambda: ["down_proj", "up_proj"])
+
+
+@ModelDescriptorFactory.register_decorator("nemotron_h_v2")
+class NemotronHV2ModelDescriptor(ModelDescriptor):
+ _DECODER_LAYER_CLS: Type[nn.Module] = None
+
+ @staticmethod
+ def decoder_layer_cls():
+ decoder_cls_list = get_dynamic_modules("NemotronHBlock")
+ if not decoder_cls_list:
+ raise AssertionError(
+ "NemotronH contains dynamic modules that should be cached beforehand, make sure to load your config using `load_model_config` or manually call `force_cache_dynamic_modules(config, checkpoint_dir)`"
+ )
+ return decoder_cls_list
+
+ @staticmethod
+ def requires_trust_remote_code() -> bool:
+ return True
+
+ @staticmethod
+ def block_config_to_layer_overrides(block_config: BlockConfig):
+ override_kwargs = {}
+ if block_config.ffn is not None and block_config.ffn.intermediate_size is not None:
+ override_kwargs["intermediate_size"] = block_config.ffn.intermediate_size
+
+ if (
+ block_config.attention is not None
+ and block_config.attention.num_key_value_heads is not None
+ ):
+ override_kwargs["num_key_value_heads"] = block_config.attention.num_key_value_heads
+
+ if block_config.ffn is not None and block_config.ffn.moe is not None:
+ override_kwargs["moe_intermediate_size"] = block_config.ffn.moe.expert_intermediate_dim
+ override_kwargs["n_routed_experts"] = block_config.ffn.moe.num_local_experts
+
+ return override_kwargs
+
+ @staticmethod
+ def _block_no_op_post_init(decoder_layer):
+ """
+ Due to the subblock structure of NemotronH always one of the subblock is set to no-op, for a real no-op both attention & ffn no-op should be set to True.
+ """
+ block_config = decoder_layer.config.block_configs[decoder_layer.layer_idx]
+ ffn_no_op = block_config.ffn is not None and block_config.ffn.no_op
+ attn_no_op = block_config.attention is not None and block_config.attention.no_op
+ if ffn_no_op and attn_no_op:
+ decoder_layer.norm = Same()
+ decoder_layer.mixer = MatchingZeros()
+
+ @staticmethod
+ def attn_no_op_post_init(decoder_layer):
+ NemotronHV2ModelDescriptor._block_no_op_post_init(decoder_layer)
+
+ @staticmethod
+ def mlp_no_op_post_init(decoder_layer):
+ NemotronHV2ModelDescriptor._block_no_op_post_init(decoder_layer)
+
+ @classmethod
+ def create_dummy_block(cls, original_layer: nn.Module, block_index: int) -> nn.Module:
+ dummy_block = super().create_dummy_block(original_layer, block_index)
+ # Required by `NemotronHModel.forward`.
+ dummy_block.block_type = original_layer.block_type
+ # Preserve layer_idx if it exists (used by _block_no_op_post_init)
+ if hasattr(original_layer, "layer_idx"):
+ dummy_block.layer_idx = original_layer.layer_idx
+ # Preserve config if it exists (used by _block_no_op_post_init to access block_configs)
+ if hasattr(original_layer, "config"):
+ dummy_block.config = original_layer.config
+ return dummy_block
+
+ @staticmethod
+ def init_rotary_embedding(model, runtime):
+ """
+ NemotronH has no positional embeddings
+ """
+
+ @staticmethod
+ def input_embedding_name():
+ return "backbone.embeddings"
+
+ @staticmethod
+ def output_embedding_name():
+ return "lm_head"
+
+ @staticmethod
+ def final_norm_name():
+ return "backbone.norm_f"
+
+ @staticmethod
+ def layer_block_name(index: int):
+ return f"backbone.layers.{index}"
+
+ @classmethod
+ def get_weight_groups(
+ cls, layer_names: Iterable[str], num_hidden_layers: int
+ ) -> Dict[str, List[str]]:
+ """
+ Problem with NemotronH is that `norm.weight` can be in both block_{i}_ffn and block_{i}_attention. duplicate groups with `norm.weight` should be removed.
+ """
+ weight_groups = defaultdict(list)
+ for name in layer_names:
+ is_matched = False
+ for group, pattern in cls.layer_name_predicates(num_hidden_layers).items():
+ if pattern.match(name):
+ weight_groups[group].append(name)
+ is_matched = True
+ if not is_matched:
+ raise ValueError(f"Couldn't find a match for {name}")
+
+ valid_weight_groups = {}
+ for group, names in weight_groups.items():
+ if len(names) == 1:
+ only_name = names[0]
+ if only_name.endswith("norm.weight") and "layers" in only_name:
+ # Skip and don't append this group to valid_weight_groups
+ continue
+ valid_weight_groups[group] = names
+
+ return valid_weight_groups
+
+ @staticmethod
+ def layer_name_predicates(num_layers: int) -> Dict[str, re.Pattern]:
+ layer_name_patterns = {
+ "embeddings": re.compile(
+ r"^(model\.embed_tokens\.weight|backbone\.embeddings\.weight)$"
+ ),
+ "lm_head": re.compile(r"^(lm_head\.weight|backbone\.norm_f\.weight)$"),
+ }
+
+ def build_ffn_predicates() -> Dict[str, re.Pattern]:
+ return {
+ f"block_{layer_idx}_ffn": re.compile(
+ rf"^backbone\.layers\.{layer_idx}\."
+ r"(norm\.weight|" # ← INCLUDED IN FFN
+ r"mixer\.(gate\.e_score_correction_bias"
+ r"|gate\.weight"
+ r"|experts\.\d+\.up_proj\.weight"
+ r"|experts\.\d+\.down_proj\.weight"
+ r"|shared_experts\.up_proj\.weight"
+ r"|shared_experts\.down_proj\.weight"
+ r"|up_proj\.weight" # Simple MLP (non-MoE)
+ r"|down_proj\.weight))$" # Simple MLP (non-MoE)
+ )
+ for layer_idx in range(num_layers)
+ }
+
+ def build_attention_predicates() -> Dict[str, re.Pattern]:
+ return {
+ f"block_{layer_idx}_attention": re.compile(
+ rf"^backbone\.layers\.{layer_idx}\."
+ r"(norm\.weight|" # ← INCLUDED IN ATTENTION
+ r"mixer\.(norm\.weight"
+ r"|A_log"
+ r"|D"
+ r"|conv1d\.weight"
+ r"|conv1d\.bias"
+ r"|dt_bias"
+ r"|in_proj\.weight"
+ r"|out_proj\.weight"
+ r"|q_proj\.weight"
+ r"|k_proj\.weight"
+ r"|v_proj\.weight"
+ r"|o_proj\.weight))$"
+ )
+ for layer_idx in range(num_layers)
+ }
+
+ layer_name_patterns.update(
+ **build_ffn_predicates(),
+ **build_attention_predicates(),
+ )
+
+ return layer_name_patterns
+
+ @staticmethod
+ def pruning_mixins() -> Dict[str, PruningMixIn]:
+ return {
+ "ffn_intermediate": FFNIntermediatePruningMixIn(
+ NemotronHV2FFNIntermediateLayerDescriptor()
+ ),
+ # TODO: Add expert removal support when ExpertRemovalPruningMixIn is migrated
+ }
diff --git a/modelopt/torch/puzzletron/anymodel/models/qwen2/__init__.py b/modelopt/torch/puzzletron/anymodel/models/qwen2/__init__.py
new file mode 100644
index 0000000000..c193fc0d6d
--- /dev/null
+++ b/modelopt/torch/puzzletron/anymodel/models/qwen2/__init__.py
@@ -0,0 +1,19 @@
+# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+# SPDX-License-Identifier: Apache-2.0
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from modelopt.torch.puzzletron.anymodel.models.qwen2.qwen2_converter import Qwen2Converter
+from modelopt.torch.puzzletron.anymodel.models.qwen2.qwen2_model_descriptor import (
+ Qwen2ModelDescriptor,
+)
diff --git a/modelopt/torch/puzzletron/anymodel/models/qwen2/qwen2_converter.py b/modelopt/torch/puzzletron/anymodel/models/qwen2/qwen2_converter.py
new file mode 100644
index 0000000000..878cfd64dc
--- /dev/null
+++ b/modelopt/torch/puzzletron/anymodel/models/qwen2/qwen2_converter.py
@@ -0,0 +1,50 @@
+# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+# SPDX-License-Identifier: Apache-2.0
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# mypy: ignore-errors
+
+"""Qwen2 converter for AnyModel compression."""
+
+from typing import List
+
+from transformers import Qwen2Config
+
+from modelopt.torch.puzzletron.anymodel.converter import Converter, ConverterFactory
+from modelopt.torch.puzzletron.decilm.deci_lm_hf_code.block_config import (
+ AttentionConfig,
+ BlockConfig,
+ FFNConfig,
+)
+
+
+@ConverterFactory.register_decorator("qwen2")
+class Qwen2Converter(Converter):
+ """Converter for Qwen2 models to AnyModel format."""
+
+ @staticmethod
+ def create_block_configs_from_main_config(config: Qwen2Config) -> List[BlockConfig]:
+ """Create uniform block configs for all Qwen2 layers.
+
+ Qwen2 models have uniform architecture across all layers, so we create
+ the same BlockConfig for each layer.
+ """
+ num_hidden_layers = config.num_hidden_layers
+
+ block_config = BlockConfig(
+ attention=AttentionConfig(no_op=False, num_key_value_heads=config.num_key_value_heads),
+ ffn=FFNConfig(no_op=False, intermediate_size=config.intermediate_size),
+ ).to_dict()
+
+ block_configs = [block_config] * num_hidden_layers
+ return block_configs
diff --git a/modelopt/torch/puzzletron/anymodel/models/qwen2/qwen2_model_descriptor.py b/modelopt/torch/puzzletron/anymodel/models/qwen2/qwen2_model_descriptor.py
new file mode 100644
index 0000000000..c2bbeed7a9
--- /dev/null
+++ b/modelopt/torch/puzzletron/anymodel/models/qwen2/qwen2_model_descriptor.py
@@ -0,0 +1,146 @@
+# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+# SPDX-License-Identifier: Apache-2.0
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# mypy: ignore-errors
+
+"""Qwen2 model descriptor for AnyModel compression."""
+
+import re
+from dataclasses import dataclass
+from typing import Dict
+
+from torch import nn
+from transformers.models.qwen2.modeling_qwen2 import (
+ Qwen2DecoderLayer,
+ Qwen2ForCausalLM,
+ Qwen2RotaryEmbedding,
+)
+
+from modelopt.torch.puzzletron.anymodel.model_descriptor import (
+ ModelDescriptor,
+ ModelDescriptorFactory,
+)
+from modelopt.torch.puzzletron.anymodel.models.llama.llama_model_descriptor import (
+ LlamaFFNIntermediateLayerDescriptor,
+)
+from modelopt.torch.puzzletron.anymodel.puzzformer.no_op import (
+ MatchingZeros,
+ Same,
+ return_tuple_of_size,
+)
+from modelopt.torch.puzzletron.decilm.deci_lm_hf_code.block_config import BlockConfig
+from modelopt.torch.puzzletron.utils.dummy_modules import DummyBlock
+
+
+@ModelDescriptorFactory.register_decorator("qwen2")
+class Qwen2ModelDescriptor(ModelDescriptor):
+ """Model descriptor for Qwen2 models."""
+
+ @staticmethod
+ def decoder_layer_cls():
+ return Qwen2DecoderLayer
+
+ @classmethod
+ def create_dummy_block(cls, original_layer: nn.Module, block_index: int) -> nn.Module:
+ """Create a dummy block that preserves Qwen2-specific attributes like attention_type.
+
+ Qwen2's forward pass accesses decoder_layer.attention_type for attention mask selection.
+ """
+ dummy = DummyBlock(block_index=block_index)
+ # Copy attention_type from original layer (required by Qwen2's forward pass)
+ if hasattr(original_layer, "attention_type"):
+ dummy.attention_type = original_layer.attention_type
+ return dummy
+
+ @staticmethod
+ def block_config_to_layer_overrides(block_config: BlockConfig):
+ return {
+ "intermediate_size": block_config.ffn.intermediate_size,
+ "num_key_value_heads": block_config.attention.num_key_value_heads,
+ }
+
+ @staticmethod
+ def attn_no_op_post_init(decoder_layer: Qwen2DecoderLayer):
+ decoder_layer.input_layernorm = Same()
+ decoder_layer.self_attn = return_tuple_of_size(MatchingZeros, size=2)()
+
+ @staticmethod
+ def mlp_no_op_post_init(decoder_layer: Qwen2DecoderLayer):
+ decoder_layer.post_attention_layernorm = Same()
+ decoder_layer.mlp = MatchingZeros()
+
+ @staticmethod
+ def init_rotary_embedding(model: Qwen2ForCausalLM, runtime):
+ model.model.rotary_emb = Qwen2RotaryEmbedding(config=model.config, device=runtime.device)
+
+ @staticmethod
+ def input_embedding_name():
+ return "model.embed_tokens"
+
+ @staticmethod
+ def output_embedding_name():
+ return "lm_head"
+
+ @staticmethod
+ def final_norm_name():
+ return "model.norm"
+
+ @staticmethod
+ def layer_block_name(index: int):
+ return f"model.layers.{index}"
+
+ @staticmethod
+ def layer_name_predicates(num_layers: int) -> Dict[str, re.Pattern]:
+ layer_name_patterns = {
+ "embeddings": re.compile(r"^model\.embed_tokens\.weight$"),
+ "lm_head": re.compile(r"^(model\.norm\.weight|lm_head\.weight)$"),
+ }
+
+ def build_ffn_predicates() -> Dict[str, re.Pattern]:
+ return {
+ f"block_{layer_idx}_ffn": re.compile(
+ rf"^model\.layers\.{layer_idx}\.(post_attention_layernorm\.weight"
+ r"|mlp\.up_proj\.weight"
+ r"|mlp\.gate_proj\.weight"
+ r"|mlp\.down_proj\.weight)$"
+ )
+ for layer_idx in range(num_layers)
+ }
+
+ def build_attention_predicates() -> Dict[str, re.Pattern]:
+ # Qwen2 has biases on attention projections
+ return {
+ f"block_{layer_idx}_attention": re.compile(
+ rf"^model\.layers\.{layer_idx}\.(input_layernorm\.weight"
+ r"|self_attn\.q_proj\.weight"
+ r"|self_attn\.q_proj\.bias"
+ r"|self_attn\.k_proj\.weight"
+ r"|self_attn\.k_proj\.bias"
+ r"|self_attn\.v_proj\.weight"
+ r"|self_attn\.v_proj\.bias"
+ r"|self_attn\.o_proj\.weight)$"
+ )
+ for layer_idx in range(num_layers)
+ }
+
+ layer_name_patterns.update(**build_ffn_predicates(), **build_attention_predicates())
+ return layer_name_patterns
+
+
+@dataclass
+class Qwen2FFNIntermediateLayerDescriptor(LlamaFFNIntermediateLayerDescriptor):
+ """Layer descriptor for Qwen2 FFN intermediate pruning.
+
+ Qwen2 uses the same FFN structure as Llama (gate_proj, up_proj, down_proj).
+ """
diff --git a/modelopt/torch/puzzletron/anymodel/models/qwen3/__init__.py b/modelopt/torch/puzzletron/anymodel/models/qwen3/__init__.py
new file mode 100644
index 0000000000..cf28475718
--- /dev/null
+++ b/modelopt/torch/puzzletron/anymodel/models/qwen3/__init__.py
@@ -0,0 +1,19 @@
+# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+# SPDX-License-Identifier: Apache-2.0
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from modelopt.torch.puzzletron.anymodel.models.qwen3.qwen3_converter import Qwen3Converter
+from modelopt.torch.puzzletron.anymodel.models.qwen3.qwen3_model_descriptor import (
+ Qwen3ModelDescriptor,
+)
diff --git a/modelopt/torch/puzzletron/anymodel/models/qwen3/qwen3_converter.py b/modelopt/torch/puzzletron/anymodel/models/qwen3/qwen3_converter.py
new file mode 100644
index 0000000000..bad9bb47d6
--- /dev/null
+++ b/modelopt/torch/puzzletron/anymodel/models/qwen3/qwen3_converter.py
@@ -0,0 +1,45 @@
+# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+# SPDX-License-Identifier: Apache-2.0
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+# mypy: ignore-errors
+
+from typing import List
+
+from transformers import Qwen3Config
+
+from modelopt.torch.puzzletron.anymodel.converter import Converter, ConverterFactory
+from modelopt.torch.puzzletron.decilm.deci_lm_hf_code.block_config import (
+ AttentionConfig,
+ BlockConfig,
+ FFNConfig,
+)
+
+
+@ConverterFactory.register_decorator("qwen3")
+class Qwen3Converter(Converter):
+ @staticmethod
+ def create_block_configs_from_main_config(config: Qwen3Config) -> List[BlockConfig]:
+ num_hidden_layers = config.num_hidden_layers
+
+ block_configs = [
+ BlockConfig(
+ attention=AttentionConfig(
+ no_op=False, num_key_value_heads=config.num_key_value_heads
+ ),
+ ffn=FFNConfig(no_op=False, intermediate_size=config.intermediate_size),
+ ).to_dict()
+ for _ in range(num_hidden_layers)
+ ]
+ return block_configs
diff --git a/modelopt/torch/puzzletron/anymodel/models/qwen3/qwen3_model_descriptor.py b/modelopt/torch/puzzletron/anymodel/models/qwen3/qwen3_model_descriptor.py
new file mode 100644
index 0000000000..ae70d96617
--- /dev/null
+++ b/modelopt/torch/puzzletron/anymodel/models/qwen3/qwen3_model_descriptor.py
@@ -0,0 +1,152 @@
+# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+# SPDX-License-Identifier: Apache-2.0
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+# mypy: ignore-errors
+
+import re
+from dataclasses import dataclass, field
+from typing import Dict, List
+
+from torch import nn
+from transformers.models.qwen3.modeling_qwen3 import (
+ Qwen3DecoderLayer,
+ Qwen3ForCausalLM,
+ Qwen3RotaryEmbedding,
+)
+
+from modelopt.torch.puzzletron.anymodel.model_descriptor import (
+ ModelDescriptor,
+ ModelDescriptorFactory,
+)
+from modelopt.torch.puzzletron.anymodel.puzzformer.no_op import (
+ MatchingZeros,
+ Same,
+ return_tuple_of_size,
+)
+from modelopt.torch.puzzletron.decilm.deci_lm_hf_code.block_config import BlockConfig
+from modelopt.torch.puzzletron.pruning.ffn_intermediate_pruning_mixin import (
+ FFNIntermediateLayerDescriptor,
+)
+from modelopt.torch.puzzletron.pruning.kv_heads_pruning_mixin import KVHeadsLayerDescriptor
+from modelopt.torch.puzzletron.utils.dummy_modules import DummyBlock
+
+
+@ModelDescriptorFactory.register_decorator("qwen3")
+class Qwen3ModelDescriptor(ModelDescriptor):
+ @staticmethod
+ def decoder_layer_cls():
+ return Qwen3DecoderLayer
+
+ @classmethod
+ def create_dummy_block(cls, original_layer: nn.Module, block_index: int) -> nn.Module:
+ """Create a dummy block that preserves Qwen3-specific attributes like attention_type.
+
+ Qwen3's forward pass accesses decoder_layer.attention_type for attention mask selection.
+ """
+ dummy = DummyBlock(block_index=block_index)
+ # Copy attention_type from original layer (required by Qwen3's forward pass)
+ if hasattr(original_layer, "attention_type"):
+ dummy.attention_type = original_layer.attention_type
+ return dummy
+
+ @staticmethod
+ def block_config_to_layer_overrides(block_config: BlockConfig):
+ return {
+ "intermediate_size": block_config.ffn.intermediate_size,
+ "num_key_value_heads": block_config.attention.num_key_value_heads,
+ }
+
+ @staticmethod
+ def attn_no_op_post_init(decoder_layer: Qwen3DecoderLayer):
+ decoder_layer.input_layernorm = Same()
+ decoder_layer.self_attn = return_tuple_of_size(MatchingZeros, size=2)()
+
+ @staticmethod
+ def mlp_no_op_post_init(decoder_layer: Qwen3DecoderLayer):
+ decoder_layer.post_attention_layernorm = Same()
+ decoder_layer.mlp = MatchingZeros()
+
+ @staticmethod
+ def init_rotary_embedding(model: Qwen3ForCausalLM, runtime):
+ model.model.rotary_emb = Qwen3RotaryEmbedding(model.config, runtime.device)
+
+ @staticmethod
+ def input_embedding_name():
+ return "model.embed_tokens"
+
+ @staticmethod
+ def output_embedding_name():
+ return "lm_head"
+
+ @staticmethod
+ def final_norm_name():
+ return "model.norm"
+
+ @staticmethod
+ def layer_block_name(index: int):
+ return f"model.layers.{index}"
+
+ @staticmethod
+ def layer_name_predicates(num_layers: int) -> Dict[str, re.Pattern]:
+ layer_name_patterns = {
+ "embeddings": re.compile(r"^model\.embed_tokens\.weight$"),
+ "lm_head": re.compile(r"^(model\.norm\.weight|lm_head\.weight)$"),
+ }
+
+ def build_ffn_predicates() -> Dict[str, re.Pattern]:
+ return {
+ f"block_{layer_idx}_ffn": re.compile(
+ rf"^model\.layers\.{layer_idx}\.(post_attention_layernorm\.weight"
+ r"|mlp\.up_proj\.weight"
+ r"|mlp\.gate_proj\.weight"
+ r"|mlp\.down_proj\.weight)$"
+ )
+ for layer_idx in range(num_layers)
+ }
+
+ def build_attention_predicates() -> Dict[str, re.Pattern]:
+ return {
+ f"block_{layer_idx}_attention": re.compile(
+ rf"^model\.layers\.{layer_idx}\.(input_layernorm\.weight"
+ r"|self_attn\.q_proj\.weight"
+ r"|self_attn\.k_proj\.weight"
+ r"|self_attn\.v_proj\.weight"
+ r"|self_attn\.o_proj\.weight"
+ r"|self_attn\.q_norm\.weight"
+ r"|self_attn\.k_norm\.weight)$"
+ )
+ for layer_idx in range(num_layers)
+ }
+
+ layer_name_patterns.update(**build_ffn_predicates(), **build_attention_predicates())
+ return layer_name_patterns
+
+
+@dataclass
+class Qwen3FFNIntermediateLayerDescriptor(FFNIntermediateLayerDescriptor):
+ down_proj_name: str = "mlp.down_proj"
+ ffn_prefix_name: str = "model.layers.{layer_idx}.mlp"
+ linear_weight_names: List[str] = field(
+ default_factory=lambda: ["down_proj", "gate_proj", "up_proj"]
+ )
+
+
+@dataclass
+class Qwen3KVHeadsLayerDescriptor(KVHeadsLayerDescriptor):
+ o_proj_name: str = "self_attn.o_proj"
+ attn_prefix_name: str = "model.layers.{layer_idx}.self_attn"
+ qkvo_weight_names: List[str] = field(
+ default_factory=lambda: ["q_proj", "k_proj", "v_proj", "o_proj"]
+ )
diff --git a/modelopt/torch/puzzletron/anymodel/models/qwen3_vl/__init__.py b/modelopt/torch/puzzletron/anymodel/models/qwen3_vl/__init__.py
new file mode 100644
index 0000000000..48dbd2de8a
--- /dev/null
+++ b/modelopt/torch/puzzletron/anymodel/models/qwen3_vl/__init__.py
@@ -0,0 +1,19 @@
+# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+# SPDX-License-Identifier: Apache-2.0
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from modelopt.torch.puzzletron.anymodel.models.qwen3_vl.qwen3_vl_converter import Qwen3VLConverter
+from modelopt.torch.puzzletron.anymodel.models.qwen3_vl.qwen3_vl_model_descriptor import (
+ Qwen3VLModelDescriptor,
+)
diff --git a/modelopt/torch/puzzletron/anymodel/models/qwen3_vl/qwen3_vl_converter.py b/modelopt/torch/puzzletron/anymodel/models/qwen3_vl/qwen3_vl_converter.py
new file mode 100644
index 0000000000..82e51b7b80
--- /dev/null
+++ b/modelopt/torch/puzzletron/anymodel/models/qwen3_vl/qwen3_vl_converter.py
@@ -0,0 +1,77 @@
+# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+# SPDX-License-Identifier: Apache-2.0
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+# mypy: ignore-errors
+
+from typing import List
+
+from transformers import Qwen3VLMoeConfig
+
+from modelopt.torch.puzzletron.anymodel.converter import Converter, ConverterFactory
+from modelopt.torch.puzzletron.decilm.deci_lm_hf_code.block_config import (
+ AttentionConfig,
+ BlockConfig,
+ FFNConfig,
+ MoEConfig,
+)
+
+
+@ConverterFactory.register_decorator("qwen3_vl")
+class Qwen3VLConverter(Converter):
+ @staticmethod
+ def create_block_configs_from_main_config(config: Qwen3VLMoeConfig) -> List[BlockConfig]:
+ # Qwen3-VL MoE has nested text_config
+ text_config = config.text_config if hasattr(config, "text_config") else config
+
+ num_hidden_layers = text_config.num_hidden_layers
+ decoder_sparse_step = getattr(text_config, "decoder_sparse_step", 1)
+ mlp_only_layers = getattr(text_config, "mlp_only_layers", [])
+
+ block_configs = []
+ for layer_idx in range(num_hidden_layers):
+ # Check if this layer is MoE or dense
+ is_moe_layer = (layer_idx % decoder_sparse_step == 0) and (
+ layer_idx not in mlp_only_layers
+ )
+
+ if is_moe_layer:
+ # MoE layer
+ block_config = BlockConfig(
+ attention=AttentionConfig(
+ no_op=False, num_key_value_heads=text_config.num_key_value_heads
+ ),
+ ffn=FFNConfig(
+ moe=MoEConfig(
+ num_local_experts=text_config.num_experts,
+ expert_intermediate_dim=text_config.moe_intermediate_size,
+ num_experts_per_tok=text_config.num_experts_per_tok,
+ )
+ ),
+ )
+ else:
+ # Dense layer
+ block_config = BlockConfig(
+ attention=AttentionConfig(
+ no_op=False, num_key_value_heads=text_config.num_key_value_heads
+ ),
+ ffn=FFNConfig(no_op=False, intermediate_size=text_config.intermediate_size),
+ )
+
+ block_configs.append(block_config)
+
+ print(
+ f"Created {len(block_configs)} block configs for Qwen3-VL MoE (decoder_sparse_step={decoder_sparse_step})"
+ )
+ return block_configs
diff --git a/modelopt/torch/puzzletron/anymodel/models/qwen3_vl/qwen3_vl_model_descriptor.py b/modelopt/torch/puzzletron/anymodel/models/qwen3_vl/qwen3_vl_model_descriptor.py
new file mode 100644
index 0000000000..8c182c8968
--- /dev/null
+++ b/modelopt/torch/puzzletron/anymodel/models/qwen3_vl/qwen3_vl_model_descriptor.py
@@ -0,0 +1,211 @@
+# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+# SPDX-License-Identifier: Apache-2.0
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+# mypy: ignore-errors
+
+import re
+from dataclasses import dataclass, field
+from typing import Dict, List
+
+from transformers.models.qwen3_vl_moe.modeling_qwen3_vl_moe import (
+ Qwen3VLMoeTextDecoderLayer,
+ Qwen3VLMoeTextRotaryEmbedding,
+ Qwen3VLMoeVisionRotaryEmbedding,
+)
+
+from modelopt.torch.puzzletron.anymodel.model_descriptor import (
+ ModelDescriptor,
+ ModelDescriptorFactory,
+)
+from modelopt.torch.puzzletron.anymodel.puzzformer.no_op import (
+ MatchingZeros,
+ Same,
+ return_tuple_of_size,
+)
+from modelopt.torch.puzzletron.decilm.deci_lm_hf_code.block_config import BlockConfig
+from modelopt.torch.puzzletron.pruning.expert_removal_pruning_mixin import (
+ ExpertRemovalLayerDescriptor,
+)
+from modelopt.torch.puzzletron.pruning.ffn_intermediate_pruning_mixin import (
+ FFNIntermediateLayerDescriptor,
+)
+from modelopt.torch.puzzletron.pruning.kv_heads_pruning_mixin import KVHeadsLayerDescriptor
+
+
+@ModelDescriptorFactory.register_decorator("qwen3_vl")
+class Qwen3VLModelDescriptor(ModelDescriptor):
+ @staticmethod
+ def uses_autocast() -> bool:
+ """
+ Qwen3-VL MoE has a dtype bug in HuggingFace transformers under torch.autocast:
+ scatter() in MoE routing fails with dtype mismatch. Use native bfloat16 instead.
+ See: https://huggingface.co/Qwen/Qwen3-VL-30B-A3B-Instruct (recommended approach)
+ """
+ return False
+
+ @staticmethod
+ def get_language_model_config(config):
+ """Qwen3-VL has nested text_config for language model parameters."""
+ return config.text_config if hasattr(config, "text_config") else config
+
+ @staticmethod
+ def decoder_layer_cls():
+ return Qwen3VLMoeTextDecoderLayer
+
+ @staticmethod
+ def block_config_to_layer_overrides(block_config: BlockConfig):
+ override_kwargs = {"num_key_value_heads": block_config.attention.num_key_value_heads}
+
+ if block_config.ffn.moe:
+ override_kwargs["moe_intermediate_size"] = block_config.ffn.moe.expert_intermediate_dim
+ override_kwargs["num_experts"] = block_config.ffn.moe.num_local_experts
+ else:
+ override_kwargs["intermediate_size"] = block_config.ffn.intermediate_size
+
+ return override_kwargs
+
+ @staticmethod
+ def attn_no_op_post_init(decoder_layer: Qwen3VLMoeTextDecoderLayer):
+ decoder_layer.input_layernorm = Same()
+ decoder_layer.self_attn = return_tuple_of_size(MatchingZeros, size=2)()
+
+ @staticmethod
+ def mlp_no_op_post_init(decoder_layer: Qwen3VLMoeTextDecoderLayer):
+ decoder_layer.post_attention_layernorm = Same()
+ decoder_layer.mlp = MatchingZeros()
+
+ @staticmethod
+ def init_rotary_embedding(model, runtime):
+ # Re-initialize text rotary embedding on correct device and dtype
+ text_config = Qwen3VLModelDescriptor.get_language_model_config(model.config)
+ model.model.language_model.rotary_emb = Qwen3VLMoeTextRotaryEmbedding(
+ config=text_config
+ ).to(device=runtime.device, dtype=runtime.dtype)
+ # Re-initialize vision rotary embedding on correct device and dtype
+ vision_config = (
+ model.config.vision_config if hasattr(model.config, "vision_config") else None
+ )
+ if vision_config is not None:
+ head_dim = vision_config.hidden_size // vision_config.num_heads
+ model.model.visual.rotary_pos_emb = Qwen3VLMoeVisionRotaryEmbedding(head_dim // 2).to(
+ device=runtime.device, dtype=runtime.dtype
+ )
+
+ @staticmethod
+ def input_embedding_name():
+ return "model.language_model.embed_tokens"
+
+ @staticmethod
+ def output_embedding_name():
+ return "lm_head"
+
+ @staticmethod
+ def final_norm_name():
+ return "model.language_model.norm"
+
+ @staticmethod
+ def layer_block_name(index: int):
+ return f"model.language_model.layers.{index}"
+
+ @staticmethod
+ def layer_name_predicates(num_layers: int) -> Dict[str, re.Pattern]:
+ # Qwen3-VL has text model under model.language_model.* prefix
+ layer_name_patterns = {
+ "embeddings": re.compile(r"^model\.language_model\.embed_tokens\.weight$"),
+ "lm_head": re.compile(r"^(model\.language_model\.norm\.weight|lm_head\.weight)$"),
+ # Vision encoder (includes merger under model.visual.deepstack_merger_list.*)
+ "vision_encoding": re.compile(r"^model\.visual\..*"),
+ }
+
+ def build_ffn_predicates() -> Dict[str, re.Pattern]:
+ return {
+ f"block_{layer_idx}_ffn": re.compile(
+ rf"^model\.language_model\.layers\.{layer_idx}\.(post_attention_layernorm\.weight"
+ # MoE router
+ r"|mlp\.gate\.weight"
+ # MoE experts - fused format (gate_up_proj, down_proj without .weight suffix)
+ r"|mlp\.experts\.gate_up_proj"
+ r"|mlp\.experts\.down_proj"
+ # Shared expert (if present)
+ r"|mlp\.shared_expert\.up_proj\.weight"
+ r"|mlp\.shared_expert\.gate_proj\.weight"
+ r"|mlp\.shared_expert\.down_proj\.weight"
+ r"|mlp\.shared_expert_gate\.weight"
+ # Dense MLP fallback (for non-MoE layers)
+ r"|mlp\.up_proj\.weight"
+ r"|mlp\.gate_proj\.weight"
+ r"|mlp\.down_proj\.weight)$"
+ )
+ for layer_idx in range(num_layers)
+ }
+
+ def build_attention_predicates() -> Dict[str, re.Pattern]:
+ return {
+ f"block_{layer_idx}_attention": re.compile(
+ rf"^model\.language_model\.layers\.{layer_idx}\.(input_layernorm\.weight"
+ r"|self_attn\.q_proj\.weight"
+ r"|self_attn\.k_proj\.weight"
+ r"|self_attn\.v_proj\.weight"
+ r"|self_attn\.o_proj\.weight"
+ r"|self_attn\.q_norm\.weight"
+ r"|self_attn\.k_norm\.weight)$"
+ )
+ for layer_idx in range(num_layers)
+ }
+
+ layer_name_patterns.update(**build_ffn_predicates(), **build_attention_predicates())
+ return layer_name_patterns
+
+
+@dataclass
+class Qwen3VLFFNIntermediateLayerDescriptor(FFNIntermediateLayerDescriptor):
+ down_proj_name: str = "mlp.down_proj"
+ ffn_prefix_name: str = "model.language_model.layers.{layer_idx}.mlp"
+ linear_weight_names: List[str] = field(
+ default_factory=lambda: ["down_proj", "gate_proj", "up_proj"]
+ )
+
+
+@dataclass
+class Qwen3VLKVHeadsLayerDescriptor(KVHeadsLayerDescriptor):
+ o_proj_name: str = "self_attn.o_proj"
+ attn_prefix_name: str = "model.language_model.layers.{layer_idx}.self_attn"
+ qkvo_weight_names: List[str] = field(
+ default_factory=lambda: ["q_proj", "k_proj", "v_proj", "o_proj"]
+ )
+
+
+@dataclass
+class Qwen3VLExpertRemovalLayerDescriptor(ExpertRemovalLayerDescriptor):
+ """
+ Qwen3-VL MoE layer descriptor.
+
+ Reference: https://github.com/huggingface/transformers/blob/main/src/transformers/models/qwen3_vl_moe/modeling_qwen3_vl_moe.py
+ - Qwen3VLMoeTextSparseMoeBlock: MoE block with .gate (router) and .experts
+ - Qwen3VLMoeTextTopKRouter: Router with .weight (no bias)
+ - Qwen3VLMoeTextExperts: Fused experts with .gate_up_proj and .down_proj tensors
+ """
+
+ target_name: str = "mlp"
+ moe_prefix_name: str = "model.language_model.layers.{layer_idx}.mlp"
+ # Router: Qwen3VLMoeTextTopKRouter has self.weight, no bias
+ router_weights: List[str] = field(default_factory=lambda: ["gate.weight"])
+ router_biases: List[str] = field(default_factory=list)
+ # Fused expert format: Qwen3VLMoeTextExperts stores all experts in single tensors
+ # with shape [num_experts, ...] instead of separate tensors per expert.
+ is_fused_experts: bool = True
+ fused_expert_weights: List[str] = field(
+ default_factory=lambda: ["experts.gate_up_proj", "experts.down_proj"]
+ )
diff --git a/modelopt/torch/puzzletron/anymodel/puzzformer/__init__.py b/modelopt/torch/puzzletron/anymodel/puzzformer/__init__.py
new file mode 100644
index 0000000000..3af98d57fe
--- /dev/null
+++ b/modelopt/torch/puzzletron/anymodel/puzzformer/__init__.py
@@ -0,0 +1,30 @@
+# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+# SPDX-License-Identifier: Apache-2.0
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""Utilities for patching and transforming HuggingFace models to work with AnyModel.
+
+Provides no-op modules for layer replacement and patching utilities for heterogeneous
+per-layer configurations.
+"""
+
+from modelopt.torch.puzzletron.anymodel.puzzformer.no_op import (
+ MatchingZeros,
+ Same,
+ return_tuple_of_size,
+)
+from modelopt.torch.puzzletron.anymodel.puzzformer.utils import (
+ deci_x_patcher,
+ override_config_with_block_configs,
+)
diff --git a/modelopt/torch/puzzletron/anymodel/puzzformer/no_op.py b/modelopt/torch/puzzletron/anymodel/puzzformer/no_op.py
new file mode 100644
index 0000000000..9b3a9a2190
--- /dev/null
+++ b/modelopt/torch/puzzletron/anymodel/puzzformer/no_op.py
@@ -0,0 +1,79 @@
+# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+# SPDX-License-Identifier: Apache-2.0
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""No-op modules for replacing layers during pruning."""
+
+from functools import cache
+
+import torch
+import torch.nn as nn
+
+
+@cache
+def return_tuple_of_size(cls: type[nn.Module], size: int) -> type[nn.Module]:
+ """Create a wrapper class that returns a tuple of the given size.
+
+ Useful for replacing modules that return multiple outputs (e.g., attention layers
+ that return (hidden_states, attn_weights)).
+
+ Args:
+ cls: The base module class to wrap.
+ size: The size of the tuple to return.
+
+ Returns:
+ A new class that wraps the base class and returns a tuple of the given size.
+
+ Example:
+ >>> decoder_layer.self_attn = return_tuple_of_size(MatchingZeros, size=2)()
+ """
+
+ class Wrapped(cls):
+ def forward(self, *args, **kwargs):
+ result = super().forward(*args, **kwargs)
+ outputs = [None] * size
+ outputs[0] = result if isinstance(result, torch.Tensor) else result[0]
+ return tuple(outputs)
+
+ def extra_repr(self) -> str:
+ return f"[{cls.__name__}]"
+
+ return Wrapped
+
+
+class MatchingZeros(nn.Module):
+ """Module that returns zeros matching the input shape.
+
+ Used to replace MLP or attention layers with no-ops. Returns zeros because
+ the hidden_states are added to the residuals, so a no-op implementation
+ should leave the residual unchanged.
+ """
+
+ def forward(self, hidden_states, *args, **kwargs):
+ return torch.zeros_like(hidden_states)
+
+
+class Same(nn.Module):
+ """Module that returns the input unchanged.
+
+ Used to replace normalization layers with identity operations.
+ """
+
+ def forward(self, hidden_states, *args, **kwargs):
+ return hidden_states
+
+ @property
+ def weight(self):
+ """Support NemotronH with scoring_activations, when lm_head is called `self.lm_head.weight.dtype`."""
+ return torch.empty(0)
diff --git a/modelopt/torch/puzzletron/anymodel/puzzformer/utils.py b/modelopt/torch/puzzletron/anymodel/puzzformer/utils.py
new file mode 100644
index 0000000000..93913b8e2b
--- /dev/null
+++ b/modelopt/torch/puzzletron/anymodel/puzzformer/utils.py
@@ -0,0 +1,122 @@
+# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+# SPDX-License-Identifier: Apache-2.0
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# mypy: ignore-errors
+
+import copy
+import inspect
+from contextlib import ExitStack, contextmanager
+from functools import wraps
+from typing import Any, Dict, List
+
+from transformers import PretrainedConfig
+
+from modelopt.torch.puzzletron.anymodel.model_descriptor.model_descriptor import ModelDescriptor
+from modelopt.torch.puzzletron.decilm.deci_lm_hf_code.block_config import (
+ BlockConfig,
+ maybe_cast_block_configs,
+)
+
+
+def _get_variable_from_stack(names: list[str]) -> Any:
+ """Search the call stack for a variable with one of the given names."""
+ f = inspect.currentframe().f_back
+ while f:
+ for name in names:
+ if name in f.f_locals:
+ return f.f_locals[name]
+ f = f.f_back
+ raise RuntimeError(f"{names} not found in caller stack")
+
+
+@contextmanager
+def deci_x_patcher(
+ model_descriptor: ModelDescriptor,
+ block_configs: List[BlockConfig | dict] | None = None,
+):
+ """Context manager that patches decoder layer __init__ for heterogeneous per-layer configs.
+
+ This is the core mechanism that enables AnyModel to work with any HuggingFace model.
+ It patches the decoder layer class(es) to read per-layer block_configs and apply
+ layer-specific overrides (e.g., different intermediate_size per layer).
+
+ Args:
+ model_descriptor: The model descriptor that defines which classes to patch
+ and how to map block_configs to layer overrides.
+ block_configs: Optional list of BlockConfig (one per layer). If not provided,
+ will try to read from config.block_configs during model initialization.
+
+ Example:
+ >>> with deci_x_patcher(LlamaModelDescriptor, block_configs):
+ ... model = AutoModelForCausalLM.from_config(config)
+ """
+ decoder_layer_classes = model_descriptor.decoder_layer_cls() # Now a list of classes
+ if not isinstance(decoder_layer_classes, list):
+ decoder_layer_classes = [decoder_layer_classes]
+
+ orig_inits = []
+ for cls in decoder_layer_classes:
+ orig_inits.append(cls.__init__)
+
+ block_configs = maybe_cast_block_configs(block_configs)
+
+ @wraps(orig_inits[0])
+ def _patched_decoder_layer_init(self, config, *args, **kwargs):
+ _block_configs = block_configs or getattr(config, "block_configs", None)
+ if _block_configs is None:
+ return orig_inits[decoder_layer_classes.index(self.__class__)](
+ self, config, *args, **kwargs
+ )
+
+ _block_configs = maybe_cast_block_configs(_block_configs)
+ layer_idx = _get_variable_from_stack(["layer_idx", "idx"])
+ _block_config = _block_configs[layer_idx]
+ override_block_config = model_descriptor.block_config_to_layer_overrides(_block_config)
+ _config = override_config_with_block_configs(config, override_block_config)
+ orig_inits[decoder_layer_classes.index(self.__class__)](self, _config, *args, **kwargs)
+
+ # Apply no-op post-init
+ if _block_config.attention.no_op:
+ if not model_descriptor.attn_no_op_supported():
+ raise NotImplementedError(
+ f"attn no-op not supported for `{model_descriptor.__class__.__name__}`, "
+ "please implement the method: `attn_no_op_post_init()`"
+ )
+ model_descriptor.attn_no_op_post_init(decoder_layer=self)
+
+ if _block_config.ffn.no_op:
+ if not model_descriptor.mlp_no_op_supported():
+ raise NotImplementedError(
+ f"mlp no-op not supported for `{model_descriptor.__class__.__name__}`, "
+ "please implement the method: `mlp_no_op_post_init()`"
+ )
+ model_descriptor.mlp_no_op_post_init(decoder_layer=self)
+
+ with ExitStack() as stack:
+ # Patch every decoder layer class
+ for orig_init, cls in zip(orig_inits, decoder_layer_classes):
+ stack.callback(setattr, cls, "__init__", orig_init) # Restore on exit
+ cls.__init__ = _patched_decoder_layer_init
+ yield
+
+
+def override_config_with_block_configs(
+ config: PretrainedConfig, block_configs: Dict[str, Any]
+) -> PretrainedConfig:
+ """Create a copy of config with block_config overrides applied."""
+ _config = copy.deepcopy(config)
+ # Model initialization requires fails with None in case of no-ops
+ _config_overrides = {k: v for k, v in block_configs.items() if v is not None}
+ _config.update(_config_overrides)
+ return _config
diff --git a/modelopt/torch/puzzletron/build_library_and_stats.py b/modelopt/torch/puzzletron/build_library_and_stats.py
new file mode 100644
index 0000000000..31cebdf6be
--- /dev/null
+++ b/modelopt/torch/puzzletron/build_library_and_stats.py
@@ -0,0 +1,90 @@
+#!/usr/bin/env python3
+# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+# SPDX-License-Identifier: Apache-2.0
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""
+Unified command that runs build_replacement_library followed by calc_subblock_stats.
+
+This script combines the functionality of both commands into a single workflow:
+1. First, it builds the replacement library for the puzzle
+2. Then, it calculates subblock statistics
+
+Usage:
+
+ python modelopt.torch.puzzletron.build_library_and_stats.py --config-dir configs --config-name Llama-3_1-8B puzzle_dir=/path/to/puzzle/dir dataset_path=/path/to/dataset
+
+The script uses the same Hydra configuration as the individual commands and supports
+all the same configuration parameters for both build_replacement_library and calc_subblock_stats.
+"""
+
+import hydra
+from omegaconf import DictConfig
+
+from modelopt.torch.puzzletron.replacement_library.build_replacement_library import (
+ launch_build_replacement_library,
+)
+from modelopt.torch.puzzletron.subblock_stats.calc_subblock_stats import launch_calc_subblock_stats
+from modelopt.torch.puzzletron.tools.hydra_utils import register_hydra_resolvers
+from modelopt.torch.puzzletron.tools.logger import mprint
+from modelopt.torch.puzzletron.utils.parsing import format_global_config
+
+
+def launch_build_library_and_stats(cfg: DictConfig) -> None:
+ """
+ Launch both build_replacement_library and calc_subblock_stats in sequence.
+
+ Args:
+ cfg: Hydra configuration containing settings for both commands
+ """
+ mprint("=" * 80)
+ mprint("STARTING UNIFIED BUILD LIBRARY AND STATS WORKFLOW")
+ mprint("=" * 80)
+
+ # Step 1: Build replacement library
+ mprint("=" * 50)
+ mprint("STEP 1: Building Replacement Library")
+ mprint("=" * 50)
+
+ try:
+ launch_build_replacement_library(cfg)
+ mprint("✅ Replacement library built successfully!")
+ except Exception as e:
+ mprint(f"❌ Failed to build replacement library: {e}")
+ raise
+
+ # Step 2: Calculate subblock statistics
+ mprint("=" * 50)
+ mprint("STEP 2: Calculating Subblock Statistics")
+ mprint("=" * 50)
+
+ try:
+ launch_calc_subblock_stats(cfg)
+ mprint("✅ Subblock statistics calculated successfully!")
+ except Exception as e:
+ mprint(f"❌ Failed to calculate subblock statistics: {e}")
+ raise
+
+ mprint("=" * 80)
+ mprint("UNIFIED WORKFLOW COMPLETED SUCCESSFULLY! 🎉")
+ mprint("=" * 80)
+
+ mprint("Generated files:")
+ mprint(f" - {cfg.puzzle_dir}/block_library.json")
+ mprint(f" - {cfg.puzzle_dir}/subblock_library.json")
+ mprint(f" - {cfg.puzzle_dir}/replacement_library.json")
+ mprint(f" - {cfg.puzzle_dir}/single_sequence_replacement_solutions.json")
+ mprint(f" - {cfg.puzzle_dir}/{cfg.calc_subblock_stats.subblock_stats_filename}")
+ if hasattr(cfg.calc_subblock_stats, "moe_stats_filename"):
+ mprint(f" - {cfg.puzzle_dir}/{cfg.calc_subblock_stats.moe_stats_filename}")
diff --git a/modelopt/torch/puzzletron/dataset/__init__.py b/modelopt/torch/puzzletron/dataset/__init__.py
new file mode 100644
index 0000000000..47f1c65a15
--- /dev/null
+++ b/modelopt/torch/puzzletron/dataset/__init__.py
@@ -0,0 +1,15 @@
+# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+# SPDX-License-Identifier: Apache-2.0
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
diff --git a/modelopt/torch/puzzletron/dataset/prepare_dataset.py b/modelopt/torch/puzzletron/dataset/prepare_dataset.py
new file mode 100644
index 0000000000..6f1749697c
--- /dev/null
+++ b/modelopt/torch/puzzletron/dataset/prepare_dataset.py
@@ -0,0 +1,65 @@
+# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+# SPDX-License-Identifier: Apache-2.0
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import os
+
+import datasets
+import fire
+import numpy as np
+
+from modelopt.torch.puzzletron.tools.logger import mprint
+
+
+def process_and_save_dataset(
+ dataset_name: str,
+ output_dir: str,
+ split: tuple = ("code", "math", "stem", "chat"),
+ overwrite: bool = False,
+):
+ # Check if output_dir contains an existing dataset
+ dataset_dict_path = os.path.join(output_dir, "dataset_dict.json")
+ if os.path.exists(output_dir) and os.path.exists(dataset_dict_path):
+ if not overwrite:
+ mprint(
+ f"Output directory '{output_dir}' already contains a dataset. "
+ "Use '--overwrite True' to overwrite existing data."
+ )
+ return
+
+ ds = datasets.load_dataset(dataset_name, split=split)
+ ds = datasets.concatenate_datasets(ds)
+ # Filter out samples with reasoning = on
+ ds = ds.filter(lambda x: x["reasoning"] == "off")
+ # Hardcoded for dynamically create a deterministic train-val split
+ seed = 408
+ generator = np.random.RandomState(seed=seed)
+ ds_split = ds.train_test_split(test_size=0.05, shuffle=True, generator=generator)
+ # Rename dataset names to follow previous conventions
+ ds_dict = datasets.DatasetDict(
+ {
+ "train": ds_split["train"],
+ "valid": ds_split["test"],
+ }
+ )
+ # Save locally
+ os.makedirs(output_dir, exist_ok=True)
+ ds_dict.save_to_disk(output_dir)
+
+ mprint(f"Dataset splits:\n{ds_dict}")
+ mprint(f"Saved processed datasets to {output_dir}")
+
+
+if __name__ == "__main__":
+ fire.Fire(process_and_save_dataset)
diff --git a/modelopt/torch/puzzletron/decilm/deci_lm_hf_code/__init__.py b/modelopt/torch/puzzletron/decilm/deci_lm_hf_code/__init__.py
new file mode 100644
index 0000000000..47f1c65a15
--- /dev/null
+++ b/modelopt/torch/puzzletron/decilm/deci_lm_hf_code/__init__.py
@@ -0,0 +1,15 @@
+# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+# SPDX-License-Identifier: Apache-2.0
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
diff --git a/modelopt/torch/puzzletron/decilm/deci_lm_hf_code/block_config.py b/modelopt/torch/puzzletron/decilm/deci_lm_hf_code/block_config.py
new file mode 100644
index 0000000000..fb630335c6
--- /dev/null
+++ b/modelopt/torch/puzzletron/decilm/deci_lm_hf_code/block_config.py
@@ -0,0 +1,277 @@
+# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+# SPDX-License-Identifier: Apache-2.0
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+# mypy: ignore-errors
+import dataclasses
+import inspect
+import warnings
+from abc import abstractmethod
+from dataclasses import dataclass
+from typing import Any, List, Optional, Type, Union, get_args, get_origin
+
+
+@dataclass(frozen=True, kw_only=True)
+class BaseDataclass:
+ """
+ A dataclass base class with several utilities:
+ 1. Comparison via string representation.
+ 2. Initialization of dataclasses fields from dicts.
+ 3. Setting attributes even though it's frozen (but only inside __post_init__!)
+ """
+
+ def __eq__(self, other: "BaseDataclass") -> bool:
+ return str(self) == str(other)
+
+ def __hash__(self) -> int:
+ return hash(str(self))
+
+ def __lt__(self, other: "BaseDataclass") -> bool:
+ return str(self) < str(other)
+
+ def _force_setattr(self, name: str, value: Any) -> None:
+ """
+ Set an attribute even in frozen dataclasses.
+ Use only inside __post_init__!
+ """
+ assert _is_called_from_post_init(), (
+ "_force_setattr should only be called from __post_init__, "
+ "if you need to change an attribute use dataclasses.replace "
+ "or create a new instance :)"
+ )
+ object.__setattr__(self, name, value)
+
+ def __post_init__(self):
+ """
+ Init dataclass fields from dicts
+ """
+ for field in dataclasses.fields(self):
+ field_dict = getattr(self, field.name)
+ if isinstance(field_dict, dict) and _is_dataclass_type(field.type):
+ dataclass_cls = _get_dataclass_type(field.type)
+ sub_fields = [field.name for field in dataclasses.fields(dataclass_cls)]
+ unsupported_fields = [
+ field_name for field_name in field_dict.keys() if field_name not in sub_fields
+ ]
+ if len(unsupported_fields) > 0:
+ warnings.warn(
+ f"Removed unsupported fields {unsupported_fields} from {dataclass_cls}"
+ )
+
+ field_dict = {k: v for k, v in field_dict.items() if k not in unsupported_fields}
+ self._force_setattr(field.name, dataclass_cls(**field_dict))
+
+
+def _is_called_from_post_init() -> bool:
+ frame = inspect.currentframe()
+ while frame:
+ if frame.f_code.co_name == "__post_init__":
+ return True
+ frame = frame.f_back
+ return False
+
+
+def _is_dataclass_type(tp: Type) -> bool:
+ """
+ Like dataclasses.is_dataclass but also works for Optional[] and Union[] of a dataclass type
+ """
+ try:
+ _get_dataclass_type(tp)
+ return True
+ except:
+ return False
+
+
+def _get_dataclass_type(tp: Type) -> dataclass:
+ """
+ If the given type is a dataclass, the function returns it.
+ If it is a Union[] or Optional[], the function extracts the first dataclass type.
+ If no dataclass type is found, the function raises a ValueError.
+ """
+ origin = get_origin(tp)
+ if origin is Union:
+ for type_in_union in get_args(tp):
+ if dataclasses.is_dataclass(type_in_union):
+ return type_in_union
+ if dataclasses.is_dataclass(tp):
+ return tp
+ raise ValueError("Not a dataclass")
+
+
+@dataclass(frozen=True, kw_only=True)
+class SubblockConfig(BaseDataclass):
+ no_op: bool = False
+ replace_with_linear: bool = False
+ sparsify: Optional[list[str]] = None
+ weights_precision: Optional[str] = "bf16"
+
+ def __post_init__(self):
+ super().__post_init__()
+ assert not (self.no_op and self.replace_with_linear)
+ if self.no_op:
+ self._force_setattr("sparsify", None)
+
+ @abstractmethod
+ def to_blockconfig(self) -> "BlockConfig":
+ """ "
+ Convert to a block including this subblock only.
+ """
+ ...
+
+
+@dataclass(frozen=True, kw_only=True)
+class MoEConfig(BaseDataclass):
+ """
+ Configuration class for Mixture of Experts parameters.
+ """
+
+ num_local_experts: int = 8
+ num_experts_per_tok: int = 1
+ expert_intermediate_dim: int = 8192
+ shared_expert_intermediate_dim: int = 8192
+ # router_aux_loss_coef: float = 0.01
+ # router_z_loss_coef: float = 0.0 # Optional z-loss coefficient
+
+ def __post_init__(self):
+ # Validate the configuration
+ if self.num_local_experts <= 0:
+ raise ValueError(f"num_local_experts must be positive, got {self.num_local_experts}")
+ if self.num_experts_per_tok <= 0:
+ raise ValueError(f"top_k must be positive, got {self.top_k}")
+ if self.num_experts_per_tok > self.num_local_experts:
+ raise ValueError(
+ f"top_k ({self.top_k}) cannot be greater than num_local_experts ({self.num_local_experts})"
+ )
+ # if self.router_aux_loss_coef < 0:
+ # raise ValueError(f"router_aux_loss_coef must be non-negative, got {self.router_aux_loss_coef}")
+
+
+@dataclass(frozen=True, kw_only=True)
+class MambaConfig(BaseDataclass):
+ state_dim: int
+ num_heads: int
+ head_dim: int
+ num_groups: int
+
+
+@dataclass(frozen=True, kw_only=True)
+class Llama4AttentionConfig(BaseDataclass):
+ attention_chunk_size: Optional[int] = None
+ use_rope: Optional[bool] = None
+ use_qk_norm: Optional[bool] = None
+ attn_scale: Optional[float] = None
+ floor_scale: Optional[float] = None
+ attn_temperature_tuning: Optional[bool] = None
+ attention_dropout: Optional[float] = None
+
+
+@dataclass(frozen=True, kw_only=True)
+class AttentionConfig(SubblockConfig):
+ num_key_value_heads: Optional[int] = None
+ llama4: Optional[Llama4AttentionConfig] = None
+ mamba: Optional[MambaConfig] = None
+
+ def __post_init__(self):
+ super().__post_init__()
+
+ if self.no_op:
+ assert not self.is_mamba
+ assert not self.is_llama4
+
+ if self.no_op or self.is_mamba:
+ for irrelevant_att in [
+ "num_key_value_heads",
+ ]:
+ self._force_setattr(irrelevant_att, None)
+ else:
+ assert self.num_key_value_heads is not None
+
+ def to_blockconfig(self) -> "BlockConfig":
+ return BlockConfig(attention=self, ffn=FFNConfig(no_op=True))
+
+ @property
+ def is_llama4(self) -> bool:
+ return self.llama4 is not None
+
+ @property
+ def is_mamba(self) -> bool:
+ return self.mamba is not None
+
+
+@dataclass(frozen=True, kw_only=True)
+class FFNConfig(SubblockConfig):
+ moe: Optional[MoEConfig] = None
+ intermediate_size: Optional[int] = None
+
+ def __post_init__(self):
+ super().__post_init__()
+ if self.no_op:
+ self._force_setattr("moe", None)
+ self._force_setattr("intermediate_size", None)
+ elif self.is_moe:
+ self._force_setattr("intermediate_size", None)
+ else:
+ assert self.intermediate_size is not None, (
+ "Intermediate size must be provided for an FFN block"
+ )
+
+ def to_blockconfig(self) -> "BlockConfig":
+ return BlockConfig(attention=AttentionConfig(no_op=True), ffn=self)
+
+ @property
+ def is_moe(self) -> bool:
+ return self.moe is not None
+
+
+SUBBLOCK_CLS_DICT = {
+ "attention": AttentionConfig,
+ "ffn": FFNConfig,
+}
+
+
+@dataclass(frozen=True, kw_only=True)
+class BlockConfig(BaseDataclass):
+ attention: Optional[AttentionConfig] = None
+ ffn: Optional[FFNConfig] = None
+ parallel_blocks: Optional[list["BlockConfig"]] = None
+
+ def __post_init__(self):
+ super().__post_init__()
+ if (self.parallel_blocks is not None) and isinstance(self.parallel_blocks[0], dict):
+ initialized_block_configs = [
+ BlockConfig(**block_config) for block_config in self.parallel_blocks
+ ]
+ self._force_setattr("parallel_blocks", initialized_block_configs)
+
+ def to_dict(self) -> dict:
+ """Convert BlockConfig to a dictionary."""
+ return dataclasses.asdict(self)
+
+
+def maybe_cast_block_configs(
+ block_configs: List[BlockConfig | dict] | None,
+) -> List[BlockConfig] | None:
+ """Cast a list of dicts to BlockConfig objects if needed.
+
+ Args:
+ block_configs: List of BlockConfig or dict objects, or None.
+
+ Returns:
+ List of BlockConfig objects, or None if input is None/empty.
+ """
+ if not block_configs:
+ return block_configs
+ if isinstance(block_configs[0], dict):
+ return [BlockConfig(**conf) for conf in block_configs]
+ return block_configs
diff --git a/modelopt/torch/puzzletron/decilm/deci_lm_hf_code/modeling_decilm.py b/modelopt/torch/puzzletron/decilm/deci_lm_hf_code/modeling_decilm.py
new file mode 100644
index 0000000000..0a0f8ab1ef
--- /dev/null
+++ b/modelopt/torch/puzzletron/decilm/deci_lm_hf_code/modeling_decilm.py
@@ -0,0 +1,29 @@
+# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+# SPDX-License-Identifier: Apache-2.0
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+# Copyright 2024 Nvidia Corporation, Google Inc, HuggingFace Inc, EleutherAI. All rights reserved.
+#
+# Small nn helpers for puzzletron pipeline code. Model configs come from HuggingFace ``AutoConfig`` (AnyModel).
+# ``LMHead`` is a distinct ``nn.Linear`` subclass so pipeline / FSDP code can target it explicitly
+# (see ``validate_runtime_pipeline``).
+# mypy: ignore-errors
+
+from torch import nn
+
+
+class LMHead(nn.Linear):
+ """
+ Special class to allow FSDP wrapping without affecting other Linear layers in the model.
+ """
diff --git a/modelopt/torch/puzzletron/export/mbridge/__init__.py b/modelopt/torch/puzzletron/export/mbridge/__init__.py
new file mode 100644
index 0000000000..471e68984b
--- /dev/null
+++ b/modelopt/torch/puzzletron/export/mbridge/__init__.py
@@ -0,0 +1,35 @@
+# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+# SPDX-License-Identifier: Apache-2.0
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""Megatron-Bridge adapters for Puzzletron AnyModel checkpoints.
+
+This module provides bridges for converting Puzzletron AnyModel checkpoints
+(heterogeneous layer architectures) to Megatron-Core format via Megatron-Bridge.
+"""
+
+# Import to register bridges (side effect)
+from modelopt.torch.puzzletron.export.mbridge.base import HeterogeneousBridgeMixin
+from modelopt.torch.puzzletron.export.mbridge.llama import ( # noqa: F401
+ PuzzletronLlamaAnyModelBridge,
+)
+from modelopt.torch.puzzletron.export.mbridge.qwen3 import ( # noqa: F401
+ PuzzletronQwen3AnyModelBridge,
+)
+
+__all__ = [
+ "HeterogeneousBridgeMixin",
+ "PuzzletronLlamaAnyModelBridge",
+ "PuzzletronQwen3AnyModelBridge",
+]
diff --git a/modelopt/torch/puzzletron/export/mbridge/base.py b/modelopt/torch/puzzletron/export/mbridge/base.py
new file mode 100644
index 0000000000..13ea6612af
--- /dev/null
+++ b/modelopt/torch/puzzletron/export/mbridge/base.py
@@ -0,0 +1,142 @@
+#!/usr/bin/env python3
+# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+# SPDX-License-Identifier: Apache-2.0
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""
+Mixin class for bridges that support heterogeneous layer architectures.
+
+This module provides a mixin class for converting models with block_configs
+(heterogeneous layer configurations) to Megatron-Core format via Megatron-Bridge.
+"""
+
+import dataclasses
+import json
+from collections.abc import Callable
+from dataclasses import dataclass, fields
+
+from megatron.bridge.models.gpt_provider import GPTModelProvider
+from megatron.bridge.models.hf_pretrained.causal_lm import PreTrainedCausalLM
+from megatron.bridge.models.transformer_config import HeterogeneousTransformerConfig
+from megatron.core.models.gpt.heterogeneous.heterogeneous_layer_specs import (
+ get_gpt_heterogeneous_layer_spec,
+)
+from megatron.core.transformer.spec_utils import ModuleSpec
+
+
+def heterogeneous_layer_spec(config) -> ModuleSpec:
+ """Get GPT heterogeneous layer spec using Transformer Engine."""
+ return get_gpt_heterogeneous_layer_spec(config, use_te=True)
+
+
+@dataclass
+class GenericHeterogeneousProvider(GPTModelProvider, HeterogeneousTransformerConfig):
+ """Generic provider for AnyModel checkpoints with block_configs."""
+
+ # Heterogeneous configuration fields
+ heterogeneous_layers_config_path: str | None = None
+ heterogeneous_layers_config_encoded_json: str = ""
+ transformer_layer_spec: ModuleSpec | Callable = heterogeneous_layer_spec
+
+ def __getattr__(self, name: str):
+ """Handle missing attributes for OmegaConf compatibility.
+
+ Returns empty list for per_block_parameters if not yet initialized (before finalize()).
+ This allows OmegaConf to serialize/deserialize configs without errors. Actual usage
+ should call finalize() first to set per_block_parameters as a real attribute.
+ """
+ if name == "per_block_parameters":
+ # Return existing attribute if set, otherwise [] for OmegaConf compatibility
+ try:
+ return object.__getattribute__(self, name)
+ except AttributeError:
+ return []
+ raise AttributeError(f"'{self.__class__.__name__}' object has no attribute '{name}'")
+
+
+class HeterogeneousBridgeMixin:
+ """Mixin for bridges supporting heterogeneous layer architectures (block_configs).
+
+ Must be used with multiple inheritance alongside a model-specific bridge.
+ Example: class PuzzletronLlamaAnyModelBridge(HeterogeneousBridgeMixin, LlamaBridge)
+ """
+
+ def provider_bridge(self, hf_pretrained: PreTrainedCausalLM) -> GPTModelProvider:
+ """Convert HF AnyModel config to Megatron GPTModelProvider.
+
+ This method:
+ 1. Calls the parent bridge's provider_bridge() to get a GPTModelProvider with all
+ model-specific settings (e.g., LlamaBridge sets normalization="RMSNorm", etc.)
+ 2. Converts the provider to a dict and filters to only fields accepted by
+ GenericHeterogeneousProvider (which inherits from GPTModelProvider, so all valid
+ GPTModelProvider fields are preserved)
+ 3. Adds heterogeneous configuration and returns GenericHeterogeneousProvider
+
+ All parameters from the parent bridge (e.g., LlamaBridge) are maintained because
+ GenericHeterogeneousProvider inherits from GPTModelProvider, which includes all
+ the fields that the parent bridge sets.
+ """
+
+ parent_provider = super().provider_bridge(hf_pretrained) # type: ignore[misc]
+
+ provider_kwargs = dataclasses.asdict(parent_provider)
+
+ # Filter to only fields that GenericHeterogeneousProvider accepts.
+ # GenericHeterogeneousProvider inherits from GPTModelProvider, so it includes all
+ # GPTModelProvider fields. Model-specific fields from subclasses (e.g., MistralModelProvider,
+ # GPTOSSModelProvider) are filtered out because GenericHeterogeneousProvider only inherits
+ # from GPTModelProvider, not from model-specific subclasses.
+ #
+ # Note: This logic may not work for bridges like MistralBridge or GPTOSSBridge if they
+ # use model-specific parameters not supported by GenericHeterogeneousProvider (e.g.,
+ # scale_factor, yarn_rotary_scaling_factor, moe_* parameters). In such cases, create a
+ # model-specific heterogeneous provider that inherits from the model-specific provider.
+ valid_fields = {f.name for f in fields(GenericHeterogeneousProvider)}
+
+ # Only keep kwargs that are valid fields
+ provider_kwargs = {k: v for k, v in provider_kwargs.items() if k in valid_fields}
+
+ provider_kwargs["heterogeneous_layers_config_encoded_json"] = (
+ self._build_heterogeneous_config_json(hf_pretrained.config)
+ )
+ return GenericHeterogeneousProvider(**provider_kwargs)
+
+ def _build_heterogeneous_config_json(self, hf_config) -> str:
+ """Build heterogeneous layers config JSON from HF config."""
+
+ hf_config_dict = json.loads(hf_config.to_json_string())
+
+ mcore_block_configs = [
+ self._convert_block_config(block) for block in hf_config_dict["block_configs"]
+ ]
+ return json.dumps({"block_configs": mcore_block_configs}, ensure_ascii=False)
+
+ def _convert_block_config(self, block: dict) -> dict:
+ """Convert a single block config from HF format to MCore format."""
+ return {
+ "attention": self._convert_attention_config(block["attention"]),
+ "ffn": self._convert_ffn_config(block["ffn"]),
+ }
+
+ def _convert_attention_config(self, attention_config: dict) -> dict:
+ """Convert attention config from HF format to MCore format."""
+ attention_config = attention_config.copy()
+ attention_config["num_query_groups"] = attention_config.pop("num_key_value_heads")
+ return attention_config
+
+ def _convert_ffn_config(self, ffn_config: dict) -> dict:
+ """Convert FFN/MLP config from HF format to MCore format."""
+ ffn_config = ffn_config.copy()
+ ffn_config["ffn_hidden_size"] = ffn_config.pop("intermediate_size")
+ return ffn_config
diff --git a/modelopt/torch/puzzletron/export/mbridge/distillation_provider.py b/modelopt/torch/puzzletron/export/mbridge/distillation_provider.py
new file mode 100644
index 0000000000..fa49dc29c5
--- /dev/null
+++ b/modelopt/torch/puzzletron/export/mbridge/distillation_provider.py
@@ -0,0 +1,190 @@
+# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+# SPDX-License-Identifier: Apache-2.0
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+# TODO: Upstream this fix to Megatron-Bridge and remove this local copy.
+
+import logging
+from dataclasses import dataclass, fields
+from typing import TYPE_CHECKING, Any, Optional
+
+from megatron.bridge.models.gpt_provider import GPTModelProvider
+from megatron.bridge.models.mamba.mamba_provider import MambaModelProvider
+from megatron.bridge.models.transformer_config import TransformerConfig
+from megatron.core.models.gpt import GPTModel as MCoreGPTModel
+
+import modelopt.torch.distill as mtd
+import modelopt.torch.distill.plugins.megatron as mtd_mcore
+
+if TYPE_CHECKING:
+ from megatron.bridge.training.post_training.distillation import ModelOptDistillConfig
+
+
+logger = logging.getLogger(__name__)
+
+
+@dataclass
+class DistillationProvider(TransformerConfig):
+ """Provider for Megatron Core GPT models in distillation mode.
+
+ Please use `convert_to_distillation_provider()` to create an instance of this class.
+ """
+
+ teacher: Optional[GPTModelProvider | MambaModelProvider] = None
+ kd_config: Optional["ModelOptDistillConfig"] = None
+
+ def __init__(self, *args, **kwargs):
+ raise NotImplementedError(
+ "Use `convert_to_distillation_provider()` to create an instance of this class."
+ )
+
+ def __post_init__(self):
+ assert getattr(self, "teacher", None) is not None, "Teacher model must be provided."
+
+ shared_attrs = [
+ "tensor_model_parallel_size",
+ "pipeline_model_parallel_size",
+ "context_parallel_size",
+ "seq_length",
+ "pipeline_dtype",
+ ]
+ for attr in shared_attrs:
+ if getattr(self, attr) != getattr(self.teacher, attr):
+ raise ValueError(f"Student and teacher providers must have the same {attr}.")
+
+ # Logits are overwritten in-place when TE cross-entropy loss is enabled, so switch it back to native version.
+ self.cross_entropy_fusion_impl = "native"
+
+ # Hack to dynamically subclass other providers and still use their methods
+ self._super_class = self.__class__.__bases__[0]
+
+ def provide(self, pre_process=None, post_process=None, vp_stage=None) -> MCoreGPTModel:
+ """Configure and instantiate a ModelOpt DistillationModel based on this configuration.
+
+ Args:
+ pre_process: Whether to include pre-processing in the model, defaults to first pipeline stage
+ post_process: Whether to include post-processing in the model, defaults to last pipeline stage
+ vp_stage: Virtual pipeline stage
+
+ Returns:
+ MCoreGPTModel: Configured ModelOpt DistillationModel instance
+ """
+ if vp_stage is not None:
+ raise ValueError("ModelOpt KD currently does not support virtual-pipeline parallel.")
+
+ assert self.teacher is not None, "Teacher model must be provided."
+ student_model = self._super_class.provide(self, pre_process, post_process, vp_stage) # type: ignore[attr-defined]
+
+ # Finalize teacher provider before creating model (required for heterogeneous models).
+ #
+ # per_block_parameters is an attribute of HeterogeneousTransformerConfig (defined in
+ # MCoreHeterogeneousTransformerConfig, heterogeneous_config.py:197). It's created during
+ # provider creation (bridge.to_megatron_provider()), but finalize() ensures they're consistent
+ # with current parallelism settings and distributed context. Student model creation (above)
+ # initializes parallel_state (process groups, TP/PP config), which weight loading/scatter
+ # requires. During teacher model creation, get_config_for_layer() is called (transformer_block.py:341)
+ # for each layer, which uses per_block_parameters and current tensor_model_parallel_size to
+ # determine layer architecture. Without finalize() in this context, architecture expectations
+ # don't match checkpoint weights, causing:
+ # ValueError: ProcessGroupNCCL::scatter: invalid tensor size at index 0
+ # (expected (2880, 4096), got (3584, 4096))
+ #
+ # Note: This explanation needs to be confirmed yet.
+ self.teacher.finalize()
+
+ # Hack to get teacher's pre-wrap hooks called to potentially load HF weights
+ teacher_model = self.teacher.provide_distributed_model(
+ wrap_with_ddp=False, mixed_precision_wrapper=None
+ )[0]
+
+ kd_cfg = mtd_mcore.setup_distillation_config(
+ self.kd_config, student_model.config, teacher_model.config
+ )
+ modelopt_cfg = {
+ "teacher_model": teacher_model,
+ "criterion": kd_cfg.criterion,
+ "loss_balancer": kd_cfg.loss_balancer,
+ }
+ kd_model = mtd.convert(student_model, mode=[("kd_loss", modelopt_cfg)])
+ mtd_mcore.adjust_distillation_model_for_mcore(kd_model, kd_cfg)
+
+ return kd_model
+
+ def to_cfg_dict(self) -> dict[str, Any]:
+ """Custom method to save equivalent to the original provider class.
+
+ Used by `_ConfigContainerBase` to serialize the main `ConfigContainer` to YAML.
+ There is no need to restore a `DistillationProvider` from the run config file, as
+ it can always be re-converted using the original student provider.
+
+ Returns:
+ Dictionary representation of this provider class
+ """
+ from megatron.bridge.training.utils.config_utils import _ConfigContainerBase
+
+ result = {"_target_": f"{self._super_class.__module__}.{self._super_class.__qualname__}"}
+
+ # Include all fields from the original provider class (self._super_class), not just DistillationProvider
+ # This ensures fields like heterogeneous_layers_config_encoded_json are preserved
+ excluded_fields = {"teacher", "kd_config"}
+ for field in fields(self._super_class):
+ if field.name.startswith("_") or field.name in excluded_fields:
+ continue
+ # Only include if the field exists on this instance (it should, since we converted from the original provider)
+ if hasattr(self, field.name):
+ result[field.name] = _ConfigContainerBase._convert_value_to_dict(
+ getattr(self, field.name)
+ )
+
+ # Also include any additional fields from DistillationProvider itself (if any)
+ for field in fields(self):
+ if field.name.startswith("_") or field.name in excluded_fields:
+ continue
+ # Skip if already included from _super_class
+ if field.name not in result:
+ result[field.name] = _ConfigContainerBase._convert_value_to_dict(
+ getattr(self, field.name)
+ )
+
+ return result
+
+ def __setattr__(self, name, value):
+ super().__setattr__(name, value)
+ # Mirror to teacher if it has that attribute
+ if hasattr(self.teacher, name):
+ setattr(self.teacher, name, value)
+
+
+def convert_to_distillation_provider(
+ student_provider: GPTModelProvider | MambaModelProvider,
+ teacher_provider: GPTModelProvider | MambaModelProvider,
+ kd_config: Optional["ModelOptDistillConfig"] = None,
+) -> "DistillationProvider":
+ """Convert a given model provider to a DistillationProvider."""
+
+ assert isinstance(student_provider, (GPTModelProvider, MambaModelProvider)), (
+ "Student provider must be a subclass of GPTModelProvider or MambaModelProvider."
+ )
+ assert isinstance(teacher_provider, (GPTModelProvider, MambaModelProvider)), (
+ "Teacher provider must be a subclass of GPTModelProvider or MambaModelProvider."
+ )
+
+ DistillationProvider.__bases__ = (type(student_provider),)
+ student_provider.__class__ = DistillationProvider
+
+ student_provider.teacher = teacher_provider
+ student_provider.kd_config = kd_config
+ student_provider.__post_init__()
+
+ return student_provider
diff --git a/modelopt/torch/puzzletron/export/mbridge/export_mbridge_to_hf.py b/modelopt/torch/puzzletron/export/mbridge/export_mbridge_to_hf.py
new file mode 100644
index 0000000000..0ab6083f77
--- /dev/null
+++ b/modelopt/torch/puzzletron/export/mbridge/export_mbridge_to_hf.py
@@ -0,0 +1,81 @@
+# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+# SPDX-License-Identifier: Apache-2.0
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""Export utilities for Megatron-Bridge checkpoints."""
+
+import shutil
+from pathlib import Path
+
+from megatron.bridge import AutoBridge
+
+from modelopt.torch.utils import print_rank_0
+
+
+def export_to_hf_and_copy_config(
+ student_hf_path: str,
+ checkpoint_dir: str,
+ train_iters: int,
+ hf_export_path: str,
+ hf_model: str,
+ trust_remote_code: bool = False,
+) -> None:
+ """
+ Export Megatron checkpoint to HuggingFace format and copy config.json from student model.
+
+ TODO: This script should not be needed (manually copying config.json from
+ student model to exported model). Remove it once export_to_hf() in AutoBridge
+ supports copying/preserving config.json from student model.
+
+
+ Args:
+ student_hf_path: Path to the original student HuggingFace model (source of config.json)
+ checkpoint_dir: Base directory where Megatron checkpoints are stored
+ train_iters: Number of training iterations (used to construct final checkpoint path)
+ hf_export_path: Directory path where the HuggingFace model will be saved
+ hf_model: HuggingFace model ID to use as template for export (e.g., meta-llama/Llama-3.1-8B-Instruct)
+ trust_remote_code: Whether to trust remote modeling code when loading the HF template model
+ """
+ print_rank_0(f"\n{'=' * 80}")
+ print_rank_0("Exporting to HuggingFace format...")
+ print_rank_0(f"{'=' * 80}\n")
+
+ # Construct path to final checkpoint iteration (format: iter_0000100 for 100 iterations)
+ final_iter_dir = Path(checkpoint_dir) / f"iter_{train_iters:07d}"
+ print_rank_0(f"📂 Using final checkpoint: {final_iter_dir}")
+
+ # Use the final iteration directory for export (export_ckpt will validate it exists)
+ megatron_path = str(final_iter_dir)
+
+ # Create bridge using standard model ID (not AnyModel checkpoint) to avoid sharding structure issues
+ print_rank_0("🌉 Creating bridge...")
+ print_rank_0(f" Using model ID: {hf_model}")
+ bridge = AutoBridge.from_hf_pretrained(hf_model, trust_remote_code=trust_remote_code)
+
+ print_rank_0("📤 Exporting to HuggingFace format...")
+ bridge.export_ckpt(
+ megatron_path=megatron_path, hf_path=hf_export_path, show_progress=True, strict=True
+ )
+
+ print_rank_0(f"✅ Successfully exported model to: {hf_export_path}")
+
+ # Copy config.json from student model to exported model (preserves block_configs)
+ student_config_path = Path(student_hf_path) / "config.json"
+ exported_config_path = Path(hf_export_path) / "config.json"
+
+ print_rank_0(f"📋 Copying config.json from student model: {student_config_path}")
+ shutil.copy(student_config_path, exported_config_path)
+ print_rank_0(f"✅ Copied config.json to: {exported_config_path}")
+
+ print_rank_0(f"\n{'=' * 80}")
+ print_rank_0("Export complete!")
diff --git a/modelopt/torch/puzzletron/export/mbridge/llama.py b/modelopt/torch/puzzletron/export/mbridge/llama.py
new file mode 100644
index 0000000000..b802215298
--- /dev/null
+++ b/modelopt/torch/puzzletron/export/mbridge/llama.py
@@ -0,0 +1,38 @@
+#!/usr/bin/env python3
+# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+# SPDX-License-Identifier: Apache-2.0
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""Megatron Bridge for Puzzletron Llama-based AnyModel heterogeneous checkpoints."""
+
+from megatron.bridge.models.conversion.model_bridge import MegatronModelBridge
+from megatron.bridge.models.llama.llama_bridge import LlamaBridge
+from megatron.core.models.gpt.gpt_model import GPTModel
+from transformers import LlamaForCausalLM
+
+from modelopt.torch.puzzletron.export.mbridge.base import HeterogeneousBridgeMixin
+
+
+@MegatronModelBridge.register_bridge(source=LlamaForCausalLM, target=GPTModel)
+class PuzzletronLlamaAnyModelBridge(HeterogeneousBridgeMixin, LlamaBridge):
+ """
+ Megatron Bridge for Puzzletron Llama-based AnyModel checkpoints.
+
+ Extends LlamaBridge with support for heterogeneous layer architectures (block_configs).
+ All Llama-specific settings are inherited from LlamaBridge.
+ """
+
+ # provider_bridge() is inherited from HeterogeneousBridgeMixin
+ # It automatically reuses LlamaBridge.provider_bridge() and adds heterogeneous config
+ # mapping_registry() is inherited from LlamaBridge
diff --git a/modelopt/torch/puzzletron/export/mbridge/qwen3.py b/modelopt/torch/puzzletron/export/mbridge/qwen3.py
new file mode 100644
index 0000000000..ace20fbf89
--- /dev/null
+++ b/modelopt/torch/puzzletron/export/mbridge/qwen3.py
@@ -0,0 +1,38 @@
+#!/usr/bin/env python3
+# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+# SPDX-License-Identifier: Apache-2.0
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""Megatron Bridge for Puzzletron Qwen3-based AnyModel heterogeneous checkpoints."""
+
+from megatron.bridge.models.conversion.model_bridge import MegatronModelBridge
+from megatron.bridge.models.qwen.qwen3_bridge import Qwen3Bridge
+from megatron.core.models.gpt.gpt_model import GPTModel
+from transformers import Qwen3ForCausalLM
+
+from modelopt.torch.puzzletron.export.mbridge.base import HeterogeneousBridgeMixin
+
+
+@MegatronModelBridge.register_bridge(source=Qwen3ForCausalLM, target=GPTModel)
+class PuzzletronQwen3AnyModelBridge(HeterogeneousBridgeMixin, Qwen3Bridge):
+ """
+ Megatron Bridge for Puzzletron Qwen3-based AnyModel checkpoints.
+
+ Extends Qwen3Bridge with support for heterogeneous layer architectures (block_configs).
+ All Qwen3-specific settings are inherited from Qwen3Bridge.
+ """
+
+ # provider_bridge() is inherited from HeterogeneousBridgeMixin
+ # It automatically reuses Qwen3Bridge.provider_bridge() and adds heterogeneous config
+ # mapping_registry() is inherited from Qwen3Bridge
diff --git a/modelopt/torch/puzzletron/mip/mip_and_realize_models.py b/modelopt/torch/puzzletron/mip/mip_and_realize_models.py
new file mode 100644
index 0000000000..17d8e4a2db
--- /dev/null
+++ b/modelopt/torch/puzzletron/mip/mip_and_realize_models.py
@@ -0,0 +1,73 @@
+# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+# SPDX-License-Identifier: Apache-2.0
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""Runs MIP (Mixed Integer Programming) optimization and realizes the resulting model solutions."""
+
+# mypy: ignore-errors
+from pathlib import Path
+
+import torch
+from omegaconf import DictConfig
+
+import modelopt.torch.utils.distributed as dist
+from modelopt.torch.puzzletron.mip.run_puzzle import run_puzzle
+from modelopt.torch.puzzletron.tools.logger import mprint
+from modelopt.torch.puzzletron.tools.validate_puzzle_with_multi_replacements import (
+ validate_puzzle_solutions,
+)
+
+
+def launch_mip(cfg: DictConfig) -> list[str]:
+ solution_paths = run_puzzle(args=cfg.mip)
+ return solution_paths
+
+
+def launch_realize_model(cfg: DictConfig):
+ validate_puzzle_solutions(args=cfg.realize_model)
+
+
+def launch_mip_and_realize_model(cfg: DictConfig) -> list[str]:
+ # Determine device for distributed operations (NCCL requires CUDA tensors)
+ device = "cpu"
+ if dist.size() > 1:
+ if torch.distributed.get_backend() == "nccl":
+ device = torch.cuda.current_device()
+
+ if dist.is_master():
+ solution_paths = launch_mip(cfg)
+ length_tensor = torch.tensor([len(solution_paths)], dtype=torch.long, device=device)
+ else:
+ solution_paths = None
+ length_tensor = torch.tensor([0], dtype=torch.long, device=device)
+
+ if not cfg.skip_realize_model:
+ if dist.size() > 1:
+ torch.distributed.broadcast(length_tensor, src=0)
+
+ list_length = length_tensor.item()
+
+ if not dist.is_master():
+ solution_paths = [None] * list_length
+
+ if dist.size() > 1:
+ torch.distributed.broadcast_object_list(solution_paths, src=0)
+
+ for solution_path in solution_paths:
+ mprint(f"Realize model for the solution: {solution_path}")
+ cfg.realize_model.solutions_path = Path(solution_path)
+ launch_realize_model(cfg)
+ dist.barrier()
+
+ return solution_paths
diff --git a/modelopt/torch/puzzletron/mip/mip_with_multi_layer_replacements.py b/modelopt/torch/puzzletron/mip/mip_with_multi_layer_replacements.py
new file mode 100644
index 0000000000..5b4eccbc15
--- /dev/null
+++ b/modelopt/torch/puzzletron/mip/mip_with_multi_layer_replacements.py
@@ -0,0 +1,203 @@
+# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+# SPDX-License-Identifier: Apache-2.0
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""Solves multi-layer replacement optimization using Mixed Integer Programming."""
+
+# mypy: ignore-errors
+import math
+import warnings
+from collections import defaultdict
+from collections.abc import Hashable, Iterable
+from copy import deepcopy
+from random import random
+from typing import Any, TypeAlias
+
+from mip import BINARY, Model, maximize, minimize, xsum
+
+from modelopt.torch.puzzletron.mip.utils import (
+ consecutive_ngrams,
+ get_nested_key,
+ sort_replacements,
+)
+
+ReplacementID: TypeAlias = Hashable
+Replacement: TypeAlias = dict[str, Any]
+ChosenReplacements: TypeAlias = list[Replacement]
+
+
+def run_mip(
+ replacements: dict[ReplacementID, Replacement],
+ objective: str,
+ constraints: dict[str, float],
+ bigger_is_better: bool,
+ max_seconds_per_solution: float | None = None,
+) -> tuple[ChosenReplacements, float, dict[str, float]]:
+ orig_num_replacements = len(replacements)
+ replacements = {
+ replacement_id: deepcopy(replacement)
+ for replacement_id, replacement in replacements.items()
+ if math.isfinite(get_nested_key(replacement, objective))
+ }
+ if len(replacements) < orig_num_replacements:
+ print("\n\n\n")
+ warnings.warn(
+ f"mip: removed {orig_num_replacements - len(replacements)} replacements with NaN/inf objective value"
+ )
+ print("\n\n\n")
+
+ mip_model = Model()
+
+ objective_vars = []
+ constraint_vars = {constraint_key: [] for constraint_key in constraints}
+ choice_indicators_by_layer = defaultdict(list)
+ for replacement_id, replacement in replacements.items():
+ is_chosen = mip_model.add_var(var_type=BINARY)
+ replacement["is_chosen"] = is_chosen
+
+ for parent_layer_idx in replacement["parent_layer_indices"]:
+ choice_indicators_by_layer[parent_layer_idx].append(is_chosen)
+
+ objective_vars.append(is_chosen * get_nested_key(replacement, objective))
+
+ for constraint_key in constraints:
+ constraint_vars[constraint_key].append(
+ is_chosen * get_nested_key(replacement, constraint_key)
+ )
+
+ # MIP constraints: each parent layer must come from exactly one chosen replacement
+ for parent_layer_idx, curr_choice_indicators in choice_indicators_by_layer.items():
+ mip_model += xsum(curr_choice_indicators) == 1
+
+ # MIP constraints: the sum of chosen replacement costs must be lower than the max cost
+ for constraint_key, max_cost in constraints.items():
+ min_cost = None
+ if isinstance(max_cost, Iterable):
+ min_cost, max_cost = max_cost
+
+ if max_cost is not None:
+ mip_model += xsum(constraint_vars[constraint_key]) <= max_cost
+ if min_cost is not None:
+ mip_model += xsum(constraint_vars[constraint_key]) >= min_cost
+
+ # MIP objective
+ mip_model.objective = (
+ maximize(xsum(objective_vars)) if bigger_is_better else minimize(xsum(objective_vars))
+ )
+
+ if max_seconds_per_solution is not None:
+ mip_model.max_seconds = max_seconds_per_solution
+
+ mip_model.optimize()
+
+ if is_chosen.x is None:
+ return []
+ # raise InfeasibleError()
+
+ # Trust But Verify: calculate total value and costs, and check that all the constraints are filled
+ total_value = 0.0
+ total_costs = dict.fromkeys(constraints.keys(), 0)
+ chosen_replacements: ChosenReplacements = []
+ chosen_layers = []
+ for replacement_id, replacement in replacements.items():
+ is_chosen = replacement["is_chosen"].x >= 0.99
+ if is_chosen:
+ assert replacement not in chosen_replacements
+ chosen_replacements.append(replacement)
+ total_value += get_nested_key(replacement, objective)
+ for constraint_key in constraints:
+ total_costs[constraint_key] += get_nested_key(replacement, constraint_key)
+ for parent_layer_idx in replacement["parent_layer_indices"]:
+ assert parent_layer_idx not in chosen_layers
+ chosen_layers.append(parent_layer_idx)
+
+ missing_layers = set(choice_indicators_by_layer.keys()) - set(chosen_layers)
+ assert len(missing_layers) == 0, (
+ f"The following layers were not chosen by any replacement:\n{missing_layers=}\n{chosen_replacements}"
+ )
+
+ for constraint_key, max_cost in constraints.items():
+ min_cost = None
+ if isinstance(max_cost, Iterable):
+ min_cost, max_cost = max_cost
+
+ if max_cost is not None:
+ assert total_costs[constraint_key] < max_cost or math.isclose(
+ total_costs[constraint_key], max_cost, rel_tol=1e-9
+ ), (
+ f"This max_cost was violated {constraint_key} in the solution, sol val={total_costs[constraint_key]} > {max_cost=}"
+ )
+ if min_cost is not None:
+ assert total_costs[constraint_key] > min_cost or math.isclose(
+ total_costs[constraint_key], min_cost, rel_tol=1e-9
+ ), (
+ f"This min_cost was violated {constraint_key} in the solution, sol val={total_costs[constraint_key]} < {min_cost=}"
+ )
+
+ chosen_replacements = sort_replacements(chosen_replacements)
+ for cr in chosen_replacements:
+ del cr["is_chosen"] # not copyable, will cause errors in deep copy
+ if "block_config" in cr:
+ cr["child_block_configs"] = cr["block_config"]
+ # del cr['block_config'] for now the dump includes both keys (duplicated values) # we might wanna either delete one of them or keep both
+ # I prefer keeping block_config and deleting 'child_block_configs' from previous puzzle steps
+
+ return [
+ {
+ "chosen_replacements": chosen_replacements,
+ "total_value": total_value,
+ "total_costs": total_costs,
+ }
+ ]
+
+
+def usage_example():
+ num_layers = 32
+ num_options_per_parent_replacement = 5
+
+ replacements = dict()
+ for num_layers_in_replacement in (1, 2, 3):
+ for i_option in range(num_options_per_parent_replacement):
+ for parent_layer_indices in consecutive_ngrams(num_layers, num_layers_in_replacement):
+ replacement_id = f"parent layers {parent_layer_indices} child config {i_option}"
+ replacement = {
+ "parent_layer_indices": parent_layer_indices,
+ "metrics": {"loss": random()},
+ "stats": {"memory_mib": random() * 100, "runtime_ms": random() * 10},
+ "replacement_id": replacement_id,
+ }
+ replacements[replacement_id] = replacement
+
+ constraints = {"stats.memory_mib": num_layers * 15.0, "stats.runtime_ms": num_layers * 1.5}
+ (result,) = run_mip(
+ replacements,
+ objective="metrics.loss",
+ constraints=constraints,
+ bigger_is_better=False,
+ )
+ chosen_replacements = result["chosen_replacements"]
+ total_value = result["total_value"]
+ total_costs = result["total_costs"]
+
+ print()
+ print()
+ print(f"{total_value=}")
+ print(f"{total_costs=}")
+ print(f"{constraints=}")
+ print("chosen_replacements=")
+ print("\n".join([rep["replacement_id"] for rep in chosen_replacements]))
+
+
+if __name__ == "__main__":
+ usage_example()
diff --git a/modelopt/torch/puzzletron/mip/run_puzzle.py b/modelopt/torch/puzzletron/mip/run_puzzle.py
new file mode 100644
index 0000000000..803fd83db3
--- /dev/null
+++ b/modelopt/torch/puzzletron/mip/run_puzzle.py
@@ -0,0 +1,760 @@
+# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+# SPDX-License-Identifier: Apache-2.0
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""Main entry point for running the puzzle optimization to find optimal layer configurations."""
+
+# mypy: ignore-errors
+import argparse
+import dataclasses
+import enum
+import json
+import sys
+from collections.abc import Hashable, Iterable
+from copy import deepcopy
+from pathlib import Path
+from typing import Any, Literal, TypeAlias
+
+import numpy as np
+import yaml
+from omegaconf import DictConfig, ListConfig, OmegaConf
+
+from modelopt.torch.puzzletron.anymodel.model_descriptor import (
+ ModelDescriptor,
+ ModelDescriptorFactory,
+)
+from modelopt.torch.puzzletron.decilm.deci_lm_hf_code.block_config import (
+ AttentionConfig,
+ BlockConfig,
+ FFNConfig,
+)
+from modelopt.torch.puzzletron.mip.mip_with_multi_layer_replacements import (
+ run_mip as run_multi_layer_replacement_mip,
+)
+from modelopt.torch.puzzletron.replacement_library.replacement_utils import (
+ extract_block_configs_and_locations,
+ parse_layer_replacement,
+ replacement_is_teacher,
+)
+from modelopt.torch.puzzletron.tools.checkpoint_utils import load_model_config
+from modelopt.torch.puzzletron.tools.logger import mprint
+from modelopt.torch.puzzletron.tools.robust_json import json_dump
+from modelopt.torch.puzzletron.utils.parsing import get_nested_key, parse_json, parse_path
+from modelopt.torch.puzzletron.utils.utils import block_config_to_str, solution_to_str
+
+"""
+Usage:
+Must specify either --single_block_replacement_validation_dir and --subblock_stats_path (in which case the metrics will
+be gathered from the validation output files) or --gathered_metrics_path (in which case the metrics will be read from
+this json file).
+
+Constraints can be specified either as 'mip_constraints' (the actual constraints that go into the MIP, e.g. 'stats.memory_mib', 'stats.runtime_ms'),
+or as "human constraints" (e.g. 'target_memory', 'target_throughput', for the full list see PuzzleConstraints._ALLOWED_HUMAN_CONSTRAINTS).
+
+"""
+
+PuzzleMetrics: TypeAlias = dict[Hashable, dict[Hashable, dict[str, float]]]
+MultiLayerPuzzleMetrics: TypeAlias = dict[str, dict[str, Hashable]]
+
+
+@dataclasses.dataclass
+class PuzzleConstraints:
+ """A set of puzzle constraints can be expressed either directly as the mip constraints (e.g. 'stats.memory_mib') or as human constraints (e.g. 'target_throughput')"""
+
+ class Type(enum.Enum):
+ MIP = "mip"
+ HUMAN = "human"
+
+ _ALLOWED_HUMAN_CONSTRAINTS = {
+ "target_memory",
+ "target_throughput",
+ "target_latency",
+ "target_time_to_first_token",
+ "num_params",
+ "stats.has_attention",
+ }
+ type: Type
+ name: str = dataclasses.field(init=False)
+ constraints: dict[str, Any]
+
+ @staticmethod
+ def sizeof_fmt(num, suffix=""):
+ for unit in ("", "K", "M", "G", "T"):
+ if abs(num) < 1000.0:
+ return f"{num:g}{unit}{suffix}"
+ num /= 1000.0
+ return f"{num:.1f}P{suffix}"
+
+ def _validate_human_constraints(self):
+ illegal_constraints = set(self.constraints.keys()) - self._ALLOWED_HUMAN_CONSTRAINTS
+ if illegal_constraints:
+ raise ValueError(
+ f"The following human_constraints are illegal: {','.join(illegal_constraints)}"
+ )
+
+ def format_num_params_to_float(self, num_params):
+ if isinstance(num_params, list):
+ return [self.format_num_params_to_float(x) for x in num_params]
+ if isinstance(num_params, str):
+ # we only deal with Billions of params scale
+ return float(num_params.replace("B", "")) * 1e9
+ return num_params
+
+ def format_num_params_to_str(self, num_params):
+ if isinstance(num_params, list):
+ return [self.format_num_params_to_str(x) for x in num_params]
+ if isinstance(num_params, float) or isinstance(num_params, int):
+ return f"{num_params / 1e9}B"
+ return num_params
+
+ def __post_init__(self):
+ if self.type == PuzzleConstraints.Type.HUMAN:
+ self._validate_human_constraints()
+
+ if "stats.active_params" in self.constraints:
+ self.constraints["stats.active_params"] = self.format_num_params_to_float(
+ self.constraints["stats.active_params"]
+ )
+
+ # Set self.name
+ constraints = deepcopy(self.constraints) # going to override with "human readable" versions
+ if "stats.active_params" in constraints:
+ constraints["stats.active_params"] = self.format_num_params_to_str(
+ constraints["stats.active_params"]
+ )
+
+ if self.type == PuzzleConstraints.Type.HUMAN:
+ # change values to a more human string form
+ if "target_memory" in constraints:
+ constraints["target_memory"] = str(constraints["target_memory"]) + "MiB"
+ if "num_params" in constraints:
+ constraints["num_params"] = self.sizeof_fmt(constraints["num_params"])
+
+ def build_constraint_name(constraint_name, constraint_value):
+ if isinstance(constraint_value, Iterable) and not isinstance(constraint_value, str):
+ return "-".join(f"{constraint_name}_{x}" for x in constraint_value)
+ else:
+ return f"{constraint_name}_{constraint_value}"
+
+ self.name = "-".join(build_constraint_name(k, v) for k, v in constraints.items()).replace(
+ ".", "_"
+ )
+
+ def to_mip_constraints(self, subblock_stats_args) -> dict[str, Any]:
+ if self.type == PuzzleConstraints.Type.MIP:
+ return self.constraints
+
+ assert all(key in subblock_stats_args for key in ("batch_size", "generation_seq_len")), (
+ "Can't realize human constraints without 'batch_size' and 'generation_seq_len' in subblock_stats_args."
+ )
+ batch_size = subblock_stats_args["batch_size"]
+ generation_seq_len = subblock_stats_args["generation_seq_len"]
+
+ mip_constraints = {}
+
+ # Memory constraints
+ if "target_memory" in self.constraints:
+ mip_constraints["stats.memory_mib"] = self.constraints["target_memory"]
+
+ # Throughput constraints
+ throughput_constraints = []
+ if "target_throughput" in self.constraints:
+ throughput_constraints.append(
+ batch_size * generation_seq_len / self.constraints["target_throughput"]
+ )
+ if "target_latency" in self.constraints:
+ throughput_constraints.append(self.constraints["target_latency"])
+ if throughput_constraints:
+ mip_constraints["stats.runtime_ms"] = 1000 * min(throughput_constraints)
+
+ # Prefill runtime constraint
+ if "target_time_to_first_token" in self.constraints:
+ mip_constraints["stats.prefill_runtime_ms"] = (
+ 1000 * self.constraints["target_time_to_first_token"]
+ )
+
+ # Num params
+ if "num_params" in self.constraints:
+ mip_constraints["stats.num_params"] = self.constraints["num_params"]
+ if "stats.has_attention" in self.constraints:
+ mip_constraints["stats.has_attention"] = self.constraints["stats.has_attention"]
+ return mip_constraints
+
+
+def parse_args() -> DictConfig:
+ parser = argparse.ArgumentParser()
+
+ parser.add_argument("--puzzle_profile", type=parse_path)
+
+ parser.add_argument("--single_block_replacement_validation_dir", type=parse_path, default=None)
+ parser.add_argument(
+ "--gathered_metrics_path",
+ type=parse_path,
+ default=None,
+ help="Can be given explicitly instead of --single_block_replacement_validation_dir",
+ )
+
+ parser.add_argument("--subblock_stats_path", type=parse_path)
+ parser.add_argument("--subblock_stats_args", type=parse_json)
+
+ parser.add_argument("--objective", type=str)
+ parser.add_argument("--mip_constraints", type=parse_json)
+ parser.add_argument("--human_constraints", type=parse_json)
+ parser.add_argument("--report_additional_costs", type=str, action="append", default=[])
+
+ parser.add_argument(
+ "--output_path",
+ type=parse_path,
+ help="The main folder under which all results will be stored.",
+ )
+
+ parser.add_argument("--max_seconds_per_solution", type=float, default=60.0)
+ parser.add_argument("--metric_overrides", type=parse_json, default=None)
+ parser.add_argument(
+ "--bigger_is_better",
+ action="store_true",
+ help="Set this if using accuracy objective, don't set if using loss objective",
+ )
+
+ args = parser.parse_args()
+ return DictConfig(vars(args))
+
+
+def run_single_puzzle_config(
+ args: DictConfig,
+ gathered_metrics: dict,
+ subblock_stats: dict,
+ subblock_stats_args: dict,
+ constraints: PuzzleConstraints,
+ output_folder,
+) -> None:
+ # we override the constraints and subblock_stats_args for this run to keep reporting out the same way.
+ args = deepcopy(args)
+
+ subblock_stats = filter_subblock_stats_by_args(subblock_stats, subblock_stats_args)
+ _add_block_stats_to_gathered_metrics(gathered_metrics, subblock_stats)
+
+ output_folder.mkdir(parents=True, exist_ok=True)
+ _dump_gathered_metrics(gathered_metrics, output_folder)
+
+ non_block_stats = {"stats": _get_block_stats(subblock_stats, "non_block")}
+ batch_size = subblock_stats["args"]["batch_size"]
+ generation_seq_len = subblock_stats["args"]["generation_seq_len"]
+
+ mip_constraints = constraints.to_mip_constraints(subblock_stats["args"])
+ orig_mip_constraints = deepcopy(mip_constraints)
+ mprint(f"Solving for the following MIP constraints: {mip_constraints}")
+ args.mip_constraints = orig_mip_constraints
+ args.human_constraints = (
+ constraints.constraints if constraints.type == PuzzleConstraints.Type.HUMAN else None
+ )
+ args.subblock_stats_args = subblock_stats_args
+
+ for stat_name, max_cost in mip_constraints.items():
+ try:
+ non_block_cost = get_nested_key(non_block_stats, stat_name)
+ except KeyError:
+ non_block_cost = 0
+
+ is_min_max = isinstance(max_cost, Iterable)
+ min_cost = None
+ if is_min_max:
+ min_cost, max_cost = max_cost
+
+ min_cost = min_cost - non_block_cost if (min_cost is not None) else None
+ max_cost = max_cost - non_block_cost if (max_cost is not None) else None
+
+ if is_min_max:
+ mip_constraints[stat_name] = (min_cost, max_cost)
+ else:
+ mip_constraints[stat_name] = max_cost
+
+ # If there's an additional cost that is not a constraint - set it to "inf" so MIP report the actual value of it.
+ for cost in set(args.report_additional_costs) - set(orig_mip_constraints.keys()):
+ mip_constraints[cost] = np.inf
+
+ mprint(f"After non-block adjustments: {mip_constraints=}")
+
+ solutions = run_multi_layer_replacement_mip(
+ replacements=gathered_metrics,
+ objective=args.objective,
+ constraints=mip_constraints,
+ bigger_is_better=args.bigger_is_better,
+ max_seconds_per_solution=args.max_seconds_per_solution,
+ )
+
+ for solution in solutions:
+ for stat_name in set([*orig_mip_constraints.keys(), *args.report_additional_costs]):
+ try:
+ non_block_cost = get_nested_key(non_block_stats, stat_name)
+ except KeyError:
+ non_block_cost = 0
+ solution["total_costs"][stat_name] += non_block_cost
+
+ # Calculate throughput from runtime_ms
+ if "stats.runtime_ms" in solution["total_costs"]:
+ total_runtime = solution["total_costs"]["stats.runtime_ms"]
+ solution["total_costs"]["throughput"] = (
+ 1000 * batch_size * generation_seq_len / total_runtime
+ )
+
+ solution["total_value"] = {args.objective: solution["total_value"]}
+ solution["puzzle_args"] = (
+ OmegaConf.to_container(args, resolve=True)
+ if isinstance(args, DictConfig)
+ else vars(args)
+ )
+ solution["subblock_stats"] = subblock_stats["args"]
+ chosen_block_configs, _ = extract_block_configs_and_locations(
+ solution["chosen_replacements"]
+ )
+ solution["chosen_block_configs"] = chosen_block_configs
+ solution["solution_repr"] = solution_to_str(chosen_block_configs)
+
+ if len(solutions) > 0:
+ solution_repr_0 = solutions[0]["solution_repr"]
+ mprint(f"\n{solution_repr_0}")
+ mprint(f"Total costs: {solutions[0]['total_costs']}")
+ (output_folder / "solution_repr_0.txt").write_text(solution_repr_0)
+
+ solutions_file = output_folder / "solutions.json"
+ json_dump(solutions, solutions_file)
+ mprint(solutions_file)
+ return solutions_file
+
+
+def _dump_gathered_metrics(gathered_metrics: PuzzleMetrics, output_folder: Path) -> None:
+ for replacement_id, replacement_info in gathered_metrics.items():
+ replacement_info["block_repr"] = block_config_to_str(replacement_info["block_config"])
+ gathered_metrics_for_dump = gathered_metrics
+
+ json_dump(gathered_metrics_for_dump, output_folder / "replacement_metrics_and_stats.json")
+
+
+def _load_all_constraints(args, puzzle_profile):
+ def parse_constraints(constraints, constraints_type: PuzzleConstraints.Type):
+ if isinstance(constraints, (list, ListConfig)):
+ return [PuzzleConstraints(type=constraints_type, constraints=c) for c in constraints]
+ elif isinstance(constraints, (dict, DictConfig)):
+ return [PuzzleConstraints(type=constraints_type, constraints=constraints)]
+ raise TypeError(f"Invalid constraints type: {constraints_type}")
+
+ # Constraints can be given explicitely
+ if args.mip_constraints is not None:
+ return parse_constraints(args.mip_constraints, PuzzleConstraints.Type.MIP)
+
+ if args.human_constraints is not None:
+ return parse_constraints(args.human_constraints, PuzzleConstraints.Type.HUMAN)
+
+ # Or through the puzzle_profile
+ if "mip_constraints" in puzzle_profile:
+ return parse_constraints(puzzle_profile["mip_constraints"], PuzzleConstraints.Type.MIP)
+
+ if "human_constraints" in puzzle_profile:
+ return parse_constraints(puzzle_profile["human_constraints"], PuzzleConstraints.Type.HUMAN)
+
+ raise ValueError(
+ "Constraints must be given either explicitely by --mip_constraints or --human_constraints arguments, or through the puzzle_profile."
+ )
+
+
+def _load_all_subblock_stats_args(args, puzzle_profile):
+ # If given explicitely in args
+ if args.subblock_stats_args is not None:
+ if isinstance(args.subblock_stats_args, dict):
+ return [args.subblock_stats_args]
+ else:
+ return args.subblock_stats_args
+
+ # Or can be given inside puzzle_profile
+ if "subblock_stats_args" in puzzle_profile:
+ return puzzle_profile["subblock_stats_args"]
+
+ raise ValueError(
+ "subblock_stats_args must be given either explicitely by the --subblock_stats_args argument, or through the puzzle_profile."
+ )
+
+
+def _override_args_from_profile(args, puzzle_profile):
+ for arg_name in vars(args):
+ if arg_name in puzzle_profile:
+ if arg_name not in ("mip_constraints", "human_constraints", "subblock_stats_args"):
+ setattr(args, arg_name, puzzle_profile[arg_name])
+
+
+def _assert_valid_config(args, puzzle_profile):
+ required_args = (
+ "subblock_stats_path",
+ "objective",
+ "output_path",
+ )
+ missing_args = [arg for arg in required_args if arg not in args or getattr(args, arg) is None]
+ if missing_args:
+ mprint(f"error: The following arguments are required: {', '.join(missing_args)}")
+ sys.exit(1)
+
+ # Make sure we have specified subblock_stats_args
+ if "subblock_stats_args" not in args and "subblock_stats_args" not in puzzle_profile:
+ mprint(
+ "error: Must specify `subblock_stats_args` in either puzzle_profile or as a commandline arg."
+ )
+ sys.exit(1)
+
+ # Make sure we have specified constraints
+ if (
+ "mip_constraints" not in args
+ and "human_constraints" not in args
+ and "mip_constraints" not in puzzle_profile
+ and "human_constraints" not in puzzle_profile
+ ):
+ mprint(
+ "error: Must specify either `mip_constraints` or `human_constraints` in one of puzzle_profile or as a commandline argument."
+ )
+ sys.exit(1)
+
+
+def _get_minimal_unique_names(dicts: list[dict]) -> list[str]:
+ all_keys = set(k for d in dicts for k in d.keys())
+ all_values = {k: set(d[k] for d in dicts if k in d) for k in all_keys}
+ non_common_keys = [k for k, values in all_values.items() if len(values) > 1]
+
+ return ["-".join(f"{k}_{d[k]}".replace(".", "_") for k in non_common_keys) for d in dicts]
+
+
+def run_puzzle(args: DictConfig) -> list[str]:
+ # Loads config from args/puzzle_profile
+ if args.puzzle_profile is not None:
+ with open(args.puzzle_profile) as f:
+ puzzle_profile = yaml.safe_load(f)
+ _override_args_from_profile(args, puzzle_profile)
+ mprint(f"Loaded Puzzle profile from {args.puzzle_profile}")
+ else:
+ puzzle_profile = {}
+ _assert_valid_config(args, puzzle_profile)
+
+ # Read Metrics and Stats
+ if args.gathered_metrics_path is not None:
+ gathered_metrics = json.loads(args.gathered_metrics_path.read_text())
+ else:
+ gathered_metrics = gather_multi_layer_puzle_metrics(
+ args.single_block_replacement_validation_dir
+ )
+
+ if args.metric_overrides is not None:
+ gathered_metrics = {**gathered_metrics, **args.metric_overrides}
+
+ subblock_stats = json.loads(args.subblock_stats_path.read_text())
+
+ all_subblock_args = _load_all_subblock_stats_args(args, puzzle_profile)
+ all_subblock_output_folders = [
+ args.output_path / unique_name
+ for unique_name in _get_minimal_unique_names(all_subblock_args)
+ ]
+ all_constraints = _load_all_constraints(args, puzzle_profile)
+
+ # Running all puzzles
+ solution_paths = []
+ for subblock_stats_args, subblock_stats_output_folder in zip(
+ all_subblock_args, all_subblock_output_folders
+ ):
+ for constraints in all_constraints:
+ output_folder = subblock_stats_output_folder / constraints.name
+ _solution_path = run_single_puzzle_config(
+ args,
+ gathered_metrics,
+ subblock_stats,
+ subblock_stats_args,
+ constraints,
+ output_folder,
+ )
+ solution_paths.append(_solution_path)
+ return solution_paths
+
+
+def gather_puzzle_metrics(
+ single_block_replacement_validation_dir: Path,
+) -> PuzzleMetrics:
+ single_block_metrics = [
+ _parse_single_block_replacement_metrics(metrics_path)
+ for metrics_path in single_block_replacement_validation_dir.glob("*solution*.json")
+ ]
+ all_metric_names = tuple(single_block_metrics[0]["metrics"].keys())
+ teacher_metrics = _parse_teacher_block_metrics(
+ single_block_replacement_validation_dir, all_metric_names
+ )
+
+ n_layer = len(teacher_metrics)
+ gathered_metrics = {f"block_{block_idx}": dict() for block_idx in range(n_layer)}
+ for variant_metrics in single_block_metrics + teacher_metrics:
+ block_config = variant_metrics["block_config"]
+ block_name = f"block_{variant_metrics['block_idx']}"
+ # if we explicitly measure teacher's blocks don't override them
+ gathered_metrics[block_name][block_config] = variant_metrics
+ # if not gathered_metrics[block_name].get(block_config):
+ # gathered_metrics[block_name][block_config] = variant_metrics
+ return gathered_metrics
+
+
+def gather_multi_layer_puzle_metrics(
+ single_replacement_validation_dir: Path,
+) -> MultiLayerPuzzleMetrics:
+ single_sequence_metrics = [
+ _parse_single_sequence_replacement_metrics(metrics_path)
+ for metrics_path in single_replacement_validation_dir.glob("*solution*.json")
+ ]
+ all_metric_names = tuple(single_sequence_metrics[0]["metrics"].keys())
+ teacher_metrics = _parse_teacher_block_metrics(
+ single_replacement_validation_dir, all_metric_names
+ )
+
+ gathered_metrics = {
+ f"replacement_{replacement_id}": replacement_metrics
+ for replacement_id, replacement_metrics in enumerate(
+ single_sequence_metrics + teacher_metrics
+ )
+ }
+
+ return gathered_metrics
+
+
+def _parse_single_block_replacement_metrics(metrics_path: Path) -> dict:
+ raw_metrics = json.loads(metrics_path.read_text())
+ single_block_replacement = raw_metrics["puzzle_solution"]["single_block_replacement"]
+ variant_metrics = {
+ "block_config": BlockConfig(**single_block_replacement["block_config"]),
+ "block_idx": single_block_replacement["block_idx"],
+ "metrics": _extract_average_metrics(raw_metrics),
+ }
+ return variant_metrics
+
+
+def _parse_single_sequence_replacement_metrics(metrics_path: Path) -> dict:
+ raw_metrics = json.loads(metrics_path.read_text())
+ single_sequence_replacement = raw_metrics["puzzle_solution"]["single_sequence_replacement"]
+ if len(single_sequence_replacement["child_block_configs"]) > 1:
+ raise NotImplementedError(
+ "Currently we only support many-to-1 replacements, but we can support many-to-many! "
+ )
+ variant_metrics = {
+ "block_config": BlockConfig(**single_sequence_replacement["child_block_configs"][0]),
+ # is there cases where child_block_configs has more than one entry?
+ "parent_layer_indices": single_sequence_replacement["parent_layer_indices"],
+ "metrics": _extract_average_metrics(raw_metrics),
+ "layer_replacement": parse_layer_replacement(single_sequence_replacement),
+ "is_teacher": False,
+ }
+ return variant_metrics
+
+
+def _parse_teacher_block_metrics(
+ single_block_replacement_validation_dir: Path,
+ all_metric_names: Iterable[str] = ("kl_div_loss",),
+) -> list[dict]:
+ raw_metrics = json.loads((single_block_replacement_validation_dir / "teacher.json").read_text())
+ teacher_checkpoint_dir = Path(raw_metrics["args"]["teacher_dir"]).resolve()
+ descriptor_name = raw_metrics["args"]["descriptor"]
+ descriptor = ModelDescriptorFactory.get(descriptor_name)
+ trust_remote_code = descriptor.requires_trust_remote_code()
+ teacher_model_config = load_model_config(
+ teacher_checkpoint_dir, trust_remote_code=trust_remote_code
+ )
+
+ teacher_replacements = None
+ replacement_library_path = raw_metrics["args"].get("replacement_library_path")
+ if replacement_library_path is not None:
+ teacher_replacements = dict()
+ all_layer_replacements = json.loads(Path(replacement_library_path).read_text())
+ for layer_replacement in all_layer_replacements:
+ layer_replacement = parse_layer_replacement(layer_replacement)
+ if replacement_is_teacher(
+ layer_replacement, teacher_model_config, teacher_checkpoint_dir
+ ):
+ block_idx = layer_replacement["parent_layer_indices"][0]
+ teacher_replacements[block_idx] = layer_replacement
+
+ teacher_metrics = [
+ {
+ "block_config": block_config,
+ "block_idx": block_idx,
+ "parent_layer_indices": [block_idx],
+ "metrics": {
+ **dict.fromkeys(all_metric_names, 0.0), # default value 0. for teacher
+ **_extract_average_metrics(raw_metrics), # override with real value if exists
+ },
+ **(
+ {"layer_replacement": teacher_replacements[block_idx]}
+ if teacher_replacements is not None
+ else {}
+ ),
+ "is_teacher": True,
+ }
+ for block_idx, block_config in enumerate(teacher_model_config.block_configs)
+ ]
+ return teacher_metrics
+
+
+def _extract_average_metrics(raw_metrics: dict[str, dict]) -> dict[str, float]:
+ average_metrics = dict()
+ for metric_name in raw_metrics:
+ metric_dict = raw_metrics[metric_name]
+ if isinstance(metric_dict, dict) and ("avg" in metric_dict.keys()):
+ metric_value = raw_metrics[metric_name]["avg"]
+ average_metrics[metric_name] = metric_value
+ average_metrics[f"one_minus_{metric_name}"] = 1 - metric_value
+ return average_metrics
+
+
+def filter_subblock_stats_by_args(
+ all_subblock_stats: list[dict],
+ subblock_stats_args: dict[str, Any],
+ convert_dicts_to_dataclasses: bool = True,
+) -> dict[str, dict]:
+ matching_subblock_stats = [
+ subblock_stats
+ for subblock_stats in all_subblock_stats
+ if _dict_is_subset(subblock_stats_args, subblock_stats["args"])
+ ]
+ assert len(matching_subblock_stats) == 1, (
+ "The provided subblock_stats_args should match exactly one measurement "
+ f"scenario, instead matched {len(matching_subblock_stats)}:\n"
+ f"{[m['args'] for m in matching_subblock_stats]}"
+ )
+ subblock_stats = deepcopy(matching_subblock_stats[0])
+
+ if convert_dicts_to_dataclasses:
+ class_name_to_class = {klass.__name__: klass for klass in [AttentionConfig, FFNConfig]}
+ subblocks_dict = dict()
+ for substats in subblock_stats["subblocks"]:
+ subblock_config_class = class_name_to_class[substats.pop("subblock_config_class")]
+ subblock_config = subblock_config_class(**substats.pop("subblock_config"))
+ dict_key = (subblock_config, None)
+ if "parent_layer_index" in substats:
+ dict_key = (subblock_config, substats["parent_layer_index"])
+ subblocks_dict[dict_key] = substats
+ subblock_stats["subblocks"] = subblocks_dict
+ return subblock_stats
+
+
+def _dict_is_subset(dict1: dict, dict2: dict) -> bool:
+ return all(item in dict2.items() for item in dict1.items())
+
+
+def _add_block_stats_to_gathered_metrics(
+ gathered_metrics: PuzzleMetrics, subblock_stats: dict
+) -> None:
+ for block_name, block_variants in gathered_metrics.items():
+ parent_layer_index = None
+ if "parent_layer_indices" in block_variants:
+ parent_layer_index = block_variants["parent_layer_indices"][0]
+
+ if "metrics" in block_variants:
+ # this is a sequence stats object for multi-layer puzzle
+ block_variants["stats"] = _get_block_stats(
+ subblock_stats, block_variants["block_config"], parent_layer_index
+ )
+ else:
+ for block_config, variant_metrics in block_variants.items():
+ variant_metrics["stats"] = _get_block_stats(subblock_stats, block_config)
+
+
+def _get_block_stats(
+ subblock_stats: dict,
+ block_config: BlockConfig | Literal["non_block"],
+ parent_layer_index: int = None,
+) -> dict[str, float]:
+ if block_config == "non_block":
+ return subblock_stats["non_block"]
+
+ if block_config.parallel_blocks is None:
+ attention_key = (block_config.attention, parent_layer_index)
+ ffn_key = (block_config.ffn, parent_layer_index)
+ attention_stats = subblock_stats["subblocks"][attention_key]
+ ffn_stats = subblock_stats["subblocks"][ffn_key]
+ assert set(attention_stats.keys()) == set(ffn_stats.keys())
+
+ block_stats = dict()
+ for k in attention_stats.keys():
+ block_stats[k] = _none_add(attention_stats[k], ffn_stats[k])
+ block_stats[f"attention_{k}"] = attention_stats[k]
+ block_stats[f"ffn_{k}"] = ffn_stats[k]
+
+ block_stats["has_attention"] = int(
+ not block_config.attention.no_op and block_config.attention.mamba is None
+ )
+ block_stats["has_ffn"] = int(not block_config.ffn.no_op)
+ block_stats["has_moe"] = int(block_config.ffn.moe is not None)
+ block_stats["not_no_op"] = int(
+ not (block_config.attention.no_op and block_config.ffn.no_op)
+ )
+ block_stats["num_kv_heads"] = (
+ block_config.attention.num_key_value_heads if block_stats["has_attention"] else 0
+ )
+ block_stats["num_local_experts"] = (
+ block_config.ffn.moe.num_local_experts if block_stats["has_moe"] else 0
+ )
+
+ return block_stats
+
+ # this is a parallel block
+ ADDITIVE_METRICS = ("memory_mib", "num_params", "kv_cache_memory_mib")
+ ADDITIVE_METRICS = [
+ f"{prefix}{metric}" for prefix in ("", "attention_", "ffn_") for metric in ADDITIVE_METRICS
+ ]
+ block_stats = [
+ _get_block_stats(subblock_stats, sub_parallel)
+ for sub_parallel in block_config.parallel_blocks
+ ]
+ block_stats = {
+ k: _none_add_list([sub_parallel_stat[k] for sub_parallel_stat in block_stats])
+ if k in ADDITIVE_METRICS
+ else _none_max_list([sub_parallel_stat[k] for sub_parallel_stat in block_stats])
+ for k in block_stats[0].keys()
+ }
+
+ return block_stats
+
+
+def _none_add(a: float | int | None, b: float | int | None) -> float | int | None:
+ if a is None or b is None:
+ return None
+ return a + b
+
+
+def _none_max(a: float | int | None, b: float | int | None) -> float | int | None:
+ if a is None or b is None:
+ return None
+ return max(a, b)
+
+
+def _none_add_list(l) -> float | int | None:
+ r = l[0]
+ if len(l) == 1:
+ return r
+ for e in l[1:]:
+ r = _none_add(r, e)
+ return r
+
+
+def _none_max_list(l) -> float | int | None:
+ r = l[0]
+ if len(l) == 1:
+ return r
+ for e in l[1:]:
+ r = _none_max(r, e)
+ return r
+
+
+if __name__ == "__main__":
+ args = parse_args()
+ run_puzzle(args)
diff --git a/modelopt/torch/puzzletron/mip/sweep.py b/modelopt/torch/puzzletron/mip/sweep.py
new file mode 100644
index 0000000000..82d9b11e12
--- /dev/null
+++ b/modelopt/torch/puzzletron/mip/sweep.py
@@ -0,0 +1,286 @@
+# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+# SPDX-License-Identifier: Apache-2.0
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""MIP sweep functionality for exploring multiple memory compression rates."""
+
+import json
+from pathlib import Path
+from typing import Any
+
+from omegaconf import DictConfig, OmegaConf
+from transformers import PretrainedConfig
+
+import modelopt.torch.puzzletron.anymodel.models # noqa: F401 — register ModelDescriptorFactory entries
+import modelopt.torch.puzzletron.mip.mip_and_realize_models as mip_and_realize_models
+import modelopt.torch.utils.distributed as dist
+from modelopt.torch.puzzletron.anymodel.model_descriptor.model_descriptor_factory import (
+ ModelDescriptorFactory,
+)
+from modelopt.torch.puzzletron.mip.run_puzzle import _get_block_stats, filter_subblock_stats_by_args
+from modelopt.torch.puzzletron.tools.checkpoint_utils_hf import load_model_config
+from modelopt.torch.puzzletron.tools.logger import mprint
+
+
+def _load_teacher_subblock_stats(hydra_cfg: DictConfig) -> tuple[dict[str, Any], PretrainedConfig]:
+ """Load filtered subblock_stats and teacher ``model_config`` for the current MIP scenario."""
+ puzzle_dir = Path(hydra_cfg.puzzle_dir)
+ teacher_dir = Path(hydra_cfg.teacher_dir)
+
+ descriptor = ModelDescriptorFactory.get(hydra_cfg.descriptor)
+ trust_remote_code = descriptor.requires_trust_remote_code()
+ model_config = load_model_config(teacher_dir, trust_remote_code=trust_remote_code)
+ lm_config = descriptor.get_language_model_config(model_config)
+ hidden_size = lm_config.hidden_size
+
+ mip_subblock_args = hydra_cfg.mip.subblock_stats_args[0]
+ subblock_stats_args = OmegaConf.to_container(mip_subblock_args, resolve=True)
+ # Subblock_stats.json can list multiple runs that share batch/dtypes but differ by hidden size;
+ # filter_subblock_stats_by_args needs n_embd so exactly one row matches the teacher.
+ subblock_stats_args = {**subblock_stats_args, "n_embd": hidden_size}
+
+ batch_size = subblock_stats_args["batch_size"]
+ weights_dtype = str(subblock_stats_args["weights_dtype"])
+ activations_dtype = str(subblock_stats_args["activations_dtype"])
+ kv_cache_dtype = str(subblock_stats_args["kv_cache_dtype"])
+
+ subblock_stats_path = puzzle_dir / "subblock_stats.json"
+ if not subblock_stats_path.exists():
+ raise FileNotFoundError(
+ f"subblock_stats.json not found at {subblock_stats_path}. "
+ "Please run the full pipeline first without --mip-only flag."
+ )
+
+ with open(subblock_stats_path) as f:
+ subblock_stats_list = json.load(f)
+
+ try:
+ subblock_stats = filter_subblock_stats_by_args(subblock_stats_list, subblock_stats_args)
+ except AssertionError as e:
+ raise ValueError(
+ f"No unique subblock_stats entry for batch_size={batch_size}, "
+ f"dtypes=({weights_dtype}, {activations_dtype}, {kv_cache_dtype}), "
+ f"n_embd={hidden_size}"
+ ) from e
+
+ return subblock_stats, model_config
+
+
+def get_teacher_memory_from_subblock_stats(hydra_cfg: DictConfig) -> float:
+ """Calculate teacher model memory from subblock_stats.json.
+
+ Sums ``non_block`` and per-layer ``_get_block_stats(subblock_stats, block_config, layer_index)``
+ over ``model_config.block_configs``, matching :func:`run_puzzle._get_block_stats`.
+
+ Args:
+ hydra_cfg: Hydra configuration object
+
+ Returns:
+ Total teacher memory in MiB
+ """
+ subblock_stats, model_config = _load_teacher_subblock_stats(hydra_cfg)
+
+ total_memory = subblock_stats.get("non_block", {}).get("memory_mib", 0.0)
+
+ for layer_idx, block_config in enumerate(model_config.block_configs):
+ block_stats = _get_block_stats(subblock_stats, block_config, layer_idx)
+ total_memory += block_stats["memory_mib"]
+
+ return total_memory
+
+
+def get_teacher_num_params_from_subblock_stats(hydra_cfg: DictConfig) -> int:
+ """Calculate total teacher parameter count from subblock_stats.json.
+
+ Sums ``non_block`` and per-layer ``_get_block_stats(...)["num_params"]`` over
+ ``model_config.block_configs``, matching :func:`run_puzzle._get_block_stats`.
+
+ Args:
+ hydra_cfg: Hydra configuration object
+
+ Returns:
+ Total teacher parameter count (same units as subblock_stats JSON).
+ """
+ subblock_stats, model_config = _load_teacher_subblock_stats(hydra_cfg)
+
+ total_params = subblock_stats.get("non_block", {}).get("num_params", 0)
+
+ for layer_idx, block_config in enumerate(model_config.block_configs):
+ block_stats = _get_block_stats(subblock_stats, block_config, layer_idx)
+ total_params += block_stats["num_params"]
+
+ return int(total_params)
+
+
+def extract_solution_results(
+ solution_path: Path,
+ target_memory_mib: float,
+ teacher_memory_mib: float,
+ compression_rate: float,
+) -> dict:
+ """Extract results from a completed MIP solution.
+
+ Args:
+ solution_path: Path to the solutions.json file (not the directory!)
+ target_memory_mib: Target memory constraint used for MIP
+ teacher_memory_mib: Teacher model memory in MiB
+ compression_rate: Compression rate applied
+
+ Returns:
+ Dictionary containing extracted metrics
+ """
+ result = {
+ "compression_rate": compression_rate,
+ "target_memory_mib": target_memory_mib,
+ "teacher_memory_mib": teacher_memory_mib,
+ }
+
+ # solution_path is the path to solutions.json file, get parent directory
+ solution_dir = solution_path.parent
+
+ # Load solutions.json for actual memory and parameters
+ solutions_file = solution_dir / "solutions.json"
+ with open(solutions_file) as f:
+ solutions_data = json.load(f)
+ solution = solutions_data[0] # First solution
+ total_costs = solution.get("total_costs", {})
+ result["actual_memory_mib"] = total_costs.get("stats.memory_mib", None)
+ result["num_params"] = total_costs.get("stats.num_params", None)
+
+ # Load solution_0.json for accuracy metrics
+ validation_dir = solution_dir / "solutions--validation"
+ # TODO: There could be multiple solutions, but we only need the first one. Is it the best solution?
+ solution_0_file = validation_dir / "solution_0.json"
+
+ with open(solution_0_file) as f:
+ validation_data = json.load(f)
+ result["lm_loss"] = validation_data.get("lm_loss", {}).get("avg", None)
+ result["token_accuracy_top_1"] = validation_data.get("token_accuracy_top_1", {}).get(
+ "avg", None
+ )
+ result["token_accuracy_top_5"] = validation_data.get("token_accuracy_top_5", {}).get(
+ "avg", None
+ )
+ result["token_accuracy_top_10"] = validation_data.get("token_accuracy_top_10", {}).get(
+ "avg", None
+ )
+
+ return result
+
+
+def write_results_to_csv(results: list, output_csv: str):
+ """Write sweep results to CSV file.
+
+ Args:
+ results: List of result dictionaries
+ output_csv: Path to output CSV file
+ """
+ import csv
+
+ # Define CSV columns in desired order
+ columns = [
+ "compression_rate",
+ "target_memory_mib",
+ "actual_memory_mib",
+ "teacher_memory_mib",
+ "num_params",
+ "lm_loss",
+ "token_accuracy_top_1",
+ "token_accuracy_top_5",
+ "token_accuracy_top_10",
+ ]
+
+ # Write CSV
+ output_path = Path(output_csv)
+ output_path.parent.mkdir(parents=True, exist_ok=True)
+
+ with open(output_path, "w", newline="") as f:
+ writer = csv.DictWriter(f, fieldnames=columns)
+ writer.writeheader()
+ writer.writerows(results)
+
+ mprint(f"Results written to: {output_path}")
+
+
+def run_mip_sweep(hydra_cfg):
+ """Run MIP for multiple memory compression rates and generate CSV with results.
+
+ This function is called when mip.sweep.enabled is True in the config.
+
+ Args:
+ hydra_cfg: Hydra configuration object with mip.sweep settings
+ """
+ mprint("=" * 80)
+ mprint("MIP Sweep Mode Enabled")
+ mprint("=" * 80)
+
+ # Get sweep configuration
+ sweep_cfg = hydra_cfg.mip.sweep
+ compression_rates = sweep_cfg.memory_compression_rates
+ output_csv = sweep_cfg.output_csv
+ puzzle_dir = Path(hydra_cfg.puzzle_dir)
+
+ mprint(f"Compression rates: {compression_rates}")
+ mprint(f"Output CSV: {output_csv}")
+ mprint(f"Puzzle directory: {puzzle_dir}")
+
+ # Calculate teacher memory from subblock_stats
+ teacher_memory = get_teacher_memory_from_subblock_stats(hydra_cfg)
+ mprint(
+ f"Teacher memory (from subblock_stats): {teacher_memory:.1f} MiB ({teacher_memory / 1024:.1f} GiB)"
+ )
+
+ # Collect results
+ all_results = []
+
+ # Run MIP for each compression rate
+ for compression_rate in compression_rates:
+ target_memory_mib = teacher_memory * compression_rate
+ mprint("\n" + "=" * 80)
+ mprint(
+ f"Running MIP for compression_rate={compression_rate:.2f} "
+ f"(target={target_memory_mib:.1f} MiB = {target_memory_mib / 1024:.1f} GiB)"
+ )
+ mprint("=" * 80)
+
+ # Modify config dynamically
+ hydra_cfg.mip.human_constraints.target_memory = target_memory_mib
+
+ # Run MIP and realize models (reuse existing distributed logic!)
+ solution_paths = mip_and_realize_models.launch_mip_and_realize_model(hydra_cfg)
+
+ # Extract results (only on master rank)
+ if dist.is_master():
+ for solution_path in solution_paths:
+ result = extract_solution_results(
+ solution_path=Path(solution_path),
+ target_memory_mib=target_memory_mib,
+ teacher_memory_mib=teacher_memory,
+ compression_rate=compression_rate,
+ )
+ all_results.append(result)
+
+ mprint(
+ f"✓ Results: actual_memory={result['actual_memory_mib']:.1f} MiB, "
+ f"lm_loss={result['lm_loss']:.4f}"
+ )
+
+ # Write results to CSV (only on master rank)
+ if dist.is_master():
+ mprint("\n" + "=" * 80)
+ mprint("MIP Sweep Complete - Writing Results")
+ mprint("=" * 80)
+ write_results_to_csv(all_results, output_csv)
+ mprint(f"Completed {len(all_results)} sweep runs")
+ mprint("=" * 80)
diff --git a/modelopt/torch/puzzletron/mip/utils.py b/modelopt/torch/puzzletron/mip/utils.py
new file mode 100644
index 0000000000..b276ff33b1
--- /dev/null
+++ b/modelopt/torch/puzzletron/mip/utils.py
@@ -0,0 +1,73 @@
+# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+# SPDX-License-Identifier: Apache-2.0
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""Utility functions for MIP optimization."""
+
+from typing import Any
+
+
+class InfeasibleError(Exception):
+ """Exception raised when optimization problem is infeasible."""
+
+
+def sort_replacements(layer_replacements: list[dict]) -> list[dict]:
+ """Sort layer replacements by parent layer indices.
+
+ Args:
+ layer_replacements: List of replacement dictionaries containing "parent_layer_indices"
+
+ Returns:
+ Sorted list of replacements
+ """
+ return sorted(layer_replacements, key=lambda replacement: replacement["parent_layer_indices"])
+
+
+def get_nested_key(dictionary: dict[str, Any], nested_key: str) -> Any:
+ """Access nested dictionary values using dot notation.
+
+ If nested_key is "a.b.c" returns dictionary["a"]["b"]["c"]
+
+ Args:
+ dictionary: Dictionary to access
+ nested_key: Dot-separated key path (e.g., "a.b.c")
+
+ Returns:
+ Value at the nested key location
+ """
+ value = dictionary
+ for key in nested_key.split("."):
+ value = value[key]
+ return value
+
+
+def consecutive_ngrams(l: int, n: int) -> list[list[int]]:
+ """Generate all consecutive n-grams from range(l).
+
+ Splits range(l) into all consecutive n-grams.
+
+ Args:
+ l: Length of the range
+ n: Size of each n-gram
+
+ Returns:
+ List of consecutive n-grams
+
+ Example:
+ consecutive_ngrams(7, 2) = [[0, 1], [1, 2], [2, 3], [3, 4], [4, 5], [5, 6]]
+ """
+ ngrams = []
+ for i in range(l - n + 1):
+ ngrams.append(list(range(i, i + n)))
+ return ngrams
diff --git a/modelopt/torch/puzzletron/nas/plugins/puzzletron_nas_plugin.py b/modelopt/torch/puzzletron/nas/plugins/puzzletron_nas_plugin.py
new file mode 100644
index 0000000000..e5025dea7d
--- /dev/null
+++ b/modelopt/torch/puzzletron/nas/plugins/puzzletron_nas_plugin.py
@@ -0,0 +1,235 @@
+# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+# SPDX-License-Identifier: Apache-2.0
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""
+Puzzletron NAS plugin for the Modelopt framework (based on Puzzle algorithm: https://arxiv.org/abs/2411.19146).
+
+It is used by mtn.convert() to convert a model from HF format to Puzzletron heterogeneous format + do pruning scoring
+and save pruned checkpoints, and by mtn.search() to perform the MIP-based NAS search.
+"""
+
+import datetime
+from pathlib import Path
+
+import hydra
+import torch
+from torch import nn
+
+import modelopt.torch.puzzletron.mip.mip_and_realize_models as mip_and_realize_models
+import modelopt.torch.puzzletron.pruning.pruning_ckpts as pruning_ckpts
+import modelopt.torch.puzzletron.scoring.scoring as scoring
+import modelopt.torch.utils.distributed as dist
+from modelopt.torch.nas.conversion import NASModeRegistry
+from modelopt.torch.opt.config import ModeloptBaseConfig, ModeloptField
+from modelopt.torch.opt.mode import (
+ ConvertEntrypoint,
+ ConvertReturnType,
+ MetadataDict,
+ ModeDescriptor,
+ RestoreEntrypoint,
+)
+from modelopt.torch.opt.searcher import BaseSearcher, SearchStateDict
+from modelopt.torch.puzzletron import build_library_and_stats
+from modelopt.torch.puzzletron.activation_scoring import score_pruning_activations
+from modelopt.torch.puzzletron.anymodel.converter import ConverterFactory
+from modelopt.torch.puzzletron.anymodel.model_descriptor import ModelDescriptorFactory
+from modelopt.torch.puzzletron.tools.hydra_utils import initialize_hydra_config_for_dir
+from modelopt.torch.puzzletron.tools.logger import mprint
+
+
+class PuzzletronModel(nn.Module):
+ pass # No model implementation is needed for the puzzletron mode
+
+
+class PuzzletronConfig(ModeloptBaseConfig):
+ """Configuration for Puzzletron NAS algorithm."""
+
+ # Input model path to compress in the HF format
+ input_model_path: str = ModeloptField(
+ default="",
+ title="",
+ description="",
+ )
+
+ # Hydra config directory containing the search space definition
+ hydra_config_dir: str = ModeloptField(
+ default="",
+ title="",
+ description="",
+ )
+
+ # Hydra config name containing the search space definition
+ hydra_config_name: str = ModeloptField(
+ default="",
+ title="",
+ description="",
+ )
+
+ # Directory to save the compressed model and intermediate results
+ puzzle_dir: str = ModeloptField(
+ default="",
+ title="",
+ description="",
+ )
+
+ # Dataset path to use for scoring in prunining and NAS search
+ dataset_path: str = ModeloptField(
+ default="",
+ title="",
+ description="",
+ )
+
+
+def convert_puzzletron_model(model: nn.Module, config: PuzzletronConfig) -> ConvertReturnType:
+ """1. Convert the model from HF format to AnyModel format.
+ 2. Score the pruning activations.
+ 3. Prune the model and save pruned checkpoints
+
+ The output of this step will be used by mnt.search() to perform the NAS search.
+ """
+ # Required for mtn.search() to read NAS configuration
+ model.hydra_config_dir = config.hydra_config_dir
+ model.hydra_config_name = config.hydra_config_name
+ model.puzzle_dir = config.puzzle_dir
+ model.dataset_path = config.dataset_path
+
+ # Load hydra config
+ hydra_cfg = initialize_hydra_config_for_dir(
+ config_dir=config.hydra_config_dir,
+ config_name=config.hydra_config_name,
+ overrides=[
+ f"puzzle_dir={config.puzzle_dir}",
+ f"dataset_path={config.dataset_path}",
+ ],
+ )
+ # Instantiate nested Hydra configs (e.g., pruning_mixin, hook_class)
+ hydra_cfg = hydra.utils.instantiate(hydra_cfg)
+
+ # Convert HuggingFace model to Puzzletron heterogeneous format (generic, uses descriptor from config)
+ if dist.is_master():
+ mprint(
+ "Puzzletron Progress 2/8: converting model to Puzzletron heterogeneous format (single-gpu)"
+ )
+ hf_ckpt_teacher_dir = "ckpts/teacher" # TODO: make it configurable
+
+ # Get descriptor and converter from the hydra config
+ descriptor_name = hydra_cfg.descriptor
+ descriptor = ModelDescriptorFactory.get(descriptor_name)
+ converter = ConverterFactory.get(descriptor_name)
+
+ converter.convert(
+ descriptor=descriptor,
+ input_dir=Path(config.input_model_path),
+ output_dir=Path(config.puzzle_dir) / hf_ckpt_teacher_dir,
+ )
+ dist.barrier()
+
+ # Score_pruning_activations (distributed processing)
+ mprint("Puzzletron Progress 3/8: scoring pruning activations (multi-gpu)")
+ score_pruning_activations.launch_score_activations(hydra_cfg)
+
+ # Prune the model and save pruned checkpoints
+ if dist.is_master():
+ mprint(
+ "Puzzletron Progress 4/8: pruning the model and saving pruned checkpoints (single-gpu)"
+ )
+ pruning_ckpts.launch_prune_ckpt(hydra_cfg)
+ dist.barrier()
+
+ return model, {}
+
+
+def restore_puzzletron_model(
+ model: nn.Module, config: PuzzletronConfig, metadata: MetadataDict
+) -> nn.Module:
+ """Restore is not needed for the puzzletron mode as we are not saving any model state"""
+ return model
+
+
+@NASModeRegistry.register_mode
+class PuzzletronDescriptor(ModeDescriptor):
+ """Descriptor for the Puzzletron mode."""
+
+ @property
+ def name(self) -> str:
+ """String identifier for this mode."""
+ return "puzzletron"
+
+ @property
+ def config_class(self) -> type[ModeloptBaseConfig]:
+ """Configuration class for this mode."""
+ return PuzzletronConfig
+
+ @property
+ def search_algorithm(self) -> type[BaseSearcher]:
+ """Return the associated searcher implementation."""
+
+ return PuzzletronSearcher
+
+ @property
+ def convert(self) -> ConvertEntrypoint:
+ """Entrypoint to convert a model."""
+ return convert_puzzletron_model
+
+ @property
+ def restore(self) -> RestoreEntrypoint:
+ """Entrypoint to restore a model."""
+ return restore_puzzletron_model
+
+ @property
+ def export_mode(self) -> str | None:
+ """The mode that corresponds to the export mode.
+ For now, this will be a no-op as there is no modelopt's concept of search space defined
+ for the puzzletron algorithm.
+ """
+ return "export_nas"
+
+
+class PuzzletronSearcher(BaseSearcher):
+ """Runs NAS search for the Puzzletron mode."""
+
+ @property
+ def default_state_dict(self) -> SearchStateDict:
+ """Not needed for the puzzletron mode as we are not saving any model state"""
+ return {}
+
+ def run_search(self) -> None:
+ # Load hydra config
+ hydra_cfg = initialize_hydra_config_for_dir(
+ config_dir=self.model.hydra_config_dir,
+ config_name=self.model.hydra_config_name,
+ overrides=[
+ f"puzzle_dir={self.model.puzzle_dir}",
+ f"dataset_path={self.model.dataset_path}",
+ ],
+ )
+ # Instantiate nested Hydra configs (e.g., pruning_mixin, hook_class)
+ hydra_cfg = hydra.utils.instantiate(hydra_cfg)
+
+ # Build_library_and_stats (single process)
+ if dist.is_master():
+ mprint(
+ "Puzzletron Progress 5/8: building replacement library and subblock statistics (single-gpu)"
+ )
+ build_library_and_stats.launch_build_library_and_stats(hydra_cfg)
+ dist.barrier()
+
+ # Calc_one_block_scores (distributed processing)
+ mprint("Puzzletron Progress 6/8: calculating one block scores (multi-gpu)")
+ scoring.launch_scoring(hydra_cfg)
+
+ # mip_and_realize_models (distributed processing)
+ mprint("Puzzletron Progress 7/8: running MIP and realizing models (multi-gpu)")
+ mip_and_realize_models.launch_mip_and_realize_model(hydra_cfg)
diff --git a/modelopt/torch/puzzletron/pruning/expert_removal_pruning_mixin.py b/modelopt/torch/puzzletron/pruning/expert_removal_pruning_mixin.py
new file mode 100644
index 0000000000..42c4ad8f51
--- /dev/null
+++ b/modelopt/torch/puzzletron/pruning/expert_removal_pruning_mixin.py
@@ -0,0 +1,237 @@
+# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+# SPDX-License-Identifier: Apache-2.0
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from dataclasses import dataclass, field
+from typing import Any, Dict, List, Optional, Tuple, Type
+
+import torch
+from transformers import PretrainedConfig
+
+from modelopt.torch.prune.importance_hooks.base_hooks import ForwardHook
+from modelopt.torch.prune.importance_hooks.expert_removal_hooks import (
+ NemotronHRemoveExpertsIndependentHook,
+ Qwen3VLRemoveExpertsIndependentHook,
+ RankedChoiceVotingHook,
+ RankedChoiceVotingHookNemotronH,
+)
+from modelopt.torch.puzzletron.pruning.pruning_mixin import LayerDescriptor, PruningMixIn
+from modelopt.torch.puzzletron.pruning.pruning_utils import MlpInitMode, _init_moe_module
+
+
+@dataclass
+class ExpertRemovalLayerDescriptor(LayerDescriptor):
+ """
+ TODO - Add Shared expert weights in case it's prunable.
+ TODO - consider removing the segmentation between weight and bias, doesn't seem to affect the pruning algo.
+ Attributes:
+ target_name: module name required to register hooks for scoring_activations, can be a regex if start with the prefix `regex:`
+ moe_prefix_name: moe prefix layer name, should include a placeholder for `layer_idx` to be repeated for all layers. i.e: `model.layers.{layer_idx}.moe`
+ expert_prefix_name: expert prefix layer name relative to moe_prefix, should include a placeholder for `expert_idx` to be repeated for all experts. i.e: `experts.{expert_idx}`
+ router_weights: List of the router weight names relative to moe_prefix.
+ router_biases: List of the router bias names relative to moe_prefix.
+ expert_weights: List of the expert weight names relative to expert_prefix (for per-expert format).
+ expert_biases: List of the expert bias names relative to expert_prefix (for per-expert format).
+ is_fused_experts: If True, experts are stored as single fused tensors with shape [num_experts, ...].
+ If False (default), experts are stored as separate tensors per expert.
+ fused_expert_weights: List of fused expert weight names relative to moe_prefix (for fused format).
+ e.g., ["experts.gate_up_proj", "experts.down_proj"]
+ """
+
+ target_name: str
+ moe_prefix_name: str
+ expert_prefix_name: str = ""
+ router_weights: List[str] = field(default_factory=list)
+ router_biases: List[str] = field(default_factory=list)
+ expert_weights: List[str] = field(default_factory=list)
+ expert_biases: List[str] = field(default_factory=list)
+ is_fused_experts: bool = False
+ fused_expert_weights: List[str] = field(default_factory=list)
+
+ def module_name_regex(self) -> str:
+ return self.target_name
+
+ def moe_prefix(self, layer_idx: int) -> str:
+ return self.moe_prefix_name.format(layer_idx=layer_idx)
+
+ def expert_prefix(self, layer_idx: int, expert_idx: int) -> str:
+ _expert_prefix = self.moe_prefix_name + "." + self.expert_prefix_name
+ return _expert_prefix.format(layer_idx=layer_idx, expert_idx=expert_idx)
+
+
+class ExpertRemovalPruningMixIn(PruningMixIn):
+ def __init__(self, layer_descriptor: ExpertRemovalLayerDescriptor):
+ assert isinstance(layer_descriptor, ExpertRemovalLayerDescriptor)
+ super().__init__(layer_descriptor)
+
+ def supported_hooks(self) -> List[Type[ForwardHook]]:
+ return [
+ RankedChoiceVotingHook,
+ RankedChoiceVotingHookNemotronH,
+ NemotronHRemoveExpertsIndependentHook,
+ Qwen3VLRemoveExpertsIndependentHook,
+ ]
+
+ def prune_single_layer(
+ self,
+ layer_idx: int,
+ parent_state_dict: dict,
+ new_state_dict: dict,
+ original_config: PretrainedConfig,
+ new_config: PretrainedConfig,
+ mlp_init_mode: MlpInitMode,
+ mlp_init_config: Optional[dict[str, Any]],
+ keys: dict,
+ **kwargs,
+ ) -> Dict[str, torch.Tensor]:
+ layer_out_state_dict = {}
+
+ child_block_config = new_config.block_configs[layer_idx]
+ parent_block_config = original_config.block_configs[layer_idx]
+
+ if not parent_block_config.ffn.is_moe:
+ return layer_out_state_dict
+
+ new_num_experts = child_block_config.ffn.moe.num_local_experts
+ orig_num_experts = parent_block_config.ffn.moe.num_local_experts
+
+ child_router_keys, new_experts_keys = self._generate_moe_keys(layer_idx, new_num_experts)
+ parent_router_keys, orig_experts_keys = self._generate_moe_keys(layer_idx, orig_num_experts)
+
+ # Pop parent's router keys from copy list; child-only router keys will be initialized below
+ for rk in sum(parent_router_keys.values(), []):
+ if rk in keys:
+ keys.pop(rk)
+ for key in sum(orig_experts_keys.values(), []):
+ if key in keys:
+ keys.pop(key)
+
+ if self.layer_descriptor.is_fused_experts:
+ # Fused format: unbundle single tensor [num_experts, ...] into list of per-expert tensors
+ orig_experts_weights = {}
+ for name, fused_keys in orig_experts_keys.items():
+ fused_tensor = parent_state_dict[fused_keys[0]] # Single fused tensor
+ orig_experts_weights[name] = [fused_tensor[i] for i in range(orig_num_experts)]
+
+ new_experts_weights = {}
+ for name, fused_keys in new_experts_keys.items():
+ fused_tensor = new_state_dict[fused_keys[0]] # Single fused tensor
+ new_experts_weights[name] = [fused_tensor[i] for i in range(new_num_experts)]
+ else:
+ # Per-expert format: load each expert tensor separately
+ orig_experts_weights = {
+ name: [parent_state_dict[key] for key in orig_experts_module_keys]
+ for name, orig_experts_module_keys in orig_experts_keys.items()
+ }
+ new_experts_weights = {
+ name: [new_state_dict[key] for key in new_experts_module_keys]
+ for name, new_experts_module_keys in new_experts_keys.items()
+ }
+
+ orig_router_weights = {
+ name: [parent_state_dict[key] for key in _module_router_keys]
+ for name, _module_router_keys in parent_router_keys.items()
+ }
+ new_router_weights = {
+ name: [new_state_dict[key] for key in _module_router_keys]
+ for name, _module_router_keys in child_router_keys.items()
+ }
+
+ out_router_weights, out_experts_weights = _init_moe_module(
+ layer_idx=layer_idx,
+ mlp_init_mode=mlp_init_mode,
+ mlp_init_config=mlp_init_config,
+ orig_router_weights=orig_router_weights,
+ orig_experts_weights=orig_experts_weights,
+ new_router_weights=new_router_weights,
+ new_experts_weights=new_experts_weights,
+ orig_num_experts=orig_num_experts,
+ new_num_experts=new_num_experts,
+ )
+ assert new_experts_keys.keys() == out_experts_weights.keys(), (
+ "new_experts_keys and out_experts_weights must have the same keys"
+ )
+ assert child_router_keys.keys() == out_router_weights.keys(), (
+ "child_router_keys and out_router_weights must have the same keys"
+ )
+
+ for name in child_router_keys.keys():
+ layer_out_state_dict.update(zip(child_router_keys[name], out_router_weights[name]))
+
+ if self.layer_descriptor.is_fused_experts:
+ # Fused format: rebundle list of per-expert tensors into single fused tensor
+ for name in new_experts_keys.keys():
+ fused_key = new_experts_keys[name][0] # Single key for fused tensor
+ fused_tensor = torch.stack(out_experts_weights[name], dim=0) # [num_experts, ...]
+ layer_out_state_dict[fused_key] = fused_tensor
+ else:
+ # Per-expert format: each expert has its own key
+ for name in new_experts_keys.keys():
+ layer_out_state_dict.update(zip(new_experts_keys[name], out_experts_weights[name]))
+
+ return layer_out_state_dict
+
+ def _generate_moe_keys(
+ self, layer_idx: int, num_experts: int
+ ) -> Tuple[Dict[str, List[str]], dict[str, list[str]]]:
+ """
+ Generate MoE weight keys for router and experts.
+ TODO simplify or better define the data structure of the moe keys returned.
+
+ :return: tuple of router_keys and expert_keys, all are absolute names relative to the model root:
+ * router_keys structure:
+ {"weight: [], bias: []"}
+ * expert_keys structure (per-expert format):
+ {": []}
+ i.e:
+ {
+ "down_proj.weight": ["model...experts.0.down_proj.weight", ..., "model...experts.N.down_proj.weight"],
+ ...
+ }
+ * expert_keys structure (fused format):
+ {": []}
+ i.e:
+ {
+ "experts.gate_up_proj": ["model...experts.gate_up_proj"],
+ "experts.down_proj": ["model...experts.down_proj"],
+ }
+ """
+ self.layer_descriptor: ExpertRemovalLayerDescriptor
+ moe_prefix = self.layer_descriptor.moe_prefix(layer_idx)
+
+ router_keys = {
+ "weight": [
+ f"{moe_prefix}.{_weight}" for _weight in self.layer_descriptor.router_weights
+ ],
+ "bias": [f"{moe_prefix}.{_bias}" for _bias in self.layer_descriptor.router_biases],
+ }
+
+ if self.layer_descriptor.is_fused_experts:
+ # Fused format: single tensor per weight type with shape [num_experts, ...]
+ experts_module_names = {}
+ for fused_weight in self.layer_descriptor.fused_expert_weights:
+ experts_module_names[fused_weight] = [f"{moe_prefix}.{fused_weight}"]
+ else:
+ # Per-expert format: separate tensor for each expert
+ expert_key_names = (
+ self.layer_descriptor.expert_weights + self.layer_descriptor.expert_biases
+ )
+ experts_module_names = {}
+ for key_name in expert_key_names:
+ experts_module_names[key_name] = [
+ f"{self.layer_descriptor.expert_prefix(layer_idx, expert_idx)}.{key_name}"
+ for expert_idx in range(num_experts)
+ ]
+
+ return router_keys, experts_module_names
diff --git a/modelopt/torch/puzzletron/pruning/ffn_intermediate_pruning_mixin.py b/modelopt/torch/puzzletron/pruning/ffn_intermediate_pruning_mixin.py
new file mode 100644
index 0000000000..9b7993de1e
--- /dev/null
+++ b/modelopt/torch/puzzletron/pruning/ffn_intermediate_pruning_mixin.py
@@ -0,0 +1,102 @@
+# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+# SPDX-License-Identifier: Apache-2.0
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# mypy: ignore-errors
+
+from dataclasses import dataclass, field
+from typing import Any, Dict, List, Optional, Type
+
+import torch
+from transformers import PretrainedConfig
+
+from modelopt.torch.prune.importance_hooks.base_hooks import (
+ ForwardHook,
+ IndependentChannelContributionHook,
+ IterativeChannelContributionHook,
+)
+from modelopt.torch.puzzletron.pruning.pruning_mixin import LayerDescriptor, PruningMixIn
+from modelopt.torch.puzzletron.tools.bypassed_training.child_init import (
+ MlpInitMode,
+ _init_mlp_module,
+)
+
+
+@dataclass
+class FFNIntermediateLayerDescriptor(LayerDescriptor):
+ down_proj_name: str
+ ffn_prefix_name: str
+ linear_weight_names: List[str] = field(default_factory=list)
+
+ def module_name_regex(self) -> str:
+ return self.down_proj_name
+
+ def ffn_prefix(self, layer_idx: int) -> str:
+ return self.ffn_prefix_name.format(layer_idx=layer_idx)
+
+
+class FFNIntermediatePruningMixIn(PruningMixIn):
+ def __init__(self, layer_descriptor: FFNIntermediateLayerDescriptor):
+ assert isinstance(layer_descriptor, FFNIntermediateLayerDescriptor)
+ super().__init__(layer_descriptor)
+
+ def supported_hooks(self) -> List[Type[ForwardHook]]:
+ return [IndependentChannelContributionHook, IterativeChannelContributionHook]
+
+ def prune_single_layer(
+ self,
+ layer_idx: int,
+ parent_state_dict: dict,
+ new_state_dict: dict,
+ original_config: PretrainedConfig,
+ new_config: PretrainedConfig,
+ mlp_init_mode: MlpInitMode,
+ mlp_init_config: Optional[dict[str, Any]],
+ keys: dict,
+ keys_to_remove: dict,
+ **kwargs,
+ ) -> Dict[str, torch.Tensor]:
+ layer_out_state_dict = {}
+ # Hardcoded strings
+ mlp_prefix = self.layer_descriptor.ffn_prefix(layer_idx)
+ mlp_key_names = [
+ f"{mlp_prefix}.{name}.weight" for name in self.layer_descriptor.linear_weight_names
+ ]
+ mlp_keys = [keys.get(module_name) for module_name in mlp_key_names]
+ mlp_keys = [k for k in mlp_keys if k is not None]
+
+ for key in mlp_keys:
+ keys_to_remove[f"{mlp_prefix}.{key.split('.')[-2]}.weight"] = key
+
+ pruned_filters = None
+ projection_matrix = None
+
+ for mlp_key in mlp_keys:
+ expanded_dim = 1 if self.layer_descriptor.down_proj_name in mlp_key else 0
+ if mlp_key in new_state_dict.keys():
+ mlp_module_weight, pruned_filters, projection_matrix = _init_mlp_module(
+ mlp_init_mode,
+ mlp_prefix,
+ expanded_dim,
+ layer_idx,
+ new_state_dict[mlp_key],
+ new_config,
+ parent_state_dict[mlp_key],
+ original_config,
+ mlp_init_config,
+ pruned_filters,
+ projection_matrix,
+ )
+ layer_out_state_dict[mlp_key] = mlp_module_weight
+
+ return layer_out_state_dict
diff --git a/modelopt/torch/puzzletron/pruning/kv_heads_pruning_mixin.py b/modelopt/torch/puzzletron/pruning/kv_heads_pruning_mixin.py
new file mode 100644
index 0000000000..4a6fe53a34
--- /dev/null
+++ b/modelopt/torch/puzzletron/pruning/kv_heads_pruning_mixin.py
@@ -0,0 +1,127 @@
+# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+# SPDX-License-Identifier: Apache-2.0
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# mypy: ignore-errors
+from dataclasses import dataclass, field
+from typing import Any, List, Optional, Type
+
+from transformers import PretrainedConfig
+
+from modelopt.torch.prune.importance_hooks.base_hooks import (
+ ForwardHook,
+ IndependentKvHeadContributionHook,
+)
+from modelopt.torch.puzzletron.pruning.pruning_mixin import LayerDescriptor, PruningMixIn
+from modelopt.torch.puzzletron.pruning.pruning_utils import (
+ GQAInitMode,
+ _init_attention_biases,
+ _init_attention_weights,
+)
+
+
+@dataclass
+class KVHeadsLayerDescriptor(LayerDescriptor):
+ o_proj_name: str
+ attn_prefix_name: str
+ qkvo_weight_names: List[str] = field(default_factory=list)
+
+ def module_name_regex(self) -> str:
+ return self.o_proj_name
+
+ def attn_prefix(self, layer_idx: int) -> str:
+ return self.attn_prefix_name.format(layer_idx=layer_idx)
+
+
+class KVHeadsPruningMixIn(PruningMixIn):
+ def __init__(self, layer_descriptor: KVHeadsLayerDescriptor):
+ assert isinstance(layer_descriptor, KVHeadsLayerDescriptor)
+ super().__init__(layer_descriptor)
+
+ def supported_hooks(self) -> List[Type[ForwardHook]]:
+ return [IndependentKvHeadContributionHook]
+
+ def prune_single_layer(
+ self,
+ layer_idx: int,
+ parent_state_dict: dict,
+ new_state_dict: dict,
+ original_config: PretrainedConfig,
+ new_config: PretrainedConfig,
+ gqa_init_mode: GQAInitMode,
+ mlp_init_config: Optional[dict[str, Any]],
+ is_original_mha: bool,
+ keys: dict,
+ keys_to_remove: dict,
+ **kwargs,
+ ):
+ layer_out_state_dict = {}
+
+ attn_prefix = self.layer_descriptor.attn_prefix(layer_idx)
+ q_name, k_name, v_name, o_name = [
+ f"{attn_prefix}.{proj_name}" for proj_name in self.layer_descriptor.qkvo_weight_names
+ ]
+
+ head_size = new_config.head_dim
+ for part in ["weight", "bias"]:
+ attn_keys = [f"{name}.{part}" for name in [q_name, k_name, v_name, o_name]]
+ q_key, k_key, v_key, o_key = attn_keys
+
+ # Drop attn keys that don't exist and required to be in the new state_dict
+ attn_keys = [key for key in attn_keys if key in new_state_dict.keys()]
+ if len(attn_keys) > 0 and all(key in keys for key in attn_keys):
+ for key in attn_keys:
+ keys_to_remove[key] = keys[key]
+ is_student_and_teacher_have_same_attention_implementation = all(
+ key in new_state_dict.keys() for key in attn_keys
+ )
+ if is_student_and_teacher_have_same_attention_implementation:
+ if part == "weight":
+ wq, wk, wv, wo = _init_attention_weights(
+ gqa_init_mode=gqa_init_mode,
+ layer_idx=layer_idx,
+ new_state_dict=new_state_dict,
+ new_config=new_config,
+ original_state_dict=parent_state_dict,
+ original_config=original_config,
+ q_key=q_key,
+ k_key=k_key,
+ v_key=v_key,
+ o_key=o_key,
+ is_original_mha=is_original_mha,
+ head_size=head_size,
+ mlp_init_config=mlp_init_config,
+ )
+ layer_out_state_dict[q_key], layer_out_state_dict[k_key] = wq, wk
+ layer_out_state_dict[v_key], layer_out_state_dict[o_key] = wv, wo
+ else:
+ bias_sd = _init_attention_biases(
+ gqa_init_mode=gqa_init_mode,
+ layer_idx=layer_idx,
+ new_state_dict=new_state_dict,
+ new_config=new_config,
+ original_state_dict=parent_state_dict,
+ original_config=original_config,
+ q_key=q_key,
+ k_key=k_key,
+ v_key=v_key,
+ o_key=o_key,
+ is_original_mha=is_original_mha,
+ head_size=head_size,
+ mlp_init_config=mlp_init_config,
+ )
+ for bias_key, sd_key in zip("qkvo", [q_key, k_key, v_key, o_key]):
+ if bias_key in bias_sd.keys():
+ layer_out_state_dict[sd_key] = bias_sd[bias_key]
+
+ return layer_out_state_dict
diff --git a/modelopt/torch/puzzletron/pruning/pruning_ckpts.py b/modelopt/torch/puzzletron/pruning/pruning_ckpts.py
new file mode 100644
index 0000000000..b9cfd75faf
--- /dev/null
+++ b/modelopt/torch/puzzletron/pruning/pruning_ckpts.py
@@ -0,0 +1,354 @@
+# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+# SPDX-License-Identifier: Apache-2.0
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""Utilities for creating pruned model checkpoints.
+
+This module provides functions to generate pruned checkpoints by modifying model architectures
+(FFN intermediate sizes, attention head groups, hidden dimensions) and initializing child pruned models
+from parent checkpoints.
+"""
+
+# mypy: ignore-errors
+import json
+import os
+import time
+from typing import Optional
+
+from omegaconf import DictConfig
+
+from modelopt.torch.puzzletron.anymodel.model_descriptor import ModelDescriptorFactory
+from modelopt.torch.puzzletron.pruning.expert_removal_pruning_mixin import ExpertRemovalPruningMixIn
+from modelopt.torch.puzzletron.pruning.ffn_intermediate_pruning_mixin import (
+ FFNIntermediatePruningMixIn,
+)
+from modelopt.torch.puzzletron.pruning.kv_heads_pruning_mixin import KVHeadsPruningMixIn
+from modelopt.torch.puzzletron.pruning.pruning_utils import (
+ GQAInitMode,
+ HiddenSizeInitMode,
+ LinearInitMode,
+ MlpInitMode,
+ resolve_pruning_mixin,
+)
+from modelopt.torch.puzzletron.tools.bypassed_training.init_child_from_parent import (
+ init_child_from_parent,
+)
+from modelopt.torch.puzzletron.tools.checkpoint_utils import load_model_config
+from modelopt.torch.puzzletron.tools.logger import mprint
+
+
+def launch_ffn_intermediates_prune_ckpt(
+ cfg: DictConfig, max_save_workers: Optional[int] = None, max_layer_workers: Optional[int] = None
+):
+ for intermediate_size in cfg.pruning.intermediate_size_list:
+ dirname = f"ffn_{intermediate_size}_attn_no_op"
+
+ if os.path.exists(os.path.join(cfg.puzzle_dir, "ckpts", dirname)):
+ mprint(f"Process intermediate_size {intermediate_size} has already been pruned & saved")
+ continue
+
+ mprint("Process intermediate_size {}".format(intermediate_size))
+
+ model_config_overrides_json = {"ffn": [{"intermediate_size": intermediate_size}]}
+ mlp_init_config_yaml = cfg.pruning.mlp_init_config_yaml
+
+ output_dir = os.path.join(cfg.pruning.pruned_ckpts_output_dir, dirname)
+
+ # Profile the overall init_child_from_parent call with optimizations
+ mprint("Starting init_child_from_parent...")
+ start_time = time.time()
+ init_child_from_parent(
+ descriptor=cfg.descriptor,
+ pruning_mixin=cfg.pruning.pruning_mixin,
+ parent_checkpoint_dir=cfg.teacher_dir,
+ model_config_overrides_dict=model_config_overrides_json,
+ output_checkpoint_dir=output_dir,
+ gqa_init_mode=GQAInitMode(cfg.pruning.gqa_init_mode),
+ mlp_init_mode=MlpInitMode(cfg.pruning.mlp_init_mode),
+ mlp_init_config_yaml=mlp_init_config_yaml,
+ linear_init_mode=LinearInitMode.FromTeacher, # dummy default value
+ max_workers=max_save_workers, # Will auto-calculate if None
+ max_layer_workers=max_layer_workers, # Will auto-calculate if None
+ )
+ init_child_from_parent_time = time.time() - start_time
+ mprint(f"init_child_from_parent completed in {init_child_from_parent_time:.2f} seconds")
+
+ # Create symlink in puzzle_dir/ckpts
+ ckpt_path = os.path.join(cfg.puzzle_dir, "ckpts")
+ os.makedirs(ckpt_path, exist_ok=True)
+ os.symlink(output_dir, os.path.join(ckpt_path, dirname))
+
+ mprint(f"=== COMPLETED FFN PRUNING FOR FFN INTERMEDIATE SIZE={intermediate_size} ===")
+ mprint(f"Total processing time: {init_child_from_parent_time:.2f} seconds\n")
+
+
+def launch_attn_groups_prune_ckpt(
+ cfg: DictConfig, max_save_workers: Optional[int] = None, max_layer_workers: Optional[int] = None
+):
+ descriptor = cfg.descriptor
+ parent_model_config = load_model_config(
+ cfg.teacher_dir, trust_remote_code=descriptor.requires_trust_remote_code()
+ )
+ num_attention_heads = parent_model_config.num_attention_heads
+
+ for n_heads_in_group in cfg.pruning.n_heads_in_group_list:
+ dirname = f"n_heads_in_group{n_heads_in_group}"
+
+ if os.path.exists(os.path.join(cfg.puzzle_dir, "ckpts", dirname)):
+ mprint(f"Process n_heads_in_group {n_heads_in_group} has already been pruned & saved")
+ continue
+
+ mprint("Process n_heads_in_group {}".format(n_heads_in_group))
+ mprint(f"=== STARTING ATTENTION PRUNING FOR n_heads_in_group={n_heads_in_group} ===")
+
+ num_key_value_heads = num_attention_heads // n_heads_in_group
+ model_config_overrides_json = {"attention": [{"num_key_value_heads": num_key_value_heads}]}
+ mlp_init_config_yaml = cfg.pruning.mlp_init_config_yaml
+
+ output_dir = os.path.join(cfg.pruning.pruned_ckpts_output_dir, dirname)
+
+ # Profile the overall init_child_from_parent call with optimizations
+ mprint("Starting init_child_from_parent...")
+ start_time = time.time()
+ init_child_from_parent(
+ descriptor=cfg.descriptor,
+ pruning_mixin=cfg.pruning.pruning_mixin,
+ parent_checkpoint_dir=cfg.teacher_dir,
+ model_config_overrides_dict=model_config_overrides_json,
+ output_checkpoint_dir=output_dir,
+ gqa_init_mode=GQAInitMode(cfg.pruning.gqa_init_mode),
+ mlp_init_mode=MlpInitMode(cfg.pruning.mlp_init_mode),
+ mlp_init_config_yaml=mlp_init_config_yaml,
+ linear_init_mode=LinearInitMode.FromTeacher, # dummy default value
+ max_workers=max_save_workers, # Will auto-calculate if None
+ max_layer_workers=max_layer_workers, # Will auto-calculate if None
+ )
+ init_child_from_parent_time = time.time() - start_time
+ mprint(f"init_child_from_parent completed in {init_child_from_parent_time:.2f} seconds")
+
+ # Create symlink in puzzle_dir/ckpts
+ ckpt_path = os.path.join(cfg.puzzle_dir, "ckpts")
+ os.makedirs(ckpt_path, exist_ok=True)
+ os.symlink(output_dir, os.path.join(ckpt_path, dirname))
+
+ mprint(f"=== COMPLETED ATTENTION PRUNING FOR n_heads_in_group={n_heads_in_group} ===")
+ mprint(f"Total processing time: {init_child_from_parent_time:.2f} seconds\n")
+
+
+def launch_hidden_dim_prune_ckpt(cfg: DictConfig):
+ """Launch hidden dimension pruning using channel importance ranking."""
+ # Get channel importance results from the activations log directory
+ activations_log_dir = cfg.pruning.activations_log_dir
+ channel_importance_path = os.path.join(activations_log_dir, "channel_importance_results.json")
+
+ if not os.path.exists(channel_importance_path):
+ raise FileNotFoundError(
+ f"Channel importance results not found at {channel_importance_path}. "
+ f"Make sure to run the activation collection step first."
+ )
+
+ # Load parent model config to get FFN configuration
+ descriptor = ModelDescriptorFactory.get(cfg.descriptor)
+ trust_remote_code = descriptor.requires_trust_remote_code()
+ parent_model_config = load_model_config(
+ cfg.pruning.model_name_or_path, trust_remote_code=trust_remote_code
+ )
+ parent_hidden_size = parent_model_config.hidden_size
+
+ # Get teacher's FFN configuration
+ intermediate_sizes = []
+ for block_config in parent_model_config.block_configs:
+ if block_config.ffn.intermediate_size is not None:
+ intermediate_sizes.append(block_config.ffn.intermediate_size)
+ else:
+ intermediate_sizes.append(None)
+
+ mprint(f"Teacher config:")
+ mprint(f" - hidden_size: {parent_hidden_size}")
+ mprint(f" - intermediate_sizes: {intermediate_sizes}")
+ os.makedirs(os.path.join(cfg.puzzle_dir, "ckpts"), exist_ok=True)
+
+ for hidden_size in cfg.pruning.hidden_size_list:
+ mprint(f"\n######################################################################")
+ mprint(f"Hidden Size = {hidden_size}")
+ mprint(f"######################################################################\n")
+
+ mprint(f"Child config:")
+ mprint(f" - hidden_size: {hidden_size}")
+
+ # Create model config overrides with proper FFN configuration
+ model_config_overrides_json = json.dumps(
+ {
+ "hidden_size": hidden_size,
+ "ffn": [
+ {
+ "intermediate_size": intermediate_size,
+ }
+ for intermediate_size in intermediate_sizes
+ ],
+ }
+ )
+
+ mlp_init_config_yaml = cfg.pruning.mlp_init_config_yaml
+ dirname = f"hidden_size_{hidden_size}"
+ output_dir = os.path.join(cfg.pruning.pruned_ckpts_output_dir, dirname)
+
+ mprint(f"Creating checkpoint with hidden_size={hidden_size}")
+ mprint(f"Model config overrides: {model_config_overrides_json}")
+
+ init_child_from_parent(
+ descriptor=cfg.descriptor,
+ pruning_mixin=cfg.pruning.pruning_mixin,
+ parent_checkpoint_dir=cfg.pruning.model_name_or_path,
+ model_config_overrides_dict=model_config_overrides_json,
+ output_checkpoint_dir=output_dir,
+ gqa_init_mode=GQAInitMode(cfg.pruning.gqa_init_mode),
+ mlp_init_mode=MlpInitMode(cfg.pruning.mlp_init_mode),
+ mlp_init_config_yaml=mlp_init_config_yaml,
+ linear_init_mode=LinearInitMode(cfg.pruning.linear_init_mode),
+ hidden_size_init_mode=HiddenSizeInitMode(cfg.pruning.hidden_size_init_mode),
+ channel_importance_path=channel_importance_path,
+ )
+
+ # Create symlink in puzzle_dir/ckpts
+ ckpt_path = os.path.join(cfg.puzzle_dir, "ckpts")
+ os.makedirs(ckpt_path, exist_ok=True)
+ os.symlink(output_dir, os.path.join(ckpt_path, dirname))
+ mprint(f"Created pruned checkpoint at: {output_dir}")
+
+
+def launch_experts_prune_ckpt(
+ cfg: DictConfig,
+ max_save_workers: Optional[int] = None,
+ max_layer_workers: Optional[int] = None,
+ symlink_suffix: Optional[str] = None,
+):
+ for num_experts in cfg.pruning.num_experts_to_keep_list:
+ dirname = f"num_experts_{num_experts}"
+ # Create symlink name with optional suffix
+ symlink_name = f"{dirname}_{symlink_suffix}" if symlink_suffix else dirname
+ if os.path.exists(os.path.join(cfg.puzzle_dir, "ckpts", symlink_name)):
+ mprint(
+ f"Process num_experts {num_experts} (symlink: {symlink_name}) has already been pruned & saved"
+ )
+ continue
+ mprint(f"Process num_experts {num_experts}")
+ mprint(f"=== STARTING EXPERT PRUNING FOR num_experts={num_experts} ===")
+ model_config_overrides_json = {"ffn": [{"moe": {"num_local_experts": num_experts}}]}
+
+ mlp_init_config_yaml = cfg.pruning.mlp_init_config_yaml
+
+ output_dir = os.path.join(cfg.pruning.pruned_ckpts_output_dir, dirname)
+
+ # Profile the overall init_child_from_parent call with optimizations
+ mprint("Starting init_child_from_parent...")
+ start_time = time.time()
+ init_child_from_parent(
+ descriptor=cfg.descriptor,
+ pruning_mixin=cfg.pruning.pruning_mixin,
+ parent_checkpoint_dir=cfg.teacher_dir,
+ model_config_overrides_dict=model_config_overrides_json,
+ output_checkpoint_dir=output_dir,
+ gqa_init_mode=GQAInitMode(cfg.pruning.gqa_init_mode),
+ mlp_init_mode=MlpInitMode(cfg.pruning.mlp_init_mode),
+ mlp_init_config_yaml=mlp_init_config_yaml,
+ linear_init_mode=LinearInitMode.FromTeacher, # dummy default value
+ max_workers=max_save_workers, # Will auto-calculate if None
+ max_layer_workers=max_layer_workers, # Will auto-calculate if None
+ )
+ init_child_from_parent_time = time.time() - start_time
+ mprint(f"init_child_from_parent completed in {init_child_from_parent_time:.2f} seconds")
+
+ # Create symlink in puzzle_dir/ckpts
+ ckpt_path = os.path.join(cfg.puzzle_dir, "ckpts")
+ os.makedirs(ckpt_path, exist_ok=True)
+ os.symlink(output_dir, os.path.join(ckpt_path, symlink_name))
+
+ mprint(f"=== COMPLETED EXPERT PRUNING FOR num_experts={num_experts} ===")
+ mprint(f"Total processing time: {init_child_from_parent_time:.2f} seconds\n")
+
+
+def launch_moe_ffn_intermediates_prune_ckpt(
+ cfg: DictConfig, max_save_workers: Optional[int] = None, max_layer_workers: Optional[int] = None
+):
+ for intermediate_size in cfg.pruning.intermediate_size_list:
+ dirname = f"moe_ffn_{intermediate_size}_attn_no_op"
+
+ if os.path.exists(os.path.join(cfg.puzzle_dir, "ckpts", dirname)):
+ mprint(f"Process intermediate_size {intermediate_size} has already been pruned & saved")
+ continue
+
+ mprint("Process intermediate_size {}".format(intermediate_size))
+
+ model_config_overrides_json = {
+ "attention": [{"no_op": True, "llama4": None}],
+ "ffn": [{"moe": {"expert_intermediate_dim": intermediate_size}}],
+ }
+ mlp_init_config_yaml = cfg.pruning.mlp_init_config_yaml
+
+ output_dir = os.path.join(cfg.pruning.pruned_ckpts_output_dir, dirname)
+
+ # Profile the overall init_child_from_parent call with optimizations
+ mprint("Starting init_child_from_parent...")
+ start_time = time.time()
+ init_child_from_parent(
+ descriptor=cfg.descriptor,
+ pruning_mixin=cfg.pruning.pruning_mixin,
+ parent_checkpoint_dir=cfg.teacher_dir,
+ model_config_overrides_dict=model_config_overrides_json,
+ output_checkpoint_dir=output_dir,
+ gqa_init_mode=GQAInitMode(cfg.pruning.gqa_init_mode),
+ mlp_init_mode=MlpInitMode(cfg.pruning.mlp_init_mode),
+ mlp_init_config_yaml=mlp_init_config_yaml,
+ linear_init_mode=LinearInitMode.FromTeacher, # dummy default value
+ max_workers=max_save_workers, # Will auto-calculate if None
+ max_layer_workers=max_layer_workers, # Will auto-calculate if None
+ )
+ init_child_from_parent_time = time.time() - start_time
+ mprint(f"init_child_from_parent completed in {init_child_from_parent_time:.2f} seconds")
+
+ # Create symlink in puzzle_dir/ckpts
+ os.symlink(output_dir, os.path.join(cfg.puzzle_dir, "ckpts", dirname))
+
+ mprint(f"=== COMPLETED MOE FFN PRUNING FOR FFN INTERMEDIATE SIZE={intermediate_size} ===")
+ mprint(f"Total processing time: {init_child_from_parent_time:.2f} seconds\n")
+
+
+def launch_prune_ckpt(cfg: DictConfig):
+ cfg.descriptor = ModelDescriptorFactory.get(cfg.descriptor)
+ # Resolve pruning_mixin from config (could be string, enum, or PruningMixIn)
+ cfg.pruning.pruning_mixin = resolve_pruning_mixin(cfg.pruning.pruning_mixin, cfg.descriptor)
+ pruning_mixin = cfg.pruning.pruning_mixin
+
+ # I/O optimization settings - same as FFN pruning
+ max_save_workers = None # Will auto-calculate as min(CPU count, num files)
+ if "PRUNING_SAVE_WORKERS" in os.environ:
+ max_save_workers = int(os.environ["PRUNING_SAVE_WORKERS"])
+
+ # Layer workers now auto-calculate but can still be overridden
+ max_layer_workers = None # Will auto-calculate as min(CPU count, num layers)
+ if "PRUNING_LAYER_WORKERS" in os.environ:
+ max_layer_workers = int(os.environ["PRUNING_LAYER_WORKERS"])
+
+ if isinstance(pruning_mixin, FFNIntermediatePruningMixIn):
+ launch_ffn_intermediates_prune_ckpt(cfg, max_save_workers, max_layer_workers)
+ elif isinstance(pruning_mixin, KVHeadsPruningMixIn):
+ launch_attn_groups_prune_ckpt(cfg, max_save_workers, max_layer_workers)
+ elif isinstance(pruning_mixin, ExpertRemovalPruningMixIn):
+ launch_experts_prune_ckpt(cfg, max_save_workers, max_layer_workers)
+ # elif target_layer == "layernorm":
+ # launch_hidden_dim_prune_ckpt(cfg)
+ else:
+ raise NotImplementedError(
+ f"checkpoint pruning is not currently supported for pruning mixin: {pruning_mixin.__class__.__name__}"
+ )
diff --git a/modelopt/torch/puzzletron/pruning/pruning_mixin.py b/modelopt/torch/puzzletron/pruning/pruning_mixin.py
new file mode 100644
index 0000000000..21685848bf
--- /dev/null
+++ b/modelopt/torch/puzzletron/pruning/pruning_mixin.py
@@ -0,0 +1,73 @@
+# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+# SPDX-License-Identifier: Apache-2.0
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# mypy: ignore-errors
+
+import re
+from abc import ABC, abstractmethod
+from typing import List, Optional, Tuple, Type
+
+from modelopt.torch.prune.importance_hooks.base_hooks import ForwardHook
+
+
+class LayerDescriptor:
+ def module_name_regex(self) -> str:
+ return ""
+
+ def block_idx_from_module_name(self, module_name: str) -> Optional[int]:
+ block_idx_match = re.search(r"\.(\d+)\.", module_name)
+ if block_idx_match:
+ return int(block_idx_match.group(1))
+ return None
+
+ def get_modules_names_to_hook(self, model) -> List[Tuple[int, str]]:
+ target_layer = self.module_name_regex()
+ if target_layer.startswith("regex:"):
+ target_layer_regex = target_layer[len("regex:") :]
+ pattern = re.compile(target_layer_regex)
+ match_predicate = lambda module_name: pattern.search(module_name)
+ else:
+ match_predicate = lambda module_name: module_name.endswith(target_layer)
+
+ module_names_to_hook = []
+ for module_name, module in model.named_modules():
+ if match_predicate(module_name):
+ module_names_to_hook.append(
+ (self.block_idx_from_module_name(module_name), module_name)
+ )
+ return module_names_to_hook
+
+
+class PruningMixIn(ABC):
+ def __init__(self, layer_descriptor: LayerDescriptor):
+ self.layer_descriptor = layer_descriptor
+
+ def get_module_names_to_hook(self, model) -> List[Tuple[int, str]]:
+ return self.layer_descriptor.get_modules_names_to_hook(model)
+
+ @abstractmethod
+ def supported_hooks(self) -> List[Type[ForwardHook]]:
+ raise NotImplementedError
+
+ # @abstractmethod
+ # def prune_single_layer(
+ # self,
+ # layer_idx: int,
+ # parent_state_dict: dict,
+ # new_state_dict: dict,
+ # original_config: PretrainedConfig,
+ # new_config: PretrainedConfig,
+ # **kwargs
+ # ):
+ # raise NotImplementedError
diff --git a/modelopt/torch/puzzletron/pruning/pruning_utils.py b/modelopt/torch/puzzletron/pruning/pruning_utils.py
new file mode 100644
index 0000000000..82ba675c94
--- /dev/null
+++ b/modelopt/torch/puzzletron/pruning/pruning_utils.py
@@ -0,0 +1,652 @@
+# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+# SPDX-License-Identifier: Apache-2.0
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# mypy: ignore-errors
+
+import json
+import math
+from enum import Enum
+from pathlib import Path
+from typing import Any, Dict, List, Optional, Tuple, Type, Union
+
+import torch
+from transformers import PretrainedConfig
+
+from modelopt.torch.puzzletron.anymodel.model_descriptor import ModelDescriptor
+from modelopt.torch.puzzletron.pruning.pruning_mixin import PruningMixIn
+
+
+class GQAInitMode(Enum):
+ RandomKV = "RandomKV"
+ AverageKV = "AverageKV"
+ FirstKV = "FirstKV"
+ RandomBlock = "RandomBlock"
+ CopyAsIs = "CopyAsIs"
+ Degrouping = "Degrouping"
+ PruneKVHeads = "PruneKVHeads"
+
+
+class MlpInitMode(Enum):
+ Random = "Random"
+ Truncate = "Truncate"
+ CopyAsIs = "CopyAsIs"
+ PruneByActivationsLog = "PruneByActivationsLog"
+ ExpertRemoval = "ExpertRemoval"
+ ConcatExpertsIntoDenseFFN = "ConcatExpertsIntoDenseFFN"
+
+
+class LinearInitMode(Enum):
+ Random = "Random"
+ FromTeacher = "FromTeacher"
+
+
+class HiddenSizeInitMode(Enum):
+ Random = "Random"
+ Truncate = "Truncate"
+ PruneByChannelRanking = "PruneByChannelRanking"
+ CopyAsIs = "CopyAsIs"
+
+
+def resolve_pruning_mixin(
+ pruning_mixin, descriptor: Type[ModelDescriptor]
+) -> PruningMixIn | List[PruningMixIn]:
+ """
+ Convert pruning_mixin argument to PruningMixIn instance(s).
+
+ Args:
+ pruning_mixin: Can be a string identifier, PruningMixIn instance,
+ or a list of any of those types.
+ descriptor: ModelDescriptor class that provides the pruning_mixins() mapping.
+
+ Returns:
+ PruningMixIn or List[PruningMixIn] depending on input type.
+ """
+ # Handle list of values recursively
+ if isinstance(pruning_mixin, list):
+ return [resolve_pruning_mixin(item, descriptor) for item in pruning_mixin]
+
+ # Handle single value
+ # If it's already a PruningMixIn, return as is
+ if isinstance(pruning_mixin, PruningMixIn):
+ return pruning_mixin
+
+ # Get the pruning mixins mapping from the descriptor
+ mixins_dict = descriptor.pruning_mixins()
+
+ if isinstance(pruning_mixin, str):
+ if pruning_mixin not in mixins_dict:
+ available_methods = list(mixins_dict.keys())
+ raise ValueError(
+ f"Pruning method '{pruning_mixin}' is not supported by {descriptor.__name__}. "
+ f"Available methods: {available_methods}"
+ )
+ return mixins_dict[pruning_mixin]
+
+ raise ValueError(f"Unsupported pruning_mixin type: {type(pruning_mixin)}")
+
+
+def _init_mlp_module(
+ mlp_init_mode: Union[MlpInitMode, str],
+ mlp_prefix: str,
+ expanded_dim: int,
+ layer_idx: int,
+ new_item: torch.Tensor,
+ new_config: PretrainedConfig,
+ orig_item: torch.Tensor,
+ original_config: PretrainedConfig,
+ mlp_init_config: Optional[dict[str, Any]],
+ pruned_filters: Optional[torch.Tensor] = None,
+ projection_matrix: Optional[dict[str, torch.Tensor]] = None,
+) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[dict[str, torch.Tensor]]]:
+ if isinstance(mlp_init_mode, str):
+ mlp_init_mode = MlpInitMode(mlp_init_mode)
+ assert orig_item.ndim == 2, f"{orig_item.ndim=}"
+ assert new_item.ndim == 2, f"{new_item.ndim=}"
+
+ assert new_config.num_hidden_layers == original_config.num_hidden_layers, (
+ f"({new_config.num_hidden_layers=}) != ({original_config.num_hidden_layers=})"
+ )
+
+ new_intermediate_size = new_config.block_configs[layer_idx].ffn.intermediate_size
+ original_intermediate_size = original_config.block_configs[layer_idx].ffn.intermediate_size
+
+ if mlp_init_mode == MlpInitMode.CopyAsIs:
+ assert new_intermediate_size == original_intermediate_size, (
+ f"({new_intermediate_size=}) != ({original_intermediate_size=}), can't be copied as is."
+ )
+ mlp_module_weight = orig_item
+
+ elif mlp_init_mode == MlpInitMode.Random:
+ mlp_module_weight = new_item
+
+ elif new_intermediate_size == original_intermediate_size:
+ mlp_module_weight = orig_item
+
+ elif mlp_init_mode in (
+ MlpInitMode.Truncate,
+ MlpInitMode.PruneByActivationsLog,
+ ):
+ assert original_intermediate_size >= new_intermediate_size, (
+ f"({original_intermediate_size=}) < ({new_intermediate_size=}), can't be truncated."
+ )
+ orig_ffn_size = orig_item.shape[expanded_dim]
+ new_ffn_size = new_item.shape[expanded_dim]
+
+ if mlp_init_mode == MlpInitMode.Truncate:
+ truncated_weight = torch.narrow(
+ orig_item, dim=expanded_dim, start=0, length=new_ffn_size
+ )
+ mlp_module_weight = truncated_weight
+
+ elif mlp_init_mode == MlpInitMode.PruneByActivationsLog:
+ if pruned_filters is None:
+ filter_importance = _load_activations_log(
+ mlp_init_config, module_name=f"{mlp_prefix}.down_proj"
+ )
+ filters_sorted_by_importance = torch.argsort(filter_importance, descending=True)
+ pruned_filters = filters_sorted_by_importance[:new_ffn_size].to(orig_item.device)
+
+ pruned_weight = torch.index_select(orig_item, dim=expanded_dim, index=pruned_filters)
+ if mlp_init_config.get("scale_pruned_weights", False) and expanded_dim == 1:
+ pruned_weight = pruned_weight * (orig_ffn_size / new_ffn_size)
+ mlp_module_weight = pruned_weight
+
+ elif (
+ mlp_init_mode == MlpInitMode.ExpertRemoval
+ ): # the case of mlp layers of maverick. for now we only support copy as is
+ assert new_intermediate_size == original_intermediate_size, (
+ f"({new_intermediate_size=}) != ({original_intermediate_size=}), can't be copied as is."
+ )
+ mlp_module_weight = orig_item
+
+ else:
+ raise ValueError(f"Unsupported {mlp_init_mode=}")
+
+ return mlp_module_weight, pruned_filters, projection_matrix
+
+
+def _load_activations_log(mlp_init_config: dict[str, Any], module_name: str) -> torch.Tensor:
+ _cache_activations_log(mlp_init_config)
+ module_log = ACTIVATIONS_LOG[module_name]
+ filter_importance = module_log["score"]
+ return filter_importance
+
+
+ACTIVATIONS_LOG = dict()
+
+
+def _cache_activations_log(mlp_init_config: dict[str, Any]) -> None:
+ if len(ACTIVATIONS_LOG) == 0:
+ assert "activations_log_dir" in mlp_init_config
+ activations_log_dir = mlp_init_config["activations_log_dir"]
+ print(f"Loading activations_log from {activations_log_dir}")
+ # Only load rank_*.pth files to avoid loading hook_states_*.pth checkpoint files
+ ACTIVATIONS_LOG.update(
+ {
+ module_name: module_log
+ for p in Path(activations_log_dir).glob("rank_*.pth")
+ for module_name, module_log in torch.load(p).items()
+ }
+ )
+
+
+def _init_attention_weights(
+ gqa_init_mode,
+ layer_idx,
+ new_state_dict,
+ new_config,
+ original_state_dict,
+ q_key,
+ k_key,
+ v_key,
+ o_key,
+ original_config,
+ is_original_mha,
+ head_size,
+ mlp_init_config,
+):
+ assert new_config.num_attention_heads == original_config.num_attention_heads, (
+ f"({new_config.num_attention_heads=}) != ({original_config.num_attention_heads=})"
+ )
+ num_q_heads = new_config.num_attention_heads
+ num_kv_heads = new_config.block_configs[layer_idx].attention.num_key_value_heads
+ orig_num_kv_heads = original_config.block_configs[layer_idx].attention.num_key_value_heads
+
+ # new_w* are typically randomly initialized
+ new_wq = new_state_dict[q_key]
+ new_wk = new_state_dict[k_key]
+ new_wv = new_state_dict[v_key]
+ new_wo = new_state_dict[o_key]
+
+ # w* are from the parent model
+ wq = original_state_dict[q_key]
+ wk = original_state_dict[k_key]
+ wv = original_state_dict[v_key]
+ wo = original_state_dict[o_key]
+
+ if "bias" in k_key:
+ for tensor in [wq, wk, wv, wo, new_wq, new_wk, new_wv, new_wo]:
+ assert tensor.ndim == 1
+ tensor.unsqueeze_(1)
+ dim1 = wk.shape[1] # this is the hidden_size in case of matrix weights, and 1 in case of biases
+
+ if gqa_init_mode in (GQAInitMode.RandomKV, GQAInitMode.RandomBlock):
+ wk, wv = new_wk, new_wv
+ elif gqa_init_mode in (GQAInitMode.AverageKV, GQAInitMode.FirstKV):
+ assert orig_num_kv_heads % num_kv_heads == 0, (
+ f"({orig_num_kv_heads=}) % ({num_kv_heads=}) != 0"
+ )
+ n_heads_to_aggregate = orig_num_kv_heads // num_kv_heads
+
+ wk = wk.view(-1, n_heads_to_aggregate, head_size, dim1)
+ wv = wv.view(-1, n_heads_to_aggregate, head_size, dim1)
+
+ if gqa_init_mode == GQAInitMode.AverageKV:
+ wk = wk.mean(dim=1)
+ wv = wv.mean(dim=1)
+ else:
+ wk = wk[:, 0]
+ wv = wv[:, 0]
+ elif gqa_init_mode == GQAInitMode.CopyAsIs:
+ assert new_wk.shape == wk.shape, f"({new_wk.shape=}) != ({wk.shape=})"
+ assert new_wv.shape == wv.shape, f"({new_wv.shape=}) != ({wv.shape=})"
+ assert new_wq.shape == wq.shape, f"({new_wq.shape=}) != ({wq.shape=})"
+ assert new_wo.shape == wo.shape, f"({new_wo.shape=}) != ({wo.shape=})"
+
+ elif gqa_init_mode == GQAInitMode.Degrouping:
+ assert not is_original_mha, (
+ "Degrouping can only be done on original models that are GQA themselves."
+ )
+ n_groups = num_kv_heads
+ orig_n_groups = orig_num_kv_heads
+ assert n_groups % orig_n_groups == 0, f"{n_groups=} must be a divisor of {orig_n_groups=}"
+ n_repeats = n_groups // orig_n_groups
+ if n_repeats > 1:
+ print(f"Degrouping {orig_n_groups} into {n_groups}")
+
+ def degroup_w(w):
+ w = w.view(orig_n_groups, head_size, dim1)
+ w = torch.repeat_interleave(w, repeats=n_repeats, dim=0)
+ w = w.reshape(n_groups * head_size, dim1)
+ return w
+
+ wk = degroup_w(wk)
+ wv = degroup_w(wv)
+
+ elif gqa_init_mode == GQAInitMode.PruneKVHeads:
+ wk = wk.view(orig_num_kv_heads, head_size, dim1)
+ wv = wv.view(orig_num_kv_heads, head_size, dim1)
+ wq = wq.view(orig_num_kv_heads, num_q_heads // orig_num_kv_heads, head_size, dim1)
+ wo = wo.view(dim1, orig_num_kv_heads, num_q_heads // orig_num_kv_heads, head_size)
+
+ o_proj_module_name = o_key.replace(".weight", "")
+ kv_head_importance = _load_activations_log(mlp_init_config, module_name=o_proj_module_name)
+ kv_heads_sorted_by_importance = torch.argsort(kv_head_importance, descending=True)
+ kv_heads_to_keep = kv_heads_sorted_by_importance[:num_kv_heads]
+ kv_heads_to_remove = kv_heads_sorted_by_importance[num_kv_heads:]
+
+ wk = wk[kv_heads_to_keep]
+ wv = wv[kv_heads_to_keep]
+
+ reduction_factor = orig_num_kv_heads // num_kv_heads
+
+ prune_via_duplication = False
+ if prune_via_duplication:
+ ## Wq option 1 - replicate the query groups to match the total number of attention heads. Queries work with familiar kv heads.
+ wq = wq[kv_heads_to_keep]
+ wq = torch.repeat_interleave(wq, repeats=reduction_factor, dim=0)
+
+ ## Wo option 1 - replicate the groups of the original Wo. Multiple by the reduction factor to mimic pruning of the other groups.
+ ## This makes sense with Wq option 1, but it will not be more expressive than true pruning due to symmetry, unless we add noise.
+ wo = wo[:, kv_heads_to_keep]
+ wo = torch.repeat_interleave(wo, repeats=reduction_factor, dim=1)
+ wo = wo / reduction_factor
+
+ else: # prune via zeroing out
+ ## Wq option 2 - keep the original queries. At init they will not be used (see the Wo zeroing), during training they can adapt to new kv heads like in variable GQA.
+ ## We need to interleave them to keep the matching between queries and kv heads.
+ kv_heads_to_keep = kv_heads_to_keep.tolist()
+ kv_heads_to_remove = kv_heads_to_remove.tolist()
+ kv_head_ordering = []
+ zero_out_mask = []
+ for i_head in range(orig_num_kv_heads):
+ if i_head % reduction_factor == 0:
+ kv_head_ordering.append(kv_heads_to_keep.pop(0))
+ zero_out_mask.append(False)
+ else:
+ kv_head_ordering.append(kv_heads_to_remove.pop(0))
+ zero_out_mask.append(True)
+
+ wq = wq[kv_head_ordering]
+
+ ## Wo option 2 - zero-out the contribution of queries that do not belong to chosen kv heads.
+ ## At initialization it's exactly like pruning, but the extra weights will have the chance to adapt to new kv heads if we train the model.
+ ## Even though the weight is 0 it can still train, like initializing biases to 0 does not prevent them from training.
+ ## Matmul backprop: if Y = AB and dY is the gradient of Y, then dA = dY @ B.T and dB = A.T @ dY, so the gradient of the zeroed-out weights depends on the gradient of what multiplies them.
+ wo = wo[:, kv_head_ordering]
+ wo[:, zero_out_mask] = 0.0
+
+ else:
+ raise ValueError(f"{gqa_init_mode=} not supported")
+
+ wk = wk.reshape(-1, dim1)
+ wv = wv.reshape(-1, dim1)
+ wq = wq.reshape(-1, dim1)
+ wo = wo.reshape(dim1, -1)
+ return wq, wk, wv, wo
+
+
+def _init_attention_biases(
+ gqa_init_mode,
+ layer_idx,
+ new_state_dict,
+ new_config,
+ original_state_dict,
+ q_key,
+ k_key,
+ v_key,
+ o_key,
+ original_config,
+ is_original_mha,
+ head_size,
+ mlp_init_config,
+):
+ assert new_config.num_attention_heads == original_config.num_attention_heads, (
+ f"({new_config.num_attention_heads=}) != ({original_config.num_attention_heads=})"
+ )
+ num_q_heads = new_config.num_attention_heads
+ num_kv_heads = new_config.block_configs[layer_idx].attention.num_key_value_heads
+ orig_num_kv_heads = original_config.block_configs[layer_idx].attention.num_key_value_heads
+ n_heads_in_group = num_q_heads // num_kv_heads
+ orig_n_heads_in_group = num_q_heads // orig_num_kv_heads
+
+ o_proj_bias = new_config.o_proj_bias
+ attention_bias = new_config.attention_bias
+
+ # If no biases
+ if not (o_proj_bias or attention_bias):
+ return {}
+
+ new_bias_sd = {}
+ bias_sd = {}
+ # new_w* are typically randomly initialized
+ if o_proj_bias:
+ new_bias_sd["o"] = new_state_dict[o_key]
+ bias_sd["o"] = original_state_dict[o_key]
+ if attention_bias:
+ for bias_key, key in zip("qkv", [q_key, k_key, v_key]):
+ new_bias_sd[bias_key] = new_state_dict[key]
+ bias_sd[bias_key] = original_state_dict[key]
+
+ # maybe unsqueeze all tensors
+ for tensor in list(new_bias_sd.values()) + list(bias_sd.values()):
+ assert tensor.ndim == 1
+ tensor.unsqueeze_(1)
+
+ dim1 = 1 # this is the hidden_size in case of matrix weights, and 1 in case of biases
+ if gqa_init_mode in (GQAInitMode.RandomKV, GQAInitMode.RandomBlock) and attention_bias:
+ bias_sd["k"] = torch.zeros(
+ new_bias_sd["k"].shape, dtype=bias_sd["k"].dtype, device=bias_sd["k"].device
+ )
+ bias_sd["v"] = torch.zeros(
+ new_bias_sd["v"].shape, dtype=bias_sd["v"].dtype, device=bias_sd["v"].device
+ )
+ elif gqa_init_mode in (GQAInitMode.AverageKV, GQAInitMode.FirstKV) and attention_bias:
+ assert n_heads_in_group % orig_n_heads_in_group == 0, (
+ f"({n_heads_in_group=}) % ({orig_n_heads_in_group=}) != 0"
+ )
+ n_heads_to_aggregate = n_heads_in_group // orig_n_heads_in_group
+
+ bias_sd["k"] = bias_sd["k"].view(-1, n_heads_to_aggregate, head_size, dim1)
+ bias_sd["v"] = bias_sd["v"].view(-1, n_heads_to_aggregate, head_size, dim1)
+
+ if gqa_init_mode == GQAInitMode.AverageKV:
+ bias_sd["k"] = bias_sd["k"].mean(dim=1)
+ bias_sd["v"] = bias_sd["v"].mean(dim=1)
+ else:
+ bias_sd["k"] = bias_sd["k"][:, 0]
+ bias_sd["v"] = bias_sd["v"][:, 0]
+ elif gqa_init_mode == GQAInitMode.CopyAsIs:
+ for key in bias_sd.keys():
+ assert new_bias_sd[key].shape == bias_sd[key].shape, (
+ f"({new_bias_sd[key].shape=}) != ({bias_sd[key].shape=})"
+ )
+
+ elif gqa_init_mode == GQAInitMode.Degrouping and attention_bias:
+ assert not is_original_mha, (
+ "Degrouping can only be done on original models that are GQA themselves."
+ )
+ n_groups = new_config.num_attention_heads // n_heads_in_group
+ orig_n_groups = original_config.num_attention_heads // orig_n_heads_in_group
+ assert n_groups % orig_n_groups == 0, f"{n_groups=} must be a divisor of {orig_n_groups=}"
+ n_repeats = n_groups // orig_n_groups
+ if n_repeats > 1:
+ print(f"Degrouping {orig_n_groups} into {n_groups}")
+
+ def degroup_w(w):
+ w = w.view(orig_n_groups, head_size, dim1)
+ w = torch.repeat_interleave(w, repeats=n_repeats, dim=0)
+ w = w.reshape(n_groups * head_size, dim1)
+ return w
+
+ bias_sd["k"] = degroup_w(bias_sd["k"])
+ bias_sd["v"] = degroup_w(bias_sd["v"])
+
+ elif gqa_init_mode == GQAInitMode.PruneKVHeads:
+ if o_proj_bias:
+ o_proj_module_name = o_key.rsplit(".", 1)[0]
+ else:
+ # Here we assume that the o_proj layer is called "o_proj"
+ o_proj_module_name = k_key.rsplit(".", 2)[0] + ".o_proj"
+
+ kv_head_importance = _load_activations_log(mlp_init_config, module_name=o_proj_module_name)
+ kv_heads_sorted_by_importance = torch.argsort(kv_head_importance, descending=True)
+ kv_heads_to_keep = kv_heads_sorted_by_importance[:num_kv_heads]
+ kv_heads_to_remove = kv_heads_sorted_by_importance[num_kv_heads:]
+
+ # view as KV groups
+ if attention_bias:
+ bias_sd["k"] = bias_sd["k"].view(orig_num_kv_heads, head_size, dim1)
+ bias_sd["v"] = bias_sd["v"].view(orig_num_kv_heads, head_size, dim1)
+ bias_sd["q"] = bias_sd["q"].view(
+ orig_num_kv_heads, orig_n_heads_in_group, head_size, dim1
+ )
+ # Keep important KV heads and prune the others
+ bias_sd["k"] = bias_sd["k"][kv_heads_to_keep]
+ bias_sd["v"] = bias_sd["v"][kv_heads_to_keep]
+ if o_proj_bias:
+ bias_sd["o"] = bias_sd["o"].view(
+ dim1, orig_num_kv_heads, orig_n_heads_in_group, head_size
+ )
+
+ reduction_factor = orig_num_kv_heads // num_kv_heads
+
+ prune_via_duplication = False
+ if prune_via_duplication:
+ if attention_bias:
+ ## Wq option 1 - replicate the query groups to match the total number of attention heads. Queries work with familiar kv heads.
+ bias_sd["q"] = bias_sd["q"][kv_heads_to_keep]
+ bias_sd["q"] = torch.repeat_interleave(
+ bias_sd["q"], repeats=reduction_factor, dim=0
+ )
+
+ if o_proj_bias:
+ ## Wo option 1 - replicate the groups of the original Wo. Multiple by the reduction factor to mimic pruning of the other groups.
+ ## This makes sense with Wq option 1, but it will not be more expressive than true pruning due to symmetry, unless we add noise.
+ bias_sd["o"] = bias_sd["o"][:, kv_heads_to_keep]
+ bias_sd["o"] = torch.repeat_interleave(
+ bias_sd["o"], repeats=reduction_factor, dim=1
+ )
+ bias_sd["o"] = bias_sd["o"] / reduction_factor
+
+ else: # prune via zeroing out
+ ## Wq option 2 - keep the original queries. At init they will not be used (see the Wo zeroing), during training they can adapt to new kv heads like in variable GQA.
+ ## We need to interleave them to keep the matching between queries and kv heads.
+ kv_heads_to_keep = kv_heads_to_keep.tolist()
+ kv_heads_to_remove = kv_heads_to_remove.tolist()
+ kv_head_ordering = []
+ zero_out_mask = []
+ for i_head in range(orig_num_kv_heads):
+ if i_head % reduction_factor == 0:
+ kv_head_ordering.append(kv_heads_to_keep.pop(0))
+ zero_out_mask.append(False)
+ else:
+ kv_head_ordering.append(kv_heads_to_remove.pop(0))
+ zero_out_mask.append(True)
+
+ if attention_bias:
+ bias_sd["q"] = bias_sd["q"][kv_head_ordering]
+
+ if o_proj_bias:
+ ## Wo option 2 - zero-out the contribution of queries that do not belong to chosen kv heads.
+ ## At initialization it's exactly like pruning, but the extra weights will have the chance to adapt to new kv heads if we train the model.
+ ## Even though the weight is 0 it can still train, like initializing biases to 0 does not prevent them from training.
+ ## Matmul backprop: if Y = AB and dY is the gradient of Y, then dA = dY @ B.T and dB = A.T @ dY, so the gradient of the zeroed-out weights depends on the gradient of what multiplies them.
+ bias_sd["o"] = bias_sd["o"][:, kv_head_ordering]
+ bias_sd["o"][:, zero_out_mask] = 0.0
+
+ else:
+ raise ValueError(f"{gqa_init_mode=} not supported")
+
+ if attention_bias:
+ for bias_key in "qkv":
+ bias_sd[bias_key] = bias_sd[bias_key].reshape(-1)
+ if o_proj_bias:
+ bias_sd["o"] = bias_sd["o"].reshape(-1)
+ return bias_sd
+
+
+def _init_moe_module(
+ mlp_init_mode: Union[MlpInitMode, str],
+ mlp_init_config: Optional[Dict[str, Any]],
+ layer_idx: int,
+ orig_router_weights: Dict[str, List[torch.Tensor]],
+ orig_experts_weights: Dict[str, List[torch.Tensor]],
+ new_router_weights: Dict[str, List[torch.Tensor]],
+ new_experts_weights: Dict[str, List[torch.Tensor]],
+ orig_num_experts: int,
+ new_num_experts: int,
+) -> Tuple[Dict[str, List[torch.Tensor]], Dict[str, List[torch.Tensor]]]:
+ if isinstance(mlp_init_mode, str):
+ mlp_init_mode = MlpInitMode(mlp_init_mode)
+
+ if mlp_init_mode != MlpInitMode.ExpertRemoval:
+ raise ValueError(f"Unsupported {mlp_init_mode=}")
+
+ selected_experts = _select_expert_indices(
+ mlp_init_config=mlp_init_config,
+ layer_idx=layer_idx,
+ orig_num_experts=orig_num_experts,
+ new_num_experts=new_num_experts,
+ )
+
+ # Router: prefer parent tensors when available; if child has bias only, slice from child
+ result_router_weights: dict[str, list[torch.Tensor]] = {}
+ for name, new_list in new_router_weights.items():
+ result_router_weights[name] = [
+ tensor_to_slice[selected_experts] for tensor_to_slice in orig_router_weights[name]
+ ]
+
+ # Experts: for each name present in the child, take from parent if available, else from child
+ result_experts_weights: dict[str, list[torch.Tensor]] = {}
+ for name, new_list in new_experts_weights.items():
+ if name in orig_experts_weights:
+ src_list = orig_experts_weights[name]
+ else:
+ src_list = new_list
+ result_experts_weights[name] = [src_list[i] for i in selected_experts]
+
+ # Validate shapes
+ assert result_router_weights.keys() == new_router_weights.keys(), (
+ "result_router_weights and new_router_weights must have the same keys"
+ )
+ for name in new_router_weights.keys():
+ assert len(new_router_weights[name]) == len(result_router_weights[name])
+ for new_router_weight, result_router_weight in zip(
+ new_router_weights[name], result_router_weights[name]
+ ):
+ assert new_router_weight.shape == result_router_weight.shape
+
+ assert result_experts_weights.keys() == new_experts_weights.keys(), (
+ "result_experts_weights and new_experts_weights must have the same keys"
+ )
+ for name in result_experts_weights.keys():
+ assert len(new_experts_weights[name]) == len(result_experts_weights[name])
+ for new_expert_weight, result_expert_weight in zip(
+ new_experts_weights[name], result_experts_weights[name]
+ ):
+ assert new_expert_weight.shape == result_expert_weight.shape
+
+ return result_router_weights, result_experts_weights
+
+
+def _select_expert_indices(
+ *, mlp_init_config: dict[str, Any], layer_idx: int, orig_num_experts: int, new_num_experts: int
+) -> list[int]:
+ expert_scores = _load_expert_scores(mlp_init_config, layer_idx)
+ assert len(expert_scores) == orig_num_experts
+ higher_is_better = mlp_init_config.get("higher_is_better", True)
+ selected_experts = sorted(
+ range(orig_num_experts),
+ key=lambda i: (
+ expert_scores[i]
+ if not math.isnan(expert_scores[i])
+ else (float("-inf") if higher_is_better else float("inf"))
+ ),
+ reverse=higher_is_better,
+ )[:new_num_experts]
+ return selected_experts
+
+
+def _load_expert_scores(
+ mlp_init_config: Optional[dict[str, Any]], layer_idx: int
+) -> list[list[int | float]]:
+ assert mlp_init_config is not None
+ if "expert_scores_file" in mlp_init_config:
+ expert_scores_file = mlp_init_config["expert_scores_file"]
+ with open(expert_scores_file, "r") as f:
+ expert_scores = json.load(f)
+ elif "activations_log_dir" in mlp_init_config:
+ _cache_activations_log(mlp_init_config)
+ # Use layer_prefix_template from pruning config, or fall back to legacy nemotron_h format
+ # TODO - get from descriptors
+ layer_prefix_template = mlp_init_config.get(
+ "layer_prefix_template", "backbone.layers.{layer_idx}."
+ )
+ layer_prefix = layer_prefix_template.format(layer_idx=layer_idx)
+ candidate_layer_keys = [
+ key for key in ACTIVATIONS_LOG.keys() if key.startswith(layer_prefix)
+ ]
+ if len(candidate_layer_keys) == 0:
+ raise ValueError(f"No layer keys found for {layer_prefix=}. {ACTIVATIONS_LOG.keys()=}")
+ elif len(candidate_layer_keys) > 1:
+ if "layer_suffix" not in mlp_init_config:
+ raise ValueError(
+ f"Multiple candidate layer keys found for {layer_prefix=}, you must specify a layer_suffix in the mlp_init_config. {candidate_layer_keys=}"
+ )
+ layer_suffix = mlp_init_config["layer_suffix"]
+ layer_key = f"{layer_prefix}{layer_suffix}"
+ else:
+ layer_key = candidate_layer_keys[0]
+ layer_log = ACTIVATIONS_LOG[layer_key]
+
+ expert_scores_key = mlp_init_config.get("expert_scores_key", "expert_ranks")
+ if expert_scores_key not in layer_log:
+ raise ValueError(
+ f"Expert scores key {expert_scores_key=} not found in {layer_log.keys()=}"
+ )
+ expert_scores = layer_log[expert_scores_key]
+ else:
+ raise ValueError(f"Unsupported {mlp_init_config=}")
+ return expert_scores
diff --git a/modelopt/torch/puzzletron/puzzletron.py b/modelopt/torch/puzzletron/puzzletron.py
new file mode 100644
index 0000000000..5a1484e07a
--- /dev/null
+++ b/modelopt/torch/puzzletron/puzzletron.py
@@ -0,0 +1,76 @@
+# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+# SPDX-License-Identifier: Apache-2.0
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""This module provides the main compression function for a model using MIP-based NAS search algorithm."""
+
+import hydra
+from omegaconf import DictConfig
+
+import modelopt.torch.puzzletron.activation_scoring.score_pruning_activations as score_pruning_activations
+import modelopt.torch.puzzletron.build_library_and_stats as build_library_and_stats
+import modelopt.torch.puzzletron.mip.mip_and_realize_models as mip_and_realize_models
+import modelopt.torch.puzzletron.pruning.pruning_ckpts as pruning_ckpts
+import modelopt.torch.puzzletron.scoring.scoring as scoring
+import modelopt.torch.utils.distributed as dist
+from modelopt.torch.puzzletron.tools.hydra_utils import initialize_hydra_config_for_dir
+
+
+def puzzletron(
+ hydra_config_dir: str, hydra_config: str, puzzle_dir: str, dataset_path: str
+) -> DictConfig:
+ """Compress a model using the MIP-based NAS search algorithm from Puzzletron.
+
+ Args:
+ hydra_config_dir (str): path to a hydra_config_dir that defines the search space
+ hydra_config (str): the corresponding hydra config file
+ puzzle_dir (str): directory with a puzzletron model to compress
+ dataset_path (str): dataset used for scoring and distillation
+
+ Returns:
+ Hydra config object after compressing the model.
+ The same hydra configuration object is used across all compression steps.
+ TODO: Investigate if this config object is immutable across steps and clarify
+ """
+ # Step 0: Load puzzletron hydra config
+ hydra_cfg = initialize_hydra_config_for_dir(
+ config_dir=hydra_config_dir,
+ config_name=hydra_config,
+ overrides=[
+ f"puzzle_dir={puzzle_dir}",
+ f"dataset_path={dataset_path}",
+ ],
+ )
+ hydra_cfg = hydra.utils.instantiate(hydra_cfg)
+
+ # Step 1: score_pruning_activations (distributed processing)
+ score_pruning_activations.launch_score_activations(hydra_cfg)
+
+ # Step 2: pruning_ckpts (single process)
+ if dist.is_master():
+ pruning_ckpts.launch_prune_ckpt(hydra_cfg)
+ dist.barrier()
+
+ # Step 4: build_library_and_stats (single process)
+ if dist.is_master():
+ build_library_and_stats.launch_build_library_and_stats(hydra_cfg)
+ dist.barrier()
+
+ # Step 5: calc_one_block_scores (distributed processing)
+ scoring.launch_scoring(hydra_cfg)
+
+ # Step 6: mip_and_realize_models (distributed processing)
+ mip_and_realize_models.launch_mip_and_realize_model(hydra_cfg)
+
+ return hydra_cfg
diff --git a/modelopt/torch/puzzletron/replacement_library/build_replacement_library.py b/modelopt/torch/puzzletron/replacement_library/build_replacement_library.py
new file mode 100644
index 0000000000..a7ed3f7d37
--- /dev/null
+++ b/modelopt/torch/puzzletron/replacement_library/build_replacement_library.py
@@ -0,0 +1,616 @@
+# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+# SPDX-License-Identifier: Apache-2.0
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""
+This module constructs the replacement library JSON files from a puzzle directory containing
+multiple trained model checkpoints. It analyzes checkpoints to extract unique block and subblock
+configurations, builds a library of available replacements, and generates solutions for layer
+replacement in compressed models. The resulting replacement library can then be used by
+ReplacementLibrary to efficiently load models with mixed teacher/student layers.
+
+Standard Puzzle Usage:
+======================
+python -m modelopt.torch.puzzletron.replacement_library.build_replacement_library PUZZLE_DIR
+
+Teacher checkpoint dir is assumed to be inside PUZZLE_DIR/ckpts/teacher (symlink is recommended)
+though you can supply an explicit --teacher_checkpoint_dir.
+
+--add_ffn_no_ops and --add_attention_no_ops are optional (default True),
+
+
+"""
+# mypy: ignore-errors
+
+import json
+from pathlib import Path
+from typing import Any, Type
+
+import pandas as pd
+from omegaconf import DictConfig
+
+from modelopt.torch.puzzletron.anymodel.model_descriptor import (
+ ModelDescriptor,
+ ModelDescriptorFactory,
+)
+from modelopt.torch.puzzletron.decilm.deci_lm_hf_code.block_config import (
+ AttentionConfig,
+ BlockConfig,
+ FFNConfig,
+)
+from modelopt.torch.puzzletron.replacement_library.replacement_utils import (
+ is_replacement_identical_to_teacher,
+ replacement_is_teacher,
+ sort_replacements,
+)
+from modelopt.torch.puzzletron.tools.checkpoint_utils import (
+ SAFETENSORS_SUBBLOCKS_DIR_NAME,
+ is_valid_decilm_checkpoint,
+ load_model_config,
+)
+from modelopt.torch.puzzletron.tools.logger import mprint
+from modelopt.torch.puzzletron.tools.robust_json import json_dump
+from modelopt.torch.puzzletron.utils.parsing import format_global_config
+from modelopt.torch.puzzletron.utils.utils import block_config_to_str, subblock_config_to_str
+
+UNIQUE_SUBBLOCK_IDENTIFIER = ["block_config", "attention_config", "ffn_config", "block_idx"]
+CHECKPOINTS_DIR_NAME = "ckpts"
+
+
+def build_replacement_library(
+ master_puzzle_dir: Path | str,
+ descriptor: ModelDescriptor,
+ teacher_checkpoint_dir: Path | str | None = None,
+ add_ffn_no_ops: bool = True,
+ add_attention_no_ops: bool = True,
+) -> None:
+ """
+ For normal puzzle runs, use default values.
+ For advanced use cases, see the Usage section.
+ """
+ master_puzzle_dir = Path(master_puzzle_dir)
+ (master_puzzle_dir / "ckpts").mkdir(exist_ok=True)
+ teacher_checkpoint_dir = infer_teacher_dir(master_puzzle_dir, teacher_checkpoint_dir)
+ trust_remote_code = descriptor.requires_trust_remote_code()
+ subblocks_df = _build_subblocks_df(
+ master_puzzle_dir,
+ teacher_checkpoint_dir,
+ add_ffn_no_ops,
+ add_attention_no_ops,
+ trust_remote_code=trust_remote_code,
+ )
+ block_library_df = _build_block_library_from_subblocks(subblocks_df)
+
+ layer_replacements = _build_layer_replacements(
+ block_library_df, master_puzzle_dir, teacher_checkpoint_dir, trust_remote_code
+ )
+
+ single_sequence_replacement_solutions = _build_single_sequence_replacement_solutions(
+ layer_replacements, teacher_checkpoint_dir, trust_remote_code
+ )
+
+ json_dump(block_library_df.to_dict(orient="records"), master_puzzle_dir / "block_library.json")
+ json_dump(subblocks_df.to_dict(orient="records"), master_puzzle_dir / "subblock_library.json")
+ json_dump(layer_replacements, master_puzzle_dir / "replacement_library.json")
+ json_dump(
+ single_sequence_replacement_solutions,
+ master_puzzle_dir / "single_sequence_replacement_solutions.json",
+ )
+ mprint("done")
+
+
+def launch_build_replacement_library(cfg: DictConfig) -> None:
+ """
+ Launch the build replacement library function with Hydra configuration.
+ """
+ mprint(f"Building replacement library for puzzle directory: {cfg.puzzle_dir}")
+ mprint(f"Teacher directory: {cfg.teacher_dir}")
+ mprint(
+ f"Build replacement library config: {format_global_config(cfg.build_replacement_library, title='Build replacement library')}"
+ )
+
+ descriptor = ModelDescriptorFactory.get(cfg.descriptor)
+ build_replacement_library(
+ master_puzzle_dir=cfg.puzzle_dir,
+ teacher_checkpoint_dir=cfg.teacher_dir,
+ add_ffn_no_ops=cfg.build_replacement_library.add_ffn_no_ops,
+ add_attention_no_ops=cfg.build_replacement_library.add_attention_no_ops,
+ descriptor=descriptor,
+ )
+
+
+def infer_teacher_dir(
+ master_puzzle_dir: Path | str,
+ teacher_checkpoint_dir: Path | str | None = None,
+) -> Path:
+ if teacher_checkpoint_dir is None:
+ teacher_checkpoint_dir = Path(master_puzzle_dir) / CHECKPOINTS_DIR_NAME / "teacher"
+ if not teacher_checkpoint_dir.exists():
+ raise ValueError(
+ f"You must either provide the --teacher_checkpoint_dir argument, or create a link to the "
+ f"teacher dir under '{{PUZZLE_DIR}}/ckpts'."
+ )
+ teacher_checkpoint_dir = Path(teacher_checkpoint_dir).resolve().absolute()
+ return teacher_checkpoint_dir
+
+
+def _build_block_library_from_subblocks(subblocks_df: pd.DataFrame) -> pd.DataFrame:
+ joint_blocks_df = subblocks_df.dropna(subset=["block_config"]).copy()
+ constructed_blocks_df = _construct_blocks_from_subblocks(subblocks_df)
+
+ is_constructed_block_has_joint_variant = pd.Series(
+ map(tuple, constructed_blocks_df[["block_config", "block_idx"]].values)
+ ).isin(pd.Series(map(tuple, joint_blocks_df[["block_config", "block_idx"]].values)))
+ constructed_blocks_df = constructed_blocks_df[~is_constructed_block_has_joint_variant]
+
+ block_library_df = pd.concat([joint_blocks_df, constructed_blocks_df])
+ block_library_df["block_repr"] = block_library_df["block_config"].apply(block_config_to_str)
+
+ dups = block_library_df.loc[
+ block_library_df[["block_config", "block_idx"]].duplicated()
+ ].sort_values(by=["block_config", "block_idx"])
+ if len(dups) > 0:
+ mprint(f"Found {len(dups)} duplicate blocks in the block library. Here are some examples:")
+ dup_block_idx = dups["block_idx"].iloc[0]
+ dups_with_same_block_idx = dups[dups["block_idx"] == dup_block_idx]
+ for _, row in dups_with_same_block_idx.head(10).iterrows():
+ mprint(row.to_dict())
+ json_dump(block_library_df.to_dict(orient="records"), "ERROR_block_library.json")
+ json_dump(subblocks_df.to_dict(orient="records"), "ERROR_subblock_library.json")
+ raise ValueError(
+ f"Found {len(dups)} duplicate blocks in the block library. See ERROR_block_library.json and ERROR_subblock_library.json for more details."
+ )
+
+ return block_library_df
+
+
+def _construct_blocks_from_subblocks(subblocks_df: pd.DataFrame) -> pd.DataFrame:
+ columns = subblocks_df.columns
+ decomp_blocks_df = subblocks_df[subblocks_df["block_config"].isna()].drop(
+ columns=columns[columns.str.contains("block_config|joint|block_repr")]
+ )
+
+ attention_df = decomp_blocks_df.dropna(subset="attention_config").drop(
+ columns=columns[columns.str.contains("ffn")]
+ )
+ ffn_df = decomp_blocks_df.dropna(subset="ffn_config").drop(
+ columns=columns[columns.str.contains("attention")]
+ )
+ constructed_blocks_df = pd.merge(attention_df, ffn_df, on="block_idx")
+
+ constructed_blocks_df["block_config"] = constructed_blocks_df.apply(
+ lambda row: BlockConfig(ffn=row["ffn_config"], attention=row["attention_config"]), axis=1
+ )
+
+ return constructed_blocks_df
+
+
+def _build_subblocks_df(
+ master_puzzle_dir: Path | str,
+ teacher_checkpoint_dir: Path | str,
+ add_ffn_no_ops: bool,
+ add_attention_no_ops: bool,
+ trust_remote_code: bool = False,
+) -> pd.DataFrame:
+ teacher_checkpoint_dir = Path(teacher_checkpoint_dir)
+ checkpoint_dirs = _get_last_checkpoint_from_each_experiment(
+ master_puzzle_dir, trust_remote_code=trust_remote_code
+ )
+ checkpoint_dirs = [teacher_checkpoint_dir] + list(checkpoint_dirs - {teacher_checkpoint_dir})
+ checkpoints_to_split = [teacher_checkpoint_dir]
+
+ subblock_rows = []
+ for checkpoint_dir in checkpoint_dirs:
+ subblocks_to_extract = _infer_subblocks_to_extract(checkpoint_dir, checkpoints_to_split)
+ if len(subblocks_to_extract) > 0:
+ subblock_rows_from_current_checkpoint = (
+ _construct_subblock_rows_from_current_checkpoint(
+ checkpoint_dir, subblocks_to_extract, trust_remote_code=trust_remote_code
+ )
+ )
+ subblock_rows.extend(subblock_rows_from_current_checkpoint)
+
+ subblocks_df = pd.DataFrame(subblock_rows)
+
+ subblocks_df = _drop_duplicates_of_decomp_no_op(subblocks_df)
+ assert subblocks_df.duplicated().sum() == 0
+
+ if add_ffn_no_ops or add_attention_no_ops:
+ subblocks_df = _add_no_op_subblock_rows(subblocks_df, add_ffn_no_ops, add_attention_no_ops)
+
+ subblocks_df = _drop_duplicates_of_teacher(subblocks_df, teacher_checkpoint_dir)
+
+ subblocks_that_have_multiple_sources = list(
+ subblocks_df[subblocks_df.duplicated(UNIQUE_SUBBLOCK_IDENTIFIER, keep=False)].groupby(
+ UNIQUE_SUBBLOCK_IDENTIFIER, dropna=False
+ )
+ )
+ if len(subblocks_that_have_multiple_sources) > 0:
+ mprint(
+ f"Found {len(subblocks_that_have_multiple_sources)} subblock types with multiple sources. Dropping duplicates..."
+ )
+ for subblock_identifier, duplicates_df in subblocks_that_have_multiple_sources:
+ mprint("\n================================")
+ mprint(dict(zip(UNIQUE_SUBBLOCK_IDENTIFIER, subblock_identifier)))
+ for _, row in duplicates_df.iterrows():
+ mprint(row.to_dict())
+
+ # Drop duplicates, keeping the first occurrence (which should be from teacher)
+ mprint(f"Dropping duplicates. Original count: {len(subblocks_df)}")
+ subblocks_df = subblocks_df.drop_duplicates(subset=UNIQUE_SUBBLOCK_IDENTIFIER, keep="first")
+ mprint(f"After dropping duplicates: {len(subblocks_df)}")
+
+ subblocks_df["ffn_repr"] = subblocks_df["ffn_config"].apply(subblock_config_to_str)
+ subblocks_df["attention_repr"] = subblocks_df["attention_config"].apply(subblock_config_to_str)
+ subblocks_df["block_repr"] = subblocks_df["block_config"].apply(block_config_to_str)
+
+ return subblocks_df
+
+
+def _drop_duplicates_of_teacher(
+ subblocks_df: pd.DataFrame,
+ teacher_checkpoint_dir: Path | str,
+) -> pd.DataFrame:
+ orig_subblocks_df = subblocks_df.copy()
+
+ attention_is_teacher = subblocks_df["attention_checkpoint_dir"] == str(teacher_checkpoint_dir)
+ ffn_is_teacher = subblocks_df["ffn_checkpoint_dir"] == str(teacher_checkpoint_dir)
+ is_joint_teacher = attention_is_teacher & ffn_is_teacher
+
+ is_decomp_attention = subblocks_df["ffn_config"].isna()
+ is_decomp_ffn = subblocks_df["attention_config"].isna()
+ is_joint_block = ~is_decomp_attention & ~is_decomp_ffn
+
+ student_indices_that_have_teacher_dups = []
+
+ for current_subset, is_teacher in [
+ (is_decomp_attention, attention_is_teacher),
+ (is_decomp_ffn, ffn_is_teacher),
+ (is_joint_block, is_joint_teacher),
+ ]:
+ subblocks_df = orig_subblocks_df.copy().loc[current_subset]
+
+ subblocks_df["is_student"] = ~is_teacher.loc[current_subset]
+
+ def get_student_indices_that_have_teacher_dups(grouped_is_student: pd.Series) -> list:
+ if grouped_is_student.all():
+ return []
+ return grouped_is_student.index[grouped_is_student].tolist()
+
+ current_student_indices_that_have_teacher_dups = [
+ dup_index
+ for dup_list in subblocks_df.groupby(UNIQUE_SUBBLOCK_IDENTIFIER, dropna=False)[
+ "is_student"
+ ].apply(get_student_indices_that_have_teacher_dups)
+ for dup_index in dup_list
+ ]
+ student_indices_that_have_teacher_dups.extend(
+ current_student_indices_that_have_teacher_dups
+ )
+
+ dedup_subblocks_df = orig_subblocks_df.drop(index=student_indices_that_have_teacher_dups)
+ return dedup_subblocks_df
+
+
+def _drop_duplicates_of_decomp_no_op(subblocks_df: pd.DataFrame) -> pd.DataFrame:
+ is_decomp = subblocks_df["block_config"].isna()
+ is_ffn_no_op = subblocks_df["ffn_config"].apply(lambda conf: conf is not None and conf.no_op)
+ is_attention_no_op = subblocks_df["attention_config"].apply(
+ lambda conf: conf is not None and conf.no_op
+ )
+ is_duplicated = subblocks_df.duplicated(subset=UNIQUE_SUBBLOCK_IDENTIFIER, keep="first")
+ is_dup_of_decomp_no_op = is_duplicated & is_decomp & (is_ffn_no_op | is_attention_no_op)
+ subblocks_df = subblocks_df[~is_dup_of_decomp_no_op]
+ return subblocks_df
+
+
+def _construct_subblock_rows_from_current_checkpoint(
+ checkpoint_dir: Path, subblocks_to_extract: list[str], trust_remote_code: bool = False
+) -> list[dict[str, Any]]:
+ subblock_rows_from_current_checkpoint = []
+ model_config = load_model_config(checkpoint_dir, trust_remote_code=trust_remote_code)
+ for block_idx, block_config in enumerate(model_config.block_configs):
+ for subblock_to_extract in subblocks_to_extract:
+ subblock_row = _init_empty_subblock_row(block_idx)
+
+ if subblock_to_extract == "block":
+ subblock_row["block_config"] = block_config
+ subblock_row["attention_config"] = block_config.attention
+ subblock_row["attention_checkpoint_dir"] = (
+ str(checkpoint_dir) if not block_config.attention.no_op else None
+ )
+ subblock_row["ffn_config"] = block_config.ffn
+ subblock_row["ffn_checkpoint_dir"] = (
+ str(checkpoint_dir) if not block_config.ffn.no_op else None
+ )
+ elif subblock_to_extract == "ffn":
+ subblock_row["ffn_config"] = block_config.ffn
+ subblock_row["ffn_checkpoint_dir"] = (
+ str(checkpoint_dir) if not block_config.ffn.no_op else None
+ )
+ elif subblock_to_extract == "attention":
+ subblock_row["attention_config"] = block_config.attention
+ subblock_row["attention_checkpoint_dir"] = (
+ str(checkpoint_dir) if not block_config.attention.no_op else None
+ )
+ else:
+ raise ValueError()
+
+ subblock_rows_from_current_checkpoint.append(subblock_row)
+ return subblock_rows_from_current_checkpoint
+
+
+def _add_no_op_subblock_rows(
+ subblocks_df: pd.DataFrame,
+ add_ffn_no_op: bool,
+ add_attention_no_op: bool,
+) -> pd.DataFrame:
+ n_layer = subblocks_df["block_idx"].max() + 1
+
+ no_op_subblocks = []
+ if add_ffn_no_op:
+ no_op_subblocks.append("ffn")
+ if add_attention_no_op:
+ no_op_subblocks.append("attention")
+
+ additional_no_op_rows = []
+ for no_op_subblock in no_op_subblocks:
+ rows_with_no_op_subblock, subblock_cls = _get_rows_with_no_op_subblock(
+ subblocks_df, no_op_subblock
+ )
+ existing_no_op_indices = rows_with_no_op_subblock["block_idx"].values
+ missing_no_op_indices = list(set(range(n_layer)) - set(existing_no_op_indices))
+ for block_idx in missing_no_op_indices:
+ no_op_subblock_row = {
+ **_init_empty_subblock_row(block_idx),
+ f"{no_op_subblock}_config": subblock_cls(no_op=True),
+ }
+ additional_no_op_rows.append(no_op_subblock_row)
+
+ subblocks_df = pd.concat([subblocks_df, pd.DataFrame(additional_no_op_rows)])
+
+ for no_op_subblock in no_op_subblocks:
+ rows_with_no_op_subblock, _ = _get_rows_with_no_op_subblock(subblocks_df, no_op_subblock)
+ assert len(rows_with_no_op_subblock) == n_layer, (
+ f"Got {len(rows_with_no_op_subblock)} rows with {no_op_subblock}=no_op, but we have {n_layer} layers"
+ )
+ return subblocks_df
+
+
+def _get_rows_with_no_op_subblock(
+ subblocks_df: pd.DataFrame, no_op_subblock: str
+) -> tuple[pd.DataFrame, Type[AttentionConfig] | Type[FFNConfig]]:
+ other_subblock = "ffn" if no_op_subblock == "attention" else "attention"
+ subblock_cls = AttentionConfig if no_op_subblock == "attention" else FFNConfig
+ no_op_subblock_config = subblock_cls(no_op=True)
+ rows_with_no_op_subblock = subblocks_df[
+ (subblocks_df[f"{no_op_subblock}_config"] == no_op_subblock_config)
+ & subblocks_df[f"{other_subblock}_config"].isna()
+ ]
+ return rows_with_no_op_subblock, subblock_cls
+
+
+def _get_last_checkpoint_from_each_experiment(
+ master_puzzle_dir: Path | str, trust_remote_code: bool = False
+) -> set[Path]:
+ master_puzzle_dir = Path(master_puzzle_dir)
+ master_checkpoints_dir = master_puzzle_dir / CHECKPOINTS_DIR_NAME
+ subdirs_of_master_checkpoints_dir = [
+ p.resolve() for p in master_checkpoints_dir.iterdir() if p.is_dir()
+ ]
+ checkpoint_dirs = [
+ p.parent
+ for subdir in subdirs_of_master_checkpoints_dir
+ for p in subdir.rglob("config.json")
+ ]
+
+ for checkpoint_dir in checkpoint_dirs:
+ if checkpoint_dir == master_checkpoints_dir:
+ raise ValueError(
+ f"We need at least 1 hierarchy level under the '{CHECKPOINTS_DIR_NAME}' dir. "
+ "Name your checkpoints, preferably with meaningful names. "
+ "If you are Ido Galil, tell Tomer that you got this exception ;) "
+ )
+
+ # Filter out checkpoints without block_configs (e.g. unconverted raw HF layouts)
+ valid_checkpoint_dirs = [
+ cp
+ for cp in checkpoint_dirs
+ if is_valid_decilm_checkpoint(cp, trust_remote_code=trust_remote_code)
+ ]
+
+ experiment_dirs = [
+ p if (p in subdirs_of_master_checkpoints_dir) else p.parent for p in valid_checkpoint_dirs
+ ]
+
+ deduped_checkpoint_dirs = set(
+ pd.DataFrame({"checkpoint_dir": valid_checkpoint_dirs, "experiment_dir": experiment_dirs})
+ .sort_values("checkpoint_dir")
+ .drop_duplicates(subset="experiment_dir", keep="last")["checkpoint_dir"]
+ .tolist()
+ )
+ return deduped_checkpoint_dirs
+
+
+def _infer_subblocks_to_extract(
+ checkpoint_dir: Path,
+ checkpoints_to_split: list[Path],
+) -> list[str]:
+ if (checkpoint_dir / "replacement_library.json").exists():
+ return []
+ bypass_config_path = checkpoint_dir / "bypass_config.json"
+ if (checkpoint_dir in checkpoints_to_split) or (not bypass_config_path.exists()):
+ subblocks_to_extract = ["block", "attention", "ffn"]
+ else:
+ bypass_config = json.loads(bypass_config_path.read_text())
+ keys_to_learn = bypass_config.get("keys_to_learn", "entire_block")
+ if keys_to_learn == "entire_block":
+ subblocks_to_extract = ["block"]
+ elif "mlp" in keys_to_learn and "attn" not in keys_to_learn:
+ subblocks_to_extract = ["ffn"]
+ elif "attn" in keys_to_learn and "mlp" not in keys_to_learn:
+ subblocks_to_extract = ["attention"]
+ else:
+ raise ValueError(f"Unrecognized {keys_to_learn=}")
+ return subblocks_to_extract
+
+
+def _init_empty_subblock_row(block_idx: int) -> dict[str, Any]:
+ return {
+ "attention_checkpoint_dir": None,
+ "ffn_checkpoint_dir": None,
+ "block_config": None,
+ "attention_config": None,
+ "ffn_config": None,
+ "block_idx": block_idx,
+ "block_repr": None,
+ "attention_repr": None,
+ "ffn_repr": None,
+ }
+
+
+def _build_layer_replacements(
+ block_library_df: pd.DataFrame,
+ master_puzzle_dir: Path,
+ teacher_checkpoint_dir: Path,
+ trust_remote_code: bool = False,
+) -> list[dict]:
+ layer_replacements_from_blocks = _build_layer_replacements_from_block_library(block_library_df)
+ layer_replacements_from_checkpoints = _gather_layer_replacements_from_checkpoints(
+ master_puzzle_dir, trust_remote_code=trust_remote_code
+ )
+ layer_replacements = layer_replacements_from_blocks + layer_replacements_from_checkpoints
+ layer_replacements = _filter_duplicate_teacher_replacements(
+ layer_replacements, teacher_checkpoint_dir, trust_remote_code
+ )
+ return layer_replacements
+
+
+def _build_layer_replacements_from_block_library(block_library_df: pd.DataFrame) -> list[dict]:
+ layer_replacements = []
+ for _, row in block_library_df.iterrows():
+ block_idx = row["block_idx"]
+ block_config = row["block_config"]
+ weight_paths = []
+ for subblock_name in ["attention", "ffn"]:
+ checkpoint_dir = row[f"{subblock_name}_checkpoint_dir"]
+ if checkpoint_dir is not None:
+ subblock_path = (
+ Path(checkpoint_dir)
+ / SAFETENSORS_SUBBLOCKS_DIR_NAME
+ / f"block_{block_idx}_{subblock_name}.safetensors"
+ )
+ weight_paths.append(subblock_path)
+ weight_paths = sorted(set(weight_paths))
+ layer_replacement = {
+ "parent_layer_indices": [block_idx],
+ "child_block_configs": [block_config],
+ "weight_paths": weight_paths,
+ }
+ layer_replacements.append(layer_replacement)
+ return layer_replacements
+
+
+def _gather_layer_replacements_from_checkpoints(
+ master_puzzle_dir: str | Path, trust_remote_code: bool = False
+) -> list[dict]:
+ gathered_layer_replacements = []
+ checkpoint_dirs = _get_last_checkpoint_from_each_experiment(
+ master_puzzle_dir, trust_remote_code=trust_remote_code
+ )
+ for checkpoint_dir in checkpoint_dirs:
+ if (layer_replacements_path := checkpoint_dir / "replacement_library.json").exists():
+ layer_replacements = json.loads(layer_replacements_path.read_text())
+ for layer_replacement in layer_replacements:
+ layer_replacement["child_block_configs"] = [
+ BlockConfig(**block_config_dict)
+ for block_config_dict in layer_replacement["child_block_configs"]
+ ]
+ layer_replacement["weight_paths"] = sorted(
+ set(Path(p) for p in layer_replacement["weight_paths"])
+ )
+ gathered_layer_replacements.extend(layer_replacements)
+ return gathered_layer_replacements
+
+
+def _filter_duplicate_teacher_replacements(
+ layer_replacements: list[dict],
+ teacher_checkpoint_dir: Path,
+ trust_remote_code: bool = False,
+) -> list[dict]:
+ teacher_model_config = load_model_config(
+ teacher_checkpoint_dir, trust_remote_code=trust_remote_code
+ )
+ filtered_layer_replacements = []
+ for layer_replacement in layer_replacements:
+ if replacement_is_teacher(
+ layer_replacement, teacher_model_config, teacher_checkpoint_dir
+ ) or not is_replacement_identical_to_teacher(layer_replacement, teacher_model_config):
+ filtered_layer_replacements.append(layer_replacement)
+ return filtered_layer_replacements
+
+
+def _build_single_sequence_replacement_solutions(
+ layer_replacements: list[dict],
+ teacher_checkpoint_dir: Path,
+ trust_remote_code: bool = False,
+) -> list[dict]:
+ teacher_model_config = load_model_config(
+ teacher_checkpoint_dir, trust_remote_code=trust_remote_code
+ )
+ n_layer = teacher_model_config.num_hidden_layers
+
+ teacher_replacements = dict()
+ student_replacements = []
+ for layer_replacement in layer_replacements:
+ if replacement_is_teacher(layer_replacement, teacher_model_config, teacher_checkpoint_dir):
+ block_idx = layer_replacement["parent_layer_indices"][0]
+ teacher_replacements[block_idx] = layer_replacement
+ else:
+ student_replacements.append(layer_replacement)
+
+ teacher_indices_represented_in_replacements = sorted(teacher_replacements.keys())
+ assert teacher_indices_represented_in_replacements == list(range(n_layer)), (
+ f"{n_layer=}, {teacher_indices_represented_in_replacements=}"
+ )
+
+ student_replacements = sort_replacements(student_replacements)
+
+ solutions = []
+ for layer_replacement in student_replacements:
+ block_indices_not_represented_in_replacement = sorted(
+ set(range(n_layer)) - set(layer_replacement["parent_layer_indices"])
+ )
+ chosen_replacements = sort_replacements(
+ [layer_replacement]
+ + [
+ teacher_replacements[block_idx]
+ for block_idx in block_indices_not_represented_in_replacement
+ ]
+ )
+
+ block_configs = [
+ block_config
+ for replacement in chosen_replacements
+ for block_config in replacement["child_block_configs"]
+ ]
+
+ solutions.append(
+ {
+ "single_sequence_replacement": layer_replacement,
+ "chosen_replacements": chosen_replacements,
+ "block_configs": block_configs,
+ }
+ )
+
+ return solutions
diff --git a/modelopt/torch/puzzletron/replacement_library/replacement_library.py b/modelopt/torch/puzzletron/replacement_library/replacement_library.py
new file mode 100644
index 0000000000..c1eb0b9b48
--- /dev/null
+++ b/modelopt/torch/puzzletron/replacement_library/replacement_library.py
@@ -0,0 +1,171 @@
+# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+# SPDX-License-Identifier: Apache-2.0
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""
+Replacement library for loading models with layer replacements (AnyModel / sharded HF checkpoints).
+"""
+# mypy: ignore-errors
+
+import copy
+import json
+import tempfile
+from pathlib import Path
+from typing import List, Optional
+
+from immutabledict import immutabledict
+from safetensors import safe_open
+from transformers import PretrainedConfig, PreTrainedModel
+
+from modelopt.torch.puzzletron.anymodel.converter.converter import Converter
+from modelopt.torch.puzzletron.replacement_library.replacement_utils import (
+ extract_block_configs_and_locations,
+ parse_layer_replacement,
+ weights_path_to_checkpoint_dir,
+)
+from modelopt.torch.puzzletron.tools.checkpoint_utils import (
+ SAFETENSORS_SUBBLOCKS_DIR_NAME,
+ load_model_config,
+)
+from modelopt.torch.puzzletron.tools.checkpoint_utils_hf import save_model_config
+from modelopt.torch.puzzletron.tools.sharded_checkpoint_utils import load_and_shard_model
+
+
+class ReplacementLibrary:
+ def __init__(
+ self,
+ replacement_library_path: str | Path,
+ descriptor,
+ model_config_overrides: Optional[dict] = None,
+ ):
+ self.descriptor = descriptor
+ self.replacement_library = self._load_replacement_library(replacement_library_path)
+ self._ensure_all_checkpoints_are_split()
+ self.model_config_overrides = (
+ immutabledict(model_config_overrides) if (model_config_overrides is not None) else None
+ )
+
+ self._model_config = None
+ self._arbitrary_checkpoint_dir = None
+
+ @staticmethod
+ def _load_replacement_library(replacement_library_path: str | Path) -> list[dict]:
+ replacement_library = json.loads(Path(replacement_library_path).read_text())
+ replacement_library = [
+ parse_layer_replacement(layer_replacement) for layer_replacement in replacement_library
+ ]
+ return replacement_library
+
+ def _ensure_all_checkpoints_are_split(self) -> None:
+ checkpoint_dirs = self._get_all_checkpoint_dirs()
+ unsplit_checkpoints = []
+ for checkpoint_dir in checkpoint_dirs:
+ if not (checkpoint_dir / SAFETENSORS_SUBBLOCKS_DIR_NAME).exists():
+ unsplit_checkpoints.append(checkpoint_dir)
+ assert len(unsplit_checkpoints) == 0, f"Found unsplit checkpoints: {unsplit_checkpoints}"
+
+ @property
+ def model_config(self) -> PretrainedConfig:
+ if self._model_config is None:
+ trust_remote_code = self.descriptor.requires_trust_remote_code()
+ self._model_config = load_model_config(
+ self.get_arbitrary_checkpoint_dir(),
+ self.model_config_overrides,
+ ignore_unexpected_config_keys=True,
+ trust_remote_code=trust_remote_code,
+ )
+ return self._model_config
+
+ def create_model_config(self, layer_replacements: list[dict]):
+ block_configs, _ = extract_block_configs_and_locations(layer_replacements)
+ model_config = copy.deepcopy(self.model_config)
+ model_config.block_configs = block_configs
+ model_config.num_hidden_layers = len(block_configs)
+ return model_config
+
+ def _get_arbitrary_non_block_checkpoint_paths(self):
+ checkpoint_dir = Path(self.get_arbitrary_checkpoint_dir())
+ subblocks_dir = checkpoint_dir / SAFETENSORS_SUBBLOCKS_DIR_NAME
+ non_block_paths = [p for p in subblocks_dir.glob("*.safetensors") if "block_" not in p.name]
+ return non_block_paths
+
+ def create_index_file_from_weights(self, weight_paths: List[str]):
+ weight_map = {}
+ for weight_path in weight_paths:
+ weight_path = Path(weight_path)
+ with safe_open(str(weight_path), framework="pt", device="cpu") as f:
+ for tensor_name in f.keys():
+ weight_map[tensor_name] = f"{SAFETENSORS_SUBBLOCKS_DIR_NAME}/{weight_path.name}"
+ index = {"metadata": {"format": "pt"}, "weight_map": weight_map}
+ return index
+
+ def prepare_tmp_checkpoint_dir(
+ self,
+ tmpdir: Path,
+ model_config: PretrainedConfig,
+ layer_replacements: List[dict],
+ ):
+ arbitrary_checkpoint_dir = Path(self.get_arbitrary_checkpoint_dir())
+
+ weight_paths = self._get_arbitrary_non_block_checkpoint_paths()
+ for layer_replacement in layer_replacements:
+ weight_paths += layer_replacement["weight_paths"]
+
+ weights_index = self.create_index_file_from_weights(weight_paths)
+ index_path = tmpdir / "model.safetensors.index.json"
+ with index_path.open("w", encoding="utf-8") as out:
+ json.dump(weights_index, out, indent=2, sort_keys=True)
+
+ Converter.copy_checkpoint_files(arbitrary_checkpoint_dir, tmpdir)
+ save_model_config(model_config, tmpdir)
+
+ # create symlinks inside tmpdir
+ subblocks_dir = tmpdir / SAFETENSORS_SUBBLOCKS_DIR_NAME
+ subblocks_dir.mkdir(exist_ok=True)
+ for weight_path in weight_paths:
+ link_path = subblocks_dir / weight_path.name
+ link_path.symlink_to(weight_path)
+
+ def load_model(
+ self,
+ layer_replacements: list[dict],
+ ) -> PreTrainedModel:
+ """Load model using AnyModel approach with temporary checkpoint directory."""
+ model_config = self.create_model_config(layer_replacements)
+ with tempfile.TemporaryDirectory(prefix="replacement_solution_") as tmpdir:
+ tmpdir = Path(tmpdir)
+ self.prepare_tmp_checkpoint_dir(
+ tmpdir, model_config=model_config, layer_replacements=layer_replacements
+ )
+ model = load_and_shard_model(descriptor=self.descriptor, checkpoint_path=tmpdir)
+ return model
+
+ def get_arbitrary_checkpoint_dir(self) -> Path:
+ if self._arbitrary_checkpoint_dir is None:
+ self._arbitrary_checkpoint_dir = self._get_arbitrary_checkpoint_dir()
+ return self._arbitrary_checkpoint_dir
+
+ def _get_arbitrary_checkpoint_dir(self) -> Path:
+ for layer_replacement in self.replacement_library:
+ weight_paths = layer_replacement["weight_paths"]
+ if len(weight_paths) > 0:
+ return weights_path_to_checkpoint_dir(weight_paths[0])
+
+ def _get_all_checkpoint_dirs(self) -> list[Path]:
+ checkpoint_dirs = set()
+ for layer_replacement in self.replacement_library:
+ weight_paths = layer_replacement["weight_paths"]
+ for weights_path in weight_paths:
+ checkpoint_dir = weights_path_to_checkpoint_dir(weights_path)
+ checkpoint_dirs.add(checkpoint_dir)
+ return list(checkpoint_dirs)
diff --git a/modelopt/torch/puzzletron/replacement_library/replacement_utils.py b/modelopt/torch/puzzletron/replacement_library/replacement_utils.py
new file mode 100644
index 0000000000..269e5e63ea
--- /dev/null
+++ b/modelopt/torch/puzzletron/replacement_library/replacement_utils.py
@@ -0,0 +1,122 @@
+# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+# SPDX-License-Identifier: Apache-2.0
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""This module provides helper functions for parsing, sorting, and analyzing layer replacement
+configurations used in the replacement library for model compression.
+"""
+
+# mypy: ignore-errors
+import json
+from copy import deepcopy
+from pathlib import Path
+
+from transformers import PretrainedConfig
+
+from modelopt.torch.puzzletron.decilm.deci_lm_hf_code.block_config import BlockConfig
+from modelopt.torch.puzzletron.mip.utils import sort_replacements
+
+
+def parse_layer_replacement(layer_replacement: dict | str) -> dict:
+ if isinstance(layer_replacement, str):
+ layer_replacement = json.loads(layer_replacement)
+ else:
+ layer_replacement = deepcopy(layer_replacement)
+
+ if "layer_replacement" in layer_replacement: # happens in puzzle solutions
+ layer_replacement = layer_replacement["layer_replacement"]
+
+ layer_replacement["child_block_configs"] = [
+ BlockConfig(**block_config) if isinstance(block_config, dict) else block_config
+ for block_config in layer_replacement["child_block_configs"]
+ ]
+ layer_replacement["weight_paths"] = [Path(p) for p in layer_replacement["weight_paths"]]
+ return layer_replacement
+
+
+# sort_replacements moved to modelopt.torch.puzzletron.mip.utils and imported above
+
+
+def extract_block_configs_and_locations(
+ layer_replacements: list[dict],
+) -> tuple[list[BlockConfig], list[tuple[dict, int]]]:
+ layer_replacements = sort_replacements(layer_replacements)
+ block_configs = []
+ block_locations = []
+ for layer_replacement in layer_replacements:
+ child_block_configs = layer_replacement["child_block_configs"]
+ if not isinstance(child_block_configs, list | tuple):
+ child_block_configs = [child_block_configs]
+ for block_idx_in_replacement, block_config in enumerate(child_block_configs):
+ block_configs.append(block_config)
+ block_locations.append((layer_replacement, block_idx_in_replacement))
+ return block_configs, block_locations
+
+
+def weights_path_to_checkpoint_dir(weights_path: Path) -> Path:
+ checkpoint_dir: Path = weights_path
+ while checkpoint_dir != Path("/"):
+ if (checkpoint_dir / "config.json").exists():
+ return checkpoint_dir
+ checkpoint_dir = checkpoint_dir.parent
+ raise FileNotFoundError(f"Couldn't find checkpoint dir for weights path {weights_path}")
+
+
+def replacement_is_teacher(
+ layer_replacement: dict,
+ teacher_model_config: PretrainedConfig,
+ teacher_checkpoint_dir: Path,
+) -> bool:
+ paths_all_teacher = all(
+ p.is_relative_to(teacher_checkpoint_dir) for p in layer_replacement["weight_paths"]
+ )
+ return paths_all_teacher and is_replacement_identical_to_teacher(
+ layer_replacement, teacher_model_config
+ )
+
+
+def is_replacement_identical_to_teacher(
+ layer_replacement: dict,
+ teacher_model_config: PretrainedConfig,
+) -> bool:
+ if len(layer_replacement["parent_layer_indices"]) == 1:
+ block_idx = layer_replacement["parent_layer_indices"][0]
+ teacher_block_config = teacher_model_config.block_configs[block_idx]
+ if len(child_block_configs := layer_replacement["child_block_configs"]) == 1:
+ replacement_block_config: BlockConfig = child_block_configs[0]
+ if replacement_block_config == teacher_block_config:
+ return True
+ else:
+ parallel_blocks = getattr(replacement_block_config, "parallel_blocks", None)
+ if (
+ parallel_blocks is not None
+ and len(parallel_blocks) == 1
+ and parallel_blocks[0].attention == teacher_block_config.attention
+ and parallel_blocks[0].ffn == teacher_block_config.ffn
+ ):
+ return True
+ return False
+
+
+def split_replacements_to_teacher_and_student(
+ replacements: list[dict],
+ teacher_model_config: PretrainedConfig,
+ teacher_checkpoint_dir: Path,
+) -> tuple[list[dict], list[dict]]:
+ teacher_replacements, student_replacements = [], []
+ for replacement in replacements:
+ if replacement_is_teacher(replacement, teacher_model_config, teacher_checkpoint_dir):
+ teacher_replacements.append(replacement)
+ else:
+ student_replacements.append(replacement)
+ return teacher_replacements, student_replacements
diff --git a/modelopt/torch/puzzletron/scoring/scoring.py b/modelopt/torch/puzzletron/scoring/scoring.py
new file mode 100644
index 0000000000..8f1871de89
--- /dev/null
+++ b/modelopt/torch/puzzletron/scoring/scoring.py
@@ -0,0 +1,90 @@
+# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+# SPDX-License-Identifier: Apache-2.0
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""Validates and scores model compression solutions by evaluating puzzle solution candidates."""
+
+# mypy: ignore-errors
+import os
+import re
+from glob import glob
+
+import hydra
+import numpy as np
+import pandas as pd
+from omegaconf import DictConfig
+
+import modelopt.torch.utils.distributed as dist
+from modelopt.torch.puzzletron.tools.hydra_utils import register_hydra_resolvers
+from modelopt.torch.puzzletron.tools.logger import mprint
+from modelopt.torch.puzzletron.tools.validate_puzzle_with_multi_replacements import (
+ validate_puzzle_solutions,
+)
+
+
+def extract_solution_id(filename):
+ pattern = r"solution_(\d+)\.json"
+ match = re.search(pattern, filename)
+
+ if match:
+ solution_id = match.group(1)
+ return int(solution_id)
+ else:
+ mprint(f"Couldn't extract solutions_id from file {filename}")
+
+
+def find_missing_solutions(solutions_df, validation_dir):
+ all_solutions = np.arange(solutions_df.shape[0])
+
+ benchmarked_solutions = list(glob(f"{validation_dir}/solution*.json"))
+ benchmarked_solutions = [
+ extract_solution_id(os.path.basename(s)) for s in benchmarked_solutions
+ ]
+ benchmarked_solutions = [s for s in benchmarked_solutions if s is not None]
+
+ unbenchmarked_solutions = np.setdiff1d(all_solutions, benchmarked_solutions)
+ return unbenchmarked_solutions.tolist()
+
+
+def get_solutions_to_validate(cfg: DictConfig):
+ _solutions_to_validate = cfg.scoring.solutions_to_validate
+ if _solutions_to_validate is None:
+ single_block_replacement_solutions = pd.read_json(cfg.scoring.solutions_path)
+ if cfg.scoring.skip_existing_solutions:
+ _solutions_to_validate = find_missing_solutions(
+ single_block_replacement_solutions, cfg.scoring.output_dir
+ )
+ else:
+ _solutions_to_validate = np.arange(single_block_replacement_solutions.shape[0]).tolist()
+ return _solutions_to_validate
+
+
+def launch_scoring(cfg: DictConfig):
+ cfg.scoring.solutions_to_validate = get_solutions_to_validate(cfg)
+ mprint(f"Solutions to validate: {cfg.scoring.solutions_to_validate}")
+ validate_puzzle_solutions(args=cfg.scoring)
+
+
+@hydra.main("", version_base="1.3")
+def main(cfg: DictConfig) -> None:
+ cfg = hydra.utils.instantiate(cfg)
+ mprint(cfg)
+ dist.setup(timeout=cfg.nccl_timeout_minutes)
+ launch_scoring(cfg)
+ dist.cleanup()
+
+
+if __name__ == "__main__":
+ register_hydra_resolvers()
+ main()
diff --git a/modelopt/torch/puzzletron/sewing_kit/__init__.py b/modelopt/torch/puzzletron/sewing_kit/__init__.py
new file mode 100644
index 0000000000..c8f7ffa013
--- /dev/null
+++ b/modelopt/torch/puzzletron/sewing_kit/__init__.py
@@ -0,0 +1,35 @@
+# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+# SPDX-License-Identifier: Apache-2.0
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# mypy: ignore-errors
+
+from .core import (
+ CantResolveNodeDependenciesException,
+ ConstantTarget,
+ ExternalTarget,
+ FunctionTarget,
+ InputsLoopFoundException,
+ KnotException,
+ LoopFoundException,
+ ModuleTarget,
+ MultipleExternalNodesException,
+ Needle,
+ OnlyInternalNodesException,
+ OutputsLoopFoundException,
+ RemoteTarget,
+ StitchedModule,
+ StitchedModuleException,
+ StitchedModuleOutput,
+)
+from .passage import InputArgs, always_false_predicate, always_true_predicate
diff --git a/modelopt/torch/puzzletron/sewing_kit/core.py b/modelopt/torch/puzzletron/sewing_kit/core.py
new file mode 100644
index 0000000000..fb9055c3ed
--- /dev/null
+++ b/modelopt/torch/puzzletron/sewing_kit/core.py
@@ -0,0 +1,883 @@
+# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+# SPDX-License-Identifier: Apache-2.0
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+# mypy: ignore-errors
+
+from __future__ import annotations
+
+from abc import ABC
+from collections import defaultdict
+from dataclasses import dataclass, field
+from typing import Any, Callable, Iterable, Literal, Optional, Sequence, Union
+
+from typing_extensions import override
+
+try:
+ from typing import Self
+except ImportError:
+ from typing_extensions import Self
+
+import torch
+import torch.distributed
+import torch.nn as nn
+
+from .passage import (
+ InputArgs,
+ OutputValue,
+ Passage,
+ PassageInputAdapter,
+ PassageInputOverrides,
+ PassageOutputAdapter,
+ PassageOutputOverrides,
+ Predicate,
+ always_false_predicate,
+)
+from .utils import distributed_isend_obj, distributed_recv_obj, dynamo_skip
+
+InputAdapter = Callable[[InputArgs], InputArgs]
+OutputAdapter = Callable[..., OutputValue]
+
+
+def default_input_adapter_fn(input_values: InputArgs) -> InputArgs:
+ return input_values
+
+
+def default_output_adapter_fn(v: OutputValue) -> OutputValue:
+ return v
+
+
+@dataclass
+class IOReducer:
+ pass
+
+
+def default_input_reducer_fn(acc: InputArgs, input_override: InputArgs, *args):
+ return acc + input_override
+
+
+@dataclass
+class InputReducer(IOReducer):
+ reducer_fn: Callable[[InputArgs, InputArgs, InputArgs, int, list[InputArgs]], InputArgs] = (
+ default_input_reducer_fn
+ )
+
+ def __call__(
+ self,
+ acc: InputArgs,
+ input_override: InputArgs,
+ original_input: InputArgs,
+ index: int,
+ all_input_overrides: list[InputArgs],
+ ) -> InputArgs:
+ result = self.reducer_fn(acc, input_override, original_input, index, all_input_overrides)
+ return result
+
+ @classmethod
+ def default(cls) -> InputReducer:
+ return InputReducer()
+
+
+def default_output_reducer_fn(acc: OutputValue, input_override: OutputValue, *args):
+ return input_override
+
+
+@dataclass
+class OutputReducer(IOReducer):
+ reducer_fn: Callable[
+ [OutputValue, OutputValue, Optional[OutputValue], int, list[OutputValue]], OutputValue
+ ] = default_output_reducer_fn
+ requires_original_output: bool = False
+
+ def __call__(
+ self,
+ acc: OutputValue,
+ output_override: OutputValue,
+ original_output: Optional[OutputValue],
+ index: int,
+ all_output_overrides: list[OutputValue],
+ ) -> InputArgs:
+ result = self.reducer_fn(acc, output_override, original_output, index, all_output_overrides)
+ return result
+
+ @classmethod
+ def default(cls) -> OutputReducer:
+ return OutputReducer()
+
+
+class Singleton(type):
+ _instances = {}
+
+ @override
+ def __call__(cls, *args, **kwargs):
+ if cls not in cls._instances:
+ cls._instances[cls] = super(Singleton, cls).__call__(*args, **kwargs)
+ return cls._instances[cls]
+
+
+@dataclass
+class Target:
+ @override
+ def __hash__(self) -> int:
+ return id(self)
+
+
+@dataclass
+class TargetWithInput(Target):
+ @override
+ def __hash__(self) -> int:
+ return super().__hash__()
+
+ def input(
+ self,
+ adapter: InputAdapter = default_input_adapter_fn,
+ reducer: InputReducer = InputReducer.default(),
+ ) -> InputDescriptor:
+ result = InputDescriptor(self, input_name="", input_adapter=adapter, reducer=reducer)
+ return result
+
+
+@dataclass
+class TargetWithNamedInputs(Target):
+ @override
+ def __hash__(self) -> int:
+ return super().__hash__()
+
+ def input(
+ self,
+ name: str,
+ adapter: InputAdapter = default_input_adapter_fn,
+ reducer: InputReducer = InputReducer.default(),
+ ) -> InputDescriptor:
+ result = InputDescriptor(self, input_name=name, input_adapter=adapter, reducer=reducer)
+ return result
+
+
+@dataclass
+class TargetWithOutput(Target):
+ @override
+ def __hash__(self) -> int:
+ return super().__hash__()
+
+ def output(
+ self,
+ adapter: OutputAdapter = default_output_adapter_fn,
+ reducer: OutputReducer = OutputReducer.default(),
+ ) -> OutputDescriptor:
+ result = OutputDescriptor(self, output_name="", output_adapter=adapter, reducer=reducer)
+ return result
+
+
+@dataclass
+class TargetWithNamedOutputs(Target):
+ @override
+ def __hash__(self) -> int:
+ return super().__hash__()
+
+ def output(
+ self,
+ name: str,
+ adapter: OutputAdapter = default_output_adapter_fn,
+ reducer: OutputReducer = OutputReducer.default(),
+ ) -> OutputDescriptor:
+ result = OutputDescriptor(self, output_name=name, output_adapter=adapter, reducer=reducer)
+ return result
+
+
+@dataclass
+class ExternalTarget(TargetWithNamedInputs, TargetWithNamedOutputs, metaclass=Singleton):
+ """External target for stitched modules."""
+
+ @override
+ def __hash__(self) -> int:
+ return super().__hash__()
+
+
+@dataclass
+class ConstantTarget(TargetWithOutput):
+ name: str
+ value: Any
+
+ @override
+ def __hash__(self) -> int:
+ return super().__hash__()
+
+
+@dataclass
+class FunctionTarget(TargetWithInput, TargetWithOutput):
+ name: str
+ function: Callable[..., Any]
+
+ @override
+ def __hash__(self) -> int:
+ return super().__hash__()
+
+
+@dataclass
+class ModuleTarget(TargetWithNamedInputs, TargetWithNamedOutputs):
+ name: str
+ module: nn.Module
+
+ @override
+ def __str__(self) -> str:
+ return f"ModuleTarget({self.name})"
+
+ @override
+ def __repr__(self) -> str:
+ return str(self)
+
+ @override
+ def __hash__(self) -> int:
+ return super().__hash__()
+
+
+@dataclass
+class RemoteTarget(Target):
+ peer_rank: Union[int, Sequence[int]]
+ process_group: Optional[torch.distributed.ProcessGroup] = None
+ blocking: bool = True
+
+ @override
+ def __hash__(self) -> int:
+ return super().__hash__()
+
+ def value(
+ self,
+ name: str,
+ adapter: OutputAdapter = default_output_adapter_fn,
+ reducer: OutputReducer = OutputReducer.default(),
+ ) -> OutputDescriptor:
+ result = OutputDescriptor(self, output_name=name, output_adapter=adapter, reducer=reducer)
+ return result
+
+
+@dataclass(frozen=True, eq=True)
+class RemoteDataDescriptor(ABC):
+ key: str
+
+
+@dataclass(frozen=True, eq=True)
+class RemoteTensorDataDescriptor(RemoteDataDescriptor):
+ device: Literal["cuda", "cpu"]
+ dtype: torch.dtype
+ shape: torch.Size
+
+
+@dataclass(frozen=True, eq=True)
+class RemotePythonDataDescriptor(RemoteDataDescriptor):
+ value: Any
+
+
+@dataclass
+class Node:
+ target: Target
+ stitches_to: list[StitchDescriptor] = field(default_factory=list)
+ stitches_from: list[StitchDescriptor] = field(default_factory=list)
+
+ @override
+ def __hash__(self) -> int:
+ return id(self)
+
+
+@dataclass
+class InputDescriptor:
+ target: Target
+ input_name: str = ""
+ input_adapter: InputAdapter = field(default=default_input_adapter_fn)
+ reducer: InputReducer = field(default_factory=InputReducer.default)
+
+ @override
+ def __hash__(self) -> int:
+ return id(self)
+
+
+@dataclass
+class OutputDescriptor:
+ target: Target
+ output_name: str = ""
+ output_adapter: OutputAdapter = field(default=default_output_adapter_fn)
+ reducer: OutputReducer = field(default_factory=OutputReducer.default)
+
+ @override
+ def __hash__(self) -> int:
+ return id(self)
+
+
+IODescriptor = Union[InputDescriptor, OutputDescriptor]
+
+
+@dataclass
+class StitchDescriptor:
+ source_descriptor: IODescriptor
+ destination_descriptor: IODescriptor
+
+ @override
+ def __hash__(self) -> int:
+ return id(self)
+
+
+@dataclass
+class StitchedModuleOutput:
+ captured_inputs: dict[str, InputArgs]
+ captured_outputs: dict[str, Any]
+
+
+class StitchedModuleException(Exception):
+ pass
+
+
+class CantResolveNodeDependenciesException(StitchedModuleException):
+ pass
+
+
+class StitchedModule(nn.Module):
+ def __init__(
+ self,
+ nodes: dict[Target, Node],
+ capture_cache_outputs_predicate: Predicate = always_false_predicate,
+ early_exit=True,
+ ignore_extra_overrides=False,
+ ) -> None:
+ super().__init__()
+ self.nodes = nodes
+ self.ignore_extra_overrides = ignore_extra_overrides
+ external_nodes = [n for n in nodes.values() if isinstance(n.target, ExternalTarget)]
+ remote_nodes = [n for n in nodes.values() if isinstance(n.target, RemoteTarget)]
+ assert len(external_nodes) <= 1
+ assert len(remote_nodes) + len(external_nodes) > 0
+ self.external_node = external_nodes[0] if len(external_nodes) > 0 else None
+ self.internal_nodes = [
+ n for n in nodes.values() if not isinstance(n.target, ExternalTarget)
+ ]
+ self.values_from_node: dict[Node, dict[IODescriptor, Any]] = defaultdict(dict)
+ self.values_to_node: dict[Node, dict[IODescriptor, Any]] = defaultdict(dict)
+
+ self.node_passages: dict[Node, Passage] = {
+ node: Passage.create(
+ module=node.target.module,
+ inputs_to_capture=set(
+ s.source_descriptor.input_name
+ for s in node.stitches_from
+ if isinstance(s.source_descriptor, InputDescriptor)
+ ),
+ outputs_to_capture=set(
+ s.source_descriptor.output_name
+ for s in node.stitches_from
+ if isinstance(s.source_descriptor, OutputDescriptor)
+ ),
+ capture_cache_outputs_predicate=capture_cache_outputs_predicate,
+ early_exit=early_exit,
+ name=getattr(node.target, "name", None),
+ )
+ for node in self.internal_nodes
+ if isinstance(node.target, ModuleTarget)
+ }
+
+ self.passage_modules = nn.ModuleDict(
+ {
+ f"node_{node_index}": self.node_passages[node]
+ for node_index, node in enumerate(nodes.values())
+ if node in self.node_passages
+ }
+ )
+ self.adapter_modules = nn.ModuleDict(
+ {
+ f"node_{node_index}__stitch_{stitch_index}__{descriptor_name}": adapter
+ for node_index, node in enumerate(nodes.values())
+ for stitch_index, stitch in enumerate(node.stitches_from + node.stitches_to)
+ for descriptor_name, descriptor in (
+ ("source", stitch.source_descriptor),
+ ("destination", stitch.destination_descriptor),
+ )
+ for adapter in [
+ descriptor.input_adapter
+ if isinstance(descriptor, InputDescriptor)
+ else descriptor.output_adapter
+ ]
+ if isinstance(adapter, nn.Module)
+ }
+ )
+
+ def create_input_overrides(
+ self, values_to_node: dict[IODescriptor, Any]
+ ) -> PassageInputOverrides:
+ input_descriptors_by_group = defaultdict[str, list[InputDescriptor]](list)
+ for io_descriptor in values_to_node.keys():
+ if isinstance(io_descriptor, InputDescriptor):
+ input_descriptors_by_group[io_descriptor.input_name].append(io_descriptor)
+
+ input_overrides = PassageInputOverrides()
+ for group, input_descriptors in input_descriptors_by_group.items():
+ reducers = [d.reducer for d in input_descriptors]
+
+ def create_reducer(input_descriptors=input_descriptors, reducers=reducers):
+ inputs = [values_to_node[d] for d in input_descriptors]
+
+ def reducer_fn(
+ original_input: InputArgs,
+ module_name: Optional[str],
+ module: Optional[nn.Module],
+ ) -> InputArgs:
+ acc = InputArgs()
+ for i, (input_, reducer) in enumerate(zip(inputs, reducers)):
+ acc = reducer(acc, input_, original_input, i, inputs)
+ return acc
+
+ return reducer_fn
+
+ input_override = PassageInputAdapter(create_reducer())
+ input_overrides[group] = input_override
+
+ return input_overrides
+
+ def create_output_overrides(
+ self, values_to_node: dict[IODescriptor, Any]
+ ) -> PassageOutputOverrides:
+ output_descriptors_by_group = defaultdict[str, list[OutputDescriptor]](list)
+ for io_descriptor in values_to_node.keys():
+ if isinstance(io_descriptor, OutputDescriptor):
+ output_descriptors_by_group[io_descriptor.output_name].append(io_descriptor)
+
+ output_overrides = PassageOutputOverrides()
+ for group, output_descriptors in output_descriptors_by_group.items():
+ reducers = [d.reducer for d in output_descriptors]
+ requires_original_output = any(r.requires_original_output for r in reducers)
+
+ def create_reducer(reducers=reducers):
+ outputs = [values_to_node[d] for d in output_descriptors]
+
+ def reducer_fn(
+ original_output: Optional[OutputValue],
+ module_name: Optional[str],
+ module: Optional[nn.Module],
+ ) -> OutputValue:
+ acc = None
+ for i, (output, reducer) in enumerate(zip(outputs, reducers)):
+ acc = reducer(acc, output, original_output, i, outputs)
+ return acc
+
+ return reducer_fn
+
+ reducer_fn = create_reducer()
+ if requires_original_output:
+ output_override = PassageOutputAdapter(reducer_fn)
+ else:
+ output_override = reducer_fn(None, None, None)
+
+ output_overrides[group] = output_override
+
+ return output_overrides
+
+ @override
+ def __call__(
+ self,
+ input_overrides: dict[str, Any],
+ output_overrides: dict[str, Any],
+ *args,
+ **kwargs,
+ ) -> StitchedModuleOutput:
+ return super().__call__(input_overrides, output_overrides, *args, **kwargs)
+
+ @override
+ @dynamo_skip
+ def forward(
+ self,
+ input_overrides: dict[str, Any],
+ output_overrides: dict[str, Any],
+ *args,
+ **kwargs,
+ ) -> StitchedModuleOutput:
+ input_overrides = {k: InputArgs.from_value(v) for k, v in input_overrides.items()}
+
+ self.values_from_node.clear()
+ self.values_to_node.clear()
+
+ unresolved_count: int = 0
+ nodes_stack: list[Node] = (
+ [] if self.external_node is None else [self.external_node]
+ ) + self.internal_nodes
+ while len(nodes_stack) > 0:
+ node = nodes_stack.pop(0)
+ values_from_node = self.values_from_node[node]
+ values_to_node = self.values_to_node[node]
+
+ if isinstance(node.target, ExternalTarget):
+ assert self.external_node is not None
+
+ if not self.ignore_extra_overrides:
+ input_override_names = set(input_overrides.keys())
+ external_node_input_names = set(
+ s.source_descriptor.input_name
+ for s in self.external_node.stitches_from
+ if isinstance(s.source_descriptor, InputDescriptor)
+ )
+ assert input_override_names == external_node_input_names
+ output_override_names = set(output_overrides.keys())
+ external_node_output_names = set(
+ s.source_descriptor.output_name
+ for s in self.external_node.stitches_from
+ if isinstance(s.source_descriptor, OutputDescriptor)
+ )
+ assert output_override_names == external_node_output_names
+
+ for stitch in self.external_node.stitches_from:
+ if isinstance(stitch.source_descriptor, InputDescriptor):
+ orig_input_override = input_overrides[stitch.source_descriptor.input_name]
+ input_override = stitch.source_descriptor.input_adapter(orig_input_override)
+ values_from_node[stitch.source_descriptor] = input_override
+ elif isinstance(stitch.source_descriptor, OutputDescriptor):
+ orig_output_override = output_overrides[
+ stitch.source_descriptor.output_name
+ ]
+ output_override = stitch.source_descriptor.output_adapter(
+ orig_output_override
+ )
+ values_from_node[stitch.source_descriptor] = output_override
+ else:
+ raise RuntimeError("Shouldn't happen")
+
+ else:
+ if len(values_to_node) < len(node.stitches_to):
+ nodes_stack.append(node)
+ unresolved_count += 1
+ if unresolved_count >= len(nodes_stack):
+ raise CantResolveNodeDependenciesException(
+ "Can't resolve nodes dependencies"
+ )
+ continue
+
+ if isinstance(node.target, ConstantTarget):
+ assert len(values_to_node) == 0
+
+ output_value = node.target.value
+
+ for stitch in node.stitches_from:
+ assert isinstance(stitch.source_descriptor, OutputDescriptor)
+ assert stitch.source_descriptor.output_name == ""
+ value = stitch.source_descriptor.output_adapter(output_value)
+ values_from_node[stitch.source_descriptor] = value
+
+ elif isinstance(node.target, FunctionTarget):
+ assert all(
+ isinstance(v, InputDescriptor) and v.input_name == ""
+ for v in values_to_node
+ )
+
+ function_input_overrides = self.create_input_overrides(values_to_node)[""]
+
+ if isinstance(function_input_overrides, InputArgs):
+ input_args = function_input_overrides
+ else:
+ input_args = function_input_overrides(InputArgs(), None, None)
+
+ function_output = node.target.function(*input_args.args, **input_args.kwargs)
+
+ for stitch in node.stitches_from:
+ assert isinstance(stitch.source_descriptor, OutputDescriptor)
+ assert stitch.source_descriptor.output_name == ""
+ value = stitch.source_descriptor.output_adapter(function_output)
+ values_from_node[stitch.source_descriptor] = value
+
+ elif isinstance(node.target, ModuleTarget):
+ passage = self.node_passages[node]
+ passage.input_overrides = self.create_input_overrides(values_to_node)
+ passage.output_overrides = self.create_output_overrides(values_to_node)
+ passage_output = passage(*args, **kwargs)
+
+ for stitch in node.stitches_from:
+ if isinstance(stitch.source_descriptor, InputDescriptor):
+ captured_input = passage_output.captured_inputs[
+ stitch.source_descriptor.input_name
+ ]
+ value = stitch.source_descriptor.input_adapter(captured_input)
+ values_from_node[stitch.source_descriptor] = value
+ elif isinstance(stitch.source_descriptor, OutputDescriptor):
+ captured_output = passage_output.captured_outputs[
+ stitch.source_descriptor.output_name
+ ]
+ value = stitch.source_descriptor.output_adapter(captured_output)
+ values_from_node[stitch.source_descriptor] = value
+ else:
+ raise RuntimeError("Shouldn't happen")
+
+ elif isinstance(node.target, RemoteTarget):
+ assert all(
+ isinstance(v, OutputDescriptor) and v.output_name != ""
+ for v in values_from_node
+ )
+ assert all(
+ isinstance(v, OutputDescriptor) and v.output_name != ""
+ for v in values_to_node
+ )
+
+ process_group = node.target.process_group
+ peers = node.target.peer_rank
+ if not isinstance(peers, Sequence):
+ peers = [peers]
+
+ if len(values_to_node) > 0:
+ items_to_send = list(self.create_output_overrides(values_to_node).items())
+
+ data_descriptors: list[RemoteDataDescriptor] = []
+ tensors_to_send: list[torch.Tensor] = []
+
+ for key, value in items_to_send:
+ if isinstance(value, torch.Tensor):
+ if value.is_cuda:
+ tensor_device = "cuda"
+ elif value.is_cpu:
+ tensor_device = "cpu"
+ else:
+ raise RuntimeError(
+ f"Invalid tensor device to send to remote target: {value.device}"
+ )
+
+ data_descriptor = RemoteTensorDataDescriptor(
+ key=key,
+ device=tensor_device,
+ dtype=value.dtype,
+ shape=value.shape,
+ )
+ tensors_to_send.append(value)
+
+ else:
+ data_descriptor = RemotePythonDataDescriptor(
+ key=key,
+ value=value,
+ )
+
+ data_descriptors.append(data_descriptor)
+
+ works: list[Optional[torch.distributed.Work]] = []
+ for peer in peers:
+ if process_group is not None:
+ peer = torch.distributed.get_global_rank(process_group, peer)
+
+ peer_works = distributed_isend_obj(data_descriptors, dst=peer)
+ works.extend(peer_works)
+
+ for tensor in tensors_to_send:
+ work = torch.distributed.isend(tensor, dst=peer)
+ works.append(work)
+
+ if node.target.blocking:
+ for work in works:
+ if work is not None:
+ work.wait()
+
+ if len(node.stitches_from) > 0:
+ assert len(peers) == 1, (
+ f"Cannot use multiple peers when using RemoteTarget as a source ({peers=})"
+ )
+ (peer,) = peers
+
+ if process_group is not None:
+ peer = torch.distributed.get_global_rank(process_group, peer)
+
+ data_descriptors = distributed_recv_obj(src=peer)
+ assert isinstance(data_descriptors, list)
+
+ tensors_to_recv: list[torch.Tensor] = []
+ received_values: dict[str, Any] = {}
+ for data_descriptor in data_descriptors:
+ if isinstance(data_descriptor, RemoteTensorDataDescriptor):
+ tensor = torch.empty(
+ data_descriptor.shape,
+ dtype=data_descriptor.dtype,
+ device=data_descriptor.device,
+ )
+ tensors_to_recv.append(tensor)
+ received_values[data_descriptor.key] = tensor
+ elif isinstance(data_descriptor, RemotePythonDataDescriptor):
+ received_values[data_descriptor.key] = data_descriptor.value
+ else:
+ raise RuntimeError(
+ f"Received invalid data descriptor from remote peer: {data_descriptor}"
+ )
+
+ works: list[Optional[torch.distributed.Work]] = []
+ for tensor in tensors_to_recv:
+ work = torch.distributed.irecv(tensor, src=peer)
+ works.append(work)
+
+ for work in works:
+ if work is not None:
+ work.wait()
+
+ for stitch in node.stitches_from:
+ if isinstance(stitch.source_descriptor, OutputDescriptor):
+ remote_output = received_values[
+ stitch.source_descriptor.output_name
+ ]
+ value = stitch.source_descriptor.output_adapter(remote_output)
+ values_from_node[stitch.source_descriptor] = value
+ else:
+ raise RuntimeError("Shouldn't happen")
+ else:
+ raise RuntimeError("Shouldn't happen")
+
+ for stitch in node.stitches_from:
+ dst_node = self.nodes[stitch.destination_descriptor.target]
+ value = values_from_node[stitch.source_descriptor]
+
+ if isinstance(stitch.destination_descriptor, InputDescriptor):
+ value = stitch.destination_descriptor.input_adapter(value)
+ elif isinstance(stitch.destination_descriptor, OutputDescriptor):
+ value = stitch.destination_descriptor.output_adapter(value)
+ else:
+ raise RuntimeError("Shouldn't happen")
+
+ self.values_to_node[dst_node][stitch.destination_descriptor] = value
+
+ unresolved_count = 0
+
+ values_to_external_node = (
+ {} if self.external_node is None else self.values_to_node[self.external_node]
+ )
+ output = StitchedModuleOutput(
+ captured_inputs={
+ k.input_name: v
+ for k, v in values_to_external_node.items()
+ if isinstance(k, InputDescriptor)
+ },
+ captured_outputs={
+ k.output_name: v
+ for k, v in values_to_external_node.items()
+ if isinstance(k, OutputDescriptor)
+ },
+ )
+
+ self.values_from_node.clear()
+ self.values_to_node.clear()
+
+ return output
+
+
+class KnotException(Exception):
+ pass
+
+
+class LoopFoundException(KnotException):
+ pass
+
+
+class InputsLoopFoundException(LoopFoundException):
+ pass
+
+
+class OutputsLoopFoundException(LoopFoundException):
+ pass
+
+
+class MultipleExternalNodesException(KnotException):
+ pass
+
+
+class OnlyInternalNodesException(KnotException):
+ pass
+
+
+class Needle:
+ def __init__(self) -> None:
+ self.nodes = dict[Target, Node]()
+
+ def get_node_for_target(self, target: Target) -> Node:
+ if target not in self.nodes:
+ node = Node(target=target)
+ self.nodes[target] = node
+ else:
+ node = self.nodes[target]
+
+ return node
+
+ def stitch(self, src: IODescriptor, dst: IODescriptor) -> Self:
+ descriptor = StitchDescriptor(source_descriptor=src, destination_descriptor=dst)
+
+ src_node = self.get_node_for_target(descriptor.source_descriptor.target)
+ dst_node = self.get_node_for_target(descriptor.destination_descriptor.target)
+
+ if descriptor not in src_node.stitches_from:
+ src_node.stitches_from.append(descriptor)
+
+ if descriptor not in dst_node.stitches_to:
+ dst_node.stitches_to.append(descriptor)
+
+ return self
+
+ def _search_loops(
+ self,
+ node: Node,
+ expand_fn: Callable[[Node], Iterable[IODescriptor]],
+ traversed_nodes: Optional[set[Node]] = None,
+ ) -> bool:
+ if isinstance(node.target, ExternalTarget):
+ return False
+
+ if traversed_nodes is None:
+ traversed_nodes = set()
+
+ if node in traversed_nodes:
+ found_loop = True
+ else:
+ traversed_nodes = traversed_nodes | {node}
+ found_loop = False
+ descriptors = expand_fn(node)
+ for descriptor in descriptors:
+ stitch_node = self.get_node_for_target(descriptor.target)
+ found_loop |= self._search_loops(stitch_node, expand_fn, traversed_nodes)
+
+ return found_loop
+
+ def _validate_nodes(self):
+ # internal_nodes = [n for n in self.nodes.values() if not isinstance(n.target, (ExternalTarget, RemoteTarget))]
+ external_nodes = [n for n in self.nodes.values() if isinstance(n.target, ExternalTarget)]
+ remote_nodes = [n for n in self.nodes.values() if isinstance(n.target, RemoteTarget)]
+
+ if len(external_nodes) + len(remote_nodes) == 0:
+ raise OnlyInternalNodesException(f"Has only internal nodes")
+
+ if len(external_nodes) > 1:
+ raise MultipleExternalNodesException(
+ f"Expected no more than 1 external node, found {len(external_nodes)}"
+ )
+
+ for i, node in enumerate(self.nodes.values()):
+ found_inputs_loop = self._search_loops(
+ node, lambda n: [s.source_descriptor for s in n.stitches_to]
+ )
+ if found_inputs_loop:
+ raise InputsLoopFoundException(f"Found a loop in inputs of node {i}: {node}")
+
+ found_outputs_loop = self._search_loops(
+ node, lambda n: [s.destination_descriptor for s in n.stitches_from]
+ )
+ if found_outputs_loop:
+ raise OutputsLoopFoundException(f"Found a loop in outputs of node {i}: {node}")
+
+ def knot(
+ self,
+ capture_cache_outputs_predicate=always_false_predicate,
+ early_exit=True,
+ ignore_extra_overrides=False,
+ ) -> StitchedModule:
+ self._validate_nodes()
+
+ module = StitchedModule(
+ nodes=self.nodes,
+ capture_cache_outputs_predicate=capture_cache_outputs_predicate,
+ early_exit=early_exit,
+ ignore_extra_overrides=ignore_extra_overrides,
+ )
+
+ return module
diff --git a/modelopt/torch/puzzletron/sewing_kit/passage/__init__.py b/modelopt/torch/puzzletron/sewing_kit/passage/__init__.py
new file mode 100644
index 0000000000..02f4c19eac
--- /dev/null
+++ b/modelopt/torch/puzzletron/sewing_kit/passage/__init__.py
@@ -0,0 +1,29 @@
+# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+# SPDX-License-Identifier: Apache-2.0
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+from .core import (
+ InputArgs,
+ OutputValue,
+ Passage,
+ PassageInputAdapter,
+ PassageInputOverrides,
+ PassageOutput,
+ PassageOutputAdapter,
+ PassageOutputOverrides,
+ Predicate,
+ always_false_predicate,
+ always_true_predicate,
+)
diff --git a/modelopt/torch/puzzletron/sewing_kit/passage/core.py b/modelopt/torch/puzzletron/sewing_kit/passage/core.py
new file mode 100644
index 0000000000..c0fcb4b123
--- /dev/null
+++ b/modelopt/torch/puzzletron/sewing_kit/passage/core.py
@@ -0,0 +1,453 @@
+# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+# SPDX-License-Identifier: Apache-2.0
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+# mypy: ignore-errors
+from __future__ import annotations
+
+from collections.abc import Callable, Sequence
+from dataclasses import dataclass
+from typing import Any, ContextManager, Iterable, Mapping, Optional, Union
+
+import torch.nn as nn
+from typing_extensions import override
+
+from ..utils import (
+ ActivityContext,
+ dynamo_skip,
+ fake_tensors,
+ has_fake_tensor,
+ is_submodule_of,
+ is_submodule_or_same,
+ real_tensors,
+)
+
+
+@dataclass
+class InputArgs:
+ """Container for input arguments to modules."""
+
+ args: list[Any]
+ kwargs: dict[str, Any]
+
+ def __init__(self, *args, **kwargs):
+ self.args = list(args)
+ self.kwargs = dict(kwargs)
+
+ def __add__(self, other: Any) -> InputArgs:
+ assert isinstance(other, InputArgs)
+ result = InputArgs(*self.args, *other.args, **{**self.kwargs, **other.kwargs})
+ return result
+
+ def drop_args(self, index: int | slice | None = None) -> InputArgs:
+ new_args = InputArgs(*self.args, **self.kwargs)
+ if index is None:
+ new_args.args.clear()
+ else:
+ del new_args.args[index]
+
+ return new_args
+
+ def drop_kwargs(self, keys: Sequence[str] | None = None) -> InputArgs:
+ new_args = InputArgs(*self.args, **self.kwargs)
+ if keys is None:
+ new_args.kwargs.clear()
+ else:
+ for key in keys:
+ new_args.kwargs.pop(key, None)
+
+ return new_args
+
+ @classmethod
+ def from_value(cls, v):
+ if isinstance(v, cls):
+ return v
+ elif isinstance(v, InputArgs):
+ return cls(*v.args, **v.kwargs)
+ elif isinstance(v, Sequence):
+ return cls(*v)
+ else:
+ return cls(v)
+
+
+OutputValue = Any
+
+
+@dataclass
+class PassageInputAdapter:
+ adapter_fn: Callable[[InputArgs, Optional[str], Optional[nn.Module]], InputArgs]
+
+ def __call__(
+ self, original_input: InputArgs, module_name: Optional[str], module: Optional[nn.Module]
+ ) -> InputArgs:
+ result = self.adapter_fn(original_input, module_name, module)
+ return result
+
+
+@dataclass
+class PassageOutputAdapter:
+ adapter_fn: Callable[[Any, Optional[str], Optional[nn.Module]], Any]
+
+ def __call__(
+ self, original_output: Any, module_name: Optional[str], module: Optional[nn.Module]
+ ) -> Any:
+ result = self.adapter_fn(original_output, module_name, module)
+ return result
+
+
+class PassageInputOverrides(dict[str, Union[PassageInputAdapter, InputArgs]]):
+ def __init__(self, input_overrides: Mapping[str, PassageInputAdapter | InputArgs] = {}):
+ for k, v in input_overrides.items():
+ self[k] = v
+
+ # def __setitem__(self, key: str, value: InputAdapter | InputArgs) -> None:
+ # if isinstance(key, InputArgs):
+ # def adapter_fn(original_input: InputArgs) -> InputArgs:
+ # assert isinstance(value, InputArgs)
+ # return value
+ # self[key] = InputAdapter(adapter_fn)
+ # else:
+ # self[key] = value
+
+
+class PassageOutputOverrides(dict[str, Union[PassageOutputAdapter, Any]]):
+ def __init__(self, output_overrides: Mapping[str, PassageOutputAdapter | Any] = {}):
+ for k, v in output_overrides.items():
+ self[k] = v
+
+
+class NoActivePassageContextError(RuntimeError):
+ pass
+
+
+class RequiredPassageOutputsCapturedSignal(Exception):
+ pass
+
+
+@dataclass
+class PassageOutput:
+ captured_inputs: dict[str, InputArgs]
+ captured_outputs: dict[str, Any]
+ captured_fake_outputs: dict[str, Any]
+ module_output: Any
+
+
+Predicate = Callable[[str, nn.Module], bool]
+
+
+def always_false_predicate(module_name: str, module: nn.Module) -> bool:
+ return False
+
+
+def always_true_predicate(module_name: str, module: nn.Module) -> bool:
+ return True
+
+
+class Passage(nn.Module):
+ create_fn_context = ActivityContext[None](max_depth=1)
+ active_passages_context = ActivityContext["Passage"](no_duplicates=True, reversed=True)
+
+ def __init__(
+ self,
+ module: nn.Module,
+ *,
+ inputs_to_capture: Iterable[str] = [],
+ outputs_to_capture: Iterable[str] = [],
+ input_overrides: Mapping[str, PassageInputAdapter | InputArgs] = {},
+ output_overrides: Mapping[str, PassageOutputAdapter | Any] = {},
+ outputs_cache: dict[str, Any] = {},
+ capture_fake_outputs_predicate: Predicate = always_false_predicate,
+ capture_cache_outputs_predicate: Predicate = always_false_predicate,
+ early_exit: bool = False,
+ name: Optional[str] = None,
+ ):
+ super().__init__()
+
+ if not self.create_fn_context.is_active():
+ raise RuntimeError("Please use Passage.create(...) in order to create a new Passage")
+
+ self.active_context_manager: Optional[ContextManager] = None
+
+ self.name = name
+ self.module = module
+ self.module_to_name_mapping = {id(v): k for k, v in module.named_modules()}
+ self.inputs_to_capture = set(inputs_to_capture)
+ self.outputs_to_capture = set(outputs_to_capture)
+ self.input_overrides = input_overrides
+ self.output_overrides = output_overrides
+ self.outputs_cache = outputs_cache
+ self.capture_fake_outputs_predicate = capture_fake_outputs_predicate
+ self.capture_cache_outputs_predicate = capture_cache_outputs_predicate
+ self.early_exit = early_exit
+
+ self.reset()
+
+ @property
+ def input_overrides(self) -> PassageInputOverrides:
+ return self._input_overrides
+
+ @input_overrides.setter
+ def input_overrides(self, value: Mapping[str, PassageInputAdapter | InputArgs]):
+ self._input_overrides = PassageInputOverrides(value)
+
+ @property
+ def output_overrides(self) -> PassageOutputOverrides:
+ return self._output_overrides
+
+ @output_overrides.setter
+ def output_overrides(self, value: Mapping[str, PassageOutputAdapter | Any]):
+ self._output_overrides = PassageOutputOverrides(value)
+
+ def reset(self):
+ self.required_capture_count = (
+ (len(self.inputs_to_capture) + len(self.outputs_to_capture))
+ if self.early_exit
+ else None
+ )
+ self.captured_outputs: dict[str, Any] = {}
+ self.captured_inputs: dict[str, InputArgs] = {}
+ self.captured_fake_outputs: dict[str, Any] = {}
+
+ @classmethod
+ def module_name_relative_to_active_passage(cls, module: PatchedModule) -> str:
+ root_passage = Passage.active_passages_context.get_active()
+ assert root_passage is not None
+ module_name = root_passage.module_to_name_mapping[id(module)]
+ return module_name
+
+ @classmethod
+ def create(
+ cls,
+ module: nn.Module,
+ *,
+ inputs_to_capture: Iterable[str] = [],
+ outputs_to_capture: Iterable[str] = [],
+ input_overrides: Mapping[str, PassageInputAdapter | InputArgs] = {},
+ output_overrides: Mapping[str, PassageOutputAdapter | Any] = {},
+ outputs_cache: dict[str, Any] = {},
+ capture_fake_outputs_predicate: Predicate = always_false_predicate,
+ capture_cache_outputs_predicate: Predicate = always_false_predicate,
+ early_exit: bool = False,
+ name: Optional[str] = None,
+ ) -> Passage:
+ with cls.create_fn_context(None):
+ passage = cls(
+ module=module,
+ inputs_to_capture=inputs_to_capture,
+ outputs_to_capture=outputs_to_capture,
+ input_overrides=input_overrides,
+ output_overrides=output_overrides,
+ outputs_cache=outputs_cache,
+ capture_fake_outputs_predicate=capture_fake_outputs_predicate,
+ capture_cache_outputs_predicate=capture_cache_outputs_predicate,
+ early_exit=early_exit,
+ name=name,
+ )
+
+ for submodule_name, submodule in module.named_modules(remove_duplicate=False):
+ patch_module(submodule_name, submodule)
+
+ # register_passage_hooks(module, descriptor)
+
+ return passage
+
+ def is_active(self) -> bool:
+ result = self.active_context_manager is not None
+ return result
+
+ def __enter__(self):
+ assert self.active_context_manager is None
+ self.active_context_manager = Passage.active_passages_context(self)
+ self.active_context_manager.__enter__()
+ self.module_to_name_mapping = {id(v): k for k, v in self.named_modules()}
+
+ return self
+
+ def __exit__(self, exc_type, exc_val, exc_tb):
+ assert self.active_context_manager is not None
+ self.active_context_manager.__exit__(exc_type, exc_val, exc_tb)
+
+ def freeze(self):
+ self.eval()
+ self.requires_grad_(False)
+
+ def unfreeze(self):
+ self.train()
+ self.requires_grad_(True)
+
+ def run(self, *args, **kwargs) -> PassageOutput:
+ return self(*args, **kwargs)
+
+ @override
+ def __call__(self, *args, **kwargs) -> PassageOutput:
+ return super().__call__(*args, **kwargs)
+
+ @dynamo_skip
+ @override
+ def forward(self, *args, **kwargs) -> PassageOutput:
+ self.reset()
+
+ with Passage.active_passages_context(self):
+ try:
+ module_output = self.module(*args, **kwargs)
+ except RequiredPassageOutputsCapturedSignal:
+ module_output = None
+
+ output = PassageOutput(
+ captured_inputs=self.captured_inputs,
+ captured_outputs=self.captured_outputs,
+ captured_fake_outputs=self.captured_fake_outputs,
+ module_output=module_output,
+ )
+
+ self.reset()
+
+ return output
+
+
+class PatchedModule: ...
+
+
+def patch_module(module_name_: str, module: nn.Module):
+ # orig_forward = module.forward
+
+ if isinstance(module, PatchedModule):
+ # if module_name != Passage.module_name_relative_to_active_passage(module):
+ # logger.warn(f'Module "{module_name}" already patched for module "{Passage.module_name_relative_to_active_passage(module)}". Could lead to bugs.')
+ return
+
+ orig_class = module.__class__
+
+ class PassageModuleWrapper(orig_class, PatchedModule):
+ # Defined as a static method to avoid potential collision with original class methods
+ @staticmethod
+ @dynamo_skip
+ def can_be_skipped(_self: PassageModuleWrapper, depth: int) -> bool:
+ passages_beyond_depth = Passage.active_passages_context[depth:]
+ module_name = Passage.module_name_relative_to_active_passage(_self)
+
+ results = [
+ (
+ module_name in passage.outputs_cache
+ and not any(
+ is_submodule_or_same(k, module_name) for k in passage.outputs_to_capture
+ )
+ and not any(
+ is_submodule_of(k, module_name)
+ for k, v in passage.input_overrides.items()
+ if v is not None
+ )
+ and not any(
+ is_submodule_of(k, module_name)
+ for k, v in passage.output_overrides.items()
+ if v is not None
+ )
+ )
+ for passage in passages_beyond_depth
+ ]
+
+ result = all(results)
+
+ return result
+
+ # Defined as a static method to avoid potential collision with original class methods
+ @staticmethod
+ @dynamo_skip
+ def run_passage(_self: PassageModuleWrapper, depth: int, args, kwargs):
+ if depth + 1 > len(Passage.active_passages_context):
+ output = super(PassageModuleWrapper, _self).__call__(*args, **kwargs)
+ return output
+
+ module_name = Passage.module_name_relative_to_active_passage(_self)
+ passage = Passage.active_passages_context[depth]
+
+ has_output_override = module_name in passage.output_overrides
+ output_override = passage.output_overrides.get(module_name)
+
+ if has_output_override and not isinstance(output_override, PassageOutputAdapter):
+ output = output_override
+ else:
+ input_override = passage.input_overrides.get(module_name)
+ if input_override is not None:
+ original_input_args = InputArgs(*args, **kwargs)
+
+ if isinstance(input_override, PassageInputAdapter):
+ new_input_args = input_override(original_input_args, module_name, module)
+ else:
+ new_input_args = input_override
+
+ args, kwargs = new_input_args.args, new_input_args.kwargs
+
+ if (
+ output_override is None
+ and PassageModuleWrapper.can_be_skipped(_self, depth)
+ and (has_fake_tensor(args) or has_fake_tensor(kwargs))
+ ):
+ cached_output = passage.outputs_cache[module_name]
+ return cached_output
+
+ output = PassageModuleWrapper.run_passage(
+ _self=_self,
+ depth=depth + 1,
+ args=args,
+ kwargs=kwargs,
+ )
+
+ if isinstance(output_override, PassageOutputAdapter):
+ output = output_override(output, module_name, module)
+
+ if passage.capture_fake_outputs_predicate(module_name, module):
+ fake_output = fake_tensors(output)
+ passage.captured_fake_outputs[module_name] = fake_output
+
+ if not module_name in passage.outputs_cache and passage.capture_cache_outputs_predicate(
+ module_name, module
+ ):
+ fake_output = fake_tensors(output)
+ passage.outputs_cache[module_name] = fake_output
+
+ if module_name in passage.inputs_to_capture:
+ real_args, real_kwargs = real_tensors(args), real_tensors(kwargs)
+ passage.captured_inputs[module_name] = InputArgs(*real_args, **real_kwargs)
+
+ if passage.required_capture_count is not None:
+ passage.required_capture_count -= 1
+
+ if module_name in passage.outputs_to_capture:
+ real_output = real_tensors(output)
+ output_value = real_output
+ passage.captured_outputs[module_name] = output_value
+
+ if passage.required_capture_count is not None:
+ passage.required_capture_count -= 1
+
+ if passage.required_capture_count == 0:
+ raise RequiredPassageOutputsCapturedSignal()
+
+ return output
+
+ @dynamo_skip
+ @override
+ def __call__(self, *args, **kwargs):
+ output = self.run_passage(
+ _self=self,
+ depth=0,
+ args=args,
+ kwargs=kwargs,
+ )
+ return output
+
+ # module.forward = forward
+ PassageModuleWrapper.__name__ = f"ModuleWrapper({module.__class__.__name__})"
+ module.__class__ = PassageModuleWrapper
diff --git a/modelopt/torch/puzzletron/sewing_kit/utils.py b/modelopt/torch/puzzletron/sewing_kit/utils.py
new file mode 100644
index 0000000000..068abef99e
--- /dev/null
+++ b/modelopt/torch/puzzletron/sewing_kit/utils.py
@@ -0,0 +1,434 @@
+# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+# SPDX-License-Identifier: Apache-2.0
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# mypy: ignore-errors
+from __future__ import annotations
+
+import inspect
+from contextlib import contextmanager
+from typing import (
+ TYPE_CHECKING,
+ Any,
+ Callable,
+ ContextManager,
+ Generic,
+ Iterable,
+ Literal,
+ Optional,
+ Protocol,
+ TypeVar,
+ cast,
+ overload,
+)
+
+import torch
+import torch._C
+import torch._dynamo
+import torch.distributed
+import torch.nn as nn
+import torch.nn.functional as F
+import torch.utils._pytree as pytree
+from torch import Tensor
+from torch._subclasses import FakeTensor, FakeTensorMode
+from typing_extensions import override
+
+if TYPE_CHECKING:
+ from collections.abc import Sequence
+
+Fn = TypeVar("Fn", bound=Callable)
+
+
+class DynamoSkip(Protocol):
+ @overload
+ def __call__(self, fn: None = None) -> Callable[[Fn], Fn]: ...
+ @overload
+ def __call__(self, fn: Fn) -> Fn: ...
+
+
+class DynamoDisable(Protocol):
+ @overload
+ def __call__(self, fn: None = None, disable: bool = False) -> Callable[[Fn], Fn]: ...
+ @overload
+ def __call__(self, fn: Fn, disable: bool = False) -> Fn: ...
+
+
+try:
+ dynamo_skip: DynamoSkip = cast("Any", torch._dynamo.decorators).skip
+ dynamo_disable: DynamoDisable = cast("Any", torch._dynamo.decorators).disable
+except:
+ dynamo_skip: DynamoSkip = cast("Any", torch._dynamo.eval_frame).skip
+ dynamo_disable: DynamoDisable = cast("Any", torch._dynamo.eval_frame).disable
+
+
+TModule = TypeVar("TModule", bound=nn.Module)
+
+
+class ModuleRef(Generic[TModule]):
+ def __init__(self, module: TModule):
+ self.module = module
+
+
+class ActivityContextMaxDepthException(Exception):
+ pass
+
+
+class ActivityContextDuplicateException(Exception):
+ pass
+
+
+T = TypeVar("T")
+
+
+class ActivityContext(Generic[T]):
+ def __init__(self, max_depth: Optional[int] = None, no_duplicates=False, reversed=False):
+ self.activity_stack: list[T] = []
+ self.max_depth = max_depth
+ self.no_duplicates = no_duplicates
+ self.reversed = reversed
+ self.head_index = 0 if self.reversed else -1
+
+ def __contains__(self, value: T) -> bool:
+ result = value in self.activity_stack
+ return result
+
+ def __call__(self, value: T) -> ContextManager:
+ @contextmanager
+ def fn():
+ try:
+ if self.no_duplicates and value in self.activity_stack:
+ raise ActivityContextDuplicateException(
+ f"Activity stack cannot have a duplicate of item {value}"
+ )
+
+ self.activity_stack.insert(self.head_index, value)
+
+ if self.max_depth is not None and len(self) > self.max_depth:
+ raise ActivityContextMaxDepthException(
+ f"Activity stack exceeds max depth of {self.max_depth}"
+ )
+
+ yield
+ finally:
+ assert self.is_active()
+ self.activity_stack.pop(self.head_index)
+
+ return fn()
+
+ def __len__(self) -> int:
+ result = len(self.activity_stack)
+ return result
+
+ @overload
+ def __getitem__(self, key: int) -> T: ...
+ @overload
+ def __getitem__(self, key: slice) -> Sequence[T]: ...
+ def __getitem__(self, key: int | slice) -> T | Sequence[T]:
+ result = self.activity_stack[key]
+ return result
+
+ def is_active(self) -> bool:
+ result = len(self) > 0
+ return result
+
+ def get_active(self) -> Optional[T]:
+ if self.is_active:
+ return self.activity_stack[-1]
+ else:
+ return None
+
+
+def is_submodule_of(module_name: str, other_module_name: str) -> bool:
+ result = module_name.startswith(f"{other_module_name}.") or (
+ module_name != "" and other_module_name == ""
+ )
+ return result
+
+
+def is_submodule_or_same(module_name: str, other_module_name: str) -> bool:
+ result = module_name == other_module_name or is_submodule_of(module_name, other_module_name)
+ return result
+
+
+fake_mode = FakeTensorMode(
+ allow_non_fake_inputs=True,
+ # allow_fallback_kernels=False,
+)
+
+
+@overload
+def fake_tensor(t: Tensor, *, dtype: Optional[torch.dtype] = None, use_meta=False) -> Tensor: ...
+
+
+@overload
+def fake_tensor(
+ size: Sequence[int] | torch.Size, *, dtype: Optional[torch.dtype] = None, use_meta=False
+) -> Tensor: ...
+
+
+@overload
+def fake_tensor(*args: int, dtype: Optional[torch.dtype] = None, use_meta=False) -> Tensor: ...
+
+
+class MyFakeTensor(Tensor):
+ @dynamo_disable
+ def __init__(self, *args, **kwargs):
+ super().__init__()
+ self._t: FakeTensor
+
+ @override
+ @dynamo_disable
+ def __repr__(self, *, tensor_contents=None):
+ return f"MyFakeTensor(shape={list(self._t.shape)}, dtype={self._t.dtype}, device={self._t.device})"
+
+ @classmethod
+ @override
+ @dynamo_disable
+ def __torch_dispatch__(cls, func, types, args=(), kwargs=None):
+ if kwargs is None:
+ kwargs = {}
+
+ args, kwargs = pytree.tree_map_only(MyFakeTensor, lambda t: t._t, (args, kwargs))
+
+ types = pytree.tree_map_only(type(MyFakeTensor), lambda t: FakeTensor, types)
+
+ out = func(*args, **kwargs)
+
+ out = pytree.tree_map_only(Tensor, lambda t: MyFakeTensor.create(t), out)
+
+ return out
+
+ __torch_function__ = torch._C._disabled_torch_function_impl
+
+ # @dynamo_disable
+ # def __getattribute__(self, attr: str):
+ # if attr in {'_t', 'device', '__repr__', '__torch_function__', '__class__'}:
+ # return object.__getattribute__(self, attr)
+
+ # result = getattr(self._t, attr)
+
+ # result = pytree.tree_map_only(
+ # Tensor, lambda t: MyFakeTensor.create(t), result
+ # )
+ # print('__getattribute__', 'attr', attr, 'ret', result)
+
+ # return result
+
+ @property
+ @dynamo_disable
+ def device(self):
+ return self._t.device
+
+ # @property
+ # @dynamo_disable
+ # def shape(self):
+ # return self._t.shape
+
+ # @dynamo_disable
+ # def size(self):
+ # return self._t.size()
+
+ # @classmethod
+ # @dynamo_disable
+ # def __torch_function__(cls, func, types, args=(), kwargs=None):
+ # if kwargs is None:
+ # kwargs = {}
+
+ # args, kwargs = pytree.tree_map_only(
+ # MyFakeTensor, lambda t: t._t, (args, kwargs)
+ # )
+
+ # ret = func(*args, **kwargs)
+
+ # ret = pytree.tree_map_only(
+ # Tensor, lambda t: MyFakeTensor.create(t), ret
+ # )
+ # print('__torch_function__', 'func', func, 'ret', ret)
+
+ # return ret
+
+ @staticmethod
+ @dynamo_disable
+ def __new__(cls, elem, device) -> MyFakeTensor:
+ self = torch.Tensor._make_subclass(
+ cls,
+ elem,
+ elem.requires_grad,
+ dispatch_device=True,
+ device_for_backend_keys=device,
+ )
+ return cast("MyFakeTensor", self)
+
+ @classmethod
+ @dynamo_disable
+ def create(cls, data: Tensor) -> MyFakeTensor:
+ if isinstance(data, MyFakeTensor):
+ return data
+
+ if isinstance(data, FakeTensor):
+ t = data
+ else:
+ t = FakeTensor.from_tensor(data, fake_mode=fake_mode)
+
+ # my_fake_tensor = MyFakeTensor(torch.empty(t.shape, dtype=t.dtype, device='meta'))
+ my_fake_tensor = MyFakeTensor(
+ torch.empty(t.shape, dtype=t.dtype, device="meta"),
+ t.device,
+ )
+ my_fake_tensor._t = t
+
+ return my_fake_tensor
+
+
+@dynamo_disable
+def fake_tensor(*args, **kwargs) -> Tensor:
+ dtype: Optional[torch.dtype] = kwargs.get("dtype")
+ use_meta = kwargs.get("use_meta", False)
+ device = kwargs.get("device", "meta")
+
+ if len(args) == 1 and isinstance(args[0], Tensor):
+ if use_meta:
+ fake_tensor = torch.empty(args[0].size(), dtype=dtype or args[0].dtype, device="meta")
+ else:
+ fake_tensor = MyFakeTensor.create(args[0])
+ else:
+ fake_tensor = torch.empty(*args, dtype=dtype, device=device)
+ if not use_meta:
+ fake_tensor = MyFakeTensor.create(fake_tensor)
+
+ return fake_tensor
+
+
+@dynamo_skip
+def fake_tensor_like(t: Tensor, use_meta=False) -> Tensor:
+ return fake_tensor(t, use_meta=use_meta)
+
+
+T = TypeVar("T")
+
+
+@dynamo_skip
+def fake_tensors(value: T, use_meta=False) -> T:
+ result = pytree.tree_map_only(Tensor, lambda t: fake_tensor_like(t, use_meta), value)
+ return result
+ # if isinstance(value, Mapping):
+ # return cast(Any, value.__class__)({k: fake_tensors(v, use_meta) for k, v in value.items()})
+ # if isinstance(value, Sequence):
+ # return cast(Any, value.__class__)([fake_tensors(v, use_meta) for v in value])
+ # if isinstance(value, Tensor):
+ # return fake_tensor_like(value, use_meta)
+ # return value
+
+
+@dynamo_skip
+def real_tensors(value: Any) -> Any:
+ result = pytree.tree_map_only(Tensor, lambda t: None if is_fake_tensor(t) else t, value)
+ return result
+ # if isinstance(value, Mapping):
+ # return cast(Any, value.__class__)({k: real_tensors(v) for k, v in value.items()})
+ # if isinstance(value, Sequence):
+ # return cast(Any, value.__class__)([real_tensors(v) for v in value])
+ # if is_fake_tensor(value):
+ # return None
+ # return value
+
+
+@dynamo_skip
+def is_fake_tensor(t: Any) -> bool:
+ return isinstance(t, (MyFakeTensor, FakeTensor)) or (isinstance(t, Tensor) and t.is_meta)
+
+
+@dynamo_skip
+def has_fake_tensor(v: Any) -> bool:
+ result = pytree.tree_any(is_fake_tensor, v)
+ return result
+
+
+def _get_device_for_distributed(
+ group: Optional[torch.distributed.ProcessGroup] = None,
+) -> str:
+ """
+ Determine the appropriate device for distributed communication based on the backend.
+ NCCL backend requires CUDA tensors, while Gloo supports both CPU and CUDA.
+ """
+ if not torch.distributed.is_initialized():
+ return "cpu"
+
+ backend = torch.distributed.get_backend(group)
+ if backend == "nccl":
+ # NCCL requires CUDA tensors
+ return torch.cuda.current_device()
+ else:
+ # Gloo and other backends support CPU tensors
+ return "cpu"
+
+
+def distributed_isend_obj(
+ obj: Any,
+ dst: int = 0,
+ group: Optional[torch.distributed.ProcessGroup] = None,
+) -> list[Optional[torch.distributed.Work]]:
+ device = _get_device_for_distributed(group)
+ obj_tensor, obj_size_tensor = torch.distributed.distributed_c10d._object_to_tensor(
+ obj, device=device, **_get_group_kwarg_if_necessary()
+ )
+ works: list[Optional[torch.distributed.Work]] = [
+ torch.distributed.isend(obj_size_tensor, dst, group),
+ torch.distributed.isend(obj_tensor, dst, group),
+ ]
+ # p2p_ops = [
+ # torch.distributed.P2POp(torch.distributed.isend, obj_size_tensor, dst, group),
+ # torch.distributed.P2POp(torch.distributed.isend, obj_tensor, dst, group),
+ # ]
+
+ # works = torch.distributed.batch_isend_irecv(p2p_ops)
+
+ return works
+
+
+def distributed_send_obj(
+ obj: Any,
+ dst: int = 0,
+ group: Optional[torch.distributed.ProcessGroup] = None,
+):
+ works = distributed_isend_obj(obj=obj, dst=dst, group=group)
+ for work in works:
+ if work is not None:
+ work.wait()
+
+
+def distributed_recv_obj(
+ src: Optional[int] = None,
+ group: Optional[torch.distributed.ProcessGroup] = None,
+) -> Any:
+ device = _get_device_for_distributed(group)
+ obj_size_tensor = torch.LongTensor(1).to(device)
+ torch.distributed.recv(obj_size_tensor, src=src, group=group)
+ obj_size = int(obj_size_tensor.item())
+
+ obj_tensor = torch.ByteTensor(obj_size).to(device)
+ torch.distributed.recv(obj_tensor, src=src, group=group)
+
+ obj = torch.distributed.distributed_c10d._tensor_to_object(
+ obj_tensor, obj_size, **_get_group_kwarg_if_necessary()
+ )
+
+ return obj
+
+
+def _get_group_kwarg_if_necessary() -> dict:
+ """For newer versions of torch"""
+ arg_names = inspect.signature(
+ torch.distributed.distributed_c10d._object_to_tensor
+ ).parameters.keys()
+ return dict(group=None) if "group" in arg_names else dict()
diff --git a/modelopt/torch/puzzletron/subblock_stats/calc_subblock_params_and_memory.py b/modelopt/torch/puzzletron/subblock_stats/calc_subblock_params_and_memory.py
new file mode 100644
index 0000000000..3ea57bd7a7
--- /dev/null
+++ b/modelopt/torch/puzzletron/subblock_stats/calc_subblock_params_and_memory.py
@@ -0,0 +1,339 @@
+# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+# SPDX-License-Identifier: Apache-2.0
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# mypy: ignore-errors
+
+"""Calculate memory usage and parameter counts for neural network subblocks.
+
+This module provides utilities to compute memory footprints and parameter counts
+for different subblock types (FFN, Attention, Mamba, MoE) in large language models,
+considering various data types, batch sizes, and sequence lengths.
+"""
+
+import copy
+import json
+import math
+from pathlib import Path
+from typing import Type
+
+import numpy as np
+import torch
+from transformers import PretrainedConfig
+
+from modelopt.torch.puzzletron.anymodel.model_descriptor import ModelDescriptor
+from modelopt.torch.puzzletron.decilm.deci_lm_hf_code.block_config import (
+ AttentionConfig,
+ BlockConfig,
+ FFNConfig,
+ MambaConfig,
+ maybe_cast_block_configs,
+)
+from modelopt.torch.puzzletron.tools.checkpoint_utils_hf import init_model_from_config
+from modelopt.torch.puzzletron.utils.utils import (
+ EmptyInitOnDevice,
+ calculate_kv_dim,
+ raise_unknown_subblock_config_error,
+ sizeof_dtype,
+)
+
+
+def calculate_subblock_memory(
+ subblock_config: FFNConfig | AttentionConfig,
+ batch_size: int,
+ prefill_seq_len: int,
+ generation_seq_len: int,
+ prefill_queue_size: int,
+ n_embd: int,
+ n_head: int,
+ weights_dtype: torch.dtype,
+ kv_cache_dtype: torch.dtype,
+ allocate_prefill_query: bool,
+ model_config: PretrainedConfig,
+ descriptor: Type[ModelDescriptor],
+) -> float | dict[str, float]:
+ """``model_config`` / ``descriptor`` are required (puzzletron-style); FFN uses them for meta init."""
+ if subblock_config.no_op:
+ return 0
+ if isinstance(subblock_config, FFNConfig):
+ return calculate_ffn_memory(
+ subblock_config,
+ model_config,
+ descriptor,
+ weights_dtype,
+ )
+ if isinstance(subblock_config, AttentionConfig):
+ if subblock_config.is_mamba:
+ return calculate_mamba_memory(
+ subblock_config,
+ model_config,
+ descriptor,
+ batch_size,
+ weights_dtype,
+ kv_cache_dtype,
+ )
+ else:
+ return calculate_attention_memory(
+ subblock_config,
+ model_config,
+ descriptor,
+ batch_size,
+ prefill_seq_len,
+ generation_seq_len,
+ prefill_queue_size,
+ n_embd,
+ n_head,
+ weights_dtype,
+ kv_cache_dtype,
+ allocate_prefill_query,
+ )
+ raise_unknown_subblock_config_error(subblock_config)
+
+
+def calculate_subblock_params(
+ config: PretrainedConfig,
+ layer_config: BlockConfig | FFNConfig | AttentionConfig,
+ descriptor: Type[ModelDescriptor],
+) -> int:
+ """Count parameters on one meta decoder layer."""
+ if isinstance(layer_config, FFNConfig):
+ block_config = layer_config.to_blockconfig()
+ elif isinstance(layer_config, AttentionConfig):
+ block_config = layer_config.to_blockconfig()
+ else:
+ block_config = layer_config
+
+ ffn = block_config.ffn
+ attn = block_config.attention
+ ffn_no_op = ffn is None or ffn.no_op
+ attn_no_op = attn is None or attn.no_op
+ if not (ffn_no_op or attn_no_op):
+ raise AssertionError(
+ "One of ffn or attention must be no-op for sublayer param calculation "
+ "(single subblock at a time)."
+ )
+ if ffn_no_op and attn_no_op:
+ return 0
+
+ _config = copy.deepcopy(config)
+ lm_config = descriptor.get_language_model_config(_config)
+ lm_config.num_hidden_layers = 1
+
+ block_configs = maybe_cast_block_configs([block_config])
+ _config.block_configs = block_configs
+ if lm_config is not _config:
+ lm_config.block_configs = block_configs
+
+ # Replaced earlier pattern:
+ # with EmptyInitOnDevice("meta"), deci_x_patcher(..., block_configs=block_configs):
+ # model = init_model_from_config(_config, ...)
+ #
+ # That fails on GPT-OSS with recent Transformers: ``deci_x_patcher`` runs
+ # ``attn_no_op_post_init`` / ``mlp_no_op_post_init`` inside ``DecoderLayer.__init__``, so norms
+ # / attn / mlp are swapped for placeholders before ``GptOssModel.__init__`` finishes. At the end
+ # of ``GptOssModel.__init__`` the stack calls ``self.post_init()`` — inherited from
+ # ``PreTrainedModel`` — which then raises
+ # ``ValueError`` (e.g. ``post_attention_layernorm`` in ``_keep_in_fp32_modules`` no longer matches
+ # the tree). Below we merge per-layer fields manually, init without the patcher, then call the
+ # same descriptor no-op hooks on the built layer (equivalent param count for
+ # ``num_hidden_layers == 1``).
+
+ # ``block_config_to_layer_overrides`` may include keys with value ``None``; we omit those so
+ # ``lm_config.update`` does not overwrite existing fields with ``None`` (same rule as
+ # ``override_config_with_block_configs`` inside ``deci_x_patcher``).
+ layer_overrides = descriptor.block_config_to_layer_overrides(block_configs[0])
+ lm_config.update({k: v for k, v in layer_overrides.items() if v is not None})
+
+ with EmptyInitOnDevice("meta"):
+ model = init_model_from_config(
+ _config,
+ trust_remote_code=descriptor.requires_trust_remote_code(),
+ )
+
+ decoder_layer = model.get_submodule(descriptor.layer_block_name(index=0))
+ if attn_no_op:
+ descriptor.attn_no_op_post_init(decoder_layer)
+ if ffn_no_op:
+ descriptor.mlp_no_op_post_init(decoder_layer)
+ return sum(p.numel() for p in decoder_layer.parameters())
+
+
+def calc_subblock_active_params(
+ sublayer_config: FFNConfig | AttentionConfig,
+ model_config: PretrainedConfig,
+ descriptor: Type[ModelDescriptor],
+ n_embd: int,
+ moe_stats_file: str,
+ batch_size: int,
+ block_idx: int,
+) -> int:
+ if not (isinstance(sublayer_config, FFNConfig) and sublayer_config.is_moe):
+ return calculate_subblock_params(model_config, sublayer_config, descriptor)
+ return estimate_moe_active_params(
+ sublayer_config, n_embd, moe_stats_file, batch_size, block_idx
+ )
+
+
+def load_moe_stats(stats_file: str) -> dict:
+ with open(stats_file) as f:
+ stats = json.load(f)
+ return [np.array(l) / np.sum(l) if len(l) > 0 else 0 for l in stats]
+
+
+def estimate_num_active_experts(
+ dist_over_experts: np.ndarray, batch_size: int, num_experts: int
+) -> int:
+ # cut the tail and renormalize
+ dist_over_experts = np.sort(dist_over_experts)[::-1][:num_experts]
+ dist_over_experts = dist_over_experts / (dist_over_experts.sum())
+ # calculate the probability of at least one expert being active
+ # (expectation on indicators is the expected number of active experts)
+ return (1 - (1 - dist_over_experts) ** batch_size).sum()
+
+
+def estimate_moe_active_params(
+ subblock_config: FFNConfig,
+ n_embd: int,
+ moe_stats_file: Path | str,
+ batch_size: int,
+ block_idx: int,
+) -> int:
+ assert Path(moe_stats_file).exists()
+ # if not Path(moe_stats_file).exists(): # if path is not provided, should we assume uniform distribution?
+ # return calculate_subblock_params(subblock_config, n_embd, n_head=None)
+ moe_stats = load_moe_stats(moe_stats_file)
+ dist_over_experts = moe_stats[block_idx]
+ num_experts = subblock_config.moe.num_local_experts
+
+ expected_num_active_experts = estimate_num_active_experts(
+ dist_over_experts, batch_size, num_experts
+ )
+ expert_dim = subblock_config.moe.expert_intermediate_dim
+ shared_expert_dim = subblock_config.moe.shared_expert_intermediate_dim
+ num_linear_layers = 3 # all moe experts have 3 linear layers
+
+ router_num_params = n_embd * num_experts
+ expected_num_active_experts_params = (
+ num_linear_layers * expert_dim * n_embd * expected_num_active_experts
+ )
+ shared_expert_num_params = num_linear_layers * shared_expert_dim * n_embd
+
+ expected_total_params = (
+ router_num_params + expected_num_active_experts_params + shared_expert_num_params
+ )
+ return expected_total_params
+
+
+def calculate_attention_memory(
+ attention_config: AttentionConfig,
+ model_config: PretrainedConfig,
+ descriptor: Type[ModelDescriptor],
+ batch_size: int,
+ prefill_seq_len: int,
+ generation_seq_len: int,
+ prefill_queue_size: int,
+ n_embd: int,
+ n_head: int,
+ weights_dtype: torch.dtype,
+ kv_cache_dtype: torch.dtype,
+ allocate_prefill_query: bool,
+) -> dict[str, float]:
+ """allocate_prefill_query: infery-llm style.
+ Infery used a unified Wqkv matrix, so before extracting the kv-cache,
+ the query also had to be kept in-memory, once per layer.
+ """
+ seq_len = prefill_seq_len + generation_seq_len
+ if (
+ attention_config.is_llama4
+ and (attention_chunk_size := attention_config.llama4.attention_chunk_size) is not None
+ ):
+ seq_len = min(seq_len, attention_chunk_size)
+
+ kv_dim = calculate_kv_dim(attention_config.num_key_value_heads, n_head, n_embd)
+ total_num_tokens = seq_len * (batch_size + prefill_queue_size)
+ kv_cache_size = total_num_tokens * kv_dim
+ query_prefill_size = seq_len * n_embd if allocate_prefill_query else 0
+ num_params = calculate_subblock_params(model_config, attention_config, descriptor)
+ total_memory = (
+ kv_cache_size * sizeof_dtype(kv_cache_dtype)
+ + query_prefill_size * sizeof_dtype(weights_dtype)
+ + num_params * sizeof_dtype(weights_dtype)
+ ) / 2**20
+ kv_cache_memory = kv_cache_size * sizeof_dtype(kv_cache_dtype) / 2**20
+ return {"memory_mib": total_memory, "kv_cache_memory_mib": kv_cache_memory}
+
+
+def calculate_mamba_memory(
+ attention_config: AttentionConfig,
+ model_config: PretrainedConfig,
+ descriptor: Type[ModelDescriptor],
+ batch_size: int,
+ weights_dtype: torch.dtype,
+ kv_cache_dtype: torch.dtype,
+) -> int:
+ assert attention_config.mamba is not None
+ mamba_config = attention_config.mamba
+ num_params = calculate_subblock_params(model_config, attention_config, descriptor)
+ return (
+ num_params * sizeof_dtype(weights_dtype)
+ + calculate_mamba_state_size(mamba_config, batch_size) * sizeof_dtype(kv_cache_dtype)
+ ) / 2**20
+
+
+def calculate_mamba_state_size(
+ mamba_config: MambaConfig,
+ batch_size: int,
+) -> int:
+ d_inner, in_proj_dim, conv_dim, kernel_size = _calculate_mamba_intermediates(mamba_config)
+ conv_state_size = math.prod((batch_size, conv_dim, kernel_size))
+ ssm_state_size = math.prod(
+ (batch_size, mamba_config.num_heads, mamba_config.head_dim, mamba_config.state_dim)
+ )
+ return conv_state_size + ssm_state_size
+
+
+def _calculate_mamba_intermediates(mamba_config: MambaConfig) -> tuple[int, ...]:
+ d_inner = mamba_config.num_heads * mamba_config.head_dim
+ in_proj_dim = (
+ d_inner * 2 + 2 * mamba_config.num_groups * mamba_config.state_dim + mamba_config.num_heads
+ )
+ conv_dim = d_inner + 2 * mamba_config.num_groups * mamba_config.state_dim
+ kernel_size = 4
+ return d_inner, in_proj_dim, conv_dim, kernel_size
+
+
+def calculate_ffn_memory(
+ ffn_config: FFNConfig,
+ model_config: PretrainedConfig,
+ descriptor: Type[ModelDescriptor],
+ weights_dtype: torch.dtype | str,
+ experts_dtype: torch.dtype | str | None = None,
+) -> float:
+ # TODO: How to separate between expert weights and the rest for any model (same as puzzletron).
+ num_params = calculate_subblock_params(model_config, ffn_config, descriptor)
+ return num_params * sizeof_dtype(weights_dtype) / 2**20
+
+
+def calculate_non_block_memory(
+ n_embd: int,
+ vocab_size: int,
+ weight_dtype: torch.dtype,
+) -> float:
+ return calculate_non_block_params(n_embd, vocab_size) * sizeof_dtype(weight_dtype) / 2**20
+
+
+def calculate_non_block_params(
+ n_embd: int,
+ vocab_size: int,
+) -> int:
+ return vocab_size * n_embd * 2 + n_embd
diff --git a/modelopt/torch/puzzletron/subblock_stats/calc_subblock_stats.py b/modelopt/torch/puzzletron/subblock_stats/calc_subblock_stats.py
new file mode 100644
index 0000000000..549d994f07
--- /dev/null
+++ b/modelopt/torch/puzzletron/subblock_stats/calc_subblock_stats.py
@@ -0,0 +1,559 @@
+# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+# SPDX-License-Identifier: Apache-2.0
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# mypy: ignore-errors
+
+"""Calc subblock stats to compute memory and runtime statistics for subblocks."""
+
+import dataclasses
+import json
+import os
+import warnings
+from functools import partial
+from itertools import product
+from pathlib import Path
+from typing import Iterable, Optional, Type, TypeVar
+
+os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"
+
+import pandas as pd
+import torch
+from immutabledict import immutabledict
+from omegaconf import DictConfig, ListConfig, OmegaConf
+from tqdm import tqdm
+from transformers import PretrainedConfig
+
+from modelopt.torch.puzzletron.anymodel.model_descriptor import (
+ ModelDescriptor,
+ ModelDescriptorFactory,
+)
+from modelopt.torch.puzzletron.decilm.deci_lm_hf_code.block_config import (
+ AttentionConfig,
+ BlockConfig,
+ FFNConfig,
+ SubblockConfig,
+)
+from modelopt.torch.puzzletron.replacement_library.replacement_utils import parse_layer_replacement
+from modelopt.torch.puzzletron.subblock_stats.calc_subblock_params_and_memory import (
+ calc_subblock_active_params,
+ calculate_non_block_memory,
+ calculate_non_block_params,
+ calculate_subblock_memory,
+ calculate_subblock_params,
+)
+from modelopt.torch.puzzletron.tools.checkpoint_utils import load_model_config
+from modelopt.torch.puzzletron.tools.logger import mprint
+from modelopt.torch.puzzletron.tools.robust_json import json_dump
+from modelopt.torch.puzzletron.utils.parsing import format_global_config
+
+# Type variable for dataclasses
+T_DataClass = TypeVar("T_DataClass")
+
+"""
+Usage:
+python -m modelopt.torch.puzzletron.subblock_stats.calc_subblock_stats PUZZLE_DIR [ --benchmark_iterations 1000 ]
+
+--benchmark_iterations=None (the default) means that the code won't use infery to benchmark runtime,
+ only memory stats will be calculated. If you want to benchmark runtime, run inside an infery-llm docker.
+
+"""
+
+
+def calculate_subblock_stats(
+ calc_subblock_stats_config: DictConfig,
+ teacher_dir: Path,
+ model_config: PretrainedConfig,
+ descriptor: Type[ModelDescriptor],
+ master_puzzle_dir: Path,
+ subblock_configs: list[immutabledict[str, AttentionConfig | FFNConfig]],
+ batch_size: int,
+ prefill_seq_len: int,
+ generation_seq_len: int,
+ prefill_queue_size: int,
+ n_embd: int,
+ n_head: int,
+ vocab_size: int,
+ benchmark_iterations: Optional[int],
+ use_cuda_graph: bool,
+ weights_dtype: torch.dtype,
+ activations_dtype: torch.dtype,
+ kv_cache_dtype: torch.dtype,
+ allocate_prefill_query: bool,
+ moe_stats_file: str | Path | None = None,
+) -> dict:
+ is_calc_runtime = benchmark_iterations is not None
+ if is_calc_runtime:
+ raise NotImplementedError("Runtime stats calculation is not implemented yet")
+
+ gpu = None if not torch.cuda.is_available() else torch.cuda.get_device_name()
+ subblock_stats = {
+ "args": dict(
+ is_calc_runtime=is_calc_runtime,
+ gpu=gpu,
+ batch_size=batch_size,
+ prefill_seq_len=prefill_seq_len,
+ generation_seq_len=generation_seq_len,
+ prefill_queue_size=prefill_queue_size,
+ n_embd=n_embd,
+ n_head=n_head,
+ vocab_size=vocab_size,
+ benchmark_iterations=benchmark_iterations,
+ use_cuda_graph=use_cuda_graph,
+ weights_dtype=str(weights_dtype),
+ activations_dtype=str(activations_dtype),
+ kv_cache_dtype=str(kv_cache_dtype),
+ ),
+ "non_block": dict(),
+ "subblocks": list(),
+ }
+ # Compute runtime stats for unique subblocks only
+ if is_calc_runtime:
+ raise NotImplementedError("Runtime stats calculation is not implemented yet")
+ subblock_configs_nolayerindex = set(
+ [subblock_config["subblock_config"] for subblock_config in subblock_configs]
+ )
+
+ # dict[SubblockConfig, float], float
+ # TODO: Manage default values for calc_subblock_stats_config in one place, e.g. within a dataclass for hydra config.
+ synth_dataset_num_requests = calc_subblock_stats_config.get("runtime_stats", {}).get(
+ "synth_dataset_num_requests", 200
+ )
+ backend = calc_subblock_stats_config.get("runtime_stats", {}).get("backend", "trt_torch")
+ runtime_by_subblock_dict, non_block_runtime_ms = calc_runtime_ms_for_subblocks(
+ subblock_configs_nolayerindex,
+ vocab_size,
+ n_embd,
+ n_head,
+ master_puzzle_dir,
+ teacher_dir,
+ synth_dataset_num_requests,
+ backend,
+ )
+
+ sorted_subblock_config = sorted(
+ subblock_configs, key=lambda subblock_config: subblock_config["subblock_config"]
+ )
+ it = (
+ tqdm(sorted_subblock_config, desc="Measuring subblock runtimes")
+ if is_calc_runtime
+ else sorted_subblock_config
+ )
+ for subblock_config_indexed in it:
+ subblock_config = subblock_config_indexed["subblock_config"]
+ parent_layer_indices = subblock_config_indexed["parent_layer_indices"]
+
+ if is_calc_runtime:
+ total_runtime_ms = runtime_by_subblock_dict[subblock_config]
+ prefill_runtime_ms = None
+ decode_runtime_ms = None
+ else:
+ total_runtime_ms, prefill_runtime_ms, decode_runtime_ms = None, None, None
+
+ subblock_memory = calculate_subblock_memory(
+ subblock_config,
+ batch_size,
+ prefill_seq_len,
+ generation_seq_len,
+ prefill_queue_size,
+ n_embd,
+ n_head,
+ weights_dtype,
+ kv_cache_dtype,
+ allocate_prefill_query,
+ model_config=model_config,
+ descriptor=descriptor,
+ )
+ if not isinstance(subblock_memory, dict):
+ subblock_memory = {"memory_mib": subblock_memory, "kv_cache_memory_mib": 0.0}
+
+ subblock_params = calculate_subblock_params(model_config, subblock_config, descriptor)
+ if moe_stats_file is not None:
+ subblock_active_params = calc_subblock_active_params(
+ subblock_config,
+ model_config,
+ descriptor,
+ n_embd,
+ moe_stats_file,
+ batch_size,
+ parent_layer_indices[0],
+ )
+ else:
+ subblock_active_params = subblock_params
+ subblock_stats["subblocks"].append(
+ {
+ "subblock_config": subblock_config,
+ "subblock_config_class": type(subblock_config).__name__,
+ "runtime_ms": total_runtime_ms,
+ "prefill_runtime_ms": prefill_runtime_ms,
+ "decode_runtime_ms": decode_runtime_ms,
+ "num_params": subblock_params,
+ "active_params": subblock_active_params,
+ "parent_layer_index": parent_layer_indices[0],
+ **subblock_memory,
+ }
+ )
+
+ if is_calc_runtime:
+ # TODO: fix
+ # from puzzle_tools.calc_subblock_runtime import measure_non_block_runtime_ms
+ # non_block_runtime_ms, embedding_runtime_ms, lm_head_runtime_ms = \
+ # measure_non_block_runtime_ms(batch_size, prefill_seq_len, generation_seq_len, n_embd, vocab_size,
+ # benchmark_iterations, use_cuda_graph)
+ embedding_runtime_ms, lm_head_runtime_ms = None, None
+ else:
+ non_block_runtime_ms, embedding_runtime_ms, lm_head_runtime_ms = None, None, None
+ non_block_memory = calculate_non_block_memory(n_embd, vocab_size, weights_dtype)
+ non_block_params = calculate_non_block_params(n_embd, vocab_size)
+
+ # TODO
+ # the semantics here is wrong why do we refer, prefill_runtime_ms as embedding_runtime_ms and lm_head_runtime_ms as decode_runtime_ms ?
+ # Prefill is the first the user prompt inference, and Decode refer to the next generation process. both processes use all the model layers.
+ subblock_stats["non_block"] = {
+ "runtime_ms": non_block_runtime_ms,
+ "prefill_runtime_ms": embedding_runtime_ms,
+ "decode_runtime_ms": lm_head_runtime_ms,
+ "memory_mib": non_block_memory,
+ "num_params": non_block_params,
+ }
+ return subblock_stats
+
+
+def launch_calc_subblock_stats(cfg: DictConfig) -> None:
+ """
+ Launch the calc subblock stats function with Hydra configuration.
+ """
+ mprint(f"Calculating subblock stats for puzzle directory: {cfg.puzzle_dir}")
+ mprint(f"Teacher directory: {cfg.teacher_dir}")
+ mprint(
+ f"Calc subblock stats config: {format_global_config(cfg.calc_subblock_stats, title='Calc subblock stats')}"
+ )
+
+ descriptor = ModelDescriptorFactory.get(cfg.descriptor)
+ calculate_subblock_stats_for_puzzle_dir(
+ cfg.calc_subblock_stats,
+ master_puzzle_dir=cfg.puzzle_dir,
+ teacher_dir=cfg.teacher_dir,
+ descriptor=descriptor,
+ model_hidden_sizes=cfg.calc_subblock_stats.get("model_hidden_sizes", OmegaConf.create([])),
+ ffn_hidden_sizes=cfg.calc_subblock_stats.get("ffn_hidden_sizes", OmegaConf.create([])),
+ batch_sizes=cfg.calc_subblock_stats.batch_sizes,
+ prefill_seq_len=cfg.calc_subblock_stats.prefill_seq_len,
+ generation_seq_len=cfg.calc_subblock_stats.generation_seq_len,
+ num_active_tokens_override=cfg.calc_subblock_stats.get("num_active_tokens_override", None),
+ prefill_queue_size=cfg.calc_subblock_stats.prefill_queue_size,
+ allocate_prefill_query=cfg.calc_subblock_stats.get("allocate_prefill_query", False),
+ benchmark_iterations=cfg.calc_subblock_stats.get("benchmark_iterations", None),
+ merge_with_existing_stats=cfg.calc_subblock_stats.merge_with_existing_stats,
+ subblock_stats_filename=cfg.calc_subblock_stats.subblock_stats_filename,
+ moe_stats_filename=cfg.calc_subblock_stats.moe_stats_filename,
+ )
+
+
+def calculate_subblock_stats_for_puzzle_dir(
+ calc_subblock_stats_config: DictConfig,
+ master_puzzle_dir: Path | str,
+ teacher_dir: Path | str,
+ descriptor: Type[ModelDescriptor],
+ model_hidden_sizes: ListConfig,
+ ffn_hidden_sizes: ListConfig,
+ batch_sizes: Iterable[int] = (1, 8, 16, 32, 64, 128, 256),
+ prefill_seq_len: int = 2048,
+ generation_seq_len: int = 2048,
+ num_active_tokens_override: int | None = None,
+ prefill_queue_size: int = 0, # it's an infery-llm thing
+ allocate_prefill_query: bool = False,
+ benchmark_iterations: (
+ int | None
+ ) = None, # If set then compute runtime performance statistics. TODO: recommend default value, is 1000 good?
+ merge_with_existing_stats: bool = False,
+ subblock_stats_filename: str = "subblock_stats.json",
+ moe_stats_filename: str = "moe_stats.json",
+) -> None:
+ # ==== START === Setup for attach-helper ====
+ # import sys
+ # import os
+ # sys.path.insert(0, os.environ["ATTACH_HELPER_INSTALLATION_PATH"])
+ # from attach_helper import debugging_setup
+ # debugging_setup() # You can optionally pass a name to identify the job (e.g. `debugging_setup(name="my_script")`)
+ # ==== END === Setup for attach-helper ====
+ if isinstance(batch_sizes, str):
+ batch_sizes = [
+ int(batch_size) for batch_size in batch_sizes.strip("[]").replace(" ", "").split(",")
+ ]
+
+ master_puzzle_dir = Path(master_puzzle_dir)
+ teacher_dir = (
+ Path(teacher_dir) if teacher_dir is not None else master_puzzle_dir / "ckpts" / "teacher"
+ )
+ trust_remote_code = descriptor.requires_trust_remote_code()
+ model_config = load_model_config(teacher_dir, trust_remote_code=trust_remote_code)
+ # Get language model config for LM-specific attributes (VL models have nested config)
+ lm_config = descriptor.get_language_model_config(model_config)
+ subblock_configs = _load_subblock_configs(master_puzzle_dir, ffn_hidden_sizes)
+
+ subblock_stats_file = master_puzzle_dir / subblock_stats_filename
+ if subblock_stats_file.exists() and not merge_with_existing_stats:
+ raise ValueError(
+ f"Subblock stats file {subblock_stats_file} already exists and `merge_with_existing_stats` was set to False."
+ )
+
+ if subblock_stats_file.exists():
+ with open(subblock_stats_file) as f:
+ subblock_stats = json.load(f)
+ else:
+ subblock_stats = []
+
+ moe_stats_file = master_puzzle_dir / moe_stats_filename
+ if not moe_stats_file.exists():
+ warnings.warn(
+ f"MOE stats file {moe_stats_file} does not exist, can't calculate num active params"
+ )
+ moe_stats_file = None
+
+ subblock_stats_args = {immutabledict(x["args"]) for x in subblock_stats}
+
+ data_types = [
+ ("nvfp4", "nvfp4", "nvfp4"),
+ (torch.int8, torch.int8, torch.int8),
+ (torch.int8, torch.int8, torch.bfloat16),
+ (torch.bfloat16, torch.bfloat16, torch.bfloat16),
+ ]
+
+ model_hidden_sizes = model_hidden_sizes + [
+ lm_config.hidden_size
+ ] # add a teacher model hidden size
+ for batch_size, (
+ weights_dtype,
+ activations_dtype,
+ kv_cache_dtype,
+ ), model_hidden_size in product(batch_sizes, data_types, model_hidden_sizes):
+ if num_active_tokens_override is not None:
+ prefill_seq_len = generation_seq_len = int(num_active_tokens_override / batch_size / 2)
+
+ curr_benchmark_iterations = (
+ benchmark_iterations if weights_dtype == torch.bfloat16 else None
+ )
+
+ curr_subblock_stats = calculate_subblock_stats(
+ calc_subblock_stats_config,
+ teacher_dir=teacher_dir,
+ model_config=model_config,
+ descriptor=descriptor,
+ master_puzzle_dir=master_puzzle_dir,
+ subblock_configs=subblock_configs,
+ batch_size=batch_size,
+ prefill_seq_len=prefill_seq_len,
+ generation_seq_len=generation_seq_len,
+ prefill_queue_size=prefill_queue_size,
+ n_embd=model_hidden_size,
+ n_head=lm_config.num_attention_heads,
+ vocab_size=lm_config.vocab_size,
+ benchmark_iterations=curr_benchmark_iterations,
+ use_cuda_graph=True,
+ weights_dtype=weights_dtype,
+ activations_dtype=activations_dtype,
+ kv_cache_dtype=kv_cache_dtype,
+ allocate_prefill_query=allocate_prefill_query,
+ moe_stats_file=moe_stats_file,
+ )
+
+ if immutabledict(curr_subblock_stats["args"]) in subblock_stats_args:
+ raise ValueError(
+ f"Failed merging subblock_stats. The following arguments already existed in the file: {curr_subblock_stats['args']}"
+ )
+
+ subblock_stats.append(curr_subblock_stats)
+
+ # TODO fix: add_int8_runtime_estimates(subblock_stats)
+
+ json_dump(subblock_stats, subblock_stats_file)
+
+ mprint(subblock_stats_file)
+
+
+def _load_subblock_configs(
+ master_puzzle_dir: Path, ffn_hidden_sizes: ListConfig
+) -> list[SubblockConfig]:
+ try:
+ subblock_configs = _load_subblock_configs_from_replacement_library(master_puzzle_dir)
+ except FileNotFoundError:
+ subblock_configs = _load_subblock_configs_from_subblock_library(master_puzzle_dir)
+
+ # Extend subblock stats calculation space with ffn_hidden_sizes defined in the calc_subblock_stats section of the model config yaml file.
+ extra_ffn_subblock_configs = []
+ for ffn_hidden_size in ffn_hidden_sizes:
+ # Use FFNConfig defaults (hidden_act will use its default value)
+ ffn_config = FFNConfig(intermediate_size=ffn_hidden_size)
+ extra_ffn_subblock_configs.append(
+ immutabledict({"subblock_config": ffn_config, "parent_layer_indices": tuple([-1])})
+ ) # -1 to indicate that this sublock has no parent layer
+ subblock_configs.extend(extra_ffn_subblock_configs)
+
+ return subblock_configs
+
+
+def _load_subblock_configs_from_subblock_library(master_puzzle_dir: Path) -> list[SubblockConfig]:
+ subblocks_df = pd.read_json(master_puzzle_dir / "subblock_library.json")
+ subblocks_df["attention_config"] = subblocks_df["attention_config"].apply(
+ partial(_dataclass_from_dict, cls=AttentionConfig)
+ )
+ subblocks_df["ffn_config"] = subblocks_df["ffn_config"].apply(
+ partial(_dataclass_from_dict, cls=FFNConfig)
+ )
+ attention_configs = subblocks_df["attention_config"].dropna().drop_duplicates().tolist()
+ ffn_configs = subblocks_df["ffn_config"].dropna().drop_duplicates().tolist()
+ subblock_configs = attention_configs + ffn_configs
+ return subblock_configs
+
+
+def _load_subblock_configs_from_replacement_library(
+ master_puzzle_dir: Path,
+) -> list[SubblockConfig]:
+ """Load unique subblocks from replacement_library.json, e.g.,
+ 256 = 32*8 unique sublocks will be returned for a model with 32 layers and the search space of
+ 4 intermediate_size + teacher_intermediate_size + ffn_noop + att_op (teacher) + att_noop.
+
+ Args:
+ master_puzzle_dir (Path): Directory with "replacement_library.json" file
+
+ Returns:
+ list[SubblockConfig]:
+ """
+ replacement_library = json.loads((master_puzzle_dir / "replacement_library.json").read_text())
+ subblock_configs = set()
+ for layer_replacement in replacement_library:
+ layer_replacement = parse_layer_replacement(layer_replacement)
+
+ for block_config in layer_replacement["child_block_configs"]:
+ block_config: BlockConfig
+ attention_frozen_dict = immutabledict(
+ {
+ "subblock_config": block_config.attention,
+ "parent_layer_indices": tuple(layer_replacement["parent_layer_indices"]),
+ }
+ )
+ ffn_frozen_dict = immutabledict(
+ {
+ "subblock_config": block_config.ffn,
+ "parent_layer_indices": tuple(layer_replacement["parent_layer_indices"]),
+ }
+ )
+ subblock_configs.add(attention_frozen_dict)
+ subblock_configs.add(ffn_frozen_dict)
+
+ if block_config.parallel_blocks is not None:
+ for block_idx, internal_block_config in enumerate(block_config.parallel_blocks):
+ attention_frozen_dict = immutabledict(
+ {
+ "subblock_config": internal_block_config.attention,
+ "parent_layer_indices": tuple(
+ layer_replacement["parent_layer_indices"]
+ ),
+ "inner_block_idx": block_idx,
+ }
+ )
+ ffn_frozen_dict = immutabledict(
+ {
+ "subblock_config": internal_block_config.ffn,
+ "parent_layer_indices": tuple(
+ layer_replacement["parent_layer_indices"]
+ ),
+ "inner_block_idx": block_idx,
+ }
+ )
+ subblock_configs.add(attention_frozen_dict)
+ subblock_configs.add(ffn_frozen_dict)
+
+ subblock_configs = list(subblock_configs)
+ return subblock_configs
+
+
+T_DataClass: TypeVar = Type[dataclasses.dataclass]
+
+
+def _dataclass_from_dict(
+ d: dict | T_DataClass | None,
+ cls: T_DataClass,
+) -> T_DataClass | None:
+ if isinstance(d, cls):
+ return d
+ if isinstance(d, dict):
+ return cls(**d)
+ if pd.isna(d):
+ return None
+ raise ValueError(f"_dataclass_from_dict: unrecognized {type(d)=} {d=}")
+
+
+def add_int8_runtime_estimates(subblock_stats: list[dict]) -> None:
+ for curr_subblock_stats in subblock_stats:
+ args = curr_subblock_stats["args"]
+ if args["weights_dtype"] == "torch.int8":
+ assert args["activations_dtype"] == "torch.int8"
+ ffn_factor = 0.5
+ attention_factor = 0.5 if args["kv_cache_dtype"] == "torch.int8" else 0.8
+
+ bf16_stats = _find_corresponding_bf16_stats(args, subblock_stats)
+ if bf16_stats is not None:
+ curr_subblocks = curr_subblock_stats["subblocks"] + [
+ curr_subblock_stats["non_block"]
+ ]
+ bf16_subblocks = bf16_stats["subblocks"] + [bf16_stats["non_block"]]
+ for curr_subblock, bf16_subblock in zip(curr_subblocks, bf16_subblocks):
+ assert curr_subblock.get("subblock_config", None) == bf16_subblock.get(
+ "subblock_config", None
+ )
+ is_attention = False
+ if (subblock_config := curr_subblock.get("subblock_config")) is not None:
+ if hasattr(subblock_config, "__dataclass_fields__"):
+ subblock_config = dataclasses.asdict(subblock_config)
+ is_attention = subblock_config.get("num_key_value_heads", None) is not None
+ runtime_factor = attention_factor if is_attention else ffn_factor
+ for stat_name, stat_value in bf16_subblock.items():
+ if "runtime" in stat_name:
+ curr_subblock[stat_name] = stat_value * runtime_factor
+
+
+def _find_corresponding_bf16_stats(args: dict, subblock_stats: list[dict]) -> dict | None:
+ scenario_keys = [
+ "batch_size",
+ "prefill_seq_len",
+ "generation_seq_len",
+ "prefill_queue_size",
+ "gpu",
+ "n_embd",
+ "n_head",
+ "vocab_size",
+ ]
+ corresponding_bf16_args = {
+ **{k: v for k, v in args.items() if k in scenario_keys},
+ "is_calc_runtime": True,
+ "weights_dtype": "torch.bfloat16",
+ "activations_dtype": "torch.bfloat16",
+ "kv_cache_dtype": "torch.bfloat16",
+ }
+ matching_bf16_stats = [
+ stats
+ for stats in subblock_stats
+ if all(
+ [
+ stats["args"][key] == corresponding_bf16_args[key]
+ for key in corresponding_bf16_args.keys()
+ ]
+ )
+ ]
+ if len(matching_bf16_stats) == 0:
+ return None
+ if len(matching_bf16_stats) == 1:
+ return matching_bf16_stats[0]
+ raise ValueError(f"Found more than 1 matching bf16 stats for {args=}")
diff --git a/modelopt/torch/puzzletron/tools/__init__.py b/modelopt/torch/puzzletron/tools/__init__.py
new file mode 100644
index 0000000000..47f1c65a15
--- /dev/null
+++ b/modelopt/torch/puzzletron/tools/__init__.py
@@ -0,0 +1,15 @@
+# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+# SPDX-License-Identifier: Apache-2.0
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
diff --git a/modelopt/torch/puzzletron/tools/bypassed_training/child_init.py b/modelopt/torch/puzzletron/tools/bypassed_training/child_init.py
new file mode 100644
index 0000000000..6b98d36a0e
--- /dev/null
+++ b/modelopt/torch/puzzletron/tools/bypassed_training/child_init.py
@@ -0,0 +1,1181 @@
+# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+# SPDX-License-Identifier: Apache-2.0
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# mypy: ignore-errors
+
+"""Core logic for creating pruned child model state dicts from parent models. Used by init_child_from_parent."""
+
+import concurrent.futures
+import dataclasses
+import json
+import os
+import re
+import time
+from copy import deepcopy
+from enum import Enum
+from functools import partial
+from pathlib import Path
+from typing import Any, Callable, Dict, List, Optional, Tuple, Union
+
+import torch
+from transformers import PretrainedConfig
+from typeguard import check_type
+
+from modelopt.torch.puzzletron.decilm.deci_lm_hf_code.block_config import (
+ SUBBLOCK_CLS_DICT,
+ BlockConfig,
+ _get_dataclass_type,
+ _is_dataclass_type,
+)
+from modelopt.torch.puzzletron.pruning.pruning_utils import (
+ ACTIVATIONS_LOG,
+ GQAInitMode,
+ HiddenSizeInitMode,
+ LinearInitMode,
+ MlpInitMode,
+ _cache_activations_log,
+ _init_attention_biases,
+ _init_attention_weights,
+ _init_mlp_module,
+ _init_moe_module,
+ _load_activations_log,
+ _load_expert_scores,
+ _select_expert_indices,
+)
+from modelopt.torch.puzzletron.tools.logger import aprint, mprint
+
+IgnoreFn = Callable[[str], bool]
+
+default_ignore_fn: IgnoreFn = lambda _: False
+
+
+class Printer:
+ @staticmethod
+ def print(s: str) -> None:
+ print(s)
+
+
+def _process_single_layer(
+ layer_idx: int,
+ pruning_mixin,
+ descriptor,
+ parent_state_dict: dict,
+ new_state_dict: dict,
+ original_config: PretrainedConfig,
+ new_config: PretrainedConfig,
+ gqa_init_mode: GQAInitMode,
+ mlp_init_mode: MlpInitMode,
+ mlp_init_config: Optional[dict[str, Any]],
+ linear_init_mode: LinearInitMode,
+ ignored_keys: set,
+ keys: dict,
+ is_original_mha: bool,
+ head_size: int,
+ hidden_size: int,
+) -> Tuple[Dict[str, torch.Tensor], Dict[str, str]]:
+ """
+ Process a single layer in parallel. Returns (layer_state_dict, keys_to_remove).
+ Thread-safe function for parallel layer processing.
+ """
+ keys_to_remove = {}
+ layer_out_state_dict = {}
+
+ # Delegate to pruning_mixin if available
+ if pruning_mixin is not None:
+ _layer_out = pruning_mixin.prune_single_layer(
+ layer_idx=layer_idx,
+ parent_state_dict=parent_state_dict,
+ new_state_dict=new_state_dict,
+ original_config=original_config,
+ new_config=new_config,
+ gqa_init_mode=gqa_init_mode,
+ mlp_init_mode=mlp_init_mode,
+ mlp_init_config=mlp_init_config,
+ linear_init_mode=linear_init_mode,
+ ignored_keys=ignored_keys,
+ keys=keys,
+ is_original_mha=is_original_mha,
+ head_size=head_size,
+ hidden_size=hidden_size,
+ keys_to_remove=keys_to_remove,
+ )
+ layer_out_state_dict.update(_layer_out)
+ return layer_out_state_dict, keys_to_remove
+
+ # Legacy inline processing (fallback when no pruning_mixin)
+
+ parent_block_config = original_config.block_configs[layer_idx]
+ child_block_config = new_config.block_configs[layer_idx]
+
+ # Attention processing
+ for part in ["weight", "bias"]:
+ attn_prefix = f"model.layers.{layer_idx}.self_attn"
+ q_key = f"{attn_prefix}.q_proj.{part}"
+ k_key = f"{attn_prefix}.k_proj.{part}"
+ v_key = f"{attn_prefix}.v_proj.{part}"
+ o_key = f"{attn_prefix}.o_proj.{part}"
+ attn_keys = [q_key, k_key, v_key, o_key]
+ # Drop attn keys that don't exist and required to be in the new state_dict
+ attn_keys = [key for key in attn_keys if key in new_state_dict.keys()]
+ if len(attn_keys) > 0 and all(key in keys for key in attn_keys):
+ for key in attn_keys:
+ keys_to_remove[key] = keys[key]
+ if all(key not in ignored_keys for key in attn_keys):
+ is_student_and_teacher_have_same_attention_implementation = all(
+ key in new_state_dict.keys() for key in attn_keys
+ )
+ if is_student_and_teacher_have_same_attention_implementation:
+ if part == "weight":
+ wq, wk, wv, wo = _init_attention_weights(
+ gqa_init_mode=gqa_init_mode,
+ layer_idx=layer_idx,
+ new_state_dict=new_state_dict,
+ new_config=new_config,
+ original_state_dict=parent_state_dict,
+ original_config=original_config,
+ q_key=q_key,
+ k_key=k_key,
+ v_key=v_key,
+ o_key=o_key,
+ is_original_mha=is_original_mha,
+ head_size=head_size,
+ mlp_init_config=mlp_init_config,
+ )
+ layer_out_state_dict[q_key], layer_out_state_dict[k_key] = wq, wk
+ layer_out_state_dict[v_key], layer_out_state_dict[o_key] = wv, wo
+ else:
+ bias_sd = _init_attention_biases(
+ gqa_init_mode=gqa_init_mode,
+ layer_idx=layer_idx,
+ new_state_dict=new_state_dict,
+ new_config=new_config,
+ original_state_dict=parent_state_dict,
+ original_config=original_config,
+ q_key=q_key,
+ k_key=k_key,
+ v_key=v_key,
+ o_key=o_key,
+ is_original_mha=is_original_mha,
+ head_size=head_size,
+ mlp_init_config=mlp_init_config,
+ )
+ for bias_key, sd_key in zip("qkvo", [q_key, k_key, v_key, o_key]):
+ if bias_key in bias_sd.keys():
+ layer_out_state_dict[sd_key] = bias_sd[bias_key]
+
+ else:
+ linear_attn_key = f"{attn_prefix}.linear_attn.weight"
+ is_student_attn_replaced_with_linear = linear_attn_key in new_state_dict.keys()
+ if is_student_attn_replaced_with_linear:
+ if linear_init_mode == LinearInitMode.Random:
+ layer_out_state_dict[linear_attn_key] = new_state_dict[linear_attn_key]
+ elif linear_init_mode == LinearInitMode.FromTeacher:
+ layer_out_state_dict[linear_attn_key] = _init_linear_attn(
+ parent_state_dict, original_config, layer_idx, v_key, o_key
+ )
+ else:
+ raise ValueError(f"Unknown {linear_init_mode=}")
+ else:
+ # student attn random init
+ for new_key in new_state_dict.keys():
+ if attn_prefix in new_key:
+ layer_out_state_dict[new_key] = new_state_dict[new_key]
+
+ # MLP/MoE processing
+ is_parent_moe = parent_block_config.ffn.is_moe
+ if not is_parent_moe: # not MoE, init the MLP
+ mlp_prefix = f"model.layers.{layer_idx}.mlp"
+ linear_mlp_key = f"{mlp_prefix}.linear_mlp.weight"
+
+ is_student_mlp_replaced_with_linear = linear_mlp_key in new_state_dict.keys()
+ if is_student_mlp_replaced_with_linear:
+ if linear_init_mode == LinearInitMode.Random:
+ layer_out_state_dict[linear_mlp_key] = new_state_dict[linear_mlp_key]
+ elif linear_init_mode == LinearInitMode.FromTeacher:
+ teacher_mlp_state_dict = {
+ k.split(mlp_prefix + ".")[1]: v
+ for k, v in parent_state_dict.items()
+ if mlp_prefix in k
+ }
+ layer_out_state_dict[linear_mlp_key] = _init_linear_mlp(teacher_mlp_state_dict)
+ else:
+ raise ValueError(f"Unknown {linear_init_mode=}")
+ else:
+ layer_out_state_dict.update(
+ _init_mlp(
+ mlp_init_mode=mlp_init_mode,
+ layer_idx=layer_idx,
+ original_config=original_config,
+ mlp_init_config=mlp_init_config,
+ original_state_dict=parent_state_dict,
+ new_state_dict=new_state_dict,
+ new_config=new_config,
+ keys=keys,
+ ignored_keys=ignored_keys,
+ )
+ )
+ else:
+ is_child_moe = child_block_config.ffn.is_moe
+ if is_child_moe:
+ parent_moe_config = original_config.block_configs[layer_idx].ffn.moe
+ child_moe_config = new_config.block_configs[layer_idx].ffn.moe
+ if parent_moe_config == child_moe_config:
+ pass # copy the MoE as is
+ elif mlp_init_mode == MlpInitMode.MoEChannelPruning:
+ for expert_idx in range(parent_moe_config.num_local_experts):
+ layer_out_state_dict.update(
+ _init_mlp(
+ mlp_init_mode=mlp_init_mode,
+ layer_idx=layer_idx,
+ original_config=original_config,
+ mlp_init_config=mlp_init_config,
+ original_state_dict=parent_state_dict,
+ new_state_dict=new_state_dict,
+ new_config=new_config,
+ keys=keys,
+ ignored_keys=ignored_keys,
+ expert_idx=expert_idx,
+ )
+ )
+
+ elif mlp_init_mode == MlpInitMode.ExpertRemoval: # remove some of the routed experts
+ router_key, new_experts_keys = _generate_moe_keys(
+ layer_idx, child_block_config.ffn.moe.num_local_experts
+ )
+ _, orig_experts_keys = _generate_moe_keys(
+ layer_idx, parent_block_config.ffn.moe.num_local_experts
+ )
+ keys_to_remove[router_key] = keys.get(router_key)
+ for key in sum(orig_experts_keys.values(), []):
+ keys_to_remove[key] = keys.get(key)
+
+ orig_experts_weights = {
+ name: [parent_state_dict[key] for key in orig_experts_module_keys]
+ for name, orig_experts_module_keys in orig_experts_keys.items()
+ }
+ new_experts_weights = {
+ name: [new_state_dict[key] for key in new_experts_module_keys]
+ for name, new_experts_module_keys in new_experts_keys.items()
+ }
+ out_router_weights, out_experts_weights = _init_moe_module(
+ layer_idx=layer_idx,
+ mlp_init_mode=mlp_init_mode,
+ mlp_init_config=mlp_init_config,
+ orig_router_weight=parent_state_dict[router_key],
+ orig_experts_weights=orig_experts_weights,
+ new_router_weight=new_state_dict[router_key],
+ new_experts_weights=new_experts_weights,
+ )
+ layer_out_state_dict[router_key] = out_router_weights
+ for name in new_experts_keys.keys():
+ layer_out_state_dict.update(
+ zip(new_experts_keys[name], out_experts_weights[name])
+ )
+ elif child_block_config.ffn.no_op: # no-op, drop this layer
+ parent_mlp_prefix = f"model.layers.{layer_idx}.mlp"
+ for key in list(keys.keys()):
+ if key.startswith(parent_mlp_prefix):
+ keys_to_remove[key] = keys[key]
+ else:
+ assert mlp_init_mode == MlpInitMode.ConcatExpertsIntoDenseFFN, (
+ "The parent layer is MoE and the child layer is a normal FFN. The only supported mode is ConcatExpertsAsMLP."
+ )
+
+ child_ffn_state_dict = _concatenate_experts_into_dense_ffn(
+ parent_state_dict,
+ mlp_init_config,
+ hidden_size,
+ layer_idx,
+ child_block_config,
+ parent_block_config,
+ )
+ layer_out_state_dict.update(child_ffn_state_dict)
+
+ for key in list(keys.keys()):
+ if key.startswith(f"model.layers.{layer_idx}.mlp"):
+ keys_to_remove[key] = keys[key]
+
+ # Handle missing keys
+ for key_possibly_missing_in_student in [
+ "self_attn.q_proj",
+ "self_attn.k_proj",
+ "self_attn.v_proj",
+ "self_attn.o_proj",
+ "mlp.gate_proj",
+ "mlp.up_proj",
+ "mlp.down_proj",
+ "input_layernorm",
+ "post_attention_layernorm",
+ ]:
+ key_possibly_missing_in_student = f".{layer_idx}.{key_possibly_missing_in_student}"
+ is_key_missing_from_student = (
+ len([k for k in new_state_dict.keys() if key_possibly_missing_in_student in k]) == 0
+ )
+ if is_key_missing_from_student:
+ for k in list(keys.keys()):
+ if key_possibly_missing_in_student in k:
+ keys_to_remove[k] = keys[k]
+
+ return layer_out_state_dict, keys_to_remove
+
+
+@torch.no_grad()
+def create_child_state_dict(
+ pruning_mixin,
+ descriptor,
+ original_state_dict: dict,
+ new_state_dict: dict,
+ original_config: PretrainedConfig,
+ new_config: PretrainedConfig,
+ gqa_init_mode: GQAInitMode,
+ ignore_fn: IgnoreFn = default_ignore_fn,
+ mlp_init_mode: MlpInitMode = MlpInitMode.CopyAsIs,
+ mlp_init_config: Optional[dict[str, Any]] = None,
+ owned_block_indexes: Optional[set[int]] = None,
+ linear_init_mode: LinearInitMode = LinearInitMode.Random,
+ hidden_size_init_mode: HiddenSizeInitMode = HiddenSizeInitMode.CopyAsIs,
+ channel_importance_path: Optional[str] = None,
+ max_layer_workers: Optional[int] = None, # Now optional - will auto-calculate if None
+):
+ mprint("=== Starting create_child_state_dict with optimizations ===")
+ total_start_time = time.time()
+
+ # Phase 1: Initial setup and validation
+ setup_start_time = time.time()
+ if owned_block_indexes is None:
+ owned_block_indexes = set(range(new_config.num_hidden_layers))
+
+ # Auto-calculate optimal layer workers: min(cpu_count, num_layers)
+ if max_layer_workers is None:
+ cpu_count = os.cpu_count() or 1
+ num_layers = len(owned_block_indexes)
+ max_layer_workers = min(cpu_count, num_layers)
+ mprint(
+ f"Auto-calculated layer workers: min({cpu_count} CPUs, {num_layers} layers) = {max_layer_workers}"
+ )
+ else:
+ mprint(f"Using specified layer workers: {max_layer_workers}")
+
+ # Memory optimization: Pre-allocate output state dict with known shapes
+ expected_keys_and_shapes = {k: v.shape for k, v in new_state_dict.items()}
+ out_state_dict = {}
+
+ # Pre-allocate tensors where possible to reduce memory fragmentation
+ for key, shape in expected_keys_and_shapes.items():
+ if key in new_state_dict:
+ tensor = new_state_dict[key]
+ # Only make contiguous if necessary (memory optimization)
+ if not tensor.is_contiguous():
+ out_state_dict[key] = tensor.contiguous()
+ else:
+ out_state_dict[key] = tensor
+
+ # Get language model config for LM-specific attributes (VL models have nested config)
+ original_lm_config = descriptor.get_language_model_config(original_config)
+ new_lm_config = descriptor.get_language_model_config(new_config)
+
+ # Check if original model is MHA (all layers have num_key_value_heads == num_attention_heads)
+ original_num_kv_heads_per_layer = [
+ b.attention.num_key_value_heads for b in original_config.block_configs
+ ]
+ num_attention_heads = original_lm_config.num_attention_heads
+ is_original_mha = all(kv == num_attention_heads for kv in original_num_kv_heads_per_layer)
+ is_same_hidden_size = original_lm_config.hidden_size == new_lm_config.hidden_size
+ head_size = _get_head_dim(new_lm_config)
+ orig_head_size = _get_head_dim(original_lm_config)
+ assert head_size == orig_head_size, f"head_size {head_size} != orig_head_size {orig_head_size}"
+
+ # Allow different hidden sizes for pruning
+ if not is_same_hidden_size:
+ assert new_lm_config.hidden_size <= original_lm_config.hidden_size, (
+ f"New hidden size ({new_lm_config.hidden_size}) must be <= original ({original_lm_config.hidden_size})"
+ )
+ assert hidden_size_init_mode != HiddenSizeInitMode.CopyAsIs, (
+ "Cannot copy as is when hidden sizes differ"
+ )
+
+ hidden_size = original_lm_config.hidden_size
+
+ ignored_keys = set([key for key in original_state_dict.keys() if ignore_fn(key)])
+ for key in ignored_keys:
+ aprint(f"Ignoring key {key} and taking its init from new_state_dict")
+ out_state_dict[key] = new_state_dict[key]
+
+ keys = {
+ match.group(1) if (match := re.search(r"(h\.\d+\..*)", key)) is not None else key: key
+ for key in original_state_dict.keys()
+ }
+ setup_time = time.time() - setup_start_time
+ mprint(f"Phase 1 - Setup and memory pre-allocation: {setup_time:.2f}s")
+
+ # Phase 2: Parallel layer processing
+ layer_processing_start_time = time.time()
+
+ # Prepare arguments for parallel processing
+ process_layer_partial = partial(
+ _process_single_layer,
+ pruning_mixin=pruning_mixin,
+ descriptor=descriptor,
+ parent_state_dict=original_state_dict,
+ new_state_dict=new_state_dict,
+ original_config=original_config,
+ new_config=new_config,
+ gqa_init_mode=gqa_init_mode,
+ mlp_init_mode=mlp_init_mode,
+ mlp_init_config=mlp_init_config,
+ linear_init_mode=linear_init_mode,
+ ignored_keys=ignored_keys,
+ keys=keys,
+ is_original_mha=is_original_mha,
+ head_size=head_size,
+ hidden_size=hidden_size,
+ )
+
+ # Process layers in parallel with optimal worker count
+ mprint(
+ f"Processing {len(owned_block_indexes)} layers in parallel with {max_layer_workers} workers..."
+ )
+ layer_results = []
+ all_keys_to_remove = {}
+
+ with concurrent.futures.ThreadPoolExecutor(max_workers=max_layer_workers) as executor:
+ future_to_layer = {
+ executor.submit(process_layer_partial, layer_idx): layer_idx
+ for layer_idx in owned_block_indexes
+ }
+
+ completed = 0
+ for future in concurrent.futures.as_completed(future_to_layer):
+ layer_idx = future_to_layer[future]
+ try:
+ layer_state_dict, keys_to_remove = future.result()
+ layer_results.append((layer_idx, layer_state_dict))
+ all_keys_to_remove.update(keys_to_remove)
+
+ completed += 1
+ if completed % 20 == 0 or completed == len(
+ owned_block_indexes
+ ): # More frequent progress updates
+ mprint(f"Completed {completed}/{len(owned_block_indexes)} layers")
+ except Exception as exc:
+ mprint(f"Layer {layer_idx} generated an exception: {exc}")
+ raise exc
+
+ # Merge layer results into main state dict (memory efficient)
+ for layer_idx, layer_state_dict in layer_results:
+ out_state_dict.update(layer_state_dict)
+
+ # Remove processed keys from the keys dict
+ for key_to_remove in all_keys_to_remove:
+ keys.pop(key_to_remove, None)
+
+ layer_processing_time = time.time() - layer_processing_start_time
+ mprint(
+ f"Phase 2 - Parallel layer processing: {layer_processing_time:.2f}s ({max_layer_workers} workers)"
+ )
+
+ # Phase 3: Copy remaining keys from original model
+ copy_start_time = time.time()
+ keys_to_copy_from_orig_model = set(keys.values()) - ignored_keys
+ for key in keys_to_copy_from_orig_model:
+ # Memory optimization: avoid unnecessary copies
+ tensor = original_state_dict[key]
+ if not tensor.is_contiguous():
+ out_state_dict[key] = tensor.contiguous()
+ else:
+ out_state_dict[key] = tensor
+ copy_time = time.time() - copy_start_time
+ mprint(
+ f"Phase 3 - Copy remaining keys: {copy_time:.2f}s ({len(keys_to_copy_from_orig_model)} keys)"
+ )
+
+ # Handle hidden size pruning for remaining keys
+ if not is_same_hidden_size:
+ out_state_dict = _apply_hidden_size_pruning(
+ out_state_dict,
+ original_state_dict,
+ new_config,
+ original_config,
+ descriptor,
+ hidden_size_init_mode,
+ channel_importance_path,
+ owned_block_indexes,
+ )
+
+ # Phase 4: Verification
+ verify_start_time = time.time()
+ _verify_state_dicts_match(out_state_dict, expected_keys_and_shapes)
+ verify_time = time.time() - verify_start_time
+ mprint(f"Phase 4 - Verification: {verify_time:.2f}s")
+
+ total_time = time.time() - total_start_time
+ mprint(f"=== create_child_state_dict completed in {total_time:.2f}s ===")
+ mprint(
+ f"Breakdown: Setup {setup_time:.1f}s + ParallelProcessing {layer_processing_time:.1f}s + Copy {copy_time:.1f}s + Verify {verify_time:.1f}s"
+ )
+ mprint(
+ f"Speedup: Used {max_layer_workers} workers for {len(owned_block_indexes)} layers (CPU utilization: {max_layer_workers}/{os.cpu_count() or 1})"
+ )
+
+ return out_state_dict
+
+
+def _generate_moe_keys(layer_idx: int, num_experts: int) -> tuple[str, dict[str, list[str]]]:
+ mlp_prefix = f"model.layers.{layer_idx}.mlp"
+ router_key = f"{mlp_prefix}.router.weight"
+ names = ["gate_proj", "up_proj", "down_proj"]
+ experts_module_names = {
+ name: f"{mlp_prefix}.experts.{{expert_idx}}.{name}.weight" for name in names
+ }
+ return router_key, {
+ name: [module_name.format(expert_idx=expert_idx) for expert_idx in range(num_experts)]
+ for name, module_name in experts_module_names.items()
+ }
+
+
+def _concatenate_experts_into_dense_ffn(
+ original_state_dict: dict[str, torch.Tensor],
+ mlp_init_config: Optional[dict],
+ hidden_size: int,
+ layer_idx: int,
+ child_block_config: BlockConfig,
+ parent_block_config: BlockConfig,
+) -> dict[str, torch.Tensor]:
+ assert child_block_config.ffn.gated and child_block_config.ffn.hidden_act == "silu", (
+ "Llama4 experts use SwiGLU."
+ )
+
+ # verify sizes
+ child_intermediate_size = child_block_config.ffn.intermediate_size
+ parent_moe_config = parent_block_config.ffn.moe
+ shared_expert_intermediate_dim = parent_moe_config.shared_expert_intermediate_dim
+ routed_expert_intermediate_dim = parent_moe_config.expert_intermediate_dim
+ total_concatenated_routed_experts_size = (
+ child_intermediate_size - shared_expert_intermediate_dim
+ )
+ assert total_concatenated_routed_experts_size % routed_expert_intermediate_dim == 0, (
+ f"{child_intermediate_size=} "
+ f"{shared_expert_intermediate_dim=} "
+ f"{routed_expert_intermediate_dim=} "
+ f"{total_concatenated_routed_experts_size=} "
+ f"{total_concatenated_routed_experts_size % routed_expert_intermediate_dim=} != 0"
+ )
+ num_concatenated_routed_experts = (
+ total_concatenated_routed_experts_size // routed_expert_intermediate_dim
+ )
+
+ # if needed, concatenate some of the routed experts
+ if num_concatenated_routed_experts == 0:
+ print(
+ f"Removing all routed experts from layer {layer_idx}, turning the shared expert into a dense FFN."
+ )
+ concat_routed_state_dict = dict()
+ else:
+ print(
+ f"Concatenating {num_concatenated_routed_experts} routed experts to the shared expert in layer {layer_idx}"
+ )
+ router_key, orig_experts_keys = _generate_moe_keys(
+ layer_idx, parent_moe_config.num_local_experts
+ )
+ orig_experts_weights = {
+ name: [original_state_dict[key] for key in orig_experts_module_keys]
+ for name, orig_experts_module_keys in orig_experts_keys.items()
+ }
+ _, experts_weights = _prune_experts_by_score(
+ mlp_init_config=mlp_init_config,
+ layer_idx=layer_idx,
+ orig_router_weight=original_state_dict[router_key],
+ orig_experts_weights=orig_experts_weights,
+ new_num_experts=num_concatenated_routed_experts,
+ )
+ concat_dims = {"gate_proj": 0, "up_proj": 0, "down_proj": 1}
+ assert list(concat_dims) == list(experts_weights), (
+ "concat_dims and experts_weights must have the same keys"
+ )
+ concat_routed_state_dict = {
+ name: torch.cat(experts_weights[name], dim=concat_dims[name])
+ for name in concat_dims.keys()
+ }
+
+ # turn the shared expert into a normal FFN. concatenate the pruned routed experts if needed.
+ parent_shared_expert_prefix = f"model.layers.{layer_idx}.mlp.shared_expert"
+ child_ffn_prefix = f"model.layers.{layer_idx}.mlp"
+ child_ffn_state_dict = dict()
+
+ for module_name in [
+ "gate_proj",
+ "up_proj",
+ "down_proj",
+ ]:
+ shared_expert_key = f"{parent_shared_expert_prefix}.{module_name}.weight"
+ child_ffn_key = f"{child_ffn_prefix}.{module_name}.weight"
+ shared_expert_weight = original_state_dict[shared_expert_key]
+ concat_routed_weight = concat_routed_state_dict.get(module_name)
+
+ if concat_routed_weight is None:
+ child_weight = shared_expert_weight
+ else:
+ child_weight = torch.cat(
+ [shared_expert_weight, concat_routed_weight],
+ dim=1 if module_name == "down_proj" else 0,
+ )
+ child_ffn_state_dict[child_ffn_key] = child_weight
+
+ return child_ffn_state_dict
+
+
+def _verify_state_dicts_match(
+ state_dict: dict[str, torch.Tensor],
+ expected_keys_and_shapes: dict[str, torch.Size],
+) -> None:
+ # Verify keys match
+ expected_keys = expected_keys_and_shapes.keys()
+ missing_keys = set(expected_keys) - set(state_dict.keys())
+ unexpected_keys = set(state_dict.keys()) - set(expected_keys)
+ assert len(missing_keys) == 0 and len(unexpected_keys) == 0, (
+ f"Missing keys: {missing_keys}\nUnexpected keys: {unexpected_keys}"
+ )
+
+ # Verify shapes match
+ shape_mismatches = []
+ for key in expected_keys:
+ expected_shape = expected_keys_and_shapes[key]
+ actual_shape = state_dict[key].shape
+ if expected_shape != actual_shape:
+ shape_mismatches.append(f"{key}: expected {expected_shape}, got {actual_shape}")
+
+ assert len(shape_mismatches) == 0, "Shape mismatches found:\n" + "\n".join(shape_mismatches)
+ print("""
+############################
+create_child_state_dict: all keys and shapes matched successfully.
+############################
+""")
+
+
+def _init_mlp(
+ *,
+ mlp_init_mode: Union[MlpInitMode, str],
+ layer_idx: int,
+ original_config: PretrainedConfig,
+ mlp_init_config: Optional[dict[str, Any]],
+ original_state_dict: dict,
+ new_state_dict: dict,
+ new_config: PretrainedConfig,
+ keys: dict[str, str],
+ ignored_keys: set[str],
+ expert_idx: Optional[int] = None,
+) -> dict[str, torch.Tensor]:
+ out_state_dict = {}
+
+ if mlp_init_mode == MlpInitMode.MoEChannelPruning:
+ if expert_idx is None:
+ return {}
+ mlp_prefix = f"model.layers.{layer_idx}.mlp.experts.{expert_idx}"
+ else:
+ mlp_prefix = f"model.layers.{layer_idx}.mlp"
+
+ key = f"{mlp_prefix}.down_proj.weight"
+ if key not in keys:
+ return {}
+
+ mlp_c_proj_key = keys[key]
+ if mlp_c_proj_key not in ignored_keys:
+ mlp_keys = [
+ keys.pop(f"{mlp_prefix}.{module_name}.weight")
+ for module_name in ["down_proj", "gate_proj", "up_proj"]
+ ]
+ pruned_filters = None
+ projection_matrix = None
+ for mlp_key in mlp_keys:
+ expanded_dim = 1 if "down_proj" in mlp_key else 0
+ if mlp_key in new_state_dict.keys():
+ mlp_module_weight, pruned_filters, projection_matrix = _init_mlp_module(
+ mlp_init_mode,
+ mlp_prefix,
+ expanded_dim,
+ layer_idx,
+ new_state_dict[mlp_key],
+ new_config,
+ original_state_dict[mlp_key],
+ original_config,
+ mlp_init_config,
+ pruned_filters,
+ projection_matrix,
+ )
+ out_state_dict[mlp_key] = mlp_module_weight
+ else:
+ mprint(f"mlp_key {mlp_key} not in new_state_dict")
+ return out_state_dict
+
+
+def _prune_experts_by_score(
+ *,
+ mlp_init_config: dict[str, Any],
+ layer_idx: int,
+ orig_router_weight: torch.Tensor,
+ orig_experts_weights: dict[str, list[torch.Tensor]],
+ new_num_experts: int,
+) -> tuple[torch.Tensor, dict[str, list[torch.Tensor]]]:
+ orig_num_experts = orig_router_weight.shape[0]
+ assert all(
+ len(orig_experts_module_weights) == orig_num_experts
+ for orig_experts_module_weights in orig_experts_weights.values()
+ )
+ expert_scores = _load_expert_scores(mlp_init_config)[layer_idx]
+ assert len(expert_scores) == orig_num_experts
+ selected_experts = sorted(
+ range(orig_num_experts),
+ key=lambda i: expert_scores[i],
+ reverse=mlp_init_config.get("higher_is_better", True),
+ )[:new_num_experts]
+ result_router_weight = orig_router_weight[selected_experts]
+ result_experts_weights = {
+ name: [orig_experts_module_weights[i] for i in selected_experts]
+ for name, orig_experts_module_weights in orig_experts_weights.items()
+ }
+ return result_router_weight, result_experts_weights
+
+
+def _init_linear_attn(
+ parent_state_dict: dict[str, torch.Tensor],
+ parent_config: PretrainedConfig,
+ layer_idx: int,
+ v_key: str,
+ o_key: str,
+) -> torch.Tensor:
+ """
+ Init a linear layer that operates like an attention layer that assigns score 1 to the current token
+ and score 0 to all others: out = (Wo @ Wv) @ x
+ """
+ n_embd = parent_config.hidden_size
+ head_size = _get_head_dim(parent_config)
+ # Get num_kv_heads from config, compute n_heads_in_group
+ n_kv_heads = parent_config.block_configs[layer_idx].attention.num_key_value_heads
+ n_heads_in_group = parent_config.num_attention_heads // n_kv_heads
+
+ wv = parent_state_dict[v_key]
+ wv = wv.view(n_kv_heads, head_size, n_embd)
+ wv_expanded = torch.repeat_interleave(wv, n_heads_in_group, dim=0).reshape(n_embd, n_embd)
+
+ wo = parent_state_dict[o_key]
+
+ w_linear = wo @ wv_expanded
+ return w_linear
+
+
+def _init_linear_mlp(teacher_mlp_state_dict: dict[str, torch.Tensor]) -> torch.Tensor:
+ """
+ A linear layer that does (W_down @ W_up) @ x, ignoring W_gate.
+ """
+ if "linear_mlp.weight" in teacher_mlp_state_dict: # if the teacher itself is a linear layer
+ return teacher_mlp_state_dict["linear_mlp.weight"]
+
+ w_up = teacher_mlp_state_dict["up_proj.weight"]
+ w_down = teacher_mlp_state_dict["down_proj.weight"]
+ w_linear = w_down @ w_up
+ return w_linear
+
+
+def update_model_config(
+ model_config: PretrainedConfig,
+ model_config_overrides: None | list[dict[str, Any]] | str | dict | Path = None,
+) -> PretrainedConfig:
+ new_model_config = deepcopy(model_config)
+ if model_config_overrides is None:
+ return new_model_config
+
+ model_config_overrides = _parse_model_config_overrides(
+ model_config_overrides, model_config.num_hidden_layers
+ )
+
+ def override(item, item_overrides):
+ if item_overrides is None:
+ return item_overrides
+ if dataclasses.is_dataclass(item):
+ assert isinstance(item_overrides, dict)
+ return dataclass_override(item, item_overrides)
+ if isinstance(item, list):
+ assert isinstance(item_overrides, list)
+ return list_override(item, item_overrides)
+ return item_overrides
+
+ def list_override(ls, ls_overrides: list):
+ assert len(ls) == len(ls_overrides)
+ return [override(item, item_overrides) for item, item_overrides in zip(ls, ls_overrides)]
+
+ def dataclass_override(dc, dc_overrides: dict):
+ if not set(dc_overrides.keys()).issubset(dataclasses.asdict(dc).keys()):
+ raise ValueError(
+ f"Uknown overrides for dataclass {type(dc)}: {', '.join(set(dc_overrides.keys()) - dataclasses.asdict(dc).keys())}"
+ )
+ field_types = {field.name: field.type for field in dataclasses.fields(dc)}
+ dc_changes = {}
+ for key, item_overrides in dc_overrides.items():
+ previous_value, item_type = getattr(dc, key), field_types[key]
+ # if original block was no_op, we should not override it
+ if getattr(dc, "no_op", False):
+ return dc
+
+ if previous_value is None and _is_dataclass_type(item_type):
+ new_value = _get_dataclass_type(item_type)(**item_overrides)
+ else:
+ new_value = override(previous_value, item_overrides)
+ check_type(new_value, item_type)
+ dc_changes[key] = new_value
+ return dataclasses.replace(dc, **dc_changes)
+
+ new_model_config.block_configs = list_override(
+ new_model_config.block_configs, model_config_overrides
+ )
+
+ return new_model_config
+
+
+def _parse_model_config_overrides(
+ model_config_overrides_json: str | dict | Path | list[dict],
+ n_layer: int,
+) -> list[dict[str, Any]]:
+ """
+ example model_config_overrides_dict:
+ {
+ "attention": [{"num_key_value_heads": 4}],
+ "ffn": [{"intermediate_size": 14336}]
+ }
+ """
+ if isinstance(model_config_overrides_json, list) and isinstance(
+ model_config_overrides_json[0], dict
+ ):
+ return model_config_overrides_json
+
+ if isinstance(model_config_overrides_json, dict):
+ model_config_overrides_dict = model_config_overrides_json
+ else:
+ if os.path.exists(
+ model_config_overrides_json
+ ): # using os.path.exists, because Path.exists throws an exception on long strings
+ model_config_overrides_json = Path(model_config_overrides_json).read_text()
+ print(f"I'm json loadsing over here. {model_config_overrides_json=}")
+ model_config_overrides_dict = json.loads(model_config_overrides_json)
+
+ # Sanity checks and conversion to list of dictionaries
+ layer_wise_overrides = [{} for _ in range(n_layer)]
+ for config_key, config_value in model_config_overrides_dict.items():
+ assert config_key in SUBBLOCK_CLS_DICT, f"Unknown config key: {config_key}"
+ assert isinstance(config_value, list), (
+ f"Expected a list for {config_key}, got {config_value}"
+ )
+ assert len(config_value) == n_layer or len(config_value) == 1, (
+ f"Number of elements in {config_key} must be 1 or equal to the number of layers in the model"
+ )
+
+ if len(config_value) == 1:
+ model_config_overrides_dict[config_key] = config_value * n_layer
+
+ for layer_idx in range(n_layer):
+ layer_wise_overrides[layer_idx][config_key] = model_config_overrides_dict[config_key][
+ layer_idx
+ ]
+
+ return layer_wise_overrides
+
+
+def _apply_hidden_size_pruning(
+ out_state_dict: dict[str, torch.Tensor],
+ original_state_dict: dict[str, torch.Tensor],
+ new_config: PretrainedConfig,
+ original_config: PretrainedConfig,
+ descriptor,
+ hidden_size_init_mode: HiddenSizeInitMode,
+ channel_importance_path: Optional[str] = None,
+ owned_block_indexes: Optional[list[int]] = None,
+) -> dict[str, torch.Tensor]:
+ """
+ Apply hidden size pruning to all layers that depend on hidden_size.
+ This includes embeddings, layer norms, and any linear layers that haven't been handled yet.
+ """
+ if isinstance(hidden_size_init_mode, str):
+ hidden_size_init_mode = HiddenSizeInitMode(hidden_size_init_mode)
+
+ # Get language model config (for VL models this extracts the nested config)
+ original_lm_config = descriptor.get_language_model_config(original_config)
+ new_lm_config = descriptor.get_language_model_config(new_config)
+
+ original_hidden_size = original_lm_config.hidden_size
+ new_hidden_size = new_lm_config.hidden_size
+
+ if hidden_size_init_mode == HiddenSizeInitMode.CopyAsIs:
+ return out_state_dict
+
+ # Load channel ranking if needed
+ if hidden_size_init_mode == HiddenSizeInitMode.PruneByChannelRanking:
+ if channel_importance_path is not None:
+ with open(channel_importance_path, "r") as f:
+ channel_ranking = json.load(f)["channel_importance_ranking"]
+ else:
+ raise ValueError(
+ "channel_ranking_path must be provided in hidden_size_init_config for PruneByChannelRanking mode"
+ )
+
+ # Handle embedding layer
+ embed_key = "model.embed_tokens.weight"
+ if embed_key in out_state_dict and embed_key in original_state_dict:
+ out_state_dict[embed_key] = _prune_hidden_size_dimension(
+ original_state_dict[embed_key],
+ new_hidden_size,
+ hidden_size_init_mode,
+ channel_ranking,
+ dim=1,
+ )
+ else:
+ raise ValueError(
+ f"Embed key {embed_key} not found in out_state_dict or original_state_dict"
+ )
+
+ # Handle final layer norm
+ norm_key = "model.norm.weight"
+ if norm_key in out_state_dict and norm_key in original_state_dict:
+ out_state_dict[norm_key] = _prune_hidden_size_dimension(
+ original_state_dict[norm_key],
+ new_hidden_size,
+ hidden_size_init_mode,
+ channel_ranking,
+ dim=0,
+ )
+
+ # Handle LM head
+ lm_head_key = "lm_head.weight"
+ if lm_head_key in out_state_dict and lm_head_key in original_state_dict:
+ if out_state_dict[lm_head_key].shape[1] != new_hidden_size:
+ out_state_dict[lm_head_key] = _prune_hidden_size_dimension(
+ original_state_dict[lm_head_key],
+ new_hidden_size,
+ hidden_size_init_mode,
+ channel_ranking,
+ dim=1,
+ )
+
+ for block_idx in owned_block_indexes:
+ if new_config.block_configs[block_idx].parallel_blocks is None:
+ key_prefix = f"model.layers.{block_idx}"
+ out_state_dict = _prune_hidden_size_dimension_block(
+ out_state_dict,
+ new_hidden_size,
+ hidden_size_init_mode,
+ channel_ranking,
+ new_config.block_configs[block_idx],
+ key_prefix,
+ )
+ else:
+ for internal_block_idx in range(
+ len(new_config.block_configs[block_idx].parallel_blocks)
+ ):
+ block_config = new_config.block_configs[block_idx].parallel_blocks[
+ internal_block_idx
+ ]
+ key_prefix = f"model.layers.{block_idx}.parallel_blocks.{internal_block_idx}"
+ out_state_dict = _prune_hidden_size_dimension_block(
+ out_state_dict,
+ new_hidden_size,
+ hidden_size_init_mode,
+ channel_ranking,
+ block_config,
+ key_prefix,
+ )
+ return out_state_dict
+
+
+def _prune_hidden_size_dimension_block(
+ out_state_dict,
+ new_hidden_size,
+ hidden_size_init_mode,
+ channel_ranking,
+ block_config,
+ key_prefix,
+):
+ for layer_norm in ["input_layernorm", "post_attention_layernorm"]:
+ for part in ["weight", "bias"]:
+ key = f"{key_prefix}.{layer_norm}.{part}"
+ if key in out_state_dict:
+ out_state_dict[key] = _prune_hidden_size_dimension(
+ out_state_dict[key],
+ new_hidden_size,
+ hidden_size_init_mode,
+ channel_ranking,
+ dim=0,
+ )
+ attn_prefix = f"{key_prefix}.self_attn"
+ if block_config.attention.replace_with_linear:
+ linear_attn_key = f"{attn_prefix}.linear_attn.weight"
+ for dim in [0, 1]:
+ out_state_dict[linear_attn_key] = _prune_hidden_size_dimension(
+ out_state_dict[linear_attn_key],
+ new_hidden_size,
+ hidden_size_init_mode,
+ channel_ranking,
+ dim=dim,
+ )
+ elif block_config.attention.is_mamba:
+ for proj in ["in", "out"]:
+ mamba_key = f"{attn_prefix}.mamba_mixer.{proj}_proj.weight"
+ out_state_dict[mamba_key] = _prune_hidden_size_dimension(
+ out_state_dict[mamba_key],
+ new_hidden_size,
+ hidden_size_init_mode,
+ channel_ranking,
+ dim=1 if proj == "in" else 0,
+ )
+ else:
+ for k in "qkvo":
+ for part in ["weight", "bias"]:
+ if k in "qkv" and part == "bias":
+ continue
+ key = f"{attn_prefix}.{k}_proj.{part}"
+ if key in out_state_dict:
+ out_state_dict[key] = _prune_hidden_size_dimension(
+ out_state_dict[key],
+ new_hidden_size,
+ hidden_size_init_mode,
+ channel_ranking,
+ dim=1 if part == "weight" and k in "qkv" else 0,
+ )
+ ffn_prefix = f"{key_prefix}.mlp"
+ if block_config.ffn.replace_with_linear:
+ linear_mlp_key = f"{ffn_prefix}.linear_mlp.weight"
+ for dim in [0, 1]:
+ out_state_dict[linear_mlp_key] = _prune_hidden_size_dimension(
+ out_state_dict[linear_mlp_key],
+ new_hidden_size,
+ hidden_size_init_mode,
+ channel_ranking,
+ dim=dim,
+ )
+ elif block_config.ffn.moe is not None:
+ router_key = f"{ffn_prefix}.router.weight"
+ out_state_dict[router_key] = _prune_hidden_size_dimension(
+ out_state_dict[router_key],
+ new_hidden_size,
+ hidden_size_init_mode,
+ channel_ranking,
+ dim=1,
+ )
+ _prune_hidden_size_dimension_mlp(
+ f"{ffn_prefix}.shared_expert",
+ out_state_dict,
+ new_hidden_size,
+ hidden_size_init_mode,
+ channel_ranking,
+ )
+ for expert_idx in range(block_config.ffn.moe.num_local_experts):
+ _prune_hidden_size_dimension_mlp(
+ f"{ffn_prefix}.experts.{expert_idx}",
+ out_state_dict,
+ new_hidden_size,
+ hidden_size_init_mode,
+ channel_ranking,
+ )
+ else:
+ _prune_hidden_size_dimension_mlp(
+ ffn_prefix, out_state_dict, new_hidden_size, hidden_size_init_mode, channel_ranking
+ )
+ return out_state_dict
+
+
+def _prune_hidden_size_dimension_mlp(
+ name_prefix, out_state_dict, new_hidden_size, hidden_size_init_mode, channel_ranking
+):
+ for proj in ["gate_proj", "up_proj", "down_proj"]:
+ for part in ["weight", "bias"]:
+ if proj != "down_proj" and part == "bias":
+ continue
+ key = f"{name_prefix}.{proj}.{part}"
+ if key in out_state_dict:
+ out_state_dict[key] = _prune_hidden_size_dimension(
+ out_state_dict[key],
+ new_hidden_size,
+ hidden_size_init_mode,
+ channel_ranking,
+ dim=1 if part == "weight" and proj != "down_proj" else 0,
+ )
+
+
+def _prune_hidden_size_dimension(
+ original_tensor: torch.Tensor,
+ new_hidden_size: int,
+ hidden_size_init_mode: HiddenSizeInitMode,
+ channel_ranking: Optional[list[int]] = None,
+ dim: int = -1,
+) -> torch.Tensor:
+ """
+ Prune a tensor along the specified dimension to match the new hidden size.
+ """
+ original_size = original_tensor.shape[dim]
+
+ if hidden_size_init_mode == HiddenSizeInitMode.Random:
+ # Initialize with random weights
+ new_shape = list(original_tensor.shape)
+ new_shape[dim] = new_hidden_size
+ return torch.randn(new_shape, dtype=original_tensor.dtype, device=original_tensor.device)
+
+ elif hidden_size_init_mode == HiddenSizeInitMode.Truncate:
+ # Simple truncation - take the first new_hidden_size elements
+ if dim == -1:
+ return original_tensor[..., :new_hidden_size]
+ elif dim == 0:
+ return original_tensor[:new_hidden_size, ...]
+ elif dim == 1:
+ return original_tensor[:, :new_hidden_size, ...]
+ else:
+ # Handle other dimensions
+ slices = [slice(None)] * original_tensor.ndim
+ slices[dim] = slice(new_hidden_size)
+ return original_tensor[tuple(slices)]
+
+ elif hidden_size_init_mode == HiddenSizeInitMode.PruneByChannelRanking:
+ if channel_ranking is None:
+ raise ValueError("Channel ranking must be provided for PruneByChannelRanking mode")
+
+ # Use channel ranking to select the most important channels
+ if len(channel_ranking) < new_hidden_size:
+ raise ValueError(
+ f"Channel ranking has {len(channel_ranking)} channels but need {new_hidden_size}"
+ )
+
+ # Take the top new_hidden_size channels according to ranking
+ selected_channels = channel_ranking[:new_hidden_size]
+
+ if dim == -1:
+ return original_tensor[..., selected_channels]
+ elif dim == 0:
+ return original_tensor[selected_channels, ...]
+ elif dim == 1:
+ return original_tensor[:, selected_channels, ...]
+ else:
+ # Handle other dimensions
+ slices = [slice(None)] * original_tensor.ndim
+ slices[dim] = selected_channels
+ return original_tensor[tuple(slices)]
+
+ else:
+ raise ValueError(f"Unsupported hidden_size_init_mode: {hidden_size_init_mode}")
+
+
+def _get_head_dim(config) -> int:
+ """Get head dimension from config in a model-agnostic way.
+
+ Some models like Llama have `head_dim` as a direct attribute, while others
+ like Qwen2 don't. This helper computes it from hidden_size and num_attention_heads.
+ """
+ if hasattr(config, "head_dim") and config.head_dim is not None:
+ return config.head_dim
+ return config.hidden_size // config.num_attention_heads
diff --git a/modelopt/torch/puzzletron/tools/bypassed_training/init_child_from_parent.py b/modelopt/torch/puzzletron/tools/bypassed_training/init_child_from_parent.py
new file mode 100644
index 0000000000..783d233c3e
--- /dev/null
+++ b/modelopt/torch/puzzletron/tools/bypassed_training/init_child_from_parent.py
@@ -0,0 +1,195 @@
+# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+# SPDX-License-Identifier: Apache-2.0
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# mypy: ignore-errors
+
+"""Initialize child models from parent models using AnyModel approach with deci_x_patcher."""
+
+import json
+import time
+from pathlib import Path
+from typing import Optional
+
+import torch
+import yaml
+from transformers import AutoModelForCausalLM
+
+from modelopt.torch.puzzletron.anymodel.model_descriptor import (
+ ModelDescriptor,
+ ModelDescriptorFactory,
+)
+from modelopt.torch.puzzletron.anymodel.puzzformer import deci_x_patcher
+from modelopt.torch.puzzletron.tools.bypassed_training.child_init import (
+ GQAInitMode,
+ HiddenSizeInitMode,
+ LinearInitMode,
+ MlpInitMode,
+ create_child_state_dict,
+ update_model_config,
+)
+from modelopt.torch.puzzletron.tools.checkpoint_utils import copy_tokenizer, load_state_dict
+from modelopt.torch.puzzletron.tools.checkpoint_utils_hf import _save_checkpoint, load_model_config
+from modelopt.torch.puzzletron.tools.logger import mprint
+from modelopt.torch.puzzletron.tools.sharded_checkpoint_utils import _get_model_class_from_config
+
+
+def init_child_from_parent(
+ descriptor: ModelDescriptor,
+ pruning_mixin,
+ parent_checkpoint_dir: str,
+ model_config_overrides_dict: dict | str,
+ output_checkpoint_dir: str,
+ gqa_init_mode: GQAInitMode,
+ mlp_init_mode: MlpInitMode,
+ mlp_init_config_yaml: Optional[str],
+ linear_init_mode: LinearInitMode,
+ hidden_size_init_mode: Optional[HiddenSizeInitMode] = None,
+ channel_importance_path: Optional[str] = None,
+ max_workers: Optional[int] = None, # Auto-calculate optimal workers if None
+ max_layer_workers: Optional[int] = None, # Auto-calculate optimal workers if None
+) -> None:
+ """
+ Init child models from parent models in the style of bypass training,
+ but without having to run the entire bypass pipeline.
+
+ Uses AnyModel approach with deci_x_patcher for heterogeneous layer configurations.
+
+ I/O Optimization Parameters:
+ - max_workers: Number of threads for parallel file I/O (default: auto-calculate min(CPU count, num files))
+ - max_layer_workers: Number of threads for parallel layer processing (default: auto-calculate min(CPU count, num layers))
+ """
+ assert (
+ gqa_init_mode not in [GQAInitMode.RandomKV, GQAInitMode.RandomBlock]
+ and mlp_init_mode != MlpInitMode.Random
+ and linear_init_mode != LinearInitMode.Random
+ ), (
+ "We do not support random init of any subblock in this script to avoid initializing the student model"
+ )
+
+ descriptor = ModelDescriptorFactory.get(descriptor)
+
+ copy_tokenizer(
+ parent_checkpoint_dir,
+ output_checkpoint_dir,
+ trust_remote_code=descriptor.requires_trust_remote_code(),
+ )
+
+ parent_model_config = load_model_config(
+ parent_checkpoint_dir, trust_remote_code=descriptor.requires_trust_remote_code()
+ )
+ parent_state_dict = load_state_dict(parent_checkpoint_dir)
+
+ # Parse JSON if string
+ if isinstance(model_config_overrides_dict, str):
+ model_config_overrides_dict = json.loads(model_config_overrides_dict)
+
+ # Separate global config overrides from block-level overrides
+ global_config_overrides = {}
+ block_config_overrides = {}
+
+ for key, value in model_config_overrides_dict.items():
+ if key in ["hidden_size"]:
+ global_config_overrides[key] = value
+ else:
+ block_config_overrides[key] = value
+
+ # Load child model config with global overrides
+ child_model_config = load_model_config(
+ parent_checkpoint_dir,
+ model_config_overrides=global_config_overrides,
+ ignore_unexpected_config_keys=True,
+ trust_remote_code=descriptor.requires_trust_remote_code(),
+ )
+
+ # Apply block-level overrides if any
+ if block_config_overrides:
+ child_model_config = update_model_config(
+ model_config=child_model_config,
+ model_config_overrides=block_config_overrides,
+ )
+
+ with torch.device("meta"):
+ # Pass block_configs explicitly so patcher works for VL models where
+ # decoder layers receive nested config (e.g., text_config) without block_configs
+ with deci_x_patcher(
+ model_descriptor=descriptor, block_configs=child_model_config.block_configs
+ ):
+ model_class = _get_model_class_from_config(child_model_config)
+ # AutoModelForCausalLM uses from_config(); concrete model classes use _from_config()
+ if model_class is AutoModelForCausalLM:
+ trust_remote_code = descriptor.requires_trust_remote_code()
+ child_model = model_class.from_config(
+ child_model_config, trust_remote_code=trust_remote_code
+ )
+ else:
+ child_model = model_class._from_config(child_model_config)
+
+ child_state_dict_with_meta_tensors = child_model.state_dict()
+
+ mlp_init_config = (
+ yaml.safe_load(mlp_init_config_yaml)
+ if isinstance(mlp_init_config_yaml, str)
+ else mlp_init_config_yaml
+ )
+
+ # Profile create_child_state_dict with automatic layer parallelization
+ mprint("Starting create_child_state_dict...")
+ start_time = time.time()
+ child_state_dict = create_child_state_dict(
+ pruning_mixin=pruning_mixin,
+ descriptor=descriptor,
+ original_state_dict=parent_state_dict,
+ new_state_dict=child_state_dict_with_meta_tensors,
+ original_config=parent_model_config,
+ new_config=child_model_config,
+ gqa_init_mode=gqa_init_mode,
+ mlp_init_mode=mlp_init_mode,
+ mlp_init_config=mlp_init_config,
+ linear_init_mode=linear_init_mode,
+ hidden_size_init_mode=hidden_size_init_mode or HiddenSizeInitMode.CopyAsIs,
+ channel_importance_path=channel_importance_path,
+ max_layer_workers=max_layer_workers,
+ )
+ create_child_state_dict_time = time.time() - start_time
+ mprint(f"create_child_state_dict completed in {create_child_state_dict_time:.2f} seconds")
+
+ # Profile _save_checkpoint with automatic I/O worker calculation
+ mprint("Starting _save_checkpoint...")
+ actual_io_workers = max_workers if max_workers else "auto"
+ mprint(f"I/O Settings: max_workers={actual_io_workers}")
+ start_time = time.time()
+ _save_checkpoint(
+ child_model_config,
+ child_state_dict,
+ output_checkpoint_dir,
+ descriptor,
+ max_workers=max_workers,
+ )
+ save_checkpoint_time = time.time() - start_time
+ mprint(f"_save_checkpoint completed in {save_checkpoint_time:.2f} seconds")
+
+ # Print profiling summary with actual worker counts used
+ total_core_time = create_child_state_dict_time + save_checkpoint_time
+ actual_layer_workers = max_layer_workers if max_layer_workers else "auto"
+ actual_io_workers = max_workers if max_workers else "auto"
+ mprint(f"\n=== PROFILING SUMMARY ===")
+ mprint(
+ f"create_child_state_dict: {create_child_state_dict_time:.2f}s ({create_child_state_dict_time / total_core_time * 100:.1f}%)"
+ )
+ mprint(
+ f"_save_checkpoint: {save_checkpoint_time:.2f}s ({save_checkpoint_time / total_core_time * 100:.1f}%)"
+ )
+ mprint(f"Total core processing: {total_core_time:.2f}s")
+ mprint(f"Optimizations: I/O workers={actual_io_workers}, Layer workers={actual_layer_workers}")
+ mprint(f"=========================\n")
diff --git a/modelopt/torch/puzzletron/tools/checkpoint_utils.py b/modelopt/torch/puzzletron/tools/checkpoint_utils.py
new file mode 100644
index 0000000000..4488898e33
--- /dev/null
+++ b/modelopt/torch/puzzletron/tools/checkpoint_utils.py
@@ -0,0 +1,192 @@
+# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+# SPDX-License-Identifier: Apache-2.0
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# mypy: ignore-errors
+
+"""Utilities for loading and initializing PyTorch model checkpoints (AnyModel / HF layouts)."""
+
+import concurrent.futures
+import warnings
+from functools import partial
+from pathlib import Path
+from typing import Literal, TypeVar
+
+import torch
+from safetensors.torch import load_file as safe_load_file
+from torch import nn
+from transformers import AutoTokenizer
+from transformers.utils import SAFE_WEIGHTS_INDEX_NAME, SAFE_WEIGHTS_NAME
+
+from modelopt.torch.puzzletron.tools.checkpoint_utils_hf import load_model_config
+from modelopt.torch.puzzletron.tools.common import infer_weights_dtype
+
+SAFETENSORS_SUBBLOCKS_DIR_NAME = "subblocks_safetensors"
+PTH_SUBBLOCKS_DIR_NAME = "subblocks"
+STATE_DICT_FILE_NAME = "model.pth"
+
+
+def load_state_dict(checkpoint_dir: Path | str) -> dict[str, torch.Tensor]:
+ checkpoint_dir = _normalize_checkpoint_dir(checkpoint_dir)
+
+ if (state_dict_path := checkpoint_dir / STATE_DICT_FILE_NAME).exists():
+ return torch.load(state_dict_path, map_location="cpu", weights_only=True)
+
+ if (safetensors_subblocks_dir := checkpoint_dir / SAFETENSORS_SUBBLOCKS_DIR_NAME).exists():
+ return _load_state_dict_from_subblocks(safetensors_subblocks_dir)
+
+ if (pth_subblocks_dir := checkpoint_dir / PTH_SUBBLOCKS_DIR_NAME).exists():
+ return _load_state_dict_from_subblocks(pth_subblocks_dir)
+
+ if (checkpoint_dir / SAFE_WEIGHTS_INDEX_NAME).exists() or (
+ checkpoint_dir / SAFE_WEIGHTS_NAME
+ ).exists():
+ from modelopt.torch.puzzletron.tools.sharded_checkpoint_utils import (
+ load_sharded_state_dict, # local import to avoid circular import
+ )
+
+ return load_sharded_state_dict(checkpoint_dir)
+
+ raise FileNotFoundError(
+ f"Couldn't find state dict path or subblocks dir inside {checkpoint_dir}"
+ )
+
+
+def _normalize_checkpoint_dir(checkpoint_dir: Path | str) -> Path:
+ checkpoint_dir = Path(checkpoint_dir)
+ if checkpoint_dir.is_file():
+ checkpoint_dir = checkpoint_dir.parent
+ return checkpoint_dir
+
+
+def _load_state_dict_from_subblocks(subblocks_dir: Path) -> dict[str, torch.Tensor]:
+ torch_paths = list(subblocks_dir.glob("*.pth"))
+ safetensors_paths = list(subblocks_dir.glob("*.safetensors"))
+
+ if len(torch_paths) != 0:
+ load_fn = partial(torch.load, map_location="cpu", weights_only=True)
+ file_paths = torch_paths
+ elif len(safetensors_paths) != 0:
+ load_fn = safe_load_file
+ file_paths = safetensors_paths
+ else:
+ raise ValueError(f"No tensor files found in {subblocks_dir=}")
+
+ with concurrent.futures.ThreadPoolExecutor() as executor:
+ state_dict_shards = list(executor.map(load_fn, file_paths))
+
+ state_dict = {k: v for shard in state_dict_shards for k, v in shard.items()}
+ return state_dict
+
+
+NNModule = TypeVar("NNModule", bound=nn.Module)
+
+
+def init_module_with_state_dict(
+ state_dict: dict[str, torch.Tensor],
+ module_cls: type[NNModule],
+ *init_args,
+ **init_kwargs,
+) -> NNModule:
+ weights_dtype = infer_weights_dtype(state_dict)
+ module = init_empty_module(module_cls, weights_dtype, *init_args, **init_kwargs)
+ module.load_state_dict(state_dict)
+ return module
+
+
+def init_empty_module(
+ module_cls: type[NNModule],
+ dtype: torch.dtype,
+ *init_args,
+ **init_kwargs,
+) -> NNModule:
+ default_dtype = torch.get_default_dtype()
+ current_device = torch.ones(1).device
+ torch.set_default_dtype(dtype)
+ module = skip_init(module_cls, *init_args, device=current_device, **init_kwargs)
+ torch.set_default_dtype(default_dtype)
+ return module
+
+
+def skip_init(module_cls, *args, **kwargs) -> nn.Module:
+ """Heavily inspired by torch.nn.utils.skip_init but does not require the module to accept a "device" kwarg."""
+ if not issubclass(module_cls, torch.nn.Module):
+ raise RuntimeError(f"Expected a Module; got {module_cls}")
+
+ final_device = kwargs.pop("device", "cpu")
+ with torch.device("meta"):
+ module = module_cls(*args, **kwargs)
+
+ module = module.to_empty(device=final_device)
+ return module
+
+
+def is_valid_decilm_checkpoint(checkpoint_dir: Path | str, trust_remote_code: bool = False) -> bool:
+ """True if the checkpoint config loads and defines ``block_configs`` (AnyModel / puzzletron layout).
+
+ Args:
+ checkpoint_dir: Path to checkpoint directory
+ trust_remote_code: If True, allows execution of custom code from the model repository.
+ This is a security risk if the model source is untrusted. Only set to True if you
+ trust the source of the model. Defaults to False for security.
+
+ Returns:
+ True if the config has ``block_configs``, False otherwise
+ """
+ try:
+ model_config = load_model_config(checkpoint_dir, trust_remote_code=trust_remote_code)
+ if not hasattr(model_config, "block_configs") or model_config.block_configs is None:
+ warnings.warn(
+ f"Skipping checkpoint '{checkpoint_dir}' - missing block_configs (not an AnyModel-style layout)"
+ )
+ return False
+ return True
+ except Exception as e:
+ warnings.warn(f"Skipping checkpoint '{checkpoint_dir}' - failed to load config: {e}")
+ return False
+
+
+def copy_tokenizer(
+ source_dir_or_tokenizer_name: Path | str,
+ target_dir: Path | str,
+ on_failure: Literal["raise", "warn"] = "raise",
+ trust_remote_code: bool = False,
+) -> None:
+ """Prefer loading the tokenizer from huggingface hub (when tokenizer_name.txt file is available)
+ to avoid collision between transformers versions.
+ """
+ source_tokenizer_name_path = Path(source_dir_or_tokenizer_name) / "tokenizer_name.txt"
+ if source_tokenizer_name_path.exists():
+ source_dir_or_tokenizer_name = source_tokenizer_name_path.read_text().strip()
+
+ tokenizer = None
+ try:
+ tokenizer = AutoTokenizer.from_pretrained(
+ source_dir_or_tokenizer_name, trust_remote_code=trust_remote_code
+ )
+ except Exception:
+ message = f"Couldn't load tokenizer from '{source_dir_or_tokenizer_name}'"
+ if on_failure == "raise":
+ raise FileNotFoundError(message)
+ else:
+ warnings.warn(message)
+
+ if tokenizer is not None:
+ target_dir = Path(target_dir)
+ target_dir.mkdir(exist_ok=True, parents=True)
+ tokenizer.save_pretrained(target_dir)
+
+ target_tokenizer_name_path = target_dir / "tokenizer_name.txt"
+ is_given_tokenizer_name_as_argument = not Path(source_dir_or_tokenizer_name).exists()
+ if is_given_tokenizer_name_as_argument:
+ target_tokenizer_name_path.write_text(source_dir_or_tokenizer_name)
diff --git a/modelopt/torch/puzzletron/tools/checkpoint_utils_hf.py b/modelopt/torch/puzzletron/tools/checkpoint_utils_hf.py
new file mode 100644
index 0000000000..1c6dcb36d8
--- /dev/null
+++ b/modelopt/torch/puzzletron/tools/checkpoint_utils_hf.py
@@ -0,0 +1,408 @@
+# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+# SPDX-License-Identifier: Apache-2.0
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# mypy: ignore-errors
+
+"""
+Utilities for loading and saving Hugging Face-format checkpoints (``AutoConfig`` + optional ``block_configs``).
+"""
+
+import concurrent.futures
+import contextlib
+import dataclasses
+import fcntl
+import os
+import time
+import warnings
+from collections import defaultdict
+from collections.abc import Callable, Mapping
+from pathlib import Path
+from typing import Any, BinaryIO
+
+import torch
+import transformers
+from safetensors.torch import save_file as safe_save_file
+from transformers import AutoConfig, AutoModelForCausalLM, PretrainedConfig, PreTrainedModel
+from transformers.dynamic_module_utils import get_class_from_dynamic_module
+from transformers.utils import SAFE_WEIGHTS_INDEX_NAME
+
+from modelopt.torch.puzzletron.decilm.deci_lm_hf_code.block_config import maybe_cast_block_configs
+from modelopt.torch.puzzletron.tools.common import infer_weights_dtype
+from modelopt.torch.puzzletron.tools.logger import mprint
+from modelopt.torch.puzzletron.tools.post_init_sparse import SparsityMethod
+from modelopt.torch.puzzletron.tools.robust_json import json_dumps
+
+SAFETENSORS_SUBBLOCKS_DIR_NAME = "subblocks_safetensors"
+PTH_SUBBLOCKS_DIR_NAME = "subblocks"
+RELATIVE_SUBBLOCKS_DIR = Path(SAFETENSORS_SUBBLOCKS_DIR_NAME)
+
+
+# TODO: (esegal) Should ask the model for something like this
+NON_LAYER_MODULE_TO_FILE_TYPE = {
+ "model.embed_tokens": "embeddings",
+ "model.norm": "lm_head",
+ "lm_head": "lm_head",
+}
+MODULE_WITHIN_LAYER_TO_FILE_TYPE = {
+ "input_layernorm": "attention",
+ "self_attn": "attention",
+ "post_attention_layernorm": "ffn",
+ "mlp": "ffn",
+ "parallel_blocks": "multi_block",
+}
+LAYERS_MODULE_NAME = "model.layers"
+
+
+def force_cache_dynamic_modules(
+ config: PretrainedConfig, checkpoint_dir: Path | str, trust_remote_code: bool = False
+):
+ has_remote_code = (
+ hasattr(config, "auto_map")
+ and isinstance(config.auto_map, dict)
+ and "AutoConfig" in config.auto_map.keys()
+ )
+ if has_remote_code and trust_remote_code:
+ for class_reference in config.auto_map.values():
+ _ = get_class_from_dynamic_module(class_reference, checkpoint_dir)
+
+
+def load_model_config(
+ checkpoint_dir: Path | str,
+ model_config_overrides: Mapping | None = None,
+ ignore_unexpected_config_keys: bool = False,
+ trust_remote_code: bool = False,
+):
+ """Load model configuration from a checkpoint directory.
+
+ Args:
+ checkpoint_dir: Path to the checkpoint directory (e.g. containing config.json).
+ model_config_overrides: Optional mapping of config overrides.
+ ignore_unexpected_config_keys: If True, ignore unexpected config keys.
+ trust_remote_code: If True, allows execution of custom code from the model repository.
+ This is a security risk if the model source is untrusted. Only set to True if you
+ trust the source of the model. Defaults to False for security.
+
+ Returns:
+ Loaded model configuration (PretrainedConfig).
+ """
+ if not isinstance(checkpoint_dir, Path):
+ checkpoint_dir = Path(checkpoint_dir)
+
+ if model_config_overrides is None:
+ model_config_overrides = {}
+
+ config, unused_kwargs = AutoConfig.from_pretrained(
+ checkpoint_dir,
+ trust_remote_code=trust_remote_code,
+ return_unused_kwargs=True,
+ **model_config_overrides,
+ )
+ if hasattr(config, "block_configs"):
+ config.block_configs = maybe_cast_block_configs(config.block_configs)
+
+ force_cache_dynamic_modules(config, checkpoint_dir, trust_remote_code=trust_remote_code)
+
+ if not ignore_unexpected_config_keys:
+ if unused_kwargs:
+ raise ValueError(f"Unexpected config keys: {unused_kwargs.keys()}")
+
+ return config
+
+
+def _get_model_class_from_config(config: PretrainedConfig) -> type:
+ """Resolve HuggingFace model class from ``config.architectures`` (see puzzletron checkpoint_utils_hf)."""
+ if hasattr(config, "architectures") and config.architectures:
+ model_class_name = config.architectures[0]
+ if hasattr(transformers, model_class_name):
+ return getattr(transformers, model_class_name)
+ mprint(
+ f"Warning: {model_class_name} not found in transformers, "
+ "falling back to AutoModelForCausalLM"
+ )
+ return AutoModelForCausalLM
+
+
+def init_model_from_config(
+ config: PretrainedConfig,
+ *,
+ trust_remote_code: bool = False,
+ **kwargs,
+) -> PreTrainedModel:
+ """Build a model from config on meta/uninitialized weights (used e.g. for subblock param counts).
+
+ ``trust_remote_code`` defaults to False (only ``AutoModelForCausalLM.from_config`` uses it).
+ Pass True when loading configs that rely on custom modeling code from the checkpoint.
+ """
+ model_class = _get_model_class_from_config(config)
+ if model_class is AutoModelForCausalLM:
+ return model_class.from_config(config, trust_remote_code=trust_remote_code, **kwargs)
+ # Concrete model classes (e.g. GptOssForCausalLM): _from_config forwards kwargs to __init__,
+ # which does not accept trust_remote_code (only AutoModel uses it when loading custom code).
+ return model_class._from_config(config, **kwargs)
+
+
+def save_checkpoint(
+ model: PreTrainedModel,
+ checkpoint_dir: Path | str,
+ descriptor: "ModelDescriptor",
+) -> None:
+ _save_checkpoint(model.config, model.state_dict(), checkpoint_dir, descriptor)
+
+
+def _save_checkpoint(
+ model_config: PretrainedConfig,
+ state_dict: dict[str, torch.Tensor],
+ checkpoint_dir: Path | str,
+ descriptor: "ModelDescriptor",
+ max_workers: int | None = None, # Now optional - will auto-calculate if None
+) -> None:
+ from modelopt.torch.puzzletron.anymodel.model_descriptor import ModelDescriptor
+
+ if not isinstance(checkpoint_dir, Path):
+ checkpoint_dir = Path(checkpoint_dir)
+
+ checkpoint_dir.mkdir(parents=True, exist_ok=True)
+
+ # Phase 1: Save config
+ save_model_config(model_config, checkpoint_dir)
+
+ # Phase 2: Build weight map using descriptor and write index
+ subblock_keys = descriptor.get_weight_groups(
+ layer_names=state_dict.keys(),
+ num_hidden_layers=model_config.num_hidden_layers,
+ )
+
+ weight_map = {}
+ for subblock, layer_keys in subblock_keys.items():
+ weight_map_entries = {
+ key: f"subblocks_safetensors/{subblock}.safetensors" for key in layer_keys
+ }
+ weight_map.update(weight_map_entries)
+
+ # Handle tie_word_embeddings - remove from state_dict and weight_map BEFORE writing index
+ output_emb_weight_name = f"{descriptor.output_embedding_name()}.weight"
+ if getattr(model_config, "tie_word_embeddings", False) and output_emb_weight_name in state_dict:
+ state_dict = {k: v for k, v in state_dict.items() if k != output_emb_weight_name}
+ weight_map = {k: v for k, v in weight_map.items() if k != output_emb_weight_name}
+
+ # Write index (now without tied embedding)
+ index = {"metadata": {"format": "pt"}, "weight_map": weight_map}
+ index_path = checkpoint_dir / SAFE_WEIGHTS_INDEX_NAME
+ index_json = json_dumps(index)
+ _write_file_process_safe(index_json, index_path)
+
+ # Phase 3: Save subblocks
+ save_subblocks(
+ state_dict,
+ checkpoint_dir,
+ weight_map=weight_map,
+ multi_threaded=True,
+ max_workers=max_workers,
+ )
+
+
+def save_subblocks(
+ state_dict: dict[str, torch.Tensor],
+ checkpoint_dir: Path | str,
+ weight_map: dict[str, str] | None = None,
+ multi_threaded: bool = True,
+ max_workers: int | None = None, # Now optional - will auto-calculate if None
+) -> None:
+ mprint("=== Starting save_subblocks detailed profiling ===")
+ subblocks_start_time = time.time()
+
+ if not isinstance(checkpoint_dir, Path):
+ checkpoint_dir = Path(checkpoint_dir)
+
+ # Step 1: Build weight map (use provided or build from state_dict)
+ weight_map_start_time = time.time()
+ if weight_map is None:
+ weight_map = _build_safetensors_weight_map(
+ state_dict=state_dict,
+ non_layer_module_to_file_type=NON_LAYER_MODULE_TO_FILE_TYPE,
+ module_within_layer_to_file_type=MODULE_WITHIN_LAYER_TO_FILE_TYPE,
+ layers_module_name=LAYERS_MODULE_NAME,
+ )
+ weight_name_to_filename = {k: checkpoint_dir / v for k, v in weight_map.items()}
+ weight_map_time = time.time() - weight_map_start_time
+ mprint(f" Step 1 - Build weight map: {weight_map_time:.2f}s ({len(weight_map)} mappings)")
+
+ # Step 2: Create subblocks directory
+ dir_create_start_time = time.time()
+ subblocks_path = checkpoint_dir / SAFETENSORS_SUBBLOCKS_DIR_NAME
+ subblocks_path.mkdir(parents=True, exist_ok=True)
+ dir_create_time = time.time() - dir_create_start_time
+ mprint(f" Step 2 - Create directory: {dir_create_time:.2f}s")
+
+ # Step 3: Organize tensors by file
+ organize_start_time = time.time()
+ filename_to_partial_state_dict = defaultdict(dict)
+ total_tensor_size = 0
+ for weight_name, weight in state_dict.items():
+ if weight_name in weight_map:
+ # Ensure tensor is contiguous and on CPU for faster I/O
+ tensor = (
+ weight.contiguous().cpu() if weight.device.type != "cpu" else weight.contiguous()
+ )
+ filename_to_partial_state_dict[weight_name_to_filename[weight_name]][weight_name] = (
+ tensor
+ )
+ total_tensor_size += weight.numel() * weight.element_size()
+ organize_time = time.time() - organize_start_time
+ mprint(
+ f" Step 3 - Organize tensors: {organize_time:.2f}s ({total_tensor_size / (1024**3):.2f}GB total)"
+ )
+
+ # Step 4: Prepare save arguments and auto-calculate optimal I/O workers
+ prepare_start_time = time.time()
+ safe_save_kwargs = [
+ {"tensors": partial_state_dict, "filename": filename, "metadata": {"format": "pt"}}
+ for filename, partial_state_dict in filename_to_partial_state_dict.items()
+ ]
+
+ # Auto-calculate optimal I/O workers: min(cpu_count, num_files)
+ if max_workers is None:
+ cpu_count = os.cpu_count() or 1
+ num_files = len(safe_save_kwargs)
+ max_workers = min(cpu_count, num_files)
+ mprint(
+ f" Auto-calculated I/O workers: min({cpu_count} CPUs, {num_files} files) = {max_workers}"
+ )
+ else:
+ mprint(f" Using specified I/O workers: {max_workers}")
+
+ prepare_time = time.time() - prepare_start_time
+ mprint(f" Step 4 - Prepare save args: {prepare_time:.2f}s ({len(safe_save_kwargs)} files)")
+
+ # Step 5: Save files with optimal worker count
+ save_start_time = time.time()
+ if multi_threaded:
+ mprint(f" Using multi-threaded saving with {max_workers} workers...")
+
+ def optimized_safe_save(kwargs):
+ try:
+ safe_save_file(**kwargs)
+ return True
+ except Exception as e:
+ mprint(f" Error saving {kwargs['filename']}: {e}")
+ return False
+
+ with concurrent.futures.ThreadPoolExecutor(max_workers=max_workers) as executor:
+ results = list(executor.map(optimized_safe_save, safe_save_kwargs))
+
+ # Check for any failures
+ failed_saves = sum(1 for r in results if not r)
+ if failed_saves > 0:
+ mprint(f" Warning: {failed_saves} files failed to save")
+ else:
+ mprint(" Using single-threaded saving...")
+ for kwargs in safe_save_kwargs:
+ safe_save_file(**kwargs)
+
+ save_time = time.time() - save_start_time
+ mprint(f" Step 5 - Save files: {save_time:.2f}s ({max_workers} workers)")
+
+ subblocks_total_time = time.time() - subblocks_start_time
+ mprint(f"=== save_subblocks completed in {subblocks_total_time:.2f}s ===")
+ mprint(
+ f" Breakdown: WeightMap {weight_map_time:.1f}s + DirCreate {dir_create_time:.1f}s + "
+ f"Organize {organize_time:.1f}s + Prepare {prepare_time:.1f}s + Save {save_time:.1f}s"
+ )
+
+ # Calculate effective I/O speed
+ io_speed_gbps = (total_tensor_size / (1024**3)) / save_time if save_time > 0 else 0
+ mprint(f" Effective I/O speed: {io_speed_gbps:.2f} GB/s ({max_workers} workers)")
+ mprint(f" Save operation was {save_time / subblocks_total_time * 100:.1f}% of total time")
+
+
+def _write_text(content: str, f: BinaryIO) -> None:
+ f.write(content.encode("utf-8"))
+
+
+def _write_file_process_safe(
+ content: Any,
+ path: Path | str,
+ write_fn: Callable[[Any, BinaryIO], None] = _write_text,
+) -> None:
+ """
+ Write a file in a multi-process safe way.
+ If another process tries to write the same file using this method, the current process
+ "gives up" and assumes that the matter is being taken care of by another process.
+
+ write_fn is a function that receives file contents and a binary file object,
+ and writes the content to the file. It can be _write_text (defined above), or torch.save,
+ or a similar function (not safetensors.torch.save_file since it expects a path).
+ """
+ with open(path, "wb") as f:
+ # Try to acquire an exclusive, non-blocking lock
+ try:
+ fcntl.flock(f, fcntl.LOCK_EX | fcntl.LOCK_NB)
+ except BlockingIOError:
+ return # Exit immediately if the lock is not acquired
+
+ write_fn(content, f) # Write the content if lock is acquired
+ f.flush() # Ensure data is written to disk
+
+ # Release the lock
+ fcntl.flock(f, fcntl.LOCK_UN)
+
+
+def _build_safetensors_weight_map(
+ *,
+ state_dict: dict[str, torch.Tensor],
+ non_layer_module_to_file_type: dict[str, str],
+ module_within_layer_to_file_type: dict[str, str],
+ layers_module_name: str,
+) -> dict[str, Path]:
+ weight_map = {}
+ unmapped_weight_names = []
+ for weight_name in state_dict:
+ found_match = False
+ for module_name, file_type in non_layer_module_to_file_type.items():
+ if weight_name.startswith(f"{module_name}."):
+ weight_map[weight_name] = str(RELATIVE_SUBBLOCKS_DIR / f"{file_type}.safetensors")
+ found_match = True
+ if not found_match:
+ if weight_name.startswith(f"{layers_module_name}."):
+ name_parts = weight_name[len(layers_module_name) + 1 :].split(".")
+ layer_index = name_parts[0]
+ name_within_layer = ".".join(name_parts[1:])
+
+ for module_name, file_type in module_within_layer_to_file_type.items():
+ if name_within_layer.startswith(f"{module_name}."):
+ weight_map[weight_name] = str(
+ RELATIVE_SUBBLOCKS_DIR / f"block_{layer_index}_{file_type}.safetensors"
+ )
+ found_match = True
+
+ if not found_match:
+ unmapped_weight_names.append(weight_name)
+
+ if len(unmapped_weight_names) > 0:
+ raise ValueError(
+ f"Unmapped weight names: {unmapped_weight_names}\n"
+ f"Add them to the `non_layer_module_to_file_type` or "
+ f"`module_within_layer_to_file_type` dictionaries."
+ )
+
+ return weight_map
+
+
+def save_model_config(model_config: PretrainedConfig, checkpoint_dir: Path | str) -> None:
+ if hasattr(model_config, "block_configs"):
+ model_config.block_configs = [
+ dataclasses.asdict(conf) if dataclasses.is_dataclass(conf) else conf
+ for conf in model_config.block_configs
+ ]
+ model_config.save_pretrained(checkpoint_dir)
diff --git a/modelopt/torch/puzzletron/tools/common.py b/modelopt/torch/puzzletron/tools/common.py
new file mode 100644
index 0000000000..96db572802
--- /dev/null
+++ b/modelopt/torch/puzzletron/tools/common.py
@@ -0,0 +1,22 @@
+# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+# SPDX-License-Identifier: Apache-2.0
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import torch
+
+
+def infer_weights_dtype(state_dict: dict[str, torch.Tensor]) -> torch.dtype:
+ weights_dtype = [p.dtype for p in state_dict.values() if torch.is_floating_point(p)]
+ weights_dtype = weights_dtype[0] if len(weights_dtype) > 0 else torch.get_default_dtype()
+ return weights_dtype
diff --git a/modelopt/torch/puzzletron/tools/hydra_utils.py b/modelopt/torch/puzzletron/tools/hydra_utils.py
new file mode 100644
index 0000000000..64c4035656
--- /dev/null
+++ b/modelopt/torch/puzzletron/tools/hydra_utils.py
@@ -0,0 +1,81 @@
+# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+# SPDX-License-Identifier: Apache-2.0
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""
+Utilities for hydra config initialization.
+"""
+
+import datetime
+import random
+from pathlib import Path
+
+from hydra import compose, initialize, initialize_config_dir
+from hydra.utils import get_object
+from omegaconf import DictConfig, OmegaConf
+
+
+def warmup_steps(tokens: int, block: int, mbs: int, pct: float = 0.05) -> int:
+ """
+ Calculate warmup steps based on total tokens, block size, micro batch size, and warmup percentage.
+ Used as a resolver in hydra configs.
+ """
+ steps = (int(tokens) // int(block)) // int(mbs)
+ w = pct * steps
+ return max(1, round(w))
+
+
+def register_hydra_resolvers():
+ OmegaConf.register_new_resolver("to_path", lambda x: Path(x))
+ OmegaConf.register_new_resolver(
+ "random_int", lambda low, high: random.randint(int(low), int(high))
+ )
+ OmegaConf.register_new_resolver(
+ "timedelta_minutes", lambda x: datetime.timedelta(minutes=x) if x is not None else None
+ )
+ OmegaConf.register_new_resolver("warmup_steps", lambda t, b, m, p: warmup_steps(t, b, m, p))
+ OmegaConf.register_new_resolver("get_object", lambda x: get_object(x))
+
+
+def initialize_hydra_config_for_dir(
+ config_dir: str, config_name: str, overrides: list[str]
+) -> DictConfig:
+ """Initialize a hydra config from an absolute path for a config directory
+
+ Args:
+ config_dir (str):
+ config_name (str):
+ overrides (List[str]):
+
+ Returns:
+ DictConfig:
+ """
+
+ with initialize_config_dir(version_base=None, config_dir=config_dir):
+ args = compose(config_name, overrides)
+ args._set_flag("allow_objects", True)
+ OmegaConf.resolve(args) # resolve object attributes
+ OmegaConf.set_struct(args, False)
+
+ return args
+
+
+def initialize_hydra_config(config_path: str, config_name: str, overrides: list[str]) -> DictConfig:
+ with initialize(version_base=None, config_path=config_path):
+ args = compose(config_name, overrides)
+ args._set_flag("allow_objects", True)
+ OmegaConf.resolve(args) # resolve object attributes
+ OmegaConf.set_struct(args, False)
+
+ return args
diff --git a/modelopt/torch/puzzletron/tools/kd_model.py b/modelopt/torch/puzzletron/tools/kd_model.py
new file mode 100644
index 0000000000..8590c3f56c
--- /dev/null
+++ b/modelopt/torch/puzzletron/tools/kd_model.py
@@ -0,0 +1,53 @@
+# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+# SPDX-License-Identifier: Apache-2.0
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""
+Knowledge distillation loss functions.
+
+Provides normalized_mse_loss and cosine_embedding_loss_batched for comparing
+model outputs. Used by validation.py.
+"""
+# mypy: ignore-errors
+
+from abc import ABCMeta, abstractmethod
+from typing import Callable, List, Literal, Optional, Tuple
+
+import torch
+import torch.nn.functional as F
+from torch import Tensor, nn
+
+
+def normalized_mse_loss(
+ input: Tensor,
+ target: Tensor,
+ reduction: Literal["none", "mean", "sum"] = "mean",
+ epsilon: float = 1e-6,
+) -> Tensor:
+ loss = F.mse_loss(input, target, reduction=reduction) / F.mse_loss(
+ target, torch.zeros_like(target) + epsilon, reduction=reduction
+ )
+ return loss
+
+
+def cosine_embedding_loss_batched(input: Tensor, target: Tensor) -> Tensor:
+ # inputs are of shape (B,T,H)
+ batch_size = input.size(0)
+ input = input.view(batch_size, -1)
+ target = target.view(batch_size, -1)
+ target_tensor = input.new(input.size(0)).fill_(1)
+ loss = F.cosine_embedding_loss(
+ input1=input, input2=target, target=target_tensor, reduction="none"
+ )
+ return loss
diff --git a/modelopt/torch/puzzletron/tools/logger.py b/modelopt/torch/puzzletron/tools/logger.py
new file mode 100644
index 0000000000..257e55abe3
--- /dev/null
+++ b/modelopt/torch/puzzletron/tools/logger.py
@@ -0,0 +1,168 @@
+# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+# SPDX-License-Identifier: Apache-2.0
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# mypy: ignore-errors
+import inspect
+import logging
+import os
+import sys
+
+import torch.distributed.launch # noqa: F401
+
+logging.getLogger("fsspec.local").setLevel(logging.ERROR)
+logging.getLogger("websockets.client").setLevel(logging.WARN)
+logging.getLogger("websockets.server").setLevel(logging.WARN)
+logging.getLogger("websockets.server:connection").setLevel(logging.WARN)
+
+
+class LogColors:
+ BLUE = "\033[94m"
+ CYAN = "\033[96m"
+ GREEN = "\033[92m"
+ YELLOW = "\033[93m"
+ RED = "\033[91m"
+
+ BOLD = "\033[1m"
+ UNDERLINE = "\033[4m"
+ RESET = "\033[0m"
+
+
+class DistributedLogger(logging.Logger):
+ verbosity = logging.ERROR
+
+ def __init__(self, name, level=logging.DEBUG):
+ super().__init__(name, level)
+ self.local_rank = int(os.environ.get("LOCAL_RANK", 0))
+ self.global_rank = int(os.environ.get("RANK", 0))
+ self.world_size = int(os.environ.get("WORLD_SIZE", 1))
+
+ def dist_log(self, msg: str, ranks: str = "main"):
+ """Log parameter msg with the given ranks.
+
+ Args:
+ msg: The message to log.
+ ranks: The ranks to log the message to. Choices are:
+ "all": log with all ranks
+ "main": log with only rank 0 in node 0
+ "last": log with only rank -1 in node 0
+ "local_main": log with only rank 0 in all nodes
+ """
+ # print(msg, ranks)
+ if ranks not in ["all", "main", "local_main", "last"]:
+ raise NotImplementedError(
+ f"Could not broadcast msg {msg} - "
+ f"ranks parameters choices are ['all', 'main', 'local_main', 'last']. Got {ranks}"
+ )
+ # All ranks to print
+ if ranks == "all":
+ pass
+
+ # Only main rank at node 0 to print
+ elif (
+ (ranks == "main" and self.global_rank != 0)
+ or (ranks == "last" and self.global_rank != self.world_size - 1)
+ or (ranks == "local_main" and self.local_rank != 0)
+ ):
+ return
+
+ message_source = self.get_caller_location()
+
+ self.info(
+ f"{LogColors.GREEN}[rank-{self.global_rank}]{LogColors.RESET}[{message_source}]\t{msg}"
+ )
+
+ # def dist_warning(self, msg):
+ # if self.verbosity <= logging.WARNING:
+ # self.warning(f"[rank-{self.global_rank}] " + msg)
+
+ @staticmethod
+ def get_caller_location() -> str:
+ # Get the caller's stack frame
+ frame = inspect.currentframe()
+
+ # f_back -> class method, 2 x f_back -> utils method, 3 x f_back -> original source
+ caller_frame = frame.f_back.f_back.f_back
+
+ # Get the filename and line number from the caller's stack frame
+ filename = os.path.basename(caller_frame.f_code.co_filename)
+ lineno = caller_frame.f_lineno
+ return f"{filename}:{lineno}"
+
+
+# Initialize logger
+logging.setLoggerClass(DistributedLogger)
+logger = logging.getLogger(__name__)
+logger.propagate = False
+
+formatter = logging.Formatter("[%(asctime)s]%(message)s")
+handler = logging.StreamHandler(sys.stdout)
+handler.setFormatter(formatter)
+handler.setLevel(logging.DEBUG)
+logger.addHandler(handler)
+
+# Manually edit torch logger
+torch_logger = logging.getLogger("torch")
+torch_logger.handlers = logger.handlers
+torch_logger.propagate = False
+
+# Manually edit deepspeed logger
+
+# Show some love to Mac & Windows users who can't easily install deepspeed ;)
+# This is allowing running tests on Mac & Windows and train in non-DDP
+try:
+ from deepspeed.utils import logger as deepspeed_logger
+
+ deepspeed_logger.handlers = logger.handlers
+ deepspeed_logger.propagate = False
+except ImportError:
+ # If deepspeed is not installed - no op
+ pass
+
+# Define a custom function to redirect warnings to logger
+# def custom_warning_handler(message, category, filename, lineno, file=None, line=None):
+# logger.dist_warning(f'{category.__name__}: {message} (in {filename}, line {lineno})')
+
+
+# Use the custom warning handler
+# warnings.showwarning = custom_warning_handler
+
+logger: DistributedLogger
+
+
+def aprint(msg: str | None):
+ """
+ All ranks from all nodes prints
+ """
+ return logger.dist_log(msg=msg, ranks="all")
+
+
+def lmprint(msg: str | None):
+ """
+ All local main ranks prints (rank 0 in each node)
+ """
+ return logger.dist_log(msg=msg, ranks="local_main")
+
+
+def mprint(msg: str | None):
+ """
+ Master prints only (rank 0 in node 0)
+ """
+ return logger.dist_log(msg=msg, ranks="main")
+
+
+def lprint(msg: str | None):
+ """
+ Last rank prints only (rank -1 in node 0)
+ """
+ return logger.dist_log(msg=msg, ranks="last")
diff --git a/modelopt/torch/puzzletron/tools/post_init_sparse.py b/modelopt/torch/puzzletron/tools/post_init_sparse.py
new file mode 100644
index 0000000000..eb20250e68
--- /dev/null
+++ b/modelopt/torch/puzzletron/tools/post_init_sparse.py
@@ -0,0 +1,123 @@
+# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+# SPDX-License-Identifier: Apache-2.0
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# mypy: ignore-errors
+import torch
+from torch import nn
+from torch.nn.utils.prune import custom_from_mask
+
+"""
+Converts a state dictionary from PyTorch's pruning format (with _orig and _mask suffixes)
+into a standard format with sparsified weights.
+"""
+
+
+class SparsityMethod:
+ def calculate_masks(self, state_dict: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]:
+ """Gets a model state_dict, returns a state_dict-like mask_dict with masks"""
+
+ @staticmethod
+ def fix_state_dict_inplace(state_dict, verbose=False, change_dtype=False):
+ sparsity_masks = {}
+ for name in list(state_dict.keys()):
+ original_name = name.replace("_orig", "")
+ mask_name = original_name + "_mask"
+ if name[-4:] == "orig" and mask_name in state_dict:
+ val = state_dict[name]
+ mask = state_dict[name[:-4] + "mask"]
+ val[mask == 0] = 0
+ sparsity = (val == 0).sum() / mask.numel()
+ sparsity_masks[original_name[:-7]] = mask
+ if verbose:
+ print(f"fix_state_dict_inplace: {name} {sparsity=}")
+ del state_dict[mask_name]
+ del state_dict[name]
+ state_dict[original_name] = val
+ if change_dtype:
+ for name in state_dict:
+ state_dict[name] = state_dict[name].to(torch.bfloat16)
+ return state_dict, sparsity_masks
+
+ def filter_function(self):
+ pass
+
+ def apply_masks(self, model: nn.Module, mask_dict: dict[str, torch.Tensor]) -> None:
+ for name, module in model.named_modules():
+ if name in mask_dict:
+ custom_from_mask(module, "weight", mask_dict[name].to(module.weight.device))
+ print(name)
+ print(torch.sum(mask_dict[name]) / mask_dict[name].numel())
+
+ def do_sparsity(self, model: nn.Module, mask_dict=None):
+ full_name_layers = []
+ for block_idx, block_config in enumerate(model.config.block_configs):
+ ffn_names = block_config.ffn.sparsify # layers_to_sparsify_pattern[block_idx]
+ att_name = block_config.attention.sparsify
+ block = model.model.layers[block_idx]
+ if hasattr(block, "mlp"):
+ for name, m in block.mlp.named_modules():
+ if isinstance(m, torch.nn.Linear) and self.filter_function(name, ffn_names):
+ full_name_layers.append(
+ "model.layers." + str(block_idx) + "." + "mlp." + name
+ )
+ if hasattr(block, "self_attn"):
+ for name, m in block.self_attn.named_modules():
+ if isinstance(m, torch.nn.Linear) and self.filter_function(name, att_name):
+ full_name_layers.append(
+ "model.layers." + str(block_idx) + "." + "self_attn." + name
+ )
+
+ if mask_dict is None:
+ state_dict_for_sparsifying = {
+ k.rstrip(".weight"): v
+ for k, v in model.state_dict().items()
+ if k.rstrip(".weight") in full_name_layers
+ }
+ mask_dict = self.calculate_masks(state_dict_for_sparsifying)
+ # print('Apply sparsity')
+ # print(full_name_layers)
+ # print(model.state_dict().keys())
+ # print(list(mask_dict.keys()))
+
+ self.apply_masks(model, mask_dict)
+
+
+class SparsityMethod2o4(SparsityMethod):
+ def calculate_masks(self, state_dict: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]:
+ """Gets a model state_dict, returns a state_dict-like mask_dict with masks"""
+ mask_dict = {}
+ for key, val in state_dict.items():
+ orig_size = val.shape
+ scores = val.flatten() ** 2
+ mask = self.create_mask(scores)
+ mask = mask.reshape(orig_size)
+ mask_dict[key] = mask
+ return mask_dict
+
+ def create_mask(self, score, value=0):
+ score = score # .cpu()
+ orig_size = score.shape
+ score = score.view(-1, 4)
+ mask = torch.zeros(score.shape)
+ values, indices = torch.topk(score, 2, dim=1)
+ rows = torch.arange(mask.size(0)).unsqueeze(-1)
+ mask[rows, indices] = 1
+ mask = mask.view(orig_size)
+ return mask # dev = score.device, return mask.to(dev)
+
+ @staticmethod
+ def filter_function(name, modules_to_sparsify_in_block):
+ if modules_to_sparsify_in_block is None:
+ return False
+ return name in modules_to_sparsify_in_block
diff --git a/modelopt/torch/puzzletron/tools/robust_json.py b/modelopt/torch/puzzletron/tools/robust_json.py
new file mode 100644
index 0000000000..3397de6393
--- /dev/null
+++ b/modelopt/torch/puzzletron/tools/robust_json.py
@@ -0,0 +1,77 @@
+# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+# SPDX-License-Identifier: Apache-2.0
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# mypy: ignore-errors
+
+"""
+Provides a robust JSON encoder that can handle various types of objects,
+including dataclasses, paths, enums, namespaces, and functions.
+"""
+
+import argparse
+import dataclasses
+import datetime
+import inspect
+import json
+from enum import Enum
+from pathlib import Path
+from typing import Any
+
+from omegaconf import DictConfig, ListConfig, OmegaConf
+
+
+class RobustJSONEncoder(json.JSONEncoder):
+ def default(self, o):
+ if dataclasses.is_dataclass(o):
+ return dataclasses.asdict(o)
+ if isinstance(o, Path):
+ return str(o)
+ if isinstance(o, Enum):
+ return o.name
+ if isinstance(o, argparse.Namespace):
+ return vars(o)
+ if type(o).__name__ == "dtype":
+ return str(o)
+ if isinstance(o, (DictConfig, ListConfig)):
+ return OmegaConf.to_container(o, resolve=True)
+ if inspect.isfunction(o) or inspect.ismethod(o):
+ if o.__module__ == "__main__":
+ # User-defined function in main — fallback to just the name
+ return o.__name__
+ return f"{o.__module__}.{o.__qualname__}"
+ if inspect.isclass(o):
+ return f"{o.__module__}.{o.__qualname__}"
+ if isinstance(o, datetime.timedelta):
+ return str(o)
+ # Fallback for arbitrary objects: return their class path
+ if hasattr(o, "__class__") and hasattr(o.__class__, "__module__"):
+ return f"{o.__class__.__module__}.{o.__class__.__qualname__}"
+ return super().default(o)
+
+
+def json_dumps(obj: Any) -> str:
+ return json.dumps(obj, cls=RobustJSONEncoder, indent=2)
+
+
+def json_dump(obj: Any, path: Path | str) -> None:
+ path = Path(path)
+ path.parent.mkdir(exist_ok=True, parents=True)
+ json_text = json_dumps(obj)
+ path.write_text(json_text)
+
+
+def json_load(path: Path | str) -> dict:
+ path = Path(path)
+ text = path.read_text()
+ return json.loads(text)
diff --git a/modelopt/torch/puzzletron/tools/sharded_checkpoint_utils.py b/modelopt/torch/puzzletron/tools/sharded_checkpoint_utils.py
new file mode 100644
index 0000000000..c18867a576
--- /dev/null
+++ b/modelopt/torch/puzzletron/tools/sharded_checkpoint_utils.py
@@ -0,0 +1,404 @@
+# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+# SPDX-License-Identifier: Apache-2.0
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# mypy: ignore-errors
+
+"""
+Provides utilities for distributed loading, saving, and manipulation of
+large language model checkpoints across multiple GPUs/processes.
+
+Uses native HuggingFace models with deci_x_patcher for heterogeneous layer configurations.
+"""
+
+import json
+from collections.abc import Iterable, Mapping
+from pathlib import Path
+from types import SimpleNamespace
+from typing import Literal, Type, cast
+
+import numpy as np
+import torch
+import torch.distributed
+import torch.nn as nn
+import transformers
+from huggingface_hub import split_torch_state_dict_into_shards
+from safetensors import safe_open
+from safetensors.torch import load_file as safe_load_file
+from safetensors.torch import save_file as safe_save_file
+from tqdm import tqdm
+from transformers import AutoConfig, AutoModelForCausalLM, PretrainedConfig
+from transformers.utils import SAFE_WEIGHTS_INDEX_NAME, SAFE_WEIGHTS_NAME
+from transformers.utils.hub import cached_file, get_checkpoint_shard_files
+
+import modelopt.torch.utils.distributed as dist
+from modelopt.torch.puzzletron.tools.checkpoint_utils import load_model_config, load_state_dict
+from modelopt.torch.puzzletron.tools.logger import mprint
+from modelopt.torch.puzzletron.utils.dummy_modules import (
+ DummyBlock,
+ DummyLMHead,
+ DummyModule,
+ DummyWTE,
+)
+from modelopt.torch.puzzletron.utils.utils import EmptyInitOnDevice
+
+
+def set_submodule(model: nn.Module, module_name: str, new_submodule: nn.Module) -> None:
+ """Set a submodule on a model by dotted path."""
+ parts = module_name.split(".")
+ parent_path = ".".join(parts[:-1])
+ attr = parts[-1]
+ parent_module = model.get_submodule(parent_path) if parent_path else model
+ setattr(parent_module, attr, new_submodule)
+
+
+def create_local_shard_(model, owned_block_indexes: set[int], descriptor, runtime):
+ # Get language model config (handles nested configs like Qwen3-VL's text_config)
+ lm_config = descriptor.get_language_model_config(model.config)
+ all_block_indexes = set(range(lm_config.num_hidden_layers))
+ has_first_block = 0 in owned_block_indexes
+ has_last_block = max(all_block_indexes) in owned_block_indexes
+
+ unowned_block_indexes = all_block_indexes - owned_block_indexes
+ for block_index in unowned_block_indexes:
+ decoder_layer_name = descriptor.layer_block_name(block_index)
+ decoder_layer = model.get_submodule(decoder_layer_name)
+ set_submodule(
+ model,
+ decoder_layer_name,
+ descriptor.create_dummy_block(decoder_layer, block_index=block_index),
+ )
+
+ # If we have the last block with tied embeddings, keep embed_tokens so lm_head works.
+ # load_sharded_state_dict will load embed_tokens.weight from the first shard's checkpoint file,
+ # and since they're tied, lm_head.weight gets populated too.
+ if not has_first_block and not (has_last_block and model.config.tie_word_embeddings):
+ set_submodule(
+ model,
+ descriptor.input_embedding_name(),
+ DummyWTE(lm_config.hidden_size, dtype=runtime.dtype),
+ )
+
+ if not has_last_block:
+ set_submodule(model, descriptor.final_norm_name(), nn.Identity())
+ if not (model.config.tie_word_embeddings and has_first_block):
+ set_submodule(model, descriptor.output_embedding_name(), DummyLMHead(lm_config))
+
+ return model
+
+
+def _get_model_class_from_config(config: PretrainedConfig):
+ """
+ Get the model class from config.architectures field.
+ Works for any model registered in transformers (CausalLM, VL models, etc.).
+ Falls back to AutoModelForCausalLM if architectures is not available.
+ """
+ if hasattr(config, "architectures") and config.architectures:
+ model_class_name = config.architectures[0]
+ if hasattr(transformers, model_class_name):
+ return getattr(transformers, model_class_name)
+ mprint(
+ f"Warning: {model_class_name} not found in transformers, falling back to AutoModelForCausalLM"
+ )
+ return AutoModelForCausalLM
+
+
+def load_and_shard_model(
+ descriptor,
+ checkpoint_path: str | Path,
+ owned_block_indexes: set[int] | Literal["auto"] = "auto",
+ model_config: PretrainedConfig | None = None,
+):
+ checkpoint_path = Path(checkpoint_path)
+ runtime = SimpleNamespace(
+ device=torch.device(dist.local_rank()),
+ dtype=torch.bfloat16,
+ global_rank=dist.rank(),
+ world_size=dist.size(),
+ is_main_process=dist.is_master(),
+ is_last_process=dist.is_last_process(),
+ use_autocast=True, # Default: use autocast; descriptor can override
+ )
+
+ with runtime.device:
+ if model_config is None:
+ trust_remote_code = descriptor.requires_trust_remote_code()
+ model_config = load_model_config(checkpoint_path, trust_remote_code=trust_remote_code)
+
+ num_hidden_layers = descriptor.get_language_model_config(model_config).num_hidden_layers
+ if owned_block_indexes == "auto":
+ owned_block_indexes = set(
+ np.array_split(np.arange(num_hidden_layers), runtime.world_size)[
+ runtime.global_rank
+ ]
+ )
+
+ mprint("Initializing model shards")
+ # Pass block_configs explicitly so patcher works for VL models where
+ # decoder layers receive nested config (e.g., text_config) without block_configs
+ from modelopt.torch.puzzletron.anymodel.puzzformer import deci_x_patcher
+
+ with deci_x_patcher(
+ model_descriptor=descriptor, block_configs=getattr(model_config, "block_configs", None)
+ ):
+ model_shard = create_sharded_model(
+ runtime=runtime,
+ descriptor=descriptor,
+ model_config=model_config,
+ owned_block_indexes=owned_block_indexes,
+ )
+
+ if (checkpoint_path / SAFE_WEIGHTS_NAME).exists() or (
+ checkpoint_path / SAFE_WEIGHTS_INDEX_NAME
+ ).exists():
+ mprint("Loading shard state_dict from safetensors")
+ shard_keys = [
+ *[name for name, _ in model_shard.named_parameters()],
+ *[name for name, _ in model_shard.named_buffers()],
+ ]
+ shard_state_dict = load_sharded_state_dict(
+ model_name_or_path=str(checkpoint_path),
+ keys_to_load=shard_keys,
+ device=runtime.device,
+ )
+
+ new_names = set(shard_state_dict.keys())
+ mprint(f"{new_names=}")
+ # strict=False: allows missing lm_head.weight when tie_word_embeddings=True (e.g., Llama 3.2 3B)
+ model_shard.load_state_dict(shard_state_dict, strict=False, assign=True)
+
+ del shard_state_dict
+
+ # Re-tie weights after load_state_dict with assign=True, which severs the tie.
+ # Needed on first rank (owns embed_tokens) and last rank (owns lm_head).
+ has_first_block = 0 in owned_block_indexes
+ has_last_block = (num_hidden_layers - 1) in owned_block_indexes
+ if model_config.tie_word_embeddings and (has_first_block or has_last_block):
+ model_shard.tie_weights()
+
+ # On the last rank with tied embeddings, we kept embed_tokens in create_local_shard_()
+ # just to load the weight and tie it to lm_head. Now replace it with a dummy so it
+ # doesn't interfere with the pipeline forward pass (only rank 0 should run embed_tokens).
+ if model_config.tie_word_embeddings and has_last_block and not has_first_block:
+ set_submodule(
+ model_shard,
+ descriptor.input_embedding_name(),
+ DummyWTE(model_config.hidden_size, dtype=runtime.dtype),
+ )
+ else:
+ mprint("Loading state_dict in main process")
+ state_dict = load_state_dict(checkpoint_path) if runtime.is_main_process else None
+
+ mprint("Distributing model to shards")
+ load_state_dict_to_shards(model_shard=model_shard, loaded_state_dict=state_dict)
+ del state_dict
+
+ descriptor.init_rotary_embedding(model_shard, runtime)
+
+ model_shard.type(runtime.dtype)
+
+ # Configure autocast based on model descriptor (some models like Qwen3-VL MoE
+ # have dtype bugs under autocast)
+ runtime.use_autocast = descriptor.uses_autocast()
+
+ params_on_meta_device = [
+ param_name
+ for param_name, param in model_shard.named_parameters()
+ if param.device == torch.device("meta")
+ ]
+ assert len(params_on_meta_device) == 0, (
+ f"[global_rank={runtime.global_rank}] Couldn't load params {params_on_meta_device}"
+ )
+
+ return model_shard
+
+
+def create_sharded_model(
+ runtime,
+ descriptor,
+ model_config: PretrainedConfig,
+ owned_block_indexes: set[int],
+ device: str | torch.device | None = "meta",
+ dtype: torch.dtype | None = torch.float32,
+):
+ if isinstance(device, str):
+ device = torch.device(device)
+
+ dist.barrier()
+
+ with EmptyInitOnDevice(device="meta", dtype=dtype):
+ # Get model class from config.architectures (works for CausalLM, VL models, etc.)
+ model_class = _get_model_class_from_config(model_config)
+ # AutoModelForCausalLM uses from_config(); concrete model classes use _from_config()
+ if model_class is AutoModelForCausalLM:
+ trust_remote_code = descriptor.requires_trust_remote_code()
+ model = model_class.from_config(model_config, trust_remote_code=trust_remote_code)
+ else:
+ model = model_class._from_config(model_config)
+ create_local_shard_(
+ model=model,
+ owned_block_indexes=owned_block_indexes,
+ descriptor=descriptor,
+ runtime=runtime,
+ )
+
+ if device != torch.device("meta"):
+ local_shard_state_dict = {
+ k: torch.empty_like(v, device=device) for k, v in model.state_dict().items()
+ }
+ model.load_state_dict(local_shard_state_dict, assign=True)
+
+ return model
+
+
+def load_state_dict_to_shards(
+ model_shard: torch.nn.Module, loaded_state_dict: dict | None = None
+) -> None:
+ from modelopt.torch.puzzletron.sewing_kit.utils import (
+ distributed_isend_obj,
+ distributed_recv_obj,
+ )
+
+ model_shard.to("meta")
+ local_state_dict_keys = list(model_shard.state_dict().keys())
+
+ if dist.is_master():
+ gathered_state_dict_keys = [None] * dist.size()
+ torch.distributed.gather_object(local_state_dict_keys, gathered_state_dict_keys)
+
+ assert loaded_state_dict is not None
+ loaded_state_dict = {k.replace("_orig_mod.", ""): v for k, v in loaded_state_dict.items()}
+
+ works: list[torch.distributed.Work] = []
+ for i, shard_keys in enumerate(gathered_state_dict_keys[1:]):
+ process_id = i + 1
+ shard_state_dict = {k: v for k, v in loaded_state_dict.items() if k in shard_keys}
+ process_works = distributed_isend_obj(shard_state_dict, process_id)
+ works.extend(process_works)
+
+ for work in works:
+ work.wait()
+
+ shard_state_dict = {
+ k: v for k, v in loaded_state_dict.items() if k in local_state_dict_keys
+ }
+ else:
+ torch.distributed.gather_object(local_state_dict_keys)
+ shard_state_dict = distributed_recv_obj()
+
+ print(f"{dist.rank()} loaded state_dict shard")
+
+ missing_keys, unexpected_keys = model_shard.load_state_dict(
+ shard_state_dict, strict=False, assign=True
+ )
+ assert len(unexpected_keys) == 0
+ assert all("dummy_param" in key for key in missing_keys)
+
+ model_shard.cuda(dist.local_rank())
+
+ dist.barrier()
+
+
+def save_sharded_model(
+ model_shard: torch.nn.Module | dict[str, torch.Tensor], out_path: str | Path
+):
+ """
+ out_path is usually output_checkpoint_path / "model.safetensors"
+ """
+ dist.barrier()
+
+ if isinstance(model_shard, torch.nn.Module):
+ shard_state_dict = model_shard.state_dict()
+ elif isinstance(model_shard, dict):
+ shard_state_dict = model_shard
+ else:
+ raise ValueError(f"Unrecognized model shard type: {type(model_shard)}")
+
+ shard_state_dict = {k: v.cpu() for k, v in shard_state_dict.items()}
+ total_shard_size = sum(
+ weight.numel() * weight.element_size() for weight in shard_state_dict.values()
+ )
+
+ num_shards = dist.size()
+ idx = dist.rank()
+
+ out_path = Path(out_path)
+ shard_file = out_path.with_stem(f"{out_path.stem}-{idx + 1:05d}-of-{num_shards:05d}")
+
+ shard_metadata = {
+ "total_shard_size": total_shard_size,
+ "shard_keys": list(shard_state_dict.keys()),
+ "shard_file": str(shard_file),
+ }
+
+ if dist.is_master():
+ shard_metadatas = [{} for _ in range(dist.size())]
+ torch.distributed.gather_object(shard_metadata, shard_metadatas, dst=0)
+ total_size = sum(x["total_shard_size"] for x in shard_metadatas)
+ metadata = {"total_size": total_size}
+ weight_map: dict[str, str] = {}
+ for shard_metadata in shard_metadatas:
+ weight_map.update(
+ {k: Path(shard_metadata["shard_file"]).name for k in shard_metadata["shard_keys"]}
+ )
+
+ index = {"metadata": metadata, "weight_map": weight_map}
+ index_path = Path(str(out_path) + ".index.json")
+ index_path.write_text(json.dumps(index, indent=2))
+
+ else:
+ torch.distributed.gather_object(shard_metadata, dst=0)
+
+ if out_path.suffix == ".safetensors":
+ safe_save_file(shard_state_dict, shard_file, metadata={"format": "pt"})
+ else:
+ torch.save(shard_state_dict, shard_file)
+
+ dist.barrier()
+
+
+def load_sharded_state_dict(
+ model_name_or_path: str | Path,
+ keys_to_load: Iterable[str] | None = None,
+ device: torch.device | str = "cpu",
+) -> dict[str, torch.Tensor]:
+ """
+ keys_to_load: entire state_dict if None, else partial state_dict containing only these keys
+ """
+ shard_paths = _resolve_shard_paths(model_name_or_path)
+ # print(f"shard_paths: {shard_paths}")
+ partial_state_dict = {}
+ for safetensors_path in shard_paths:
+ if keys_to_load is None:
+ shard = safe_load_file(safetensors_path)
+ partial_state_dict.update(shard)
+ else:
+ with safe_open(safetensors_path, framework="pt", device=str(device)) as f:
+ for key in f.keys(): # noqa: SIM118 - safe_open objects require .keys(), not directly iterable
+ if key in keys_to_load:
+ partial_state_dict[key] = f.get_tensor(key)
+ return partial_state_dict
+
+
+def _resolve_shard_paths(model_name_or_path: str) -> list[str]:
+ try:
+ unsharded_path = cached_file(model_name_or_path, SAFE_WEIGHTS_NAME)
+ return [unsharded_path]
+ except OSError:
+ index_path = cached_file(model_name_or_path, SAFE_WEIGHTS_INDEX_NAME)
+ shard_paths, _ = get_checkpoint_shard_files(model_name_or_path, index_path)
+ return shard_paths
+
+
+def is_in_safetensors_format(checkpoint_dir: Path) -> bool:
+ return len(list(checkpoint_dir.glob("*.safetensors"))) > 0
diff --git a/modelopt/torch/puzzletron/tools/validate_model.py b/modelopt/torch/puzzletron/tools/validate_model.py
new file mode 100644
index 0000000000..4a300fcd0b
--- /dev/null
+++ b/modelopt/torch/puzzletron/tools/validate_model.py
@@ -0,0 +1,262 @@
+# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+# SPDX-License-Identifier: Apache-2.0
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# mypy: ignore-errors
+"""
+Provides a function to validate a model. Runs a model forward pass on a dataset and calculates
+the loss, and optionally registers hooks to capture the inputs and the outputs
+of pytorch modules that are used for activation scoring for pruning.
+
+TODO: Consider moving this a separate module dedicated for scoring
+
+Uses native HuggingFace models with deci_x_patcher for heterogeneous layer configurations.
+"""
+
+import textwrap
+from pathlib import Path
+from typing import Type
+
+import torch
+from omegaconf import DictConfig
+from torch import nn
+from torch.utils.data import DataLoader
+from transformers import AutoTokenizer, PreTrainedModel, PreTrainedTokenizerBase
+
+import modelopt.torch.utils.distributed as dist
+from modelopt.torch.puzzletron.activation_scoring.activation_hooks.utils import (
+ register_activation_hooks,
+)
+from modelopt.torch.puzzletron.anymodel.model_descriptor import (
+ ModelDescriptor,
+ ModelDescriptorFactory,
+)
+from modelopt.torch.puzzletron.anymodel.puzzformer.no_op import Same
+from modelopt.torch.puzzletron.tools.logger import aprint, mprint
+from modelopt.torch.puzzletron.tools.sharded_checkpoint_utils import (
+ load_and_shard_model,
+ set_submodule,
+)
+from modelopt.torch.puzzletron.utils.data.dataloaders import create_validation_dataloader
+from modelopt.torch.puzzletron.utils.parsing import (
+ simple_parse_args_string, # noqa: F401 (kept for backwards compat)
+)
+from modelopt.torch.puzzletron.utils.validate_runtime_pipeline import (
+ HiddenStatesAndLMHead,
+ calculate_losses_pipeline,
+)
+
+"""
+Two goals:
+1) Calculate lm loss and token accuracy for a model.
+May raise lots of NCCL warnings when it finishes, don't be alarmed.
+Can be used to validate a HuggingFace model.
+Automatically uses pipeline parallelism via device_map="auto".
+
+2) Register hooks to capture the inputs and the outputs of pytorch modules.
+For example, to collect activations scores for various layers (ffn, layer_norm, etc.)
+that are used for pruning (ffn_hidden_size, embedding_pruning, etc).
+See activations_log_dir and activation_hooks_kwargs arguments.
+"""
+
+
+@torch.no_grad()
+def validate_model(
+ args: DictConfig,
+ model: PreTrainedModel | None = None,
+ tokenizer: PreTrainedTokenizerBase | None = None,
+ target_hidden_states_per_batch: list[torch.Tensor] | None = None,
+ return_hidden_states: bool = False,
+ calculate_full_score_ablations: bool = False,
+ val_dataloader: DataLoader | None = None,
+) -> tuple[dict[str, dict], HiddenStatesAndLMHead | None] | tuple[None, None]:
+ """Validate a language model on a dataset by calculating loss and optionally capturing activations.
+
+ Args:
+ args: Configuration object containing the following attributes:
+
+ Model Configuration:
+ - model_name_or_path (str): Path to model checkpoint or HuggingFace model name. Required unless model is passed directly.
+ - model_dtype (str or torch.dtype): Model data type (e.g., "torch.bfloat16", torch.float16).
+ - autocast_dtype (str or torch.dtype): Autocast data type for mixed precision.
+
+ Dataset Configuration:
+ - dataset_path (str): Path to the validation dataset.
+ - tokenizer_name (str, optional): Tokenizer name/path. Uses model_name_or_path if not specified.
+ - data_column (str): Column name in dataset containing text data.
+ - block_size (int): Maximum sequence length for tokenization.
+ - eval_samples (int, optional): Number of samples to evaluate. Uses all if None.
+ - val_dataset_name (str): Name of validation dataset split.
+ - source_datasets_to_discard (list[str], optional): List of source datasets to exclude.
+ - load_dataset_fn (callable, optional): Custom function to load the dataset.
+
+ Data Processing:
+ - micro_batch_size (int): Batch size for evaluation.
+ - seed (int): Random seed for reproducibility.
+ - shuffle_seed (int, optional): Seed for shuffling data. Uses seed if None.
+ - varlen (bool): Enable variable-length sequences.
+ - bos_rate (float): Rate of adding BOS token.
+ - fim_rate (float): Fill-in-the-middle rate for code completion tasks.
+ - fim_spm_rate (float): SPM-based fill-in-the-middle rate.
+
+ Activation Hooks:
+ - activations_log_dir (str, optional): Directory to log activation scores. If provided, hooks will be registered to capture activations.
+ - activation_hooks_kwargs (str or dict, optional): Arguments for activation hooks. If string, comma-separated format: "arg1=val1,arg2=val2".
+
+ Execution Options:
+ - calc_losses_on_cpu (bool): Calculate losses on CPU to avoid OOM. Very slow, not recommended.
+ - write_results (bool): Write validation results to file.
+
+ model: Pre-loaded model. If None, will be loaded from args.model_name_or_path.
+ tokenizer: Pre-loaded tokenizer. If None, will be loaded based on args.
+ target_hidden_states_per_batch: Target hidden states for pipeline parallel evaluation.
+ return_hidden_states: Whether to return hidden states from the model.
+ calculate_full_score_ablations: Calculate comprehensive teacher similarity scores. False calculates only a small suite for efficiency.
+ val_dataloader: Pre-created validation dataloader. If None, will be created from args.
+
+ Returns:
+ A tuple containing:
+ - losses: Dictionary mapping loss names to loss statistics (avg, per_sample).
+ - hidden_states_per_batch: Hidden states and LM head outputs if return_hidden_states is True, else None.
+
+ Returns (None, None) if not on master rank.
+ """
+ descriptor = ModelDescriptorFactory.get(args.descriptor)
+
+ if val_dataloader is None:
+ val_dataloader = prepare_dataloader(args, tokenizer) if dist.is_master() else None
+ validation_full_iters = (
+ args.eval_samples // args.micro_batch_size
+ ) # model pipeline, single data rank
+
+ model = prepare_model(args, descriptor=descriptor, model=model)
+
+ just_model_forward = False
+ checkpoint_manager = None
+ activation_hooks = None
+
+ if args.activations_log_dir is not None:
+ activation_hooks_kwargs = args.activation_hooks_kwargs or {}
+ activation_hooks_kwargs["validation_full_iters"] = validation_full_iters
+ hook_class = args.hook_class
+
+ # Create activation hooks using pruning mixin
+ activation_hooks = register_activation_hooks(
+ model=model,
+ activation_hooks_kwargs=activation_hooks_kwargs,
+ hook_class=hook_class,
+ pruning_mixin=args.pruning_mixin,
+ )
+
+ # Create checkpoint manager with hooks
+ from modelopt.torch.puzzletron.utils.checkpoint_manager import ScoringCheckpointManager
+
+ mprint(
+ f"Creating checkpoint manager with {len(activation_hooks)} hooks for dir: {args.activations_log_dir}"
+ )
+ checkpoint_manager = ScoringCheckpointManager(
+ checkpoint_dir=args.activations_log_dir,
+ activation_hooks=activation_hooks,
+ checkpoint_interval=50, # Save every 50 batches
+ )
+
+ # Load existing checkpoint if available
+ mprint("Attempting to load existing checkpoint...")
+ checkpoint_data = checkpoint_manager.load_checkpoint()
+ if checkpoint_data:
+ mprint(f"Checkpoint loaded successfully: {checkpoint_data}")
+ else:
+ mprint("No checkpoint found, starting fresh")
+ just_model_forward = True
+ set_submodule(model, descriptor.output_embedding_name(), Same())
+
+ losses, hidden_states_per_batch = calculate_losses_pipeline(
+ stitched_model=model,
+ dataloader=val_dataloader,
+ target_hidden_states_per_batch=target_hidden_states_per_batch,
+ return_hidden_states=return_hidden_states,
+ calculate_full_score_ablations=calculate_full_score_ablations,
+ calc_on_cpu=args.calc_losses_on_cpu,
+ just_model_forward=just_model_forward,
+ checkpoint_manager=checkpoint_manager,
+ autocast_dtype=getattr(
+ torch, getattr(args, "autocast_dtype", "torch.bfloat16").strip("torch.")
+ ),
+ descriptor=descriptor,
+ use_autocast=descriptor.uses_autocast(),
+ )
+
+ if losses is not None:
+ avg_losses = {loss_name: loss_log["avg"] for loss_name, loss_log in losses.items()}
+
+ results_str = f"""
+ validate_model:
+ {args.model_name_or_path=}
+ Average losses = {avg_losses}
+ Actual num samples = {len(next(iter(losses.values()))["per_sample"])}
+ {args=}
+ """
+ results_str = textwrap.dedent(results_str)
+ aprint(results_str)
+ if args.write_results:
+ Path(f"{args.model_name_or_path}/validate_model_results.txt").write_text(results_str)
+
+ if activation_hooks is not None:
+ hook_class.dump_activations_logs(activation_hooks, args.activations_log_dir, args)
+
+ return losses, hidden_states_per_batch
+
+
+def prepare_model(
+ args: DictConfig,
+ descriptor: Type[ModelDescriptor],
+ model: PreTrainedModel | None = None,
+) -> nn.Module:
+ if model is None:
+ assert args.model_name_or_path is not None
+ model = load_and_shard_model(descriptor=descriptor, checkpoint_path=args.model_name_or_path)
+
+ model.eval()
+ return model
+
+
+def prepare_dataloader(
+ args: DictConfig, tokenizer: PreTrainedTokenizerBase | None = None
+) -> DataLoader:
+ if tokenizer is None:
+ tokenizer_name = getattr(args, "tokenizer_name", None)
+ assert (tokenizer_name is not None) or (args.model_name_or_path is not None)
+ tokenizer = AutoTokenizer.from_pretrained(
+ tokenizer_name or args.model_name_or_path, trust_remote_code=True
+ )
+
+ val_dataloader = create_validation_dataloader(
+ accelerator=None,
+ seed=args.seed,
+ tokenizer=tokenizer,
+ block_size=args.block_size,
+ dataset=args.dataset_path,
+ content_field=args.data_column,
+ fim_rate=args.fim_rate,
+ fim_spm_rate=args.fim_spm_rate,
+ micro_batch_size=args.micro_batch_size,
+ eval_samples=args.eval_samples,
+ dataset_name=args.val_dataset_name,
+ source_datasets_to_discard=args.source_datasets_to_discard,
+ bos_rate=args.bos_rate,
+ varlen=args.varlen,
+ shuffle_seed=args.shuffle_seed,
+ load_dataset_fn=args.load_dataset_fn,
+ )
+
+ return val_dataloader
diff --git a/modelopt/torch/puzzletron/tools/validate_puzzle_with_multi_replacements.py b/modelopt/torch/puzzletron/tools/validate_puzzle_with_multi_replacements.py
new file mode 100644
index 0000000000..f647cd3f89
--- /dev/null
+++ b/modelopt/torch/puzzletron/tools/validate_puzzle_with_multi_replacements.py
@@ -0,0 +1,290 @@
+# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+# SPDX-License-Identifier: Apache-2.0
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""Validates puzzle solutions by applying layer replacements and evaluating model performance.
+
+TODO: Consider moving this a separate module dedicated for scoring
+"""
+
+# mypy: ignore-errors
+
+import json
+import shutil
+import warnings
+from functools import partial
+from pathlib import Path
+from typing import Optional
+
+import torch
+from omegaconf import DictConfig
+from tqdm import tqdm
+from transformers import AutoTokenizer, PreTrainedTokenizerBase
+
+import modelopt.torch.utils.distributed as dist
+from modelopt.torch.puzzletron.anymodel.converter import Converter
+from modelopt.torch.puzzletron.anymodel.model_descriptor import ModelDescriptorFactory
+from modelopt.torch.puzzletron.replacement_library.replacement_library import ReplacementLibrary
+from modelopt.torch.puzzletron.replacement_library.replacement_utils import parse_layer_replacement
+from modelopt.torch.puzzletron.tools import validate_model
+from modelopt.torch.puzzletron.tools.checkpoint_utils import (
+ SAFETENSORS_SUBBLOCKS_DIR_NAME,
+ copy_tokenizer,
+)
+from modelopt.torch.puzzletron.tools.checkpoint_utils_hf import save_checkpoint
+from modelopt.torch.puzzletron.tools.sharded_checkpoint_utils import load_and_shard_model
+from modelopt.torch.puzzletron.tools.validation_utils import (
+ validate_model_and_extract_hidden_states,
+ validate_model_with_teacher_similarity_metrics,
+)
+from modelopt.torch.puzzletron.utils.parsing import get_nested_key, parse_path
+from modelopt.torch.puzzletron.utils.validate_runtime_pipeline import perform_pipeline_stitches
+
+"""
+Usage Example:
+==============
+
+Validate single_block_replacement_solutions by calling validate_puzzle_solutions() directly
+with an args object containing the required attributes. See the function docstring for details.
+
+"""
+
+
+@torch.no_grad()
+def validate_puzzle_solutions(args: DictConfig) -> None:
+ """Validate puzzle solutions by applying layer replacements and evaluating model performance.
+
+ Args:
+ args: Configuration object containing the following attributes:
+
+ Puzzle Configuration (Required):
+ - replacement_library_path (Path): Path to the replacement library JSON file.
+ - solutions_path (Path): Path to puzzle solutions JSON file or directory containing solution files.
+ - solutions_to_validate (list[int], optional): Indices of specific solutions to validate. Validates all solutions if None.
+ - sort_solutions_by (str, optional): JSON field path to sort solutions by before validation.
+ - bigger_is_better (bool): If True, sort solutions in descending order. Used with sort_solutions_by.
+ - skip_validation (bool): If True, skip model validation and only save models if requested.
+ - save_models (bool): If True, save realized model checkpoints for each solution.
+
+ Teacher/Tokenizer Configuration:
+ - teacher_dir (Path, optional): Path to teacher model directory. Auto-inferred if not provided.
+ - tokenizer_name (str, optional): Tokenizer name/path. Uses teacher_dir if not specified.
+
+ Model Configuration (Required if skip_validation=False):
+ - model_dtype (str or torch.dtype): Model data type (e.g., "torch.bfloat16", torch.float16).
+ - autocast_dtype (str or torch.dtype): Autocast data type for mixed precision.
+
+ Dataset Configuration (Required if skip_validation=False):
+ - dataset_path (str): Path to the validation dataset.
+ - data_column (str): Column name in dataset containing text data.
+ - block_size (int): Maximum sequence length for tokenization.
+ - eval_samples (int, optional): Number of samples to evaluate.
+ - val_dataset_name (str): Name of validation dataset split.
+ - source_datasets_to_discard (list[str], optional): List of source datasets to exclude.
+ - load_dataset_fn (callable, optional): Custom function to load the dataset.
+
+ Data Processing (Required if skip_validation=False):
+ - micro_batch_size (int): Batch size for evaluation.
+ - seed (int): Random seed for reproducibility.
+ - shuffle_seed (int, optional): Seed for shuffling data.
+ - varlen (bool): Enable variable-length sequences.
+ - bos_rate (float): Rate of adding BOS token.
+ - fim_rate (float): Fill-in-the-middle rate for code completion tasks.
+ - fim_spm_rate (float): SPM-based fill-in-the-middle rate.
+
+ Output Configuration:
+ - output_dir (Path, optional): Directory to save validation results. Auto-generated from solutions_path if not provided.
+
+ Execution Options (Optional if skip_validation=False):
+ - calc_losses_on_cpu (bool): Calculate losses on CPU to avoid OOM.
+ - write_results (bool): Write validation results to file.
+ - activations_log_dir (str, optional): Directory to log activation scores.
+ - activation_hooks_kwargs (str or dict, optional): Arguments for activation hooks.
+
+ Returns:
+ None. Saves validation results and optionally model checkpoints to disk.
+ """
+ descriptor = ModelDescriptorFactory.get(args.descriptor)
+
+ puzzle_solutions = load_puzzle_solutions(
+ args.solutions_path, args.sort_solutions_by, args.bigger_is_better
+ )
+ if args.solutions_to_validate is None:
+ args.solutions_to_validate = list(range(len(puzzle_solutions)))
+ puzzle_solutions = [puzzle_solutions[i] for i in args.solutions_to_validate]
+
+ tokenizer = _load_tokenizer(args, trust_remote_code=descriptor.requires_trust_remote_code())
+ if not args.skip_validation:
+ val_dataloader = (
+ validate_model.prepare_dataloader(args, tokenizer) if dist.is_master() else None
+ )
+
+ output_dir = (
+ args.output_dir
+ if getattr(args, "output_dir", None) is not None
+ else args.solutions_path.with_name(f"{args.solutions_path.stem}--validation")
+ )
+
+ replacement_library = ReplacementLibrary(
+ args.replacement_library_path,
+ descriptor=descriptor,
+ model_config_overrides={"use_cache": False},
+ )
+
+ teacher_hidden_states = None
+ if (args.teacher_dir is not None) and (not args.skip_validation):
+ teacher_model = load_and_shard_model(
+ checkpoint_path=args.teacher_dir, descriptor=descriptor
+ )
+ teacher_model.cuda(dist.local_rank())
+ stitched_model = perform_pipeline_stitches(teacher_model, descriptor=descriptor)
+ teacher_hidden_states = validate_model_and_extract_hidden_states(
+ args,
+ stitched_model,
+ tokenizer,
+ output_dir,
+ model_name="teacher",
+ val_dataloader=val_dataloader,
+ )
+
+ # Properly release CUDA memory after teacher validation
+ teacher_model.cpu()
+ stitched_model.cpu()
+ torch.cuda.empty_cache()
+ torch.cuda.synchronize()
+ dist.barrier()
+
+ for i_solution, puzzle_solution in tqdm(
+ list(zip(args.solutions_to_validate, puzzle_solutions)), desc="Validating solutions"
+ ):
+ layer_replacements = _extract_layer_replacements_from_puzzle_solution(puzzle_solution)
+ realizable_as_symlinks = can_realize_as_symlinks(layer_replacements)
+ # realizable_as_symlinks = False
+ model_config = replacement_library.create_model_config(layer_replacements)
+ if (args.save_models and not realizable_as_symlinks) or (not args.skip_validation):
+ model = replacement_library.load_model(layer_replacements)
+ model_config = model.config
+
+ if args.save_models:
+ checkpoint_dir = (
+ args.solutions_path.with_name(f"{args.solutions_path.stem}--checkpoints")
+ / f"solution_{i_solution}"
+ )
+
+ model_config.dtype = getattr(args, "model_dtype", "torch.bfloat16")
+ Converter.copy_checkpoint_files(args.teacher_dir, checkpoint_dir)
+ if realizable_as_symlinks:
+ if dist.is_master():
+ # TODO: Loo into internal Puzzleron code to see how to save as symlinks
+ # save_checkpoint_as_symlinks is currently not supported
+ pass
+ save_checkpoint(model, checkpoint_dir, descriptor)
+
+ copy_tokenizer(
+ args.tokenizer_name,
+ checkpoint_dir,
+ trust_remote_code=descriptor.requires_trust_remote_code(),
+ )
+
+ dist.barrier()
+
+ if not args.skip_validation:
+ model.cuda(dist.local_rank())
+ stitched_model = perform_pipeline_stitches(model, descriptor=descriptor)
+ validate_model_with_teacher_similarity_metrics(
+ args,
+ stitched_model,
+ tokenizer,
+ teacher_hidden_states,
+ output_dir,
+ model_name=f"solution_{i_solution}",
+ extra_payload={"i_solution": i_solution, "puzzle_solution": puzzle_solution},
+ val_dataloader=val_dataloader,
+ )
+
+ # Properly release CUDA memory after solution validation
+ model.cpu()
+ stitched_model.cpu()
+ torch.cuda.empty_cache()
+ torch.cuda.synchronize()
+
+ dist.barrier()
+
+
+def can_realize_as_symlinks(layer_replacements: list[dict]) -> bool:
+ for layer_replacement in layer_replacements:
+ num_parent_layers = len(layer_replacement["parent_layer_indices"])
+ num_child_layers = len(layer_replacement["child_block_configs"])
+ if num_parent_layers != num_child_layers or num_parent_layers != 1:
+ return False
+ return True
+
+
+def _load_tokenizer(args: DictConfig, trust_remote_code: bool = False) -> PreTrainedTokenizerBase:
+ tokenizer = None
+ if (tokenizer_name := getattr(args, "tokenizer_name", None)) is not None:
+ tokenizer = AutoTokenizer.from_pretrained(
+ tokenizer_name, trust_remote_code=trust_remote_code
+ )
+ elif args.teacher_dir is not None:
+ try:
+ tokenizer = AutoTokenizer.from_pretrained(
+ args.teacher_dir, trust_remote_code=trust_remote_code
+ )
+ except Exception:
+ pass
+ if tokenizer is None:
+ warnings.warn("Couldn't find a tokenizer, trying to continue without one")
+ return tokenizer
+
+
+def _extract_layer_replacements_from_puzzle_solution(
+ puzzle_solution: dict,
+) -> list[dict]:
+ puzzle_solution = puzzle_solution.get("puzzle_solution", puzzle_solution)
+ layer_replacements = [
+ parse_layer_replacement(rep) for rep in puzzle_solution["chosen_replacements"]
+ ]
+ return layer_replacements
+
+
+def load_puzzle_solutions(
+ solutions_path: Path,
+ sort_solutions_by: Optional[str],
+ bigger_is_better: bool,
+) -> list[dict]:
+ assert solutions_path.exists(), f"{solutions_path=} does not exist"
+
+ if solutions_path.is_file():
+ puzzle_solutions = json.loads(solutions_path.read_text())
+ if isinstance(puzzle_solutions, dict):
+ puzzle_solutions = [puzzle_solutions]
+ else:
+ puzzle_solutions = [
+ json.loads(p.read_text()) for p in solutions_path.glob("*solution*.json")
+ ]
+
+ if len(puzzle_solutions) == 0:
+ raise ValueError(f"No solutions under {solutions_path=}")
+
+ if sort_solutions_by is not None:
+ puzzle_solutions = sorted(
+ puzzle_solutions, key=partial(get_nested_key, field=sort_solutions_by)
+ )
+ if bigger_is_better:
+ puzzle_solutions = puzzle_solutions[::-1]
+ vals = [get_nested_key(sol, sort_solutions_by) for sol in puzzle_solutions]
+ print(f"sorted solutions by {sort_solutions_by}. {vals[:10]=} {vals[-10:]=}")
+
+ return puzzle_solutions
diff --git a/modelopt/torch/puzzletron/tools/validation_utils.py b/modelopt/torch/puzzletron/tools/validation_utils.py
new file mode 100644
index 0000000000..d7197e8abf
--- /dev/null
+++ b/modelopt/torch/puzzletron/tools/validation_utils.py
@@ -0,0 +1,114 @@
+# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+# SPDX-License-Identifier: Apache-2.0
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""Utility functions for validating models and extracting hidden states and similarity metrics.
+
+TODO: Consider moving this a separate module dedicated for scoring.
+"""
+
+# mypy: ignore-errors
+
+from pathlib import Path
+from typing import TYPE_CHECKING, Any, Optional, Union
+
+import torch
+from omegaconf import DictConfig, OmegaConf
+from torch import nn
+from transformers import PreTrainedTokenizerBase
+
+import modelopt.torch.utils.distributed as dist
+from modelopt.torch.puzzletron.tools import validate_model
+from modelopt.torch.puzzletron.tools.logger import mprint
+from modelopt.torch.puzzletron.tools.robust_json import json_dump
+from modelopt.torch.puzzletron.utils.validation import LowMemorySparseTensor
+
+if TYPE_CHECKING:
+ from modelopt.torch.puzzletron.sewing_kit import StitchedModule
+
+
+def validate_model_and_extract_hidden_states(
+ args: DictConfig,
+ model: "nn.Module | StitchedModule",
+ tokenizer: PreTrainedTokenizerBase,
+ output_dir: str | Path,
+ model_name: str,
+ extra_payload: Optional[dict[str, Any]] = None,
+ val_dataloader=None,
+) -> list[torch.Tensor | LowMemorySparseTensor]:
+ mprint(f"""
+
+################################################################
+validate_model_and_extract_token_probs({model_name=})
+################################################################
+
+""")
+ losses, hidden_states_per_batch = validate_model.validate_model(
+ args,
+ model,
+ tokenizer,
+ return_hidden_states=True,
+ val_dataloader=val_dataloader,
+ )
+ if dist.is_last_process():
+ output_dir = output_dir if (output_dir is not None) else args.bypass_dir
+ extra_payload = extra_payload if (extra_payload is not None) else dict()
+ write_results(output_dir, model_name, args, {**losses, **extra_payload})
+ return hidden_states_per_batch
+
+
+def validate_model_with_teacher_similarity_metrics(
+ args: DictConfig,
+ model: "nn.Module | StitchedModule",
+ tokenizer: PreTrainedTokenizerBase,
+ target_hidden_states_per_batch: list[torch.Tensor],
+ output_dir: str | Path,
+ model_name: str,
+ extra_payload: Optional[dict[str, Any]] = None,
+ calculate_full_score_ablations: bool = False,
+ val_dataloader=None,
+) -> None:
+ is_calc_kl_div = target_hidden_states_per_batch is not None
+ mprint(f"""
+
+################################################################
+validate_model_with_kl_div({model_name=}, {is_calc_kl_div=})
+################################################################
+
+""")
+ losses, _ = validate_model.validate_model(
+ args,
+ model,
+ tokenizer,
+ target_hidden_states_per_batch=target_hidden_states_per_batch,
+ calculate_full_score_ablations=calculate_full_score_ablations,
+ val_dataloader=val_dataloader,
+ )
+ if dist.is_last_process():
+ extra_payload = extra_payload if (extra_payload is not None) else dict()
+ write_results(output_dir, model_name, args, {**losses, **extra_payload})
+
+
+def write_results(
+ output_dir: str | Path, result_name: str, args: DictConfig, payload: dict[str, Any]
+) -> None:
+ output_path = Path(output_dir) / f"{result_name}.json"
+ output_path.parent.mkdir(parents=True, exist_ok=True)
+ results = {
+ **payload,
+ "args": OmegaConf.to_container(args, resolve=True)
+ if isinstance(args, DictConfig)
+ else args.__dict__,
+ }
+ json_dump(results, output_path)
diff --git a/modelopt/torch/puzzletron/utils/checkpoint_manager.py b/modelopt/torch/puzzletron/utils/checkpoint_manager.py
new file mode 100644
index 0000000000..a1347deaea
--- /dev/null
+++ b/modelopt/torch/puzzletron/utils/checkpoint_manager.py
@@ -0,0 +1,259 @@
+# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+# SPDX-License-Identifier: Apache-2.0
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""Checkpoint manager for activation hook scoring with periodic saves and resume support."""
+
+import json
+import time
+from pathlib import Path
+from typing import Any
+
+import modelopt.torch.utils.distributed as dist
+from modelopt.torch.puzzletron.tools.logger import aprint, mprint
+
+
+class ScoringCheckpointManager:
+ """Manages checkpointing for activation hook scoring with periodic saves."""
+
+ def __init__(self, checkpoint_dir: str, activation_hooks=None, checkpoint_interval: int = 100):
+ """Initialize checkpoint manager.
+
+ Args:
+ checkpoint_dir: Directory to save checkpoints
+ activation_hooks: Dictionary of activation hooks to manage
+ checkpoint_interval: Save checkpoint every N batches
+ """
+ self.checkpoint_dir = Path(checkpoint_dir)
+ self.activation_hooks = activation_hooks
+ self.checkpoint_interval = checkpoint_interval
+ self.rank = dist.rank()
+ self.is_main_process = dist.is_master()
+
+ # Debug: Log checkpoint manager initialization
+ hook_count = len(activation_hooks) if activation_hooks else 0
+ aprint(
+ f"[Rank {self.rank}] Checkpoint manager initialized: {hook_count} hooks, dir: {checkpoint_dir}"
+ )
+
+ # Checkpoint files
+ self.progress_file = self.checkpoint_dir / "scoring_progress.json"
+ self.hook_states_file = self.checkpoint_dir / f"hook_states_rank_{self.rank}.pth"
+
+ # Progress tracking
+ self.current_batch = 0
+ self.total_batches = 0
+ self.start_time = time.time()
+
+ # Ensure directory exists
+ if self.is_main_process:
+ self.checkpoint_dir.mkdir(parents=True, exist_ok=True)
+
+ def load_checkpoint(self) -> dict[str, Any] | None:
+ """Load existing checkpoint if available, including hook states.
+
+ Returns:
+ Dict with checkpoint info or None if no checkpoint exists
+ """
+ aprint(f"[Rank {self.rank}] Looking for checkpoint at: {self.progress_file}")
+ if not self.progress_file.exists():
+ aprint(f"[Rank {self.rank}] No checkpoint file found at {self.progress_file}")
+ return None
+
+ try:
+ with open(self.progress_file) as f:
+ checkpoint_data = json.load(f)
+
+ # Validate checkpoint
+ if "current_batch" in checkpoint_data and "total_batches" in checkpoint_data:
+ self.current_batch = checkpoint_data["current_batch"]
+ self.total_batches = checkpoint_data["total_batches"]
+
+ mprint(
+ f"Found checkpoint: batch {self.current_batch}/{self.total_batches} ({checkpoint_data.get('progress', 0.0):.1%})"
+ )
+ mprint(
+ f"Will resume from batch {self.current_batch}, skipping batches 0-{self.current_batch - 1}"
+ )
+
+ # Load hook states if hooks are available
+ if self.activation_hooks is not None:
+ success = self.load_hook_states(self.activation_hooks)
+ if success:
+ aprint(
+ f"[Rank {self.rank}] Successfully loaded hook states from checkpoint"
+ )
+ else:
+ aprint(f"[Rank {self.rank}] Failed to load hook states - starting fresh")
+
+ return checkpoint_data
+ else:
+ aprint(
+ f"[Rank {self.rank}] Invalid checkpoint format (missing current_batch/total_batches): {checkpoint_data}"
+ )
+ return None
+
+ except (json.JSONDecodeError, KeyError) as e:
+ mprint(f"Error loading checkpoint: {e}")
+
+ return None
+
+ def load_hook_states(self, activation_hooks) -> bool:
+ """Load hook states from checkpoint files.
+
+ Args:
+ activation_hooks: Hook objects to load states into
+
+ Returns:
+ bool: True if hook states were successfully loaded, False otherwise
+ """
+ import os
+
+ # Each rank loads only its own hook states
+ current_rank = int(os.environ.get("RANK", 0))
+ hook_states_path = self.checkpoint_dir / f"hook_states_rank_{current_rank}.pth"
+
+ if hook_states_path.exists():
+ aprint(f"[Rank {current_rank}] Loading hook states from {hook_states_path}")
+ try:
+ import torch
+
+ hook_states = torch.load(hook_states_path, map_location="cpu")
+
+ # Load states into corresponding hooks
+ loaded_count = 0
+ for module_name, hook in activation_hooks.items():
+ if module_name in hook_states:
+ hook.load_state_dict(hook_states[module_name])
+ loaded_count += 1
+
+ # Log progress info if available (only for a few hooks to avoid spam)
+ if loaded_count <= 3: # Only log first few hooks
+ progress_info = hook.get_progress_info()
+ if progress_info:
+ aprint(f"[Rank {current_rank}] {module_name}: {progress_info}")
+ else:
+ aprint(
+ f"[Rank {current_rank}] Warning: No saved state found for hook: {module_name}"
+ )
+
+ aprint(
+ f"[Rank {current_rank}] Successfully loaded states for {loaded_count}/{len(activation_hooks)} hooks"
+ )
+ return True
+
+ except Exception as e:
+ aprint(f"[Rank {current_rank}] Error loading hook states: {e}")
+ return False
+ else:
+ aprint(f"[Rank {current_rank}] No hook states file found at {hook_states_path}")
+ return False
+
+ def should_skip_batch(self, batch_idx: int) -> bool:
+ """Check if we should skip this batch (already processed in previous run)."""
+ should_skip = batch_idx < self.current_batch
+ if should_skip and batch_idx % 10 == 0: # Log every 10th skipped batch to avoid spam
+ mprint(f"Skipping batch {batch_idx} (resume from batch {self.current_batch})")
+ return should_skip
+
+ def update_progress(self, batch_idx: int, total_batches: int):
+ """Update progress and potentially save checkpoint.
+
+ Args:
+ batch_idx: Current batch index
+ total_batches: Total number of batches
+ """
+ self.current_batch = batch_idx
+ self.total_batches = total_batches
+
+ # Save checkpoint periodically or on completion
+ should_save = (
+ (batch_idx % self.checkpoint_interval == 0) # Periodic save
+ or (batch_idx == total_batches - 1) # Final batch
+ )
+
+ if should_save:
+ # All ranks save their hook states
+ if self.activation_hooks is not None:
+ try:
+ from modelopt.torch.prune.importance_hooks.base_hooks import ForwardHook
+
+ ForwardHook.save_hook_states(self.activation_hooks, self.checkpoint_dir)
+ except Exception as e:
+ mprint(f"Warning: Failed to save hook states: {e}")
+
+ # Only main process saves progress info
+ if self.is_main_process:
+ self.save_checkpoint()
+
+ # Synchronize all ranks after checkpointing
+ dist.barrier()
+
+ def save_checkpoint(self):
+ """Save current checkpoint to disk (progress info only).
+ Hook states are saved separately in update_progress.
+ """
+ try:
+ # Save progress
+ progress_data = {
+ "current_batch": self.current_batch,
+ "total_batches": self.total_batches,
+ "progress": self.current_batch / self.total_batches
+ if self.total_batches > 0
+ else 0.0,
+ "timestamp": time.time(),
+ "elapsed_time": time.time() - self.start_time,
+ "rank": self.rank,
+ }
+
+ # Write progress atomically
+ temp_file = self.progress_file.with_suffix(".tmp")
+ with open(temp_file, "w") as f:
+ json.dump(progress_data, f, indent=2)
+ temp_file.replace(self.progress_file)
+
+ # Hook states are saved at a higher level to ensure all ranks participate
+
+ if self.current_batch % (self.checkpoint_interval) == 0:
+ progress_pct = progress_data["progress"] * 100
+ elapsed = progress_data["elapsed_time"]
+ mprint(
+ f"Checkpoint saved: batch {self.current_batch}/{self.total_batches} ({progress_pct:.1f}%), elapsed: {elapsed:.1f}s"
+ )
+
+ except Exception as e:
+ mprint(f"Error saving checkpoint: {e}")
+
+ def finalize(self):
+ """Mark scoring as completed."""
+ # All ranks save their final hook states
+ if self.activation_hooks is not None:
+ try:
+ from modelopt.torch.prune.importance_hooks.base_hooks import ForwardHook
+
+ saved_path = ForwardHook.save_hook_states(
+ self.activation_hooks, self.checkpoint_dir
+ )
+ mprint(f"Final hook states saved to {saved_path}")
+ except Exception as e:
+ mprint(f"Warning: Failed to save final hook states: {e}")
+
+ # Only main process saves progress info
+ if self.is_main_process:
+ self.current_batch = self.total_batches
+ self.save_checkpoint()
+ mprint(f"Scoring completed and finalized: {self.total_batches} batches processed")
+
+ # Synchronize all ranks after finalization
+ dist.barrier()
diff --git a/modelopt/torch/puzzletron/utils/data/dataloaders.py b/modelopt/torch/puzzletron/utils/data/dataloaders.py
new file mode 100644
index 0000000000..892d1f3c2c
--- /dev/null
+++ b/modelopt/torch/puzzletron/utils/data/dataloaders.py
@@ -0,0 +1,203 @@
+# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+# SPDX-License-Identifier: Apache-2.0
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""DataLoader utilities for language model training and validation."""
+
+from collections.abc import Callable, Mapping, Sequence
+from functools import partial
+from typing import Protocol, TypeVar
+
+import datasets
+import torch
+import torch.distributed
+from accelerate import Accelerator
+from torch.utils.data import DataLoader, Dataset, IterableDataset
+from torch.utils.data._utils.collate import collate, default_collate_fn_map
+from tqdm import tqdm
+from transformers import PreTrainedTokenizerBase
+
+from modelopt.torch.puzzletron.tools.logger import mprint
+from modelopt.torch.puzzletron.utils.data.dataset import ConstantLengthDataset
+
+
+def collate_none_fn(
+ batch, *, collate_fn_map: dict[type | tuple[type, ...], Callable] | None = None
+):
+ return None
+
+
+collate_fn_map_with_none_support = {**default_collate_fn_map, type(None): collate_none_fn}
+collate_fn_with_none_support = partial(collate, collate_fn_map=collate_fn_map_with_none_support)
+
+
+class LoadDatasetFn(Protocol):
+ def __call__(
+ self, dataset_path: str, content_field: str, keep_in_memory: bool = False
+ ) -> Mapping[str, Dataset]: ...
+
+
+def load_from_disk_fn(
+ dataset_path: str, content_field: str, keep_in_memory: bool = False
+) -> Mapping[str, Dataset]:
+ return datasets.load_from_disk(dataset_path, keep_in_memory=keep_in_memory)
+
+
+def load_streaming_fn(
+ dataset_path: str, content_field: str, keep_in_memory: bool = False
+) -> Mapping[str, Dataset]:
+ dataset = datasets.load_dataset(
+ dataset_path,
+ streaming=True,
+ features=datasets.Features(
+ {
+ content_field: datasets.Value(dtype="string"),
+ }
+ ),
+ keep_in_memory=keep_in_memory,
+ )
+
+ return dataset
+
+
+def create_validation_dataloader(
+ accelerator: Accelerator | None,
+ seed: int,
+ tokenizer: PreTrainedTokenizerBase,
+ block_size: int,
+ dataset: str | Mapping[str, Dataset],
+ content_field: str,
+ fim_rate: float,
+ fim_spm_rate: float,
+ micro_batch_size: int,
+ eval_samples: int | None = None,
+ load_dataset_fn: LoadDatasetFn = load_from_disk_fn,
+ dataset_name: str = "__auto__",
+ keep_in_memory: bool = False,
+ source_datasets_to_discard: Sequence[str] = (),
+ bos_rate: float = 1.0,
+ varlen: bool = True,
+ shuffle_seed: int | None = None,
+):
+ if accelerator is None:
+ accelerator = Printer()
+
+ if accelerator.is_main_process:
+ if isinstance(dataset, str):
+ dataset = load_dataset_fn(dataset, content_field, keep_in_memory)
+
+ if isinstance(dataset, datasets.Dataset | torch.utils.data.Dataset):
+ valid_data = dataset
+ mprint(
+ "#### Path to specific dataset was given (not DatasetDict), taking it as-is ####"
+ )
+ else:
+ assert isinstance(dataset, datasets.DatasetDict)
+ if dataset_name == "__auto__":
+ val_split_options = []
+ for val_key_prefix in ("val", "test"):
+ if len(val_split_options) == 0:
+ val_split_options = [
+ split
+ for split in dataset # DatasetDict is dict-like and supports direct iteration
+ if split.lower().startswith(val_key_prefix)
+ ]
+ assert len(val_split_options) == 1, (
+ f"Expected exactly one validation split, got {val_split_options=} ({dataset.keys()=})"
+ )
+ val_split = val_split_options[0]
+ mprint(f"Inferred validation split automatically: '{val_split}'")
+ else:
+ val_split = dataset_name
+ mprint(f"Validation split explicitly chosen: '{val_split}'")
+ valid_data = dataset[val_split]
+
+ if shuffle_seed is not None:
+ mprint(f"Shuffling with {shuffle_seed=}")
+ valid_data = valid_data.shuffle(seed=shuffle_seed)
+
+ valid_dataset = ConstantLengthDataset(
+ tokenizer,
+ valid_data,
+ infinite=False,
+ seq_length=block_size * micro_batch_size if varlen else block_size,
+ content_field=content_field,
+ fim_rate=fim_rate,
+ fim_spm_rate=fim_spm_rate,
+ seed=seed,
+ source_datasets_to_discard=source_datasets_to_discard,
+ bos_rate=bos_rate,
+ # return_cu_seqlens=varlen,
+ # seqlen_cap=block_size if varlen else None
+ )
+ if varlen and eval_samples is not None:
+ eval_samples = eval_samples // micro_batch_size
+ val_offloaded_dataset = realize_dataset_in_memory(valid_dataset, eval_samples)
+
+ valid_data_len = len(val_offloaded_dataset)
+ mprint(f"num validation examples = {valid_data_len}")
+ else:
+ val_offloaded_dataset = None
+
+ if not isinstance(accelerator, Printer):
+ obj_list = [val_offloaded_dataset]
+ torch.distributed.broadcast_object_list(obj_list)
+ val_offloaded_dataset = obj_list[0]
+
+ # let accelerate prepare to handle distributed sampling
+ val_dataloader = DataLoader(
+ val_offloaded_dataset,
+ batch_size=1 if varlen else micro_batch_size,
+ pin_memory=True,
+ collate_fn=collate_fn_with_none_support,
+ )
+
+ return val_dataloader
+
+
+def realize_dataset_in_memory(dataset: IterableDataset, eval_samples: int | None) -> list[dict]:
+ tqdm_desc = f"realize_dataset_in_memory({eval_samples=})"
+ if eval_samples is None:
+ offloaded_dataset = list(tqdm(dataset, desc=tqdm_desc))
+ else:
+ val_iter = iter(dataset)
+ offloaded_dataset = [next(val_iter) for _ in tqdm(range(eval_samples), desc=tqdm_desc)]
+ return offloaded_dataset
+
+
+TensorT = TypeVar("TensorT", bound=torch.Tensor)
+
+
+@torch.no_grad()
+def create_padded_tensor(
+ tensor: TensorT, desired_shape: Sequence[int], padding_value: float = 0
+) -> TensorT:
+ if tensor.shape == torch.Size(desired_shape):
+ return tensor
+
+ padded_tensor = torch.full(
+ desired_shape, fill_value=padding_value, dtype=tensor.dtype, device=tensor.device
+ )
+ indices = torch.where(torch.ones_like(tensor, dtype=torch.bool))
+ padded_tensor[indices] = tensor.view(-1)
+ return padded_tensor
+
+
+class Printer:
+ is_main_process = True
+ process_index = None
+
+ @staticmethod
+ def print(*args, **kwargs) -> None:
+ print(*args, **kwargs)
diff --git a/modelopt/torch/puzzletron/utils/data/dataset.py b/modelopt/torch/puzzletron/utils/data/dataset.py
new file mode 100644
index 0000000000..fffc2a3a1d
--- /dev/null
+++ b/modelopt/torch/puzzletron/utils/data/dataset.py
@@ -0,0 +1,314 @@
+# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+# SPDX-License-Identifier: Apache-2.0
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# mypy: ignore-errors
+import functools
+from collections.abc import Sequence
+
+import numpy as np
+import torch
+from torch.utils.data import IterableDataset
+
+FIM_TOKEN_START = "", "middle>", "suffix>", "pad>"]
+CODEGEN_FIM_TOKENS = ["", "<|endoftext|>", ""]
+
+
+class ConstantLengthDataset(IterableDataset):
+ """Iterable dataset that returns constant length chunks of tokens from stream of text files.
+
+ Args:
+ tokenizer (Tokenizer): The processor used for proccessing the data.
+ dataset (dataset.Dataset): Dataset with text files.
+ infinite (bool): If True the iterator is reset after dataset reaches end else stops.
+ seq_length (int): Length of token sequences to return.
+ num_of_sequences (int): Number of token sequences to keep in buffer.
+ chars_per_token (int): Number of characters per token used to estimate number of tokens in text buffer.
+ fim_rate (float): Rate (0.0 to 1.0) that sample will be permuted with FIM.
+ fim_spm_rate (float): Rate (0.0 to 1.0) of FIM permuations that will use SPM.
+ seed (int): Seed for random number generator.
+ label_shift (bool): Whether to shift labels by 1 or not.
+ """
+
+ def __init__(
+ self,
+ tokenizer,
+ dataset,
+ infinite=False,
+ seq_length=1024,
+ num_of_sequences=1024,
+ chars_per_token=3.6,
+ content_field="content",
+ fim_rate=0.5,
+ fim_spm_rate=0.5,
+ seed=0,
+ label_shift=True,
+ max_sample_length=200_000,
+ tokens_field="token_ids",
+ source_datasets_to_discard: Sequence[str] | None = tuple(),
+ bos_rate: float = 1.0,
+ return_cu_seqlens: bool = False,
+ seqlen_cap: int | None = None,
+ ):
+ self.tokenizer = tokenizer
+ self.concat_token_id = tokenizer.eos_token_id
+ # self.concat_token_id = tokenizer.eos_id # for lit-lamma tokenizer
+ self.dataset = dataset
+ self.is_dataset_already_tokenized = tokens_field in self.dataset.column_names
+ self.seq_length = seq_length
+ self.infinite = infinite
+ self.current_size = 0
+ if not self.is_dataset_already_tokenized:
+ self.max_buffer_size = seq_length * chars_per_token * num_of_sequences
+ self.max_sample_length = max_sample_length
+ else:
+ self.max_buffer_size = seq_length * num_of_sequences
+ # self.max_sample_length = int(max_sample_length / chars_per_token)
+ self.max_sample_length = max_sample_length # we don't know the exact chars_per_token
+ self.content_field = content_field
+ self.tokens_field = tokens_field
+ self.fim_rate = fim_rate
+ self.fim_spm_rate = fim_spm_rate
+ self.seed = seed
+ self.max_sample_length = max_sample_length
+
+ self.fim_token_ids = get_fim_token_ids(self.tokenizer)
+ if None in self.fim_token_ids.values() and self.fim_rate > 0:
+ self.fim_rate = 0
+ self.label_shift = label_shift
+ self.bos_rate = bos_rate
+ self.source_datasets_to_discard = (
+ source_datasets_to_discard if source_datasets_to_discard is not None else tuple()
+ )
+ self.return_cu_seqlens = return_cu_seqlens
+ self.seqlen_cap = seqlen_cap
+ self.np_rng = np.random.RandomState(seed=self.seed)
+
+ def __iter__(self) -> dict[str, torch.Tensor]:
+ iterator = iter(self.dataset)
+ more_examples = True
+ while more_examples:
+ buffer, buffer_len = [], 0
+ while True:
+ if buffer_len >= self.max_buffer_size:
+ break
+ try:
+ sample = next(iterator)
+ if (
+ len(self.source_datasets_to_discard) > 0
+ and sample["dataset_name"] in self.source_datasets_to_discard
+ ):
+ continue
+ if not self.is_dataset_already_tokenized:
+ sample = sample[self.content_field]
+ if (
+ isinstance(sample, list)
+ and isinstance(sample[0], dict)
+ and {"content", "role"}.issubset(sample[0])
+ ):
+ if len(sample) > 1:
+ sample = self.tokenizer.apply_chat_template(sample, tokenize=False)
+ else:
+ sample = sample[0]["content"]
+ else:
+ sample = sample[self.tokens_field]
+ sample = sample[: self.max_sample_length]
+ buffer.append(sample)
+ buffer_len += len(sample)
+ except StopIteration:
+ if self.infinite:
+ iterator = iter(self.dataset)
+ else:
+ more_examples = False
+ break
+
+ if not self.is_dataset_already_tokenized:
+ tokenized_inputs = self.tokenizer(buffer, truncation=False)["input_ids"]
+ else:
+ tokenized_inputs = buffer
+
+ all_token_ids = []
+
+ for tokenized_input in tokenized_inputs:
+ if (
+ self.bos_rate < 1.0
+ and not self.np_rng.binomial(1, self.bos_rate)
+ and self.tokenizer.bos_token_id is not None
+ and tokenized_input[0] == self.tokenizer.bos_token_id
+ ):
+ tokenized_input = tokenized_input[1:]
+ # optionally do FIM permutations
+ if self.fim_rate > 0:
+ tokenized_input, np_rng = permute(
+ sample=tokenized_input,
+ np_rng=self.np_rng,
+ fim_token_ids=self.fim_token_ids,
+ fim_rate=self.fim_rate,
+ fim_spm_rate=self.fim_spm_rate,
+ truncate_or_pad=False,
+ )
+
+ all_token_ids.extend(tokenized_input + [self.concat_token_id])
+
+ examples = []
+ # cuts code snippets in the middle to yield constant length instances
+ for i in range(0, len(all_token_ids), self.seq_length):
+ input_ids = all_token_ids[i : i + self.seq_length]
+ labels = all_token_ids[
+ i + int(self.label_shift) : i + int(self.label_shift) + self.seq_length
+ ]
+ # ignores last short example in the buffer
+ if len(labels) == self.seq_length:
+ examples.append((input_ids, labels))
+
+ shuffling_indices = self.np_rng.permutation(len(examples))
+ examples = [examples[i] for i in shuffling_indices]
+
+ for input_ids, labels in examples:
+ self.current_size += 1
+ input_ids = torch.LongTensor(input_ids)
+ if self.return_cu_seqlens:
+ cu_seqlens = self.prepare_cu_seqlens(input_ids)
+ yield {
+ "input_ids": input_ids,
+ "targets": torch.LongTensor(labels),
+ "cu_seqlens": cu_seqlens,
+ }
+ else:
+ yield {
+ "input_ids": input_ids,
+ "targets": torch.LongTensor(labels),
+ }
+
+ def prepare_cu_seqlens(self, input_ids):
+ if not self.return_cu_seqlens:
+ return None
+ # seqlens is of shape (num_seqs+1,) and with the property that
+ # the i-th sequnce is input_ids[seqlens[i-1]:seqlens[i]]
+ cu_seqlens = (input_ids == self.concat_token_id).nonzero().squeeze(-1).int() + 1
+ cu_seqlens = torch.cat(
+ (
+ torch.IntTensor([0]),
+ cu_seqlens,
+ torch.IntTensor([len(input_ids)]),
+ )
+ )
+ if self.seqlen_cap is not None:
+ i = 1
+ while i < len(cu_seqlens):
+ curr_seqlen = cu_seqlens[i] - cu_seqlens[i - 1]
+ if curr_seqlen > self.seqlen_cap:
+ cu_seqlens = torch.cat(
+ (cu_seqlens[:i], cu_seqlens[[i - 1]] + self.seqlen_cap, cu_seqlens[i:])
+ )
+ i += 1
+ if cu_seqlens[-1] == cu_seqlens[-2]:
+ cu_seqlens = cu_seqlens[:-1]
+ return cu_seqlens
+
+
+## Adapted from https://github.com/NVIDIA/Megatron-LM/blob/6c4bf908df8fd86b4977f54bf5b8bd4b521003d1/megatron/data/gpt_dataset.py
+def permute(
+ sample,
+ np_rng,
+ fim_token_ids,
+ fim_rate=0.5,
+ fim_spm_rate=0.5,
+ truncate_or_pad=False,
+):
+ """Take in a sample (list of tokens) and perform a FIM transformation on it with a probability of fim_rate, using two FIM modes:
+ PSM and SPM (with a probability of fim_spm_rate).
+ """
+ if np_rng.binomial(1, fim_rate):
+ boundaries = list(np_rng.randint(low=0, high=len(sample) + 1, size=2))
+ boundaries.sort()
+
+ prefix = np.array(sample[: boundaries[0]], dtype=np.int64)
+ middle = np.array(sample[boundaries[0] : boundaries[1]], dtype=np.int64)
+ suffix = np.array(sample[boundaries[1] :], dtype=np.int64)
+
+ if truncate_or_pad:
+ raise NotImplementedError
+
+ if "" in fim_token_ids: # use codegen FIM pattern
+ assert fim_spm_rate == 0
+ new_sample = np.concatenate(
+ [
+ prefix,
+ [fim_token_ids[""]],
+ suffix,
+ [fim_token_ids["<|endoftext|>"]],
+ [fim_token_ids[""]],
+ [fim_token_ids[""]],
+ middle,
+ ]
+ )
+ elif np_rng.binomial(1, fim_spm_rate):
+ # SPM (variant 2 from FIM paper)
+ new_sample = np.concatenate(
+ [
+ [fim_token_ids["prefix_tok_id"], fim_token_ids["suffix_tok_id"]],
+ suffix,
+ [fim_token_ids["middle_tok_id"]],
+ prefix,
+ middle,
+ ]
+ )
+ else:
+ # PSM
+ new_sample = np.concatenate(
+ [
+ [fim_token_ids["prefix_tok_id"]],
+ prefix,
+ [fim_token_ids["suffix_tok_id"]],
+ suffix,
+ [fim_token_ids["middle_tok_id"]],
+ middle,
+ ]
+ )
+ else:
+ # don't do FIM preproc
+ new_sample = sample
+
+ return list(new_sample), np_rng
+
+
+# this is expensive so we cache it
+@functools.lru_cache(maxsize=None)
+def get_fim_token_ids(tokenizer):
+ # ugly fix for Salesforce/codegen25-7b-multi tokenizer
+ if hasattr(tokenizer, "encoder"):
+ search_vocab = tokenizer.encoder._special_tokens
+ fim_token_ids = {tok: search_vocab.get(tok, None) for tok in CODEGEN_FIM_TOKENS}
+ else:
+ search_vocab = tokenizer.vocab
+ if (FIM_TOKEN_START + FIM_TOKEN_CONNECTOR_STAR + FIM_TOKEN_END_LIST[0]) in search_vocab:
+ prefix_tok_id, middle_tok_id, suffix_tok_id, pad_tok_id = (
+ search_vocab.get(FIM_TOKEN_START + FIM_TOKEN_CONNECTOR_STAR + tok, None)
+ for tok in FIM_TOKEN_END_LIST
+ )
+ else:
+ prefix_tok_id, middle_tok_id, suffix_tok_id, pad_tok_id = (
+ search_vocab.get(FIM_TOKEN_START + FIM_TOKEN_CONNECTOR_SANTA + tok, None)
+ for tok in FIM_TOKEN_END_LIST
+ )
+ fim_token_ids = {
+ "suffix_tok_id": suffix_tok_id,
+ "prefix_tok_id": prefix_tok_id,
+ "middle_tok_id": middle_tok_id,
+ "pad_tok_id": pad_tok_id,
+ }
+ return fim_token_ids
diff --git a/modelopt/torch/puzzletron/utils/dummy_modules.py b/modelopt/torch/puzzletron/utils/dummy_modules.py
new file mode 100644
index 0000000000..c9eaa2bc6c
--- /dev/null
+++ b/modelopt/torch/puzzletron/utils/dummy_modules.py
@@ -0,0 +1,75 @@
+# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+# SPDX-License-Identifier: Apache-2.0
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from typing import Optional
+
+import torch
+import torch.nn as nn
+from transformers import PretrainedConfig
+from typing_extensions import override
+
+
+class DummyModule(nn.Module):
+ def __init__(self, *args, **kwargs) -> None:
+ super().__init__(*args, **kwargs)
+ self.register_load_state_dict_post_hook(self.load_state_dict_post_hook)
+
+ @staticmethod
+ def load_state_dict_post_hook(
+ module: torch.nn.Module,
+ incompatible_keys: torch.nn.modules.module._IncompatibleKeys,
+ ) -> None:
+ incompatible_keys.missing_keys.clear()
+ incompatible_keys.unexpected_keys.clear()
+
+
+class DummyBlock(DummyModule):
+ def __init__(self, block_index: int):
+ super().__init__()
+ self.block_index = block_index
+
+ @override
+ def forward(
+ self,
+ x: torch.Tensor,
+ *args,
+ **kwargs,
+ ) -> torch.Tensor | tuple[torch.Tensor, None]:
+ return x
+
+
+class DummyWTE(DummyModule):
+ def __init__(self, hidden_size: int, dtype: Optional[torch.dtype] = None):
+ super().__init__()
+ self.n_embd = hidden_size
+ self.dtype = dtype
+
+ @override
+ def forward(self, input_ids: torch.Tensor) -> torch.Tensor:
+ B, T = input_ids.shape
+ result = torch.ones((B, T, self.n_embd), dtype=self.dtype, device=input_ids.device)
+ return result
+
+
+class DummyLMHead(DummyModule):
+ def __init__(self, config: PretrainedConfig):
+ super().__init__()
+ self.vocab_size = config.vocab_size
+
+ @override
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ B, T, C = x.shape
+ result = torch.ones((B, T, self.vocab_size), dtype=x.dtype, device=x.device)
+ return result
diff --git a/modelopt/torch/puzzletron/utils/parsing.py b/modelopt/torch/puzzletron/utils/parsing.py
new file mode 100644
index 0000000000..ff5bb6963a
--- /dev/null
+++ b/modelopt/torch/puzzletron/utils/parsing.py
@@ -0,0 +1,455 @@
+# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+# SPDX-License-Identifier: Apache-2.0
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""
+Parsing and formatting utilities for configuration handling in model compression.
+
+This module provides utilities for:
+- Parsing command-line arguments and configuration strings
+- Formatting and displaying model configurations (block configs, attention, FFN)
+- Formatting loss metrics for logging and visualization
+"""
+# mypy: ignore-errors
+
+import json
+from pathlib import Path
+from typing import Any
+
+import torch
+from omegaconf import DictConfig
+
+
+def handle_arg_string(arg):
+ if arg.lower() == "true":
+ return True
+ elif arg.lower() == "false":
+ return False
+ elif arg.isnumeric():
+ return int(arg)
+ try:
+ return float(arg)
+ except ValueError:
+ return arg
+
+
+def simple_parse_args_string(args_string):
+ """
+ Parses something like
+ args1=val1,arg2=val2
+ Into a dictionary
+ """
+ if args_string is None:
+ return {}
+ args_string = args_string.strip()
+ if not args_string:
+ return {}
+ arg_list = [arg for arg in args_string.split(",") if arg]
+ args_dict = {k: handle_arg_string(v) for k, v in [arg.split("=") for arg in arg_list]}
+ return args_dict
+
+
+def parse_json(s: str | None) -> Any:
+ if s is None:
+ return None
+ return json.loads(s)
+
+
+def parse_path(s: str | None) -> Path | None:
+ if s is None or s == "":
+ return None
+ return Path(s)
+
+
+def parse_dtype(dtype_name: str) -> torch.dtype:
+ dtype = {
+ "bf16": torch.bfloat16,
+ "bfloat16": torch.bfloat16,
+ "fp32": torch.float32,
+ "float32": torch.float32,
+ "fp16": torch.float16,
+ "float16": torch.float16,
+ }[dtype_name]
+ return dtype
+
+
+def get_nested_key(dictionary: dict[str, Any], nested_key: str) -> Any:
+ """
+ If nested_key is "a.b.c" returns dictionary["a"]["b"]["c"]
+ """
+ value = dictionary
+ for key in nested_key.split("."):
+ value = value[key]
+ return value
+
+
+def format_block_configs(config) -> str:
+ """
+ Formats block_configs from a model configuration into a beautiful, readable string.
+
+ Each line represents a layer with attention and FFN configuration.
+
+ Args:
+ config: PretrainedConfig object containing block_configs
+
+ Returns:
+ Formatted string with layer configurations
+
+ Example output:
+ ╭─────────────────────── Model Architecture ────────────────────────╮
+ │ Layer 1 │ Attention: no_op │ FFN: mult = 4.95 │
+ │ Layer 2 │ Attention: 4 heads in group │ FFN: mult = 4.95 │
+ │ Layer 3 │ Attention: 4 heads in group │ FFN: no_op │
+ ╰────────────────────────────────────────────────────────────────────╯
+ """
+ if not hasattr(config, "block_configs") or not config.block_configs:
+ return "❌ No block configs found"
+
+ lines = []
+
+ # Header
+ header = "╭─────────────────────────────────────── Model Architecture ────────────────────────────────────────╮"
+ lines.append(header)
+
+ # Format each layer
+ for i, block in enumerate(config.block_configs, 1):
+ attention_info = _format_attention_config(block.attention)
+ ffn_info = _format_ffn_config(block.ffn)
+
+ # Create formatted line with proper padding
+ layer_str = f"Layer {i:2d}"
+ attention_str = f"Attention: {attention_info}"
+ ffn_str = f"FFN: {ffn_info}"
+
+ line = f"│ {layer_str:8s} │ {attention_str:30s} │ {ffn_str:18s} │"
+ lines.append(line)
+
+ # Footer
+ footer = "╰────────────────────────────────────────────────────────────────────────────────────────────────────╯"
+ lines.append(footer)
+
+ return "\n".join(lines)
+
+
+def _format_attention_config(attention_config) -> str:
+ """Format attention configuration for display with visual indicators."""
+ if not attention_config:
+ return "default"
+
+ if attention_config.no_op:
+ return "❌ no_op"
+
+ num_kv_heads = attention_config.num_key_value_heads
+ if num_kv_heads is not None:
+ return f"{num_kv_heads} kv heads"
+
+ if attention_config.replace_with_linear:
+ return "linear replacement"
+
+ # Check for other attention types
+ if attention_config.mamba:
+ return "🐍 mamba"
+ if attention_config.llama4:
+ return "🦙 llama4"
+
+ window_length = attention_config.window_length
+ if window_length is not None:
+ return f"windowed ({window_length})"
+
+ if attention_config.sparsify:
+ return "sparse"
+
+ return "default"
+
+
+def _format_ffn_config(ffn_config) -> str:
+ """Format FFN configuration for display with visual indicators."""
+ if not ffn_config:
+ return "default"
+
+ if ffn_config.no_op:
+ return "❌ no_op"
+
+ if ffn_config.replace_with_linear:
+ return "linear"
+
+ ffn_intermediate = ffn_config.intermediate_size
+ if ffn_intermediate is not None:
+ return f"ffn_intermediate = {ffn_intermediate}"
+
+ # Check for MoE configuration
+ moe_config = ffn_config.moe
+ if moe_config:
+ return "MoE"
+
+ if ffn_config.sparsify:
+ return "sparse"
+
+ return "default"
+
+
+def format_global_config(config: DictConfig, title: str = "Global Configuration") -> str:
+ """
+ Pretty prints a global DictConfig with nice formatting and visual indicators.
+
+ Args:
+ config: DictConfig object to format
+ title: Title to display at the top of the formatted output
+
+ Returns:
+ Formatted string with configuration details
+
+ Example output:
+ ╭─────────────────── Global Configuration ────────────────────╮
+ │ Training │
+ │ • learning_rate: 1e-4 │
+ │ • batch_size: 32 │
+ │ • epochs: 100 │
+ │ Model │
+ │ • hidden_dim: 512 │
+ │ • num_layers: 6 │
+ │ Data │
+ │ • dataset_path: /path/to/data │
+ │ • block_size: 2048 │
+ ╰──────────────────────────────────────────────────────────────╯
+ """
+ if not config:
+ return "❌ No configuration found"
+
+ lines = []
+
+ # Calculate box width based on title
+ box_width = max(60, len(title) + 10)
+ title_padding = (box_width - len(title) - 2) // 2
+
+ # Header
+ header = f"\n╭{'─' * (box_width - 2)}╮"
+ title_line = (
+ f"│{' ' * title_padding}{title}{' ' * (box_width - 2 - title_padding - len(title))}│"
+ )
+ lines.extend([header, title_line])
+
+ def _format_value(value: Any, indent: int = 0) -> str:
+ """Format a value with appropriate type indicators."""
+ prefix = " " * indent
+
+ if isinstance(value, (bool, int, float)):
+ return f"{prefix} {value}"
+ elif isinstance(value, str):
+ # Show truncated long strings
+ if len(value) > 50:
+ return f"{prefix} {value[:47]}..."
+ return f"{prefix} {value}"
+ elif isinstance(value, (list, tuple)):
+ if not value:
+ return f"{prefix} []"
+ elif len(value) <= 3:
+ return f"{prefix} {list(value)}"
+ else:
+ return f"{prefix} [{len(value)} items]"
+ elif value is None:
+ return f"{prefix} None"
+ else:
+ return f"{prefix} {value!s}"
+
+ def _add_config_section(cfg: DictConfig, section_name: str = "", indent: int = 0):
+ """Recursively add configuration sections."""
+ if section_name:
+ indent_str = " " * indent
+ section_line = f"│ {indent_str}{section_name}"
+ # Pad to box width
+ padding_needed = box_width - len(section_line) - 1
+ section_line += " " * padding_needed + "│"
+ lines.append(section_line)
+
+ for key, value in cfg.items():
+ if isinstance(value, DictConfig):
+ # Nested configuration section
+ _add_config_section(value, f"{key}", indent + 1)
+ else:
+ # Regular key-value pair
+ indent_str = " " * (indent + 1)
+ value_str = _format_value(value).replace(" " * 0, "").strip()
+ line = f"│ {indent_str} {key}: {value_str}"
+ # Pad to box width
+ if len(line) >= box_width - 1:
+ # Truncate long lines
+ line = line[: box_width - 4] + "..."
+ padding_needed = box_width - len(line) - 1
+ line += " " * padding_needed + "│"
+ lines.append(line)
+
+ # Add configuration sections
+ _add_config_section(config)
+
+ # Footer
+ footer = f"╰{'─' * (box_width - 2)}╯"
+ lines.append(footer)
+
+ return "\n".join(lines)
+
+
+def format_stitched_losses(
+ losses_dict: dict[str, float],
+ best_steps_dict: dict[str, int] | None = None,
+ best_values_dict: dict[str, float] | None = None,
+ step_number: int | None = None,
+ title: str = "Stitched Module Losses",
+) -> str:
+ """
+ Pretty prints stitched module losses with comprehensive tracking and visual indicators.
+
+ Args:
+ losses_dict: Dictionary with block names as keys and current loss values as floats
+ best_steps_dict: Optional dictionary with block names as keys and best step numbers as values
+ best_values_dict: Optional dictionary with block names as keys and best loss values as floats
+ step_number: Optional current step number to include in summary
+ title: Title to display at the top of the formatted output
+
+ Returns:
+ Formatted string with loss values in a comprehensive table format
+
+ Example output:
+ ╭─────────────────── Stitched Module Losses ──────────────────╮
+ │ Block │ Loss Value │ Best Step │ Best Value │ Change from avg │
+ │───────┼────────────┼───────────┼────────────┼──────────────────│
+ │ 00 │ 6.21e-03 │ Step 5 │ 5.95e-03 │ ↑ +2.6e-04 │
+ │ 01 │ 5.14e-04 │ Step 12 │ 5.14e-04 │ ↓ -1.2e-04 │
+ │ 02 │ 9.84e-05 │ Step 15 │ 9.84e-05 │ ↓ -3.1e-04 │
+ ╰──────────────────────────────────────────────────────────────╯
+ """
+ if not losses_dict:
+ return "❌ No losses found"
+
+ lines = []
+
+ # Calculate statistics
+ loss_values = list(losses_dict.values())
+ max_loss = max(loss_values)
+ min_loss = min(loss_values)
+ avg_loss = sum(loss_values) / len(loss_values)
+
+ # Calculate box width for new layout (removed Bar column)
+ box_width = 74
+ title_padding = (box_width - len(title) - 2) // 2
+
+ # Header
+ header = f"╭{'─' * (box_width - 2)}╮"
+ title_line = (
+ f"│{' ' * title_padding}{title}{' ' * (box_width - 2 - title_padding - len(title))}│"
+ )
+ separator = (
+ f"│ {'Block':<5} │ {'Loss Value':<12} │ {'Best Step':<10} │ "
+ f"{'Best Value':<12} │ {'Change from avg':<18} │"
+ )
+ divider = f"│{'─' * 7}┼{'─' * 14}┼{'─' * 12}┼{'─' * 14}┼{'─' * 20}│"
+
+ lines.extend([header, title_line, separator, divider])
+
+ # Format each loss
+ for block_name, loss_value in losses_dict.items():
+ # Format current loss value
+ loss_str = f"{loss_value:.2e}"
+
+ # Format best step
+ if best_steps_dict and block_name in best_steps_dict:
+ best_step_str = f"Step {best_steps_dict[block_name]}"
+ else:
+ best_step_str = " --"
+
+ # Format best value
+ if best_values_dict and block_name in best_values_dict:
+ best_value = best_values_dict[block_name]
+ best_value_str = f"{best_value:.2e}"
+ else:
+ best_value = loss_value # Assume current is best if no history
+ best_value_str = f"{best_value:.2e}"
+
+ # Calculate change from average
+ change_from_avg = loss_value - avg_loss
+ if abs(change_from_avg) > 1e-8: # Only show if meaningful
+ change_str = f"{abs(change_from_avg):.1e}"
+ if change_from_avg > 0:
+ # Current is above average (worse for loss)
+ change_display = f"↑ +{change_str}"
+ else:
+ # Current is below average (better for loss)
+ change_display = f"↓ -{change_str}"
+ else:
+ # At average value
+ change_display = "↔ 0.0e+00"
+
+ # Format the line
+ block_display = block_name.replace("block_", "").zfill(2)
+
+ line = (
+ f"│ {block_display:<5} │ {loss_str:<12} │ {best_step_str:<10} │ "
+ f"{best_value_str:<12} │ {change_display:<18} │"
+ )
+ lines.append(line)
+
+ # Add summary statistics
+ lines.append(divider)
+
+ # Build summary string with optional step number
+ summary_parts = []
+ if step_number is not None:
+ summary_parts.append(f"Step {step_number}")
+ summary_parts.extend([f"Avg={avg_loss:.2e}", f"Max={max_loss:.2e}", f"Min={min_loss:.2e}"])
+
+ summary_text = ", ".join(summary_parts)
+ summary = f"│ Summary: {summary_text}"
+
+ # Pad summary to box width
+ padding_needed = box_width - len(summary) - 1
+ summary += " " * padding_needed + "│"
+ lines.append(summary)
+
+ # Add best step summary if we have best step data
+ if best_steps_dict and best_values_dict:
+ # Find the most common best step (modal step)
+ step_counts = {}
+ for step in best_steps_dict.values():
+ step_counts[step] = step_counts.get(step, 0) + 1
+
+ if step_counts:
+ modal_best_step = max(step_counts, key=step_counts.get)
+
+ # Get values at the modal best step for blocks that have it as their best
+ best_step_values = []
+ for block_name, best_step in best_steps_dict.items():
+ if best_step == modal_best_step and block_name in best_values_dict:
+ best_step_values.append(best_values_dict[block_name])
+
+ if best_step_values:
+ best_step_avg = sum(best_step_values) / len(best_step_values)
+ best_step_max = max(best_step_values)
+ best_step_min = min(best_step_values)
+
+ best_step_summary_text = (
+ f"Best: Step {modal_best_step}, Avg={best_step_avg:.2e}, "
+ f"Max={best_step_max:.2e}, Min={best_step_min:.2e}"
+ )
+ best_step_summary = f"│ {best_step_summary_text}"
+
+ # Pad best step summary to box width
+ padding_needed = box_width - len(best_step_summary) - 1
+ best_step_summary += " " * padding_needed + "│"
+ lines.append(best_step_summary)
+
+ # Footer
+ footer = f"╰{'─' * (box_width - 2)}╯"
+ lines.append(footer)
+
+ return "\n".join(lines)
diff --git a/modelopt/torch/puzzletron/utils/utils.py b/modelopt/torch/puzzletron/utils/utils.py
new file mode 100644
index 0000000000..77a13609aa
--- /dev/null
+++ b/modelopt/torch/puzzletron/utils/utils.py
@@ -0,0 +1,253 @@
+# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+# SPDX-License-Identifier: Apache-2.0
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import dataclasses
+import json
+import os
+from copy import deepcopy
+from typing import Any
+
+import torch
+
+from modelopt.torch.puzzletron.decilm.deci_lm_hf_code.block_config import (
+ AttentionConfig,
+ BlockConfig,
+ FFNConfig,
+)
+
+
+def calculate_kv_dim(num_key_value_heads: int, n_head: int, n_embd: int) -> int:
+ """Calculate the key-value dimension for grouped-query attention.
+
+ Args:
+ num_key_value_heads: Number of key-value heads.
+ n_head: Total number of attention heads.
+ n_embd: Embedding dimension.
+
+ Returns:
+ Combined dimension for key and value tensors (2 * num_key_value_heads * head_size).
+ """
+ if num_key_value_heads is None:
+ return 0
+ head_size = n_embd // n_head
+ kv_dim = 2 * num_key_value_heads * head_size
+ return kv_dim
+
+
+def raise_unknown_subblock_config_error(subblock_config: Any) -> None:
+ """Raise an error for invalid subblock configuration types.
+
+ TODO: Consider a better place for this function.
+ Args:
+ subblock_config: The invalid subblock configuration object.
+
+ Raises:
+ ValueError: Always raised with a message indicating the expected types.
+ """
+ raise ValueError(
+ f"subblock_config should be an instance of FFNConfig or AttentionConfig, instead got {type(subblock_config)}"
+ )
+
+
+def sizeof_dtype(dtype: torch.dtype) -> int | float:
+ """Return the size in bytes of the given data type.
+
+ TODO: Consider a better place for this function.
+ Args:
+ dtype: PyTorch data type or custom type string (e.g., 'nvfp4').
+
+ Returns:
+ Size in bytes of the data type. Special case: 'nvfp4' returns ~0.588 bytes.
+ """
+ if dtype == "nvfp4":
+ return 1 / 1.7
+ return torch.tensor([], dtype=dtype).element_size()
+
+
+def load_json(file_path: str):
+ """Load and parse a JSON file.
+
+ TODO: Consider a better place for this function.
+
+ Args:
+ file_path: Path to the JSON file to load.
+
+ Returns:
+ Parsed JSON data as a Python object, or None if the file doesn't exist.
+ """
+ if not os.path.exists(file_path):
+ print("file does not exist {file_path}")
+ return None
+
+ with open(file=file_path) as f:
+ return json.load(f)
+
+
+def solution_to_str(block_configs: list[dict[str, Any] | BlockConfig]) -> str:
+ """Convert a list of block configurations to a human-readable string representation.
+
+ TODO: Consider a better place for this function.
+ Better place for this and subsequent related function would be in __repr__ function in class
+ BlockConfig so when we print it or do str(block_config), it automatically
+ prints in this custom formatted string
+
+ Args:
+ block_configs: List of BlockConfig dataclasses or dicts containing layer configurations.
+
+ Returns:
+ Multi-line string with each block's configuration on a separate line.
+ """
+ block_configs = deepcopy(block_configs)
+ reps = []
+ for block_idx, block_config in enumerate(block_configs):
+ rep = f"block_{block_idx}:".ljust(9)
+ rep += block_config_to_str(block_config)
+ reps.append(rep)
+ rep = "\n".join(reps) + "\n"
+ return rep
+
+
+def block_config_to_str(block_config: BlockConfig | dict[str, Any] | None) -> str | None:
+ """
+ Convert a BlockConfig to a human-readable string representation.
+
+ TODO: Consider a better place for this function.
+ Args:
+ block_config: BlockConfig dataclass or dict containing attention and ffn configs.
+
+ Returns:
+ Formatted string with attention and FFN information, or None if input is None.
+ """
+ if block_config is None:
+ return None
+ rep = ""
+ if dataclasses.is_dataclass(block_config):
+ block_config = dataclasses.asdict(block_config)
+ for subblock_name in ["attention", "ffn"]:
+ subblock_config = block_config[subblock_name]
+ rep += subblock_config_to_str(subblock_config, subblock_name)
+ return rep
+
+
+def subblock_config_to_str(
+ subblock_config: FFNConfig | AttentionConfig | dict[str, Any] | None,
+ subblock_name: None | str = None,
+) -> str | None:
+ """Convert a subblock config (FFN, Attention, Mamba, or MoE) to string.
+
+ TODO: Consider a better place for this function.
+ Args:
+ subblock_config: FFNConfig, AttentionConfig dataclass or dict.
+ subblock_name: Name of subblock ('ffn', 'attention', 'mamba', 'moe').
+ Auto-detected if subblock_config is a dataclass.
+
+ Returns:
+ Formatted string showing subblock type and key parameters (e.g., intermediate_size,
+ num_key_value_heads), or None if input is None.
+ """
+ if subblock_config is None:
+ return None
+ subblock_name = (
+ "ffn"
+ if isinstance(subblock_config, FFNConfig)
+ else "mamba"
+ if isinstance(subblock_config, AttentionConfig) and subblock_config.is_mamba
+ else "attention"
+ if isinstance(subblock_config, AttentionConfig)
+ else subblock_name
+ )
+ assert subblock_name is not None, "Must provide subblock_name if subblock_config is a dict."
+
+ if dataclasses.is_dataclass(subblock_config):
+ subblock_config = dataclasses.asdict(subblock_config)
+
+ if subblock_name == "attention" and subblock_config.get("mamba") is not None:
+ subblock_name = "mamba"
+
+ if subblock_name == "ffn" and subblock_config.get("moe") is not None:
+ subblock_name = "moe"
+
+ rep = f" {subblock_name}"
+ if subblock_config.get("no_op"):
+ rep += " no_op".ljust(8)
+ elif subblock_config.get("replace_with_linear"):
+ rep += " linear".ljust(8)
+ elif subblock_name == "ffn":
+ intermediate_size = subblock_config["intermediate_size"]
+ rep += f" intermediate_{intermediate_size}".ljust(8)
+ elif subblock_name == "attention":
+ num_key_value_heads = subblock_config["num_key_value_heads"]
+ rep += f" kv_heads_{num_key_value_heads}".ljust(8)
+ elif subblock_name == "mamba":
+ mamba_num_heads = subblock_config["mamba"]["num_heads"]
+ mamba_head_dim = subblock_config["mamba"]["head_dim"]
+ rep += f" num_heads_{mamba_num_heads} head_dim_{mamba_head_dim}".ljust(8)
+ elif subblock_name == "moe":
+ moe_num_local_experts = subblock_config["moe"]["num_local_experts"]
+ moe_expert_intermediate_dim = subblock_config["moe"]["expert_intermediate_dim"]
+ shared_expert_intermediate_dim = subblock_config["moe"]["shared_expert_intermediate_dim"]
+ num_experts_per_tok = subblock_config["moe"]["num_experts_per_tok"]
+ rep += f" num_experts_{moe_num_local_experts} expert_intermediate_dim_{moe_expert_intermediate_dim} shared_expert_intermediate_dim_{shared_expert_intermediate_dim} num_experts_per_tok_{num_experts_per_tok}".ljust(
+ 8
+ )
+ else:
+ raise ValueError(f"subblock_config_to_str: unrecognized subblock_name: {subblock_name}.")
+
+ return rep
+
+
+class EmptyInitOnDevice(torch.overrides.TorchFunctionMode):
+ def __init__(self, device=None, dtype=None):
+ """
+ Create tensors with given device and dtype and don't run initialization
+ (but instead use "empty tensors", i.e. uninitialized memory).
+
+ device: `torch.device` to work with
+ dtype: `torch.dtype` to work with
+
+ Example::
+ with EmptyInitOnDevice("cuda", dtype=torch.bfloat16):
+ model = LLaMA(model_config)
+ model.load_state_dict(torch.load("llama-lit/7B/lit-llama.pth"))"""
+
+ self.device = device
+ self.dtype = dtype
+
+ def __enter__(self):
+ return super().__enter__()
+
+ def __exit__(self, exc_type, exc_val, exc_tb):
+ return super().__exit__(exc_type, exc_val, exc_tb)
+
+ def __torch_function__(self, func, types, args=(), kwargs=None):
+ kwargs = kwargs or {}
+ if getattr(func, "__module__", None) == "torch.nn.init":
+ if "tensor" in kwargs:
+ return kwargs["tensor"]
+ else:
+ return args[0]
+ if (
+ self.device is not None
+ and func in torch.utils._device._device_constructors()
+ and kwargs.get("device") is None
+ ):
+ kwargs["device"] = self.device
+ if (
+ self.dtype is not None
+ and func in torch.utils._device._device_constructors()
+ and kwargs.get("dtype") is None
+ ):
+ kwargs["dtype"] = self.dtype
+ return func(*args, **kwargs)
diff --git a/modelopt/torch/puzzletron/utils/validate_runtime_pipeline.py b/modelopt/torch/puzzletron/utils/validate_runtime_pipeline.py
new file mode 100644
index 0000000000..90fea13c56
--- /dev/null
+++ b/modelopt/torch/puzzletron/utils/validate_runtime_pipeline.py
@@ -0,0 +1,302 @@
+# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+# SPDX-License-Identifier: Apache-2.0
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""
+Model evaluation utilities for models split across multiple GPUs in pipeline-parallel mode.
+
+Coordinates forward passes and loss computation through model shards distributed across GPUs
+using sewing_kit's StitchedModule framework. Relies on validation.py for core loss computation.
+
+Used by validate_model.py during activation scoring for sharded models.
+"""
+# mypy: ignore-errors
+
+import traceback
+from contextlib import nullcontext
+from typing import Type
+
+import numpy as np
+import torch
+from torch.utils.data import DataLoader
+from tqdm import tqdm
+
+import modelopt.torch.utils.distributed as dist
+from modelopt.torch.puzzletron.anymodel.model_descriptor import ModelDescriptor
+from modelopt.torch.puzzletron.decilm.deci_lm_hf_code.modeling_decilm import LMHead
+from modelopt.torch.puzzletron.sewing_kit import (
+ ExternalTarget,
+ InputArgs,
+ ModuleTarget,
+ Needle,
+ RemoteTarget,
+ StitchedModule,
+)
+from modelopt.torch.puzzletron.sewing_kit.core import InputReducer
+from modelopt.torch.puzzletron.sewing_kit.utils import (
+ distributed_recv_obj,
+ distributed_send_obj,
+ fake_tensor,
+)
+from modelopt.torch.puzzletron.tools.checkpoint_utils import init_module_with_state_dict
+from modelopt.torch.puzzletron.tools.sharded_checkpoint_utils import DummyBlock
+from modelopt.torch.puzzletron.utils.validation import _organize_outputs, calculate_batch_outputs
+
+
+def _log_forward_error(e: Exception, rank: int, batch_idx: int, num_batches: int) -> None:
+ """Log detailed error info for distributed forward pass failures.
+
+ When one rank crashes during distributed forward, others may hang waiting for communication.
+ This logging helps diagnose which rank failed and why.
+ """
+ error_msg = (
+ f"\n{'=' * 60}\n"
+ f"[Rank {rank}] ERROR in stitched_model forward (batch {batch_idx}/{num_batches})\n"
+ f"Error: {type(e).__name__}: {e}\n"
+ f"{'=' * 60}\n"
+ f"{traceback.format_exc()}"
+ f"{'=' * 60}\n"
+ )
+ print(error_msg, flush=True)
+
+
+class HiddenStatesAndLMHead(list):
+ def __init__(self, hidden_states: list[torch.Tensor], lm_head_weights: torch.Tensor):
+ super().__init__(hidden_states)
+ self.lm_head_weights = lm_head_weights
+
+
+@torch.no_grad()
+def calculate_losses_pipeline(
+ stitched_model: StitchedModule,
+ dataloader: DataLoader | None,
+ target_hidden_states_per_batch: HiddenStatesAndLMHead | None = None,
+ return_hidden_states: bool = False,
+ calculate_full_score_ablations: bool = False,
+ calc_on_cpu: bool = False,
+ just_model_forward: bool = False,
+ checkpoint_manager=None,
+ autocast_dtype: torch.dtype = torch.bfloat16,
+ descriptor: Type[ModelDescriptor] = None,
+ use_autocast: bool = True,
+) -> tuple[dict[str, dict], HiddenStatesAndLMHead | None] | tuple[None, None]:
+ """
+ Do model forward on each batch and calculate LM loss.
+ Optionally also calculate kl_div loss and other metrics from given target_hidden_states_per_batch.
+ Optionally return hidden states per batch.
+ Does not support data-parallel.
+ just_model_forward: skip loss calculation, just forward the model. Useful for activation hooks.
+
+
+ Returns:
+ losses: dict = {
+ "lm_loss": {
+ "avg": float,
+ "per_sample": list[float]
+ }
+ more metrics if provided with target_hidden_states_per_batch
+ }
+ target_hidden_states_per_batch: list[torch.Tensor], returned if return_hidden_states=True
+
+ """
+ if not isinstance(stitched_model, StitchedModule):
+ stitched_model = perform_pipeline_stitches(stitched_model, descriptor)
+
+ params = list(stitched_model.parameters())
+ model_device = params[0].device if params else "cpu"
+
+ # Pre-populate outputs with dummy values for skipped batches
+ start_batch = checkpoint_manager.current_batch if checkpoint_manager else 0
+ if dist.is_last_process():
+ outputs = [{"lm_loss": [0.0]}] * start_batch
+ else:
+ outputs = None
+
+ if dist.is_master():
+ all_input_ids, all_targets = zip(
+ *[(batch["input_ids"], batch["targets"]) for batch in dataloader]
+ )
+ if dist.size() > 1:
+ distributed_send_obj(all_targets, dst=dist.size() - 1)
+
+ if dist.is_last_process():
+ if dist.size() > 1:
+ all_targets = distributed_recv_obj(src=0)
+
+ lm_head: LMHead = next(
+ module
+ for module_name, module in stitched_model.named_modules()
+ if "lm_head" in module_name
+ )
+
+ if target_hidden_states_per_batch is not None:
+ lm_head_weights = target_hidden_states_per_batch.lm_head_weights
+ with torch.device(model_device):
+ target_lm_head = init_module_with_state_dict(
+ {"weight": lm_head_weights}, LMHead, *lm_head_weights.shape[::-1], bias=False
+ )
+
+ if dist.is_master():
+ num_batches = len(all_input_ids)
+ seq_len = all_input_ids[0].shape[1]
+ if dist.size() > 1:
+ torch.distributed.broadcast_object_list([num_batches, seq_len])
+
+ # Create progress bar with sliced range starting from checkpoint position
+ desc = (
+ f"[rank {dist.rank()}] calculate_losses_pipeline("
+ f"{(target_hidden_states_per_batch is None)=}, {return_hidden_states=}, {num_batches=})"
+ )
+ progress_bar = tqdm(range(start_batch, num_batches), desc=desc)
+ else:
+ obj_list = [None, None]
+ if dist.size() > 1:
+ torch.distributed.broadcast_object_list(obj_list)
+ num_batches, seq_len = obj_list
+ progress_bar = range(start_batch, num_batches)
+
+ stitched_model.eval()
+
+ # Use autocast for mixed precision, or nullcontext if disabled
+ # (some models like Qwen3-VL MoE have dtype bugs under autocast)
+ autocast_ctx = (
+ torch.autocast(device_type="cuda", dtype=autocast_dtype) if use_autocast else nullcontext()
+ )
+ with autocast_ctx:
+ fake_input_ids = fake_tensor(1, seq_len, dtype=torch.long, device=model_device)
+ for i_batch in progress_bar:
+ if dist.is_master():
+ input_ids = all_input_ids[i_batch].to(model_device)
+ else:
+ input_ids = fake_input_ids
+
+ try:
+ output = stitched_model({}, {}, input_ids)
+ except Exception as e:
+ _log_forward_error(e, dist.rank(), i_batch, num_batches)
+ raise
+
+ if dist.is_last_process():
+ logits = output.captured_outputs.get("model_output")
+ logits = getattr(logits, "logits", logits)
+ hidden_states = output.captured_outputs.get("hidden_states")
+ targets = all_targets[i_batch].to(model_device)
+
+ target_hidden_states = None
+ target_logits = None
+ if target_hidden_states_per_batch is not None:
+ target_hidden_states = target_hidden_states_per_batch[i_batch]
+ target_hidden_states = target_hidden_states.to(hidden_states.device)
+ target_logits = target_lm_head(target_hidden_states)
+
+ if just_model_forward:
+ batch_outputs = {"lm_loss": [-1.0] * len(targets)}
+ else:
+ batch_outputs = calculate_batch_outputs(
+ hidden_states,
+ target_hidden_states,
+ logits,
+ target_logits,
+ targets,
+ return_hidden_states,
+ calculate_full_score_ablations,
+ calc_on_cpu,
+ )
+
+ outputs.append(batch_outputs)
+
+ # Free GPU memory after processing each batch
+ del logits, hidden_states, targets
+ if target_hidden_states is not None:
+ del target_hidden_states
+ if target_logits is not None:
+ del target_logits
+
+ # Free output tensor memory on all ranks
+ del output
+
+ # Update checkpoint progress periodically
+ if checkpoint_manager:
+ checkpoint_manager.update_progress(i_batch + 1, num_batches)
+
+ losses, hidden_states_per_batch = (
+ _organize_outputs(outputs) if outputs is not None else (None, None)
+ )
+
+ if hidden_states_per_batch is not None:
+ hidden_states_per_batch = HiddenStatesAndLMHead(
+ hidden_states_per_batch, lm_head.weight.cpu()
+ )
+
+ dist.barrier()
+ return losses, hidden_states_per_batch
+
+
+def perform_pipeline_stitches(
+ model,
+ descriptor: Type[ModelDescriptor],
+) -> StitchedModule:
+ """Create pipeline stitches for distributed model evaluation.
+
+ Args:
+ model: The model to stitch (any HuggingFace model with AnyModel descriptor).
+ descriptor: ModelDescriptor for layer naming.
+ """
+ target = ModuleTarget("module", model)
+ stitcher = Needle()
+
+ num_layers = model.config.num_hidden_layers
+
+ is_real_block = np.flatnonzero(
+ [
+ not isinstance(model.get_submodule(descriptor.layer_block_name(i)), DummyBlock)
+ for i in range(num_layers)
+ ]
+ )
+
+ first_block, last_block = is_real_block.min(), is_real_block.max()
+
+ if dist.rank() != 0:
+ # receive activations from previous rank
+ stitcher.stitch(
+ RemoteTarget(peer_rank=dist.rank() - 1).value(
+ name="activations", adapter=lambda x: InputArgs(x)
+ ),
+ target.input(
+ name=descriptor.layer_block_name(first_block),
+ reducer=InputReducer(
+ lambda acc, override, orig, *args: override + orig.drop_args(0)
+ ),
+ ),
+ )
+
+ if not dist.is_last_process():
+ # send activations to next rank
+ stitcher.stitch(
+ target.output(descriptor.layer_block_name(last_block)),
+ RemoteTarget(peer_rank=dist.rank() + 1).value(name="activations"),
+ )
+ else:
+ # register model output
+ stitcher.stitch(
+ target.output(name=descriptor.output_embedding_name()),
+ ExternalTarget().output("model_output"),
+ )
+ stitcher.stitch(
+ target.output(name=descriptor.final_norm_name()),
+ ExternalTarget().output("hidden_states"),
+ )
+
+ stitched_module = stitcher.knot(ignore_extra_overrides=True)
+ return stitched_module
diff --git a/modelopt/torch/puzzletron/utils/validation.py b/modelopt/torch/puzzletron/utils/validation.py
new file mode 100644
index 0000000000..0fff907549
--- /dev/null
+++ b/modelopt/torch/puzzletron/utils/validation.py
@@ -0,0 +1,550 @@
+# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+# SPDX-License-Identifier: Apache-2.0
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""Model validation and loss calculation utilities for single-GPU and multi-GPU setups.
+
+Also provides helper functions for loss metrics, KL divergence, JS divergence,
+and similarity losses for knowledge distillation.
+"""
+
+# mypy: ignore-errors
+import functools
+import math
+from enum import Enum
+
+import numpy as np
+import torch
+import torch.nn.functional as F
+from torch import nn
+from torch.utils.data import DataLoader
+from tqdm import tqdm
+from transformers.generation.logits_process import TopKLogitsWarper, TopPLogitsWarper
+from typing_extensions import Self
+
+from modelopt.torch.puzzletron.tools import kd_model
+
+
+class UnshardedLowMemorySparseTensor:
+ def __init__(self, x: torch.Tensor):
+ inds_dtype = self._infer_inds_dtype(x)
+ x_sparse = x.to_sparse_coo()
+ self._values = x_sparse.values()
+ self._indices = x_sparse.indices().to(inds_dtype)
+ self._size = x_sparse.size()
+
+ @staticmethod
+ def _infer_inds_dtype(x: torch.Tensor) -> torch.dtype:
+ max_dim = max(x.shape)
+ for inds_dtype in [torch.int16, torch.int32, torch.int64]:
+ if torch.iinfo(inds_dtype).max >= max_dim:
+ return inds_dtype
+
+ def to_sparse_coo(self) -> torch.Tensor:
+ return torch.sparse_coo_tensor(values=self._values, indices=self._indices, size=self._size)
+
+ def to_dense(self) -> torch.Tensor:
+ return self.to_sparse_coo().to_dense()
+
+ def to(self, *args) -> Self:
+ self._values = self._values.to(*args)
+ for arg in args:
+ if isinstance(arg, torch.device) or isinstance(arg, str):
+ self._indices = self._indices.to(arg)
+ return self
+
+
+class LowMemorySparseTensor:
+ _max_sparse_size = torch.iinfo(torch.int32).max
+
+ def __init__(self, x: torch.Tensor):
+ num_chunks = math.ceil(x.numel() / self._max_sparse_size)
+ self._chunk_dim = np.argmax(x.shape)
+ self._chunks = [
+ UnshardedLowMemorySparseTensor(chunk)
+ for chunk in torch.chunk(x, num_chunks, dim=self._chunk_dim)
+ ]
+
+ def to(self, *args) -> Self:
+ for chunk in self._chunks:
+ chunk.to(*args)
+ return self
+
+ def to_dense(self) -> torch.Tensor:
+ return torch.concat([chunk.to_dense() for chunk in self._chunks], dim=self._chunk_dim)
+
+
+@torch.no_grad()
+def calculate_losses(
+ model: nn.Module,
+ dataloader: DataLoader,
+ target_probs: None = None,
+ return_probs: bool = False,
+ checkpoint_manager=None,
+) -> tuple[dict[str, dict], None] | tuple[None, None]:
+ """Do model forward on each batch and calculate LM loss.
+ Works on lit-llama models (single gpu) and huggingface models (can be multi gpu).
+ Does not support data-parallel.
+
+ ### Anything related to probs and hidden states is not supported currently! ###
+ calculate_losses() isn't updated according to the major refactor in
+ calculate_losses_pipeline() regarding hidden states.
+
+ Returns:
+ outputs = {
+ "lm_loss": list[float],
+ "token_accuracy_top_1": list[float],
+ "token_accuracy_top_5": list[float],
+ "token_accuracy_top_10": list[float],
+ }
+ """
+ if (target_probs is not None) or return_probs:
+ raise NotImplementedError(
+ "calculate_losses() isn't updated according to the major refactor in "
+ "calculate_losses_pipeline() regarding hidden states."
+ )
+
+ model_device = next(model.parameters()).device
+ outputs = []
+
+ try:
+ num_batches = len(dataloader)
+ except:
+ num_batches = None
+
+ # Adjust progress bar for resume
+ start_batch = checkpoint_manager.current_batch if checkpoint_manager else 0
+ progress_bar = tqdm(
+ enumerate(dataloader),
+ total=num_batches,
+ desc=f"calculate_losses({(target_probs is None)=}, {return_probs=})",
+ )
+ if start_batch > 0:
+ progress_bar.update(start_batch)
+
+ for i_batch, batch in progress_bar:
+ # Skip batch if resuming from checkpoint
+ if checkpoint_manager and checkpoint_manager.should_skip_batch(i_batch):
+ continue
+
+ input_ids = batch["input_ids"].to(model_device)
+ logits = model(input_ids)
+ if hasattr(logits, "logits"):
+ logits = logits.logits
+ # logits = logits.float()
+
+ targets = batch["targets"].to(model_device)
+
+ batch_outputs = calculate_batch_outputs(
+ hidden_states=None,
+ target_hidden_states=None,
+ logits=logits,
+ target_logits=None,
+ targets=targets,
+ return_hidden_states=False,
+ calculate_full_score_ablations=False,
+ calc_on_cpu=False,
+ )
+ outputs.append(batch_outputs)
+
+ # Update checkpoint progress periodically
+ if checkpoint_manager:
+ checkpoint_manager.update_progress(i_batch + 1, num_batches)
+
+ losses, _ = _organize_outputs(outputs)
+ return losses, None
+
+
+def calculate_batch_outputs(
+ hidden_states: torch.Tensor | None,
+ target_hidden_states: torch.Tensor | None,
+ logits: torch.Tensor,
+ target_logits: torch.Tensor | None,
+ targets: torch.Tensor,
+ return_hidden_states: bool,
+ calculate_full_score_ablations: bool,
+ calc_on_cpu: bool,
+) -> dict:
+ if calc_on_cpu:
+ if hidden_states is not None:
+ hidden_states = hidden_states.cpu()
+ if target_hidden_states is not None:
+ target_hidden_states = target_hidden_states.cpu()
+ if logits is not None:
+ logits = logits.cpu()
+ if target_logits is not None:
+ target_logits = target_logits.cpu()
+ if targets is not None:
+ targets = targets.cpu()
+
+ batch_outputs = _calculate_ground_truth_based_scores(logits, targets)
+
+ if (target_hidden_states is not None) or (target_logits is not None):
+ batch_outputs.update(
+ _calculate_teacher_similarity_scores(
+ hidden_states,
+ target_hidden_states,
+ logits,
+ target_logits,
+ calculate_full_score_ablations,
+ )
+ )
+
+ if return_hidden_states:
+ batch_outputs["hidden_states_per_batch"] = hidden_states.cpu()
+
+ return batch_outputs
+
+
+def _organize_outputs(
+ outputs_per_batch: list[dict],
+) -> tuple[dict[str, dict], list[torch.Tensor] | None]:
+ outputs = _concatenate_batch_outputs(outputs_per_batch)
+ hidden_states_per_batch = outputs.pop("hidden_states_per_batch", None)
+ losses = {
+ loss_name: {
+ "avg": sum(loss_per_sample) / len(loss_per_sample),
+ "per_sample": loss_per_sample,
+ }
+ for loss_name, loss_per_sample in outputs.items()
+ }
+ return losses, hidden_states_per_batch
+
+
+def _concatenate_batch_outputs(outputs_per_batch: list[dict]) -> dict[str, list]:
+ outputs = {}
+ for output_name in outputs_per_batch[0]: # Regular dict is directly iterable
+ item_list = []
+ for batch_outputs in outputs_per_batch:
+ batch_items = batch_outputs[output_name]
+ if isinstance(batch_items, list | tuple):
+ item_list.extend(batch_items)
+ else:
+ item_list.append(batch_items)
+ outputs[output_name] = item_list
+ return outputs
+
+
+def _calculate_per_sample_lm_loss(
+ logits: torch.Tensor,
+ targets: torch.Tensor,
+) -> list[float]:
+ per_sample_lm_loss = (
+ torch.nn.functional.cross_entropy(
+ logits.transpose(1, 2), targets, ignore_index=-1, reduction="none"
+ )
+ .mean(dim=-1)
+ .tolist()
+ )
+ return per_sample_lm_loss
+
+
+def _calculate_ground_truth_based_scores(
+ logits: torch.Tensor,
+ targets: torch.Tensor,
+) -> dict[str, list[float]]:
+ scores = {"lm_loss": _calculate_per_sample_lm_loss(logits, targets)}
+
+ for top_k in (1, 5, 10):
+ top_k_predictions = logits.topk(top_k, dim=-1).indices # [b, t, top_k]
+ is_target_in_predictions = (targets.unsqueeze(-1) == top_k_predictions).any(
+ dim=-1
+ ) # [b, t]
+ fraction_model_predicted_target = is_target_in_predictions.float().mean(dim=-1) # [b]
+ scores[f"token_accuracy_top_{top_k}"] = fraction_model_predicted_target.tolist()
+
+ return scores
+
+
+def cosine_embedding_loss(
+ hidden_states: torch.Tensor,
+ target_hidden_states: torch.Tensor,
+) -> list[float]:
+ return kd_model.cosine_embedding_loss_batched(hidden_states, target_hidden_states).tolist()
+
+
+def normalized_mse_loss(
+ hidden_states: torch.Tensor,
+ target_hidden_states: torch.Tensor,
+) -> list[float]:
+ return [
+ kd_model.normalized_mse_loss(hidden_states[i_sample], target_hidden_states[i_sample]).item()
+ for i_sample in range(hidden_states.shape[0])
+ ]
+
+
+def mse_loss(
+ hidden_states: torch.Tensor,
+ target_hidden_states: torch.Tensor,
+) -> list[float]:
+ return [
+ F.mse_loss(hidden_states[i_sample], target_hidden_states[i_sample]).item()
+ for i_sample in range(hidden_states.shape[0])
+ ]
+
+
+def mae_loss(
+ hidden_states: torch.Tensor,
+ target_hidden_states: torch.Tensor,
+) -> list[float]:
+ return [
+ F.l1_loss(hidden_states[i_sample], target_hidden_states[i_sample]).item()
+ for i_sample in range(hidden_states.shape[0])
+ ]
+
+
+def _calculate_teacher_similarity_scores(
+ hidden_states: torch.Tensor,
+ target_hidden_states: torch.Tensor,
+ logits: torch.Tensor,
+ target_logits: torch.Tensor,
+ calculate_full_score_ablations: bool,
+) -> dict[str, list[float]]:
+ """hidden_states: [batch, tokens, n_embd]
+ target_hidden_states: [batch, tokens, n_embd]
+ logits: [batch, tokens, vocab]
+ target_logits: [batch, tokens, vocab]
+ """
+
+ def calc_per_sample(func, logits, target_probs):
+ return [
+ func(logits=logits[i_sample], target_probs=target_probs[i_sample])
+ for i_sample in range(logits.shape[0])
+ ]
+
+ score_ablations = {}
+
+ if (target_hidden_states is not None) and (hidden_states.shape == target_hidden_states.shape):
+ for func in (cosine_embedding_loss, normalized_mse_loss, mse_loss, mae_loss):
+ score_name = f"{func.__name__}_hidden_states"
+ score_ablations[score_name] = func(hidden_states, target_hidden_states)
+
+ if target_logits is not None:
+ for func in (cosine_embedding_loss, normalized_mse_loss, mse_loss, mae_loss):
+ score_name = f"{func.__name__}_logits"
+ score_ablations[score_name] = func(logits, target_logits)
+
+ for top_p in (0.99, 0.95, None) if calculate_full_score_ablations else (None,):
+ transformed_logits = (
+ logits if (top_p is None) else top_p_top_k(logits, top_p=top_p, top_k=None)
+ )
+ transformed_target_logits = (
+ target_logits
+ if (top_p is None)
+ else top_p_top_k(target_logits, top_p=top_p, top_k=None)
+ )
+ target_probs = transformed_target_logits.softmax(-1)
+
+ for func in (kl_div, js_div, tv_dist):
+ for clip_epsilon in (
+ (
+ ClipEpsilon.NO_CLIP,
+ ClipEpsilon.CLIP_NO_RENORMALIZE,
+ ClipEpsilon.CLIP_RENORMALIZE,
+ )
+ if calculate_full_score_ablations
+ else (ClipEpsilon.NO_CLIP,)
+ ):
+ epsilon_factors = (
+ (1.0, 0.1, 0.01) if not clip_epsilon == ClipEpsilon.NO_CLIP else (None,)
+ )
+
+ for epsilon_factor in epsilon_factors:
+ score_name = (
+ f"{func.__name__}--top_p_{top_p}--clip_epsilon_{clip_epsilon.name}"
+ f"--epsilon_factor_{epsilon_factor}"
+ )
+ func_with_args = functools.partial(
+ func, clip_epsilon=clip_epsilon, epsilon_factor=epsilon_factor
+ )
+ score_ablations[score_name] = calc_per_sample(
+ func_with_args, transformed_logits, target_probs
+ )
+ if (top_p is None) and (clip_epsilon == ClipEpsilon.NO_CLIP):
+ short_score_name = func.__name__
+ score_ablations[short_score_name] = score_ablations[score_name]
+
+ for top_k in (1, 5, 10):
+ teacher_greedy_prediction = target_logits.argmax(dim=-1, keepdim=True) # [b,t,1]
+ student_top_k_predictions = logits.topk(top_k, dim=-1).indices # [b,t,k]
+ is_teacher_prediction_in_student_predictions = (
+ teacher_greedy_prediction == student_top_k_predictions
+ ).any(dim=-1) # [b,t]
+ fraction_student_predicted_teacher = (
+ is_teacher_prediction_in_student_predictions.float().mean(dim=-1)
+ ) # [b]
+ score_ablations[f"greedy_teacher_prediction_in_student_top_{top_k}"] = (
+ fraction_student_predicted_teacher.tolist()
+ )
+
+ if calculate_full_score_ablations:
+ for top_p in (0.99, 0.95, 0.50, None):
+ # student
+ transformed_logits = logits.clone()
+
+ # teacher
+ transformed_target_logits = (
+ target_logits.clone()
+ if (top_p is None)
+ else top_p_top_k(target_logits, top_p=top_p, top_k=None)
+ )
+
+ target_probs = transformed_target_logits.softmax(-1)
+ mask = transformed_target_logits == -1000
+ if torch.any(mask):
+ transformed_logits[mask] = 0
+ transformed_target_logits[mask] = 0
+ target_probs[mask] = 0
+
+ for func in (mse_loss, mae_loss):
+ score_name = f"{func.__name__}_logits_top_p_{top_p}"
+ score_ablations[score_name] = func(
+ transformed_logits, transformed_target_logits
+ )
+
+ if top_p is not None and top_p > 0.9:
+ func = kl_div
+ clip_epsilon = ClipEpsilon.NO_CLIP
+ score_name = (
+ f"{func.__name__}--top_p_{top_p}--clip_epsilon_no_clip_student_unfiltered"
+ )
+ func_with_args = functools.partial(
+ func, clip_epsilon=clip_epsilon, epsilon_factor=epsilon_factor
+ )
+ score_ablations[score_name] = calc_per_sample(
+ func_with_args, logits, target_probs
+ )
+ # score_name = f"{func.__name__}_abs--top_p_{top_p}--clip_epsilon_no_clip_student_unfiltered"
+ # score_ablations[score_name] = [s.abs() for s in score_ablations[score_name]]
+
+ return score_ablations
+
+
+class ClipEpsilon(Enum):
+ NO_CLIP = "NO_CLIP"
+ CLIP_RENORMALIZE = "CLIP_RENORMALIZE"
+ CLIP_NO_RENORMALIZE = "CLIP_NO_RENORMALIZE"
+
+
+def _logits_to_logprobs(
+ logits: torch.Tensor, clip_epsilon: ClipEpsilon, epsilon_factor: float
+) -> torch.Tensor:
+ """logits: [tokens, vocab]"""
+ logprobs = logits.log_softmax(
+ -1
+ ) # must normalize logits before clipping otherwise log(1/voacb) means nothing
+ if clip_epsilon == ClipEpsilon.NO_CLIP:
+ return logprobs
+ vocab_size = logprobs.shape[-1]
+ epsilon = math.log(epsilon_factor * 1 / vocab_size)
+ logprobs = torch.clip(logprobs, min=epsilon)
+ if clip_epsilon == ClipEpsilon.CLIP_RENORMALIZE:
+ logprobs = logprobs.log_softmax(
+ -1
+ ) # we do log_softmax again to retain legitimate distributions
+ return logprobs
+
+
+def kl_div(
+ logits: torch.Tensor,
+ target_probs: torch.Tensor,
+ clip_epsilon: ClipEpsilon = ClipEpsilon.NO_CLIP,
+ epsilon_factor: float = 1.0,
+) -> float:
+ """Kullback-Leibler Divergence for a single sample.
+ logits: [tokens, vocab]
+ target_probs: [tokens, vocab]
+ """
+ num_tokens = logits.shape[0]
+ logprobs = _logits_to_logprobs(logits, clip_epsilon, epsilon_factor)
+
+ _kl_div = (
+ F.kl_div(logprobs, target_probs, reduction="sum", log_target=False).item() / num_tokens
+ )
+ return _kl_div
+
+
+def js_div(
+ logits: torch.Tensor,
+ target_probs: torch.Tensor,
+ clip_epsilon: ClipEpsilon = ClipEpsilon.NO_CLIP,
+ epsilon_factor: float = 1.0,
+) -> float:
+ """Jensen-Shannon Divergence for a single sample.
+ logits: [tokens, vocab]
+ target_probs: [tokens, vocab]
+ """
+ probs = logits.softmax(-1)
+ mixture_probs = (probs + target_probs) / 2
+ mixture_logprobs = mixture_probs.log().clip(min=-1000)
+
+ pred_kl_div = kl_div(mixture_logprobs, probs, clip_epsilon, epsilon_factor)
+ target_kl_div = kl_div(mixture_logprobs, target_probs, clip_epsilon, epsilon_factor)
+ _js_div = 0.5 * (pred_kl_div + target_kl_div)
+ return _js_div
+
+
+def tv_dist(
+ logits: torch.Tensor,
+ target_probs: torch.Tensor,
+ clip_epsilon: ClipEpsilon = ClipEpsilon.NO_CLIP,
+ epsilon_factor: float = 1.0,
+) -> float:
+ """Total Variation Distance (L1-loss) for a single sample.
+ logits: [tokens, vocab]
+ target_probs: [tokens, vocab]
+ """
+ num_tokens, vocab_size = logits.shape
+ probs = logits.softmax(-1)
+
+ if clip_epsilon != ClipEpsilon.NO_CLIP:
+ epsilon = epsilon_factor * 1 / vocab_size
+ probs = probs.clip(min=epsilon)
+ target_probs = target_probs.clip(min=epsilon)
+ if clip_epsilon == ClipEpsilon.CLIP_RENORMALIZE:
+ probs = probs / probs.sum(-1, keepdim=True)
+ target_probs = target_probs / target_probs.sum(-1, keepdim=True)
+
+ _tv_dist = 0.5 * (probs - target_probs).abs().sum().item() / num_tokens
+ return _tv_dist
+
+
+DEFAULT_TOP_P = 0.999
+# WestLake model:
+# 700 = percentile 0.9 for top_p=0.99
+# 1700 = percentile 0.95 for top_p=0.99 and percentile 0.75 for top_p=0.999
+# For top_p=0.999 and top_k=1700 you take about 75 GB for 2048*8192 tokens
+DEFAULT_TOP_K = 1000
+
+
+def top_p_top_k(
+ logits: torch.Tensor,
+ top_p: float | None = DEFAULT_TOP_P,
+ top_k: int | None = DEFAULT_TOP_K,
+ filter_value=-1000,
+) -> torch.Tensor:
+ logit_warpers = []
+ if top_p is not None:
+ logit_warpers.append(TopPLogitsWarper(top_p=top_p, filter_value=filter_value))
+ if top_k is not None:
+ logit_warpers.append(TopKLogitsWarper(top_k=top_k, filter_value=filter_value))
+
+ warped_logits = []
+ for sample_logits in logits:
+ for warper in logit_warpers:
+ sample_logits = warper(input_ids=None, scores=sample_logits)
+ warped_logits.append(sample_logits)
+ warped_logits = torch.stack(warped_logits)
+
+ return warped_logits
diff --git a/modelopt/torch/utils/robust_json.py b/modelopt/torch/utils/robust_json.py
index c4a72fde83..23a3091637 100644
--- a/modelopt/torch/utils/robust_json.py
+++ b/modelopt/torch/utils/robust_json.py
@@ -55,8 +55,13 @@ def default(self, o):
# User-defined function in main — fallback to just the name
return o.__name__
return f"{o.__module__}.{o.__qualname__}"
+ if inspect.isclass(o):
+ return f"{o.__module__}.{o.__qualname__}"
if isinstance(o, datetime.timedelta):
return str(o)
+ # Fallback for arbitrary objects (e.g. mixins injected into Hydra configs)
+ if hasattr(o, "__class__") and hasattr(o.__class__, "__module__"):
+ return f"{o.__class__.__module__}.{o.__class__.__qualname__}"
return super().default(o)
diff --git a/pyproject.toml b/pyproject.toml
index 96490dff0a..a03e2029dc 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -84,6 +84,16 @@ hf = [
"transformers>=4.56,<5.0", # Should match modelopt/torch/__init__.py and tox.ini
"wonderwords",
]
+
+puzzletron = [ # Dependedencies for modelopt.torch.puzzletron subpackage
+ "fire",
+ "hydra-core==1.3.2",
+ "immutabledict",
+ "lru-dict",
+ "mip",
+ "pandas",
+ "typeguard",
+]
dev-lint = [
"bandit[toml]==1.7.9", # security/compliance checks
"mypy==1.17.1",
@@ -113,7 +123,7 @@ dev-test = [
"tox-current-env>=0.0.12",
]
# Compound extras via self-references
-all = ["nvidia-modelopt[onnx,hf]"]
+all = ["nvidia-modelopt[onnx,hf,puzzletron]"]
dev = ["nvidia-modelopt[all,dev-docs,dev-lint,dev-test]"]
[project.urls]
@@ -203,6 +213,17 @@ extend-ignore = [
"D",
"E501",
] # Ignore missing docstrings or line length for Jupyter notebooks
+"modelopt/torch/puzzletron/*" = [
+ "C4",
+ "D",
+ "E",
+ "F",
+ "N",
+ "PERF",
+ "RUF",
+ "SIM",
+ "UP",
+] # TODO: Disabled for now, will enable later, once all puzzletron code is migrated
"modelopt/torch/quantization/triton/*" = ["N803", "N806", "E731"] # triton style
"modelopt/torch/sparsity/attention_sparsity/kernels/*" = [
"N803",
@@ -271,7 +292,7 @@ markers = [
[tool.coverage.run]
branch = false
include = ["modelopt/*"]
-omit = ["*/plugins/*", "*/export/*"]
+omit = ["*/plugins/*", "*/export/*", "modelopt/torch/puzzletron/*"]
[tool.coverage.report]
diff --git a/tests/_test_utils/torch/puzzletron/utils.py b/tests/_test_utils/torch/puzzletron/utils.py
new file mode 100644
index 0000000000..fc6d6d5c16
--- /dev/null
+++ b/tests/_test_utils/torch/puzzletron/utils.py
@@ -0,0 +1,217 @@
+# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+# SPDX-License-Identifier: Apache-2.0
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import os
+from pathlib import Path
+
+import torch
+from _test_utils.torch.transformers_models import get_tiny_tokenizer
+from datasets import Dataset, DatasetDict
+from transformers import AutoConfig, AutoModelForCausalLM, PreTrainedTokenizerBase
+
+import modelopt.torch.utils.distributed as dist
+from modelopt.torch.puzzletron.tools.hydra_utils import register_hydra_resolvers
+
+
+def setup_test_model_and_data(
+ tmp_path: Path, rank: int, hf_model_name: str, hybrid_override_pattern: str | None = None
+) -> tuple[Path, Path, Path]:
+ """
+ Setup the test model and data for the puzzletron NAS search.
+
+ Args:
+ tmp_path: the temporary path to use for the test
+ rank: the rank of the process
+ hf_model_name: HuggingFace model card name (e.g., "meta-llama/Llama-3.1-8B-Instruct")
+ hybrid_override_pattern: For NemotronH models, the layer type pattern
+
+ Returns:
+ tuple[Path, Path, Path]: the puzzle_dir, hf_checkpoint_path, dataset_path
+ """
+ # Register Hydra custom resolvers (needed for config resolution)
+ register_hydra_resolvers()
+
+ puzzle_dir = tmp_path / hf_model_name
+ hf_checkpoint_path = puzzle_dir / f"hf_models/{hf_model_name}"
+ dataset_path = puzzle_dir / "dummy_dataset"
+
+ if rank == 0:
+ save_dummy_dataset(dataset_path)
+
+ # Create a small HF model
+ tokenizer = get_tiny_tokenizer()
+ create_and_save_small_hf_model(
+ output_path=str(hf_checkpoint_path),
+ tokenizer=tokenizer,
+ hf_model_name=hf_model_name,
+ hybrid_override_pattern=hybrid_override_pattern,
+ )
+ dist.barrier()
+
+ return puzzle_dir, hf_checkpoint_path, dataset_path
+
+
+def create_and_save_small_hf_model(
+ output_path: str,
+ tokenizer: PreTrainedTokenizerBase,
+ hf_model_name: str,
+ hybrid_override_pattern: str | None = None,
+):
+ """
+ Create and save a small HuggingFace model for testing the conversion pipeline.
+ Uses real HuggingFace config to preserve model-specific settings (like tie_word_embeddings),
+ but shrinks size parameters for fast testing.
+
+ Args:
+ output_path: Where to save the model
+ tokenizer: Tokenizer to save alongside the model
+ hf_model_name: HuggingFace model card name (e.g., "meta-llama/Llama-3.1-8B-Instruct")
+ hybrid_override_pattern: For NemotronH models, the layer type pattern (e.g., "*-" for Attention+MLP,
+ "M-" for Mamba+MLP). Must match num_hidden_layers. None for non-NemotronH models.
+ """
+ # Load real HuggingFace config (preserves tie_word_embeddings, rope_scaling, etc.)
+ config = AutoConfig.from_pretrained(hf_model_name, trust_remote_code=True)
+
+ # Override size-related params to make it small for testing
+ # Note: intermediate_size must be divisible by 256 per DeciLM config requirements
+ # Note: hidden_size must give head_dim >= 8 for Flash Attention 2 compatibility
+
+ # VL models have nested configs (text_config, vision_config)
+ if hasattr(config, "text_config") and hasattr(config, "vision_config"):
+ config.text_config.vocab_size = tokenizer.vocab_size
+ config.text_config.hidden_size = 256
+ config.text_config.intermediate_size = 512
+ config.text_config.num_hidden_layers = 2
+ config.text_config.num_attention_heads = 32
+ config.text_config.num_key_value_heads = 8
+ config.text_config.num_experts = 16 # Reduce from 128
+ config.text_config.moe_intermediate_size = 256
+ config.text_config.max_position_embeddings = 512
+ config.vision_config.depth = 2 # Reduce from 27
+ config.vision_config.hidden_size = 256
+ config.vision_config.intermediate_size = 512
+ config.vision_config.out_hidden_size = 256
+ # TODO: this is hack, redesign converter to not read config.num_hidden_layers directly.
+ # set top-level num_hidden_layers for converter compatibility
+ config.num_hidden_layers = config.text_config.num_hidden_layers
+ else:
+ # Regular models have flat config
+ config.vocab_size = tokenizer.vocab_size
+ config.hidden_size = 256
+ config.intermediate_size = 512
+ config.num_hidden_layers = max(2, dist.size())
+ config.num_attention_heads = 32
+ config.num_key_value_heads = 8
+ config.max_position_embeddings = 512
+
+ # Fix layer_types to match num_hidden_layers (newer transformers validates this)
+ if hasattr(config, "layer_types") and config.layer_types is not None:
+ config.layer_types = config.layer_types[: config.num_hidden_layers]
+
+ # Fix rope_scaling to be consistent with max_position_embeddings
+ if hasattr(config, "rope_scaling") and config.rope_scaling is not None:
+ config.rope_scaling["original_max_position_embeddings"] = 256
+
+ # NemotronH requires hybrid_override_pattern to match num_hidden_layers
+ if hasattr(config, "hybrid_override_pattern") and hybrid_override_pattern is not None:
+ config.hybrid_override_pattern = hybrid_override_pattern
+
+ # Ensure pad_token_id is within vocab_size (nn.Embedding requires padding_idx < num_embeddings)
+ if (
+ getattr(config, "pad_token_id", None) is not None
+ and config.pad_token_id >= tokenizer.vocab_size
+ ):
+ config.pad_token_id = 0
+
+ # Set seed for reproducible weight initialization
+ torch.manual_seed(42)
+
+ # Create and save the model
+ # Force CPU initialization for deterministic behavior (prevents NaN on RTX GPUs)
+ original_cuda_visible = os.environ.get("CUDA_VISIBLE_DEVICES")
+ os.environ["CUDA_VISIBLE_DEVICES"] = ""
+ # TODO: Consider using AutoModel.from_config instead.
+ if hasattr(config, "text_config") and hasattr(config, "vision_config"):
+ from transformers import Qwen3VLMoeForConditionalGeneration
+
+ model = Qwen3VLMoeForConditionalGeneration._from_config(config)
+ else:
+ model = AutoModelForCausalLM.from_config(config, trust_remote_code=True)
+
+ # Initialize weights to ensure all parameters are properly initialized
+ # This prevents NaN values in uninitialized parameters (e.g., backbone.layers.1.mixer.gate.weight
+ # in nemotron-3-nano-30b-a3b-base-bf16) that can occur with from_config on RTX GPU cards (not on H100)
+ model.initialize_weights()
+
+ # Fix any remaining NaN/Inf values that initialize_weights() might have missed
+ for param in model.parameters():
+ if torch.isnan(param).any() or torch.isinf(param).any():
+ nan_inf_mask = torch.isnan(param) | torch.isinf(param)
+ param.data = torch.where(nan_inf_mask, torch.zeros_like(param), param)
+
+ # Restore CUDA_VISIBLE_DEVICES after model creation and initialization
+ if original_cuda_visible is not None:
+ os.environ["CUDA_VISIBLE_DEVICES"] = original_cuda_visible
+ else:
+ os.environ.pop("CUDA_VISIBLE_DEVICES", None)
+
+ model.to(dtype=torch.bfloat16).save_pretrained(output_path)
+
+ # Save tokenizer
+ tokenizer.save_pretrained(output_path)
+
+ # Save config
+ config.save_pretrained(output_path)
+
+
+def save_dummy_dataset(dataset_path: Path | str):
+ """
+ Save a dummy dataset for testing purposes.
+ """
+ # dummy sample
+ sample = [
+ {"role": "user", "content": "please cite Lorem Ipsum?"},
+ {
+ "role": "assistant",
+ "content": (
+ "Lorem ipsum dolor sit amet, consectetur adipiscing elit. Sed in blandit ante. "
+ "Sed tempus erat urna, ac elementum nisl facilisis quis. Aliquam consectetur mollis massa, "
+ "in elementum sem venenatis posuere. Fusce lorem arcu, egestas vel massa sollicitudin, "
+ "dictum mollis purus. Proin in ullamcorper elit. Nam tellus nisi, volutpat a mattis vel, "
+ "pretium in purus. Nunc at lectus facilisis risus scelerisque rhoncus eu nec ex. "
+ "Maecenas semper, tellus non placerat vulputate, urna felis facilisis diam, "
+ "sit amet vestibulum erat sapien nec libero. Praesent non massa velit. Donec faucibus mi eros. "
+ "Nam turpis nulla, congue sit amet mi at, porttitor scelerisque elit. Nunc id sodales lorem, "
+ "nec tincidunt leo. Quisque a neque nec ligula porttitor auctor. "
+ "Nunc accumsan nunc ac tellus congue vehicula. Praesent tellus eros, luctus non gravida dapibus, "
+ "faucibus eu ex. Quisque bibendum leo pharetra, tristique est vitae, hendrerit nunc. "
+ "Duis nec congue dolor. Donec commodo ipsum non efficitur volutpat. "
+ "Nulla risus nulla, efficitur et urna at, imperdiet sodales lorem. "
+ "Suspendisse erat est, sollicitudin at nisl tincidunt, vehicula hendrerit lectus. "
+ "Nam quis nisi ullamcorper, rhoncus massa vel, tempus purus. "
+ "Duis pulvinar eros vel nulla pellentesque, at dapibus justo laoreet. "
+ "Praesent tortor orci, vulputate fermentum dapibus nec, feugiat vitae tortor. "
+ "Donec mollis convallis massa quis iaculis."
+ ),
+ },
+ ]
+
+ # Prepare train and val splits with sample repeated, 2500 samples are for
+ # 128 samples with block-size 8192 and LLama3 tokenizer
+ data = [{"conversation": sample}] * 2500
+
+ # For train-val splits
+ data_dict = DatasetDict({"train": Dataset.from_list(data), "valid": Dataset.from_list(data)})
+ data_dict.save_to_disk(str(dataset_path))
diff --git a/tests/_test_utils/torch/tokenizer/special_tokens_map.json b/tests/_test_utils/torch/tokenizer/special_tokens_map.json
new file mode 100644
index 0000000000..02ee80b619
--- /dev/null
+++ b/tests/_test_utils/torch/tokenizer/special_tokens_map.json
@@ -0,0 +1,16 @@
+{
+ "bos_token": {
+ "content": "<|begin_of_text|>",
+ "lstrip": false,
+ "normalized": false,
+ "rstrip": false,
+ "single_word": false
+ },
+ "eos_token": {
+ "content": "<|eot_id|>",
+ "lstrip": false,
+ "normalized": false,
+ "rstrip": false,
+ "single_word": false
+ }
+}
diff --git a/tests/_test_utils/torch/tokenizer/tokenizer.json b/tests/_test_utils/torch/tokenizer/tokenizer.json
new file mode 100644
index 0000000000..83592e2494
--- /dev/null
+++ b/tests/_test_utils/torch/tokenizer/tokenizer.json
@@ -0,0 +1,212 @@
+{
+ "version": "1.0",
+ "truncation": null,
+ "padding": null,
+ "added_tokens": [],
+ "normalizer": null,
+ "pre_tokenizer": {
+ "type": "Sequence",
+ "pretokenizers": [
+ {
+ "type": "Split",
+ "pattern": {
+ "Regex": "(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\\r\\n\\p{L}\\p{N}]?\\p{L}+|\\p{N}{1,3}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+"
+ },
+ "behavior": "Isolated",
+ "invert": false
+ },
+ {
+ "type": "ByteLevel",
+ "add_prefix_space": false,
+ "trim_offsets": true,
+ "use_regex": false
+ }
+ ]
+ },
+ "post_processor": {
+ "type": "Sequence",
+ "processors": [
+ {
+ "type": "ByteLevel",
+ "add_prefix_space": true,
+ "trim_offsets": false,
+ "use_regex": true
+ },
+ {
+ "type": "TemplateProcessing",
+ "single": [
+ {
+ "SpecialToken": {
+ "id": "<|begin_of_text|>",
+ "type_id": 0
+ }
+ },
+ {
+ "Sequence": {
+ "id": "A",
+ "type_id": 0
+ }
+ }
+ ],
+ "pair": [
+ {
+ "SpecialToken": {
+ "id": "<|begin_of_text|>",
+ "type_id": 0
+ }
+ },
+ {
+ "Sequence": {
+ "id": "A",
+ "type_id": 0
+ }
+ },
+ {
+ "SpecialToken": {
+ "id": "<|begin_of_text|>",
+ "type_id": 1
+ }
+ },
+ {
+ "Sequence": {
+ "id": "B",
+ "type_id": 1
+ }
+ }
+ ],
+ "special_tokens": {
+ "<|begin_of_text|>": {
+ "id": "<|begin_of_text|>",
+ "ids": [
+ 100
+ ],
+ "tokens": [
+ "<|begin_of_text|>"
+ ]
+ }
+ }
+ }
+ ]
+ },
+ "decoder": {
+ "type": "ByteLevel",
+ "add_prefix_space": true,
+ "trim_offsets": true,
+ "use_regex": true
+ },
+ "model": {
+ "type": "BPE",
+ "dropout": null,
+ "unk_token": null,
+ "continuing_subword_prefix": null,
+ "end_of_word_suffix": null,
+ "fuse_unk": false,
+ "byte_fallback": false,
+ "ignore_merges": true,
+ "vocab": {
+ "!": 0,
+ "\"": 1,
+ "#": 2,
+ "$": 3,
+ "%": 4,
+ "&": 5,
+ "'": 6,
+ "(": 7,
+ ")": 8,
+ "*": 9,
+ "+": 10,
+ ",": 11,
+ "-": 12,
+ ".": 13,
+ "/": 14,
+ "0": 15,
+ "1": 16,
+ "2": 17,
+ "3": 18,
+ "4": 19,
+ "5": 20,
+ "6": 21,
+ "7": 22,
+ "8": 23,
+ "9": 24,
+ ":": 25,
+ ";": 26,
+ "<": 27,
+ "=": 28,
+ ">": 29,
+ "?": 30,
+ "@": 31,
+ "A": 32,
+ "B": 33,
+ "C": 34,
+ "D": 35,
+ "E": 36,
+ "F": 37,
+ "G": 38,
+ "H": 39,
+ "I": 40,
+ "J": 41,
+ "K": 42,
+ "L": 43,
+ "M": 44,
+ "N": 45,
+ "O": 46,
+ "P": 47,
+ "Q": 48,
+ "R": 49,
+ "S": 50,
+ "T": 51,
+ "U": 52,
+ "V": 53,
+ "W": 54,
+ "X": 55,
+ "Y": 56,
+ "Z": 57,
+ "[": 58,
+ "\\": 59,
+ "]": 60,
+ "^": 61,
+ "_": 62,
+ "`": 63,
+ "a": 64,
+ "b": 65,
+ "c": 66,
+ "d": 67,
+ "e": 68,
+ "f": 69,
+ "g": 70,
+ "h": 71,
+ "i": 72,
+ "j": 73,
+ "k": 74,
+ "l": 75,
+ "m": 76,
+ "n": 77,
+ "o": 78,
+ "p": 79,
+ "q": 80,
+ "r": 81,
+ "s": 82,
+ "t": 83,
+ "u": 84,
+ "v": 85,
+ "w": 86,
+ "x": 87,
+ "y": 88,
+ "z": 89,
+ "{": 90,
+ "|": 91,
+ "}": 92,
+ "~": 93,
+ "¡": 94,
+ "¢": 95,
+ "£": 96,
+ "¤": 97,
+ "¥": 98,
+ "¦": 99,
+ "<|begin_of_text|>": 100,
+ "<|eot_id|>": 101
+ },
+ "merges": []
+ }
+}
diff --git a/tests/_test_utils/torch/tokenizer/tokenizer_config.json b/tests/_test_utils/torch/tokenizer/tokenizer_config.json
new file mode 100644
index 0000000000..754d9e8db5
--- /dev/null
+++ b/tests/_test_utils/torch/tokenizer/tokenizer_config.json
@@ -0,0 +1,13 @@
+{
+ "bos_token": "<|begin_of_text|>",
+ "chat_template": "{{- bos_token }}\n{%- if custom_tools is defined %}\n {%- set tools = custom_tools %}\n{%- endif %}\n{%- if not tools_in_user_message is defined %}\n {%- set tools_in_user_message = true %}\n{%- endif %}\n{%- if not date_string is defined %}\n {%- set date_string = \"26 Jul 2024\" %}\n{%- endif %}\n{%- if not tools is defined %}\n {%- set tools = none %}\n{%- endif %}\n\n{#- This block extracts the system message, so we can slot it into the right place. #}\n{%- if messages[0]['role'] == 'system' %}\n {%- set system_message = messages[0]['content']|trim %}\n {%- set messages = messages[1:] %}\n{%- else %}\n {%- set system_message = \"\" %}\n{%- endif %}\n\n{#- System message + builtin tools #}\n{{- \"<|start_header_id|>system<|end_header_id|>\\n\\n\" }}\n{%- if builtin_tools is defined or tools is not none %}\n {{- \"Environment: ipython\\n\" }}\n{%- endif %}\n{%- if builtin_tools is defined %}\n {{- \"Tools: \" + builtin_tools | reject('equalto', 'code_interpreter') | join(\", \") + \"\\n\\n\"}}\n{%- endif %}\n{{- \"Cutting Knowledge Date: December 2023\\n\" }}\n{{- \"Today Date: \" + date_string + \"\\n\\n\" }}\n{%- if tools is not none and not tools_in_user_message %}\n {{- \"You have access to the following functions. To call a function, please respond with JSON for a function call.\" }}\n {{- 'Respond in the format {\"name\": function name, \"parameters\": dictionary of argument name and its value}.' }}\n {{- \"Do not use variables.\\n\\n\" }}\n {%- for t in tools %}\n {{- t | tojson(indent=4) }}\n {{- \"\\n\\n\" }}\n {%- endfor %}\n{%- endif %}\n{{- system_message }}\n{{- \"<|eot_id|>\" }}\n\n{#- Custom tools are passed in a user message with some extra guidance #}\n{%- if tools_in_user_message and not tools is none %}\n {#- Extract the first user message so we can plug it in here #}\n {%- if messages | length != 0 %}\n {%- set first_user_message = messages[0]['content']|trim %}\n {%- set messages = messages[1:] %}\n {%- else %}\n {{- raise_exception(\"Cannot put tools in the first user message when there's no first user message!\") }}\n{%- endif %}\n {{- '<|start_header_id|>user<|end_header_id|>\\n\\n' -}}\n {{- \"Given the following functions, please respond with a JSON for a function call \" }}\n {{- \"with its proper arguments that best answers the given prompt.\\n\\n\" }}\n {{- 'Respond in the format {\"name\": function name, \"parameters\": dictionary of argument name and its value}.' }}\n {{- \"Do not use variables.\\n\\n\" }}\n {%- for t in tools %}\n {{- t | tojson(indent=4) }}\n {{- \"\\n\\n\" }}\n {%- endfor %}\n {{- first_user_message + \"<|eot_id|>\"}}\n{%- endif %}\n\n{%- for message in messages %}\n {%- if not (message.role == 'ipython' or message.role == 'tool' or 'tool_calls' in message) %}\n {{- '<|start_header_id|>' + message['role'] + '<|end_header_id|>\\n\\n'+ message['content'] | trim + '<|eot_id|>' }}\n {%- elif 'tool_calls' in message %}\n {%- if not message.tool_calls|length == 1 %}\n {{- raise_exception(\"This model only supports single tool-calls at once!\") }}\n {%- endif %}\n {%- set tool_call = message.tool_calls[0].function %}\n {%- if builtin_tools is defined and tool_call.name in builtin_tools %}\n {{- '<|start_header_id|>assistant<|end_header_id|>\\n\\n' -}}\n {{- \"<|python_tag|>\" + tool_call.name + \".call(\" }}\n {%- for arg_name, arg_val in tool_call.arguments | items %}\n {{- arg_name + '=\"' + arg_val + '\"' }}\n {%- if not loop.last %}\n {{- \", \" }}\n {%- endif %}\n {%- endfor %}\n {{- \")\" }}\n {%- else %}\n {{- '<|start_header_id|>assistant<|end_header_id|>\\n\\n' -}}\n {{- '{\"name\": \"' + tool_call.name + '\", ' }}\n {{- '\"parameters\": ' }}\n {{- tool_call.arguments | tojson }}\n {{- \"}\" }}\n {%- endif %}\n {%- if builtin_tools is defined %}\n {#- This means we're in ipython mode #}\n {{- \"<|eom_id|>\" }}\n {%- else %}\n {{- \"<|eot_id|>\" }}\n {%- endif %}\n {%- elif message.role == \"tool\" or message.role == \"ipython\" %}\n {{- \"<|start_header_id|>ipython<|end_header_id|>\\n\\n\" }}\n {%- if message.content is mapping or message.content is iterable %}\n {{- message.content | tojson }}\n {%- else %}\n {{- message.content }}\n {%- endif %}\n {{- \"<|eot_id|>\" }}\n {%- endif %}\n{%- endfor %}\n{%- if add_generation_prompt %}\n {{- '<|start_header_id|>assistant<|end_header_id|>\\n\\n' }}\n{%- endif %}\n",
+ "clean_up_tokenization_spaces": true,
+ "eos_token": "<|eot_id|>",
+ "extra_special_tokens": {},
+ "model_input_names": [
+ "input_ids",
+ "attention_mask"
+ ],
+ "model_max_length": 131072,
+ "tokenizer_class": "PreTrainedTokenizer"
+}
diff --git a/tests/_test_utils/torch/transformers_models.py b/tests/_test_utils/torch/transformers_models.py
index 01e8fa4d38..54bc10a562 100644
--- a/tests/_test_utils/torch/transformers_models.py
+++ b/tests/_test_utils/torch/transformers_models.py
@@ -39,6 +39,12 @@
SEED = 1234
+TINY_TOKENIZER_PATH = Path(__file__).parent / "tokenizer"
+
+
+def get_tiny_tokenizer() -> "transformers.PreTrainedTokenizerBase":
+ return AutoTokenizer.from_pretrained(TINY_TOKENIZER_PATH)
+
##### Qwen3 #####
def get_tiny_qwen3(**config_kwargs) -> PreTrainedModel:
@@ -66,9 +72,7 @@ def create_tiny_qwen3_dir(
) -> Path | tuple[Path, PreTrainedModel]:
qwen3_dir = Path(tmp_path) / "tiny_qwen3"
if with_tokenizer:
- tokenizer = AutoTokenizer.from_pretrained(
- "hf-internal-testing/tiny-random-LlamaForCausalLM"
- )
+ tokenizer = get_tiny_tokenizer()
tokenizer.save_pretrained(qwen3_dir)
config_kwargs["vocab_size"] = tokenizer.vocab_size
tiny_qwen3 = get_tiny_qwen3(**config_kwargs)
@@ -149,9 +153,7 @@ def create_tiny_llama_dir(
) -> Path:
llama_dir = Path(tmp_path) / "tiny_llama"
if with_tokenizer:
- tokenizer = AutoTokenizer.from_pretrained(
- "hf-internal-testing/tiny-random-LlamaForCausalLM"
- )
+ tokenizer = get_tiny_tokenizer()
tokenizer.save_pretrained(llama_dir)
config_kwargs["vocab_size"] = tokenizer.vocab_size
diff --git a/tests/conftest.py b/tests/conftest.py
index 53a2330c22..a4e65ff2ae 100644
--- a/tests/conftest.py
+++ b/tests/conftest.py
@@ -115,7 +115,7 @@ def enable_hf_checkpointing():
mto.enable_huggingface_checkpointing()
-@pytest.fixture
+@pytest.fixture(scope="session")
def project_root_path(request: pytest.FixtureRequest) -> Path:
"""Fixture providing the project root path for tests."""
return Path(request.config.rootpath)
diff --git a/tests/examples/puzzletron/mbridge_distillation/test_distill_hf.py b/tests/examples/puzzletron/mbridge_distillation/test_distill_hf.py
new file mode 100644
index 0000000000..6ca0ac0dd9
--- /dev/null
+++ b/tests/examples/puzzletron/mbridge_distillation/test_distill_hf.py
@@ -0,0 +1,142 @@
+# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+# SPDX-License-Identifier: Apache-2.0
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""Tests for distill_hf.py script."""
+
+from pathlib import Path
+
+import torch
+from _test_utils.examples.run_command import extend_cmd_parts, run_example_command
+from _test_utils.torch.distributed.utils import get_free_port
+from _test_utils.torch.puzzletron.utils import create_and_save_small_hf_model
+from _test_utils.torch.transformers_models import get_tiny_tokenizer
+from transformers import AutoModelForCausalLM
+
+from modelopt.torch.puzzletron.anymodel import convert_model
+
+
+def test_distill_hf(project_root_path: Path, tmp_path: Path):
+ """Integration test for distill_hf.py.
+
+ Creates Qwen3 models programmatically, converts them to heterogeneous format (AnyModel),
+ and runs mbridge distillation. The models are created with reduced size for faster testing.
+ Models are converted to include block_configs.
+ """
+ # Prepare student and teacher models
+ student_hf_dir, student_anymodel_dir, _, teacher_anymodel_dir = (
+ _prepare_student_and_teacher_models(project_root_path, tmp_path)
+ )
+
+ output_dir = tmp_path / "distill_output"
+ hf_export_dir = tmp_path / "hf_export"
+
+ # Build command-line arguments for distill_hf.py
+ nproc_per_node = torch.cuda.device_count()
+ tp_size = nproc_per_node
+ train_iters = 5
+
+ cmd_parts = [
+ "torchrun",
+ f"--nproc_per_node={nproc_per_node}",
+ "--master-addr",
+ "127.0.0.1",
+ "--master-port",
+ str(get_free_port()),
+ "distill_hf.py",
+ "--use_mock_data",
+ ]
+ extend_cmd_parts(
+ cmd_parts,
+ student_hf_path=student_anymodel_dir,
+ teacher_hf_path=teacher_anymodel_dir,
+ output_dir=output_dir,
+ tp_size=tp_size,
+ pp_size=1,
+ seq_length=128,
+ split="99,1,0",
+ mbs=1,
+ gbs=4,
+ train_iters=train_iters,
+ lr=0.0001,
+ min_lr=1e-5,
+ lr_warmup_iters=2,
+ eval_interval=100,
+ eval_iters=0,
+ log_interval=5,
+ hf_export_path=hf_export_dir,
+ hf_model=student_hf_dir,
+ )
+
+ run_example_command(cmd_parts, example_path="puzzletron/mbridge_distillation")
+
+ # Check that distillation checkpoint contains run_config.yaml
+ run_config_path = output_dir / "checkpoints" / f"iter_{train_iters:07d}" / "run_config.yaml"
+ assert run_config_path.exists(), f"Expected run_config.yaml to exist at: {run_config_path}"
+
+ # Verify that the distilled model can be loaded in HuggingFace format
+ model = AutoModelForCausalLM.from_pretrained(hf_export_dir)
+ assert model is not None, "Failed to load distilled model with AutoModelForCausalLM"
+
+
+def _prepare_student_and_teacher_models(
+ project_root_path: Path, tmp_path: Path
+) -> tuple[Path, Path, Path, Path]:
+ """Prepare student and teacher models for distillation.
+
+ Creates Qwen3 models programmatically, converts them to heterogeneous format (AnyModel),
+ and returns the paths to the converted checkpoints.
+
+ """
+
+ # Create temporary directories for models
+ student_hf_dir = tmp_path / "student_hf"
+ teacher_hf_dir = tmp_path / "teacher_hf"
+
+ # Create tokenizer (uses local tokenizer from test resources)
+ tokenizer = get_tiny_tokenizer()
+
+ # Create student model using utility function (loads config from Hub).
+ # TODO: Make the student model using different ffn sizes across layers.
+ create_and_save_small_hf_model(
+ output_path=str(student_hf_dir),
+ tokenizer=tokenizer,
+ hf_model_name="Qwen/Qwen3-0.6B",
+ hybrid_override_pattern=None,
+ )
+
+ # Create teacher model (same as student for testing)
+ create_and_save_small_hf_model(
+ output_path=str(teacher_hf_dir),
+ tokenizer=tokenizer,
+ hf_model_name="Qwen/Qwen3-0.6B",
+ hybrid_override_pattern=None,
+ )
+
+ # Convert models to AnyModel format BEFORE distillation
+ # This is needed as converted checkpoints will be used as input for distillation later
+ student_anymodel_dir = tmp_path / "student_anymodel"
+ teacher_anymodel_dir = tmp_path / "teacher_anymodel"
+
+ convert_model(
+ input_dir=str(student_hf_dir), output_dir=str(student_anymodel_dir), converter="qwen3"
+ )
+
+ convert_model(
+ input_dir=str(teacher_hf_dir), output_dir=str(teacher_anymodel_dir), converter="qwen3"
+ )
+ print("Models converted to AnyModel format:")
+ print(f" Student AnyModel: {student_anymodel_dir}")
+ print(f" Teacher AnyModel: {teacher_anymodel_dir}")
+
+ return student_hf_dir, student_anymodel_dir, teacher_hf_dir, teacher_anymodel_dir
diff --git a/tests/gpu/torch/puzzletron/nas/plugins/test_nas_convert.py b/tests/gpu/torch/puzzletron/nas/plugins/test_nas_convert.py
new file mode 100644
index 0000000000..25991f1c74
--- /dev/null
+++ b/tests/gpu/torch/puzzletron/nas/plugins/test_nas_convert.py
@@ -0,0 +1,142 @@
+# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+# SPDX-License-Identifier: Apache-2.0
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import os
+from datetime import timedelta
+from functools import partial
+from pathlib import Path
+
+import torch
+from _test_utils.torch.distributed.utils import spawn_multiprocess_job
+from _test_utils.torch.puzzletron.utils import setup_test_model_and_data
+
+import modelopt.torch.nas as mtn
+import modelopt.torch.utils.distributed as dist
+from modelopt.torch.puzzletron.nas.plugins.puzzletron_nas_plugin import PuzzletronModel
+
+
+def test_nas_convert_ffn_pruning(project_root_path: Path, tmp_path: Path):
+ spawn_multiprocess_job(
+ size=torch.cuda.device_count(),
+ job=partial(_test_nas_convert_ffn_pruning_multiprocess_job, project_root_path, tmp_path),
+ backend="nccl",
+ )
+
+
+def _test_nas_convert_ffn_pruning_multiprocess_job(
+ project_root_path: Path, tmp_path: Path, rank: int, size: int
+):
+ dist.setup(timeout=timedelta(10))
+ # Setup the test model and data.
+ puzzle_dir, llama_checkpoint_path, dataset_path = setup_test_model_and_data(
+ tmp_path, rank, "meta-llama/Llama-3.1-8B-Instruct"
+ )
+ hydra_config_dir = project_root_path / "tests/gpu/torch/puzzletron/resources/configs"
+ hydra_config_name = "meta-llama/Llama-3.1-8B-Instruct/Llama-3.1-8B-Instruct"
+
+ #
+ # Run the mnt.convert() step
+ #
+ input_model = PuzzletronModel()
+ mtn.convert(
+ input_model,
+ mode=[
+ (
+ "puzzletron",
+ {
+ "puzzle_dir": str(puzzle_dir),
+ "input_model_path": str(llama_checkpoint_path),
+ "hydra_config_dir": str(hydra_config_dir),
+ "hydra_config_name": hydra_config_name,
+ "dataset_path": str(dataset_path),
+ },
+ )
+ ],
+ )
+
+ #
+ # Check assertions
+ #
+ if rank == 0:
+ # assertions for the score_pruning_activations step
+ rank = int(os.environ["RANK"])
+ rank_filepath = (
+ f"pruning/pruning_scores/ffn_iterative/100samples_diverse_mini/rank_{rank}.pth"
+ )
+ assert (puzzle_dir / rank_filepath).is_file()
+
+ # assertions for the pruning_ckpts step
+ assert (puzzle_dir / "ckpts/ffn_256_attn_no_op").exists()
+
+ dist.cleanup()
+
+
+def test_nas_convert_attn_pruning(project_root_path: Path, tmp_path: Path):
+ spawn_multiprocess_job(
+ size=torch.cuda.device_count(),
+ job=partial(_test_nas_convert_attn_pruning_multiprocess_job, project_root_path, tmp_path),
+ backend="nccl",
+ )
+
+
+def _test_nas_convert_attn_pruning_multiprocess_job(
+ project_root_path: Path, tmp_path: Path, rank: int, size: int
+):
+ dist.setup(timeout=timedelta(10))
+ # Setup the test model and data.
+ puzzle_dir, llama_checkpoint_path, dataset_path = setup_test_model_and_data(
+ tmp_path, rank, "meta-llama/Llama-3.1-8B-Instruct"
+ )
+ hydra_config_dir = project_root_path / "tests/gpu/torch/puzzletron/resources/configs"
+ hydra_config_name = "meta-llama/Llama-3.1-8B-Instruct/Llama-3.1-8B-Instruct-attn-pruning"
+
+ #
+ # Run the mnt.convert() step
+ #
+ input_model = PuzzletronModel()
+ mtn.convert(
+ input_model,
+ mode=[
+ (
+ "puzzletron",
+ {
+ "puzzle_dir": str(puzzle_dir),
+ "input_model_path": str(llama_checkpoint_path),
+ "hydra_config_dir": str(hydra_config_dir),
+ "hydra_config_name": hydra_config_name,
+ "dataset_path": str(dataset_path),
+ },
+ )
+ ],
+ )
+
+ #
+ # Check assertions
+ #
+ if rank == 0:
+ # assertions for the score_pruning_activations step
+ rank = int(os.environ["RANK"])
+ rank_filepath = (
+ f"pruning/pruning_scores/attn_independent_kv_head_contribution/"
+ f"100samples_diverse_mini/rank_{rank}.pth"
+ )
+ assert (puzzle_dir / rank_filepath).is_file()
+
+ # assertions for the pruning_ckpts step
+ assert (puzzle_dir / "ckpts/n_heads_in_group8").exists()
+ assert (puzzle_dir / "ckpts/n_heads_in_group16").exists()
+ assert (puzzle_dir / "ckpts/n_heads_in_group32").exists()
+
+ dist.cleanup()
diff --git a/tests/gpu/torch/puzzletron/nas/plugins/test_nas_search.py b/tests/gpu/torch/puzzletron/nas/plugins/test_nas_search.py
new file mode 100644
index 0000000000..aede36bded
--- /dev/null
+++ b/tests/gpu/torch/puzzletron/nas/plugins/test_nas_search.py
@@ -0,0 +1,102 @@
+# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+# SPDX-License-Identifier: Apache-2.0
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from datetime import timedelta
+from functools import partial
+from pathlib import Path
+
+import torch
+from _test_utils.torch.distributed.utils import spawn_multiprocess_job
+from _test_utils.torch.puzzletron.utils import setup_test_model_and_data
+
+import modelopt.torch.nas as mtn
+import modelopt.torch.utils.distributed as dist
+from modelopt.torch.puzzletron.nas.plugins.puzzletron_nas_plugin import PuzzletronModel
+
+
+def test_nas_search(project_root_path: Path, tmp_path: Path):
+ spawn_multiprocess_job(
+ size=torch.cuda.device_count(),
+ job=partial(_test_nas_search_multiprocess_job, project_root_path, tmp_path),
+ backend="nccl",
+ )
+
+
+def _test_nas_search_multiprocess_job(
+ project_root_path: Path, tmp_path: Path, rank: int, size: int
+):
+ dist.setup(timeout=timedelta(10))
+ # Setup the test model and data.
+ puzzle_dir, llama_checkpoint_path, dataset_path = setup_test_model_and_data(
+ tmp_path, rank, "meta-llama/Llama-3.1-8B-Instruct"
+ )
+ hydra_config_dir = project_root_path / "tests/gpu/torch/puzzletron/resources/configs"
+ hydra_config_name = "meta-llama/Llama-3.1-8B-Instruct/Llama-3.1-8B-Instruct"
+
+ #
+ # Run the mnt.convert() step
+ #
+ input_model = PuzzletronModel()
+ converted_model = mtn.convert(
+ input_model,
+ mode=[
+ (
+ "puzzletron",
+ {
+ "puzzle_dir": str(puzzle_dir),
+ "input_model_path": str(llama_checkpoint_path),
+ "hydra_config_dir": str(hydra_config_dir),
+ "hydra_config_name": hydra_config_name,
+ "dataset_path": str(dataset_path),
+ },
+ )
+ ],
+ )
+
+ #
+ # Run the mnt.search() step
+ #
+ mtn.search(
+ converted_model,
+ constraints={}, # this is not used as the search space is defined in the hydra config
+ dummy_input=None, # Not used
+ config={}, # this is not used as the search space is defined in the hydra config
+ )
+
+ #
+ # Check assertions for mtn.search() step
+ #
+ if rank == 0:
+ # assertions for the build_library_and_stats step
+ assert (puzzle_dir / "replacement_library.json").is_file()
+ assert (puzzle_dir / "subblock_stats.json").is_file()
+
+ # assertions for the scoring step
+ solution_0_filepath = (
+ puzzle_dir / "single_sequence_replacement_solutions--validation/solution_0.json"
+ )
+
+ assert solution_0_filepath.exists()
+
+ # assertions for the mip_and_realize_models step
+ solution_0_ckpt_config_path = (
+ puzzle_dir
+ / "mip/puzzle_solutions/target_memory_780000MiB/solutions--checkpoints/solution_0/config.json"
+ )
+
+ assert solution_0_ckpt_config_path.exists()
+ assert (puzzle_dir / "mip/puzzle_solutions/target_memory_780000MiB/solutions.json").exists()
+
+ dist.cleanup()
diff --git a/tests/gpu/torch/puzzletron/resources/configs/Qwen/Qwen2.5-7B-Instruct/Qwen2.5-7B-Instruct.yaml b/tests/gpu/torch/puzzletron/resources/configs/Qwen/Qwen2.5-7B-Instruct/Qwen2.5-7B-Instruct.yaml
new file mode 100644
index 0000000000..2843f0b97a
--- /dev/null
+++ b/tests/gpu/torch/puzzletron/resources/configs/Qwen/Qwen2.5-7B-Instruct/Qwen2.5-7B-Instruct.yaml
@@ -0,0 +1,113 @@
+# @package _global_
+defaults:
+ - /Qwen/Qwen2.5-7B-Instruct/pruning@pruning: ffn_pruning
+ - /validate_solutions_defaults@scoring
+ - /validate_solutions_defaults@realize_model
+ - _self_
+
+puzzle_dir: ???
+teacher_dir: ${puzzle_dir}/ckpts/teacher/
+replacement_library_path: ${puzzle_dir}/replacement_library.json
+dataset_path: ??? # path to v0.4_mini
+
+skip_realize_model: false
+
+descriptor: qwen2
+
+build_replacement_library:
+ add_ffn_no_ops: true
+ add_attention_no_ops: true
+
+calc_subblock_stats:
+ batch_sizes: [64, 96, 128]
+ prefill_seq_len: 4096
+ generation_seq_len: 4096
+ num_active_tokens_override: # Optional override for sequence lengths
+ prefill_queue_size: 0
+ allocate_prefill_query: false
+ benchmark_iterations: # Set to a number (e.g., 1000) to enable runtime benchmarking
+ merge_with_existing_stats: false
+ subblock_stats_filename: "subblock_stats.json"
+ moe_stats_filename: "moe_stats.json"
+
+scoring:
+ descriptor: ${descriptor}
+
+ solutions_to_validate:
+ skip_existing_solutions: true
+
+ replacement_library_path: ${replacement_library_path}
+ solutions_path: ${to_path:${puzzle_dir}/single_sequence_replacement_solutions.json}
+ teacher_dir: ${to_path:${teacher_dir}}
+ output_dir: ${puzzle_dir}/single_sequence_replacement_solutions--validation
+
+ eval_samples: 2
+ micro_batch_size: 1
+ dataset_path: ${dataset_path}/valid
+ seed: 42
+ shuffle_seed: 444
+
+mip:
+ single_block_replacement_validation_dir: ${to_path:${scoring.output_dir}}
+ subblock_stats_path: ${to_path:${puzzle_dir}/${calc_subblock_stats.subblock_stats_filename}}
+ output_path: ${to_path:${puzzle_dir}/mip/puzzle_solutions}
+ gathered_metrics_path:
+ puzzle_profile:
+
+ # puzzle_profile:
+ objective: metrics.cosine_embedding_loss_hidden_states
+ bigger_is_better: false
+ num_solutions: 1
+ minimal_diversity: 2
+
+ subblock_stats_args:
+ - batch_size: 96
+ weights_dtype: torch.bfloat16
+ activations_dtype: torch.bfloat16
+ kv_cache_dtype: torch.bfloat16
+
+ report_additional_costs:
+ - stats.memory_mib
+ - stats.num_params
+ - stats.num_kv_heads
+ - stats.has_attention
+ - stats.has_ffn
+ - stats.kv_cache_memory_mib
+ - stats.attention_memory_mib
+ - stats.ffn_memory_mib
+ - stats.ffn_num_params
+ - stats.attention_num_params
+
+ human_constraints:
+ target_memory: 780_000 # 78_000
+
+ mip_constraints:
+ use_greedy_search: false
+ is_multi_layer_puzzle: true
+ metric_overrides:
+ constrain_search_func:
+ max_seconds_per_solution: 60
+
+realize_model:
+ descriptor: ${descriptor}
+
+ teacher_dir: ${to_path:${teacher_dir}}
+ tokenizer_name: ${to_path:${teacher_dir}}
+ replacement_library_path: ${replacement_library_path}
+ save_models: true
+ solutions_path: # Filled dynamically
+
+ # Validate params
+ skip_validation: false # To enable validation of the model solution set `skip_validation` as False
+ eval_samples: 2
+ micro_batch_size: 1
+ dataset_path: ${dataset_path}/valid
+ seed: 42
+ shuffle_seed: 444
+
+nccl_timeout_minutes: ${timedelta_minutes:10}
+
+# This section redirects Hydra outputs
+hydra:
+ run:
+ dir: ${puzzle_dir}/hydra_logs/${now:%Y-%m-%d}/${now:%H-%M-%S}
diff --git a/tests/gpu/torch/puzzletron/resources/configs/Qwen/Qwen2.5-7B-Instruct/pruning/ffn_pruning.yaml b/tests/gpu/torch/puzzletron/resources/configs/Qwen/Qwen2.5-7B-Instruct/pruning/ffn_pruning.yaml
new file mode 100644
index 0000000000..cf6201080c
--- /dev/null
+++ b/tests/gpu/torch/puzzletron/resources/configs/Qwen/Qwen2.5-7B-Instruct/pruning/ffn_pruning.yaml
@@ -0,0 +1,7 @@
+defaults:
+ - /pruning/ffn_pruning_base@_here_
+ - _self_
+
+pruning_mixin:
+ layer_descriptor:
+ _target_: modelopt.torch.puzzletron.anymodel.models.qwen2.qwen2_model_descriptor.Qwen2FFNIntermediateLayerDescriptor
diff --git a/tests/gpu/torch/puzzletron/resources/configs/Qwen/Qwen3-8B/Qwen3-8B.yaml b/tests/gpu/torch/puzzletron/resources/configs/Qwen/Qwen3-8B/Qwen3-8B.yaml
new file mode 100644
index 0000000000..cd82a47271
--- /dev/null
+++ b/tests/gpu/torch/puzzletron/resources/configs/Qwen/Qwen3-8B/Qwen3-8B.yaml
@@ -0,0 +1,112 @@
+# @package _global_
+defaults:
+ - /Qwen/Qwen3-8B/pruning@pruning: ffn_pruning
+ - /validate_solutions_defaults@scoring
+ - /validate_solutions_defaults@realize_model
+ - _self_
+
+puzzle_dir: ???
+teacher_dir: ${puzzle_dir}/ckpts/teacher/
+replacement_library_path: ${puzzle_dir}/replacement_library.json
+dataset_path: ??? # path to v0.4_mini
+
+skip_realize_model: false
+
+descriptor: qwen3
+
+build_replacement_library:
+ add_ffn_no_ops: true
+ add_attention_no_ops: true
+
+calc_subblock_stats:
+ batch_sizes: [64, 96, 128]
+ prefill_seq_len: 4096
+ generation_seq_len: 4096
+ num_active_tokens_override: # Optional override for sequence lengths
+ prefill_queue_size: 0
+ benchmark_iterations: # Set to a number (e.g., 1000) to enable runtime benchmarking
+ merge_with_existing_stats: false
+ subblock_stats_filename: "subblock_stats.json"
+ moe_stats_filename: "moe_stats.json"
+
+scoring:
+ descriptor: ${descriptor}
+
+ solutions_to_validate:
+ skip_existing_solutions: true
+
+ replacement_library_path: ${replacement_library_path}
+ solutions_path: ${to_path:${puzzle_dir}/single_sequence_replacement_solutions.json}
+ teacher_dir: ${to_path:${teacher_dir}}
+ output_dir: ${puzzle_dir}/single_sequence_replacement_solutions--validation
+
+ eval_samples: 2
+ micro_batch_size: 1
+ dataset_path: ${dataset_path}/valid
+ seed: 42
+ shuffle_seed: 444
+
+mip:
+ single_block_replacement_validation_dir: ${to_path:${scoring.output_dir}}
+ subblock_stats_path: ${to_path:${puzzle_dir}/${calc_subblock_stats.subblock_stats_filename}}
+ output_path: ${to_path:${puzzle_dir}/mip/puzzle_solutions}
+ gathered_metrics_path:
+ puzzle_profile:
+
+ # puzzle_profile:
+ objective: metrics.cosine_embedding_loss_hidden_states
+ bigger_is_better: false
+ num_solutions: 1
+ minimal_diversity: 2
+
+ subblock_stats_args:
+ - batch_size: 96
+ weights_dtype: torch.bfloat16
+ activations_dtype: torch.bfloat16
+ kv_cache_dtype: torch.bfloat16
+
+ report_additional_costs:
+ - stats.memory_mib
+ - stats.num_params
+ - stats.num_kv_heads
+ - stats.has_attention
+ - stats.has_ffn
+ - stats.kv_cache_memory_mib
+ - stats.attention_memory_mib
+ - stats.ffn_memory_mib
+ - stats.ffn_num_params
+ - stats.attention_num_params
+
+ human_constraints:
+ target_memory: 780_000 # 78_000
+
+ mip_constraints:
+ use_greedy_search: false
+ is_multi_layer_puzzle: true
+ metric_overrides:
+ constrain_search_func:
+ max_seconds_per_solution: 60
+
+realize_model:
+ descriptor: ${descriptor}
+
+ teacher_dir: ${to_path:${teacher_dir}}
+ tokenizer_name: ${to_path:${teacher_dir}}
+ replacement_library_path: ${replacement_library_path}
+ save_models: true
+ solutions_path: # Filled dynamically
+
+ # Validate params
+ skip_validation: false # To enable validation of the model solution set `skip_validation` as False
+ eval_samples: 2
+ micro_batch_size: 1
+ dataset_path: ${dataset_path}/valid
+ seed: 42
+ shuffle_seed: 444
+
+nccl_timeout_minutes: ${timedelta_minutes:10}
+
+# This section redirects Hydra outputs
+hydra:
+ run:
+ dir: ${puzzle_dir}/hydra_logs/${now:%Y-%m-%d}/${now:%H-%M-%S}
diff --git a/tests/gpu/torch/puzzletron/resources/configs/Qwen/Qwen3-8B/pruning/ffn_pruning.yaml b/tests/gpu/torch/puzzletron/resources/configs/Qwen/Qwen3-8B/pruning/ffn_pruning.yaml
new file mode 100644
index 0000000000..6bfeec715c
--- /dev/null
+++ b/tests/gpu/torch/puzzletron/resources/configs/Qwen/Qwen3-8B/pruning/ffn_pruning.yaml
@@ -0,0 +1,7 @@
+defaults:
+ - /pruning/ffn_pruning_base@_here_
+ - _self_
+
+pruning_mixin:
+ layer_descriptor:
+ _target_: modelopt.torch.puzzletron.anymodel.models.qwen3.qwen3_model_descriptor.Qwen3FFNIntermediateLayerDescriptor
diff --git a/tests/gpu/torch/puzzletron/resources/configs/Qwen/Qwen3-VL-30B-A3B-Instruct/Qwen3-VL-30B-A3B-Instruct.yaml b/tests/gpu/torch/puzzletron/resources/configs/Qwen/Qwen3-VL-30B-A3B-Instruct/Qwen3-VL-30B-A3B-Instruct.yaml
new file mode 100644
index 0000000000..00b21ea979
--- /dev/null
+++ b/tests/gpu/torch/puzzletron/resources/configs/Qwen/Qwen3-VL-30B-A3B-Instruct/Qwen3-VL-30B-A3B-Instruct.yaml
@@ -0,0 +1,113 @@
+# @package _global_
+defaults:
+ - /Qwen/Qwen3-VL-30B-A3B-Instruct/pruning@pruning: expert_pruning
+ - /validate_solutions_defaults@scoring
+ - /validate_solutions_defaults@realize_model
+ - _self_
+
+puzzle_dir: ???
+teacher_dir: ${puzzle_dir}/ckpts/teacher/
+replacement_library_path: ${puzzle_dir}/replacement_library.json
+dataset_path: ??? # path to v0.4_mini
+
+skip_realize_model: false
+
+descriptor: qwen3_vl
+
+build_replacement_library:
+ add_ffn_no_ops: true
+ add_attention_no_ops: true
+
+calc_subblock_stats:
+ batch_sizes: [64, 96, 128]
+ prefill_seq_len: 4096
+ generation_seq_len: 4096
+ num_active_tokens_override: # Optional override for sequence lengths
+ prefill_queue_size: 0
+ benchmark_iterations: # Set to a number (e.g., 1000) to enable runtime benchmarking
+ merge_with_existing_stats: false
+ subblock_stats_filename: "subblock_stats.json"
+ moe_stats_filename: "moe_stats.json"
+
+scoring:
+ descriptor: ${descriptor}
+
+ solutions_to_validate:
+ skip_existing_solutions: true
+
+ replacement_library_path: ${replacement_library_path}
+ solutions_path: ${to_path:${puzzle_dir}/single_sequence_replacement_solutions.json}
+ teacher_dir: ${to_path:${teacher_dir}}
+ output_dir: ${puzzle_dir}/single_sequence_replacement_solutions--validation
+
+ eval_samples: 2
+ micro_batch_size: 1
+ dataset_path: ${dataset_path}/valid
+ seed: 42
+ shuffle_seed: 444
+
+mip:
+ single_block_replacement_validation_dir: ${to_path:${scoring.output_dir}}
+ subblock_stats_path: ${to_path:${puzzle_dir}/${calc_subblock_stats.subblock_stats_filename}}
+ output_path: ${to_path:${puzzle_dir}/mip/puzzle_solutions}
+ gathered_metrics_path:
+ puzzle_profile:
+
+ # puzzle_profile:
+ objective: metrics.cosine_embedding_loss_hidden_states
+ bigger_is_better: false
+ num_solutions: 1
+ minimal_diversity: 2
+
+ subblock_stats_args:
+ - batch_size: 96
+ weights_dtype: torch.bfloat16
+ activations_dtype: torch.bfloat16
+ kv_cache_dtype: torch.bfloat16
+
+ report_additional_costs:
+ - stats.memory_mib
+ - stats.num_params
+ - stats.num_kv_heads
+ - stats.has_attention
+ - stats.has_ffn
+ - stats.kv_cache_memory_mib
+ - stats.attention_memory_mib
+ - stats.ffn_memory_mib
+ - stats.ffn_num_params
+ - stats.attention_num_params
+ - stats.num_local_experts
+
+ human_constraints:
+
+ mip_constraints:
+ - stats.num_local_experts: 1472 # same constraint as nemotron-3-nano for test consistency
+ use_greedy_search: false
+ is_multi_layer_puzzle: true
+ metric_overrides:
+ constrain_search_func:
+ max_seconds_per_solution: 60
+
+realize_model:
+ descriptor: ${descriptor}
+
+ teacher_dir: ${to_path:${teacher_dir}}
+ tokenizer_name: ${to_path:${teacher_dir}}
+ replacement_library_path: ${replacement_library_path}
+ save_models: true
+ solutions_path: # Filled dynamically
+
+ # Validate params
+ skip_validation: false # To enable validation of the model solution set `skip_validation` as False
+ eval_samples: 2
+ micro_batch_size: 1
+ dataset_path: ${dataset_path}/valid
+ seed: 42
+ shuffle_seed: 444
+
+nccl_timeout_minutes: ${timedelta_minutes:10}
+
+# This section redirects Hydra outputs
+hydra:
+ run:
+ dir: ${puzzle_dir}/hydra_logs/${now:%Y-%m-%d}/${now:%H-%M-%S}
diff --git a/tests/gpu/torch/puzzletron/resources/configs/Qwen/Qwen3-VL-30B-A3B-Instruct/pruning/expert_pruning.yaml b/tests/gpu/torch/puzzletron/resources/configs/Qwen/Qwen3-VL-30B-A3B-Instruct/pruning/expert_pruning.yaml
new file mode 100644
index 0000000000..4e0786dc7a
--- /dev/null
+++ b/tests/gpu/torch/puzzletron/resources/configs/Qwen/Qwen3-VL-30B-A3B-Instruct/pruning/expert_pruning.yaml
@@ -0,0 +1,20 @@
+defaults:
+ - /pruning/pruning_defaults@_here_
+
+eval_samples: 10
+activations_log_dir: ${puzzle_dir}/pruning/pruning_scores/expert_removal/${pruning.experiment_id}
+pruning_mixin:
+ _target_: modelopt.torch.puzzletron.pruning.expert_removal_pruning_mixin.ExpertRemovalPruningMixIn
+ layer_descriptor:
+ _target_: modelopt.torch.puzzletron.anymodel.models.qwen3_vl.qwen3_vl_model_descriptor.Qwen3VLExpertRemovalLayerDescriptor
+ target_name: "mlp"
+
+hook_class: ${get_object:modelopt.torch.prune.importance_hooks.expert_removal_hooks.Qwen3VLRemoveExpertsIndependentHook}
+activation_hooks_kwargs:
+
+# num_experts_to_keep must be >= num_experts_per_tok (can't route to more experts than exist)
+num_experts_to_keep_list: [8] # num_experts in test model is 16, num_experts_per_tok is 8
+mlp_init_mode: "ExpertRemoval"
+mlp_init_config_yaml:
+ expert_scores_key: "expert_ranks_mse"
+ layer_prefix_template: "model.language_model.layers.{layer_idx}.mlp"
diff --git a/tests/gpu/torch/puzzletron/resources/configs/meta-llama/Llama-3.1-8B-Instruct/Llama-3.1-8B-Instruct-attn-pruning.yaml b/tests/gpu/torch/puzzletron/resources/configs/meta-llama/Llama-3.1-8B-Instruct/Llama-3.1-8B-Instruct-attn-pruning.yaml
new file mode 100644
index 0000000000..57051431a1
--- /dev/null
+++ b/tests/gpu/torch/puzzletron/resources/configs/meta-llama/Llama-3.1-8B-Instruct/Llama-3.1-8B-Instruct-attn-pruning.yaml
@@ -0,0 +1,10 @@
+# @package _global_
+defaults:
+ - /meta-llama/Llama-3.1-8B-Instruct/pruning@pruning: attn_pruning
+ - _self_
+
+descriptor: llama
+
+puzzle_dir: ???
+teacher_dir: ${puzzle_dir}/ckpts/teacher/
+dataset_path: ???
diff --git a/tests/gpu/torch/puzzletron/resources/configs/meta-llama/Llama-3.1-8B-Instruct/Llama-3.1-8B-Instruct.yaml b/tests/gpu/torch/puzzletron/resources/configs/meta-llama/Llama-3.1-8B-Instruct/Llama-3.1-8B-Instruct.yaml
new file mode 100644
index 0000000000..8e2e0786b3
--- /dev/null
+++ b/tests/gpu/torch/puzzletron/resources/configs/meta-llama/Llama-3.1-8B-Instruct/Llama-3.1-8B-Instruct.yaml
@@ -0,0 +1,106 @@
+# @package _global_
+defaults:
+ - /meta-llama/Llama-3.1-8B-Instruct/pruning@pruning: ffn_pruning
+ - /validate_solutions_defaults@scoring
+ - /validate_solutions_defaults@realize_model
+ - _self_
+
+descriptor: llama
+
+puzzle_dir: ???
+teacher_dir: ${puzzle_dir}/ckpts/teacher/
+replacement_library_path: ${puzzle_dir}/replacement_library.json
+dataset_path: ??? # path to v0.4_mini
+
+skip_realize_model: false
+
+build_replacement_library:
+ add_ffn_no_ops: true
+ add_attention_no_ops: true
+
+calc_subblock_stats:
+ batch_sizes: [64, 96, 128]
+ prefill_seq_len: 4096
+ generation_seq_len: 4096
+ num_active_tokens_override: # Optional override for sequence lengths
+ prefill_queue_size: 0
+ allocate_prefill_query: false
+ benchmark_iterations: # Set to a number (e.g., 1000) to enable runtime benchmarking
+ merge_with_existing_stats: false
+ subblock_stats_filename: "subblock_stats.json"
+ moe_stats_filename: "moe_stats.json"
+
+scoring:
+ descriptor: ${descriptor}
+ solutions_to_validate:
+ skip_existing_solutions: true
+
+ replacement_library_path: ${replacement_library_path}
+ solutions_path: ${to_path:${puzzle_dir}/single_sequence_replacement_solutions.json}
+ teacher_dir: ${to_path:${teacher_dir}}
+ output_dir: ${puzzle_dir}/single_sequence_replacement_solutions--validation
+
+ eval_samples: 2
+ micro_batch_size: 1
+ dataset_path: ${dataset_path}/valid
+ seed: 42
+ shuffle_seed: 444
+
+mip:
+ single_block_replacement_validation_dir: ${to_path:${scoring.output_dir}}
+ subblock_stats_path: ${to_path:${puzzle_dir}/${calc_subblock_stats.subblock_stats_filename}}
+ output_path: ${to_path:${puzzle_dir}/mip/puzzle_solutions}
+ gathered_metrics_path:
+ puzzle_profile:
+
+ # puzzle_profile:
+ objective: metrics.cosine_embedding_loss_hidden_states
+ bigger_is_better: false
+
+ subblock_stats_args:
+ - batch_size: 96
+ weights_dtype: torch.bfloat16
+ activations_dtype: torch.bfloat16
+ kv_cache_dtype: torch.bfloat16
+
+ report_additional_costs:
+ - stats.memory_mib
+ - stats.num_params
+ - stats.num_kv_heads
+ - stats.has_attention
+ - stats.has_ffn
+ - stats.kv_cache_memory_mib
+ - stats.attention_memory_mib
+ - stats.ffn_memory_mib
+ - stats.ffn_num_params
+ - stats.attention_num_params
+
+ human_constraints:
+ target_memory: 780_000 # 78_000
+
+ mip_constraints:
+ metric_overrides:
+ max_seconds_per_solution: 60
+
+realize_model:
+ descriptor: ${descriptor}
+ teacher_dir: ${to_path:${teacher_dir}}
+ tokenizer_name: ${to_path:${teacher_dir}}
+ replacement_library_path: ${replacement_library_path}
+ save_models: true
+ solutions_path: # Filled dynamically
+
+ # Validate params
+ skip_validation: false # To enable validation of the model solution set `skip_validation` as False
+ eval_samples: 2
+ micro_batch_size: 1
+ dataset_path: ${dataset_path}/valid
+ seed: 42
+ shuffle_seed: 444
+
+nccl_timeout_minutes: ${timedelta_minutes:10}
+
+# This section redirects Hydra outputs
+hydra:
+ run:
+ dir: ${puzzle_dir}/hydra_logs/${now:%Y-%m-%d}/${now:%H-%M-%S}
diff --git a/tests/gpu/torch/puzzletron/resources/configs/meta-llama/Llama-3.1-8B-Instruct/pruning/attn_pruning.yaml b/tests/gpu/torch/puzzletron/resources/configs/meta-llama/Llama-3.1-8B-Instruct/pruning/attn_pruning.yaml
new file mode 100644
index 0000000000..6e8af1f651
--- /dev/null
+++ b/tests/gpu/torch/puzzletron/resources/configs/meta-llama/Llama-3.1-8B-Instruct/pruning/attn_pruning.yaml
@@ -0,0 +1,7 @@
+defaults:
+ - /pruning/attn_pruning@_here_
+ - _self_
+
+pruning_mixin:
+ layer_descriptor:
+ _target_: modelopt.torch.puzzletron.anymodel.models.llama.llama_model_descriptor.LlamaKVHeadsLayerDescriptor
diff --git a/tests/gpu/torch/puzzletron/resources/configs/meta-llama/Llama-3.1-8B-Instruct/pruning/ffn_pruning.yaml b/tests/gpu/torch/puzzletron/resources/configs/meta-llama/Llama-3.1-8B-Instruct/pruning/ffn_pruning.yaml
new file mode 100644
index 0000000000..b30f4a17d9
--- /dev/null
+++ b/tests/gpu/torch/puzzletron/resources/configs/meta-llama/Llama-3.1-8B-Instruct/pruning/ffn_pruning.yaml
@@ -0,0 +1,7 @@
+defaults:
+ - /pruning/ffn_pruning_base@_here_
+ - _self_
+
+pruning_mixin:
+ layer_descriptor:
+ _target_: modelopt.torch.puzzletron.anymodel.models.llama.llama_model_descriptor.LlamaFFNIntermediateLayerDescriptor
diff --git a/tests/gpu/torch/puzzletron/resources/configs/meta-llama/Llama-3.2-3B-Instruct/Llama-3.2-3B-Instruct.yaml b/tests/gpu/torch/puzzletron/resources/configs/meta-llama/Llama-3.2-3B-Instruct/Llama-3.2-3B-Instruct.yaml
new file mode 100644
index 0000000000..78cb6bd73c
--- /dev/null
+++ b/tests/gpu/torch/puzzletron/resources/configs/meta-llama/Llama-3.2-3B-Instruct/Llama-3.2-3B-Instruct.yaml
@@ -0,0 +1,106 @@
+# @package _global_
+defaults:
+ - /meta-llama/Llama-3.2-3B-Instruct/pruning@pruning: ffn_pruning
+ - /validate_solutions_defaults@scoring
+ - /validate_solutions_defaults@realize_model
+ - _self_
+
+descriptor: llama
+
+puzzle_dir: ???
+teacher_dir: ${puzzle_dir}/ckpts/teacher/
+replacement_library_path: ${puzzle_dir}/replacement_library.json
+dataset_path: ??? # path to v0.4_mini
+
+skip_realize_model: false
+
+build_replacement_library:
+ add_ffn_no_ops: true
+ add_attention_no_ops: true
+
+calc_subblock_stats:
+ batch_sizes: [64, 96, 128]
+ prefill_seq_len: 4096
+ generation_seq_len: 4096
+ num_active_tokens_override: # Optional override for sequence lengths
+ prefill_queue_size: 0
+ allocate_prefill_query: false
+ benchmark_iterations: # Set to a number (e.g., 1000) to enable runtime benchmarking
+ merge_with_existing_stats: false
+ subblock_stats_filename: "subblock_stats.json"
+ moe_stats_filename: "moe_stats.json"
+
+scoring:
+ descriptor: ${descriptor}
+ solutions_to_validate:
+ skip_existing_solutions: true
+
+ replacement_library_path: ${replacement_library_path}
+ solutions_path: ${to_path:${puzzle_dir}/single_sequence_replacement_solutions.json}
+ teacher_dir: ${to_path:${teacher_dir}}
+ output_dir: ${puzzle_dir}/single_sequence_replacement_solutions--validation
+
+ eval_samples: 2
+ micro_batch_size: 1
+ dataset_path: ${dataset_path}/valid
+ seed: 42
+ shuffle_seed: 444
+
+mip:
+ single_block_replacement_validation_dir: ${to_path:${scoring.output_dir}}
+ subblock_stats_path: ${to_path:${puzzle_dir}/${calc_subblock_stats.subblock_stats_filename}}
+ output_path: ${to_path:${puzzle_dir}/mip/puzzle_solutions}
+ gathered_metrics_path:
+ puzzle_profile:
+
+ # puzzle_profile:
+ objective: metrics.cosine_embedding_loss_hidden_states
+ bigger_is_better: false
+
+ subblock_stats_args:
+ - batch_size: 96
+ weights_dtype: torch.bfloat16
+ activations_dtype: torch.bfloat16
+ kv_cache_dtype: torch.bfloat16
+
+ report_additional_costs:
+ - stats.memory_mib
+ - stats.num_params
+ - stats.num_kv_heads
+ - stats.has_attention
+ - stats.has_ffn
+ - stats.kv_cache_memory_mib
+ - stats.attention_memory_mib
+ - stats.ffn_memory_mib
+ - stats.ffn_num_params
+ - stats.attention_num_params
+
+ human_constraints:
+ target_memory: 780_000 # 78_000
+
+ mip_constraints:
+ metric_overrides:
+ max_seconds_per_solution: 60
+
+realize_model:
+ descriptor: ${descriptor}
+ teacher_dir: ${to_path:${teacher_dir}}
+ tokenizer_name: ${to_path:${teacher_dir}}
+ replacement_library_path: ${replacement_library_path}
+ save_models: true
+ solutions_path: # Filled dynamically
+
+ # Validate params
+ skip_validation: false # To enable validation of the model solution set `skip_validation` as False
+ eval_samples: 2
+ micro_batch_size: 1
+ dataset_path: ${dataset_path}/valid
+ seed: 42
+ shuffle_seed: 444
+
+nccl_timeout_minutes: ${timedelta_minutes:10}
+
+# This section redirects Hydra outputs
+hydra:
+ run:
+ dir: ${puzzle_dir}/hydra_logs/${now:%Y-%m-%d}/${now:%H-%M-%S}
diff --git a/tests/gpu/torch/puzzletron/resources/configs/meta-llama/Llama-3.2-3B-Instruct/pruning/ffn_pruning.yaml b/tests/gpu/torch/puzzletron/resources/configs/meta-llama/Llama-3.2-3B-Instruct/pruning/ffn_pruning.yaml
new file mode 100644
index 0000000000..b30f4a17d9
--- /dev/null
+++ b/tests/gpu/torch/puzzletron/resources/configs/meta-llama/Llama-3.2-3B-Instruct/pruning/ffn_pruning.yaml
@@ -0,0 +1,7 @@
+defaults:
+ - /pruning/ffn_pruning_base@_here_
+ - _self_
+
+pruning_mixin:
+ layer_descriptor:
+ _target_: modelopt.torch.puzzletron.anymodel.models.llama.llama_model_descriptor.LlamaFFNIntermediateLayerDescriptor
diff --git a/tests/gpu/torch/puzzletron/resources/configs/mistralai/Mistral-Small-24B-Instruct-2501/Mistral-Small-24B-Instruct-2501.yaml b/tests/gpu/torch/puzzletron/resources/configs/mistralai/Mistral-Small-24B-Instruct-2501/Mistral-Small-24B-Instruct-2501.yaml
new file mode 100644
index 0000000000..e042c4bb62
--- /dev/null
+++ b/tests/gpu/torch/puzzletron/resources/configs/mistralai/Mistral-Small-24B-Instruct-2501/Mistral-Small-24B-Instruct-2501.yaml
@@ -0,0 +1,112 @@
+# @package _global_
+defaults:
+ - /mistralai/Mistral-Small-24B-Instruct-2501/pruning@pruning: ffn_pruning
+ - /validate_solutions_defaults@scoring
+ - /validate_solutions_defaults@realize_model
+ - _self_
+
+puzzle_dir: ???
+teacher_dir: ${puzzle_dir}/ckpts/teacher/
+replacement_library_path: ${puzzle_dir}/replacement_library.json
+dataset_path: ??? # path to v0.4_mini
+
+skip_realize_model: false
+
+descriptor: mistral_small
+
+build_replacement_library:
+ add_ffn_no_ops: true
+ add_attention_no_ops: true
+
+calc_subblock_stats:
+ batch_sizes: [64, 96, 128]
+ prefill_seq_len: 4096
+ generation_seq_len: 4096
+ num_active_tokens_override: # Optional override for sequence lengths
+ prefill_queue_size: 0
+ benchmark_iterations: # Set to a number (e.g., 1000) to enable runtime benchmarking
+ merge_with_existing_stats: false
+ subblock_stats_filename: "subblock_stats.json"
+ moe_stats_filename: "moe_stats.json"
+
+scoring:
+ descriptor: ${descriptor}
+
+ solutions_to_validate:
+ skip_existing_solutions: true
+
+ replacement_library_path: ${replacement_library_path}
+ solutions_path: ${to_path:${puzzle_dir}/single_sequence_replacement_solutions.json}
+ teacher_dir: ${to_path:${teacher_dir}}
+ output_dir: ${puzzle_dir}/single_sequence_replacement_solutions--validation
+
+ eval_samples: 2
+ micro_batch_size: 1
+ dataset_path: ${dataset_path}/valid
+ seed: 42
+ shuffle_seed: 444
+
+mip:
+ single_block_replacement_validation_dir: ${to_path:${scoring.output_dir}}
+ subblock_stats_path: ${to_path:${puzzle_dir}/${calc_subblock_stats.subblock_stats_filename}}
+ output_path: ${to_path:${puzzle_dir}/mip/puzzle_solutions}
+ gathered_metrics_path:
+ puzzle_profile:
+
+ # puzzle_profile:
+ objective: metrics.cosine_embedding_loss_hidden_states
+ bigger_is_better: false
+ num_solutions: 1
+ minimal_diversity: 2
+
+ subblock_stats_args:
+ - batch_size: 96
+ weights_dtype: torch.bfloat16
+ activations_dtype: torch.bfloat16
+ kv_cache_dtype: torch.bfloat16
+
+ report_additional_costs:
+ - stats.memory_mib
+ - stats.num_params
+ - stats.num_kv_heads
+ - stats.has_attention
+ - stats.has_ffn
+ - stats.kv_cache_memory_mib
+ - stats.attention_memory_mib
+ - stats.ffn_memory_mib
+ - stats.ffn_num_params
+ - stats.attention_num_params
+
+ human_constraints:
+ target_memory: 780_000 # 78_000
+
+ mip_constraints:
+ use_greedy_search: false
+ is_multi_layer_puzzle: true
+ metric_overrides:
+ constrain_search_func:
+ max_seconds_per_solution: 60
+
+realize_model:
+ descriptor: ${descriptor}
+
+ teacher_dir: ${to_path:${teacher_dir}}
+ tokenizer_name: ${to_path:${teacher_dir}}
+ replacement_library_path: ${replacement_library_path}
+ save_models: true
+ solutions_path: # Filled dynamically
+
+ # Validate params
+ skip_validation: false # To enable validation of the model solution set `skip_validation` as False
+ eval_samples: 2
+ micro_batch_size: 1
+ dataset_path: ${dataset_path}/valid
+ seed: 42
+ shuffle_seed: 444
+
+nccl_timeout_minutes: ${timedelta_minutes:10}
+
+# This section redirects Hydra outputs
+hydra:
+ run:
+ dir: ${puzzle_dir}/hydra_logs/${now:%Y-%m-%d}/${now:%H-%M-%S}
diff --git a/tests/gpu/torch/puzzletron/resources/configs/mistralai/Mistral-Small-24B-Instruct-2501/pruning/ffn_pruning.yaml b/tests/gpu/torch/puzzletron/resources/configs/mistralai/Mistral-Small-24B-Instruct-2501/pruning/ffn_pruning.yaml
new file mode 100644
index 0000000000..37c21fd638
--- /dev/null
+++ b/tests/gpu/torch/puzzletron/resources/configs/mistralai/Mistral-Small-24B-Instruct-2501/pruning/ffn_pruning.yaml
@@ -0,0 +1,7 @@
+defaults:
+ - /pruning/ffn_pruning_base@_here_
+ - _self_
+
+pruning_mixin:
+ layer_descriptor:
+ _target_: modelopt.torch.puzzletron.anymodel.models.mistral_small.mistral_small_model_descriptor.MistralFFNIntermediateLayerDescriptor
diff --git a/tests/gpu/torch/puzzletron/resources/configs/nvidia/NVIDIA-Nemotron-3-Nano-30B-A3B-Base-BF16/NVIDIA-Nemotron-3-Nano-30B-A3B-Base-BF16.yaml b/tests/gpu/torch/puzzletron/resources/configs/nvidia/NVIDIA-Nemotron-3-Nano-30B-A3B-Base-BF16/NVIDIA-Nemotron-3-Nano-30B-A3B-Base-BF16.yaml
new file mode 100644
index 0000000000..ab2b09e679
--- /dev/null
+++ b/tests/gpu/torch/puzzletron/resources/configs/nvidia/NVIDIA-Nemotron-3-Nano-30B-A3B-Base-BF16/NVIDIA-Nemotron-3-Nano-30B-A3B-Base-BF16.yaml
@@ -0,0 +1,115 @@
+# @package _global_
+defaults:
+ - /nvidia/NVIDIA-Nemotron-3-Nano-30B-A3B-Base-BF16/pruning@pruning: expert_pruning
+ - /validate_solutions_defaults@scoring
+ - /validate_solutions_defaults@realize_model
+ - _self_
+
+
+puzzle_dir: ???
+teacher_dir: ${puzzle_dir}/ckpts/teacher/
+replacement_library_path: ${puzzle_dir}/replacement_library.json
+dataset_path: ??? # path to v0.4_mini
+
+skip_realize_model: false
+
+descriptor: nemotron_h
+
+build_replacement_library:
+ add_ffn_no_ops: true
+ add_attention_no_ops: true
+
+calc_subblock_stats:
+ batch_sizes: [64, 96, 128]
+ prefill_seq_len: 4096
+ generation_seq_len: 4096
+ num_active_tokens_override: # Optional override for sequence lengths
+ prefill_queue_size: 0
+ allocate_prefill_query: false
+ benchmark_iterations: # Set to a number (e.g., 1000) to enable runtime benchmarking
+ merge_with_existing_stats: false
+ subblock_stats_filename: "subblock_stats.json"
+ moe_stats_filename: "moe_stats.json"
+ runtime_stats:
+ backend: trt_torch
+
+scoring:
+ descriptor: ${descriptor}
+
+ solutions_to_validate:
+ skip_existing_solutions: true
+
+ replacement_library_path: ${replacement_library_path}
+ solutions_path: ${to_path:${puzzle_dir}/single_sequence_replacement_solutions.json}
+ teacher_dir: ${to_path:${teacher_dir}}
+ output_dir: ${puzzle_dir}/single_sequence_replacement_solutions--validation
+
+ eval_samples: 2
+ micro_batch_size: 1
+ seed: 42
+ shuffle_seed: 444
+ dataset_path: ${dataset_path}/valid
+
+mip:
+ single_block_replacement_validation_dir: ${to_path:${scoring.output_dir}}
+ subblock_stats_path: ${to_path:${puzzle_dir}/${calc_subblock_stats.subblock_stats_filename}}
+ output_path: ${to_path:${puzzle_dir}/mip/puzzle_solutions}
+ gathered_metrics_path:
+ puzzle_profile:
+
+ # puzzle_profile:
+ objective: metrics.cosine_embedding_loss_hidden_states
+ bigger_is_better: false
+ num_solutions: 1
+ minimal_diversity: 2
+
+ subblock_stats_args:
+ - batch_size: 96
+ weights_dtype: torch.bfloat16
+ activations_dtype: torch.bfloat16
+ kv_cache_dtype: torch.bfloat16
+
+ report_additional_costs:
+ - stats.memory_mib
+ - stats.num_params
+ - stats.num_kv_heads
+ - stats.has_attention
+ - stats.has_ffn
+ - stats.kv_cache_memory_mib
+ - stats.attention_memory_mib
+ - stats.ffn_memory_mib
+ - stats.ffn_num_params
+ - stats.attention_num_params
+ - stats.num_local_experts
+
+ human_constraints:
+ mip_constraints:
+ - stats.num_local_experts: 1472 # teacher has: 23 moe-blocks * 128 experts = 2944 total experts use_greedy_search: false
+ is_multi_layer_puzzle: true
+ metric_overrides:
+ constrain_search_func:
+ max_seconds_per_solution: 60
+
+realize_model:
+ descriptor: ${descriptor}
+
+ teacher_dir: ${to_path:${teacher_dir}}
+ tokenizer_name: ${to_path:${teacher_dir}}
+ replacement_library_path: ${replacement_library_path}
+ save_models: true
+ solutions_path: # Filled dynamically
+
+ # Validate params
+ skip_validation: false # To enable validation of the model solution set `skip_validation` as False
+ eval_samples: 2
+ micro_batch_size: 1
+ seed: 42
+ shuffle_seed: 444
+ dataset_path: ${dataset_path}/valid
+
+nccl_timeout_minutes: ${timedelta_minutes:10}
+
+# This section redirects Hydra outputs
+hydra:
+ run:
+ dir: ${puzzle_dir}/hydra_logs/${now:%Y-%m-%d}/${now:%H-%M-%S}
diff --git a/tests/gpu/torch/puzzletron/resources/configs/nvidia/NVIDIA-Nemotron-3-Nano-30B-A3B-Base-BF16/pruning/expert_pruning.yaml b/tests/gpu/torch/puzzletron/resources/configs/nvidia/NVIDIA-Nemotron-3-Nano-30B-A3B-Base-BF16/pruning/expert_pruning.yaml
new file mode 100644
index 0000000000..ae20b6d7d2
--- /dev/null
+++ b/tests/gpu/torch/puzzletron/resources/configs/nvidia/NVIDIA-Nemotron-3-Nano-30B-A3B-Base-BF16/pruning/expert_pruning.yaml
@@ -0,0 +1,18 @@
+defaults:
+ - /pruning/pruning_defaults@_here_
+
+eval_samples: 10
+activations_log_dir: ${puzzle_dir}/pruning/pruning_scores/expert_removal/${pruning.experiment_id}
+pruning_mixin:
+ _target_: modelopt.torch.puzzletron.pruning.expert_removal_pruning_mixin.ExpertRemovalPruningMixIn
+ layer_descriptor:
+ _target_: modelopt.torch.puzzletron.anymodel.models.nemotron_h.nemotron_h_model_descriptor.NemotronHExpertRemovalLayerDescriptor
+ target_name: "mixer"
+
+hook_class: ${get_object:modelopt.torch.prune.importance_hooks.expert_removal_hooks.NemotronHRemoveExpertsIndependentHook}
+activation_hooks_kwargs: # Additional kwargs to pass to the hook init
+
+num_experts_to_keep_list: [96, 64, 32, 16, 8] # num_experts in teacher is 128
+mlp_init_mode: "ExpertRemoval"
+mlp_init_config_yaml:
+ expert_scores_key: "expert_ranks_mse"
diff --git a/tests/gpu/torch/puzzletron/resources/configs/nvidia/NVIDIA-Nemotron-3-Nano-30B-A3B-Base-BF16/pruning/ffn_pruning.yaml b/tests/gpu/torch/puzzletron/resources/configs/nvidia/NVIDIA-Nemotron-3-Nano-30B-A3B-Base-BF16/pruning/ffn_pruning.yaml
new file mode 100644
index 0000000000..abc501287d
--- /dev/null
+++ b/tests/gpu/torch/puzzletron/resources/configs/nvidia/NVIDIA-Nemotron-3-Nano-30B-A3B-Base-BF16/pruning/ffn_pruning.yaml
@@ -0,0 +1,14 @@
+defaults:
+ - /pruning/pruning_defaults
+
+activations_log_dir: ${puzzle_dir}/pruning/pruning_scores/ffn/${pruning.experiment_id}
+pruning_mixin:
+ _target_: modelopt.torch.puzzletron.pruning.ffn_intermediate_pruning_mixin.FFNIntermediatePruningMixIn
+ layer_descriptor:
+ _target_: modelopt.torch.puzzletron.anymodel.models.llama.llama_model_descriptor.LlamaFFNIntermediateLayerDescriptor
+
+hook_class: ${get_object:modelopt.torch.prune.importance_hooks.base_hooks.IterativeChannelContributionHook}
+activation_hooks_kwargs: # Additional kwargs to pass to the hook init
+
+intermediate_size_list: [3072, 5888, 8704, 11520] # teacher_intermediate_size is 14336
+mlp_init_mode: "PruneByActivationsLog"
diff --git a/tests/gpu/torch/puzzletron/resources/configs/nvidia/NVIDIA-Nemotron-Nano-12B-v2/NVIDIA-Nemotron-Nano-12B-v2.yaml b/tests/gpu/torch/puzzletron/resources/configs/nvidia/NVIDIA-Nemotron-Nano-12B-v2/NVIDIA-Nemotron-Nano-12B-v2.yaml
new file mode 100644
index 0000000000..906b7338d8
--- /dev/null
+++ b/tests/gpu/torch/puzzletron/resources/configs/nvidia/NVIDIA-Nemotron-Nano-12B-v2/NVIDIA-Nemotron-Nano-12B-v2.yaml
@@ -0,0 +1,113 @@
+# @package _global_
+defaults:
+ - /nvidia/NVIDIA-Nemotron-Nano-12B-v2/pruning@pruning: ffn_pruning
+ - /validate_solutions_defaults@scoring
+ - /validate_solutions_defaults@realize_model
+ - _self_
+
+puzzle_dir: ???
+teacher_dir: ${puzzle_dir}/ckpts/teacher/
+replacement_library_path: ${puzzle_dir}/replacement_library.json
+dataset_path: ??? # path to v0.4_mini
+
+skip_realize_model: false
+
+descriptor: nemotron_h_v2
+
+build_replacement_library:
+ add_ffn_no_ops: true
+ add_attention_no_ops: true
+
+calc_subblock_stats:
+ batch_sizes: [64, 96, 128]
+ prefill_seq_len: 4096
+ generation_seq_len: 4096
+ num_active_tokens_override: # Optional override for sequence lengths
+ prefill_queue_size: 0
+ allocate_prefill_query: false
+ benchmark_iterations: # Set to a number (e.g., 1000) to enable runtime benchmarking
+ merge_with_existing_stats: false
+ subblock_stats_filename: "subblock_stats.json"
+ moe_stats_filename: "moe_stats.json"
+
+scoring:
+ descriptor: ${descriptor}
+
+ solutions_to_validate:
+ skip_existing_solutions: true
+
+ replacement_library_path: ${replacement_library_path}
+ solutions_path: ${to_path:${puzzle_dir}/single_sequence_replacement_solutions.json}
+ teacher_dir: ${to_path:${teacher_dir}}
+ output_dir: ${puzzle_dir}/single_sequence_replacement_solutions--validation
+
+ eval_samples: 2
+ micro_batch_size: 1
+ dataset_path: ${dataset_path}/valid
+ seed: 42
+ shuffle_seed: 444
+
+mip:
+ single_block_replacement_validation_dir: ${to_path:${scoring.output_dir}}
+ subblock_stats_path: ${to_path:${puzzle_dir}/${calc_subblock_stats.subblock_stats_filename}}
+ output_path: ${to_path:${puzzle_dir}/mip/puzzle_solutions}
+ gathered_metrics_path:
+ puzzle_profile:
+
+ # puzzle_profile:
+ objective: metrics.cosine_embedding_loss_hidden_states
+ bigger_is_better: false
+ num_solutions: 1
+ minimal_diversity: 2
+
+ subblock_stats_args:
+ - batch_size: 96
+ weights_dtype: torch.bfloat16
+ activations_dtype: torch.bfloat16
+ kv_cache_dtype: torch.bfloat16
+
+ report_additional_costs:
+ - stats.memory_mib
+ - stats.num_params
+ - stats.num_kv_heads
+ - stats.has_attention
+ - stats.has_ffn
+ - stats.kv_cache_memory_mib
+ - stats.attention_memory_mib
+ - stats.ffn_memory_mib
+ - stats.ffn_num_params
+ - stats.attention_num_params
+
+ human_constraints:
+ target_memory: 780_000 # 78_000
+
+ mip_constraints:
+ use_greedy_search: false
+ is_multi_layer_puzzle: true
+ metric_overrides:
+ constrain_search_func:
+ max_seconds_per_solution: 60
+
+realize_model:
+ descriptor: ${descriptor}
+
+ teacher_dir: ${to_path:${teacher_dir}}
+ tokenizer_name: ${to_path:${teacher_dir}}
+ replacement_library_path: ${replacement_library_path}
+ save_models: true
+ solutions_path: # Filled dynamically
+
+ # Validate params
+ skip_validation: false # To enable validation of the model solution set `skip_validation` as False
+ eval_samples: 2
+ micro_batch_size: 1
+ dataset_path: ${dataset_path}/valid
+ seed: 42
+ shuffle_seed: 444
+
+nccl_timeout_minutes: ${timedelta_minutes:10}
+
+# This section redirects Hydra outputs
+hydra:
+ run:
+ dir: ${puzzle_dir}/hydra_logs/${now:%Y-%m-%d}/${now:%H-%M-%S}
diff --git a/tests/gpu/torch/puzzletron/resources/configs/nvidia/NVIDIA-Nemotron-Nano-12B-v2/pruning/ffn_pruning.yaml b/tests/gpu/torch/puzzletron/resources/configs/nvidia/NVIDIA-Nemotron-Nano-12B-v2/pruning/ffn_pruning.yaml
new file mode 100644
index 0000000000..f68068c3ac
--- /dev/null
+++ b/tests/gpu/torch/puzzletron/resources/configs/nvidia/NVIDIA-Nemotron-Nano-12B-v2/pruning/ffn_pruning.yaml
@@ -0,0 +1,12 @@
+defaults:
+ - /pruning/ffn_pruning_base@_here_
+ - _self_
+
+pruning_mixin:
+ layer_descriptor:
+ _target_: modelopt.torch.puzzletron.anymodel.models.nemotron_h_v2.nemotron_h_v2_model_descriptor.NemotronHV2FFNIntermediateLayerDescriptor
+
+activation_hooks_kwargs:
+ method: iterative
+ target_layer: "mixer.down_proj"
+ layer_input_descriptors_path:
diff --git a/tests/gpu/torch/puzzletron/resources/configs/openai/gpt-oss-20b/gpt-oss-20b.yaml b/tests/gpu/torch/puzzletron/resources/configs/openai/gpt-oss-20b/gpt-oss-20b.yaml
new file mode 100644
index 0000000000..2b77516174
--- /dev/null
+++ b/tests/gpu/torch/puzzletron/resources/configs/openai/gpt-oss-20b/gpt-oss-20b.yaml
@@ -0,0 +1,109 @@
+# @package _global_
+defaults:
+ - /openai/gpt-oss-20b/pruning@pruning: expert_removal # TODO: Note: Works for unquantized test models, not MXFP4 quantized production models
+ - /validate_solutions_defaults@scoring
+ - /validate_solutions_defaults@realize_model
+ - bypass:
+ - override /hydra/hydra_logging: disabled
+ - _self_
+
+descriptor: gpt_oss
+
+puzzle_dir: ???
+teacher_dir: ${puzzle_dir}/ckpts/teacher/
+replacement_library_path: ${puzzle_dir}/replacement_library.json
+dataset_path: ??? # path to v0.4_mini
+
+skip_realize_model: false
+
+build_replacement_library:
+ add_ffn_no_ops: true # TODO: Works for unquantized test models
+ add_attention_no_ops: true
+
+calc_subblock_stats:
+ batch_sizes: [64, 96, 128]
+ prefill_seq_len: 4096
+ generation_seq_len: 4096
+ num_active_tokens_override: # Optional override for sequence lengths
+ prefill_queue_size: 0
+ allocate_prefill_query: false
+ benchmark_iterations: # Set to a number (e.g., 1000) to enable runtime benchmarking
+ merge_with_existing_stats: false
+ subblock_stats_filename: "subblock_stats.json"
+ moe_stats_filename: "moe_stats.json"
+
+scoring:
+ descriptor: ${descriptor}
+ solutions_to_validate:
+ skip_existing_solutions: true
+
+ replacement_library_path: ${replacement_library_path}
+ solutions_path: ${to_path:${puzzle_dir}/single_sequence_replacement_solutions.json}
+ teacher_dir: ${to_path:${teacher_dir}}
+ output_dir: ${puzzle_dir}/single_sequence_replacement_solutions--validation
+
+ eval_samples: 2
+ micro_batch_size: 1
+ dataset_path: ${dataset_path}/valid
+ seed: 42
+ shuffle_seed: 444
+
+mip:
+ single_block_replacement_validation_dir: ${to_path:${scoring.output_dir}}
+ subblock_stats_path: ${to_path:${puzzle_dir}/${calc_subblock_stats.subblock_stats_filename}}
+ output_path: ${to_path:${puzzle_dir}/mip/puzzle_solutions}
+ gathered_metrics_path:
+ puzzle_profile:
+
+ # puzzle_profile:
+ objective: metrics.cosine_embedding_loss_hidden_states
+ bigger_is_better: false
+
+ subblock_stats_args:
+ - batch_size: 96
+ weights_dtype: torch.bfloat16
+ activations_dtype: torch.bfloat16
+ kv_cache_dtype: torch.bfloat16
+
+ report_additional_costs:
+ - stats.memory_mib
+ - stats.num_params
+ - stats.num_kv_heads
+ - stats.has_attention
+ - stats.has_ffn
+ - stats.kv_cache_memory_mib
+ - stats.attention_memory_mib
+ - stats.ffn_memory_mib
+ - stats.ffn_num_params
+ - stats.attention_num_params
+
+ human_constraints:
+ target_memory: 780_000 # 78_000
+
+ mip_constraints:
+ - stats.num_local_experts: 48 # teacher has: 2 layers * 32 experts = 64 total experts
+ metric_overrides:
+ max_seconds_per_solution: 60
+
+realize_model:
+ descriptor: ${descriptor}
+ teacher_dir: ${to_path:${teacher_dir}}
+ tokenizer_name: ${to_path:${teacher_dir}}
+ replacement_library_path: ${replacement_library_path}
+ save_models: true
+ solutions_path: # Filled dynamically
+
+ # Validate params
+ skip_validation: false # To enable validation of the model solution set `skip_validation` as False
+ eval_samples: 2
+ micro_batch_size: 1
+ dataset_path: ${dataset_path}/valid
+ seed: 42
+ shuffle_seed: 444
+
+nccl_timeout_minutes: ${timedelta_minutes:10}
+
+# This section redirects Hydra outputs
+hydra:
+ run:
+ dir: ${puzzle_dir}/hydra_logs/${now:%Y-%m-%d}/${now:%H-%M-%S}
diff --git a/tests/gpu/torch/puzzletron/resources/configs/openai/gpt-oss-20b/pruning/expert_removal.yaml b/tests/gpu/torch/puzzletron/resources/configs/openai/gpt-oss-20b/pruning/expert_removal.yaml
new file mode 100644
index 0000000000..5a4761886f
--- /dev/null
+++ b/tests/gpu/torch/puzzletron/resources/configs/openai/gpt-oss-20b/pruning/expert_removal.yaml
@@ -0,0 +1,19 @@
+defaults:
+ - /pruning/pruning_defaults@_here_
+
+eval_samples: 10
+activations_log_dir: ${puzzle_dir}/pruning/pruning_scores/expert_removal/${pruning.experiment_id}
+
+pruning_mixin:
+ _target_: modelopt.torch.puzzletron.pruning.expert_removal_pruning_mixin.ExpertRemovalPruningMixIn
+ layer_descriptor:
+ _target_: modelopt.torch.puzzletron.anymodel.models.gpt_oss.gpt_oss_model_descriptor.GptOssExpertRemovalLayerDescriptor
+ target_name: "mlp.router"
+hook_class: ${get_object:modelopt.torch.prune.importance_hooks.expert_removal_hooks.RankedChoiceVotingHook}
+activation_hooks_kwargs: # Additional kwargs to pass to the hook init
+
+num_experts_to_keep_list: [24, 16, 8] # num_experts in teacher is 128
+mlp_init_mode: "ExpertRemoval"
+mlp_init_config_yaml:
+ expert_scores_key: "expert_ranks"
+ layer_prefix_template: "model.layers.{layer_idx}.mlp.router"
diff --git a/tests/gpu/torch/puzzletron/resources/configs/pruning/attn_pruning.yaml b/tests/gpu/torch/puzzletron/resources/configs/pruning/attn_pruning.yaml
new file mode 100644
index 0000000000..0dadc20134
--- /dev/null
+++ b/tests/gpu/torch/puzzletron/resources/configs/pruning/attn_pruning.yaml
@@ -0,0 +1,23 @@
+defaults:
+ - /pruning/pruning_defaults@_here_
+ - _self_
+
+activations_log_dir: ${puzzle_dir}/pruning/pruning_scores/attn_${pruning.activation_hooks_kwargs.method}/${pruning.experiment_id}
+
+pruning_mixin:
+ _target_: modelopt.torch.puzzletron.pruning.kv_heads_pruning_mixin.KVHeadsPruningMixIn
+ layer_descriptor:
+ _target_: ???
+
+hook_class: ${get_object:modelopt.torch.prune.importance_hooks.base_hooks.IndependentKvHeadContributionHook}
+activation_hooks_kwargs:
+ method: independent_kv_head_contribution
+ optimize_for: memory # IndependentKvHeadContributionHook implementation that consumes less memory
+ target_layer: "self_attn.o_proj"
+ layer_input_descriptors_path:
+
+# n_heads_in_group: 4
+# num_attention_heads: 32 # num query heads
+# num_kv_heads: 32 / 4 = 8 # num_query_heads // n_heads_in_group
+n_heads_in_group_list: [8, 16, 32] # num_kv_heads = [4, 2, 1]
+gqa_init_mode: "PruneKVHeads"
diff --git a/tests/gpu/torch/puzzletron/resources/configs/pruning/ffn_pruning_base.yaml b/tests/gpu/torch/puzzletron/resources/configs/pruning/ffn_pruning_base.yaml
new file mode 100644
index 0000000000..c1c951984f
--- /dev/null
+++ b/tests/gpu/torch/puzzletron/resources/configs/pruning/ffn_pruning_base.yaml
@@ -0,0 +1,19 @@
+defaults:
+ - /pruning/pruning_defaults@_here_
+ - _self_
+
+activations_log_dir: ${puzzle_dir}/pruning/pruning_scores/ffn_${pruning.activation_hooks_kwargs.method}/${pruning.experiment_id}
+
+pruning_mixin:
+ _target_: modelopt.torch.puzzletron.pruning.ffn_intermediate_pruning_mixin.FFNIntermediatePruningMixIn
+ layer_descriptor:
+ _target_: ???
+
+hook_class: ${get_object:modelopt.torch.prune.importance_hooks.base_hooks.IterativeChannelContributionHook}
+activation_hooks_kwargs:
+ method: iterative
+ target_layer: "mlp.down_proj"
+ layer_input_descriptors_path:
+
+intermediate_size_list: [256]
+mlp_init_mode: "PruneByActivationsLog"
diff --git a/tests/gpu/torch/puzzletron/resources/configs/pruning/hidden_dim_pruning.yaml b/tests/gpu/torch/puzzletron/resources/configs/pruning/hidden_dim_pruning.yaml
new file mode 100644
index 0000000000..4033fedf3a
--- /dev/null
+++ b/tests/gpu/torch/puzzletron/resources/configs/pruning/hidden_dim_pruning.yaml
@@ -0,0 +1,15 @@
+defaults:
+ - /pruning/pruning_defaults@_here_
+
+activations_log_dir: ${puzzle_dir}/pruning/pruning_scores/hidden_dim_${pruning.activation_hooks_kwargs.method}/${pruning.experiment_id}
+
+activation_hooks_kwargs:
+ method: layer_norm_contribution
+ target_layer: "layernorm"
+
+# Hidden dimension pruning specific settings
+hidden_size_list: [3072, 2048] # Target hidden sizes to prune to
+hidden_size_init_mode: "PruneByChannelRanking"
+mlp_init_mode: "Truncate" # TODO, make it work with CopyAsIs/FromTeacher
+gqa_init_mode: "AverageKV" # TODO, make it work with CopyAsIs/FromTeacher
+linear_init_mode: "FromTeacher"
diff --git a/tests/gpu/torch/puzzletron/resources/configs/pruning/pruning_defaults.yaml b/tests/gpu/torch/puzzletron/resources/configs/pruning/pruning_defaults.yaml
new file mode 100644
index 0000000000..f00a86da66
--- /dev/null
+++ b/tests/gpu/torch/puzzletron/resources/configs/pruning/pruning_defaults.yaml
@@ -0,0 +1,34 @@
+defaults:
+ - /validate_model_defaults@_here_
+
+model_name_or_path: ${teacher_dir}
+experiment_id: ${pruning.eval_samples}samples_diverse_mini
+activations_log_dir: ???
+activation_hooks_kwargs: ???
+
+descriptor: ${descriptor}
+
+# Data:
+eval_samples: 100
+micro_batch_size: 4
+dataset_path: ${dataset_path}
+val_dataset_name: train
+
+# Prune ckpts
+pruned_ckpts_output_dir: ${puzzle_dir}/pruning/${pruning.experiment_id}
+
+## FFN pruning
+ffn_list:
+mlp_init_mode: "Truncate"
+
+## KV-heads pruning
+n_heads_in_group_list:
+gqa_init_mode: "AverageKV"
+
+## Hidden dimension pruning
+hidden_size_list:
+hidden_size_init_mode: "PruneByChannelRanking"
+linear_init_mode: "FromTeacher"
+
+mlp_init_config_yaml:
+ activations_log_dir: ${pruning.activations_log_dir}
diff --git a/tests/gpu/torch/puzzletron/resources/configs/validate_model_defaults.yaml b/tests/gpu/torch/puzzletron/resources/configs/validate_model_defaults.yaml
new file mode 100644
index 0000000000..9dabef7413
--- /dev/null
+++ b/tests/gpu/torch/puzzletron/resources/configs/validate_model_defaults.yaml
@@ -0,0 +1,15 @@
+block_size: 8192
+bos_rate: 0.5
+data_column: conversation
+val_dataset_name: train
+shuffle_seed: 81436
+seed: 42
+fim_rate: 0
+fim_spm_rate: 0
+source_datasets_to_discard:
+varlen: false
+write_results: false
+calc_losses_on_cpu: false
+activations_log_dir:
+model_name_or_path:
+load_dataset_fn: ${get_object:modelopt.torch.puzzletron.utils.data.dataloaders.load_from_disk_fn}
diff --git a/tests/gpu/torch/puzzletron/resources/configs/validate_solutions_defaults.yaml b/tests/gpu/torch/puzzletron/resources/configs/validate_solutions_defaults.yaml
new file mode 100644
index 0000000000..ec13902379
--- /dev/null
+++ b/tests/gpu/torch/puzzletron/resources/configs/validate_solutions_defaults.yaml
@@ -0,0 +1,10 @@
+defaults:
+ - /validate_model_defaults
+ - _self_
+
+solutions_to_validate:
+skip_validation: false
+save_models: false
+bigger_is_better: false
+sort_solutions_by:
+calculate_full_score_ablations: false
diff --git a/tests/gpu/torch/puzzletron/test_puzzletron.py b/tests/gpu/torch/puzzletron/test_puzzletron.py
new file mode 100644
index 0000000000..45c438ec0d
--- /dev/null
+++ b/tests/gpu/torch/puzzletron/test_puzzletron.py
@@ -0,0 +1,345 @@
+# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+# SPDX-License-Identifier: Apache-2.0
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import json
+import warnings
+from datetime import timedelta
+from functools import partial
+from pathlib import Path
+
+import pytest
+import torch
+from _test_utils.torch.distributed.utils import spawn_multiprocess_job
+from _test_utils.torch.misc import set_seed
+from _test_utils.torch.puzzletron.utils import setup_test_model_and_data
+
+import modelopt.torch.utils.distributed as dist
+from modelopt.torch.puzzletron import puzzletron
+from modelopt.torch.puzzletron.anymodel import convert_model
+from modelopt.torch.puzzletron.mip.sweep import (
+ get_teacher_memory_from_subblock_stats,
+ get_teacher_num_params_from_subblock_stats,
+)
+
+# The e2e test to compress a model based on Local Neural Architecture Search (Mixed Integer Programing NAS search)
+# using a one-click command.
+#
+# Note: Bypass is disabled now in the test.
+#
+
+SEED = 1234
+
+
+@pytest.mark.parametrize(
+ ("hf_model_name", "converter", "hybrid_override_pattern", "has_moe_layers"),
+ [
+ ("meta-llama/Llama-3.1-8B-Instruct", "llama", None, False),
+ ("meta-llama/Llama-3.2-3B-Instruct", "llama", None, False),
+ ("mistralai/Mistral-Small-24B-Instruct-2501", "mistral_small", None, False),
+ ("nvidia/NVIDIA-Nemotron-3-Nano-30B-A3B-Base-BF16", "nemotron_h", "*E", True),
+ ("nvidia/NVIDIA-Nemotron-Nano-12B-v2", "nemotron_h_v2", "*-", False),
+ ("openai/gpt-oss-20b", "gpt_oss", None, True),
+ ("Qwen/Qwen2.5-7B-Instruct", "qwen2", None, False),
+ ("Qwen/Qwen3-8B", "qwen3", None, False),
+ ("Qwen/Qwen3-VL-30B-A3B-Instruct", "qwen3_vl", None, True),
+ ],
+)
+def test_puzzletron(
+ project_root_path: Path,
+ tmp_path: Path,
+ hf_model_name: str,
+ converter: str,
+ hybrid_override_pattern: str,
+ has_moe_layers: bool,
+):
+ spawn_multiprocess_job(
+ size=torch.cuda.device_count(),
+ job=partial(
+ _test_puzzletron_multiprocess_job,
+ project_root_path,
+ tmp_path,
+ hf_model_name,
+ converter,
+ hybrid_override_pattern,
+ has_moe_layers,
+ ),
+ backend="nccl",
+ )
+
+
+def _test_puzzletron_multiprocess_job(
+ project_root_path: Path,
+ tmp_path: Path,
+ hf_model_name: str,
+ converter: str,
+ hybrid_override_pattern: str,
+ has_moe_layers: bool,
+ rank: int,
+ size: int,
+):
+ # Set seed BEFORE dist.setup() to ensure reproducibility across all processes
+ set_seed(SEED)
+ dist.setup(timeout=timedelta(10))
+
+ # Setup the test model and data.
+ puzzle_dir, hf_checkpoint_path, dataset_path = setup_test_model_and_data(
+ tmp_path, rank, hf_model_name, hybrid_override_pattern
+ )
+ hydra_config_dir = project_root_path / "tests/gpu/torch/puzzletron/resources/configs"
+ model_basename = hf_model_name.split("/")[1]
+ hydra_config_name = f"{hf_model_name}/{model_basename}"
+
+ # Convert the model using AnyModel converter.
+ if rank == 0:
+ convert_model(
+ input_dir=str(hf_checkpoint_path),
+ output_dir=str(puzzle_dir / "ckpts/teacher"),
+ converter=converter,
+ )
+ dist.barrier()
+
+ # Compress the model using a one-click approach
+ hydra_cfg = puzzletron.puzzletron(
+ str(hydra_config_dir), hydra_config_name, str(puzzle_dir), str(dataset_path)
+ )
+
+ #
+ # Check assertions
+ #
+ if rank == 0:
+ if has_moe_layers:
+ # assertions for the score_pruning_activations step 1 (MoE models only)
+ rank_filepath = (
+ f"pruning/pruning_scores/expert_removal/10samples_diverse_mini/rank_{rank}.pth"
+ )
+ assert (puzzle_dir / rank_filepath).is_file(), f"Expected {rank_filepath} to exist"
+
+ # assertions for the pruning_ckpts step 2
+ assert (puzzle_dir / "ckpts/num_experts_8").exists()
+
+ # assertions for the mip_and_realize_models step 6
+ # Find the MIP solution directory dynamically (e.g., stats_num_local_experts_*)
+ mip_solutions_dir = puzzle_dir / "mip/puzzle_solutions"
+ solution_dirs = [
+ d
+ for d in mip_solutions_dir.iterdir()
+ if d.is_dir() and d.name.startswith("stats_num_local_experts_")
+ ]
+ assert len(solution_dirs) == 1, (
+ f"Expected exactly one stats_num_local_experts_* directory, found: {[d.name for d in solution_dirs]}"
+ )
+ solution_dir = solution_dirs[0]
+
+ solution_0_ckpt_config_path = (
+ solution_dir / "solutions--checkpoints/solution_0/config.json"
+ )
+ assert solution_0_ckpt_config_path.exists()
+ assert (solution_dir / "solutions.json").exists()
+
+ # Validate lm_loss
+ _assert_lm_loss(puzzle_dir, hf_model_name, tolerance=0.01)
+ else:
+ # assertions for the score_pruning_activations step 1 (FFN pruning)
+ _assert_score_pruning_activations(puzzle_dir, hf_model_name)
+
+ # assertions for the pruning_ckpts step 2
+ assert (puzzle_dir / "ckpts/ffn_256_attn_no_op").exists()
+
+ # assertions for the mip_and_realize_models step 6
+ _assert_mip_solutions(puzzle_dir, hf_model_name)
+
+ # assertions for the build_library_and_stats step 4
+ assert (puzzle_dir / "replacement_library.json").is_file()
+ _assert_subblock_stats_anymodel(hf_model_name, hydra_cfg)
+
+ # assertions for the scoring step 5
+ solution_0_filepath = (
+ puzzle_dir / "single_sequence_replacement_solutions--validation/solution_0.json"
+ )
+ assert solution_0_filepath.exists()
+
+ dist.cleanup()
+
+
+def _assert_subblock_stats_anymodel(hf_model_name: str, hydra_cfg) -> None:
+ """Minimal subblock_stats checks and teacher memory / param regression values."""
+ assert (Path(hydra_cfg.puzzle_dir) / "subblock_stats.json").is_file()
+ teacher_mem_mib = get_teacher_memory_from_subblock_stats(hydra_cfg)
+ teacher_num_params = get_teacher_num_params_from_subblock_stats(hydra_cfg)
+
+ assert abs(teacher_mem_mib - EXPECTED_TEACHER_MEMORY_MIB[hf_model_name]) < 1e-6, (
+ f"Teacher memory mismatch for {hf_model_name}: "
+ f"expected {EXPECTED_TEACHER_MEMORY_MIB[hf_model_name]}, got {teacher_mem_mib}"
+ )
+ assert abs(teacher_num_params - EXPECTED_TEACHER_NUM_PARAMS[hf_model_name]) < 1e-6, (
+ f"Teacher num_params mismatch for {hf_model_name}: "
+ f"expected {EXPECTED_TEACHER_NUM_PARAMS[hf_model_name]}, got {teacher_num_params}"
+ )
+
+
+def _assert_score_pruning_activations(puzzle_dir: Path, hf_model_name: str):
+ """Assertions for the score_pruning_activations step 1."""
+ rank = dist.rank()
+ rank_filepath = f"pruning/pruning_scores/ffn_iterative/100samples_diverse_mini/rank_{rank}.pth"
+ assert (puzzle_dir / rank_filepath).is_file()
+
+ pruning_scores = torch.load(puzzle_dir / rank_filepath)
+
+ layer_names = list(pruning_scores.keys())
+ expected = EXPECTED_PRUNING_VALUES[hf_model_name]
+ size = dist.size()
+
+ if hf_model_name == "mistralai/Mistral-Small-24B-Instruct-2501" and size == 1:
+ warnings.warn("Mistral-Small score assertions only work for 2 GPUs")
+ return
+
+ if expected is not None:
+ # In multi-GPU: layers are distributed across ranks
+ # Each rank processes len(expected) // size layers
+ expected_layers_per_rank = len(expected) // size
+ assert len(layer_names) == expected_layers_per_rank, (
+ f"Expected {expected_layers_per_rank} FFN layers on rank {rank}/{size}, got {len(layer_names)}"
+ )
+ # Check each layer's values
+ for i, layer_name in enumerate(layer_names):
+ layer_data = pruning_scores[layer_name]
+ # Calculate global layer index from rank and local index
+ global_idx = rank * expected_layers_per_rank + i
+ assert layer_data["score"][0].item() == expected[global_idx]["score"], (
+ layer_name,
+ layer_data["score"][0].item(),
+ expected[global_idx]["score"],
+ global_idx,
+ )
+ assert (
+ layer_data["channels_importance_ascending"][0].item()
+ == expected[global_idx]["channels"]
+ )
+ else:
+ print(f"\n=== PRUNING VALUES for {hf_model_name} (num_layers={len(layer_names)}) ===")
+ print(f'"{hf_model_name}": [')
+ for layer_name in layer_names:
+ layer_data = pruning_scores[layer_name]
+ score = layer_data["score"][0].item()
+ channels = layer_data["channels_importance_ascending"][0].item()
+ print(f' {{"score": {score}, "channels": {channels}}},')
+ print("],")
+ print("===")
+ pytest.fail(f"Expected pruning values not found for {hf_model_name}")
+
+
+def _assert_lm_loss(puzzle_dir: Path, hf_model_name: str, tolerance: float = 0.01):
+ """Validate lm_loss for a model solution."""
+ solution_0_path = (
+ puzzle_dir / "single_sequence_replacement_solutions--validation/solution_0.json"
+ )
+ with open(solution_0_path) as f:
+ validation = json.load(f)
+
+ actual_lm_loss = validation["lm_loss"]["avg"]
+ expected_lm_loss = EXPECTED_LM_LOSS.get(hf_model_name)
+ if expected_lm_loss is not None:
+ assert abs(actual_lm_loss - expected_lm_loss) < tolerance, (
+ f"lm_loss mismatch: expected {expected_lm_loss}, got {actual_lm_loss}"
+ )
+ else:
+ # Print value for new models - update EXPECTED_LM_LOSS with this
+ print(f"\n=== LM_LOSS for {hf_model_name} ===")
+ print(f'"{hf_model_name}": {actual_lm_loss},')
+ print("===")
+
+
+def _assert_mip_solutions(puzzle_dir: Path, hf_model_name: str):
+ """Assertions for the mip_and_realize_models step."""
+ mip_dir = puzzle_dir / "mip/puzzle_solutions/target_memory_780000MiB"
+
+ assert (mip_dir / "solutions.json").exists()
+ assert (mip_dir / "solutions--checkpoints/solution_0/config.json").exists()
+
+ # Validate lm_loss
+ _assert_lm_loss(puzzle_dir, hf_model_name)
+
+
+# Expected pruning activation values per model
+# Each model has a list of (score, channels) tuples for each FFN layer
+EXPECTED_PRUNING_VALUES = {
+ "meta-llama/Llama-3.1-8B-Instruct": [
+ {"score": 73, "channels": 95},
+ {"score": 440, "channels": 174},
+ ],
+ "meta-llama/Llama-3.2-3B-Instruct": [
+ {"score": 79, "channels": 95},
+ {"score": 428, "channels": 174},
+ ],
+ "mistralai/Mistral-Small-24B-Instruct-2501": [
+ {"score": 73, "channels": 95},
+ {"score": 431, "channels": 174},
+ ],
+ # NemotronH with pattern "*-" has only 1 FFN layer (the "-" layer)
+ "nvidia/NVIDIA-Nemotron-Nano-12B-v2": [
+ {"score": 70, "channels": 509},
+ ],
+ # nvidia/NVIDIA-Nemotron-3-Nano-30B-A3B-Base-BF16 uses MoE expert pruning, not FFN pruning
+ "Qwen/Qwen2.5-7B-Instruct": [
+ {"score": 96, "channels": 433},
+ {"score": 485, "channels": 105},
+ ],
+ "Qwen/Qwen3-8B": [
+ {"score": 208, "channels": 51},
+ {"score": 475, "channels": 266},
+ ],
+}
+
+
+# Expected lm_loss values per model
+EXPECTED_LM_LOSS = {
+ "meta-llama/Llama-3.1-8B-Instruct": 4.706878662109375,
+ "meta-llama/Llama-3.2-3B-Instruct": 4.816886901855469,
+ "mistralai/Mistral-Small-24B-Instruct-2501": 4.709150314331055,
+ # TODO: not reproducible in CI, skipping for now
+ # "nvidia/NVIDIA-Nemotron-3-Nano-30B-A3B-Base-BF16": 4.733944892883301,
+ "nvidia/NVIDIA-Nemotron-Nano-12B-v2": 4.79390811920166,
+ "openai/gpt-oss-20b": 4.689250946044922,
+ "Qwen/Qwen2.5-7B-Instruct": 4.778186798095703,
+ "Qwen/Qwen3-8B": 4.733874320983887,
+ "Qwen/Qwen3-VL-30B-A3B-Instruct": 4.65625,
+}
+
+
+# Expected teacher memory from subblock_stats (MiB)
+EXPECTED_TEACHER_MEMORY_MIB = {
+ "meta-llama/Llama-3.1-8B-Instruct": 395.60205078125,
+ "meta-llama/Llama-3.2-3B-Instruct": 395.60205078125,
+ "mistralai/Mistral-Small-24B-Instruct-2501": 395.60205078125,
+ "nvidia/NVIDIA-Nemotron-3-Nano-30B-A3B-Base-BF16": 202.10107421875,
+ "nvidia/NVIDIA-Nemotron-Nano-12B-v2": 202.10107421875,
+ "openai/gpt-oss-20b": 437.302490234375,
+ "Qwen/Qwen2.5-7B-Instruct": 386.228515625,
+ "Qwen/Qwen3-8B": 395.60302734375,
+ "Qwen/Qwen3-VL-30B-A3B-Instruct": 406.11865234375,
+}
+
+
+# Expected total teacher params from subblock_stats
+EXPECTED_TEACHER_NUM_PARAMS = {
+ "meta-llama/Llama-3.1-8B-Instruct": 6082816.0,
+ "meta-llama/Llama-3.2-3B-Instruct": 6082816.0,
+ "mistralai/Mistral-Small-24B-Instruct-2501": 6082816.0,
+ "nvidia/NVIDIA-Nemotron-3-Nano-30B-A3B-Base-BF16": 5295872.0,
+ "nvidia/NVIDIA-Nemotron-Nano-12B-v2": 5295872.0,
+ "openai/gpt-oss-20b": 27945856.0,
+ "Qwen/Qwen2.5-7B-Instruct": 1168384.0,
+ "Qwen/Qwen3-8B": 6083328.0,
+ "Qwen/Qwen3-VL-30B-A3B-Instruct": 11596544.0,
+}
diff --git a/tests/unit/torch/puzzletron/test_convert_anymodel.py b/tests/unit/torch/puzzletron/test_convert_anymodel.py
new file mode 100644
index 0000000000..f27cb9c9b9
--- /dev/null
+++ b/tests/unit/torch/puzzletron/test_convert_anymodel.py
@@ -0,0 +1,35 @@
+# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+# SPDX-License-Identifier: Apache-2.0
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import pytest
+
+pytest.importorskip("transformers")
+
+from _test_utils.torch.transformers_models import create_tiny_qwen3_dir
+from transformers import AutoModelForCausalLM
+
+from modelopt.torch.puzzletron.anymodel import convert_model
+from modelopt.torch.puzzletron.anymodel.model_descriptor import ModelDescriptorFactory
+from modelopt.torch.puzzletron.anymodel.puzzformer import deci_x_patcher
+
+
+def test_convert_anymodel(tmp_path):
+ input_dir = create_tiny_qwen3_dir(tmp_path, with_tokenizer=True)
+ output_dir = tmp_path / "qwen3-0.6b-anymodel"
+ convert_model(input_dir, output_dir, converter="qwen3")
+
+ descriptor = ModelDescriptorFactory.get("qwen3")
+ with deci_x_patcher(descriptor):
+ _ = AutoModelForCausalLM.from_pretrained(output_dir)
diff --git a/tox.ini b/tox.ini
index 80299d814d..7948a19f5c 100644
--- a/tox.ini
+++ b/tox.ini
@@ -66,6 +66,10 @@ commands_pre =
# Install cupy-cuda13x for INT4 ONNX quantization (default is cupy-cuda12x)
pip uninstall -y cupy-cuda12x
pip install cupy-cuda13x
+
+ # Install mamba and causal-conv1d for Nemotron tests
+ pip install --no-build-isolation git+https://github.com/state-spaces/mamba.git
+ pip install --no-build-isolation git+https://github.com/Dao-AILab/causal-conv1d.git
commands =
# Coverage fails with "Can't combine line data with arc data" error so not using "--cov"
python -m pytest tests/gpu