diff --git a/.github/CODEOWNERS b/.github/CODEOWNERS index d1e02acfba..f9f98b73d8 100644 --- a/.github/CODEOWNERS +++ b/.github/CODEOWNERS @@ -24,6 +24,7 @@ modelopt/torch/nas @NVIDIA/modelopt-torch-nas-prune-codeowners modelopt/torch/opt @NVIDIA/modelopt-torch-opt-codeowners modelopt/torch/peft @NVIDIA/modelopt-torch-peft-codeowners modelopt/torch/prune @NVIDIA/modelopt-torch-nas-prune-codeowners +modelopt/torch/puzzletron @NVIDIA/modelopt-torch-puzzletron-codeowners modelopt/torch/quantization @NVIDIA/modelopt-torch-quantization-codeowners modelopt/torch/sparsity @NVIDIA/modelopt-torch-sparsity-codeowners modelopt/torch/speculative @NVIDIA/modelopt-torch-speculative-codeowners diff --git a/.github/workflows/example_tests.yml b/.github/workflows/example_tests.yml index f3f3908043..848b3d326d 100644 --- a/.github/workflows/example_tests.yml +++ b/.github/workflows/example_tests.yml @@ -125,14 +125,14 @@ jobs: strategy: &nemo_strategy fail-fast: false matrix: - example: [megatron_bridge] + example: [megatron_bridge, puzzletron] uses: ./.github/workflows/_example_tests_runner.yml secrets: inherit with: docker_image: "nvcr.io/nvidia/nemo:26.02" example: ${{ matrix.example }} timeout_minutes: 30 - pip_install_extras: "[hf,dev-test]" + pip_install_extras: "[hf,puzzletron,dev-test]" runner: linux-amd64-gpu-rtxpro6000-latest-1 nemo-non-pr: @@ -144,7 +144,7 @@ jobs: docker_image: "nvcr.io/nvidia/nemo:26.02" example: ${{ matrix.example }} timeout_minutes: 30 - pip_install_extras: "[hf,dev-test]" + pip_install_extras: "[hf,puzzletron,dev-test]" runner: linux-amd64-gpu-rtxpro6000-latest-2 ##### ONNX/TensorRT Example Tests ##### diff --git a/.github/workflows/gpu_tests.yml b/.github/workflows/gpu_tests.yml index 538e05e75f..542e948909 100644 --- a/.github/workflows/gpu_tests.yml +++ b/.github/workflows/gpu_tests.yml @@ -85,6 +85,8 @@ jobs: - name: Setup environment variables run: | echo "LD_LIBRARY_PATH=${LD_LIBRARY_PATH}:/usr/include:/usr/lib/x86_64-linux-gnu" >> $GITHUB_ENV + - name: Install dependencies for mip + run: apt-get update && apt-get install -y libffi-dev - name: Run gpu tests run: pip install tox-current-env && tox -e cuda13-${{ matrix.example }} --current-env gpu-tests-non-pr: diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 0a7821d42f..7810db7886 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -79,6 +79,7 @@ repos: modelopt/onnx/quantization/ort_patching.py| modelopt/torch/_deploy/utils/onnx_utils.py| modelopt/torch/export/transformer_engine.py| + modelopt/torch/puzzletron/anymodel/models/gpt_oss/gpt_oss_pruned_to_mxfp4.py| modelopt/torch/quantization/export_onnx.py| modelopt/torch/quantization/plugins/attention.py| modelopt/torch/speculative/eagle/utils.py| @@ -95,6 +96,7 @@ repos: examples/llm_eval/modeling.py| examples/llm_qat/main.py| examples/llm_sparsity/weight_sparsity/finetune.py| + examples/puzzletron/evaluation/lm_eval_anymodel.py| examples/specdec_bench/specdec_bench/models/specbench_medusa.py| examples/speculative_decoding/main.py| examples/speculative_decoding/medusa_utils.py| diff --git a/examples/pruning/README.md b/examples/pruning/README.md index 656d6315db..930e9c6d25 100644 --- a/examples/pruning/README.md +++ b/examples/pruning/README.md @@ -7,6 +7,7 @@ Pruning can involve removal (prune) of Linear and Conv layers; and Transformer a This section focuses on applying Model Optimizer's state-of-the-art complementary pruning modes to enable you to search for the best subnet architecture from your provided base model: 1. [Minitron](https://arxiv.org/pdf/2408.11796): A pruning method developed by NVIDIA Research for pruning GPT (and later extended to Mamba, MoE, and Hybrid Transformer Mamba) models in NVIDIA Megatron-LM (M-LM) or Megatron-Bridge (M-Bridge) framework. It uses the activation magnitudes to prune the embedding hidden size; mlp ffn hidden size; transformer attention heads; mamba heads and head dimension; MoE number of experts, ffn hidden size, and shared expert intermediate size; and number of layers of the model. +1. [Puzzletron](../puzzletron/README.md): An advanced pruning method by NVIDIA using Mixed Integer Programming (MIP) based NAS search algorithm. 1. FastNAS: A pruning method recommended for Computer Vision models. Given a pretrained model, FastNAS finds the subnet which maximizes the score function while meeting the given constraints. 1. GradNAS: A light-weight pruning method recommended for language models like Hugging Face BERT, GPT-J. It uses the gradient information to prune the model's linear layers and attention heads to meet the given constraints. diff --git a/examples/puzzletron/GPTOSS.md b/examples/puzzletron/GPTOSS.md new file mode 100644 index 0000000000..7c160c8997 --- /dev/null +++ b/examples/puzzletron/GPTOSS.md @@ -0,0 +1,14 @@ + +## GptOss + +With this release Puzzle algorithm supports only experts removal for `Gpt-Oss`. + +This model comes as a quantized checkpoint i.e. MoE experts matrices are quantized with _MXFP4_ format. +In the pruning steps puzzle utilizes decompressed model (back to BF16) for statistics and scores computation. +This means, during the conversion to puzzle format we decompress the model and store it as a BF16. +Once the pruning is done i.e. experts to be removed are identified and the process is finished, user may want to get back the _MXFP4_ format of the checkpoint. +To do so, there is an additional script, that takes the original and the pruned checkpoint and outputs pruned checkpoint in _MXFP4_ format. + +```bash +python -m modelopt.torch.puzzletron.anymodel.models.gpt_oss.gpt_oss_pruned_to_mxfp4 --student-path /workspaces/any_model_gpt_oss/mip/puzzle_solutions/stats_num_params_18014757184/solutions--checkpoints/solution_0/ --original-path /workspaces/source_model_checkpoints/openai_gpt-oss-20b/ --output-path /workspaces/any_model_gpt_oss/mip/puzzle_solutions/stats_num_params_18014757184/solutions--checkpoints/mxfp4-ckpt/ --num-layers 24 +``` diff --git a/examples/puzzletron/README.md b/examples/puzzletron/README.md new file mode 100644 index 0000000000..a7e3aedfc1 --- /dev/null +++ b/examples/puzzletron/README.md @@ -0,0 +1,284 @@ +# Puzzletron Algorithm Tutorial + +This tutorial demonstrates how to compress large language models using the puzzletron algorithm based on the [Puzzle paper](https://arxiv.org/abs/2411.19146). +The goal of the algorithm it to find the most optimal modifications to MLP and attention layers of the model, resulting in a heterogeneous model architecture. +The supported modifications are: + +- `ffn_intermediate_size`: different FFN intermediate sizes +- `attention op/noop`: complete removal of attention layers + +To use the Puzzle algorithm effectively, we need to specify the target number of parameters and/or the memory. The final stage is based on Mixed-Integer Programming (MIP) algorithm to find the most optimal combination of layer modifications that satisfy the target requirements. + +In this example, we compress the [Llama-3.1-8B-Instruct](https://huggingface.co/meta-llama/Llama-3.1-8B-Instruct) model reducing GPU memory usage from 113 GiB to 96 GiB (15% reduction) with less than 1% regression in the token_accuracy_top_10 metric. Other supported models should be compressed in a similar way. For GptOss there is one [additional step to be performed](GPTOSS.md). + +> **Note:** Other models are also supported. See the [configs](./configs/) directory for additional model configurations (e.g., Llama-3.2-3B-Instruct on 1x H100, Qwen2.5-7B-Instruct on 1x H100, Qwen3-8B on 1x H100, Nemotron-Nano-12B-v2 on 1x H100, Mistral-Small-24B-Instruct-2501 on 4x H100). For information on adding support for new models, see the [AnyModel Guide](../../modelopt/torch/puzzletron/anymodel/README.md). + +## Environment + +- Install Model-Optimizer in editable mode with the corresponding dependencies (run from the repo root): + +```bash +pip install -e .[hf,puzzletron] +pip install -r examples/puzzletron/requirements.txt +``` + +> **Note:** NeMo containers may ship `nvidia-lm-eval` which may conflict with `lm-eval` that is used for evaluation. +> If so, run `pip uninstall nvidia-lm-eval -y` before installing requirements. + +- For this example we are using 2x NVIDIA H100 80GB HBM3 to show multi-GPU steps. You can use also use a single GPU. + +- To make use of [Llama-3.1-8B-Instruct](https://huggingface.co/meta-llama/Llama-3.1-8B-Instruct) and [Nemotron-Post-Training-Dataset-v2](https://huggingface.co/datasets/nvidia/Nemotron-Post-Training-Dataset-v2), you need to accept the terms and conditions for the corresponding model and the dataset in the Huggingface Hub. Log in to the Huggingface Hub and enter your HF token. + +```bash +hf auth login +``` + +## Compress the Model + +1. Download and prepare the [Nemotron-Post-Training-Dataset-v2](https://huggingface.co/datasets/nvidia/Nemotron-Post-Training-Dataset-v2). + + dataset split: "code", "math", "stem", "chat", excluding reasoning samples (2.62GB) + + ```bash + python -m modelopt.torch.puzzletron.dataset.prepare_dataset --dataset_name nvidia/Nemotron-Post-Training-Dataset-v2 --output_dir path/to/Nemotron-Post-Training-Dataset-v2 + ``` + +2. Specify the `puzzle_dir`, `input_hf_model_path`, `dataset_path`, `intermediate_size_list`, and `target_memory` arguments in the [llama-3_1-8B_pruneffn_memory.yaml](./configs/llama-3_1-8B_pruneffn_memory/llama-3_1-8B_pruneffn_memory.yaml) configuration file. + + - `puzzle_dir` indicates a new directory for saving the resulting model. + - `input_hf_model_path` indicates the local directory with the input model checkpoint. + - `dataset_path` indicates the directory with the dataset downloaded earlier. + + **_NOTE:_** + How to choose `intermediate_size_list`? + The list specifies the candidate FFN sizes that we wish to search over. It is recommended to choose several pruning sizes (e.g. 15%, 20%, 30% etc of the original). Note that the values must be hardware-friendly (divisible by a 256) to avoid issues with tensor operations in subsequent steps. + + Let's first shoot for 32% GPU memory reduction setting `target_memory = 78_000` MiB. This means that the algorithm will choose the candidates with highest accuracy that also meet the specified requirements. + + We can also set the target size of the resulting model using `num_params = 7_000_000_000`. This will be used as an upper bound for the number of parameters of the model. + +3. Run the puzzletron pipeline. + + ```bash + torchrun --nproc_per_node 2 examples/puzzletron/main.py --config examples/puzzletron/configs/llama-3_1-8B_pruneffn_memory/llama-3_1-8B_pruneffn_memory.yaml 2>&1 | tee ./log.txt | grep "Puzzletron Progress" + ``` + + This will save the full output to `log.txt` and display the following progress on screen: + + ```bash + [2025-11-02 12:06:34][rank-0][main.py:71] Puzzletron Progress 1/8: starting puzzletron pipeline + [2025-11-02 12:06:45][rank-0][puzzletron_nas_plugin.py:123] Puzzletron Progress 2/8: converting model from HF to DeciLM (single-gpu) + [2025-11-02 12:07:07][rank-0][puzzletron_nas_plugin.py:132] Puzzletron Progress 3/8: scoring pruning activations (multi-gpu) + [2025-11-02 12:11:36][rank-0][puzzletron_nas_plugin.py:137] Puzzletron Progress 4/8: pruning the model and saving pruned checkpoints (single-gpu) + [2025-11-02 12:12:20][rank-0][puzzletron_nas_plugin.py:217] Puzzletron Progress 5/8: building replacement library and subblock statistics (single-gpu) + [2025-11-02 12:12:21][rank-0][puzzletron_nas_plugin.py:222] Puzzletron Progress 6/8: calculating one block scores (multi-gpu) + [2025-11-02 12:50:41][rank-0][puzzletron_nas_plugin.py:226] Puzzletron Progress 7/8: running MIP and realizing models (multi-gpu) + [2025-11-02 12:52:34][rank-0][main.py:115] Puzzletron Progress 8/8: puzzletron pipeline completed (multi-gpu) + ``` + + Once the process is complete, the resulting network architecture will be recorded in `log.txt` for your review: + + ```bash + ... + block_0: attention gqa_4 ffn intermediate_14336 + block_1: attention gqa_4 ffn intermediate_14336 + block_2: attention gqa_4 ffn intermediate_14336 + block_3: attention gqa_4 ffn intermediate_14336 + block_4: attention gqa_4 ffn intermediate_14336 + block_5: attention gqa_4 ffn intermediate_14336 + block_6: attention gqa_4 ffn intermediate_14336 + block_7: attention gqa_4 ffn intermediate_14336 + block_8: attention gqa_4 ffn intermediate_14336 + block_9: attention gqa_4 ffn intermediate_14336 + block_10: attention gqa_4 ffn intermediate_14336 + block_11: attention gqa_4 ffn intermediate_14336 + block_12: attention gqa_4 ffn intermediate_14336 + block_13: attention gqa_4 ffn intermediate_14336 + block_14: attention gqa_4 ffn intermediate_14336 + block_15: attention gqa_4 ffn intermediate_14336 + block_16: attention gqa_4 ffn intermediate_14336 + block_17: attention no_op ffn intermediate_14336 + block_18: attention no_op ffn intermediate_14336 + block_19: attention no_op ffn intermediate_14336 + block_20: attention no_op ffn intermediate_14336 + block_21: attention no_op ffn intermediate_14336 + block_22: attention no_op ffn intermediate_14336 + block_23: attention no_op ffn intermediate_14336 + block_24: attention no_op ffn intermediate_14336 + block_25: attention no_op ffn intermediate_14336 + block_26: attention no_op ffn intermediate_14336 + block_27: attention no_op ffn intermediate_14336 + block_28: attention no_op ffn intermediate_14336 + block_29: attention gqa_4 ffn intermediate_14336 + block_30: attention gqa_4 ffn intermediate_14336 + block_31: attention gqa_4 ffn intermediate_14336 + + [2025-11-02 04:53:11,332]^[[92m[rank-0]^[[0m[run_puzzle.py:295] Total costs: {'stats.memory_mib': 75796.4140625, 'stats.ffn_num_params': 5637275648, 'stats.num_kv_heads': 160, 'stats.kv_cache_memory_mib': 61440.0, 'stats.ffn_memory_mib': 10752.25, 'stats.attention_memory_mib': 63040.15625, 'stats.attention_num_params': 838942720, 'stats.num_params': 7526895616, 'stats.has_attention': 20, 'stats.has_ffn': 32} + ... + ################################################################ + validate_model_and_extract_token_probs(model_name='teacher') + ################################################################ + ... + Average losses = {'lm_loss': 1.118250765837729, 'token_accuracy_top_1': 0.7331905364990234, 'token_accuracy_top_5': 0.9094219207763672, 'token_accuracy_top_10': 0.9423646926879883} + ... + ################################################################ + validate_model_with_kl_div(model_name='solution_0', is_calc_kl_div=True) + ################################################################ + .... + Average losses = {'lm_loss': 1.7577573340386152, 'token_accuracy_top_1': 0.6225490570068359, 'token_accuracy_top_5': 0.846257209777832, 'token_accuracy_top_10': 0.8987817764282227} + ``` + + 30% GPU memory reduction leads to nearly 5% regression in token_accuracy_top_10 metric (0.898 / 0.942). + +## Re-run MIP Search with different constraints + +If you want to try different constraints without re-running the expensive pruning and scoring steps, use the `--mip-only` flag. +This assumes pruning, replacement library building, NAS scoring, and subblock stats calculation have already been completed. + +For example, let's set `target_memory: 96_000` in `llama-3_1-8B_pruneffn_memory.yaml`. + +```bash +torchrun --nproc_per_node 2 examples/puzzletron/main.py --config examples/puzzletron/configs/llama-3_1-8B_pruneffn_memory/llama-3_1-8B_pruneffn_memory.yaml --mip-only 2>&1 | tee ./log.txt | grep "Puzzletron Progress" +``` + +This will generate the following network architecture (see `log.txt`): + +```bash +block_0: attention gqa_4 ffn intermediate_14336 +block_1: attention gqa_4 ffn intermediate_14336 +block_2: attention gqa_4 ffn intermediate_14336 +block_3: attention gqa_4 ffn intermediate_14336 +block_4: attention gqa_4 ffn intermediate_14336 +block_5: attention gqa_4 ffn intermediate_14336 +block_6: attention gqa_4 ffn intermediate_14336 +block_7: attention gqa_4 ffn intermediate_14336 +block_8: attention gqa_4 ffn intermediate_14336 +block_9: attention gqa_4 ffn intermediate_14336 +block_10: attention gqa_4 ffn intermediate_14336 +block_11: attention gqa_4 ffn intermediate_14336 +block_12: attention gqa_4 ffn intermediate_14336 +block_13: attention gqa_4 ffn intermediate_14336 +block_14: attention gqa_4 ffn intermediate_14336 +block_15: attention gqa_4 ffn intermediate_14336 +block_16: attention gqa_4 ffn intermediate_14336 +block_17: attention gqa_4 ffn intermediate_14336 +block_18: attention no_op ffn intermediate_14336 +block_19: attention no_op ffn intermediate_14336 +block_20: attention no_op ffn intermediate_14336 +block_21: attention gqa_4 ffn intermediate_14336 +block_22: attention no_op ffn intermediate_14336 +block_23: attention no_op ffn intermediate_14336 +block_24: attention no_op ffn intermediate_14336 +block_25: attention gqa_4 ffn intermediate_14336 +block_26: attention gqa_4 ffn intermediate_14336 +block_27: attention gqa_4 ffn intermediate_14336 +block_28: attention gqa_4 ffn intermediate_14336 +block_29: attention gqa_4 ffn intermediate_14336 +block_30: attention gqa_4 ffn intermediate_14336 +block_31: attention gqa_4 ffn intermediate_14336 + +[2025-11-02 12:50:42,024]^[[92m[rank-0]^[[0m[run_puzzle.py:295] Total costs: {'stats.memory_mib': 94708.4609375, 'stats.has_ffn': 32, 'stats.ffn_memory_mib': 10752.25, 'stats.kv_cache_memory_mib': 79872.0, 'stats.attention_num_params': 1090625536, 'stats.ffn_num_params': 5637275648, 'stats.has_attention': 26, 'stats.num_params': 7778578432, 'stats.attention_memory_mib': 81952.203125, 'stats.num_kv_heads': 208} +... +################################################################ +validate_model_with_kl_div(model_name='solution_0', is_calc_kl_div=True) +################################################################ +Average losses = {'lm_loss': 1.2425934937782586, 'token_accuracy_top_1': 0.703862190246582, 'token_accuracy_top_5': 0.8954982757568359, 'token_accuracy_top_10': 0.9336576461791992 +``` + +On the other hand, if you set `target_memory: 28_000`, you'll observe that the intermediate FFN sizes are significantly reduced in certain layers (see `log.txt` for details): + +```bash +block_5: attention no_op ffn intermediate_11520 +block_6: attention no_op ffn intermediate_14336 +block_7: attention no_op ffn intermediate_8704 +block_8: attention no_op ffn intermediate_14336 +block_9: attention no_op ffn intermediate_3072 +block_10: attention no_op ffn intermediate_11520 +block_11: attention no_op ffn intermediate_11520 +block_12: attention no_op ffn intermediate_11520 +block_13: attention no_op ffn intermediate_11520 +block_14: attention no_op ffn intermediate_3072 +``` + +### MIP Sweep Mode + +The **MIP sweep mode** lets you explore multiple memory compression rates in a single run and compare the accuracy-memory trade-offs. + +#### Quick Start + +1. Enable sweep in your config YAML (e.g., `llama-3_1-8B_pruneffn_memory.yaml`): + + ```yaml + mip: + sweep: + enabled: true + memory_compression_rates: [0.5, 0.6, 0.7, 0.8, 0.9, 1.0] + output_csv: ${puzzle_dir}/mip_sweep_results.csv + ``` + +2. Run the sweep: + + ```bash + torchrun --nproc_per_node 2 examples/puzzletron/main.py --config examples/puzzletron/configs/llama-3_1-8B_pruneffn_memory/llama-3_1-8B_pruneffn_memory.yaml --mip-only 2>&1 | tee ./log.txt | grep "Puzzletron Progress" + ``` + +3. View results: The CSV file contains compression rates, memory usage, and accuracy metrics for each configuration. + +#### Example Results + +MIP Sweep Results + +The plot shows how token accuracy changes with different compression rates. Higher compression (0.5 = 50% of original memory) reduces accuracy, while lower compression maintains accuracy closer to the teacher model. + +## Evaluation + +Evaluate AnyModel checkpoints using [lm-eval](https://github.com/EleutherAI/lm-evaluation-harness) directly. + +```bash +python examples/puzzletron/evaluation/lm_eval_anymodel.py \ + --model hf \ + --model_args pretrained=path/to/checkpoint,dtype=bfloat16,parallelize=True \ + --tasks mmlu \ + --num_fewshot 5 \ + --batch_size 4 +``` + +For a quick smoke test, add `--limit 10`. + +> **Alternative:** For server-based evaluation via an OpenAI-compatible endpoint, +> see [evaluation/nemo_evaluator_instructions.md](./evaluation/nemo_evaluator_instructions.md). + +## Inference Performance Benchmarking + +Now let's evaluate how much speedup we get with the compressed model in terms of throughput and latency. + +- Install [vLLM from source](https://docs.vllm.ai/en/latest/getting_started/installation/gpu/index.html#build-wheel-from-source). +- Rearrange the model safetensors to be used for vLLM. + +```bash +cd path/to/model +mv subblocks_safetensors/* . +sed -i 's+subblocks_safetensors/++g' model.safetensors.index.json +``` + +- Benchmark latency + +```bash +vllm bench latency --model path/to/model --load-format safetensors --trust-remote-code +``` + +- Benchmark throughput + +```bash +vllm bench throughput --model path/to/model --input-len 2000 --output-len 100 --load-format safetensors --trust-remote-code +``` + +## Knowledge Distillation + +To recover degradation in the quality of the compressed model, we can use knowledge distillation. This allows transferring the capabilities of the original model to the pruned one. + +See [mbridge_distillation/README.md](./mbridge_distillation/README.md) for instructions on using Megatron-Bridge for knowledge distillation. + +## Advanced Usage + +Modify `llama-3_1-8B_pruneffn_memory.yaml` file for advanced compression scenarios. diff --git a/examples/puzzletron/configs/gptoss-20b_remove_experts_memory/gptoss-20b.yaml b/examples/puzzletron/configs/gptoss-20b_remove_experts_memory/gptoss-20b.yaml new file mode 100644 index 0000000000..b48f1de78c --- /dev/null +++ b/examples/puzzletron/configs/gptoss-20b_remove_experts_memory/gptoss-20b.yaml @@ -0,0 +1,110 @@ +defaults: + - pruning: ffn_pruning + - scoring: ../validate_solutions_defaults + - realize_model: ../validate_solutions_defaults + - bypass: + - override hydra/hydra_logging: disabled + - _self_ + +puzzle_dir: ??? +descriptor: gpt_oss +teacher_dir: ${puzzle_dir}/ckpts/teacher/ +replacement_library_path: ${puzzle_dir}/replacement_library.json +dataset_path: ??? # path to Nemotron-Post-Training-Dataset-v2 + +skip_realize_model: false + +build_replacement_library: + add_ffn_no_ops: true + add_attention_no_ops: true + +calc_subblock_stats: + batch_sizes: [64, 96, 128] + prefill_seq_len: 4096 + generation_seq_len: 4096 + num_active_tokens_override: # Optional override for sequence lengths + prefill_queue_size: 0 + allocate_prefill_query: false + benchmark_iterations: # Set to a number (e.g., 1000) to enable runtime benchmarking + merge_with_existing_stats: false + subblock_stats_filename: "subblock_stats.json" + moe_stats_filename: "moe_stats.json" + runtime_stats: + backend: trt_torch + +scoring: + descriptor: ${descriptor} + solutions_to_validate: + skip_existing_solutions: true + + replacement_library_path: ${replacement_library_path} + solutions_path: ${to_path:${puzzle_dir}/single_sequence_replacement_solutions.json} + teacher_dir: ${to_path:${teacher_dir}} + output_dir: ${puzzle_dir}/single_sequence_replacement_solutions--validation + + eval_samples: 128 + micro_batch_size: 1 + seed: 42 + shuffle_seed: 444 + dataset_path: ${dataset_path} + +mip: + single_block_replacement_validation_dir: ${to_path:${scoring.output_dir}} + subblock_stats_path: ${to_path:${puzzle_dir}/${calc_subblock_stats.subblock_stats_filename}} + output_path: ${to_path:${puzzle_dir}/mip/puzzle_solutions} + gathered_metrics_path: + puzzle_profile: + + # puzzle_profile: + objective: metrics.cosine_embedding_loss_hidden_states + bigger_is_better: false + + subblock_stats_args: + - batch_size: 96 + weights_dtype: torch.bfloat16 + activations_dtype: torch.bfloat16 + kv_cache_dtype: torch.bfloat16 + + report_additional_costs: + - stats.memory_mib + - stats.num_params + - stats.num_kv_heads + - stats.has_attention + - stats.has_ffn + - stats.kv_cache_memory_mib + - stats.attention_memory_mib + - stats.ffn_memory_mib + - stats.ffn_num_params + - stats.attention_num_params + + human_constraints: + target_memory: 45_000 + num_params: 3_000_000_000 + + mip_constraints: + metric_overrides: + max_seconds_per_solution: 60 + +realize_model: + descriptor: ${descriptor} + teacher_dir: ${to_path:${teacher_dir}} + tokenizer_name: ${to_path:${teacher_dir}} + replacement_library_path: ${replacement_library_path} + save_models: true + solutions_path: # Filled dynamically + + # Validate params + skip_validation: false # To enable validation of the model solution set `skip_validation` as False + eval_samples: 128 + micro_batch_size: 1 + seed: 42 + shuffle_seed: 444 + dataset_path: ${dataset_path} + +nccl_timeout_minutes: ${timedelta_minutes:10} + +# This section redirects Hydra outputs +hydra: + run: + dir: ${puzzle_dir}/hydra_logs/${now:%Y-%m-%d}/${now:%H-%M-%S} + diff --git a/examples/puzzletron/configs/gptoss-20b_remove_experts_memory/gptoss-20b_remove_experts_memory.yaml b/examples/puzzletron/configs/gptoss-20b_remove_experts_memory/gptoss-20b_remove_experts_memory.yaml new file mode 100644 index 0000000000..8ed06e9568 --- /dev/null +++ b/examples/puzzletron/configs/gptoss-20b_remove_experts_memory/gptoss-20b_remove_experts_memory.yaml @@ -0,0 +1,17 @@ +defaults: + - gptoss-20b + - _self_ + +# Input Hugging Face model to compress +input_hf_model_path: /workspace/hf_models/openai/gpt-oss-20b + +# Dataset path for pruning and NAS scoring +dataset_path: /workspace/datasets/Nemotron-Post-Training-Dataset-v2 + +# Working directory for compression outputs +puzzle_dir: /workspace/puzzle_dir + +# MIP memory constraint (in MiB) +mip: + human_constraints: + target_memory: 16_000 # 45 GiB diff --git a/examples/puzzletron/configs/gptoss-20b_remove_experts_memory/pruning/ffn_pruning.yaml b/examples/puzzletron/configs/gptoss-20b_remove_experts_memory/pruning/ffn_pruning.yaml new file mode 100644 index 0000000000..258e6c38a3 --- /dev/null +++ b/examples/puzzletron/configs/gptoss-20b_remove_experts_memory/pruning/ffn_pruning.yaml @@ -0,0 +1,21 @@ +defaults: + - pruning_defaults + +eval_samples: 2500 #10 +activations_log_dir: ${puzzle_dir}/pruning/pruning_scores/expert_removal/${pruning.experiment_id} + +pruning_mixin: + _target_: modelopt.torch.puzzletron.pruning.expert_removal_pruning_mixin.ExpertRemovalPruningMixIn + layer_descriptor: + _target_: modelopt.torch.puzzletron.anymodel.models.gpt_oss.gpt_oss_model_descriptor.GptOssExpertRemovalLayerDescriptor + target_name: "mlp.router" + +hook_class: ${get_object:modelopt.torch.prune.importance_hooks.expert_removal_hooks.RankedChoiceVotingHook} +activation_hooks_kwargs: # Additional kwargs to pass to the hook init + +num_experts_to_keep_list: [24, 16, 8] # num_experts in teacher is 128 +mlp_init_mode: "ExpertRemoval" +mlp_init_config_yaml: + expert_scores_key: "expert_ranks" + layer_prefix_template: "model.layers.{layer_idx}.mlp.router" + diff --git a/examples/puzzletron/configs/gptoss-20b_remove_experts_memory/pruning/pruning_defaults.yaml b/examples/puzzletron/configs/gptoss-20b_remove_experts_memory/pruning/pruning_defaults.yaml new file mode 100644 index 0000000000..0eff799d7e --- /dev/null +++ b/examples/puzzletron/configs/gptoss-20b_remove_experts_memory/pruning/pruning_defaults.yaml @@ -0,0 +1,34 @@ +defaults: + - /validate_model_defaults + +model_name_or_path: ${teacher_dir} +experiment_id: ${pruning.eval_samples}samples_diverse_mini +activations_log_dir: ??? +activation_hooks_kwargs: ??? + +descriptor: ${descriptor} + +# Data: +eval_samples: 10_000 +micro_batch_size: 1 +dataset_path: ${dataset_path} +val_dataset_name: train + +# Prune ckpts +pruned_ckpts_output_dir: ${puzzle_dir}/pruning/${pruning.experiment_id} + +## FFN pruning +ffn_list: +mlp_init_mode: "Truncate" # PruneByActivationsLog + +## KV-heads pruning +n_heads_in_group_list: +gqa_init_mode: "AverageKV" + +## Hidden dimension pruning +hidden_size_list: +hidden_size_init_mode: "PruneByChannelRanking" +linear_init_mode: "FromTeacher" + +mlp_init_config_yaml: + activations_log_dir: ${pruning.activations_log_dir} diff --git a/examples/puzzletron/configs/gptoss-20b_remove_experts_memory/validate_model_defaults.yaml b/examples/puzzletron/configs/gptoss-20b_remove_experts_memory/validate_model_defaults.yaml new file mode 100644 index 0000000000..b80faea5f5 --- /dev/null +++ b/examples/puzzletron/configs/gptoss-20b_remove_experts_memory/validate_model_defaults.yaml @@ -0,0 +1,18 @@ +model_dtype: torch.bfloat16 # dtype to cast the model for validate_model +autocast_dtype: torch.bfloat16 # dtype for torch.autocast for validate_model +block_size: 8192 +bos_rate: 0.5 +data_column: messages +val_dataset_name: valid +shuffle_seed: 81436 +seed: 42 +fim_rate: 0 +fim_spm_rate: 0 +source_datasets_to_discard: +varlen: false +write_results: false +calc_losses_on_cpu: false +activations_log_dir: +model_name_or_path: +load_dataset_fn: ${get_object:modelopt.torch.puzzletron.utils.data.dataloaders.load_from_disk_fn} + diff --git a/examples/puzzletron/configs/gptoss-20b_remove_experts_memory/validate_solutions_defaults.yaml b/examples/puzzletron/configs/gptoss-20b_remove_experts_memory/validate_solutions_defaults.yaml new file mode 100644 index 0000000000..ab8c892182 --- /dev/null +++ b/examples/puzzletron/configs/gptoss-20b_remove_experts_memory/validate_solutions_defaults.yaml @@ -0,0 +1,11 @@ +defaults: + - /validate_model_defaults + - _self_ + +solutions_to_validate: +skip_validation: false +save_models: false +bigger_is_better: false +sort_solutions_by: +calculate_full_score_ablations: false + diff --git a/examples/puzzletron/configs/llama-3_1-8B_pruneffn_memory/Llama-3_1-8B.yaml b/examples/puzzletron/configs/llama-3_1-8B_pruneffn_memory/Llama-3_1-8B.yaml new file mode 100644 index 0000000000..21903db162 --- /dev/null +++ b/examples/puzzletron/configs/llama-3_1-8B_pruneffn_memory/Llama-3_1-8B.yaml @@ -0,0 +1,109 @@ +defaults: + - pruning: ffn_pruning + - scoring: ../validate_solutions_defaults + - realize_model: ../validate_solutions_defaults + - bypass: + - override hydra/hydra_logging: disabled + - _self_ + +puzzle_dir: ??? +descriptor: llama +teacher_dir: ${puzzle_dir}/ckpts/teacher/ +replacement_library_path: ${puzzle_dir}/replacement_library.json +dataset_path: ??? # ppath to Nemotron-Post-Training-Dataset-v2 + +skip_realize_model: false + +build_replacement_library: + add_ffn_no_ops: true + add_attention_no_ops: true + +calc_subblock_stats: + batch_sizes: [64, 96, 128] + prefill_seq_len: 4096 + generation_seq_len: 4096 + num_active_tokens_override: # Optional override for sequence lengths + prefill_queue_size: 0 + allocate_prefill_query: false + benchmark_iterations: # Set to a number (e.g., 1000) to enable runtime benchmarking + merge_with_existing_stats: false + subblock_stats_filename: "subblock_stats.json" + moe_stats_filename: "moe_stats.json" + runtime_stats: + backend: trt_torch + +scoring: + descriptor: ${descriptor} + solutions_to_validate: + skip_existing_solutions: true + + replacement_library_path: ${replacement_library_path} + solutions_path: ${to_path:${puzzle_dir}/single_sequence_replacement_solutions.json} + teacher_dir: ${to_path:${teacher_dir}} + output_dir: ${puzzle_dir}/single_sequence_replacement_solutions--validation + + eval_samples: 128 + micro_batch_size: 1 + seed: 42 + shuffle_seed: 444 + dataset_path: ${dataset_path} + +mip: + single_block_replacement_validation_dir: ${to_path:${scoring.output_dir}} + subblock_stats_path: ${to_path:${puzzle_dir}/${calc_subblock_stats.subblock_stats_filename}} + output_path: ${to_path:${puzzle_dir}/mip/puzzle_solutions} + gathered_metrics_path: + puzzle_profile: + + # puzzle_profile: + objective: metrics.cosine_embedding_loss_hidden_states + bigger_is_better: false + + subblock_stats_args: + - batch_size: 96 + weights_dtype: torch.bfloat16 + activations_dtype: torch.bfloat16 + kv_cache_dtype: torch.bfloat16 + + report_additional_costs: + - stats.memory_mib + - stats.num_params + - stats.num_kv_heads + - stats.has_attention + - stats.has_ffn + - stats.kv_cache_memory_mib + - stats.attention_memory_mib + - stats.ffn_memory_mib + - stats.ffn_num_params + - stats.attention_num_params + + human_constraints: + target_memory: 78_000 + num_params: 7_000_000_000 + + mip_constraints: + metric_overrides: + max_seconds_per_solution: 60 + +realize_model: + descriptor: ${descriptor} + teacher_dir: ${to_path:${teacher_dir}} + tokenizer_name: ${to_path:${teacher_dir}} + replacement_library_path: ${replacement_library_path} + save_models: true + solutions_path: # Filled dynamically + + # Validate params + skip_validation: false # To enable validation of the model solution set `skip_validation` as False + eval_samples: 128 + micro_batch_size: 1 + seed: 42 + shuffle_seed: 444 + dataset_path: ${dataset_path} + +nccl_timeout_minutes: ${timedelta_minutes:10} + +# This section redirects Hydra outputs +hydra: + run: + dir: ${puzzle_dir}/hydra_logs/${now:%Y-%m-%d}/${now:%H-%M-%S} diff --git a/examples/puzzletron/configs/llama-3_1-8B_pruneffn_memory/llama-3_1-8B_pruneffn_memory.yaml b/examples/puzzletron/configs/llama-3_1-8B_pruneffn_memory/llama-3_1-8B_pruneffn_memory.yaml new file mode 100644 index 0000000000..ad16dbc5ea --- /dev/null +++ b/examples/puzzletron/configs/llama-3_1-8B_pruneffn_memory/llama-3_1-8B_pruneffn_memory.yaml @@ -0,0 +1,26 @@ +defaults: + - Llama-3_1-8B + - _self_ + +# Input Hugging Face model to compress +input_hf_model_path: /workspace/hf_models/meta-llama/Llama-3.1-8B-Instruct + +# Dataset path for pruning and NAS scoring +dataset_path: /workspace/datasets/Nemotron-Post-Training-Dataset-v2 + +# Working directory for puzzletron outputs +puzzle_dir: /workspace/puzzle_dir + +# MIP memory constraint (in MiB) +mip: + human_constraints: + target_memory: 78_000 # 78 GiB + # Memory sweep configuration (optional) + sweep: + enabled: false + memory_compression_rates: [0.5, 0.6, 0.7, 0.8, 0.9] + output_csv: ${puzzle_dir}/mip_sweep_results.csv + +# FFN intermediate sizes to search over (heterogeneous architecture) +pruning: + intermediate_size_list: [3072, 5888, 8704, 11520] # teacher_intermediate_size is 14336 diff --git a/examples/puzzletron/configs/llama-3_1-8B_pruneffn_memory/pruning/attn_pruning.yaml b/examples/puzzletron/configs/llama-3_1-8B_pruneffn_memory/pruning/attn_pruning.yaml new file mode 100644 index 0000000000..01886607e4 --- /dev/null +++ b/examples/puzzletron/configs/llama-3_1-8B_pruneffn_memory/pruning/attn_pruning.yaml @@ -0,0 +1,16 @@ +defaults: + - pruning_defaults + +activations_log_dir: ${puzzle_dir}/pruning/pruning_scores/attn_${pruning.activation_hooks_kwargs.method}/${pruning.experiment_id} + +activation_hooks_kwargs: + method: independent_kv_head_contribution + optimize_for: memory # IndependentKvHeadContributionHook implementation that consumes less memory + target_layer: "self_attn.o_proj" + layer_input_descriptors_path: + +# n_heads_in_group: 4 +# num_attention_heads: 32 # num query heads +# num_kv_heads: 32 / 4 = 8 # num_query_heads // n_heads_in_group +n_heads_in_group_list: [8, 16, 32] # num_kv_heads = [4, 2, 1] +gqa_init_mode: "PruneKVHeads" diff --git a/examples/puzzletron/configs/llama-3_1-8B_pruneffn_memory/pruning/ffn_pruning.yaml b/examples/puzzletron/configs/llama-3_1-8B_pruneffn_memory/pruning/ffn_pruning.yaml new file mode 100644 index 0000000000..da0b972070 --- /dev/null +++ b/examples/puzzletron/configs/llama-3_1-8B_pruneffn_memory/pruning/ffn_pruning.yaml @@ -0,0 +1,19 @@ +defaults: + - pruning_defaults + +pruning_mixin: + _target_: modelopt.torch.puzzletron.pruning.ffn_intermediate_pruning_mixin.FFNIntermediatePruningMixIn + layer_descriptor: + _target_: modelopt.torch.puzzletron.anymodel.models.llama.llama_model_descriptor.LlamaFFNIntermediateLayerDescriptor + +hook_class: ${get_object:modelopt.torch.prune.importance_hooks.base_hooks.IterativeChannelContributionHook} + +activations_log_dir: ${puzzle_dir}/pruning/pruning_scores/ffn_${pruning.activation_hooks_kwargs.method}/${pruning.experiment_id} + +activation_hooks_kwargs: + method: iterative + target_layer: "mlp.down_proj" + layer_input_descriptors_path: + +intermediate_size_list: [3072, 5888, 8704, 11520] # teacher_intermediate_size is 14336 +mlp_init_mode: "PruneByActivationsLog" diff --git a/examples/puzzletron/configs/llama-3_1-8B_pruneffn_memory/pruning/hidden_dim_pruning.yaml b/examples/puzzletron/configs/llama-3_1-8B_pruneffn_memory/pruning/hidden_dim_pruning.yaml new file mode 100644 index 0000000000..407c835d8c --- /dev/null +++ b/examples/puzzletron/configs/llama-3_1-8B_pruneffn_memory/pruning/hidden_dim_pruning.yaml @@ -0,0 +1,15 @@ +defaults: + - pruning_defaults + +activations_log_dir: ${puzzle_dir}/pruning/pruning_scores/hidden_dim_${pruning.activation_hooks_kwargs.method}/${pruning.experiment_id} + +activation_hooks_kwargs: + method: layer_norm_contribution + target_layer: "layernorm" + +# Hidden dimension pruning specific settings +hidden_size_list: [3072, 2048] # Target hidden sizes to prune to +hidden_size_init_mode: "PruneByChannelRanking" +mlp_init_mode: "Truncate" # TODO, make it work with CopyAsIs/FromTeacher +gqa_init_mode: "AverageKV" # TODO, make it work with CopyAsIs/FromTeacher +linear_init_mode: "FromTeacher" diff --git a/examples/puzzletron/configs/llama-3_1-8B_pruneffn_memory/pruning/pruning_defaults.yaml b/examples/puzzletron/configs/llama-3_1-8B_pruneffn_memory/pruning/pruning_defaults.yaml new file mode 100644 index 0000000000..e05e775bee --- /dev/null +++ b/examples/puzzletron/configs/llama-3_1-8B_pruneffn_memory/pruning/pruning_defaults.yaml @@ -0,0 +1,33 @@ +defaults: + - /validate_model_defaults + +descriptor: ${descriptor} +model_name_or_path: ${teacher_dir} +experiment_id: ${pruning.eval_samples}samples_diverse_mini +activations_log_dir: ??? +activation_hooks_kwargs: ??? + +# Data: +eval_samples: 1000 # default is 10000 +micro_batch_size: 4 +dataset_path: ${dataset_path} +val_dataset_name: train + +# Prune ckpts +pruned_ckpts_output_dir: ${puzzle_dir}/pruning/${pruning.experiment_id} + +## FFN pruning +ffn_list: +mlp_init_mode: "Truncate" # PruneByActivationsLog + +## KV-heads pruning +n_heads_in_group_list: +gqa_init_mode: "AverageKV" + +## Hidden dimension pruning +hidden_size_list: +hidden_size_init_mode: "PruneByChannelRanking" +linear_init_mode: "FromTeacher" + +mlp_init_config_yaml: + activations_log_dir: ${pruning.activations_log_dir} diff --git a/examples/puzzletron/configs/llama-3_1-8B_pruneffn_memory/validate_model_defaults.yaml b/examples/puzzletron/configs/llama-3_1-8B_pruneffn_memory/validate_model_defaults.yaml new file mode 100644 index 0000000000..ce1749d969 --- /dev/null +++ b/examples/puzzletron/configs/llama-3_1-8B_pruneffn_memory/validate_model_defaults.yaml @@ -0,0 +1,17 @@ +model_dtype: torch.bfloat16 # dtype to cast the model for validate_model +autocast_dtype: torch.bfloat16 # dtype for torch.autocast for validate_model +block_size: 8192 +bos_rate: 0.5 +data_column: messages +val_dataset_name: valid +shuffle_seed: 81436 +seed: 42 +fim_rate: 0 +fim_spm_rate: 0 +source_datasets_to_discard: +varlen: false +write_results: false +calc_losses_on_cpu: false +activations_log_dir: +model_name_or_path: +load_dataset_fn: ${get_object:modelopt.torch.puzzletron.utils.data.dataloaders.load_from_disk_fn} diff --git a/examples/puzzletron/configs/llama-3_1-8B_pruneffn_memory/validate_solutions_defaults.yaml b/examples/puzzletron/configs/llama-3_1-8B_pruneffn_memory/validate_solutions_defaults.yaml new file mode 100644 index 0000000000..ec13902379 --- /dev/null +++ b/examples/puzzletron/configs/llama-3_1-8B_pruneffn_memory/validate_solutions_defaults.yaml @@ -0,0 +1,10 @@ +defaults: + - /validate_model_defaults + - _self_ + +solutions_to_validate: +skip_validation: false +save_models: false +bigger_is_better: false +sort_solutions_by: +calculate_full_score_ablations: false diff --git a/examples/puzzletron/configs/llama-3_2-3B_pruneffn_memory/Llama-3_2-3B.yaml b/examples/puzzletron/configs/llama-3_2-3B_pruneffn_memory/Llama-3_2-3B.yaml new file mode 100644 index 0000000000..7de281e788 --- /dev/null +++ b/examples/puzzletron/configs/llama-3_2-3B_pruneffn_memory/Llama-3_2-3B.yaml @@ -0,0 +1,110 @@ +defaults: + - pruning: ffn_pruning + - scoring: ../validate_solutions_defaults + - realize_model: ../validate_solutions_defaults + - bypass: + - override hydra/hydra_logging: disabled + - _self_ + +puzzle_dir: ??? +descriptor: llama +teacher_dir: ${puzzle_dir}/ckpts/teacher/ +replacement_library_path: ${puzzle_dir}/replacement_library.json +dataset_path: ??? # path to Nemotron-Post-Training-Dataset-v2 + +skip_realize_model: false + +build_replacement_library: + add_ffn_no_ops: true + add_attention_no_ops: true + +calc_subblock_stats: + batch_sizes: [64, 96, 128] + prefill_seq_len: 4096 + generation_seq_len: 4096 + num_active_tokens_override: # Optional override for sequence lengths + prefill_queue_size: 0 + allocate_prefill_query: false + benchmark_iterations: # Set to a number (e.g., 1000) to enable runtime benchmarking + merge_with_existing_stats: false + subblock_stats_filename: "subblock_stats.json" + moe_stats_filename: "moe_stats.json" + runtime_stats: + backend: trt_torch + +scoring: + descriptor: ${descriptor} + solutions_to_validate: + skip_existing_solutions: true + + replacement_library_path: ${replacement_library_path} + solutions_path: ${to_path:${puzzle_dir}/single_sequence_replacement_solutions.json} + teacher_dir: ${to_path:${teacher_dir}} + output_dir: ${puzzle_dir}/single_sequence_replacement_solutions--validation + + eval_samples: 128 + micro_batch_size: 1 + seed: 42 + shuffle_seed: 444 + dataset_path: ${dataset_path} + +mip: + single_block_replacement_validation_dir: ${to_path:${scoring.output_dir}} + subblock_stats_path: ${to_path:${puzzle_dir}/${calc_subblock_stats.subblock_stats_filename}} + output_path: ${to_path:${puzzle_dir}/mip/puzzle_solutions} + gathered_metrics_path: + puzzle_profile: + + # puzzle_profile: + objective: metrics.cosine_embedding_loss_hidden_states + bigger_is_better: false + + subblock_stats_args: + - batch_size: 96 + weights_dtype: torch.bfloat16 + activations_dtype: torch.bfloat16 + kv_cache_dtype: torch.bfloat16 + + report_additional_costs: + - stats.memory_mib + - stats.num_params + - stats.num_kv_heads + - stats.has_attention + - stats.has_ffn + - stats.kv_cache_memory_mib + - stats.attention_memory_mib + - stats.ffn_memory_mib + - stats.ffn_num_params + - stats.attention_num_params + + human_constraints: + target_memory: 45_000 + num_params: 3_000_000_000 + + mip_constraints: + metric_overrides: + max_seconds_per_solution: 60 + +realize_model: + descriptor: ${descriptor} + teacher_dir: ${to_path:${teacher_dir}} + tokenizer_name: ${to_path:${teacher_dir}} + replacement_library_path: ${replacement_library_path} + save_models: true + solutions_path: # Filled dynamically + + # Validate params + skip_validation: false # To enable validation of the model solution set `skip_validation` as False + eval_samples: 128 + micro_batch_size: 1 + seed: 42 + shuffle_seed: 444 + dataset_path: ${dataset_path} + +nccl_timeout_minutes: ${timedelta_minutes:10} + +# This section redirects Hydra outputs +hydra: + run: + dir: ${puzzle_dir}/hydra_logs/${now:%Y-%m-%d}/${now:%H-%M-%S} + diff --git a/examples/puzzletron/configs/llama-3_2-3B_pruneffn_memory/llama-3_2-3B_pruneffn_memory.yaml b/examples/puzzletron/configs/llama-3_2-3B_pruneffn_memory/llama-3_2-3B_pruneffn_memory.yaml new file mode 100644 index 0000000000..b5303d318a --- /dev/null +++ b/examples/puzzletron/configs/llama-3_2-3B_pruneffn_memory/llama-3_2-3B_pruneffn_memory.yaml @@ -0,0 +1,22 @@ +defaults: + - Llama-3_2-3B + - _self_ + +# Input Hugging Face model to compress +input_hf_model_path: /workspace/hf_models/meta-llama/Llama-3.2-3B-Instruct + +# Dataset path for pruning and NAS scoring +dataset_path: /workspace/datasets/Nemotron-Post-Training-Dataset-v2 + +# Working directory for compression outputs +puzzle_dir: /workspace/puzzle_dir + +# MIP memory constraint (in MiB) +mip: + human_constraints: + target_memory: 45_000 # 45 GiB + +# FFN intermediate sizes to search over (heterogeneous architecture) +# teacher_intermediate_size is 8192, so we use proportionally smaller values +pruning: + intermediate_size_list: [2048, 4096, 6144] diff --git a/examples/puzzletron/configs/llama-3_2-3B_pruneffn_memory/pruning/ffn_pruning.yaml b/examples/puzzletron/configs/llama-3_2-3B_pruneffn_memory/pruning/ffn_pruning.yaml new file mode 100644 index 0000000000..05de8bfdcc --- /dev/null +++ b/examples/puzzletron/configs/llama-3_2-3B_pruneffn_memory/pruning/ffn_pruning.yaml @@ -0,0 +1,21 @@ +defaults: + - pruning_defaults + +pruning_mixin: + _target_: modelopt.torch.puzzletron.pruning.ffn_intermediate_pruning_mixin.FFNIntermediatePruningMixIn + layer_descriptor: + _target_: modelopt.torch.puzzletron.anymodel.models.llama.llama_model_descriptor.LlamaFFNIntermediateLayerDescriptor + +hook_class: ${get_object:modelopt.torch.prune.importance_hooks.base_hooks.IterativeChannelContributionHook} + +activations_log_dir: ${puzzle_dir}/pruning/pruning_scores/ffn_${pruning.activation_hooks_kwargs.method}/${pruning.experiment_id} + +activation_hooks_kwargs: + method: iterative + target_layer: "mlp.down_proj" + layer_input_descriptors_path: + +# Llama-3.2-3B has intermediate_size=8192, so we use proportionally smaller pruning sizes +intermediate_size_list: [2048, 4096, 6144] +mlp_init_mode: "PruneByActivationsLog" + diff --git a/examples/puzzletron/configs/llama-3_2-3B_pruneffn_memory/pruning/pruning_defaults.yaml b/examples/puzzletron/configs/llama-3_2-3B_pruneffn_memory/pruning/pruning_defaults.yaml new file mode 100644 index 0000000000..e05e775bee --- /dev/null +++ b/examples/puzzletron/configs/llama-3_2-3B_pruneffn_memory/pruning/pruning_defaults.yaml @@ -0,0 +1,33 @@ +defaults: + - /validate_model_defaults + +descriptor: ${descriptor} +model_name_or_path: ${teacher_dir} +experiment_id: ${pruning.eval_samples}samples_diverse_mini +activations_log_dir: ??? +activation_hooks_kwargs: ??? + +# Data: +eval_samples: 1000 # default is 10000 +micro_batch_size: 4 +dataset_path: ${dataset_path} +val_dataset_name: train + +# Prune ckpts +pruned_ckpts_output_dir: ${puzzle_dir}/pruning/${pruning.experiment_id} + +## FFN pruning +ffn_list: +mlp_init_mode: "Truncate" # PruneByActivationsLog + +## KV-heads pruning +n_heads_in_group_list: +gqa_init_mode: "AverageKV" + +## Hidden dimension pruning +hidden_size_list: +hidden_size_init_mode: "PruneByChannelRanking" +linear_init_mode: "FromTeacher" + +mlp_init_config_yaml: + activations_log_dir: ${pruning.activations_log_dir} diff --git a/examples/puzzletron/configs/llama-3_2-3B_pruneffn_memory/validate_model_defaults.yaml b/examples/puzzletron/configs/llama-3_2-3B_pruneffn_memory/validate_model_defaults.yaml new file mode 100644 index 0000000000..b80faea5f5 --- /dev/null +++ b/examples/puzzletron/configs/llama-3_2-3B_pruneffn_memory/validate_model_defaults.yaml @@ -0,0 +1,18 @@ +model_dtype: torch.bfloat16 # dtype to cast the model for validate_model +autocast_dtype: torch.bfloat16 # dtype for torch.autocast for validate_model +block_size: 8192 +bos_rate: 0.5 +data_column: messages +val_dataset_name: valid +shuffle_seed: 81436 +seed: 42 +fim_rate: 0 +fim_spm_rate: 0 +source_datasets_to_discard: +varlen: false +write_results: false +calc_losses_on_cpu: false +activations_log_dir: +model_name_or_path: +load_dataset_fn: ${get_object:modelopt.torch.puzzletron.utils.data.dataloaders.load_from_disk_fn} + diff --git a/examples/puzzletron/configs/llama-3_2-3B_pruneffn_memory/validate_solutions_defaults.yaml b/examples/puzzletron/configs/llama-3_2-3B_pruneffn_memory/validate_solutions_defaults.yaml new file mode 100644 index 0000000000..ab8c892182 --- /dev/null +++ b/examples/puzzletron/configs/llama-3_2-3B_pruneffn_memory/validate_solutions_defaults.yaml @@ -0,0 +1,11 @@ +defaults: + - /validate_model_defaults + - _self_ + +solutions_to_validate: +skip_validation: false +save_models: false +bigger_is_better: false +sort_solutions_by: +calculate_full_score_ablations: false + diff --git a/examples/puzzletron/configs/mistral-small-24b-instruct-2501_pruneffn_memory/Mistral-Small-24B.yaml b/examples/puzzletron/configs/mistral-small-24b-instruct-2501_pruneffn_memory/Mistral-Small-24B.yaml new file mode 100644 index 0000000000..18213f9b7a --- /dev/null +++ b/examples/puzzletron/configs/mistral-small-24b-instruct-2501_pruneffn_memory/Mistral-Small-24B.yaml @@ -0,0 +1,109 @@ +defaults: + - pruning: ffn_pruning + - scoring: ../validate_solutions_defaults + - realize_model: ../validate_solutions_defaults + - bypass: + - override hydra/hydra_logging: disabled + - _self_ + +puzzle_dir: ??? +descriptor: mistral_small +teacher_dir: ${puzzle_dir}/ckpts/teacher/ +replacement_library_path: ${puzzle_dir}/replacement_library.json +dataset_path: ??? # path to Nemotron-Post-Training-Dataset-v2 + +skip_realize_model: false + +build_replacement_library: + add_ffn_no_ops: true + add_attention_no_ops: true + +calc_subblock_stats: + batch_sizes: [64, 96, 128] + prefill_seq_len: 4096 + generation_seq_len: 4096 + num_active_tokens_override: # Optional override for sequence lengths + prefill_queue_size: 0 + allocate_prefill_query: false + benchmark_iterations: # Set to a number (e.g., 1000) to enable runtime benchmarking + merge_with_existing_stats: false + subblock_stats_filename: "subblock_stats.json" + moe_stats_filename: "moe_stats.json" + runtime_stats: + backend: trt_torch + +scoring: + descriptor: ${descriptor} + solutions_to_validate: + skip_existing_solutions: true + + replacement_library_path: ${replacement_library_path} + solutions_path: ${to_path:${puzzle_dir}/single_sequence_replacement_solutions.json} + teacher_dir: ${to_path:${teacher_dir}} + output_dir: ${puzzle_dir}/single_sequence_replacement_solutions--validation + + eval_samples: 128 + micro_batch_size: 1 + seed: 42 + shuffle_seed: 444 + dataset_path: ${dataset_path} + +mip: + single_block_replacement_validation_dir: ${to_path:${scoring.output_dir}} + subblock_stats_path: ${to_path:${puzzle_dir}/${calc_subblock_stats.subblock_stats_filename}} + output_path: ${to_path:${puzzle_dir}/mip/puzzle_solutions} + gathered_metrics_path: + puzzle_profile: + + # puzzle_profile: + objective: metrics.cosine_embedding_loss_hidden_states + bigger_is_better: false + + subblock_stats_args: + - batch_size: 96 + weights_dtype: torch.bfloat16 + activations_dtype: torch.bfloat16 + kv_cache_dtype: torch.bfloat16 + + report_additional_costs: + - stats.memory_mib + - stats.num_params + - stats.num_kv_heads + - stats.has_attention + - stats.has_ffn + - stats.kv_cache_memory_mib + - stats.attention_memory_mib + - stats.ffn_memory_mib + - stats.ffn_num_params + - stats.attention_num_params + + human_constraints: + target_memory: 78_000 + num_params: 24_000_000_000 + + mip_constraints: + metric_overrides: + max_seconds_per_solution: 60 + +realize_model: + descriptor: ${descriptor} + teacher_dir: ${to_path:${teacher_dir}} + tokenizer_name: ${to_path:${teacher_dir}} + replacement_library_path: ${replacement_library_path} + save_models: true + solutions_path: # Filled dynamically + + # Validate params + skip_validation: false # To enable validation of the model solution set `skip_validation` as False + eval_samples: 128 + micro_batch_size: 1 + seed: 42 + shuffle_seed: 444 + dataset_path: ${dataset_path} + +nccl_timeout_minutes: ${timedelta_minutes:10} + +# This section redirects Hydra outputs +hydra: + run: + dir: ${puzzle_dir}/hydra_logs/${now:%Y-%m-%d}/${now:%H-%M-%S} diff --git a/examples/puzzletron/configs/mistral-small-24b-instruct-2501_pruneffn_memory/mistral-small-24b-instruct-2501_pruneffn_memory.yaml b/examples/puzzletron/configs/mistral-small-24b-instruct-2501_pruneffn_memory/mistral-small-24b-instruct-2501_pruneffn_memory.yaml new file mode 100644 index 0000000000..68a0652d6f --- /dev/null +++ b/examples/puzzletron/configs/mistral-small-24b-instruct-2501_pruneffn_memory/mistral-small-24b-instruct-2501_pruneffn_memory.yaml @@ -0,0 +1,21 @@ +defaults: + - Mistral-Small-24B + - _self_ + +# Input Hugging Face model to compress +input_hf_model_path: /workspace/hf_models/mistralai/Mistral-Small-24B-Instruct-2501 + +# Dataset path for pruning and NAS scoring +dataset_path: /workspace/datasets/Nemotron-Post-Training-Dataset-v2 + +# Working directory for compression outputs +puzzle_dir: /workspace/puzzle_dir + +# MIP memory constraint (in MiB) +mip: + human_constraints: + target_memory: 234_000 # 234 GiB + +# FFN intermediate sizes to search over (heterogeneous architecture) +pruning: + intermediate_size_list: [8192, 16384, 24576] # teacher_intermediate_size is 32768 diff --git a/examples/puzzletron/configs/mistral-small-24b-instruct-2501_pruneffn_memory/pruning/attn_pruning.yaml b/examples/puzzletron/configs/mistral-small-24b-instruct-2501_pruneffn_memory/pruning/attn_pruning.yaml new file mode 100644 index 0000000000..cb24e1bc24 --- /dev/null +++ b/examples/puzzletron/configs/mistral-small-24b-instruct-2501_pruneffn_memory/pruning/attn_pruning.yaml @@ -0,0 +1,17 @@ +defaults: + - pruning_defaults + +activations_log_dir: ${puzzle_dir}/pruning/pruning_scores/attn_${pruning.activation_hooks_kwargs.method}/${pruning.experiment_id} + +activation_hooks_kwargs: + method: independent_kv_head_contribution + optimize_for: memory # IndependentKvHeadContributionHook implementation that consumes less memory + target_layer: "self_attn.o_proj" + layer_input_descriptors_path: + +# Mistral Small 24B: 32 query heads, 8 KV heads +# n_heads_in_group = num_query_heads / num_kv_heads +# num_kv_heads = num_query_heads / n_heads_in_group +# Base: n_heads_in_group = 4, num_kv_heads = 8 +n_heads_in_group_list: [8, 16, 32] # num_kv_heads = [4, 2, 1] +gqa_init_mode: "PruneKVHeads" diff --git a/examples/puzzletron/configs/mistral-small-24b-instruct-2501_pruneffn_memory/pruning/ffn_pruning.yaml b/examples/puzzletron/configs/mistral-small-24b-instruct-2501_pruneffn_memory/pruning/ffn_pruning.yaml new file mode 100644 index 0000000000..5fb7fcbdd2 --- /dev/null +++ b/examples/puzzletron/configs/mistral-small-24b-instruct-2501_pruneffn_memory/pruning/ffn_pruning.yaml @@ -0,0 +1,20 @@ +defaults: + - pruning_defaults + +activations_log_dir: ${puzzle_dir}/pruning/pruning_scores/ffn_${pruning.activation_hooks_kwargs.method}/${pruning.experiment_id} + +pruning_mixin: + _target_: modelopt.torch.puzzletron.pruning.ffn_intermediate_pruning_mixin.FFNIntermediatePruningMixIn + layer_descriptor: + _target_: modelopt.torch.puzzletron.anymodel.models.mistral_small.mistral_small_model_descriptor.MistralFFNIntermediateLayerDescriptor + +hook_class: ${get_object:modelopt.torch.prune.importance_hooks.base_hooks.IterativeChannelContributionHook} +activation_hooks_kwargs: + method: iterative + target_layer: "mlp.down_proj" + layer_input_descriptors_path: + +# FFN intermediate sizes to search over (heterogeneous architecture) +# teacher_intermediate_size is 32768 +intermediate_size_list: [8192, 16384, 24576] +mlp_init_mode: "PruneByActivationsLog" diff --git a/examples/puzzletron/configs/mistral-small-24b-instruct-2501_pruneffn_memory/pruning/hidden_dim_pruning.yaml b/examples/puzzletron/configs/mistral-small-24b-instruct-2501_pruneffn_memory/pruning/hidden_dim_pruning.yaml new file mode 100644 index 0000000000..7de32621e0 --- /dev/null +++ b/examples/puzzletron/configs/mistral-small-24b-instruct-2501_pruneffn_memory/pruning/hidden_dim_pruning.yaml @@ -0,0 +1,16 @@ +defaults: + - pruning_defaults + +activations_log_dir: ${puzzle_dir}/pruning/pruning_scores/hidden_dim_${pruning.activation_hooks_kwargs.method}/${pruning.experiment_id} + +activation_hooks_kwargs: + method: layer_norm_contribution + target_layer: "layernorm" + +# Hidden dimension pruning specific settings +# Mistral Small 24B: hidden_size is 5120 +hidden_size_list: [3072, 4096] # Target hidden sizes to prune to +hidden_size_init_mode: "PruneByChannelRanking" +mlp_init_mode: "Truncate" # TODO, make it work with CopyAsIs/FromTeacher +gqa_init_mode: "AverageKV" # TODO, make it work with CopyAsIs/FromTeacher +linear_init_mode: "FromTeacher" diff --git a/examples/puzzletron/configs/mistral-small-24b-instruct-2501_pruneffn_memory/pruning/pruning_defaults.yaml b/examples/puzzletron/configs/mistral-small-24b-instruct-2501_pruneffn_memory/pruning/pruning_defaults.yaml new file mode 100644 index 0000000000..e05e775bee --- /dev/null +++ b/examples/puzzletron/configs/mistral-small-24b-instruct-2501_pruneffn_memory/pruning/pruning_defaults.yaml @@ -0,0 +1,33 @@ +defaults: + - /validate_model_defaults + +descriptor: ${descriptor} +model_name_or_path: ${teacher_dir} +experiment_id: ${pruning.eval_samples}samples_diverse_mini +activations_log_dir: ??? +activation_hooks_kwargs: ??? + +# Data: +eval_samples: 1000 # default is 10000 +micro_batch_size: 4 +dataset_path: ${dataset_path} +val_dataset_name: train + +# Prune ckpts +pruned_ckpts_output_dir: ${puzzle_dir}/pruning/${pruning.experiment_id} + +## FFN pruning +ffn_list: +mlp_init_mode: "Truncate" # PruneByActivationsLog + +## KV-heads pruning +n_heads_in_group_list: +gqa_init_mode: "AverageKV" + +## Hidden dimension pruning +hidden_size_list: +hidden_size_init_mode: "PruneByChannelRanking" +linear_init_mode: "FromTeacher" + +mlp_init_config_yaml: + activations_log_dir: ${pruning.activations_log_dir} diff --git a/examples/puzzletron/configs/mistral-small-24b-instruct-2501_pruneffn_memory/validate_model_defaults.yaml b/examples/puzzletron/configs/mistral-small-24b-instruct-2501_pruneffn_memory/validate_model_defaults.yaml new file mode 100644 index 0000000000..ce1749d969 --- /dev/null +++ b/examples/puzzletron/configs/mistral-small-24b-instruct-2501_pruneffn_memory/validate_model_defaults.yaml @@ -0,0 +1,17 @@ +model_dtype: torch.bfloat16 # dtype to cast the model for validate_model +autocast_dtype: torch.bfloat16 # dtype for torch.autocast for validate_model +block_size: 8192 +bos_rate: 0.5 +data_column: messages +val_dataset_name: valid +shuffle_seed: 81436 +seed: 42 +fim_rate: 0 +fim_spm_rate: 0 +source_datasets_to_discard: +varlen: false +write_results: false +calc_losses_on_cpu: false +activations_log_dir: +model_name_or_path: +load_dataset_fn: ${get_object:modelopt.torch.puzzletron.utils.data.dataloaders.load_from_disk_fn} diff --git a/examples/puzzletron/configs/mistral-small-24b-instruct-2501_pruneffn_memory/validate_solutions_defaults.yaml b/examples/puzzletron/configs/mistral-small-24b-instruct-2501_pruneffn_memory/validate_solutions_defaults.yaml new file mode 100644 index 0000000000..ec13902379 --- /dev/null +++ b/examples/puzzletron/configs/mistral-small-24b-instruct-2501_pruneffn_memory/validate_solutions_defaults.yaml @@ -0,0 +1,10 @@ +defaults: + - /validate_model_defaults + - _self_ + +solutions_to_validate: +skip_validation: false +save_models: false +bigger_is_better: false +sort_solutions_by: +calculate_full_score_ablations: false diff --git a/examples/puzzletron/configs/nemotron-nano-12b-v2/nemotron_nano_12b_v2.yaml b/examples/puzzletron/configs/nemotron-nano-12b-v2/nemotron_nano_12b_v2.yaml new file mode 100644 index 0000000000..62b6ecb4cb --- /dev/null +++ b/examples/puzzletron/configs/nemotron-nano-12b-v2/nemotron_nano_12b_v2.yaml @@ -0,0 +1,109 @@ +defaults: + - pruning: ffn_pruning + - scoring: ../validate_solutions_defaults + - realize_model: ../validate_solutions_defaults + - bypass: + - override hydra/hydra_logging: disabled + - _self_ + +puzzle_dir: ??? +descriptor: nemotron_h_v2 +teacher_dir: ${puzzle_dir}/ckpts/teacher/ +replacement_library_path: ${puzzle_dir}/replacement_library.json +dataset_path: ??? # path to Nemotron-Post-Training-Dataset-v2 + +skip_realize_model: false + +build_replacement_library: + add_ffn_no_ops: true + add_attention_no_ops: true + +calc_subblock_stats: + batch_sizes: [64, 96, 128] + prefill_seq_len: 4096 + generation_seq_len: 4096 + num_active_tokens_override: # Optional override for sequence lengths + prefill_queue_size: 0 + allocate_prefill_query: false + runtime_stats: + backend: trt_torch + benchmark_iterations: # Set to a number (e.g., 1000) to enable runtime benchmarking + merge_with_existing_stats: false + subblock_stats_filename: "subblock_stats.json" + moe_stats_filename: "moe_stats.json" + +scoring: + descriptor: ${descriptor} + solutions_to_validate: + skip_existing_solutions: true + + replacement_library_path: ${replacement_library_path} + solutions_path: ${to_path:${puzzle_dir}/single_sequence_replacement_solutions.json} + teacher_dir: ${to_path:${teacher_dir}} + output_dir: ${puzzle_dir}/single_sequence_replacement_solutions--validation + + eval_samples: 128 + micro_batch_size: 1 + seed: 42 + shuffle_seed: 444 + dataset_path: ${dataset_path} + +mip: + single_block_replacement_validation_dir: ${to_path:${scoring.output_dir}} + subblock_stats_path: ${to_path:${puzzle_dir}/${calc_subblock_stats.subblock_stats_filename}} + output_path: ${to_path:${puzzle_dir}/mip/puzzle_solutions} + gathered_metrics_path: + puzzle_profile: + + # puzzle_profile: + objective: metrics.cosine_embedding_loss_hidden_states + bigger_is_better: false + + subblock_stats_args: + - batch_size: 96 + weights_dtype: torch.bfloat16 + activations_dtype: torch.bfloat16 + kv_cache_dtype: torch.bfloat16 + + report_additional_costs: + - stats.memory_mib + - stats.num_params + - stats.num_kv_heads + - stats.has_attention + - stats.has_ffn + - stats.kv_cache_memory_mib + - stats.attention_memory_mib + - stats.ffn_memory_mib + - stats.ffn_num_params + - stats.attention_num_params + + human_constraints: + target_memory: 90_000 + num_params: 12_000_000_000 + + mip_constraints: + metric_overrides: + max_seconds_per_solution: 60 + +realize_model: + descriptor: ${descriptor} + teacher_dir: ${to_path:${teacher_dir}} + tokenizer_name: ${to_path:${teacher_dir}} + replacement_library_path: ${replacement_library_path} + save_models: true + solutions_path: # Filled dynamically + + # Validate params + skip_validation: false # To enable validation of the model solution set `skip_validation` as False + eval_samples: 128 + micro_batch_size: 1 + seed: 42 + shuffle_seed: 444 + dataset_path: ${dataset_path} + +nccl_timeout_minutes: ${timedelta_minutes:10} + +# This section redirects Hydra outputs +hydra: + run: + dir: ${puzzle_dir}/hydra_logs/${now:%Y-%m-%d}/${now:%H-%M-%S} diff --git a/examples/puzzletron/configs/nemotron-nano-12b-v2/nemotron_nano_12b_v2_pruneffn_memory.yaml b/examples/puzzletron/configs/nemotron-nano-12b-v2/nemotron_nano_12b_v2_pruneffn_memory.yaml new file mode 100644 index 0000000000..3b880b2c7d --- /dev/null +++ b/examples/puzzletron/configs/nemotron-nano-12b-v2/nemotron_nano_12b_v2_pruneffn_memory.yaml @@ -0,0 +1,22 @@ +defaults: + - nemotron_nano_12b_v2 + - _self_ + +# Input Hugging Face model to compress +input_hf_model_path: /workspace/hf_models/nvidia/Nemotron-Nano-12B-v2 + +# Dataset path for pruning and NAS scoring +dataset_path: /workspace/datasets/Nemotron-Post-Training-Dataset-v2 + +# Working directory for compression outputs +puzzle_dir: /workspace/puzzle_dir + +# MIP memory constraint (in MiB) +mip: + human_constraints: + target_memory: 90_000 # 90 GiB + +# FFN intermediate sizes to search over (heterogeneous architecture) +# teacher_intermediate_size is 20480 +pruning: + intermediate_size_list: [4352, 8448, 12544, 16384] diff --git a/examples/puzzletron/configs/nemotron-nano-12b-v2/pruning/attn_pruning.yaml b/examples/puzzletron/configs/nemotron-nano-12b-v2/pruning/attn_pruning.yaml new file mode 100644 index 0000000000..01886607e4 --- /dev/null +++ b/examples/puzzletron/configs/nemotron-nano-12b-v2/pruning/attn_pruning.yaml @@ -0,0 +1,16 @@ +defaults: + - pruning_defaults + +activations_log_dir: ${puzzle_dir}/pruning/pruning_scores/attn_${pruning.activation_hooks_kwargs.method}/${pruning.experiment_id} + +activation_hooks_kwargs: + method: independent_kv_head_contribution + optimize_for: memory # IndependentKvHeadContributionHook implementation that consumes less memory + target_layer: "self_attn.o_proj" + layer_input_descriptors_path: + +# n_heads_in_group: 4 +# num_attention_heads: 32 # num query heads +# num_kv_heads: 32 / 4 = 8 # num_query_heads // n_heads_in_group +n_heads_in_group_list: [8, 16, 32] # num_kv_heads = [4, 2, 1] +gqa_init_mode: "PruneKVHeads" diff --git a/examples/puzzletron/configs/nemotron-nano-12b-v2/pruning/ffn_pruning.yaml b/examples/puzzletron/configs/nemotron-nano-12b-v2/pruning/ffn_pruning.yaml new file mode 100644 index 0000000000..1e2ecf07a0 --- /dev/null +++ b/examples/puzzletron/configs/nemotron-nano-12b-v2/pruning/ffn_pruning.yaml @@ -0,0 +1,18 @@ +defaults: + - pruning_defaults + +activations_log_dir: ${puzzle_dir}/pruning/pruning_scores/ffn_${pruning.activation_hooks_kwargs.method}/${pruning.experiment_id} + +pruning_mixin: + _target_: modelopt.torch.puzzletron.pruning.ffn_intermediate_pruning_mixin.FFNIntermediatePruningMixIn + layer_descriptor: + _target_: modelopt.torch.puzzletron.anymodel.models.nemotron_h_v2.nemotron_h_v2_model_descriptor.NemotronHV2FFNIntermediateLayerDescriptor + +hook_class: ${get_object:modelopt.torch.prune.importance_hooks.base_hooks.IterativeChannelContributionHook} +activation_hooks_kwargs: + method: iterative + target_layer: "mixer.down_proj" + layer_input_descriptors_path: + +intermediate_size_list: [256] # teacher_intermediate_size is 14336 +mlp_init_mode: "PruneByActivationsLog" diff --git a/examples/puzzletron/configs/nemotron-nano-12b-v2/pruning/hidden_dim_pruning.yaml b/examples/puzzletron/configs/nemotron-nano-12b-v2/pruning/hidden_dim_pruning.yaml new file mode 100644 index 0000000000..407c835d8c --- /dev/null +++ b/examples/puzzletron/configs/nemotron-nano-12b-v2/pruning/hidden_dim_pruning.yaml @@ -0,0 +1,15 @@ +defaults: + - pruning_defaults + +activations_log_dir: ${puzzle_dir}/pruning/pruning_scores/hidden_dim_${pruning.activation_hooks_kwargs.method}/${pruning.experiment_id} + +activation_hooks_kwargs: + method: layer_norm_contribution + target_layer: "layernorm" + +# Hidden dimension pruning specific settings +hidden_size_list: [3072, 2048] # Target hidden sizes to prune to +hidden_size_init_mode: "PruneByChannelRanking" +mlp_init_mode: "Truncate" # TODO, make it work with CopyAsIs/FromTeacher +gqa_init_mode: "AverageKV" # TODO, make it work with CopyAsIs/FromTeacher +linear_init_mode: "FromTeacher" diff --git a/examples/puzzletron/configs/nemotron-nano-12b-v2/pruning/pruning_defaults.yaml b/examples/puzzletron/configs/nemotron-nano-12b-v2/pruning/pruning_defaults.yaml new file mode 100644 index 0000000000..8816eecc4a --- /dev/null +++ b/examples/puzzletron/configs/nemotron-nano-12b-v2/pruning/pruning_defaults.yaml @@ -0,0 +1,34 @@ +defaults: + - /validate_model_defaults + +model_name_or_path: ${teacher_dir} +experiment_id: ${pruning.eval_samples}samples_diverse_mini +activations_log_dir: ??? +activation_hooks_kwargs: ??? + +descriptor: ${descriptor} + +# Data: +eval_samples: 1000 # default is 10000 +micro_batch_size: 4 +dataset_path: ${dataset_path} +val_dataset_name: train + +# Prune ckpts +pruned_ckpts_output_dir: ${puzzle_dir}/pruning/${pruning.experiment_id} + +## FFN pruning +ffn_list: +mlp_init_mode: "Truncate" + +## KV-heads pruning +n_heads_in_group_list: +gqa_init_mode: "AverageKV" + +## Hidden dimension pruning +hidden_size_list: +hidden_size_init_mode: "PruneByChannelRanking" +linear_init_mode: "FromTeacher" + +mlp_init_config_yaml: + activations_log_dir: ${pruning.activations_log_dir} diff --git a/examples/puzzletron/configs/nemotron-nano-12b-v2/validate_model_defaults.yaml b/examples/puzzletron/configs/nemotron-nano-12b-v2/validate_model_defaults.yaml new file mode 100644 index 0000000000..ce1749d969 --- /dev/null +++ b/examples/puzzletron/configs/nemotron-nano-12b-v2/validate_model_defaults.yaml @@ -0,0 +1,17 @@ +model_dtype: torch.bfloat16 # dtype to cast the model for validate_model +autocast_dtype: torch.bfloat16 # dtype for torch.autocast for validate_model +block_size: 8192 +bos_rate: 0.5 +data_column: messages +val_dataset_name: valid +shuffle_seed: 81436 +seed: 42 +fim_rate: 0 +fim_spm_rate: 0 +source_datasets_to_discard: +varlen: false +write_results: false +calc_losses_on_cpu: false +activations_log_dir: +model_name_or_path: +load_dataset_fn: ${get_object:modelopt.torch.puzzletron.utils.data.dataloaders.load_from_disk_fn} diff --git a/examples/puzzletron/configs/nemotron-nano-12b-v2/validate_solutions_defaults.yaml b/examples/puzzletron/configs/nemotron-nano-12b-v2/validate_solutions_defaults.yaml new file mode 100644 index 0000000000..ec13902379 --- /dev/null +++ b/examples/puzzletron/configs/nemotron-nano-12b-v2/validate_solutions_defaults.yaml @@ -0,0 +1,10 @@ +defaults: + - /validate_model_defaults + - _self_ + +solutions_to_validate: +skip_validation: false +save_models: false +bigger_is_better: false +sort_solutions_by: +calculate_full_score_ablations: false diff --git a/examples/puzzletron/configs/qwen2_5_7b_instruct_pruneffn_memory/pruning/attn_pruning.yaml b/examples/puzzletron/configs/qwen2_5_7b_instruct_pruneffn_memory/pruning/attn_pruning.yaml new file mode 100644 index 0000000000..3f7a248ee7 --- /dev/null +++ b/examples/puzzletron/configs/qwen2_5_7b_instruct_pruneffn_memory/pruning/attn_pruning.yaml @@ -0,0 +1,16 @@ +defaults: + - pruning_defaults + +activations_log_dir: ${puzzle_dir}/pruning/pruning_scores/attn_${modelopt.torch.puzzletron.pruning.activation_hooks_kwargs.method}/${modelopt.torch.puzzletron.pruning.experiment_id} + +activation_hooks_kwargs: + method: independent_kv_head_contribution + optimize_for: memory # IndependentKvHeadContributionHook implementation that consumes less memory + target_layer: "self_attn.o_proj" + layer_input_descriptors_path: + +# n_heads_in_group: 4 +# num_attention_heads: 32 # num query heads +# num_kv_heads: 32 / 4 = 8 # num_query_heads // n_heads_in_group +n_heads_in_group_list: [8, 16, 32] # num_kv_heads = [4, 2, 1] +gqa_init_mode: "PruneKVHeads" diff --git a/examples/puzzletron/configs/qwen2_5_7b_instruct_pruneffn_memory/pruning/ffn_pruning.yaml b/examples/puzzletron/configs/qwen2_5_7b_instruct_pruneffn_memory/pruning/ffn_pruning.yaml new file mode 100644 index 0000000000..18d7e234ac --- /dev/null +++ b/examples/puzzletron/configs/qwen2_5_7b_instruct_pruneffn_memory/pruning/ffn_pruning.yaml @@ -0,0 +1,18 @@ +defaults: + - pruning_defaults + +activations_log_dir: ${puzzle_dir}/pruning/pruning_scores/ffn_${pruning.activation_hooks_kwargs.method}/${pruning.experiment_id} + +pruning_mixin: + _target_: modelopt.torch.puzzletron.pruning.ffn_intermediate_pruning_mixin.FFNIntermediatePruningMixIn + layer_descriptor: + _target_: modelopt.torch.puzzletron.anymodel.models.qwen2.qwen2_model_descriptor.Qwen2FFNIntermediateLayerDescriptor + +hook_class: ${get_object:modelopt.torch.prune.importance_hooks.base_hooks.IterativeChannelContributionHook} +activation_hooks_kwargs: + method: iterative + target_layer: "mlp.down_proj" + layer_input_descriptors_path: + +intermediate_size_list: [256] # teacher_intermediate_size is 14336 +mlp_init_mode: "PruneByActivationsLog" diff --git a/examples/puzzletron/configs/qwen2_5_7b_instruct_pruneffn_memory/pruning/hidden_dim_pruning.yaml b/examples/puzzletron/configs/qwen2_5_7b_instruct_pruneffn_memory/pruning/hidden_dim_pruning.yaml new file mode 100644 index 0000000000..af8af990b7 --- /dev/null +++ b/examples/puzzletron/configs/qwen2_5_7b_instruct_pruneffn_memory/pruning/hidden_dim_pruning.yaml @@ -0,0 +1,15 @@ +defaults: + - pruning_defaults + +activations_log_dir: ${puzzle_dir}/pruning/pruning_scores/hidden_dim_${modelopt.torch.puzzletron.pruning.activation_hooks_kwargs.method}/${modelopt.torch.puzzletron.pruning.experiment_id} + +activation_hooks_kwargs: + method: layer_norm_contribution + target_layer: "layernorm" + +# Hidden dimension pruning specific settings +hidden_size_list: [3072, 2048] # Target hidden sizes to prune to +hidden_size_init_mode: "PruneByChannelRanking" +mlp_init_mode: "Truncate" # TODO, make it work with CopyAsIs/FromTeacher +gqa_init_mode: "AverageKV" # TODO, make it work with CopyAsIs/FromTeacher +linear_init_mode: "FromTeacher" diff --git a/examples/puzzletron/configs/qwen2_5_7b_instruct_pruneffn_memory/pruning/pruning_defaults.yaml b/examples/puzzletron/configs/qwen2_5_7b_instruct_pruneffn_memory/pruning/pruning_defaults.yaml new file mode 100644 index 0000000000..8816eecc4a --- /dev/null +++ b/examples/puzzletron/configs/qwen2_5_7b_instruct_pruneffn_memory/pruning/pruning_defaults.yaml @@ -0,0 +1,34 @@ +defaults: + - /validate_model_defaults + +model_name_or_path: ${teacher_dir} +experiment_id: ${pruning.eval_samples}samples_diverse_mini +activations_log_dir: ??? +activation_hooks_kwargs: ??? + +descriptor: ${descriptor} + +# Data: +eval_samples: 1000 # default is 10000 +micro_batch_size: 4 +dataset_path: ${dataset_path} +val_dataset_name: train + +# Prune ckpts +pruned_ckpts_output_dir: ${puzzle_dir}/pruning/${pruning.experiment_id} + +## FFN pruning +ffn_list: +mlp_init_mode: "Truncate" + +## KV-heads pruning +n_heads_in_group_list: +gqa_init_mode: "AverageKV" + +## Hidden dimension pruning +hidden_size_list: +hidden_size_init_mode: "PruneByChannelRanking" +linear_init_mode: "FromTeacher" + +mlp_init_config_yaml: + activations_log_dir: ${pruning.activations_log_dir} diff --git a/examples/puzzletron/configs/qwen2_5_7b_instruct_pruneffn_memory/qwen2_5_7b_instruct.yaml b/examples/puzzletron/configs/qwen2_5_7b_instruct_pruneffn_memory/qwen2_5_7b_instruct.yaml new file mode 100644 index 0000000000..aa11499a3c --- /dev/null +++ b/examples/puzzletron/configs/qwen2_5_7b_instruct_pruneffn_memory/qwen2_5_7b_instruct.yaml @@ -0,0 +1,109 @@ +defaults: + - pruning: ffn_pruning + - scoring: ../validate_solutions_defaults + - realize_model: ../validate_solutions_defaults + - bypass: + - override hydra/hydra_logging: disabled + - _self_ + +puzzle_dir: ??? +descriptor: qwen2 +teacher_dir: ${puzzle_dir}/ckpts/teacher/ +replacement_library_path: ${puzzle_dir}/replacement_library.json +dataset_path: ??? # path to Nemotron-Post-Training-Dataset-v2 + +skip_realize_model: false + +build_replacement_library: + add_ffn_no_ops: true + add_attention_no_ops: true + +calc_subblock_stats: + batch_sizes: [64, 96, 128] + prefill_seq_len: 4096 + generation_seq_len: 4096 + num_active_tokens_override: # Optional override for sequence lengths + prefill_queue_size: 0 + allocate_prefill_query: false + runtime_stats: + backend: trt_torch + benchmark_iterations: # Set to a number (e.g., 1000) to enable runtime benchmarking + merge_with_existing_stats: false + subblock_stats_filename: "subblock_stats.json" + moe_stats_filename: "moe_stats.json" + +scoring: + descriptor: ${descriptor} + solutions_to_validate: + skip_existing_solutions: true + + replacement_library_path: ${replacement_library_path} + solutions_path: ${to_path:${puzzle_dir}/single_sequence_replacement_solutions.json} + teacher_dir: ${to_path:${teacher_dir}} + output_dir: ${puzzle_dir}/single_sequence_replacement_solutions--validation + + eval_samples: 128 + micro_batch_size: 1 + seed: 42 + shuffle_seed: 444 + dataset_path: ${dataset_path} + +mip: + single_block_replacement_validation_dir: ${to_path:${scoring.output_dir}} + subblock_stats_path: ${to_path:${puzzle_dir}/${calc_subblock_stats.subblock_stats_filename}} + output_path: ${to_path:${puzzle_dir}/mip/puzzle_solutions} + gathered_metrics_path: + puzzle_profile: + + # puzzle_profile: + objective: metrics.cosine_embedding_loss_hidden_states + bigger_is_better: false + + subblock_stats_args: + - batch_size: 96 + weights_dtype: torch.bfloat16 + activations_dtype: torch.bfloat16 + kv_cache_dtype: torch.bfloat16 + + report_additional_costs: + - stats.memory_mib + - stats.num_params + - stats.num_kv_heads + - stats.has_attention + - stats.has_ffn + - stats.kv_cache_memory_mib + - stats.attention_memory_mib + - stats.ffn_memory_mib + - stats.ffn_num_params + - stats.attention_num_params + + human_constraints: + target_memory: 78_000 + num_params: 7_000_000_000 + + mip_constraints: + metric_overrides: + max_seconds_per_solution: 60 + +realize_model: + descriptor: ${descriptor} + teacher_dir: ${to_path:${teacher_dir}} + tokenizer_name: ${to_path:${teacher_dir}} + replacement_library_path: ${replacement_library_path} + save_models: true + solutions_path: # Filled dynamically + + # Validate params + skip_validation: false # To enable validation of the model solution set `skip_validation` as False + eval_samples: 128 + micro_batch_size: 1 + seed: 42 + shuffle_seed: 444 + dataset_path: ${dataset_path} + +nccl_timeout_minutes: ${timedelta_minutes:10} + +# This section redirects Hydra outputs +hydra: + run: + dir: ${puzzle_dir}/hydra_logs/${now:%Y-%m-%d}/${now:%H-%M-%S} diff --git a/examples/puzzletron/configs/qwen2_5_7b_instruct_pruneffn_memory/qwen2_5_7b_instruct_pruneffn_memory.yaml b/examples/puzzletron/configs/qwen2_5_7b_instruct_pruneffn_memory/qwen2_5_7b_instruct_pruneffn_memory.yaml new file mode 100644 index 0000000000..fb961033bc --- /dev/null +++ b/examples/puzzletron/configs/qwen2_5_7b_instruct_pruneffn_memory/qwen2_5_7b_instruct_pruneffn_memory.yaml @@ -0,0 +1,22 @@ +defaults: + - qwen2_5_7b_instruct + - _self_ + +# Input Hugging Face model to compress +input_hf_model_path: /workspace/hf_models/Qwen/Qwen2.5-7B-Instruct + +# Dataset path for pruning and NAS scoring +dataset_path: /workspace/datasets/Nemotron-Post-Training-Dataset-v2 + +# Working directory for compression outputs +puzzle_dir: /workspace/puzzle_dir + +# MIP memory constraint (in MiB) +mip: + human_constraints: + target_memory: 78_000 # 78 GiB + +# FFN intermediate sizes to search over (heterogeneous architecture) +# teacher_intermediate_size is 18944 +pruning: + intermediate_size_list: [4096, 7808, 11520, 15104] diff --git a/examples/puzzletron/configs/qwen2_5_7b_instruct_pruneffn_memory/validate_model_defaults.yaml b/examples/puzzletron/configs/qwen2_5_7b_instruct_pruneffn_memory/validate_model_defaults.yaml new file mode 100644 index 0000000000..ce1749d969 --- /dev/null +++ b/examples/puzzletron/configs/qwen2_5_7b_instruct_pruneffn_memory/validate_model_defaults.yaml @@ -0,0 +1,17 @@ +model_dtype: torch.bfloat16 # dtype to cast the model for validate_model +autocast_dtype: torch.bfloat16 # dtype for torch.autocast for validate_model +block_size: 8192 +bos_rate: 0.5 +data_column: messages +val_dataset_name: valid +shuffle_seed: 81436 +seed: 42 +fim_rate: 0 +fim_spm_rate: 0 +source_datasets_to_discard: +varlen: false +write_results: false +calc_losses_on_cpu: false +activations_log_dir: +model_name_or_path: +load_dataset_fn: ${get_object:modelopt.torch.puzzletron.utils.data.dataloaders.load_from_disk_fn} diff --git a/examples/puzzletron/configs/qwen2_5_7b_instruct_pruneffn_memory/validate_solutions_defaults.yaml b/examples/puzzletron/configs/qwen2_5_7b_instruct_pruneffn_memory/validate_solutions_defaults.yaml new file mode 100644 index 0000000000..ec13902379 --- /dev/null +++ b/examples/puzzletron/configs/qwen2_5_7b_instruct_pruneffn_memory/validate_solutions_defaults.yaml @@ -0,0 +1,10 @@ +defaults: + - /validate_model_defaults + - _self_ + +solutions_to_validate: +skip_validation: false +save_models: false +bigger_is_better: false +sort_solutions_by: +calculate_full_score_ablations: false diff --git a/examples/puzzletron/configs/qwen3-8b_pruneffn_memory/pruning/attn_pruning.yaml b/examples/puzzletron/configs/qwen3-8b_pruneffn_memory/pruning/attn_pruning.yaml new file mode 100644 index 0000000000..01886607e4 --- /dev/null +++ b/examples/puzzletron/configs/qwen3-8b_pruneffn_memory/pruning/attn_pruning.yaml @@ -0,0 +1,16 @@ +defaults: + - pruning_defaults + +activations_log_dir: ${puzzle_dir}/pruning/pruning_scores/attn_${pruning.activation_hooks_kwargs.method}/${pruning.experiment_id} + +activation_hooks_kwargs: + method: independent_kv_head_contribution + optimize_for: memory # IndependentKvHeadContributionHook implementation that consumes less memory + target_layer: "self_attn.o_proj" + layer_input_descriptors_path: + +# n_heads_in_group: 4 +# num_attention_heads: 32 # num query heads +# num_kv_heads: 32 / 4 = 8 # num_query_heads // n_heads_in_group +n_heads_in_group_list: [8, 16, 32] # num_kv_heads = [4, 2, 1] +gqa_init_mode: "PruneKVHeads" diff --git a/examples/puzzletron/configs/qwen3-8b_pruneffn_memory/pruning/ffn_pruning.yaml b/examples/puzzletron/configs/qwen3-8b_pruneffn_memory/pruning/ffn_pruning.yaml new file mode 100644 index 0000000000..93590d13e5 --- /dev/null +++ b/examples/puzzletron/configs/qwen3-8b_pruneffn_memory/pruning/ffn_pruning.yaml @@ -0,0 +1,18 @@ +defaults: + - pruning_defaults + +activations_log_dir: ${puzzle_dir}/pruning/pruning_scores/ffn_${pruning.activation_hooks_kwargs.method}/${pruning.experiment_id} + +pruning_mixin: + _target_: modelopt.torch.puzzletron.pruning.ffn_intermediate_pruning_mixin.FFNIntermediatePruningMixIn + layer_descriptor: + _target_: modelopt.torch.puzzletron.anymodel.models.qwen3.qwen3_model_descriptor.Qwen3FFNIntermediateLayerDescriptor + +hook_class: ${get_object:modelopt.torch.prune.importance_hooks.base_hooks.IterativeChannelContributionHook} +activation_hooks_kwargs: + method: iterative + target_layer: "mlp.down_proj" + layer_input_descriptors_path: + +intermediate_size_list: [256] # teacher_intermediate_size is 14336 +mlp_init_mode: "PruneByActivationsLog" diff --git a/examples/puzzletron/configs/qwen3-8b_pruneffn_memory/pruning/hidden_dim_pruning.yaml b/examples/puzzletron/configs/qwen3-8b_pruneffn_memory/pruning/hidden_dim_pruning.yaml new file mode 100644 index 0000000000..407c835d8c --- /dev/null +++ b/examples/puzzletron/configs/qwen3-8b_pruneffn_memory/pruning/hidden_dim_pruning.yaml @@ -0,0 +1,15 @@ +defaults: + - pruning_defaults + +activations_log_dir: ${puzzle_dir}/pruning/pruning_scores/hidden_dim_${pruning.activation_hooks_kwargs.method}/${pruning.experiment_id} + +activation_hooks_kwargs: + method: layer_norm_contribution + target_layer: "layernorm" + +# Hidden dimension pruning specific settings +hidden_size_list: [3072, 2048] # Target hidden sizes to prune to +hidden_size_init_mode: "PruneByChannelRanking" +mlp_init_mode: "Truncate" # TODO, make it work with CopyAsIs/FromTeacher +gqa_init_mode: "AverageKV" # TODO, make it work with CopyAsIs/FromTeacher +linear_init_mode: "FromTeacher" diff --git a/examples/puzzletron/configs/qwen3-8b_pruneffn_memory/pruning/pruning_defaults.yaml b/examples/puzzletron/configs/qwen3-8b_pruneffn_memory/pruning/pruning_defaults.yaml new file mode 100644 index 0000000000..8816eecc4a --- /dev/null +++ b/examples/puzzletron/configs/qwen3-8b_pruneffn_memory/pruning/pruning_defaults.yaml @@ -0,0 +1,34 @@ +defaults: + - /validate_model_defaults + +model_name_or_path: ${teacher_dir} +experiment_id: ${pruning.eval_samples}samples_diverse_mini +activations_log_dir: ??? +activation_hooks_kwargs: ??? + +descriptor: ${descriptor} + +# Data: +eval_samples: 1000 # default is 10000 +micro_batch_size: 4 +dataset_path: ${dataset_path} +val_dataset_name: train + +# Prune ckpts +pruned_ckpts_output_dir: ${puzzle_dir}/pruning/${pruning.experiment_id} + +## FFN pruning +ffn_list: +mlp_init_mode: "Truncate" + +## KV-heads pruning +n_heads_in_group_list: +gqa_init_mode: "AverageKV" + +## Hidden dimension pruning +hidden_size_list: +hidden_size_init_mode: "PruneByChannelRanking" +linear_init_mode: "FromTeacher" + +mlp_init_config_yaml: + activations_log_dir: ${pruning.activations_log_dir} diff --git a/examples/puzzletron/configs/qwen3-8b_pruneffn_memory/qwen3_8b.yaml b/examples/puzzletron/configs/qwen3-8b_pruneffn_memory/qwen3_8b.yaml new file mode 100644 index 0000000000..eec82a7d63 --- /dev/null +++ b/examples/puzzletron/configs/qwen3-8b_pruneffn_memory/qwen3_8b.yaml @@ -0,0 +1,109 @@ +defaults: + - pruning: ffn_pruning + - scoring: ../validate_solutions_defaults + - realize_model: ../validate_solutions_defaults + - bypass: + - override hydra/hydra_logging: disabled + - _self_ + +puzzle_dir: ??? +descriptor: qwen3 +teacher_dir: ${puzzle_dir}/ckpts/teacher/ +replacement_library_path: ${puzzle_dir}/replacement_library.json +dataset_path: ??? # path to Nemotron-Post-Training-Dataset-v2 + +skip_realize_model: false + +build_replacement_library: + add_ffn_no_ops: true + add_attention_no_ops: true + +calc_subblock_stats: + batch_sizes: [64, 96, 128] + prefill_seq_len: 4096 + generation_seq_len: 4096 + num_active_tokens_override: # Optional override for sequence lengths + prefill_queue_size: 0 + allocate_prefill_query: false + runtime_stats: + backend: trt_torch + benchmark_iterations: # Set to a number (e.g., 1000) to enable runtime benchmarking + merge_with_existing_stats: false + subblock_stats_filename: "subblock_stats.json" + moe_stats_filename: "moe_stats.json" + +scoring: + descriptor: ${descriptor} + solutions_to_validate: + skip_existing_solutions: true + + replacement_library_path: ${replacement_library_path} + solutions_path: ${to_path:${puzzle_dir}/single_sequence_replacement_solutions.json} + teacher_dir: ${to_path:${teacher_dir}} + output_dir: ${puzzle_dir}/single_sequence_replacement_solutions--validation + + eval_samples: 128 + micro_batch_size: 1 + seed: 42 + shuffle_seed: 444 + dataset_path: ${dataset_path} + +mip: + single_block_replacement_validation_dir: ${to_path:${scoring.output_dir}} + subblock_stats_path: ${to_path:${puzzle_dir}/${calc_subblock_stats.subblock_stats_filename}} + output_path: ${to_path:${puzzle_dir}/mip/puzzle_solutions} + gathered_metrics_path: + puzzle_profile: + + # puzzle_profile: + objective: metrics.cosine_embedding_loss_hidden_states + bigger_is_better: false + + subblock_stats_args: + - batch_size: 96 + weights_dtype: torch.bfloat16 + activations_dtype: torch.bfloat16 + kv_cache_dtype: torch.bfloat16 + + report_additional_costs: + - stats.memory_mib + - stats.num_params + - stats.num_kv_heads + - stats.has_attention + - stats.has_ffn + - stats.kv_cache_memory_mib + - stats.attention_memory_mib + - stats.ffn_memory_mib + - stats.ffn_num_params + - stats.attention_num_params + + human_constraints: + target_memory: 78_000 + num_params: 8_000_000_000 + + mip_constraints: + metric_overrides: + max_seconds_per_solution: 60 + +realize_model: + descriptor: ${descriptor} + teacher_dir: ${to_path:${teacher_dir}} + tokenizer_name: ${to_path:${teacher_dir}} + replacement_library_path: ${replacement_library_path} + save_models: true + solutions_path: # Filled dynamically + + # Validate params + skip_validation: false # To enable validation of the model solution set `skip_validation` as False + eval_samples: 128 + micro_batch_size: 1 + seed: 42 + shuffle_seed: 444 + dataset_path: ${dataset_path} + +nccl_timeout_minutes: ${timedelta_minutes:10} + +# This section redirects Hydra outputs +hydra: + run: + dir: ${puzzle_dir}/hydra_logs/${now:%Y-%m-%d}/${now:%H-%M-%S} diff --git a/examples/puzzletron/configs/qwen3-8b_pruneffn_memory/qwen3_8b_pruneffn_memory.yaml b/examples/puzzletron/configs/qwen3-8b_pruneffn_memory/qwen3_8b_pruneffn_memory.yaml new file mode 100644 index 0000000000..4ee81286dd --- /dev/null +++ b/examples/puzzletron/configs/qwen3-8b_pruneffn_memory/qwen3_8b_pruneffn_memory.yaml @@ -0,0 +1,22 @@ +defaults: + - qwen3_8b + - _self_ + +# Input Hugging Face model to compress +input_hf_model_path: /workspace/hf_models/Qwen/Qwen3-8B + +# Dataset path for pruning and NAS scoring +dataset_path: /workspace/datasets/Nemotron-Post-Training-Dataset-v2 + +# Working directory for compression outputs +puzzle_dir: /workspace/puzzle_dir + +# MIP memory constraint (in MiB) +mip: + human_constraints: + target_memory: 78_000 # 78 GiB + +# FFN intermediate sizes to search over (heterogeneous architecture) +# teacher_intermediate_size is 12288 +pruning: + intermediate_size_list: [2560, 5120, 7424, 9984] diff --git a/examples/puzzletron/configs/qwen3-8b_pruneffn_memory/validate_model_defaults.yaml b/examples/puzzletron/configs/qwen3-8b_pruneffn_memory/validate_model_defaults.yaml new file mode 100644 index 0000000000..ce1749d969 --- /dev/null +++ b/examples/puzzletron/configs/qwen3-8b_pruneffn_memory/validate_model_defaults.yaml @@ -0,0 +1,17 @@ +model_dtype: torch.bfloat16 # dtype to cast the model for validate_model +autocast_dtype: torch.bfloat16 # dtype for torch.autocast for validate_model +block_size: 8192 +bos_rate: 0.5 +data_column: messages +val_dataset_name: valid +shuffle_seed: 81436 +seed: 42 +fim_rate: 0 +fim_spm_rate: 0 +source_datasets_to_discard: +varlen: false +write_results: false +calc_losses_on_cpu: false +activations_log_dir: +model_name_or_path: +load_dataset_fn: ${get_object:modelopt.torch.puzzletron.utils.data.dataloaders.load_from_disk_fn} diff --git a/examples/puzzletron/configs/qwen3-8b_pruneffn_memory/validate_solutions_defaults.yaml b/examples/puzzletron/configs/qwen3-8b_pruneffn_memory/validate_solutions_defaults.yaml new file mode 100644 index 0000000000..ec13902379 --- /dev/null +++ b/examples/puzzletron/configs/qwen3-8b_pruneffn_memory/validate_solutions_defaults.yaml @@ -0,0 +1,10 @@ +defaults: + - /validate_model_defaults + - _self_ + +solutions_to_validate: +skip_validation: false +save_models: false +bigger_is_better: false +sort_solutions_by: +calculate_full_score_ablations: false diff --git a/examples/puzzletron/evaluation/hf_deployable_anymodel.py b/examples/puzzletron/evaluation/hf_deployable_anymodel.py new file mode 100644 index 0000000000..3ca8dd7581 --- /dev/null +++ b/examples/puzzletron/evaluation/hf_deployable_anymodel.py @@ -0,0 +1,724 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# mypy: ignore-errors + +import json +import logging +from typing import Any + +import numpy as np +import torch +from nemo_deploy import ITritonDeployable +from nemo_deploy.utils import broadcast_list, cast_output, str_ndarray2list +from nemo_export_deploy_common.import_utils import ( + MISSING_TRITON_MSG, + UnavailableError, + null_decorator, +) +from peft import PeftModel +from transformers import AutoModel, AutoModelForCausalLM, AutoTokenizer + +from modelopt.torch.puzzletron.anymodel.model_descriptor.model_descriptor_factory import ( + resolve_descriptor_from_pretrained, +) +from modelopt.torch.puzzletron.anymodel.puzzformer import deci_x_patcher + +try: + from pytriton.decorators import batch + from pytriton.model_config import Tensor + + HAVE_TRITON = True +except (ImportError, ModuleNotFoundError): + from unittest.mock import MagicMock + + HAVE_TRITON = False + batch = MagicMock() + Tensor = MagicMock() + batch = null_decorator + + +LOGGER = logging.getLogger("NeMo") + +SUPPORTED_TASKS = ["text-generation"] + + +class HuggingFaceLLMDeploy(ITritonDeployable): + """A Triton inference server compatible wrapper for HuggingFace models. + + This class provides a standardized interface for deploying HuggingFace models + in Triton inference server. It supports various NLP tasks and handles model + loading, inference, and deployment configurations. + + Args: + hf_model_id_path (Optional[str]): Path to the HuggingFace model or model identifier. + Can be a local path or a model ID from HuggingFace Hub. + hf_peft_model_id_path (Optional[str]): Path to the PEFT model or model identifier. + Can be a local path or a model ID from HuggingFace Hub. + tokenizer_id_path (Optional[str]): Path to the tokenizer or tokenizer identifier. + If None, will use the same path as hf_model_id_path. + model (Optional[AutoModel]): Pre-loaded HuggingFace model. + tokenizer (Optional[AutoTokenizer]): Pre-loaded HuggingFace tokenizer. + tokenizer_padding (bool): Whether to enable padding in tokenizer. Defaults to True. + tokenizer_truncation (bool): Whether to enable truncation in tokenizer. Defaults to True. + tokenizer_padding_side (str): Which side to pad on ('left' or 'right'). Defaults to 'left'. + task (str): HuggingFace task type (e.g., "text-generation"). Defaults to "text-generation". + **hf_kwargs: Additional keyword arguments to pass to HuggingFace model loading. + """ + + def __init__( + self, + hf_model_id_path: str | None = None, + hf_peft_model_id_path: str | None = None, + tokenizer_id_path: str | None = None, + model: AutoModel | None = None, + tokenizer: AutoTokenizer | None = None, + tokenizer_padding=True, + tokenizer_truncation=True, + tokenizer_padding_side="left", + task: str | None = "text-generation", + torch_dtype: torch.dtype | None = "auto", + device_map: str | None = "auto", + **hf_kwargs, + ): + if not HAVE_TRITON: + raise UnavailableError(MISSING_TRITON_MSG) + + if hf_model_id_path is None and model is None: + raise ValueError("hf_model_id_path or model parameters has to be passed.") + elif hf_model_id_path is not None and model is not None: + LOGGER.warning( + "hf_model_id_path will be ignored and the HuggingFace model set with model parameter will be used." + ) + + assert task in SUPPORTED_TASKS, "Task {} is not a support task.".format(task) + + self.hf_model_id_path = hf_model_id_path + self.hf_peft_model_id_path = hf_peft_model_id_path + self.task = task + self.model = model + self.tokenizer = tokenizer + self.tokenizer_padding = tokenizer_padding + self.tokenizer_truncation = tokenizer_truncation + self.tokenizer_padding_side = tokenizer_padding_side + + if tokenizer_id_path is None: + self.tokenizer_id_path = hf_model_id_path + else: + self.tokenizer_id_path = tokenizer_id_path + + if model is None: + self._load(torch_dtype=torch_dtype, device_map=device_map, **hf_kwargs) + + def _load( + self, torch_dtype: torch.dtype | None = "auto", device_map: str | None = "auto", **hf_kwargs + ) -> None: + """Load the HuggingFace pipeline with the specified model and task. + + This method initializes the HuggingFace AutoModel classes using the provided model + configuration and task type. It handles the model and tokenizer loading + process. + + Args: + torch_dtype (torch.dtype): Data type for the model. Defaults to "auto". + device_map (str): Device map for the model. Defaults to "auto". + **hf_kwargs: Additional keyword arguments to pass to the HuggingFace model loading. + + Raises: + AssertionError: If task is not specified. + """ + assert self.task is not None, "A task has to be given for the generation task." + + if self.task == "text-generation": + # ========================================================================= + # BEGIN ANYMODEL PATCH + # Wraps model loading with deci_x_patcher for heterogeneous layer configs. + # See: modelopt/torch/puzzletron/anymodel/puzzformer/utils.py + # ========================================================================= + + descriptor = resolve_descriptor_from_pretrained( + self.hf_model_id_path, trust_remote_code=hf_kwargs.get("trust_remote_code", False) + ) + + with deci_x_patcher(model_descriptor=descriptor): + self.model = AutoModelForCausalLM.from_pretrained( + self.hf_model_id_path, + torch_dtype=torch_dtype, + device_map=device_map, + **hf_kwargs, + ) + # ========================================================================= + # END ANYMODEL PATCH + # ========================================================================= + + if self.hf_peft_model_id_path is not None: + self.model = PeftModel.from_pretrained(self.model, self.hf_peft_model_id_path) + else: + raise ValueError("Task {} is not supported.".format(self.task)) + num_gpus = torch.cuda.device_count() + # If there is only one GPU, move the model to GPU. If you are using device_map as "auto" or "balanced", + # the model will be moved to GPU automatically. + if device_map is None and num_gpus >= 1 and self.model.device.type != "cuda": + self.model.cuda() + self.tokenizer = AutoTokenizer.from_pretrained( + self.tokenizer_id_path, + trust_remote_code=hf_kwargs.pop("trust_remote_code", False), + padding=self.tokenizer_padding, + truncation=self.tokenizer_truncation, + padding_side=self.tokenizer_padding_side, + ) + + if self.tokenizer.pad_token is None: + self.tokenizer.pad_token = self.tokenizer.eos_token + + def generate( + self, + **kwargs: Any, + ) -> list[str]: + """Generate text based on the provided input prompts. + + This method processes input prompts through the loaded pipeline and + generates text according to the specified parameters. + + Args: + **kwargs: Generation parameters including: + - text_inputs: List of input prompts + - max_length: Maximum number of tokens to generate + - num_return_sequences: Number of sequences to generate per prompt + - temperature: Sampling temperature + - top_k: Number of highest probability tokens to consider + - top_p: Cumulative probability threshold for token sampling + - do_sample: Whether to use sampling, default is False for greedy decoding + - echo: Whether to return prompt + generated text (True) or just generated text (False) + - return_full_text: Whether to return full text or only generated part + + Returns: + If output logits and output scores are False: + List[str]: A list of generated texts, one for each input prompt. + If output logits and output scores are True: + Dict: A dictionary containing: + - sentences: List of generated texts + - logits: List of logits + - scores: List of scores + - input_lengths: List of input token lengths (for echo processing) + + Raises: + RuntimeError: If the pipeline is not initialized. + """ + if not self.model: + raise RuntimeError("Model is not initialized") + + inputs = self.tokenizer( + kwargs["text_inputs"], + return_tensors="pt", + padding=self.tokenizer_padding, + truncation=self.tokenizer_truncation, + ) + + # Store input lengths to extract only generated tokens later + input_lengths = [len(input_ids) for input_ids in inputs["input_ids"]] + + # Get echo parameter (default False - only return generated text) + echo = kwargs.pop("echo", False) + kwargs.pop("text_inputs") # Remove text_inputs as it's already been tokenized + + kwargs = {**inputs, **kwargs} + for key, val in kwargs.items(): + if torch.is_tensor(val): + kwargs[key] = val.cuda() + + with torch.no_grad(): + generated_ids = self.model.generate(**kwargs) + return_dict_in_generate = kwargs.get("return_dict_in_generate", False) + if return_dict_in_generate: + # Handle dict output (when logits/scores are requested) + sequences = generated_ids["sequences"] + output = {"sentences": [], "input_lengths": input_lengths, "sequences": sequences} + + if echo: + # Return full text (prompt + generated). + # HF model's generate returns the input/prompt tokens as well by default. + for i, seq in enumerate(sequences): + full_text = self.tokenizer.decode(seq, skip_special_tokens=True) + output["sentences"].append(full_text) + else: + # Extract only the generated tokens (skip input tokens). + # This is required as HF model's generate returns the input/prompt tokens + # as well by default. (return_full_text is specific to some models) + for i, seq in enumerate(sequences): + input_len = input_lengths[i] if i < len(input_lengths) else 0 + generated_tokens = seq[input_len:] # Skip input tokens + generated_text = self.tokenizer.decode( + generated_tokens, skip_special_tokens=True + ) + output["sentences"].append(generated_text) + + if kwargs.get("output_logits", False): + output["logits"] = generated_ids["logits"] + if kwargs.get("output_scores", False): + output["scores"] = generated_ids["scores"] + else: + # Handle list output (normal case) + output = [] + if echo: + # Return full text (prompt + generated), which is the default in case of HF model generate. + for i, seq in enumerate(generated_ids): + full_text = self.tokenizer.decode(seq, skip_special_tokens=True) + output.append(full_text) + else: + # Extract only the generated tokens (skip input tokens) as the default + # behavior returns the input/prompt tokens as well. + for i, seq in enumerate(generated_ids): + input_len = input_lengths[i] if i < len(input_lengths) else 0 + generated_tokens = seq[input_len:] # Skip input tokens + generated_text = self.tokenizer.decode( + generated_tokens, skip_special_tokens=True + ) + output.append(generated_text) + + return output + + def generate_other_ranks(self): + """Generate function for ranks other than the rank 0.""" + while True: + message = torch.empty(1, dtype=torch.long, device="cuda") + torch.distributed.broadcast(message, src=0) + if message == 0: + prompts = broadcast_list(data=[None], src=0) + ( + temperature, + top_k, + top_p, + num_tokens_to_generate, + output_logits, + output_scores, + ) = broadcast_list(data=[None], src=0) + + return_dict_in_generate = False + if output_logits or output_scores: + return_dict_in_generate = True + + self.generate( + text_inputs=prompts, + do_sample=False, # do_sample=False for greedy decoding + top_k=top_k, + top_p=top_p, + temperature=temperature, + max_new_tokens=num_tokens_to_generate, + output_logits=output_logits, + output_scores=output_scores, + return_dict_in_generate=return_dict_in_generate, + ) + else: + return + + @property + def get_triton_input(self): + inputs = ( + Tensor(name="prompts", shape=(-1,), dtype=bytes), + Tensor(name="max_length", shape=(-1,), dtype=np.int_, optional=True), + Tensor(name="max_batch_size", shape=(-1,), dtype=np.int_, optional=True), + Tensor(name="top_k", shape=(-1,), dtype=np.int_, optional=True), + Tensor(name="top_p", shape=(-1,), dtype=np.single, optional=True), + Tensor(name="temperature", shape=(-1,), dtype=np.single, optional=True), + Tensor(name="random_seed", shape=(-1,), dtype=np.int_, optional=True), + Tensor(name="max_length", shape=(-1,), dtype=np.int_, optional=True), + Tensor(name="output_logits", shape=(-1,), dtype=np.bool_, optional=True), + Tensor(name="output_scores", shape=(-1,), dtype=np.bool_, optional=True), + ) + return inputs + + @property + def get_triton_output(self): + return ( + Tensor(name="sentences", shape=(-1,), dtype=bytes), + Tensor(name="logits", shape=(-1,), dtype=np.single), + Tensor(name="scores", shape=(-1,), dtype=np.single), + ) + + @batch + def triton_infer_fn(self, **inputs: np.ndarray): + output_infer = {} + + try: + prompts = str_ndarray2list(inputs.pop("prompts")) + temperature = inputs.pop("temperature")[0][0] if "temperature" in inputs else 1.0 + top_k = int(inputs.pop("top_k")[0][0] if "top_k" in inputs else 1) + top_p = inputs.pop("top_p")[0][0] if "top_p" in inputs else 0 + num_tokens_to_generate = ( + inputs.pop("max_length")[0][0] if "max_length" in inputs else 256 + ) + output_logits = ( + inputs.pop("output_logits")[0][0] if "output_logits" in inputs else False + ) + output_scores = ( + inputs.pop("output_scores")[0][0] if "output_scores" in inputs else False + ) + return_dict_in_generate = False + if output_logits or output_scores: + return_dict_in_generate = True + + if torch.distributed.is_initialized(): + if torch.distributed.get_world_size() > 1: + torch.distributed.broadcast( + torch.tensor([0], dtype=torch.long, device="cuda"), src=0 + ) + broadcast_list(prompts, src=0) + broadcast_list( + data=[ + temperature, + top_k, + top_p, + num_tokens_to_generate, + output_logits, + output_scores, + ], + src=0, + ) + + output = self.generate( + text_inputs=prompts, + do_sample=False, # do_sample=False for greedy decoding + top_k=top_k, + top_p=top_p, + temperature=temperature, + max_new_tokens=num_tokens_to_generate, + output_logits=output_logits, + output_scores=output_scores, + return_dict_in_generate=return_dict_in_generate, + echo=False, + ) + + if isinstance(output, dict): + output_infer = {"sentences": cast_output(output["sentences"], np.bytes_)} + + if "scores" in output: + output_scores = [] + for r in output["scores"]: + lp = torch.tensor(r).cpu().detach().numpy() + if len(lp) == 0: + output_scores.append([0]) + else: + output_scores.append(lp) + output_infer["scores"] = np.array(output_scores).transpose(1, 0, 2) + + if "logits" in output: + output_logits = [] + for r in output["logits"]: + lp = torch.tensor(r).cpu().detach().numpy() + if len(lp) == 0: + output_logits.append([0]) + else: + output_logits.append(lp) + output_infer["logits"] = np.array(output_logits).transpose(1, 0, 2) + else: + output_infer = {"sentences": cast_output(output, np.bytes_)} + + except Exception as error: + err_msg = "An error occurred: {}".format(str(error)) + output_infer["sentences"] = cast_output([err_msg], np.bytes_) + + return output_infer + + def _compute_logprobs( + self, + prompts: list[str], + output_infer: dict[str, Any], + compute_logprob: bool, + n_top_logprobs: int, + echo: bool, + ): + """Compute log probabilities and top log probabilities from model scores. + Used by ray_infer_fn to provide OAI API compatible output for evaluations. + + This method processes the raw scores from model generation to compute: + - Log probabilities for chosen tokens + - Top-k log probabilities for each position (if requested) + - Handles both prompt tokens (when echo=True) and generated tokens + + Args: + prompts: List of input prompts + output_infer: Dictionary containing model outputs including scores, sequences, and input_lengths + compute_logprob: Whether to compute log probabilities + n_top_logprobs: Number of top log probabilities to return (0 to disable) + echo: Whether to include prompt token log probabilities + + Returns: + Tuple[Optional[List], Optional[List]]: + - log_probs_list: List of log probabilities for each sample (None if not computed) + - top_logprobs_list: List of top-k log probabilities for each sample (None if not computed) + """ + # Tokenize the prompts to get prompt token IDs (needed for echo) + prompt_token_ids = None + prompt_inputs = None + if echo: + prompt_inputs = self.tokenizer( + prompts, + return_tensors="pt", + padding=self.tokenizer_padding, + truncation=self.tokenizer_truncation, + ) + prompt_token_ids = prompt_inputs["input_ids"] + # Move to same device as model + for key, val in prompt_inputs.items(): + if torch.is_tensor(val): + prompt_inputs[key] = val.cuda() + + # Process each sample + log_probs_list = [] + top_logprobs_list = [] + + for sample_idx in range(len(prompts)): + sample_log_probs = [] + sample_top_logprobs = [] + + # Get the generated sequence for this sample + sequences = output_infer["sequences"][sample_idx] + + # For echo, compute prompt token logprobs by running forward pass + if echo and prompt_token_ids is not None: + prompt_len = len(prompt_token_ids[sample_idx]) + + # Run forward pass on prompt to get logits for prompt tokens as scores in output_infer contains + # logits only for generated tokens. + with torch.no_grad(): + # Create input for this specific sample + sample_prompt_input = { + key: val[sample_idx : sample_idx + 1] for key, val in prompt_inputs.items() + } + prompt_outputs = self.model(**sample_prompt_input) + prompt_logits = prompt_outputs.logits[0] # Shape: [seq_len, vocab_size] + + # Calculate log probs for each prompt token (except the first BOS token) + for token_pos in range(1, prompt_len): # Start from 1 to skip BOS + # The logit at position i-1 predicts token at position i + logit_for_current_token = prompt_logits[token_pos - 1] + current_token_id = prompt_token_ids[sample_idx][token_pos].item() + + # Calculate log probabilities + log_probs = torch.nn.functional.log_softmax(logit_for_current_token, dim=-1) + chosen_log_prob = log_probs[current_token_id].item() + sample_log_probs.append(chosen_log_prob) + + # Calculate top log probabilities if requested + if n_top_logprobs > 0: + top_log_probs_dict = {} + top_k_values, top_k_indices = torch.topk( + log_probs, min(n_top_logprobs, len(log_probs)) + ) + for k_idx in range(len(top_k_indices)): + token_id = top_k_indices[k_idx].item() + token_str = self.tokenizer.decode([token_id]) + top_log_probs_dict[token_str] = top_k_values[k_idx].item() + sample_top_logprobs.append(top_log_probs_dict) + + # Process the scores for generated tokens + for token_idx, score_tensor in enumerate(output_infer["scores"]): + # Get the chosen token ID from the sequence + # Scores start after the prompt, so we need to offset + input_len = ( + output_infer.get("input_lengths", [0])[sample_idx] + if "input_lengths" in output_infer + else 0 + ) + seq_idx = input_len + token_idx + + if seq_idx < len(sequences): + chosen_token_id = ( + sequences[seq_idx].item() + if hasattr(sequences[seq_idx], "item") + else sequences[seq_idx] + ) + + # Calculate log probabilities + log_probs = torch.nn.functional.log_softmax(score_tensor[sample_idx], dim=-1) + chosen_log_prob = log_probs[chosen_token_id].item() + sample_log_probs.append(chosen_log_prob) + + # Calculate top log probabilities if requested + if n_top_logprobs > 0: + top_log_probs_dict = {} + top_k_values, top_k_indices = torch.topk( + log_probs, min(n_top_logprobs, len(log_probs)) + ) + for k_idx in range(len(top_k_indices)): + token_id = top_k_indices[k_idx].item() + token_str = self.tokenizer.decode([token_id]) + top_log_probs_dict[token_str] = top_k_values[k_idx].item() + sample_top_logprobs.append(top_log_probs_dict) + + log_probs_list.append(sample_log_probs) + if n_top_logprobs > 0: + top_logprobs_list.append(sample_top_logprobs) + + # Return log probs and top logprobs + return_log_probs = log_probs_list if compute_logprob else None + return_top_logprobs = top_logprobs_list if n_top_logprobs > 0 else None + + return return_log_probs, return_top_logprobs + + def ray_infer_fn(self, inputs: dict[Any, Any]): + """Perform inference using Ray with dictionary inputs and outputs. + + Args: + inputs (Dict[Any, Any]): Dictionary containing input parameters: + - prompts: List of input prompts + - temperature: Sampling temperature (optional) + - top_k: Number of highest probability tokens to consider (optional) + - top_p: Cumulative probability threshold for token sampling (optional) + - max_tokens: Maximum number of tokens to generate (optional) + - compute_logprob: Whether to compute log probabilities (optional) + - n_top_logprobs: Number of top log probabilities to return (optional) + - echo: Whether to echo the prompt in output (optional) + + Returns: + Dict[str, Any]: Dictionary containing: + - sentences: List of generated texts + - log_probs: Optional list of log probabilities if compute_logprob is True + - top_logprobs: Optional list of top log probabilities if n_top_logprobs > 0 + """ + try: + prompts = inputs.pop("prompts") + temperature = inputs.pop("temperature", 1.0) + top_k = int(inputs.pop("top_k", 1)) + top_p = inputs.pop("top_p", 0.0) + num_tokens_to_generate = inputs.pop("max_tokens", 256) + output_logits = inputs.pop("output_logits", False) + output_scores = inputs.pop("output_scores", False) + compute_logprob = inputs.pop("compute_logprob", False) + n_top_logprobs = inputs.pop("n_top_logprobs", 0) + echo = inputs.pop("echo", False) + + output_infer = self._infer_fn_ray( + prompts=prompts, + temperature=temperature, + top_k=top_k, + top_p=top_p, + num_tokens_to_generate=num_tokens_to_generate, + output_logits=output_logits, + output_scores=output_scores, + compute_logprob=compute_logprob, + n_top_logprobs=n_top_logprobs, + echo=echo, + ) + # Code to get logprobs (required in OAI API format for eval) from the scores in output_infer. + if ( + (compute_logprob or n_top_logprobs > 0) + and "scores" in output_infer + and output_infer["scores"] + ): + log_probs_list, top_logprobs_list = self._compute_logprobs( + prompts=prompts, + output_infer=output_infer, + compute_logprob=compute_logprob, + n_top_logprobs=n_top_logprobs, + echo=echo, + ) + + # Add to output + if log_probs_list is not None: + output_infer["log_probs"] = log_probs_list + if top_logprobs_list is not None: + # Convert to JSON strings for compatibility + output_infer["top_logprobs"] = [ + json.dumps(top_logprobs) for top_logprobs in top_logprobs_list + ] + + # Remove raw outputs that are not needed in the final response + output_infer.pop("scores", None) + output_infer.pop("sequences", None) + output_infer.pop("input_lengths", None) + return output_infer + except Exception as error: + err_msg = "An error occurred: {}".format(str(error)) + return {"sentences": [err_msg]} + + def _infer_fn_ray( + self, + prompts, + temperature=1.0, + top_k=1, + top_p=0.0, + num_tokens_to_generate=256, + output_logits=False, + output_scores=False, + compute_logprob=False, + n_top_logprobs=0, + echo=False, + cast_output_func=None, + ): + """Common internal function for inference operations. + + Args: + prompts: List of input prompts + temperature: Sampling temperature + top_k: Number of highest probability tokens to consider + top_p: Cumulative probability threshold for token sampling + num_tokens_to_generate: Maximum number of tokens to generate + output_logits: Whether to output logits + output_scores: Whether to output scores + compute_logprob: Whether to compute log probabilities + n_top_logprobs: Number of top log probabilities to return + echo: Whether to echo the prompt in output + cast_output_func: Optional function to cast output values + + Returns: + Dict containing inference results with raw outputs + """ + # Enable return_dict if we need scores for logprobs or if output_logits/scores are requested + return_dict_in_generate = ( + output_logits or output_scores or compute_logprob or n_top_logprobs > 0 + ) + # Enable output_scores if we need to compute logprobs. scores and logits from generate are both identical in + # case of greedy decoding. Hence setting output_scores to True when compute_logprob or n_top_logprobs > 0. + if compute_logprob or n_top_logprobs > 0: + output_scores = True + + if torch.distributed.is_initialized(): + if torch.distributed.get_world_size() > 1: + torch.distributed.broadcast( + torch.tensor([0], dtype=torch.long, device="cuda"), src=0 + ) + broadcast_list(prompts, src=0) + broadcast_list( + data=[ + temperature, + top_k, + top_p, + num_tokens_to_generate, + output_logits, + output_scores, + ], + src=0, + ) + + output = self.generate( + text_inputs=prompts, + do_sample=False, # do_sample=False for greedy decoding + top_k=top_k, + top_p=top_p, + temperature=temperature, + max_new_tokens=num_tokens_to_generate, + output_logits=output_logits, + output_scores=output_scores, + return_dict_in_generate=return_dict_in_generate, + echo=echo, + ) + + if isinstance(output, dict): + return output + + else: + return {"sentences": output} diff --git a/examples/puzzletron/evaluation/lm_eval_anymodel.py b/examples/puzzletron/evaluation/lm_eval_anymodel.py new file mode 100644 index 0000000000..7f9e07dd2b --- /dev/null +++ b/examples/puzzletron/evaluation/lm_eval_anymodel.py @@ -0,0 +1,115 @@ +# Adapted from https://github.com/EleutherAI/lm-evaluation-harness/tree/aa457edc3d64d81530159cd3a182932320c78f8c + +# MIT License +# +# Copyright (c) 2020 EleutherAI +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. + +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# mypy: ignore-errors + +"""Run lm-eval directly on AnyModel (Puzzletron) checkpoints without a deployment server. + +Patches lm-eval's HFLM to wrap model loading with deci_x_patcher so AnyModel +Puzzletron checkpoints load correctly. Model descriptor is auto-detected from the +checkpoint's config.json model_type. +""" + +from lm_eval import utils +from lm_eval.__main__ import cli_evaluate +from lm_eval.api.model import T +from lm_eval.models.huggingface import HFLM + +# Trigger factory registration for all model descriptors +import modelopt.torch.puzzletron.anymodel.models # noqa: F401 +from modelopt.torch.puzzletron.anymodel.model_descriptor.model_descriptor_factory import ( + resolve_descriptor_from_pretrained, +) +from modelopt.torch.puzzletron.anymodel.puzzformer import deci_x_patcher + + +def create_from_arg_obj(cls: type[T], arg_dict: dict, additional_config: dict | None = None) -> T: + """Override HFLM.create_from_arg_obj to wrap model loading with deci_x_patcher.""" + + additional_config = {} if additional_config is None else additional_config + additional_config = {k: v for k, v in additional_config.items() if v is not None} + + pretrained = arg_dict.get("pretrained") + descriptor = resolve_descriptor_from_pretrained( + pretrained, trust_remote_code=arg_dict.get("trust_remote_code", False) + ) + # The patcher must be active during HFLM.__init__ because that's where + # AutoModelForCausalLM.from_pretrained() is called internally. + with deci_x_patcher(model_descriptor=descriptor): + model_obj = cls(**arg_dict, **additional_config) + + return model_obj + + +def create_from_arg_string( + cls: type[T], arg_string: str, additional_config: dict | None = None +) -> T: + """Create an LM instance from a comma-separated argument string. + + Args: + arg_string: Arguments as ``"key1=value1,key2=value2"``. + additional_config: Extra configuration merged into the parsed args. + + Returns: + An instance of this LM subclass. + """ + args = utils.simple_parse_args_string(arg_string) + additional_config = {} if additional_config is None else additional_config + args2 = {k: v for k, v in additional_config.items() if v is not None} + + pretrained = args.get("pretrained") + descriptor = resolve_descriptor_from_pretrained( + pretrained, trust_remote_code=args.get("trust_remote_code", False) + ) + + # The patcher must be active during HFLM.__init__ because that's where + # AutoModelForCausalLM.from_pretrained() is called internally. + with deci_x_patcher(model_descriptor=descriptor): + model_obj = cls(**args, **args2) + + return model_obj + + +# Monkey-patch HFLM so lm-eval uses our patched model loading +HFLM.create_from_arg_obj = classmethod(create_from_arg_obj) +HFLM.create_from_arg_string = classmethod(create_from_arg_string) + + +if __name__ == "__main__": + cli_evaluate() diff --git a/examples/puzzletron/evaluation/nemo_evaluator_instructions.md b/examples/puzzletron/evaluation/nemo_evaluator_instructions.md new file mode 100644 index 0000000000..f8b53889c6 --- /dev/null +++ b/examples/puzzletron/evaluation/nemo_evaluator_instructions.md @@ -0,0 +1,70 @@ +# Evaluation with NeMo Evaluator (Alternative) + +> **Recommended approach:** Use lm-eval for direct evaluation without a +> deployment server. See the main [README](../README.md#evaluation) for details. + +Evaluate AnyModel checkpoints by deploying a local OpenAI-compatible completions endpoint and running benchmarks against it. + +This flow requires Ray for serving the model and NeMo Export-Deploy (included in NeMo containers): + +```bash +pip install -r examples/puzzletron/requirements.txt +``` + +**1. Deploy the model (2 GPUs example):** + +We need to patch the `hf_deployable.py` script from Export-Deploy. Best way is to do it as a mount in docker run: + +```bash +export MODELOPT_DIR=${PWD}/Model-Optimizer # or set to your local Model-Optimizer repository path if you have cloned it +if [ ! -d "${MODELOPT_DIR}" ]; then + git clone https://github.com/NVIDIA/Model-Optimizer.git ${MODELOPT_DIR} +fi + +export DOCKER_IMAGE=nvcr.io/nvidia/nemo:26.02 +docker run \ + --gpus all \ + --shm-size=16GB \ + --net=host \ + --ulimit memlock=-1 \ + --rm -it \ + -v ${MODELOPT_DIR}:/opt/Model-Optimizer \ + -v ${MODELOPT_DIR}/modelopt:/opt/venv/lib/python3.12/site-packages/modelopt \ + -v ${MODELOPT_DIR}/examples/puzzletron/evaluation/hf_deployable_anymodel.py:/opt/Export-Deploy/nemo_deploy/llm/hf_deployable.py \ + -w /opt/Model-Optimizer/examples/megatron_bridge \ + ${DOCKER_IMAGE} bash +``` + +Alternatively you can manually update the file + +```bash +# Install the AnyModel-patched deployable (first time only: backs up the original) +# /opt/Export-Deploy is the default path in NeMo containers — adjust if needed +cp /opt/Export-Deploy/nemo_deploy/llm/hf_deployable.py /opt/Export-Deploy/nemo_deploy/llm/hf_deployable.py.bak +cp examples/puzzletron/evaluation/hf_deployable_anymodel.py /opt/Export-Deploy/nemo_deploy/llm/hf_deployable.py +``` + +Now start ray server and deploy the model + +```bash +# Start the server (blocks while running — use a separate terminal) +ray start --head --num-gpus 2 --port 6379 --disable-usage-stats +python /opt/Export-Deploy/scripts/deploy/nlp/deploy_ray_hf.py \ + --model_path path/to/checkpoint \ + --model_id anymodel-hf \ + --num_gpus 2 --num_gpus_per_replica 2 --num_cpus_per_replica 16 \ + --trust_remote_code --port 8083 --device_map "auto" --cuda_visible_devices "0,1" +``` + +**2. Run MMLU:** + +```bash +eval-factory run_eval \ + --eval_type mmlu \ + --model_id anymodel-hf \ + --model_type completions \ + --model_url http://0.0.0.0:8083/v1/completions/ \ + --output_dir examples/puzzletron/evals/mmlu_anymodel +``` + +For a quick debug run, add `--overrides "config.params.limit_samples=5"`. diff --git a/examples/puzzletron/main.py b/examples/puzzletron/main.py new file mode 100644 index 0000000000..5bb04818e5 --- /dev/null +++ b/examples/puzzletron/main.py @@ -0,0 +1,177 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Main script for running the puzzletron algorithm on large language models (based on Puzzle paper https://arxiv.org/abs/2411.19146). + +This script provides three modes: +1. Default mode: Runs the full puzzletron pipeline +2. MIP-only mode: Runs only the MIP search and realize models phase +3. MIP sweep mode: Runs MIP for multiple memory compression rates (enabled via config) + +Usage: + # Full puzzletron pipeline + torchrun main.py --config ./configs/llama_3.2_1B_pruneffn_memory.yaml + + # Only MIP search and realize models phase + torchrun main.py --config ./configs/llama_3.2_1B_pruneffn_memory.yaml --mip-only + + # MIP sweep mode (set mip.sweep.enabled: true in config) + torchrun main.py --config ./configs/llama_3.2_1B_pruneffn_memory.yaml --mip-only +""" + +import argparse +from datetime import timedelta +from pathlib import Path + +import modelopt.torch.nas as mtn +import modelopt.torch.puzzletron.mip.mip_and_realize_models as mip_and_realize_models +import modelopt.torch.puzzletron.mip.sweep as sweep +import modelopt.torch.utils.distributed as dist +from modelopt.torch.puzzletron.nas.plugins.puzzletron_nas_plugin import PuzzletronModel +from modelopt.torch.puzzletron.tools.hydra_utils import ( + initialize_hydra_config_for_dir, + register_hydra_resolvers, +) +from modelopt.torch.puzzletron.tools.logger import mprint + + +def parse_args(): + """Parse command line arguments.""" + parser = argparse.ArgumentParser( + description="Compress large language models using the Puzzletron algorithm (based on Puzzle paper https://arxiv.org/abs/2411.19146)" + ) + parser.add_argument( + "--config", + type=str, + required=True, + help="Path to the main config YAML file (e.g., ./configs/llama_3.2_1B_pruneffn_memory.yaml)", + ) + parser.add_argument( + "--mip-only", + action="store_true", + help="Run only the MIP search and realize models phase (skip pruning and NAS scoring)", + ) + + return parser.parse_args() + + +def run_full_puzzletron(hydra_config_path: str): + """Run the full puzzletron pipeline. + + Args: + config_path: Path to the YAML configuration file + """ + mprint("Puzzletron Progress 1/8: starting puzzletron pipeline") + dist.setup(timeout=timedelta(10)) + + # Register Hydra custom resolvers (needed for config resolution) + register_hydra_resolvers() + + hydra_config_path = Path(hydra_config_path).resolve() + hydra_config_dir = str(hydra_config_path.parent) + hydra_config_name = hydra_config_path.stem + + # Load hydra config + hydra_cfg = initialize_hydra_config_for_dir( + config_dir=hydra_config_dir, + config_name=hydra_config_name, + overrides=[], + ) + + # Convert model (convert from HF to DeciLM, score pruning activations, + # prune the model and save pruned checkpoints) + input_model = PuzzletronModel() + converted_model = mtn.convert( + input_model, + mode=[ + ( + "puzzletron", + { + "puzzle_dir": str(hydra_cfg.puzzle_dir), + "input_model_path": hydra_cfg.input_hf_model_path, + "hydra_config_dir": hydra_config_dir, + "hydra_config_name": hydra_config_name, + "dataset_path": str(hydra_cfg.dataset_path), + }, + ) + ], + ) + + # Run NAS search (build replacement library and compute stats, + # compute one block scores, run MIP and realize models) + mtn.search( + converted_model, + constraints={}, # this is not used as the search space is defined in the hydra config + dummy_input=None, # Not used + config={}, # this is not used as the search space is defined in the hydra config + ) + + dist.cleanup() + mprint("Puzzletron Progress 8/8: puzzletron pipeline completed (multi-gpu)") + + +def run_mip_only(hydra_config_path: str): + """Run only the MIP search and realize models phase. + + This assumes that pruning, replacement library building, NAS scoring, and subblock stats calculation + have already been completed. + + Args: + hydra_config_path: Path to the YAML configuration file + """ + dist.setup(timeout=timedelta(10)) + + # Register Hydra custom resolvers (needed for config resolution) + register_hydra_resolvers() + + hydra_config_path = Path(hydra_config_path).resolve() + hydra_config_dir = str(hydra_config_path.parent) + hydra_config_name = hydra_config_path.stem + + # Load hydra config + hydra_cfg = initialize_hydra_config_for_dir( + config_dir=hydra_config_dir, + config_name=hydra_config_name, + overrides=[], + ) + + # Check if sweep mode is enabled + if hasattr(hydra_cfg.mip, "sweep") and hydra_cfg.mip.sweep.get("enabled", False): + mprint( + "Puzzletron Progress 7/8: running MIP sweep for multiple compression rates (multi-gpu)" + ) + sweep.run_mip_sweep(hydra_cfg) + else: + # mip_and_realize_models (distributed processing) + # TODO: How to make it part of mnt.search() api, similarly to run_full_puzzletron() API + mprint("Puzzletron Progress 7/8: running MIP and realizing models (multi-gpu)") + mip_and_realize_models.launch_mip_and_realize_model(hydra_cfg) + + dist.cleanup() + mprint("Puzzletron Progress 8/8: puzzletron pipeline completed (multi-gpu)") + + +def main(): + args = parse_args() + + if args.mip_only: + run_mip_only(hydra_config_path=args.config) + else: + run_full_puzzletron(hydra_config_path=args.config) + + +if __name__ == "__main__": + main() diff --git a/examples/puzzletron/mbridge_distillation/README.md b/examples/puzzletron/mbridge_distillation/README.md new file mode 100644 index 0000000000..f7dda866e8 --- /dev/null +++ b/examples/puzzletron/mbridge_distillation/README.md @@ -0,0 +1,152 @@ +# Knowledge Distillation with Megatron-Bridge + +This guide shows how to perform knowledge distillation on Puzzletron-compressed AnyModel checkpoints using Megatron-Bridge. + +## Overview + +1. Set up the environment with Megatron-Bridge +2. Prepare tokenized dataset +3. Run knowledge distillation training directly from HuggingFace checkpoints +4. Review MMLU evaluation results (before/after distillation) + +## Setup + +**Clone Model-Optimizer repo:** + +The NeMo container does not include Model-Optimizer examples, so you need to clone the Model-Optimizer repo: + +```bash +export MODELOPT_DIR=${PWD}/Model-Optimizer +git clone https://github.com/NVIDIA/Model-Optimizer.git ${MODELOPT_DIR} +``` + +**Start Docker container:** + +Use the [NeMo 26.02.01 container](https://catalog.ngc.nvidia.com/orgs/nvidia/containers/nemo?version=26.02.01): + +```bash +# Recommended to mount a workspace directory for storing datasets and distilled models +docker run --gpus all -it --rm \ + -v /path/to/your/project:/workspace \ + -v ${MODELOPT_DIR}:/opt/Model-Optimizer \ + -v ${MODELOPT_DIR}/modelopt:/opt/venv/lib/python3.12/site-packages/modelopt \ + -w /opt/Model-Optimizer \ + nvcr.io/nvidia/nemo:26.02.01 \ + /bin/bash +``` + +## Dataset Preparation + +This section describes how to prepare datasets for knowledge distillation. We provide examples using WikiText-103, which is a small dataset that can still produce decent results (see the Qwen3-8B example below showing +10.11 percentage point improvement). For production use, larger datasets like [Nemotron-Post-Training-Dataset-v2](https://huggingface.co/datasets/nvidia/Nemotron-Post-Training-Dataset-v2) are recommended. + +### Download and Tokenize Dataset + +Download and tokenize the dataset in a single step. This downloads the dataset from HuggingFace, tokenizes it, and saves it in the Megatron format (`.bin` and `.idx` files): + +```bash +python -m modelopt.torch.utils.plugins.megatron_preprocess_data \ + --hf_dataset Salesforce/wikitext \ + --hf_name wikitext-103-v1 \ + --hf_split train \ + --output_dir path/to/hf_datasets/wikitext-103-v1 \ + --tokenizer meta-llama/Llama-3.1-8B-Instruct \ + --json_keys text \ + --workers 32 +``` + +This will create: + +- `Salesforce--wikitext_wikitext-103-v1_train_text_document.bin` - Binary tokenized data +- `Salesforce--wikitext_wikitext-103-v1_train_text_document.idx` - Index file for the binary data +- `Salesforce--wikitext_wikitext-103-v1_train_text_document/cache/` - Cache directory (created after running distillation) + +## Run Knowledge Distillation + +Run distillation directly from HuggingFace checkpoints (student and teacher) with tokenized dataset: + +```bash +torchrun --nproc_per_node=8 examples/puzzletron/mbridge_distillation/distill_hf.py \ + --student_hf_path /path/to/student/huggingface/checkpoint \ + --teacher_hf_path /path/to/teacher/huggingface/checkpoint \ + --data_paths 1.0 /path/to/hf_datasets/wikitext-103-v1/Salesforce--wikitext_wikitext-103-v1_train_text_document \ + --output_dir /path/to/distilled/checkpoint \ + --hf-export-path /path/to/exported/hf/model \ + --hf-model meta-llama/Llama-3.1-8B-Instruct \ + --seq_length 4096 \ + --tp_size 8 \ + --pp_size 1 \ + --mbs 1 \ + --gbs 4 \ + --train_iters 100 \ + --lr 0.0001 \ + --min_lr 1e-05 \ + --lr_warmup_iters 10 \ + --eval_interval 10 \ + --eval_iters 10 \ + --log_interval 1 +``` + +**Notes:** + +- Add `--trust_remote_code` if student or teacher checkpoints need HuggingFace custom modeling code. +- The distilled Megatron-Bridge checkpoint will be saved to `--output_dir/checkpoints/iter_`. +- Add `--hf-export-path` (or `--hf_export_path`) to automatically export the final checkpoint to HuggingFace format after distillation. When exporting, you must also provide `--hf-model` / `--hf_model` as the HuggingFace model ID for the export template (e.g., `meta-llama/Llama-3.1-8B-Instruct`). It should match the base architecture of the student model. The exported model can be evaluated for accuracy using the evaluation tools described in the main [README.md](../README.md#evaluation). +- For production use, use larger datasets like [Nemotron-Pretraining-SFT-v1](https://huggingface.co/datasets/nvidia/Nemotron-Pretraining-SFT-v1) and train for more iterations. See the [Megatron-Bridge distillation tutorial](https://github.com/NVIDIA/Model-Optimizer/tree/main/examples/megatron_bridge#distillation) for best practices. + +## MMLU Evaluation Results + +This section presents MMLU evaluation results for knowledge distillation experiments compressing Qwen3-8B and Llama-3.1-8B-Instruct. + +### Successful Case: Qwen3-8B (80% of original) + +Distillation results for a memory-compressed Qwen3-8B checkpoint (80% of original size): + +| Model | MMLU | Humanities | Other | Social Sci | STEM | +|-------|------|------------|-------|------------|------| +| 80% pre-distillation | 0.5910 | 0.5046 | 0.6363 | 0.6831 | 0.5855 | +| 80% post-distillation | 0.6921 | 0.5906 | 0.7316 | 0.7975 | 0.7016 | +| Original Qwen3-8B | 0.7493 | 0.6648 | 0.7856 | 0.8385 | 0.7526 | + +**Key observations:** + +- MMLU accuracy improved from 59.10% to 69.21% (+10.11 percentage points) after distillation +- Achieved with just 100 iterations on WikiText-103, demonstrating efficient knowledge transfer +- Recovery of 64% of the gap to the teacher model (from 59.10% to 69.21%, closing 64% of the gap from 59.10% to 74.93%) +- All individual category scores (Humanities, Other, Social Sciences, STEM) improved significantly + +### Successful Case: Llama-3.1-8B-Instruct (50% of original, 56,810 MiB) + +Distillation results for a pruned Llama-3.1-8B-Instruct checkpoint (50% of original size, 56,810 MiB memory constraint): + +| Model | MMLU | Humanities | Other | Social Sciences | STEM | +|-------|------|------------|-------|-----------------|------| +| Before distillation | 0.2316 | 0.2462 | 0.2292 | 0.2250 | 0.2274 | +| After distillation | 0.2960 | 0.3146 | 0.3085 | 0.2925 | 0.2768 | +| Original Llama-3.1-8B-Instruct | 0.6839 | 0.7231 | 0.7038 | 0.7667 | 0.5911 | + +**Key observations:** + +- MMLU accuracy (average across all categories) improved from 23.16% to 29.60% (+6.44 percentage points) +- All individual category scores (Humanities, Other, Social Sciences, STEM) improved, demonstrating effective knowledge transfer from teacher to student + +### Regression Case: Llama-3.1-8B-Instruct (69% of original, 78,000 MiB) + +Distillation results for a pruned Llama-3.1-8B-Instruct checkpoint (approximately 69% of original size, 78,000 MiB memory constraint) showing regression due to overfitting on the small WikiText-103 dataset (evaluated with limit 100): + +| Model | MMLU | Humanities | Other | Social Sciences | STEM | +|-------|------|------------|-------|-----------------|------| +| Before distillation | 0.6626 | 0.7069 | 0.6892 | 0.7525 | 0.5574 | +| After distillation | 0.6496 | 0.6862 | 0.6677 | 0.7433 | 0.5532 | +| Original Llama-3.1-8B-Instruct | 0.6839 | 0.7231 | 0.7038 | 0.7667 | 0.5911 | + +**Key observations:** + +- MMLU accuracy (average across all categories) decreased from 66.26% to 64.96% (-1.30 percentage points) after distillation +- The model overfitted to the small WikiText-103 dataset, causing performance regression +- This demonstrates the critical importance of using larger, more diverse datasets for knowledge distillation + +### Recommendations + +- **For production distillation:** Use larger production datasets like [nvidia/Nemotron-Pretraining-SFT-v1](https://huggingface.co/datasets/nvidia/Nemotron-Pretraining-SFT-v1) for better results and to avoid overfitting (see regression case above) +- **Training duration:** Train for more iterations to ensure proper convergence +- **See the [Megatron-Bridge distillation tutorial](https://github.com/NVIDIA/Model-Optimizer/tree/main/examples/megatron_bridge#distillation) for best practices** diff --git a/examples/puzzletron/mbridge_distillation/distill_hf.py b/examples/puzzletron/mbridge_distillation/distill_hf.py new file mode 100644 index 0000000000..ac703909c2 --- /dev/null +++ b/examples/puzzletron/mbridge_distillation/distill_hf.py @@ -0,0 +1,329 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Distillation script for Megatron-Bridge. + +Loads student and teacher models directly from HuggingFace checkpoints (local or remote) and saves the distilled model +to `/checkpoints` in megatron distributed checkpoint format. + +See `README.md` in this directory for example usage and data preparation instructions. +""" + +import argparse +import os +import traceback + +import megatron.bridge.models.distillation_provider +import torch +from megatron.bridge import AutoBridge +from megatron.bridge.recipes.utils.optimizer_utils import ( + distributed_fused_adam_with_cosine_annealing, +) +from megatron.bridge.training.config import ( + CheckpointConfig, + ConfigContainer, + GPTDatasetConfig, + LoggerConfig, + MockGPTDatasetConfig, + RNGConfig, + TokenizerConfig, + TrainingConfig, +) +from megatron.bridge.training.post_training.distillation import ModelOptDistillConfig +from megatron.core.datasets.utils import get_blend_from_list +from megatron.core.distributed import DistributedDataParallelConfig + +# Import heterogeneous bridges BEFORE AutoBridge.from_hf_pretrained() is called to ensure +# registration takes precedence. The @MegatronModelBridge.register_bridge decorator registers +# bridges when the module is imported. If both LlamaBridge and PuzzletronLlamaAnyModelBridge +# register for the same source (LlamaForCausalLM), the dispatch system uses the last registration. +# +# Note: Currently, bridges are also registered when distillation_provider is imported +# below (via mbridge/__init__.py), but this import will be needed once DistillationProvider +# is upstreamed to Megatron-Bridge and we no longer import from modelopt.torch.puzzletron. +import modelopt.torch.puzzletron.export.mbridge # noqa: F401 +import modelopt.torch.utils.distributed as dist + +# Use local copy of distillation_provider with fix for heterogeneous models +# TODO: Remove this local copy once fix is upstreamed to Megatron-Bridge +from modelopt.torch.puzzletron.export.mbridge.distillation_provider import ( + DistillationProvider, + convert_to_distillation_provider, +) +from modelopt.torch.puzzletron.export.mbridge.export_mbridge_to_hf import ( + export_to_hf_and_copy_config, +) +from modelopt.torch.utils import print_rank_0 + +# Patch upstream module BEFORE importing distill() so isinstance checks work with our local DistillationProvider +# This must happen before distill() is imported because distill.py imports DistillationProvider at module load time +megatron.bridge.models.distillation_provider.DistillationProvider = DistillationProvider + +# Import distill() AFTER patching so it uses the patched DistillationProvider +from megatron.bridge.training.distill import distill # noqa: E402 + +SEED = 1234 + + +def get_args(): + """Parse command-line arguments.""" + parser = argparse.ArgumentParser(description="Distillation for Megatron-Bridge.") + # Model arguments (accepts HuggingFace input only at the moment) + parser.add_argument( + "--student_hf_path", + type=str, + required=True, + help="HuggingFace model name or path for the student (e.g. Qwen/Qwen3-0.6B)", + ) + parser.add_argument( + "--teacher_hf_path", + type=str, + required=True, + help="HuggingFace model name or path for the teacher (e.g. Qwen/Qwen3-8B)", + ) + parser.add_argument("--trust_remote_code", action="store_true", help="Trust remote code") + # Parallelism arguments + parser.add_argument("--tp_size", type=int, default=1, help="Tensor parallel size") + parser.add_argument("--pp_size", type=int, default=1, help="Pipeline parallel size") + # Dataset arguments + parser.add_argument( + "--data_paths", + nargs="+", + help="List of tokenized data paths to load from (weight1 path1 weight2 path2 ...)", + ) + parser.add_argument( + "--split", type=str, default="99,1,0", help="Train,Val,Test ratios to split data" + ) + parser.add_argument( + "--data_path_to_cache", type=str, default=None, help="Path to cache the dataset indices" + ) + parser.add_argument( + "--use_mock_data", action="store_true", help="Use mock data instead of --data_paths" + ) + # Training & Eval arguments + parser.add_argument( + "--output_dir", type=str, required=True, help="Folder for logging and checkpoint saving" + ) + parser.add_argument( + "--seq_length", + type=int, + default=4096, + help="Number of tokens per input sample. Use 8192 if your dataset has longer sequences.", + ) + parser.add_argument("--mbs", type=int, default=1, help="Micro-batch Size") + parser.add_argument("--gbs", type=int, default=768, help="Global Batch Size") + parser.add_argument( + "--train_iters", type=int, required=True, help="Number of training iterations" + ) + parser.add_argument("--lr", type=float, default=1e-4, help="Peak learning rate") + parser.add_argument("--min_lr", type=float, default=1e-5, help="Minimum learning rate") + parser.add_argument("--lr_warmup_iters", type=int, default=50, help="Number of LR warmup steps") + parser.add_argument( + "--eval_interval", type=int, default=100, help="Validate + checkpoint every steps" + ) + parser.add_argument( + "--eval_iters", type=int, default=32, help="Number of batches per validation stage" + ) + # Logging arguments + parser.add_argument("--log_interval", type=int, default=10, help="Write to log every steps") + parser.add_argument( + "--wandb_project", type=str, help="Wandb project name (required to enable Wandb logging)" + ) + parser.add_argument("--wandb_entity", type=str, help="Wandb entity name (optional)") + parser.add_argument("--wandb_exp_name", type=str, help="Wandb experiment name (optional)") + # Export arguments + parser.add_argument( + "--hf_export_path", + "--hf-export-path", + type=str, + default=None, + help=( + "Path where to save the HuggingFace export. " + "If provided, exports checkpoint to HF format after distillation." + ), + ) + parser.add_argument( + "--hf_model", + "--hf-model", + type=str, + required=True, + help="HuggingFace model ID to use as template for export (e.g., meta-llama/Llama-3.1-8B-Instruct). " + "Should match the base architecture of the student model.", + ) + args = parser.parse_args() + + # Sanity checks + if not args.use_mock_data and not args.data_paths: + raise ValueError("Must provide either --data_paths or set --use_mock_data.") + + print_rank_0("\n==================== Arguments ====================") + for k, v in args.__dict__.items(): + print_rank_0(f"{k:<35} {v}") + print_rank_0("===================================================\n") + + return args + + +def main(args: argparse.Namespace): + checkpoint_dir = os.path.join(args.output_dir, "checkpoints") + tensorboard_dir = os.path.join(args.output_dir, "tb_logs") + + # Build student and teacher model providers + def _build_model_provider(hf_path): + bridge = AutoBridge.from_hf_pretrained(hf_path, trust_remote_code=args.trust_remote_code) + provider = bridge.to_megatron_provider(load_weights=True) + + # Override parallelism / training settings + provider.tensor_model_parallel_size = args.tp_size + provider.pipeline_model_parallel_size = args.pp_size + provider.context_parallel_size = 1 + provider.sequence_parallel = args.tp_size > 1 + provider.seq_length = args.seq_length + provider.pipeline_dtype = torch.bfloat16 + return provider + + # TODO: Support megatron-ckpt as an alternative to HF checkpoints (e.g. /path/to/ckpt/iter_0000000) + # Still requires an HF model name or path to build provider correctly + student_provider = _build_model_provider(args.student_hf_path) + teacher_provider = _build_model_provider(args.teacher_hf_path) + + # Wrap into DistillationProvider + kd_config = ModelOptDistillConfig() + distill_provider = convert_to_distillation_provider( + student_provider, teacher_provider, kd_config + ) + + # Build optimizer and scheduler + optimizer_config, scheduler_config = distributed_fused_adam_with_cosine_annealing( + lr_warmup_iters=args.lr_warmup_iters, + max_lr=args.lr, + min_lr=args.min_lr, + adam_beta2=0.98, + ) + + # Build dataset config + dataset_kwargs = { + "seq_length": args.seq_length, + "path_to_cache": args.data_path_to_cache, + "random_seed": SEED, + "reset_attention_mask": False, + "reset_position_ids": False, + "eod_mask_loss": False, + "num_dataset_builder_threads": 1, + "data_sharding": True, + "dataloader_type": "single", + "skip_getting_attention_mask_from_dataset": True, + } + if args.use_mock_data: + dataset_config = MockGPTDatasetConfig(**dataset_kwargs) + else: + # Convert flat CLI list (e.g. ["1.0", "/path/data"]) to Megatron blend format + blend = get_blend_from_list(args.data_paths) + dataset_config = GPTDatasetConfig(blend=blend, split=args.split, **dataset_kwargs) + + # Assemble ConfigContainer and run distillation + config = ConfigContainer( + model=distill_provider, + train=TrainingConfig( + train_iters=args.train_iters, + eval_interval=args.eval_interval, + eval_iters=args.eval_iters, + global_batch_size=args.gbs, + micro_batch_size=args.mbs, + manual_gc=True, + manual_gc_interval=100, + ), + # TODO: Replace validation args in train with validation config in nemo:26.04 + # validation=ValidationConfig(eval_interval=args.eval_interval, eval_iters=args.eval_iters), + optimizer=optimizer_config, + scheduler=scheduler_config, + ddp=DistributedDataParallelConfig( + check_for_nan_in_grad=True, + grad_reduce_in_fp32=True, + overlap_grad_reduce=True, + overlap_param_gather=True, + average_in_collective=True, + use_distributed_optimizer=True, + ), + dataset=dataset_config, + logger=LoggerConfig( + log_interval=args.log_interval, + tensorboard_dir=tensorboard_dir, + log_timers_to_tensorboard=True, + # Weights & Biases logging + wandb_project=args.wandb_project, + wandb_entity=args.wandb_entity, # optional + wandb_exp_name=args.wandb_exp_name, + ), + tokenizer=TokenizerConfig( + tokenizer_type="NullTokenizer", vocab_size=distill_provider.vocab_size + ), + checkpoint=CheckpointConfig( + save_interval=args.eval_interval, + save=checkpoint_dir, + load=checkpoint_dir, # Resume from this directory (if exists) + most_recent_k=3, # Keeps 3 most recent checkpoints (not metric-based) + ckpt_format="torch_dist", + async_save=True, + fully_parallel_save=True, + ), + rng=RNGConfig(seed=SEED), + mixed_precision="bf16_mixed", + ) + + print_rank_0("\nStarting distillation...") + distill(config) + print_rank_0(f"\nDistillation done! Saved checkpoint to {checkpoint_dir}\n") + + # Export to HuggingFace format if hf_export_path is provided + if args.hf_export_path: + # Wait for all ranks to finish distillation before export + if torch.distributed.is_initialized(): + torch.distributed.barrier() + + # Save rank before destroying process group (dist.rank() won't work after destruction) + is_rank_0 = dist.rank() == 0 + + # Destroy process group on all ranks - export_ckpt will create its own temporary one + # This prevents cleanup from hanging (cleanup tries to barrier, but rank 0 would be gone) + if torch.distributed.is_initialized(): + torch.distributed.destroy_process_group() + + # Only rank 0 exports + if is_rank_0: + try: + export_to_hf_and_copy_config( + student_hf_path=args.student_hf_path, + checkpoint_dir=checkpoint_dir, + train_iters=args.train_iters, + hf_export_path=args.hf_export_path, + hf_model=args.hf_model, + trust_remote_code=args.trust_remote_code, + ) + except Exception as e: + print(f"⚠️ Export failed: {e}") + traceback.print_exc() + + +if __name__ == "__main__": + dist.setup() + args = get_args() + try: + main(args) + except Exception as e: + print_rank_0(f"✗ MAIN FAILED: {type(e).__name__}: {e}") + print_rank_0(f"Traceback:\n{traceback.format_exc()}") + raise + finally: + dist.cleanup() diff --git a/examples/puzzletron/mip_sweep_example.png b/examples/puzzletron/mip_sweep_example.png new file mode 100644 index 0000000000..4eb1089fe0 Binary files /dev/null and b/examples/puzzletron/mip_sweep_example.png differ diff --git a/examples/puzzletron/requirements.txt b/examples/puzzletron/requirements.txt new file mode 100644 index 0000000000..0511fb473b --- /dev/null +++ b/examples/puzzletron/requirements.txt @@ -0,0 +1,3 @@ +lm-eval==0.4.10 +math-verify +ray diff --git a/modelopt/torch/prune/importance_hooks/__init__.py b/modelopt/torch/prune/importance_hooks/__init__.py index 3bf30c2a46..1e86ddcf65 100644 --- a/modelopt/torch/prune/importance_hooks/__init__.py +++ b/modelopt/torch/prune/importance_hooks/__init__.py @@ -18,6 +18,7 @@ from .base_hooks import * from .base_hooks_analysis import * +from .expert_removal_hooks import * with import_plugin("megatron_hooks"): from .plugins.megatron_hooks import * diff --git a/modelopt/torch/prune/importance_hooks/base_hooks.py b/modelopt/torch/prune/importance_hooks/base_hooks.py index 248e6ec108..44eea3bdbe 100644 --- a/modelopt/torch/prune/importance_hooks/base_hooks.py +++ b/modelopt/torch/prune/importance_hooks/base_hooks.py @@ -149,7 +149,8 @@ def dump_activations_logs( torch.save(activations_log, activations_log_path) if rank == 0: - args.activation_hooks_kwargs.pop("model") + if args.activation_hooks_kwargs is not None: + args.activation_hooks_kwargs.pop("model", None) json_dump(OmegaConf.to_container(args, resolve=True), activations_log_dir / "args.json") dist.barrier() @@ -565,9 +566,9 @@ def __init__(self, linear_layer: nn.Linear, activation_hooks_kwargs: dict): assert self.optimize_for in ["latency", "memory"] self.hidden_size = model_config.hidden_size - self.n_heads_in_group = block_config.attention.n_heads_in_group self.num_q_heads = model_config.num_attention_heads - self.num_kv_heads = self.num_q_heads // self.n_heads_in_group + self.num_kv_heads = block_config.attention.num_key_value_heads + self.n_heads_in_group = self.num_q_heads // self.num_kv_heads self.head_dim = getattr(model_config, "head_dim", self.hidden_size // self.num_q_heads) self.agg_kv_head_contributions = torch.zeros( @@ -734,7 +735,7 @@ def _save_channel_importance_results( all_scores = [] for activation_file in activation_files: print(f"Loading activations from {activation_file}") - # SECURITY: weights_only=False is required because files contain dictionaries with tensors. + # Security NOTE: weights_only=False is required because files contain dictionaries with tensors. # These files are generated by dump_activations_logs() in this module and contain # hook state dictionaries. The activations_log_dir should only contain trusted files # generated by the same codebase, not from untrusted sources. diff --git a/modelopt/torch/prune/importance_hooks/compare_module_outputs.py b/modelopt/torch/prune/importance_hooks/compare_module_outputs.py index e692a518ae..dbb4f564d7 100644 --- a/modelopt/torch/prune/importance_hooks/compare_module_outputs.py +++ b/modelopt/torch/prune/importance_hooks/compare_module_outputs.py @@ -52,7 +52,8 @@ python compare_module_outputs.py \ --reference output_unpruned.pt \ --compare output_l2norm.pt \ - --output-json comparison_stats.json + --output-json comparison_stats.json \ + --trust-inputs The saved file format\: @@ -180,21 +181,26 @@ def main(): default=None, help="Path to save comparison statistics as JSON", ) + parser.add_argument( + "--trust-inputs", + action="store_true", + help="Trust input files for loading with weights_only=False in torch.load()", + ) args = parser.parse_args() # Load reference data print(f"\nLoading reference: {args.reference}") - # SECURITY: weights_only=False is required because files contain dictionaries with tensors. + # Security NOTE: weights_only=False is required because files contain dictionaries with tensors. # These files are expected to be generated by save_multi_layer_outputs() in this module, # not from untrusted sources. Users should only load files they generated themselves. - ref_data = torch.load(args.reference, map_location="cpu", weights_only=False) + ref_data = torch.load(args.reference, map_location="cpu", weights_only=not args.trust_inputs) # Load comparison data print(f"Loading compare: {args.compare}") - # SECURITY: weights_only=False is required because files contain dictionaries with tensors. + # Security NOTE: weights_only=False is required because files contain dictionaries with tensors. # These files are expected to be generated by save_multi_layer_outputs() in this module, # not from untrusted sources. Users should only load files they generated themselves. - comp_data = torch.load(args.compare, map_location="cpu", weights_only=False) + comp_data = torch.load(args.compare, map_location="cpu", weights_only=not args.trust_inputs) # Compare multi-layer outputs compare_multi_layer(ref_data, comp_data, args.output_json) diff --git a/modelopt/torch/prune/importance_hooks/expert_removal_hooks.py b/modelopt/torch/prune/importance_hooks/expert_removal_hooks.py new file mode 100644 index 0000000000..68eaf2e711 --- /dev/null +++ b/modelopt/torch/prune/importance_hooks/expert_removal_hooks.py @@ -0,0 +1,387 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# mypy: ignore-errors + +"""MoE expert-removal and ranked-choice importance hooks (uses Puzzletron BlockConfig).""" + +from abc import ABC, abstractmethod + +import torch +from torch import nn + +from modelopt.torch.puzzletron.decilm.deci_lm_hf_code.block_config import BlockConfig # noqa: TC001 + +from .base_hooks import ForwardHook + +__all__ = [ + "NemotronHRemoveExpertsIndependentHook", + "Qwen3VLRemoveExpertsIndependentHook", + "RankedChoiceVotingHook", + "RankedChoiceVotingHookNemotronH", + "RemoveExpertsIndependentHook", +] + + +class RemoveExpertsIndependentHook(ForwardHook, ABC): + """Base hook for measuring expert importance in Mixture-of-Experts models. + + This hook measures how much removing each expert affects the model output + by comparing outputs with and without each expert. + """ + + def __init__(self, moe: nn.Module, activation_hooks_kwargs: dict): + """Initialize the hook. + + Args: + moe: The MoE module to analyze + activation_hooks_kwargs: Configuration dict containing block_config + """ + self.moe = moe + block_config: BlockConfig = activation_hooks_kwargs["block_config"] + self.num_local_experts = block_config.ffn.moe.num_local_experts + self.num_experts_per_tok = block_config.ffn.moe.num_experts_per_tok + # tensor of zeros of size num experts + self.diffs = ["mse", "cosine"] + some_param = next(self.moe.parameters()) + self.diffs = { + k: torch.zeros( + size=(self.num_local_experts,), dtype=torch.float32, device=some_param.device + ) + for k in self.diffs + } + self.call_count = 0 + + @abstractmethod + def get_router_logits_and_routed_experts( + self, hidden_states: torch.Tensor, router_logits: torch.Tensor | None = None + ) -> tuple[torch.Tensor, torch.Tensor]: + """Extract router logits and expert outputs for measuring expert importance. + + This method is called twice per forward pass: + 1. First call (router_logits=None): Compute original routing and expert outputs + 2. Second call (router_logits provided): Re-run with modified logits (expert disabled) + + Args: + hidden_states: Input tensor of shape (batch, seq_len, hidden_dim) + router_logits: Optional pre-computed router logits. If None, compute from hidden_states. + + Returns: + tuple of (router_logits, routed_experts): + - router_logits: Shape (num_tokens, num_local_experts) + - routed_experts: Shape (num_tokens, hidden_dim) + """ + raise NotImplementedError + + def __call__( + self, module: nn.Module, args: tuple[torch.Tensor, ...], output: torch.Tensor + ) -> None: + """Forward hook that measures expert importance.""" + hidden_states = args[0] + router_logits, original_routed_out = self.get_router_logits_and_routed_experts( + hidden_states + ) + + hidden_states = hidden_states.view(-1, hidden_states.shape[-1]) + original_routed_out = original_routed_out.view(-1, original_routed_out.shape[-1]) + + _, router_indices = torch.topk(router_logits, self.num_experts_per_tok, dim=-1) + self.call_count += 1 + + for i_expert in range(self.num_local_experts): + expert_mask = router_indices == i_expert + is_token_routed_to_this_expert = expert_mask.any(dim=-1) + + num_tokens_displaced = is_token_routed_to_this_expert.sum() + if num_tokens_displaced == 0: + continue + num_total_tokens = is_token_routed_to_this_expert.numel() + + relevant_hidden_states = hidden_states[is_token_routed_to_this_expert, :] + + router_logits_without_i = router_logits.clone() + router_logits_without_i[..., i_expert] = -float("inf") # disable expert i + router_logits_without_i = router_logits_without_i[is_token_routed_to_this_expert, :] + _, routed_out_without_i = self.get_router_logits_and_routed_experts( + relevant_hidden_states, router_logits_without_i + ) + + relevant_tokens_original_out = original_routed_out[is_token_routed_to_this_expert, :] + self.diffs["mse"][i_expert] += ( + nn.functional.mse_loss( + relevant_tokens_original_out, routed_out_without_i, reduction="mean" + ) + * num_tokens_displaced + / num_total_tokens + ) + self.diffs["cosine"][i_expert] += ( + -nn.functional.cosine_similarity( + relevant_tokens_original_out, routed_out_without_i, dim=-1 + ).mean() + * num_tokens_displaced + / num_total_tokens + ) + + def to_dict(self) -> dict[str, torch.Tensor]: + """Convert accumulated statistics to dict format.""" + expert_ranks_mse = torch.argsort(self.diffs["mse"]) + expert_ranks_cosine = torch.argsort(self.diffs["cosine"]) + return { + "expert_ranks_mse": expert_ranks_mse.cpu(), + "expert_ranks_cosine": expert_ranks_cosine.cpu(), + "cosine_diffs": (self.diffs["cosine"] / self.call_count).cpu(), + "mse_diffs": (self.diffs["mse"] / self.call_count).cpu(), + } + + def accumulate(self) -> torch.Tensor: + """Return accumulated expert importance scores.""" + return self.diffs["mse"] + + def state_dict(self) -> dict: + """Return the internal state for checkpointing.""" + return { + "diffs_mse": self.diffs["mse"].cpu(), + "diffs_cosine": self.diffs["cosine"].cpu(), + "call_count": self.call_count, + } + + def load_state_dict(self, state_dict: dict) -> None: + """Load the internal state from a checkpoint.""" + self.diffs["mse"] = state_dict["diffs_mse"].to(self.diffs["mse"].device) + self.diffs["cosine"] = state_dict["diffs_cosine"].to(self.diffs["cosine"].device) + self.call_count = state_dict["call_count"] + + +class NemotronHRemoveExpertsIndependentHook(RemoveExpertsIndependentHook): + """Expert removal importance hook for NemotronH models.""" + + def get_router_logits_and_routed_experts( + self, hidden_states: torch.Tensor, router_logits: torch.Tensor | None = None + ) -> tuple[torch.Tensor, torch.Tensor]: + """Extract router logits and expert outputs for NemotronH MoE. + + Based on NemotronHMOE forward, uses minimum ops to get router_logits and routed_experts. + """ + orig_shape = hidden_states.shape + # NemotronHMOE.gate forward, copied to extract router_logits + hidden_states = hidden_states.view(-1, hidden_states.shape[-1]) + if router_logits is None: + router_logits = nn.functional.linear( + hidden_states.type(torch.float32), self.moe.gate.weight.type(torch.float32) + ) + router_logits = router_logits.sigmoid() + router_logits = router_logits + self.moe.gate.e_score_correction_bias.unsqueeze(0) + + topk_indices = self._get_topk_indices_without_correction_bias(router_logits) + topk_weights = router_logits.gather(1, topk_indices) + if self.moe.gate.norm_topk_prob: + denominator = topk_weights.sum(dim=-1, keepdim=True) + 1e-20 + topk_weights /= denominator + topk_weights = topk_weights * self.moe.gate.routed_scaling_factor + # Routed experts forward + hidden_states = self.moe.moe(hidden_states, topk_indices, topk_weights).view(*orig_shape) + return router_logits, hidden_states + + @torch.no_grad() + def _get_topk_indices_without_correction_bias(self, scores: torch.Tensor) -> torch.Tensor: + """Get topk indices without correction bias. + + Same as NemotronHMOE.gate.get_topk_indices but without adding e_score_correction_bias. + """ + group_scores = ( + scores.view( + -1, self.moe.gate.n_group, self.moe.gate.n_routed_experts // self.moe.gate.n_group + ) + .topk(2, dim=-1)[0] + .sum(dim=-1) + ) + group_idx = torch.topk(group_scores, k=self.moe.gate.topk_group, dim=-1, sorted=False)[1] + group_mask = torch.zeros_like(group_scores) + group_mask.scatter_(1, group_idx, 1) + score_mask = ( + group_mask.unsqueeze(-1) + .expand( + -1, self.moe.gate.n_group, self.moe.gate.n_routed_experts // self.moe.gate.n_group + ) + .reshape(-1, self.moe.gate.n_routed_experts) + ) + scores_for_choice = scores.masked_fill(~score_mask.bool(), 0.0) + topk_indices = torch.topk(scores_for_choice, k=self.moe.gate.top_k, dim=-1, sorted=False)[1] + return topk_indices + + +class RankedChoiceVotingHook(ForwardHook): + """Hook for ranking experts using ranked choice voting algorithm. + + This hook tracks router decisions and uses ranked choice voting to determine + which experts are least important (can be pruned first). + """ + + def __init__(self, router: nn.Module, activation_hooks_kwargs: dict): + """Initialize the hook. + + Args: + router: The router module (typically nn.Linear) + activation_hooks_kwargs: Configuration dict containing block_config + """ + self.router_argsort: list[torch.Tensor] = [] + block_config: BlockConfig = activation_hooks_kwargs["block_config"] + self.top_k = block_config.ffn.moe.num_experts_per_tok + + def __call__( + self, module: nn.Module, args: tuple[torch.Tensor, ...], output: torch.Tensor + ) -> None: + """Forward hook that records router decisions. + + Args: + module: The router module + args: Tuple with one tensor entry (B, T, I) + output: Router logits of shape (B, T, E) + """ + router_logits = output[0] if isinstance(output, tuple) else output + num_experts = router_logits.shape[-1] + router_argsort = torch.argsort(router_logits, dim=-1, descending=True) + router_argsort = router_argsort.view(-1, num_experts).to(torch.int16).cpu() + self.router_argsort.append(router_argsort) + + def to_dict(self) -> dict[str, torch.Tensor]: + """Convert accumulated statistics to dict format using ranked choice voting.""" + router_argsort = torch.concat(self.router_argsort, dim=0) + num_tokens, num_experts = router_argsort.shape + + expert_ranks = torch.full((num_experts,), -1) + expert_counts_at_pruning_time = {} + + expert_kept_per_iteration: list[list[int]] = [] + expert_counts_per_iteration: list[dict[int, int]] = [] + + for rank in range(num_experts): + ids, counts = router_argsort[:, : self.top_k].unique(return_counts=True) + ids = ids.tolist() + counts = counts.tolist() + expert_counts = dict(zip(ids, counts)) + + expert_kept_per_iteration.append(ids) + expert_counts_per_iteration.append(expert_counts) + + least_popular_expert, min_count = min(expert_counts.items(), key=lambda tup: tup[1]) + + expert_ranks[least_popular_expert] = rank + expert_counts_at_pruning_time[least_popular_expert] = min_count + print(f"#{rank}: router_argsort shape = {router_argsort.shape}") + router_argsort = router_argsort[router_argsort != least_popular_expert].view( + num_tokens, -1 + ) + + zero_shot_expert_counts = torch.zeros((num_experts,), dtype=torch.long) + for expert_id, expert_counts_val in expert_counts_per_iteration[0].items(): + zero_shot_expert_counts[expert_id] = expert_counts_val + + # Compute zero-shot expert ranks (double argsort converts counts to rank positions) + zero_shot_expert_ranks = torch.argsort(torch.argsort(zero_shot_expert_counts)) + + print("Done: Returning hook metadata.") + return { + "expert_ranks": expert_ranks, + "zero_shot_expert_ranks": zero_shot_expert_ranks, + "expert_counts_at_pruning_time": expert_counts_at_pruning_time, + "expert_counts_per_iteration": expert_counts_per_iteration, + "top_k": self.top_k, + } + + def accumulate(self) -> torch.Tensor: + """Return accumulated expert ranks.""" + if not self.router_argsort: + return torch.tensor([]) + router_argsort = torch.concat(self.router_argsort, dim=0) + return router_argsort[:, 0].float() + + def state_dict(self) -> dict: + """Return the internal state for checkpointing.""" + return { + "router_argsort": [tensor.cpu().clone() for tensor in self.router_argsort], + "top_k": self.top_k, + } + + def load_state_dict(self, state_dict: dict) -> None: + """Load the internal state from a checkpoint.""" + self.router_argsort = [tensor.cpu() for tensor in state_dict["router_argsort"]] + self.top_k = state_dict["top_k"] + + def get_progress_info(self) -> dict: + """Get progress information.""" + return { + "num_batches_processed": len(self.router_argsort), + "total_tokens_processed": sum(tensor.shape[0] for tensor in self.router_argsort) + if self.router_argsort + else 0, + } + + +class RankedChoiceVotingHookNemotronH(RankedChoiceVotingHook): + """Ranked choice voting hook for NemotronH models. + + In NemotronH, router_logits is an internal temporary state that never leaves + the forward() function. We reconstruct router_logits from the input hidden_states. + """ + + def __call__( + self, module: nn.Module, args: tuple[torch.Tensor, ...], output: torch.Tensor + ) -> None: + """Forward hook that reconstructs router logits from hidden states.""" + hidden_states = args[0] + hidden_states = hidden_states.view(-1, module.config.hidden_size) + router_logits = nn.functional.linear( + hidden_states.type(torch.float32), module.weight.type(torch.float32) + ) + super().__call__(module, args, router_logits) + + +class Qwen3VLRemoveExpertsIndependentHook(RemoveExpertsIndependentHook): + """Expert removal importance hook for Qwen3-VL models.""" + + def get_router_logits_and_routed_experts( + self, hidden_states: torch.Tensor, router_logits: torch.Tensor | None = None + ) -> tuple[torch.Tensor, torch.Tensor]: + """Extract router logits and expert outputs for Qwen3-VL MoE. + + Based on Qwen3VLMoeSparseMoe forward pass. + """ + orig_shape = hidden_states.shape + + # Flatten to (num_tokens, hidden_size) for processing + hidden_states_flat = hidden_states.reshape(-1, self.moe.hidden_size) + + if router_logits is None: + router_logits = self.moe.gate(hidden_states_flat) + + routing_weights = torch.nn.functional.softmax(router_logits, dim=-1, dtype=torch.float) + routing_weights, router_indices = torch.topk(routing_weights, self.moe.top_k, dim=-1) + routing_weights = routing_weights / routing_weights.sum(dim=-1, keepdim=True) + routing_weights = routing_weights.to(hidden_states_flat.dtype) + router_weights = torch.zeros_like(router_logits).scatter_( + 1, router_indices, routing_weights + ) + + # Reshape hidden_states for moe.experts (expects 3D: batch, seq, hidden) + # router_weights and router_indices remain 2D (num_tokens, num_experts) + batch_size = orig_shape[0] if hidden_states.ndim == 3 else 1 + hidden_states_3d = hidden_states_flat.reshape(batch_size, -1, self.moe.hidden_size) + + routed_out = self.moe.experts(hidden_states_3d, router_weights, router_indices) + + # Return in same shape as input + routed_out = routed_out.reshape(*orig_shape) + + return router_logits, routed_out diff --git a/modelopt/torch/puzzletron/README.md b/modelopt/torch/puzzletron/README.md new file mode 100644 index 0000000000..4c6da80e54 --- /dev/null +++ b/modelopt/torch/puzzletron/README.md @@ -0,0 +1,3 @@ +Experimental model compression algorithm based on a Local Neural Architecture Search. +Based on the Puzzle paper: +PoC for Llama 3.1 model. diff --git a/modelopt/torch/puzzletron/__init__.py b/modelopt/torch/puzzletron/__init__.py new file mode 100644 index 0000000000..47f1c65a15 --- /dev/null +++ b/modelopt/torch/puzzletron/__init__.py @@ -0,0 +1,15 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + diff --git a/modelopt/torch/puzzletron/activation_scoring/activation_hooks/__init__.py b/modelopt/torch/puzzletron/activation_scoring/activation_hooks/__init__.py new file mode 100644 index 0000000000..47f1c65a15 --- /dev/null +++ b/modelopt/torch/puzzletron/activation_scoring/activation_hooks/__init__.py @@ -0,0 +1,15 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + diff --git a/modelopt/torch/puzzletron/activation_scoring/activation_hooks/utils.py b/modelopt/torch/puzzletron/activation_scoring/activation_hooks/utils.py new file mode 100644 index 0000000000..ccf73f7612 --- /dev/null +++ b/modelopt/torch/puzzletron/activation_scoring/activation_hooks/utils.py @@ -0,0 +1,96 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# mypy: ignore-errors + +"""Provides a function to register activation hooks for a model. +Activation hooks are used to compute activation scores for pruning.""" + +from typing import Type + +import torch + +from modelopt.torch.prune.importance_hooks.base_hooks import ForwardHook as ActivationsHook +from modelopt.torch.puzzletron.tools.logger import aprint +from modelopt.torch.puzzletron.utils.dummy_modules import DummyBlock, DummyModule + + +def register_activation_hooks( + model, + activation_hooks_kwargs: dict, + pruning_mixin, + hook_class: Type[ActivationsHook], +) -> dict[str, ActivationsHook]: + """Register activation hooks using the pruning mixin approach. + + Args: + model: The model to register hooks on. + activation_hooks_kwargs: Keyword arguments passed to hook constructors. + pruning_mixin: The pruning mixin that defines which modules to hook. + hook_class: The hook class to instantiate for each module. + + Returns: + Dictionary mapping module names to hook instances. + """ + activation_hooks_kwargs["model"] = model + + if hook_class not in pruning_mixin.supported_hooks(): + raise ValueError( + f"Hook class not supported for {pruning_mixin.__class__.__name__}, " + f"must be in {pruning_mixin.supported_hooks()}" + ) + + module_names_to_hook = pruning_mixin.get_module_names_to_hook(model) + activation_hooks = dict() + for block_idx, module_name in module_names_to_hook: + try: + module = model.get_submodule(module_name) + except AttributeError: + # Module doesn't exist on this rank's shard (e.g., in distributed setup) + continue + + # Skip dummy modules - they don't have real activations to hook + if isinstance(module, (DummyModule, DummyBlock)): + continue + + block_config = None + if block_idx is not None: + block_config = model.config.block_configs[block_idx] + curr_activation_hooks_kwargs = { + **activation_hooks_kwargs, + "block_config": block_config, + } + + hook = hook_class(module, curr_activation_hooks_kwargs) + module.register_forward_hook(hook) + activation_hooks[module_name] = hook + + if len(activation_hooks) == 0: + # In distributed mode, it's okay for a rank to have 0 hooks if it doesn't own + # the target modules (e.g., with hybrid patterns like "*-" where different + # ranks own different layer types). However, we still want to catch real bugs + # where no hooks are found at all. + is_distributed = torch.distributed.is_available() and torch.distributed.is_initialized() + if is_distributed: + aprint( + "No hooks registered on this rank. This is expected if this rank " + "doesn't own any layers matching the hook pattern (e.g., in hybrid " + "patterns with distributed model sharding)." + ) + else: + raise ValueError("couldn't find any hooks") + + if len(activation_hooks) > 0: + aprint(f"Found the following hooks: {activation_hooks.keys()}") + return activation_hooks diff --git a/modelopt/torch/puzzletron/activation_scoring/score_pruning_activations.py b/modelopt/torch/puzzletron/activation_scoring/score_pruning_activations.py new file mode 100644 index 0000000000..c043c20d5f --- /dev/null +++ b/modelopt/torch/puzzletron/activation_scoring/score_pruning_activations.py @@ -0,0 +1,141 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from pathlib import Path + +import torch +from omegaconf import DictConfig + +import modelopt.torch.utils.distributed as dist +from modelopt.torch.puzzletron.tools.logger import mprint +from modelopt.torch.puzzletron.tools.validate_model import validate_model + + +def has_checkpoint_support(activation_hooks_kwargs: dict) -> bool: + """Determine if the activation hook method has proper checkpoint support implemented. + + Args: + activation_hooks_kwargs: Hook configuration + + Returns: + bool: True if the hook method has save_state/load_state implemented + """ + method = activation_hooks_kwargs.get("method", "") + + # Methods with implemented checkpoint support + supported_methods = { + "iterative", # IterativeChannelContributionHook: save_state/load_state implemented + "independent", # IndependentChannelContributionHook: save_state/load_state implemented + "stats", # RouterStatsHook: save_state/load_state implemented + "ranked_choice_voting", # RankedChoiceVotingHook: save_state/load_state implemented + } + + return method in supported_methods + + +def check_scoring_completion(activations_log_dir: str, activation_hooks_kwargs=None) -> bool: + """Check if scoring is already completed by looking for the expected output files. + Also checks if the scoring method is safe for resume. + + Args: + activations_log_dir: Directory where activation logs should be stored + activation_hooks_kwargs: Hook configuration to check if resume is safe + + Returns: + bool: True if scoring is completed (has rank files and args.json) + """ + # Only check completion on main process + if dist.is_master(): + log_dir = Path(activations_log_dir) + + # Check if directory exists + if not log_dir.exists(): + return False + + # Check for rank files (at least rank_0.pth should exist) + rank_files = list(log_dir.glob("rank_*.pth")) + + if not rank_files: + return False + + # Check for args.json (created by main process) + args_file = log_dir / "args.json" + has_args_json = args_file.exists() + + # Check for completion: if we have rank files and args.json, scoring is complete + if rank_files and has_args_json: + # Add optional completion info for debugging + mprint(f"Found completed scoring in {activations_log_dir}") + mprint(f" - Found {len(rank_files)} rank files") + mprint(f" - Found args.json: {has_args_json}") + + return True + + return False + + +def should_skip_scoring_completely(cfg: DictConfig) -> bool: + """Determine if we should skip scoring entirely (only if 100% complete). + Partial progress should proceed to validate_model for proper resume. + + Args: + cfg: Configuration object + + Returns: + bool: True if we should skip scoring (100% completed), False if we should run/resume it + """ + # Check if activations_log_dir is specified + if not hasattr(cfg.pruning, "activations_log_dir") or cfg.pruning.activations_log_dir is None: + mprint("No activations_log_dir specified, running scoring") + return False + + # Check for force restart flag + force_restart = getattr(cfg.pruning, "force_restart_scoring", False) + if force_restart: + mprint("Force restart flag set, will restart scoring regardless of existing artifacts") + return False + + # Get hook configuration to check if resume is mathematically safe + activation_hooks_kwargs = getattr(cfg.pruning, "activation_hooks_kwargs", {}) + + # Check if scoring is already completed + is_completed = check_scoring_completion( + cfg.pruning.activations_log_dir, activation_hooks_kwargs + ) + + # Broadcast the result to all processes in distributed mode + if dist.size() > 1: + should_skip = [is_completed] # Use list for mutable object + torch.distributed.broadcast_object_list(should_skip, src=0) + is_completed = should_skip[0] + + if is_completed: + mprint("Scoring 100% completed, skipping...") + + return is_completed + + +# Old progress tracking removed - checkpoint manager handles all progress tracking + + +def launch_score_activations(cfg: DictConfig): + # Check if we should skip scoring entirely (only if 100% complete) + if should_skip_scoring_completely(cfg): + return + + mprint("Starting pruning activation scoring...") + + # The checkpoint manager inside validate_model handles all progress tracking + validate_model(args=cfg.pruning) diff --git a/modelopt/torch/puzzletron/anymodel/README.md b/modelopt/torch/puzzletron/anymodel/README.md new file mode 100644 index 0000000000..291966eb7b --- /dev/null +++ b/modelopt/torch/puzzletron/anymodel/README.md @@ -0,0 +1,204 @@ +# AnyModel Guide + +This guide explains how to add support for new models in the Puzzletron pipeline. + +## Convert model + +Convert a HuggingFace model to Puzzletron format. + +Step 1: Create Model Descriptor + +Extend `ModelDescriptor` and implement `layer_name_predicates()` to define regex patterns for grouping weights into subblocks (embeddings, lm_head, block_N_ffn, block_N_attention). + +Key points: + +- Find weight names on the model's HuggingFace page → click "Files info" to see the safetensors structure with all tensor names (example: [Llama-3.1-8B-Instruct](https://huggingface.co/meta-llama/Llama-3.1-8B-Instruct?show_file_info=model.safetensors.index.json)) + +See example: [llama_model_descriptor.py](models/llama/llama_model_descriptor.py) + +Step 2: Create Converter + +Extend `Converter` and implement `create_block_configs_from_main_config()` to create per-layer BlockConfigs from the HuggingFace config. + +Key points: + +- Import correct HuggingFace config class (e.g., `MistralConfig`, `LlamaConfig`, `Qwen2Config`). Find it in the transformers source: `github.com/huggingface/transformers/tree/main/src/transformers/models//configuration_.py` + +See example: [llama_converter.py](models/llama/llama_converter.py) + +Step 3: Create `models//__init__.py` + +Export descriptor and converter classes: + +```python +from models.._model_descriptor import MyModelDescriptor +from models.._converter import MyConverter +``` + +Step 4: Register in `models/__init__.py` + +Add import to trigger factory registration: + +```python +from models. import * +``` + +## Usage + +```python +from modelopt.torch.puzzletron.anymodel import convert_model + +convert_model( + input_dir="path/to/hf_checkpoint", + output_dir="path/to/puzzletron_checkpoint", + converter="model_name", +) +``` + +## Compress model + +Run pruning and compression on a Puzzletron model. + +Step 1: Implement ModelDescriptor methods for compression + +Add to your `ModelDescriptor`: + +- `decoder_layer_cls()` - return the decoder layer class(es) to patch for heterogeneous config support +- `block_config_to_layer_overrides()` - map BlockConfig to layer override dict (see [details](#implementing-block_config_to_layer_overrides)) +- `init_rotary_embedding()` - reinitialize rotary embeddings after model loading (see [details](#implementing-init_rotary_embedding)) +- `input_embedding_name()` - return the name of the input embedding layer (see [details](#implementing-path-based-methods)) +- `output_embedding_name()` - return the name of the output embedding layer (see [details](#implementing-path-based-methods)) +- `layer_block_name()` - return the name pattern for decoder layers (see [details](#implementing-path-based-methods)) +- `final_norm_name()` - return the name of the final normalization layer (see [details](#implementing-path-based-methods)) +- `attn_no_op_post_init()` - replace attention sublayers with no-op modules +- `mlp_no_op_post_init()` - replace MLP sublayers with no-op modules + +Step 2: Create FFN Layer Descriptor + +Extend `FFNIntermediateLayerDescriptor` to define model-specific paths for FFN pruning hooks (`down_proj_name`, `ffn_prefix_name`, `linear_weight_names`). Derive values from your model's weight names in `layer_name_predicates()`. + +See example: [llama_model_descriptor.py](models/llama/llama_model_descriptor.py) → `LlamaFFNIntermediateLayerDescriptor` + +Step 3: Configure YAML files + +Update the main model config YAML: + +- Set `descriptor` to match the name used in `@ModelDescriptorFactory.register_decorator("your_model_name")` +- See example: [llama_3_1_8b_instruct.yaml](../../../../tests/gpu/torch/puzzletron/resources/configs/llama_3_1_8b_instruct/llama_3_1_8b_instruct.yaml) + +Update pruning YAML files (`ffn_pruning.yaml`, `expert_pruning.yaml`, etc.): + +- Set `pruning_mixin._target_` to the appropriate mixin class +- Set `layer_descriptor._target_` to your layer descriptor class +- Set `hook_class` to the activation hook for scoring +- Set `target_layer` in `activation_hooks_kwargs` to the layer name for hook attachment +- See examples in [configs/llama_3_1_8b_instruct/pruning/](../../../../tests/gpu/torch/puzzletron/resources/configs/llama_3_1_8b_instruct/pruning/) + +## End-to-end example + +See [test_puzzletron.py](../../../../tests/gpu/torch/puzzletron/test_puzzletron.py) for a complete example that runs both convert and compression steps. + +--- + +## Advanced Topics + +## Pruning Configuration + +### Pruning YAML Structure + +Each pruning type has a YAML config with these key fields: + +```yaml +pruning_mixin: + _target_: pruning._pruning_mixin. + layer_descriptor: + _target_: models.. + +hook_class: ${get_object:utils.activation_hooks.hooks.} +activation_hooks_kwargs: + method: + target_layer: "" # e.g., "mlp.down_proj", "self_attn.o_proj" +``` + +| Field | Description | +|-------|-------------| +| `pruning_mixin._target_` | Mixin class that orchestrates this pruning type | +| `layer_descriptor._target_` | Model-specific class defining layer paths for hooks | +| `hook_class` | Activation hook class for importance scoring | +| `target_layer` | Layer name (relative to decoder block) where hooks attach | + +### Adding a New Hook Class + +1. **Implement the hook** under `modelopt/torch/prune/importance_hooks/` (e.g. `base_hooks.py` for generic hooks, `expert_removal_hooks.py` for MoE expert removal): + - Extend an existing hook base class (e.g., `RemoveExpertsIndependentHook` in `expert_removal_hooks.py`) + - Implement required methods (e.g., `get_router_logits_and_routed_experts`) + +2. **Register the hook** in the appropriate pruning mixin's `supported_hooks()`: + + For FFN pruning (`pruning/ffn_intermediate_pruning_mixin.py`): + + ```python + def supported_hooks(self) -> List[Type[ActivationsHook]]: + return [IndependentChannelContributionHook, IterativeChannelContributionHook, YourNewHook] + ``` + + For expert removal (`pruning/expert_removal_pruning_mixin.py`): + + ```python + def supported_hooks(self) -> List[Type[ActivationsHook]]: + return [RankedChoiceVotingHook, ..., YourNewHook] + ``` + +3. **Reference in YAML**: + + ```yaml + hook_class: ${get_object:utils.activation_hooks.hooks.YourNewHook} + ``` + +### Pruning Types Reference + +| Type | Mixin | Example Hooks | +|------|-------|---------------| +| FFN intermediate | [`FFNIntermediatePruningMixIn`](../pruning/ffn_intermediate_pruning_mixin.py) | [`IterativeChannelContributionHook`](../../prune/importance_hooks/base_hooks.py), [`IndependentChannelContributionHook`](../../prune/importance_hooks/base_hooks.py) | +| Expert removal | [`ExpertRemovalPruningMixIn`](../pruning/expert_removal_pruning_mixin.py) | [`NemotronHRemoveExpertsIndependentHook`](../../prune/importance_hooks/expert_removal_hooks.py), [`Qwen3VLRemoveExpertsIndependentHook`](../../prune/importance_hooks/expert_removal_hooks.py) | +| KV heads | [`KVHeadsPruningMixIn`](../pruning/kv_heads_pruning_mixin.py) | [`IndependentKvHeadContributionHook`](../../prune/importance_hooks/base_hooks.py) | + +## Implementing `block_config_to_layer_overrides` + +Maps Puzzletron's [`BlockConfig`](../decilm/deci_lm_hf_code/block_config.py) fields to HuggingFace config attribute names. Only override attributes that change during pruning: + +| BlockConfig Field | HuggingFace Attribute (check `config.json`) | +|-------------------|---------------------------------------------| +| `attention.num_key_value_heads` | `num_key_value_heads` | +| `ffn.intermediate_size` | `intermediate_size` | +| `ffn.moe.num_local_experts` | `num_experts` or `n_routed_experts` (model-specific) | +| `ffn.moe.expert_intermediate_dim` | `moe_intermediate_size` | + +**Tip**: Check the model's `config.json` for exact attribute names - they vary between models. + +See examples: [qwen3_vl](models/qwen3_vl/qwen3_vl_model_descriptor.py), [nemotron_h](models/nemotron_h/nemotron_h_model_descriptor.py) + +--- + +## Implementing path-based methods + +These methods return paths derived from the model's weight names: + +- `input_embedding_name()`, `output_embedding_name()`, `layer_block_name()`, `final_norm_name()` + +Find them on the model's HuggingFace page → "Files info" → safetensors structure (example: [Llama-3.1-8B-Instruct](https://huggingface.co/meta-llama/Llama-3.1-8B-Instruct?show_file_info=model.safetensors.index.json)). + +See example: [llama_model_descriptor.py](models/llama/llama_model_descriptor.py) + +--- + +## Implementing `init_rotary_embedding` + +Rotary embeddings are computed modules (not saved weights). After model sharding, they need re-initialization on the correct device/dtype. + +Look in `github.com/huggingface/transformers/tree/main/src/transformers/models//modeling_.py` for: + +- `class.*Rotary` — the rotary embedding class name and constructor arguments +- `self.rotary_emb` — the attribute path + +See example: [llama_model_descriptor.py](models/llama/llama_model_descriptor.py) diff --git a/modelopt/torch/puzzletron/anymodel/__init__.py b/modelopt/torch/puzzletron/anymodel/__init__.py new file mode 100644 index 0000000000..e1755a16d8 --- /dev/null +++ b/modelopt/torch/puzzletron/anymodel/__init__.py @@ -0,0 +1,64 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# mypy: ignore-errors + +"""AnyModel: Architecture-agnostic model compression for HuggingFace models. + +This module provides a declarative approach to model compression that works with +any HuggingFace model without requiring custom modeling code. Instead of duplicating +HuggingFace modeling classes, AnyModel uses ModelDescriptors that define: + +1. Which decoder layer class(es) to patch for heterogeneous configs +2. How to map BlockConfig to layer-specific overrides +3. Weight name patterns for subblock checkpointing + +Example usage: + >>> from modelopt.torch.puzzletron.anymodel import convert_model + >>> convert_model( + ... input_dir="path/to/hf_checkpoint", + ... output_dir="path/to/anymodel_checkpoint", + ... converter="llama", + ... ) + +Supported models: + - llama: Llama 2, Llama 3, Llama 3.1, Llama 3.2 + - (more to come: qwen2, mistral_small, etc.) +""" + +# Import models to trigger factory registration +from modelopt.torch.puzzletron.anymodel import models # noqa: F401 +from modelopt.torch.puzzletron.anymodel.converter import Converter, ConverterFactory, convert_model +from modelopt.torch.puzzletron.anymodel.model_descriptor import ( + ModelDescriptor, + ModelDescriptorFactory, +) +from modelopt.torch.puzzletron.anymodel.puzzformer import ( + MatchingZeros, + Same, + deci_x_patcher, + return_tuple_of_size, +) + +__all__ = [ + "Converter", + "ConverterFactory", + "ModelDescriptor", + "ModelDescriptorFactory", + "deci_x_patcher", + "MatchingZeros", + "Same", + "return_tuple_of_size", + "convert_model", +] diff --git a/modelopt/torch/puzzletron/anymodel/converter/__init__.py b/modelopt/torch/puzzletron/anymodel/converter/__init__.py new file mode 100644 index 0000000000..02903b817d --- /dev/null +++ b/modelopt/torch/puzzletron/anymodel/converter/__init__.py @@ -0,0 +1,19 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Converters for transforming HuggingFace models to AnyModel format.""" + +from .convert_any_model import * +from .converter import * +from .converter_factory import * diff --git a/modelopt/torch/puzzletron/anymodel/converter/convert_any_model.py b/modelopt/torch/puzzletron/anymodel/converter/convert_any_model.py new file mode 100644 index 0000000000..889685c001 --- /dev/null +++ b/modelopt/torch/puzzletron/anymodel/converter/convert_any_model.py @@ -0,0 +1,68 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# mypy: ignore-errors + +"""Convert a HuggingFace model to AnyModel format.""" + +from pathlib import Path + +from modelopt.torch.puzzletron.anymodel.converter.converter import Converter +from modelopt.torch.puzzletron.anymodel.converter.converter_factory import ConverterFactory +from modelopt.torch.puzzletron.anymodel.model_descriptor import ModelDescriptorFactory + +__all__ = ["convert_model"] + + +def convert_model( + input_dir: str, + output_dir: str, + converter: Converter | str, +): + """Convert a HuggingFace model to AnyModel format. + + This function converts a HuggingFace checkpoint to the AnyModel format used + for compression. The conversion process: + + 1. Copies non-weight files (config, tokenizer, etc.) + 2. Creates block_configs for each layer + 3. Reorganizes weights into subblock checkpoints + + Args: + input_dir: Path to the input HuggingFace checkpoint directory. + output_dir: Path to the output AnyModel checkpoint directory. + converter: Either a converter name (e.g., "llama") or a Converter class. + + Example: + >>> convert_model( + ... input_dir="/path/to/Llama-3.1-8B-Instruct", + ... output_dir="/path/to/output/ckpts/teacher", + ... converter="llama", + ... ) + """ + input_dir = Path(input_dir) + output_dir = Path(output_dir) + output_dir.mkdir(parents=True, exist_ok=True) + + # Get descriptor and converter from factories (they use the same name) + descriptor = ModelDescriptorFactory.get(converter) + converter = ConverterFactory.get(converter) + + converter.convert(descriptor=descriptor, input_dir=input_dir, output_dir=output_dir) + + +if __name__ == "__main__": + from fire import Fire + + Fire(convert_model) diff --git a/modelopt/torch/puzzletron/anymodel/converter/converter.py b/modelopt/torch/puzzletron/anymodel/converter/converter.py new file mode 100644 index 0000000000..eb2330b515 --- /dev/null +++ b/modelopt/torch/puzzletron/anymodel/converter/converter.py @@ -0,0 +1,239 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# mypy: ignore-errors + +import copy +import fnmatch +import json +import os +import shutil +from abc import ABC, abstractmethod +from collections import defaultdict +from pathlib import Path +from typing import Dict, List + +from safetensors.torch import load_file, save_file +from tqdm import tqdm +from transformers import PretrainedConfig +from transformers.integrations.mxfp4 import convert_moe_packed_tensors + +from modelopt.torch.puzzletron.anymodel.model_descriptor import ModelDescriptor +from modelopt.torch.puzzletron.decilm.deci_lm_hf_code.block_config import BlockConfig +from modelopt.torch.puzzletron.tools.checkpoint_utils_hf import load_model_config, save_model_config + +__all__ = ["Converter"] + + +class Converter(ABC): + """Base class for converting HuggingFace models to Puzzletron/AnyModel format.""" + + @staticmethod + def _get_weight_map(input_dir: Path) -> Dict[str, str]: + """Load weight map from checkpoint directory (supports both sharded and single-file models). + + Returns a dict mapping parameter names to their safetensors filenames. + """ + index_path = input_dir / "model.safetensors.index.json" + single_file_path = input_dir / "model.safetensors" + + if index_path.exists(): + # Sharded model + with open(index_path, "r") as f: + index = json.load(f) + return index["weight_map"] + elif single_file_path.exists(): + # Single file model - create a synthetic weight map + data = load_file(single_file_path) + return {name: "model.safetensors" for name in data.keys()} + else: + raise FileNotFoundError( + f"Neither {index_path} nor {single_file_path} found. Cannot determine model format." + ) + + @classmethod + def convert_model_weights( + cls, input_dir: Path, output_dir: Path, descriptor: ModelDescriptor, num_hidden_layers: int + ): + """Convert model weights to subblock format.""" + param_to_file = Converter._get_weight_map(input_dir) + all_param_names = list(param_to_file.keys()) + + # Reverse map: file -> set of params + file_to_params = defaultdict(set) + for name, file in param_to_file.items(): + file_to_params[file].add(name) + + # Determine subblocks needed + subblocks = descriptor.get_weight_groups( + all_param_names, num_hidden_layers=num_hidden_layers + ) + + # Output directory + out_dir = output_dir / "subblocks_safetensors" + os.makedirs(out_dir, exist_ok=True) + + # New weight index + new_index = {"metadata": {"format": "pt"}, "weight_map": {}} + + for subblock, param_names in tqdm(subblocks.items(), desc="Processing subblocks"): + param_files = set(param_to_file[name] for name in param_names) + tensors = {} + + # Load only needed files for this subblock + for file in param_files: + data = load_file(os.path.join(input_dir, file)) + for name in param_names: + if param_to_file[name] == file and name in data: + converted_name = cls.convert_weight_name(name) + # Convert MoE packed tensors if quantized is mxfp4 //gpt-oss-20b + if getattr(cls, "quantized", None) == "mxfp4": + if name.endswith("_blocks"): + converted_name = converted_name.replace("_blocks", "") + tensors[converted_name] = convert_moe_packed_tensors( + data[converted_name + "_blocks"], + data[converted_name + "_scales"], + ) + elif name.endswith("_scales"): + continue + else: + tensors[converted_name] = data[name] + else: + tensors[converted_name] = data[name] + + # Save this subblock + print(f"\n✅ Group: {subblock} ({len(tensors)} layers)") + for layer in tensors.keys(): + print(f" - {layer}") + + subblock_file = f"{subblock}.safetensors" + save_file(tensors, os.path.join(out_dir, subblock_file)) + + # Update index + for new_name in tensors.keys(): + new_index["weight_map"][new_name] = f"subblocks_safetensors/{subblock_file}" + + # Save new index file + with (output_dir / "model.safetensors.index.json").open("w") as f: + json.dump(new_index, f, indent=2) + + print(f"✅ Finished saving subblocks and index to {output_dir}") + + @classmethod + def convert_configs_in_dirs( + cls, + input_dir: Path, + output_dir: Path, + trust_remote_code: bool = False, + ): + """Convert config and add block_configs.""" + config = load_model_config(input_dir, trust_remote_code=trust_remote_code) + + block_configs = cls.create_block_configs_from_main_config(config) + out_config = copy.deepcopy(config) + out_config.block_configs = block_configs + + save_model_config(out_config, output_dir) + return out_config + + @staticmethod + def copy_checkpoint_files(input_dir: Path, output_dir: Path): + """Copy checkpoint files except model weights (which will be converted).""" + ignore_patterns = [ + "model-*.safetensors", + "model.safetensors", + "model.safetensors.index.json", + "subblocks_safetensors", + ] + + def ignore_func(dir, files): + ignored = set() + for pattern in ignore_patterns: + ignored.update(fnmatch.filter(files, pattern)) + return ignored + + shutil.copytree(str(input_dir), str(output_dir), ignore=ignore_func, dirs_exist_ok=True) + + @classmethod + def convert( + cls, + descriptor: ModelDescriptor, + input_dir: Path, + output_dir: Path, + ): + """Convert a HuggingFace model to AnyModel format. + + Args: + descriptor: Model descriptor for the model type. + input_dir: Path to the input HuggingFace checkpoint. + output_dir: Path to the output AnyModel checkpoint. + """ + cls.copy_checkpoint_files(input_dir, output_dir) + trust_remote_code = descriptor.requires_trust_remote_code() + config = cls.convert_configs_in_dirs( + input_dir, output_dir, trust_remote_code=trust_remote_code + ) + cls.convert_model_weights( + input_dir, output_dir, descriptor=descriptor, num_hidden_layers=config.num_hidden_layers + ) + + @staticmethod + @abstractmethod + def create_block_configs_from_main_config(config: PretrainedConfig) -> List[BlockConfig]: + """Create per-layer BlockConfig list from a HuggingFace model config. + + This method extracts layer-specific parameters (e.g., intermediate_size, + num_key_value_heads) from the main model config and creates a BlockConfig + for each layer. These BlockConfigs enable layer-specific pruning and + modifications during the compression pipeline. + + Args: + config: HuggingFace PretrainedConfig (e.g., LlamaConfig, Qwen2Config) + + Returns: + List of BlockConfig, one per hidden layer. Each BlockConfig contains: + - AttentionConfig: attention settings (no_op, num_key_value_heads) + - FFNConfig: FFN settings (no_op, intermediate_size) + + Example: + For a model with uniform layers (e.g., Llama): + return [BlockConfig(...)] * config.num_hidden_layers + + For a model with heterogeneous layers (e.g., NemotronH with Mamba/Attention): + return [BlockConfig(...) for layer_idx in range(num_layers)] + """ + raise NotImplementedError + + @staticmethod + def convert_weight_name(name: str) -> str: + """ + Convert weight names during checkpoint conversion. + + This method can be overridden by subclasses to apply model-specific weight name + transformations when converting checkpoints from HuggingFace format to Puzzletron format. + + Default implementation returns the name unchanged (identity function). + + Args: + name: Original weight name from HuggingFace checkpoint + + Returns: + Converted weight name for Puzzletron format + + Example: + For Qwen2.5-VL, this converts: + - visual.* → model.visual.* + - model.* → model.language_model.* + """ + return name diff --git a/modelopt/torch/puzzletron/anymodel/converter/converter_factory.py b/modelopt/torch/puzzletron/anymodel/converter/converter_factory.py new file mode 100644 index 0000000000..88d490d653 --- /dev/null +++ b/modelopt/torch/puzzletron/anymodel/converter/converter_factory.py @@ -0,0 +1,75 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# mypy: ignore-errors + +import inspect +from typing import Callable, Type + +from modelopt.torch.puzzletron.anymodel.model_descriptor import ModelDescriptor + +__all__ = ["ConverterFactory"] + + +class ConverterFactory: + """Factory for registering and retrieving Converter classes.""" + + CLASS_MAPPING = {} + + @classmethod + def register(cls, **entries: Type): + """Register converter classes. + + Raises: + KeyError: if entry key is already in type_dict and points to a different class. + """ + for cls_name, cls_type in entries.items(): + if cls_name in cls.CLASS_MAPPING: + ref = cls.CLASS_MAPPING[cls_name] + # If ref and cls_name point to the same class ignore and don't raise an exception. + if cls_type == ref: + continue + raise KeyError( + f"Could not register `{cls_name}`: {cls_type}, " + f"`{cls_name}` is already registered and points to " + f"`{inspect.getmodule(ref).__name__}.{ref.__name__}`" + ) + cls.CLASS_MAPPING[cls_name] = cls_type + + @classmethod + def register_decorator(cls, name: str | None) -> Callable: + """Set up a register decorator. + + Args: + name: If specified, the decorated object will be registered with this name. + + Returns: + Decorator that registers the callable. + """ + + def decorator(cls_type: Type) -> Callable: + """Register the decorated callable.""" + cls_name = name if name is not None else cls_type.__name__ + cls.register(**{cls_name: cls_type}) + return cls_type + + return decorator + + @classmethod + def get(cls, value: str | ModelDescriptor): + """Get a registered converter by name or return the converter if already resolved.""" + if isinstance(value, str): + if value in cls.CLASS_MAPPING: + return cls.CLASS_MAPPING[value] + return value diff --git a/modelopt/torch/puzzletron/anymodel/model_descriptor/__init__.py b/modelopt/torch/puzzletron/anymodel/model_descriptor/__init__.py new file mode 100644 index 0000000000..cc8e89e34b --- /dev/null +++ b/modelopt/torch/puzzletron/anymodel/model_descriptor/__init__.py @@ -0,0 +1,18 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Model descriptors for defining model-specific properties and layer naming conventions.""" + +from .model_descriptor import * +from .model_descriptor_factory import * diff --git a/modelopt/torch/puzzletron/anymodel/model_descriptor/model_descriptor.py b/modelopt/torch/puzzletron/anymodel/model_descriptor/model_descriptor.py new file mode 100644 index 0000000000..4cc4356c8e --- /dev/null +++ b/modelopt/torch/puzzletron/anymodel/model_descriptor/model_descriptor.py @@ -0,0 +1,228 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import re +from abc import ABC, abstractmethod +from collections import defaultdict +from typing import Any, Dict, Iterable, List, Type + +import torch.nn as nn + +from modelopt.torch.puzzletron.decilm.deci_lm_hf_code.block_config import BlockConfig +from modelopt.torch.puzzletron.utils.dummy_modules import DummyBlock + +__all__ = ["ModelDescriptor"] + + +class ModelDescriptor(ABC): + @staticmethod + @abstractmethod + def decoder_layer_cls() -> Type[nn.Module] | List[Type[nn.Module]]: + """Decoder layer class types to patch for heterogeneous config support. + + In most cases this class will hold as attributes both FFN & attention layers. + + Returns: + nn.Module class type or a list if several class types should be patched. + """ + raise NotImplementedError + + @staticmethod + @abstractmethod + def block_config_to_layer_overrides(block_config: BlockConfig) -> Dict[str, Any]: + """Map between BlockConfig and layer config overrides. + + These overrides are consumed by a specific decoder layer and by the whole model. + Usage can be seen in `deci_x_patcher` under the method `_patched_decoder_layer_init`. + + Example implementation to override the FFN intermediate size of a block: + >>> def block_config_to_layer_overrides(block_config: BlockConfig) -> Dict[str, Any]: + >>> return {"intermediate_size": block_config.ffn.intermediate_size} + """ + raise NotImplementedError + + @staticmethod + def requires_trust_remote_code() -> bool: + """Whether this model descriptor requires trust_remote_code=True for loading. + + Models that use custom code (e.g., via auto_map in config) should override + this to return True. + + Returns: + True if trust_remote_code=True is required, False otherwise. + """ + return False + + @staticmethod + def mlp_no_op_post_init(decoder_layer: nn.Module): + """Post-init callback to alter a decoder layer so that FFN/mlp subblock performs as no-op. + + It is recommended to use the utils modules from `no_op.py` to replace layers to dummy + counterparts. + + Example for replacing a layernorm layer with identity: + + >>> decoder_layer.post_attention_layernorm = Same() + + Example for replacing an MLP layer with zeroes (zeroes since hidden_states are added to + the residuals hidden_states so a no-op implementation will leave residual the same): + + >>> decoder_layer.mlp = MatchingZeros() + + In case the MLP layer to replace returns multiple outputs i.e `hidden_states, _ = self.mlp()`, + use the util method `return_tuple_of_size` to return trailing None values: + + >>> decoder_layer.mlp = return_tuple_of_size(MatchingZeros, size=2)() + """ + raise NotImplementedError + + @staticmethod + def attn_no_op_post_init(decoder_layer: nn.Module): + """Post-init callback to alter a decoder layer so that Attention subblock performs as no-op. + + It is recommended to use the utils modules from `no_op.py` to replace layers to dummy + counterparts. + + Example for replacing a layernorm layer with identity: + + >>> decoder_layer.post_attention_layernorm = Same() + + Example for replacing an attention layer with zeroes: + + >>> decoder_layer.self_attn = MatchingZeros() + + In case the attention layer returns multiple outputs i.e `hidden_states, _ = self.self_attn()`, + use the util method `return_tuple_of_size` to return trailing None values: + + >>> decoder_layer.self_attn = return_tuple_of_size(MatchingZeros, size=2)() + """ + raise NotImplementedError + + @staticmethod + @abstractmethod + def init_rotary_embedding(model, runtime): + """Re-initiate the rotary embeddings based on an existing model. + + In puzzletron we initiate a sharded model by first creating a meta model then replacing + to the actual device by loading the state_dict with the real weights. + + Rotary embeddings frequencies are tensor buffers that are created dynamically during init + and are not part of the model state_dict, so cannot be restored after a meta device + initialization. + """ + raise NotImplementedError + + @staticmethod + @abstractmethod + def input_embedding_name(): + """Return the name of the input embedding layer.""" + raise NotImplementedError + + @staticmethod + @abstractmethod + def output_embedding_name(): + """Return the name of the output embedding layer.""" + raise NotImplementedError + + @staticmethod + @abstractmethod + def final_norm_name(): + """Return the name of the final normalization layer.""" + raise NotImplementedError + + @staticmethod + @abstractmethod + def layer_block_name(index: int): + """Return the name of the decoder layer at the given index.""" + raise NotImplementedError + + @staticmethod + @abstractmethod + def layer_name_predicates(num_layers: int) -> Dict[str, re.Pattern]: + """Return predicates for grouping model weights to support subblock checkpointing. + + For every group name return a regex predicate whether a layer name is part of the group. + + Returns: + Dictionary of group name to regex pattern predicate. + """ + raise NotImplementedError + + @staticmethod + def uses_autocast() -> bool: + """Whether this model supports torch.autocast. + + Some models (e.g., Qwen3-VL MoE) have dtype bugs under autocast. + Override and return False for models that do not support autocast. + """ + return True + + @staticmethod + def get_language_model_config(config): + """Get the language model config from a PretrainedConfig. + + For regular LM models, returns the config itself. + For VL/multimodal models with nested configs, override to return the + language model portion (e.g., config.text_config for Qwen-VL). + """ + return config + + @classmethod + def create_dummy_block(cls, original_layer: nn.Module, block_index: int) -> nn.Module: + """Create a dummy block to replace a layer for sharded model initialization.""" + return DummyBlock(block_index=block_index) + + @classmethod + def mlp_no_op_supported(cls) -> bool: + """Check whether `mlp_no_op_post_init` is overridden for mlp no-op support.""" + method_name = ModelDescriptor.mlp_no_op_post_init.__name__ + return getattr(cls, method_name) is not getattr(ModelDescriptor, method_name) + + @classmethod + def attn_no_op_supported(cls): + """Check whether `attn_no_op_post_init` is overridden for attention no-op support.""" + method_name = ModelDescriptor.attn_no_op_post_init.__name__ + return getattr(cls, method_name) is not getattr(ModelDescriptor, method_name) + + @classmethod + def get_weight_groups( + cls, layer_names: Iterable[str], num_hidden_layers: int + ) -> Dict[str, List[str]]: + """Group model weights to support the puzzle subblock checkpointing format. + + This method uses the abstract method `layer_name_predicates` by default. + + Args: + layer_names: state_dict layer names of the model. + num_hidden_layers: number of decoder layers in the model. + + Returns: + Dictionary of group names to list of layer names per group, e.g.: + >>> { + ... "embedding": ["model.embed_tokens.weight"], + ... "lm_head": ["lm_head.weight", "model.norm.weight"], + ... "block_0_ffn": ["model.layers.0.mlp.down_proj", ...], + ... "block_0_attention": ["model.layers.0.self_attn.q_proj", ...], + ... } + """ + weight_groups = defaultdict(list) + for name in layer_names: + for group, pattern in cls.layer_name_predicates(num_hidden_layers).items(): + if pattern.match(name): + weight_groups[group].append(name) + break + else: + raise ValueError(f"Couldn't find a match for {name}") + return weight_groups diff --git a/modelopt/torch/puzzletron/anymodel/model_descriptor/model_descriptor_factory.py b/modelopt/torch/puzzletron/anymodel/model_descriptor/model_descriptor_factory.py new file mode 100644 index 0000000000..badbe2b0e3 --- /dev/null +++ b/modelopt/torch/puzzletron/anymodel/model_descriptor/model_descriptor_factory.py @@ -0,0 +1,122 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# mypy: ignore-errors + +import inspect +from typing import Callable, Type + +from transformers import AutoConfig + +from modelopt.torch.puzzletron.anymodel.model_descriptor.model_descriptor import ModelDescriptor + +__all__ = ["ModelDescriptorFactory"] + +# Map from HuggingFace config.model_type (in checkpoint config.json) to ModelDescriptorFactory name. +# Local to this script; add entries when supporting new model types for auto-detection. +_MODEL_TYPE_TO_DESCRIPTOR = { + "llama": "llama", + "mistral": "mistral_small", + "qwen2": "qwen2", + "qwen3": "qwen3", + "nemotron_h": "nemotron_h", + "nemotron_h_v2": "nemotron_h_v2", + "gpt_oss_20b": "gpt_oss_20b", +} + + +def resolve_descriptor_from_pretrained(pretrained: str, trust_remote_code: bool = False): + """Resolve the model descriptor by loading the checkpoint config and mapping model_type. + + Args: + pretrained: Path to a pretrained model checkpoint or HuggingFace model identifier. + trust_remote_code: If True, allows execution of custom code from the model repository. + This is a security risk if the model source is untrusted. Only set to True if you + trust the source of the model. Defaults to False for security. + + Returns: + The resolved ModelDescriptor class for the detected model type. + + Raises: + ValueError: If pretrained is not provided or if the model type cannot be auto-detected. + """ + + config = AutoConfig.from_pretrained(pretrained, trust_remote_code=trust_remote_code) + model_type = getattr(config, "model_type", None) + + if model_type and model_type in _MODEL_TYPE_TO_DESCRIPTOR: + detected = _MODEL_TYPE_TO_DESCRIPTOR[model_type] + print( + f"[resolve_descriptor_from_pretrained] Auto-detected model_type='{model_type}' → descriptor='{detected}'" + ) + return ModelDescriptorFactory.get(detected) + + known = sorted(_MODEL_TYPE_TO_DESCRIPTOR.keys()) + raise ValueError( + f"Cannot auto-detect descriptor for model_type='{model_type}'. " + f"Known model types: {known}. Add this model_type to _MODEL_TYPE_TO_DESCRIPTOR if supported." + ) + + +class ModelDescriptorFactory: + """Factory for registering and retrieving ModelDescriptor classes.""" + + CLASS_MAPPING = {} + + @classmethod + def register(cls, **entries: Type): + """Register model descriptor classes. + + Raises: + KeyError: if entry key is already in type_dict and points to a different class. + """ + for cls_name, cls_type in entries.items(): + if cls_name in cls.CLASS_MAPPING: + ref = cls.CLASS_MAPPING[cls_name] + # If ref and cls_name point to the same class ignore and don't raise an exception. + if cls_type == ref: + continue + raise KeyError( + f"Could not register `{cls_name}`: {cls_type}, " + f"`{cls_name}` is already registered and points to " + f"`{inspect.getmodule(ref).__name__}.{ref.__name__}`" + ) + cls.CLASS_MAPPING[cls_name] = cls_type + + @classmethod + def register_decorator(cls, name: str | None) -> Callable: + """Set up a register decorator. + + Args: + name: If specified, the decorated object will be registered with this name. + + Returns: + Decorator that registers the callable. + """ + + def decorator(cls_type: Type) -> Callable: + """Register the decorated callable.""" + cls_name = name if name is not None else cls_type.__name__ + cls.register(**{cls_name: cls_type}) + return cls_type + + return decorator + + @classmethod + def get(cls, value: str | ModelDescriptor): + """Get a registered model descriptor by name or return the descriptor if already resolved.""" + if isinstance(value, str): + if value in cls.CLASS_MAPPING: + return cls.CLASS_MAPPING[value] + return value diff --git a/modelopt/torch/puzzletron/anymodel/models/__init__.py b/modelopt/torch/puzzletron/anymodel/models/__init__.py new file mode 100644 index 0000000000..4c68dbc823 --- /dev/null +++ b/modelopt/torch/puzzletron/anymodel/models/__init__.py @@ -0,0 +1,24 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Import models to trigger factory registration +from modelopt.torch.puzzletron.anymodel.models.gpt_oss import * +from modelopt.torch.puzzletron.anymodel.models.llama import * +from modelopt.torch.puzzletron.anymodel.models.mistral_small import * +from modelopt.torch.puzzletron.anymodel.models.nemotron_h import * +from modelopt.torch.puzzletron.anymodel.models.nemotron_h_v2 import * +from modelopt.torch.puzzletron.anymodel.models.qwen2 import * +from modelopt.torch.puzzletron.anymodel.models.qwen3 import * +from modelopt.torch.puzzletron.anymodel.models.qwen3_vl import * diff --git a/modelopt/torch/puzzletron/anymodel/models/gpt_oss/__init__.py b/modelopt/torch/puzzletron/anymodel/models/gpt_oss/__init__.py new file mode 100644 index 0000000000..9f72b8dd78 --- /dev/null +++ b/modelopt/torch/puzzletron/anymodel/models/gpt_oss/__init__.py @@ -0,0 +1,22 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +"""GPT-OSS model support for AnyModel.""" + +from .gpt_oss_converter import GptOssConverter +from .gpt_oss_model_descriptor import GptOssModelDescriptor diff --git a/modelopt/torch/puzzletron/anymodel/models/gpt_oss/gpt_oss_converter.py b/modelopt/torch/puzzletron/anymodel/models/gpt_oss/gpt_oss_converter.py new file mode 100644 index 0000000000..3e7371aaee --- /dev/null +++ b/modelopt/torch/puzzletron/anymodel/models/gpt_oss/gpt_oss_converter.py @@ -0,0 +1,74 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# mypy: ignore-errors + +"""GPT-OSS-20B converter for AnyModel compression.""" + +from typing import List + +from transformers import PretrainedConfig + +from modelopt.torch.puzzletron.anymodel.converter import Converter, ConverterFactory +from modelopt.torch.puzzletron.decilm.deci_lm_hf_code.block_config import ( + AttentionConfig, + BlockConfig, + FFNConfig, + MoEConfig, +) + + +@ConverterFactory.register_decorator("gpt_oss") +class GptOssConverter(Converter): + """Converter for GPT-OSS models to AnyModel format. + + GPT-OSS is a pure MoE model with 32/128 experts per layer and 4/16 active experts. + All layers use MoE FFN (no standard dense FFN layers). + """ + + quantized = "mxfp4" + + @staticmethod + def create_block_configs_from_main_config(config: PretrainedConfig) -> List[BlockConfig]: + """Create block configs for GPT-OSS layers. + + GPT-OSS uses MoE for all FFN layers with: + - 32/128 local experts (num_local_experts) + - 4/16 active experts per token (experts_per_token) + - No dense/standard FFN layers + """ + num_hidden_layers = config.num_hidden_layers + num_local_experts = config.num_local_experts + experts_per_token = config.experts_per_token + intermediate_size = config.intermediate_size + + block_configs = [] + for layer_idx in range(num_hidden_layers): + block_config = BlockConfig( + attention=AttentionConfig( + no_op=False, num_key_value_heads=config.num_key_value_heads + ), + ffn=FFNConfig( + no_op=False, + intermediate_size=None, # MoE doesn't use this field + moe=MoEConfig( + num_local_experts=num_local_experts, + num_experts_per_tok=experts_per_token, + expert_intermediate_dim=intermediate_size, + ), + ), + ).to_dict() + block_configs.append(block_config) + + return block_configs diff --git a/modelopt/torch/puzzletron/anymodel/models/gpt_oss/gpt_oss_model_descriptor.py b/modelopt/torch/puzzletron/anymodel/models/gpt_oss/gpt_oss_model_descriptor.py new file mode 100644 index 0000000000..c77a4547f0 --- /dev/null +++ b/modelopt/torch/puzzletron/anymodel/models/gpt_oss/gpt_oss_model_descriptor.py @@ -0,0 +1,236 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# mypy: ignore-errors + +"""GPT-OSS model descriptor for AnyModel compression.""" + +import re +from dataclasses import dataclass, field +from typing import Dict, List, Tuple, Type + +import torch.nn as nn +from transformers.models.gpt_oss.modeling_gpt_oss import GptOssDecoderLayer, GptOssRotaryEmbedding + +from modelopt.torch.puzzletron.anymodel.model_descriptor import ( + ModelDescriptor, + ModelDescriptorFactory, +) +from modelopt.torch.puzzletron.anymodel.puzzformer.no_op import ( + MatchingZeros, + Same, + return_tuple_of_size, +) +from modelopt.torch.puzzletron.decilm.deci_lm_hf_code.block_config import BlockConfig +from modelopt.torch.puzzletron.pruning.expert_removal_pruning_mixin import ( + ExpertRemovalLayerDescriptor, + ExpertRemovalPruningMixIn, +) + +# Expert removal is supported for unquantized models (test models). +# Production models use MXFP4 quantized MoE with combined tensors +# (gate_up_proj_blocks, down_proj_blocks), which is not yet supported. +from modelopt.torch.puzzletron.pruning.pruning_mixin import PruningMixIn +from modelopt.torch.puzzletron.utils.dummy_modules import DummyBlock + + +@ModelDescriptorFactory.register_decorator("gpt_oss") +class GptOssModelDescriptor(ModelDescriptor): + """Model descriptor for GPT-OSS (pure MoE model).""" + + _DECODER_LAYER_CLS: Type[nn.Module] = None + + @classmethod + def create_dummy_block(cls, original_layer: GptOssDecoderLayer, block_index: int) -> nn.Module: + dummy_block = DummyBlock(block_index=block_index) + # Required by `GptOssModel.forward`. + dummy_block.attention_type = original_layer.attention_type + return dummy_block + + @staticmethod + def decoder_layer_cls(): + """Get the decoder layer class for GPT-OSS models. + + GPT-OSS is a standard transformers model in recent versions. + Import directly from transformers.models.gpt_oss.modeling_gpt_oss. + """ + return GptOssDecoderLayer + + @staticmethod + def block_config_to_layer_overrides(block_config: BlockConfig): + """Map BlockConfig to layer constructor overrides.""" + override_kwargs = {} + + if block_config.attention.num_key_value_heads is not None: + override_kwargs["num_key_value_heads"] = block_config.attention.num_key_value_heads + + if block_config.ffn.moe is not None: + override_kwargs["moe_intermediate_size"] = block_config.ffn.moe.expert_intermediate_dim + override_kwargs["num_local_experts"] = block_config.ffn.moe.num_local_experts + override_kwargs["num_experts_per_tok"] = block_config.ffn.moe.num_experts_per_tok + + return override_kwargs + + @staticmethod + def attn_no_op_post_init(decoder_layer): + """Replace attention sublayers with no-op modules.""" + decoder_layer.input_layernorm = Same() + decoder_layer.self_attn = return_tuple_of_size(MatchingZeros, size=2)() + + @staticmethod + def mlp_no_op_post_init(decoder_layer): + """Replace MLP sublayers with no-op modules. + + Note: GPT-OSS MoE layers return (hidden_states, router_scores), so we need + to return a tuple of 2 values. + """ + decoder_layer.post_attention_layernorm = Same() + decoder_layer.mlp = return_tuple_of_size(MatchingZeros, size=2)() + + @staticmethod + def init_rotary_embedding(model, runtime): + """Initialize rotary embeddings on the correct device.""" + # GPT-OSS uses RoPE with YARN scaling + + model.model.rotary_emb = GptOssRotaryEmbedding( + config=model.config, + device=runtime.device, + ) + + @staticmethod + def input_embedding_name(): + return "model.embed_tokens" + + @staticmethod + def output_embedding_name(): + return "lm_head" + + @staticmethod + def final_norm_name(): + return "model.norm" + + @staticmethod + def layer_block_name(index: int): + return f"model.layers.{index}" + + @staticmethod + def layer_name_predicates(num_layers: int) -> Dict[str, re.Pattern]: + """Define regex patterns for grouping weights into subblocks.""" + layer_name_patterns = { + "embeddings": re.compile(r"^model\.embed_tokens\.weight$"), + "lm_head": re.compile(r"^(model\.norm\.weight|lm_head\.weight)$"), + } + + def build_ffn_predicates() -> Dict[str, re.Pattern]: + """FFN is MoE in GPT-OSS with MXFP4 quantization.""" + return { + f"block_{layer_idx}_ffn": re.compile( + rf"^model\.layers\.{layer_idx}\." + r"(post_attention_layernorm\.weight" + r"|mlp\.router\.weight" + r"|mlp\.router\.bias" + r"|mlp\.experts\.(gate_up_proj|down_proj)(_(bias|blocks|scales))?)$" + ) + for layer_idx in range(num_layers) + } + + def build_attention_predicates() -> Dict[str, re.Pattern]: + return { + f"block_{layer_idx}_attention": re.compile( + rf"^model\.layers\.{layer_idx}\." + r"(input_layernorm\.weight" + r"|self_attn\.q_proj\.weight" + r"|self_attn\.q_proj\.bias" + r"|self_attn\.k_proj\.weight" + r"|self_attn\.k_proj\.bias" + r"|self_attn\.v_proj\.weight" + r"|self_attn\.v_proj\.bias" + r"|self_attn\.o_proj\.weight" + r"|self_attn\.o_proj\.bias" + r"|self_attn\.sinks)$" + ) + for layer_idx in range(num_layers) + } + + layer_name_patterns.update( + **build_ffn_predicates(), + **build_attention_predicates(), + ) + + return layer_name_patterns + + @staticmethod + def pruning_mixins() -> Dict[str, PruningMixIn]: + """Return available pruning mixins for GPT-OSS. + + Note: Expert removal works for unquantized models (test models). + Production models use MXFP4 quantization which is not yet supported. + """ + return {"expert_removal": ExpertRemovalPruningMixIn(GptOssExpertRemovalLayerDescriptor())} + + +@dataclass +class GptOssExpertRemovalLayerDescriptor(ExpertRemovalLayerDescriptor): + """ + GPT-OSS MoE layer descriptor for expert removal. + + Note: This only works for unquantized models (e.g., test models). + Production GPT-OSS models use MXFP4 quantization with fused experts + (_blocks, _scales, _bias), which requires a different approach. + + Structure: + - Router: mlp.router with .weight and .bias + - Experts: mlp.experts.{idx}.{gate_up_proj,down_proj} with .weight and .bias + """ + + target_name: str = "mlp" + moe_prefix_name: str = "model.layers.{layer_idx}.mlp" + expert_prefix_name: str = "experts" + + # Router has both weight and bias + router_weights: List[str] = field(default_factory=lambda: ["router.weight"]) + router_biases: List[str] = field(default_factory=lambda: ["router.bias"]) + + # Fused format: experts stored as single tensors + is_fused_experts: bool = True + + # Fused format: single tensors containing all experts (test models) + fused_expert_weights: List[str] = field( + default_factory=lambda: [ + "experts.gate_up_proj", + "experts.gate_up_proj_bias", + "experts.down_proj", + "experts.down_proj_bias", + ] + ) + + # Not used for fused format, but kept for compatibility + expert_weights: List[str] = field(default_factory=lambda: ["gate_up_proj", "down_proj"]) + expert_biases: List[str] = field( + default_factory=lambda: ["gate_up_proj_bias", "down_proj_bias"] + ) + + def get_modules_names_to_hook(self, model) -> List[Tuple[int, str]]: + target_class_name = "GptOssTopKRouter" + + module_names_to_hook = [] + for module_name, module in model.named_modules(): + if ( + module_name.endswith(self.target_name) + and module.__class__.__name__ == target_class_name + ): + module_names_to_hook.append( + (self.block_idx_from_module_name(module_name), module_name) + ) + return module_names_to_hook diff --git a/modelopt/torch/puzzletron/anymodel/models/gpt_oss/gpt_oss_pruned_to_mxfp4.py b/modelopt/torch/puzzletron/anymodel/models/gpt_oss/gpt_oss_pruned_to_mxfp4.py new file mode 100644 index 0000000000..64d18921fd --- /dev/null +++ b/modelopt/torch/puzzletron/anymodel/models/gpt_oss/gpt_oss_pruned_to_mxfp4.py @@ -0,0 +1,524 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Create a HuggingFace checkpoint with MXFP4 MoE weights from the original gpt-oss-120b model. + +This script: +1. Copies non-MoE weights from the student model (trained attention, embeddings, etc.) +2. Extracts MoE expert weights from the original gpt-oss-120b in MXFP4 format +3. Deduces expert mappings by comparing weights +4. Outputs a new pruned (heterogeneous) checkpoint with PACKED MXFP4 expert weights +""" + +import argparse +import json +import os +import shutil +from typing import Any, Dict, List, Optional, TextIO, Tuple + +import torch +from safetensors import safe_open +from safetensors.torch import save_file +from tqdm import tqdm +from transformers.integrations.mxfp4 import convert_moe_packed_tensors + + +def deduce_experts_for_layer( + layer: int, + original_path: str, + original_index: Dict, + student_path: str, +) -> Tuple[List[int], int, int]: + """ + Deduce which original experts match the student experts by comparing weights. + + Compares dequantized MXFP4 weights from the original model against the student + model's BF16 weights using L2 distance. Finds the best 1-to-1 matching. + + Args: + layer: Layer index + original_path: Path to original model + original_index: Original model's safetensors index + student_path: Path to student model + num_student_experts: Number of experts in student model (if None, auto-detect) + + Returns: + Tuple of (expert_indices, num_student_experts, num_original_experts) + """ + # Load original tensors + orig_tensors = load_layer_tensors(original_path, layer, original_index) + mlp1_blocks = orig_tensors[f"model.layers.{layer}.mlp.experts.gate_up_proj_blocks"] + mlp1_scales = orig_tensors[f"model.layers.{layer}.mlp.experts.gate_up_proj_scales"] + mlp2_blocks = orig_tensors[f"model.layers.{layer}.mlp.experts.down_proj_blocks"] + mlp2_scales = orig_tensors[f"model.layers.{layer}.mlp.experts.down_proj_scales"] + + num_original_experts = mlp1_blocks.shape[0] + + # Load student tensors + student_subblocks = os.path.join(student_path, "subblocks_safetensors") + student_ffn = os.path.join(student_subblocks, f"block_{layer}_ffn.safetensors") + if not os.path.exists(student_ffn): + print(f"FFN file not found at {student_ffn} - fallback to no_op") + return [], 0, num_original_experts + + student_experts = {} + with safe_open(student_ffn, framework="pt") as f: + for key in f.keys(): + if "experts" in key or "router" in key: + student_experts[key] = f.get_tensor(key) + + # Auto-detect number of student experts + num_student_experts = student_experts[f"model.layers.{layer}.mlp.experts.gate_up_proj"].size(0) + print( + f" Layer {layer}: Comparing {num_student_experts} student experts against {num_original_experts} original experts" + ) + + # Pre-dequantize all original experts once (optimization) + print(f" Pre-dequantizing {num_original_experts} original experts...") + deqexpert_mlp1 = convert_moe_packed_tensors(mlp1_blocks, mlp1_scales).cpu() + deqexpert_mlp2 = convert_moe_packed_tensors(mlp2_blocks, mlp2_scales).cpu() + original_experts_dequant = [] + for orig_idx in range(num_original_experts): + original_experts_dequant.append( + {"up": deqexpert_mlp1[orig_idx], "down": deqexpert_mlp2[orig_idx]} + ) + + # For each student expert, find best matching original expert + experts_to_keep = [] + used_original_indices = set() + + # Number of values to use for quick comparison (tune this) + quick_compare_size = 8 + # Number of candidates to keep for full comparison + top_k_candidates = min(10, num_original_experts) + + for student_idx in range(num_student_experts): + # Get student expert weights + prefix = f"model.layers.{layer}.mlp" + student_up = student_experts.get(f"{prefix}.experts.gate_up_proj")[student_idx] # type: ignore[index] + student_down = student_experts.get(f"{prefix}.experts.down_proj")[student_idx] # type: ignore[index] + + # if student_gate is None or student_up is None or student_down is None: + if student_up is None or student_down is None: + raise ValueError( + f"Missing student expert weights for layer {layer} expert {student_idx}" + ) + + # Step 1: Quick filtering using first N values + candidate_scores = [] + for orig_idx in range(num_original_experts): + if orig_idx in used_original_indices: + continue + + orig_expert = original_experts_dequant[orig_idx] + + up_quick = ( + ( + orig_expert["up"].flatten()[:quick_compare_size] + - student_up.float().flatten()[:quick_compare_size] + ) + .pow(2) + .mean() + .sqrt() + ) + down_quick = ( + ( + orig_expert["down"].flatten()[:quick_compare_size] + - student_down.float().flatten()[:quick_compare_size] + ) + .pow(2) + .mean() + .sqrt() + ) + + quick_score = (up_quick + down_quick) / 2.0 + candidate_scores.append((orig_idx, quick_score.item())) + + # Step 2: Get top-k candidates based on quick comparison + candidate_scores.sort(key=lambda x: x[1]) + top_candidates = [idx for idx, _ in candidate_scores[:top_k_candidates]] + + # Step 3: Full comparison only on top candidates + best_match_idx = None + best_match_score = float("inf") + + for orig_idx in top_candidates: + orig_expert = original_experts_dequant[orig_idx] + + # Full comparison across all values + up_diff = (orig_expert["up"] - student_up.float()).pow(2).mean().sqrt() + down_diff = (orig_expert["down"] - student_down.float()).pow(2).mean().sqrt() + + score = (up_diff + down_diff) / 2.0 + + if score < best_match_score: + best_match_score = score + best_match_idx = orig_idx + + if best_match_idx is None: + raise ValueError( + f"Could not find match for student expert {student_idx} in layer {layer}" + ) + + experts_to_keep.append(best_match_idx) + used_original_indices.add(best_match_idx) + print( + f" Student expert {student_idx} -> Original expert {best_match_idx} (RMSE: {best_match_score:.6f})" + ) + + return experts_to_keep, num_student_experts, num_original_experts + + +def load_original_index(path: str) -> Dict[str, Any]: + """Load the original model's safetensors index.""" + with open(path, "r") as f: + return json.load(f) + + +def load_layer_tensors(original_path: str, layer: int, index: Dict) -> Dict[str, torch.Tensor]: + """Load all MoE-related tensors for a layer, potentially from multiple files.""" + keys_to_load = [ + f"model.layers.{layer}.mlp.experts.gate_up_proj_blocks", + f"model.layers.{layer}.mlp.experts.gate_up_proj_scales", + f"model.layers.{layer}.mlp.experts.gate_up_proj_bias", + f"model.layers.{layer}.mlp.experts.down_proj_blocks", + f"model.layers.{layer}.mlp.experts.down_proj_scales", + f"model.layers.{layer}.mlp.experts.down_proj_bias", + f"model.layers.{layer}.mlp.router.weight", # Router weight + f"model.layers.{layer}.mlp.router.bias", # Router bias + ] + + # Group by file + file_to_keys = {} + for key in keys_to_load: + if key in index["weight_map"]: + filename = index["weight_map"][key] + if filename not in file_to_keys: + file_to_keys[filename] = [] + file_to_keys[filename].append(key) + + # Load from each file + tensors = {} + for filename, keys in file_to_keys.items(): + filepath = os.path.join(original_path, filename) + with safe_open(filepath, framework="pt") as f: + for key in keys: + tensors[key] = f.get_tensor(key) + + return tensors + + +def copy_non_moe_weights(student_path: str, output_path: str, num_layers: int) -> Dict[str, str]: + """ + Copy non-MoE weights from student model. + Returns weight_map for the new index. + """ + weight_map = {} + subblocks_dir = os.path.join(output_path, "subblocks_safetensors") + os.makedirs(subblocks_dir, exist_ok=True) + + student_subblocks = os.path.join(student_path, "subblocks_safetensors") + + # Copy embeddings + src_emb = os.path.join(student_subblocks, "embeddings.safetensors") + dst_emb = os.path.join(subblocks_dir, "embeddings.safetensors") + shutil.copy2(src_emb, dst_emb) + with safe_open(src_emb, framework="pt") as f: + for key in f.keys(): + weight_map[key] = "subblocks_safetensors/embeddings.safetensors" + + # Copy lm_head + src_head = os.path.join(student_subblocks, "lm_head.safetensors") + dst_head = os.path.join(subblocks_dir, "lm_head.safetensors") + shutil.copy2(src_head, dst_head) + with safe_open(src_head, framework="pt") as f: + for key in f.keys(): + weight_map[key] = "subblocks_safetensors/lm_head.safetensors" + + # Copy attention blocks + for layer in range(num_layers): + src_attn = os.path.join(student_subblocks, f"block_{layer}_attention.safetensors") + dst_attn = os.path.join(subblocks_dir, f"block_{layer}_attention.safetensors") + shutil.copy2(src_attn, dst_attn) + with safe_open(src_attn, framework="pt") as f: + for key in f.keys(): + weight_map[key] = f"subblocks_safetensors/block_{layer}_attention.safetensors" + + return weight_map + + +def process_single_layer( + layer: int, + original_path: str, + original_index: Dict, + student_path: str, + output_path: str, + experts_to_keep: List[int], +) -> Tuple[Dict[str, str], List[str]]: + """ + Process a single layer - loads tensors from potentially multiple files. + Returns (weight_map, verification_errors). + """ + weight_map = {} + verification_errors = [] + subblocks_dir = os.path.join(output_path, "subblocks_safetensors") + student_subblocks = os.path.join(student_path, "subblocks_safetensors") + + # Load all tensors for this layer (may come from multiple files) + orig_tensors = load_layer_tensors(original_path, layer, original_index) + + # Load student FFN file + student_ffn = os.path.join(student_subblocks, f"block_{layer}_ffn.safetensors") + + tensors_to_save = {} + student_tensors = {} + + with safe_open(student_ffn, framework="pt") as f: + for key in f.keys(): + tensor = f.get_tensor(key) + if "experts" not in key and "router" not in key: + # Copy norm weights + tensors_to_save[key] = tensor + + # Get router from original model, sliced to kept experts + orig_router_weight = orig_tensors[f"model.layers.{layer}.mlp.router.weight"] + orig_router_bias = orig_tensors[f"model.layers.{layer}.mlp.router.bias"] + + kept_indices_tensor = torch.tensor(experts_to_keep, dtype=torch.long) + sliced_router_weight = orig_router_weight[kept_indices_tensor] + sliced_router_bias = orig_router_bias[kept_indices_tensor] + + tensors_to_save[f"model.layers.{layer}.mlp.router.weight"] = sliced_router_weight + tensors_to_save[f"model.layers.{layer}.mlp.router.bias"] = sliced_router_bias + + # Get MoE tensors + mlp1_blocks = orig_tensors[f"model.layers.{layer}.mlp.experts.gate_up_proj_blocks"] + mlp1_scales = orig_tensors[f"model.layers.{layer}.mlp.experts.gate_up_proj_scales"] + mlp2_blocks = orig_tensors[f"model.layers.{layer}.mlp.experts.down_proj_blocks"] + mlp2_scales = orig_tensors[f"model.layers.{layer}.mlp.experts.down_proj_scales"] + mlp1_bias = orig_tensors[f"model.layers.{layer}.mlp.experts.gate_up_proj_bias"] + mlp2_bias = orig_tensors[f"model.layers.{layer}.mlp.experts.down_proj_bias"] + + tensors_to_save[f"model.layers.{layer}.mlp.experts.gate_up_proj_blocks"] = mlp1_blocks[ + kept_indices_tensor + ] + tensors_to_save[f"model.layers.{layer}.mlp.experts.gate_up_proj_scales"] = mlp1_scales[ + kept_indices_tensor + ] + tensors_to_save[f"model.layers.{layer}.mlp.experts.gate_up_proj_bias"] = mlp1_bias[ + kept_indices_tensor + ] + + tensors_to_save[f"model.layers.{layer}.mlp.experts.down_proj_blocks"] = mlp2_blocks[ + kept_indices_tensor + ] + tensors_to_save[f"model.layers.{layer}.mlp.experts.down_proj_scales"] = mlp2_scales[ + kept_indices_tensor + ] + tensors_to_save[f"model.layers.{layer}.mlp.experts.down_proj_bias"] = mlp2_bias[ + kept_indices_tensor + ] + + # Save the FFN file + output_file = os.path.join(subblocks_dir, f"block_{layer}_ffn.safetensors") + save_file(tensors_to_save, output_file) + + # Build weight map + for key in tensors_to_save.keys(): + weight_map[key] = f"subblocks_safetensors/block_{layer}_ffn.safetensors" + + return weight_map, verification_errors + + +def copy_config_files(student_path: str, output_path: str): + """Copy configuration files from student model and update config.json.""" + files_to_copy = [ + "tokenizer.json", + "tokenizer_config.json", + "special_tokens_map.json", + "chat_template.jinja", + ] + + # Also copy transformers compatibility files + if os.path.exists(student_path): + for f in os.listdir(student_path): + if f.startswith("transformers_"): + files_to_copy.append(f) + + for filename in files_to_copy: + src = os.path.join(student_path, filename) + dst = os.path.join(output_path, filename) + + # Try student path first + if os.path.exists(src): + try: + shutil.copy2(src, dst) + continue + except PermissionError: + pass + + # If we get here, file doesn't exist or permission denied + if not os.path.exists(dst): + print(f" Warning: Could not copy {filename}") + + # Update config.json for DeciGptOssForCausalLM with MXFP4 + src_config = os.path.join(student_path, "config.json") + if not os.path.exists(src_config): + raise FileNotFoundError(f"config.json not found at {src_config}") + + with open(src_config, "r") as f: + config = json.load(f) # type: ignore[arg-type] + + # Set architecture to DeciGptOssForCausalLM for MXFP4 support + config["architectures"] = ["DeciGptOssForCausalLM"] + + # Add quantization_config so vllm calls _load_weights_mxfp4 + config["quantization_config"] = { + "quant_method": "mxfp4", + "modules_to_not_convert": [ + "model.layers.*.self_attn", + "model.layers.*.mlp.router", + "model.embed_tokens", + "lm_head", + ], + } + + dst_config = os.path.join(output_path, "config.json") + with open(dst_config, "w") as f: + json.dump(config, f, indent=2) # type: ignore[arg-type] + + +def main(): + parser = argparse.ArgumentParser(description="Create MXFP4 checkpoint from student model") + parser.add_argument( + "--student-path", type=str, required=True, help="Path to student model checkpoint" + ) + parser.add_argument( + "--original-path", + type=str, + required=True, + help="Path to original gpt-oss-120b model with MXFP4 weights", + ) + parser.add_argument( + "--output-path", type=str, required=True, help="Output path for the new checkpoint" + ) + parser.add_argument("--num-layers", type=int, default=36, help="Number of transformer layers") + args = parser.parse_args() + + print(f"Creating MXFP4 checkpoint...") + print(f" Student model: {args.student_path}") + print(f" Original model: {args.original_path}") + print(f" Output: {args.output_path}") + + # Load original model index + original_index = load_original_index( + os.path.join(args.original_path, "model.safetensors.index.json") + ) + + print("\nDeducing expert mappings by comparing weights...") + experts_to_keep = [] + layer_statistics = [] # Store (num_student, num_original) for each layer + + for layer in range(args.num_layers): + layer_experts, num_student, num_original = deduce_experts_for_layer( + layer, + args.original_path, + original_index, + args.student_path, + ) + experts_to_keep.append(layer_experts) + layer_statistics.append((num_student, num_original)) + + # Print statistics + print(f"\n{'=' * 70}") + print("EXPERT DEDUCTION STATISTICS") + print(f"{'=' * 70}") + print(f"{'Layer':<8} {'Student Experts':<18} {'Original Experts':<18} {'Kept %':<10}") + print(f"{'-' * 70}") + + total_student = 0 + total_original = 0 + for layer, (num_student, num_original) in enumerate(layer_statistics): + percentage = (num_student / num_original * 100) if num_original > 0 else 0 + print(f"{layer:<8} {num_student:<18} {num_original:<18} {percentage:<10.2f}") + total_student += num_student + total_original += num_original + + print(f"{'-' * 70}") + avg_percentage = (total_student / total_original * 100) if total_original > 0 else 0 + print(f"{'TOTAL':<8} {total_student:<18} {total_original:<18} {avg_percentage:<10.2f}") + print(f"{'=' * 70}") + print(f"\n Deduced experts_to_keep mapping for {len(experts_to_keep)} layers") + + # Create output directory + os.makedirs(args.output_path, exist_ok=True) + os.makedirs(os.path.join(args.output_path, "subblocks_safetensors"), exist_ok=True) + + # Copy config files + print("Copying configuration files...") + copy_config_files(args.student_path, args.output_path) + + # Save experts_to_keep.json + experts_to_keep_output = os.path.join(args.output_path, "experts_to_keep.json") + with open(experts_to_keep_output, "w") as f: + json.dump(experts_to_keep, f, indent=2) + print(f" Saved experts_to_keep mapping to {experts_to_keep_output}") + + # Copy non-MoE weights (embeddings, attention, lm_head) + print("Copying non-MoE weights...") + weight_map = copy_non_moe_weights(args.student_path, args.output_path, args.num_layers) + + # Load weights per layer (handles multi-file loading) + print(f"Processing {args.num_layers} layers...") + + all_verification_errors = [] + + # Process each layer + for layer in tqdm(range(args.num_layers), desc="Processing layers"): + if len(experts_to_keep[layer]) == 0: + print(f"Layer {layer} has no experts to keep - ffn->no_op") + continue + layer_weight_map, layer_errors = process_single_layer( + layer, + args.original_path, + original_index, + args.student_path, + args.output_path, + experts_to_keep[layer], + ) + weight_map.update(layer_weight_map) + all_verification_errors.extend(layer_errors) + + # Calculate total size + total_size = 0 + subblocks_dir = os.path.join(args.output_path, "subblocks_safetensors") + for filename in os.listdir(subblocks_dir): + filepath = os.path.join(subblocks_dir, filename) + total_size += os.path.getsize(filepath) + + # Create model.safetensors.index.json + index = {"metadata": {"total_size": total_size}, "weight_map": weight_map} + + index_path = os.path.join(args.output_path, "model.safetensors.index.json") + with open(index_path, "w") as f: + json.dump(index, f, indent=2) + + print(f"\nCheckpoint created successfully at: {args.output_path}") + print(f"Total size: {total_size / 1e9:.2f} GB") + + +if __name__ == "__main__": + main() diff --git a/modelopt/torch/puzzletron/anymodel/models/llama/__init__.py b/modelopt/torch/puzzletron/anymodel/models/llama/__init__.py new file mode 100644 index 0000000000..a0be9f919e --- /dev/null +++ b/modelopt/torch/puzzletron/anymodel/models/llama/__init__.py @@ -0,0 +1,19 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from modelopt.torch.puzzletron.anymodel.models.llama.llama_converter import LlamaConverter +from modelopt.torch.puzzletron.anymodel.models.llama.llama_model_descriptor import ( + LlamaModelDescriptor, +) diff --git a/modelopt/torch/puzzletron/anymodel/models/llama/llama_converter.py b/modelopt/torch/puzzletron/anymodel/models/llama/llama_converter.py new file mode 100644 index 0000000000..5a0686ecc8 --- /dev/null +++ b/modelopt/torch/puzzletron/anymodel/models/llama/llama_converter.py @@ -0,0 +1,53 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# mypy: ignore-errors + +"""Llama converter for AnyModel compression.""" + +from typing import List + +from transformers import LlamaConfig + +from modelopt.torch.puzzletron.anymodel.converter import Converter, ConverterFactory +from modelopt.torch.puzzletron.decilm.deci_lm_hf_code.block_config import ( + AttentionConfig, + BlockConfig, + FFNConfig, +) + + +@ConverterFactory.register_decorator("llama") +class LlamaConverter(Converter): + """Converter for Llama models to AnyModel format.""" + + @staticmethod + def create_block_configs_from_main_config(config: LlamaConfig) -> List[BlockConfig]: + """Create uniform block configs for all Llama layers. + + Llama models have uniform architecture across all layers, so we create + the same BlockConfig for each layer. + """ + num_hidden_layers = config.num_hidden_layers + + block_configs = [ + BlockConfig( + attention=AttentionConfig( + no_op=False, num_key_value_heads=config.num_key_value_heads + ), + ffn=FFNConfig(no_op=False, intermediate_size=config.intermediate_size), + ).to_dict() + for _ in range(num_hidden_layers) + ] + return block_configs diff --git a/modelopt/torch/puzzletron/anymodel/models/llama/llama_model_descriptor.py b/modelopt/torch/puzzletron/anymodel/models/llama/llama_model_descriptor.py new file mode 100644 index 0000000000..082e5da599 --- /dev/null +++ b/modelopt/torch/puzzletron/anymodel/models/llama/llama_model_descriptor.py @@ -0,0 +1,141 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# mypy: ignore-errors + +"""Llama model descriptor for AnyModel compression.""" + +import re +from dataclasses import dataclass, field +from typing import Dict, List + +from transformers.models.llama.modeling_llama import ( + LlamaDecoderLayer, + LlamaForCausalLM, + LlamaRotaryEmbedding, +) + +from modelopt.torch.puzzletron.anymodel.model_descriptor import ( + ModelDescriptor, + ModelDescriptorFactory, +) +from modelopt.torch.puzzletron.anymodel.puzzformer.no_op import ( + MatchingZeros, + Same, + return_tuple_of_size, +) +from modelopt.torch.puzzletron.decilm.deci_lm_hf_code.block_config import BlockConfig +from modelopt.torch.puzzletron.pruning.ffn_intermediate_pruning_mixin import ( + FFNIntermediateLayerDescriptor, +) +from modelopt.torch.puzzletron.pruning.kv_heads_pruning_mixin import KVHeadsLayerDescriptor + + +@ModelDescriptorFactory.register_decorator("llama") +class LlamaModelDescriptor(ModelDescriptor): + """Model descriptor for Llama models (Llama 2, Llama 3, Llama 3.1, Llama 3.2).""" + + @staticmethod + def decoder_layer_cls(): + return LlamaDecoderLayer + + @staticmethod + def block_config_to_layer_overrides(block_config: BlockConfig): + return { + "intermediate_size": block_config.ffn.intermediate_size, + "num_key_value_heads": block_config.attention.num_key_value_heads, + } + + @staticmethod + def attn_no_op_post_init(decoder_layer: LlamaDecoderLayer): + decoder_layer.input_layernorm = Same() + decoder_layer.self_attn = return_tuple_of_size(MatchingZeros, size=2)() + + @staticmethod + def mlp_no_op_post_init(decoder_layer: LlamaDecoderLayer): + decoder_layer.post_attention_layernorm = Same() + decoder_layer.mlp = MatchingZeros() + + @staticmethod + def init_rotary_embedding(model: LlamaForCausalLM, runtime): + model.model.rotary_emb = LlamaRotaryEmbedding(model.config, runtime.device) + + @staticmethod + def input_embedding_name(): + return "model.embed_tokens" + + @staticmethod + def output_embedding_name(): + return "lm_head" + + @staticmethod + def final_norm_name(): + return "model.norm" + + @staticmethod + def layer_block_name(index: int): + return f"model.layers.{index}" + + @staticmethod + def layer_name_predicates(num_layers: int) -> Dict[str, re.Pattern]: + layer_name_patterns = { + "embeddings": re.compile(r"^model\.embed_tokens\.weight$"), + "lm_head": re.compile(r"^(model\.norm\.weight|lm_head\.weight)$"), + } + + def build_ffn_predicates() -> Dict[str, re.Pattern]: + return { + f"block_{layer_idx}_ffn": re.compile( + rf"^model\.layers\.{layer_idx}\.(post_attention_layernorm\.weight" + r"|mlp\.up_proj\.weight" + r"|mlp\.gate_proj\.weight" + r"|mlp\.down_proj\.weight)$" + ) + for layer_idx in range(num_layers) + } + + def build_attention_predicates() -> Dict[str, re.Pattern]: + return { + f"block_{layer_idx}_attention": re.compile( + rf"^model\.layers\.{layer_idx}\.(input_layernorm\.weight" + r"|self_attn\.q_proj\.weight" + r"|self_attn\.k_proj\.weight" + r"|self_attn\.v_proj\.weight" + r"|self_attn\.o_proj\.weight)$" + ) + for layer_idx in range(num_layers) + } + + layer_name_patterns.update(**build_ffn_predicates(), **build_attention_predicates()) + return layer_name_patterns + + +@dataclass +class LlamaFFNIntermediateLayerDescriptor(FFNIntermediateLayerDescriptor): + """Layer descriptor for Llama FFN intermediate pruning.""" + + down_proj_name: str = "mlp.down_proj" + ffn_prefix_name: str = "model.layers.{layer_idx}.mlp" + linear_weight_names: List[str] = field( + default_factory=lambda: ["down_proj", "gate_proj", "up_proj"] + ) + + +@dataclass +class LlamaKVHeadsLayerDescriptor(KVHeadsLayerDescriptor): + o_proj_name: str = "self_attn.o_proj" + attn_prefix_name: str = "model.layers.{layer_idx}.self_attn" + qkvo_weight_names: List[str] = field( + default_factory=lambda: ["q_proj", "k_proj", "v_proj", "o_proj"] + ) diff --git a/modelopt/torch/puzzletron/anymodel/models/mistral_small/__init__.py b/modelopt/torch/puzzletron/anymodel/models/mistral_small/__init__.py new file mode 100644 index 0000000000..821be47e9d --- /dev/null +++ b/modelopt/torch/puzzletron/anymodel/models/mistral_small/__init__.py @@ -0,0 +1,21 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from modelopt.torch.puzzletron.anymodel.models.mistral_small.mistral_small_converter import ( + MistralSmallConverter, +) +from modelopt.torch.puzzletron.anymodel.models.mistral_small.mistral_small_model_descriptor import ( + MistralSmallModelDescriptor, +) diff --git a/modelopt/torch/puzzletron/anymodel/models/mistral_small/mistral_small_converter.py b/modelopt/torch/puzzletron/anymodel/models/mistral_small/mistral_small_converter.py new file mode 100644 index 0000000000..ddc8151dc9 --- /dev/null +++ b/modelopt/torch/puzzletron/anymodel/models/mistral_small/mistral_small_converter.py @@ -0,0 +1,41 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# mypy: ignore-errors + +from typing import List + +from transformers import MistralConfig + +from modelopt.torch.puzzletron.anymodel.converter import Converter, ConverterFactory +from modelopt.torch.puzzletron.decilm.deci_lm_hf_code.block_config import ( + AttentionConfig, + BlockConfig, + FFNConfig, +) + + +@ConverterFactory.register_decorator("mistral_small") +class MistralSmallConverter(Converter): + @staticmethod + def create_block_configs_from_main_config(config: MistralConfig) -> List[BlockConfig]: + num_hidden_layers = config.num_hidden_layers + + block_config = BlockConfig( + attention=AttentionConfig(no_op=False, num_key_value_heads=config.num_key_value_heads), + ffn=FFNConfig(no_op=False, intermediate_size=config.intermediate_size), + ).to_dict() + + block_configs = [block_config] * num_hidden_layers + return block_configs diff --git a/modelopt/torch/puzzletron/anymodel/models/mistral_small/mistral_small_model_descriptor.py b/modelopt/torch/puzzletron/anymodel/models/mistral_small/mistral_small_model_descriptor.py new file mode 100644 index 0000000000..1ac2bd7072 --- /dev/null +++ b/modelopt/torch/puzzletron/anymodel/models/mistral_small/mistral_small_model_descriptor.py @@ -0,0 +1,135 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# mypy: ignore-errors + +import re +from dataclasses import dataclass, field +from typing import Dict, List + +from transformers.models.mistral.modeling_mistral import ( + MistralDecoderLayer, + MistralForCausalLM, + MistralRotaryEmbedding, +) + +from modelopt.torch.puzzletron.anymodel.model_descriptor import ( + ModelDescriptor, + ModelDescriptorFactory, +) +from modelopt.torch.puzzletron.anymodel.puzzformer.no_op import ( + MatchingZeros, + Same, + return_tuple_of_size, +) +from modelopt.torch.puzzletron.decilm.deci_lm_hf_code.block_config import BlockConfig +from modelopt.torch.puzzletron.pruning.ffn_intermediate_pruning_mixin import ( + FFNIntermediateLayerDescriptor, +) +from modelopt.torch.puzzletron.pruning.kv_heads_pruning_mixin import KVHeadsLayerDescriptor + + +@ModelDescriptorFactory.register_decorator("mistral_small") +class MistralSmallModelDescriptor(ModelDescriptor): + @staticmethod + def decoder_layer_cls(): + return MistralDecoderLayer + + @staticmethod + def block_config_to_layer_overrides(block_config: BlockConfig): + return { + "intermediate_size": block_config.ffn.intermediate_size, + "num_key_value_heads": block_config.attention.num_key_value_heads, + } + + @staticmethod + def attn_no_op_post_init(decoder_layer: MistralDecoderLayer): + decoder_layer.input_layernorm = Same() + decoder_layer.self_attn = return_tuple_of_size(MatchingZeros, size=2)() + + @staticmethod + def mlp_no_op_post_init(decoder_layer: MistralDecoderLayer): + decoder_layer.post_attention_layernorm = Same() + decoder_layer.mlp = MatchingZeros() + + @staticmethod + def init_rotary_embedding(model: MistralForCausalLM, runtime): + model.model.rotary_emb = MistralRotaryEmbedding(model.config, runtime.device) + + @staticmethod + def input_embedding_name(): + return "model.embed_tokens" + + @staticmethod + def output_embedding_name(): + return "lm_head" + + @staticmethod + def final_norm_name(): + return "model.norm" + + @staticmethod + def layer_block_name(index: int): + return f"model.layers.{index}" + + @staticmethod + def layer_name_predicates(num_layers: int) -> Dict[str, re.Pattern]: + layer_name_patterns = { + "embeddings": re.compile(r"^model\.embed_tokens\.weight$"), + "lm_head": re.compile(r"^(model\.norm\.weight|lm_head\.weight)$"), + } + + def build_ffn_predicates() -> Dict[str, re.Pattern]: + return { + f"block_{layer_idx}_ffn": re.compile( + rf"^model\.layers\.{layer_idx}\.(post_attention_layernorm\.weight" + r"|mlp\.up_proj\.weight" + r"|mlp\.gate_proj\.weight" + r"|mlp\.down_proj\.weight)$" + ) + for layer_idx in range(num_layers) + } + + def build_attention_predicates() -> Dict[str, re.Pattern]: + return { + f"block_{layer_idx}_attention": re.compile( + rf"^model\.layers\.{layer_idx}\.(input_layernorm\.weight" + r"|self_attn\.q_proj\.weight" + r"|self_attn\.k_proj\.weight" + r"|self_attn\.v_proj\.weight" + r"|self_attn\.o_proj\.weight)$" + ) + for layer_idx in range(num_layers) + } + + layer_name_patterns.update(**build_ffn_predicates(), **build_attention_predicates()) + return layer_name_patterns + + +@dataclass +class MistralFFNIntermediateLayerDescriptor(FFNIntermediateLayerDescriptor): + down_proj_name: str = "mlp.down_proj" + ffn_prefix_name: str = "model.layers.{layer_idx}.mlp" + linear_weight_names: List[str] = field( + default_factory=lambda: ["down_proj", "gate_proj", "up_proj"] + ) + + +@dataclass +class MistralKVHeadsLayerDescriptor(KVHeadsLayerDescriptor): + o_proj_name: str = "self_attn.o_proj" + attn_prefix_name: str = "model.layers.{layer_idx}.self_attn" + qkvo_weight_names: List[str] = field( + default_factory=lambda: ["q_proj", "k_proj", "v_proj", "o_proj"] + ) diff --git a/modelopt/torch/puzzletron/anymodel/models/nemotron_h/__init__.py b/modelopt/torch/puzzletron/anymodel/models/nemotron_h/__init__.py new file mode 100644 index 0000000000..a2140f118e --- /dev/null +++ b/modelopt/torch/puzzletron/anymodel/models/nemotron_h/__init__.py @@ -0,0 +1,21 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from modelopt.torch.puzzletron.anymodel.models.nemotron_h.nemotron_h_converter import ( + NemotronHConverter, +) +from modelopt.torch.puzzletron.anymodel.models.nemotron_h.nemotron_h_model_descriptor import ( + NemotronHModelDescriptor, +) diff --git a/modelopt/torch/puzzletron/anymodel/models/nemotron_h/nemotron_h_converter.py b/modelopt/torch/puzzletron/anymodel/models/nemotron_h/nemotron_h_converter.py new file mode 100644 index 0000000000..16d9e3c73d --- /dev/null +++ b/modelopt/torch/puzzletron/anymodel/models/nemotron_h/nemotron_h_converter.py @@ -0,0 +1,84 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import List + +from modelopt.torch.puzzletron.anymodel.converter import Converter, ConverterFactory +from modelopt.torch.puzzletron.decilm.deci_lm_hf_code.block_config import ( + AttentionConfig, + BlockConfig, + FFNConfig, + MambaConfig, + MoEConfig, +) + + +@ConverterFactory.register_decorator("nemotron_h") +class NemotronHConverter(Converter): + @staticmethod + def create_block_configs_from_main_config(config) -> List[BlockConfig]: + # Create block configs for each layer based on the hybrid_override_pattern + block_configs = [] + + # Parse the hybrid_override_pattern: "M-M-M-M*-M-M-M-M-M*-M-M-M-M-M*-M-M-M-M-M*-M-M-M-M-M-" + pattern = config.hybrid_override_pattern + print(f"Parsing hybrid pattern: {pattern}") + + for i, char in enumerate(pattern): + if char == "M": + _block_config = BlockConfig( + attention=AttentionConfig( + mamba=MambaConfig( # Those parameters are currently used only for calc_block_stats. + state_dim=config.ssm_state_size, + num_heads=config.mamba_num_heads, + head_dim=config.mamba_head_dim, + num_groups=config.n_groups, + ) + ), + ffn=FFNConfig(no_op=True), + ) + + elif char == "-": + _block_config = BlockConfig( + attention=AttentionConfig(no_op=True), + ffn=FFNConfig(intermediate_size=config.intermediate_size), + ) + + elif char == "*": + _block_config = BlockConfig( + attention=AttentionConfig(num_key_value_heads=config.num_key_value_heads), + ffn=FFNConfig(no_op=True), + ) + + elif char == "E": + _block_config = BlockConfig( + attention=AttentionConfig(no_op=True), + ffn=FFNConfig( + moe=MoEConfig( + num_local_experts=config.n_routed_experts, + expert_intermediate_dim=config.moe_intermediate_size, + num_experts_per_tok=config.num_experts_per_tok, + ) + ), + ) + else: + raise ValueError( + f"Unknown character '{char}' in hybrid_override_pattern at position {i}" + ) + + block_configs.append(_block_config) + + print(f"Created {len(block_configs)} block configs from pattern") + return block_configs diff --git a/modelopt/torch/puzzletron/anymodel/models/nemotron_h/nemotron_h_model_descriptor.py b/modelopt/torch/puzzletron/anymodel/models/nemotron_h/nemotron_h_model_descriptor.py new file mode 100644 index 0000000000..7687d57c83 --- /dev/null +++ b/modelopt/torch/puzzletron/anymodel/models/nemotron_h/nemotron_h_model_descriptor.py @@ -0,0 +1,255 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# mypy: ignore-errors + +import importlib +import inspect +import pkgutil +import re +from collections import defaultdict +from dataclasses import dataclass, field +from typing import Dict, Iterable, List, Tuple, Type + +import torch.nn as nn + +from modelopt.torch.puzzletron.anymodel.model_descriptor import ( + ModelDescriptor, + ModelDescriptorFactory, +) +from modelopt.torch.puzzletron.anymodel.puzzformer.no_op import MatchingZeros, Same +from modelopt.torch.puzzletron.decilm.deci_lm_hf_code.block_config import BlockConfig +from modelopt.torch.puzzletron.pruning.expert_removal_pruning_mixin import ( + ExpertRemovalLayerDescriptor, + ExpertRemovalPruningMixIn, +) +from modelopt.torch.puzzletron.pruning.pruning_mixin import PruningMixIn + + +def get_dynamic_modules(module_cls_str: str) -> List[Type[nn.Module]]: + import transformers_modules + + matches = [] + for finder, modname, ispkg in pkgutil.walk_packages( + transformers_modules.__path__, transformers_modules.__name__ + "." + ): + module = importlib.import_module(modname) + for _, obj in inspect.getmembers(module, inspect.isclass): + if obj.__name__ == module_cls_str: + matches.append(obj) + + return matches + + +@dataclass +class NemotronHExpertRemovalLayerDescriptor(ExpertRemovalLayerDescriptor): + target_name: str = "mixer.gate" + moe_prefix_name: str = "backbone.layers.{layer_idx}.mixer" + expert_prefix_name: str = "experts.{expert_idx}" + router_weights: List[str] = field(default_factory=lambda: ["gate.weight"]) + router_biases: List[str] = field(default_factory=lambda: ["gate.e_score_correction_bias"]) + expert_weights: List[str] = field( + default_factory=lambda: ["up_proj.weight", "down_proj.weight"] + ) + + def get_modules_names_to_hook(self, model) -> List[Tuple[int, str]]: + if self.target_name != "mixer": + return super().get_modules_names_to_hook(model) + + # when target is `mixer` we'll target moe layers of class type: `NemotronHMOE`, as NemotronH models use auto-map we'll check for class name instead of class type. + target_class_name = "NemotronHMOE" + + module_names_to_hook = [] + for module_name, module in model.named_modules(): + # restrict to attributes called "mixer" and with the desired class name + if ( + module_name.endswith(self.target_name) + and module.__class__.__name__ == target_class_name + ): + module_names_to_hook.append( + (self.block_idx_from_module_name(module_name), module_name) + ) + return module_names_to_hook + + +@ModelDescriptorFactory.register_decorator("nemotron_h") +class NemotronHModelDescriptor(ModelDescriptor): + _DECODER_LAYER_CLS: Type[nn.Module] = None + + @staticmethod + def decoder_layer_cls(): + decoder_cls_list = get_dynamic_modules("NemotronHBlock") + if not decoder_cls_list: + raise AssertionError( + "NemotronH contains dynamic modules that should be cached beforehand, make sure to load your config using `load_model_config` or manually call `force_cache_dynamic_modules(config, checkpoint_dir)`" + ) + return decoder_cls_list + + @staticmethod + def requires_trust_remote_code() -> bool: + return True + + @staticmethod + def block_config_to_layer_overrides(block_config: BlockConfig): + override_kwargs = {} + if block_config.ffn.intermediate_size is not None: + override_kwargs["intermediate_size"] = block_config.ffn.intermediate_size + + if block_config.attention.num_key_value_heads is not None: + override_kwargs["num_key_value_heads"] = block_config.attention.num_key_value_heads + + if block_config.ffn.moe is not None: + override_kwargs["moe_intermediate_size"] = block_config.ffn.moe.expert_intermediate_dim + override_kwargs["n_routed_experts"] = block_config.ffn.moe.num_local_experts + + return override_kwargs + + @staticmethod + def _block_no_op_post_init(decoder_layer): + """ + Due to the subblock structure of NemotronH always one of the subblock is set to no-op, for a real no-op both attention & ffn no-op should be set to True. + """ + block_config = decoder_layer.config.block_configs[decoder_layer.layer_idx] + if block_config.ffn.no_op and block_config.attention.no_op: + decoder_layer.norm = Same() + decoder_layer.mixer = MatchingZeros() + + @staticmethod + def attn_no_op_post_init(decoder_layer): + NemotronHModelDescriptor._block_no_op_post_init(decoder_layer) + + @staticmethod + def mlp_no_op_post_init(decoder_layer): + NemotronHModelDescriptor._block_no_op_post_init(decoder_layer) + + @classmethod + def create_dummy_block(cls, original_layer: nn.Module, block_index: int) -> nn.Module: + dummy_block = super().create_dummy_block(original_layer, block_index) + # Required by `NemotronHModel.forward`. + dummy_block.block_type = original_layer.block_type + # Preserve layer_idx if it exists (used by _block_no_op_post_init) + if hasattr(original_layer, "layer_idx"): + dummy_block.layer_idx = original_layer.layer_idx + # Preserve config if it exists (used by _block_no_op_post_init to access block_configs) + if hasattr(original_layer, "config"): + dummy_block.config = original_layer.config + return dummy_block + + @staticmethod + def init_rotary_embedding(model, runtime): + """ + NemotronH has no positional embeddings + """ + + @staticmethod + def input_embedding_name(): + return "backbone.embeddings" + + @staticmethod + def output_embedding_name(): + return "lm_head" + + @staticmethod + def final_norm_name(): + return "backbone.norm_f" + + @staticmethod + def layer_block_name(index: int): + return f"backbone.layers.{index}" + + @classmethod + def get_weight_groups( + cls, layer_names: Iterable[str], num_hidden_layers: int + ) -> Dict[str, List[str]]: + """ + Problem with NemotronH is that `norm.weight` can be in both block_{i}_ffn and block_{i}_attention. duplicate groups with `norm.weight` should be removed. + """ + weight_groups = defaultdict(list) + for name in layer_names: + is_matched = False + for group, pattern in cls.layer_name_predicates(num_hidden_layers).items(): + if pattern.match(name): + weight_groups[group].append(name) + is_matched = True + if not is_matched: + raise ValueError(f"Couldn't find a match for {name}") + + valid_weight_groups = {} + for group, names in weight_groups.items(): + if len(names) == 1: + only_name = names[0] + if only_name.endswith("norm.weight") and "layers" in only_name: + # Skip and don't append this group to valid_weight_groups + continue + valid_weight_groups[group] = names + + return valid_weight_groups + + @staticmethod + def layer_name_predicates(num_layers: int) -> Dict[str, re.Pattern]: + layer_name_patterns = { + "embeddings": re.compile( + r"^(model\.embed_tokens\.weight|backbone\.embeddings\.weight)$" + ), + "lm_head": re.compile(r"^(lm_head\.weight|backbone\.norm_f\.weight)$"), + } + + def build_ffn_predicates() -> Dict[str, re.Pattern]: + return { + f"block_{layer_idx}_ffn": re.compile( + rf"^backbone\.layers\.{layer_idx}\." + r"(norm\.weight|" # ← INCLUDED IN FFN + r"mixer\.(gate\.e_score_correction_bias" + r"|gate\.weight" + r"|experts\.\d+\.up_proj\.weight" + r"|experts\.\d+\.down_proj\.weight" + r"|shared_experts\.up_proj\.weight" + r"|shared_experts\.down_proj\.weight))$" + ) + for layer_idx in range(num_layers) + } + + def build_attention_predicates() -> Dict[str, re.Pattern]: + return { + f"block_{layer_idx}_attention": re.compile( + rf"^backbone\.layers\.{layer_idx}\." + r"(norm\.weight|" # ← INCLUDED IN ATTENTION + r"mixer\.(norm\.weight" + r"|A_log" + r"|D" + r"|conv1d\.weight" + r"|conv1d\.bias" + r"|dt_bias" + r"|in_proj\.weight" + r"|out_proj\.weight" + r"|q_proj\.weight" + r"|k_proj\.weight" + r"|v_proj\.weight" + r"|o_proj\.weight))$" + ) + for layer_idx in range(num_layers) + } + + layer_name_patterns.update( + **build_ffn_predicates(), + **build_attention_predicates(), + ) + + return layer_name_patterns + + @staticmethod + def pruning_mixins() -> Dict[str, PruningMixIn]: + return { + "experts_removal": ExpertRemovalPruningMixIn(NemotronHExpertRemovalLayerDescriptor()), + } diff --git a/modelopt/torch/puzzletron/anymodel/models/nemotron_h_v2/__init__.py b/modelopt/torch/puzzletron/anymodel/models/nemotron_h_v2/__init__.py new file mode 100644 index 0000000000..4b17785ace --- /dev/null +++ b/modelopt/torch/puzzletron/anymodel/models/nemotron_h_v2/__init__.py @@ -0,0 +1,21 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from modelopt.torch.puzzletron.anymodel.models.nemotron_h_v2.nemotron_h_v2_converter import ( + NemotronHV2Converter, +) +from modelopt.torch.puzzletron.anymodel.models.nemotron_h_v2.nemotron_h_v2_model_descriptor import ( + NemotronHV2ModelDescriptor, +) diff --git a/modelopt/torch/puzzletron/anymodel/models/nemotron_h_v2/nemotron_h_v2_converter.py b/modelopt/torch/puzzletron/anymodel/models/nemotron_h_v2/nemotron_h_v2_converter.py new file mode 100644 index 0000000000..2c54388325 --- /dev/null +++ b/modelopt/torch/puzzletron/anymodel/models/nemotron_h_v2/nemotron_h_v2_converter.py @@ -0,0 +1,84 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import List + +from modelopt.torch.puzzletron.anymodel.converter import Converter, ConverterFactory +from modelopt.torch.puzzletron.decilm.deci_lm_hf_code.block_config import ( + AttentionConfig, + BlockConfig, + FFNConfig, + MambaConfig, + MoEConfig, +) + + +@ConverterFactory.register_decorator("nemotron_h_v2") +class NemotronHV2Converter(Converter): + @staticmethod + def create_block_configs_from_main_config(config) -> List[BlockConfig]: + # Create block configs for each layer based on the hybrid_override_pattern + block_configs = [] + + # Parse the hybrid_override_pattern: "M-M-M-M*-M-M-M-M-M*-M-M-M-M-M*-M-M-M-M-M*-M-M-M-M-M-" + pattern = config.hybrid_override_pattern + print(f"Parsing hybrid pattern: {pattern}") + + for i, char in enumerate(pattern): + if char == "M": + _block_config = BlockConfig( + attention=AttentionConfig( + mamba=MambaConfig( # Those parameters are currently used only for calc_block_stats. + state_dim=config.ssm_state_size, + num_heads=config.mamba_num_heads, + head_dim=config.mamba_head_dim, + num_groups=config.n_groups, + ) + ), + ffn=FFNConfig(no_op=True), + ) + + elif char == "-": + _block_config = BlockConfig( + attention=AttentionConfig(no_op=True), + ffn=FFNConfig(intermediate_size=config.intermediate_size), + ) + + elif char == "*": + _block_config = BlockConfig( + attention=AttentionConfig(num_key_value_heads=config.num_key_value_heads), + ffn=FFNConfig(no_op=True), + ) + + elif char == "E": + _block_config = BlockConfig( + attention=AttentionConfig(no_op=True), + ffn=FFNConfig( + moe=MoEConfig( + num_local_experts=config.n_routed_experts, + expert_intermediate_dim=config.moe_intermediate_size, + num_experts_per_tok=config.num_experts_per_tok, + ) + ), + ) + else: + raise ValueError( + f"Unknown character '{char}' in hybrid_override_pattern at position {i}" + ) + + block_configs.append(_block_config) + + print(f"Created {len(block_configs)} block configs from pattern") + return block_configs diff --git a/modelopt/torch/puzzletron/anymodel/models/nemotron_h_v2/nemotron_h_v2_model_descriptor.py b/modelopt/torch/puzzletron/anymodel/models/nemotron_h_v2/nemotron_h_v2_model_descriptor.py new file mode 100644 index 0000000000..c8c89658bf --- /dev/null +++ b/modelopt/torch/puzzletron/anymodel/models/nemotron_h_v2/nemotron_h_v2_model_descriptor.py @@ -0,0 +1,240 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import importlib +import inspect +import pkgutil +import re +from collections import defaultdict +from dataclasses import dataclass, field +from typing import Dict, Iterable, List, Type + +import torch.nn as nn + +from modelopt.torch.puzzletron.anymodel.model_descriptor import ( + ModelDescriptor, + ModelDescriptorFactory, +) +from modelopt.torch.puzzletron.anymodel.puzzformer.no_op import MatchingZeros, Same +from modelopt.torch.puzzletron.decilm.deci_lm_hf_code.block_config import BlockConfig +from modelopt.torch.puzzletron.pruning.ffn_intermediate_pruning_mixin import ( + FFNIntermediateLayerDescriptor, + FFNIntermediatePruningMixIn, +) +from modelopt.torch.puzzletron.pruning.pruning_mixin import PruningMixIn + + +def get_dynamic_modules(module_cls_str: str) -> List[Type[nn.Module]]: + import transformers_modules + + matches = [] + for finder, modname, ispkg in pkgutil.walk_packages( + transformers_modules.__path__, transformers_modules.__name__ + "." + ): + module = importlib.import_module(modname) + for _, obj in inspect.getmembers(module, inspect.isclass): + if obj.__name__ == module_cls_str: + matches.append(obj) + + return matches + + +@dataclass +class NemotronHV2FFNIntermediateLayerDescriptor(FFNIntermediateLayerDescriptor): + down_proj_name: str = "mixer.down_proj" + ffn_prefix_name: str = "backbone.layers.{layer_idx}.mixer" + linear_weight_names: List[str] = field(default_factory=lambda: ["down_proj", "up_proj"]) + + +@ModelDescriptorFactory.register_decorator("nemotron_h_v2") +class NemotronHV2ModelDescriptor(ModelDescriptor): + _DECODER_LAYER_CLS: Type[nn.Module] = None + + @staticmethod + def decoder_layer_cls(): + decoder_cls_list = get_dynamic_modules("NemotronHBlock") + if not decoder_cls_list: + raise AssertionError( + "NemotronH contains dynamic modules that should be cached beforehand, make sure to load your config using `load_model_config` or manually call `force_cache_dynamic_modules(config, checkpoint_dir)`" + ) + return decoder_cls_list + + @staticmethod + def requires_trust_remote_code() -> bool: + return True + + @staticmethod + def block_config_to_layer_overrides(block_config: BlockConfig): + override_kwargs = {} + if block_config.ffn is not None and block_config.ffn.intermediate_size is not None: + override_kwargs["intermediate_size"] = block_config.ffn.intermediate_size + + if ( + block_config.attention is not None + and block_config.attention.num_key_value_heads is not None + ): + override_kwargs["num_key_value_heads"] = block_config.attention.num_key_value_heads + + if block_config.ffn is not None and block_config.ffn.moe is not None: + override_kwargs["moe_intermediate_size"] = block_config.ffn.moe.expert_intermediate_dim + override_kwargs["n_routed_experts"] = block_config.ffn.moe.num_local_experts + + return override_kwargs + + @staticmethod + def _block_no_op_post_init(decoder_layer): + """ + Due to the subblock structure of NemotronH always one of the subblock is set to no-op, for a real no-op both attention & ffn no-op should be set to True. + """ + block_config = decoder_layer.config.block_configs[decoder_layer.layer_idx] + ffn_no_op = block_config.ffn is not None and block_config.ffn.no_op + attn_no_op = block_config.attention is not None and block_config.attention.no_op + if ffn_no_op and attn_no_op: + decoder_layer.norm = Same() + decoder_layer.mixer = MatchingZeros() + + @staticmethod + def attn_no_op_post_init(decoder_layer): + NemotronHV2ModelDescriptor._block_no_op_post_init(decoder_layer) + + @staticmethod + def mlp_no_op_post_init(decoder_layer): + NemotronHV2ModelDescriptor._block_no_op_post_init(decoder_layer) + + @classmethod + def create_dummy_block(cls, original_layer: nn.Module, block_index: int) -> nn.Module: + dummy_block = super().create_dummy_block(original_layer, block_index) + # Required by `NemotronHModel.forward`. + dummy_block.block_type = original_layer.block_type + # Preserve layer_idx if it exists (used by _block_no_op_post_init) + if hasattr(original_layer, "layer_idx"): + dummy_block.layer_idx = original_layer.layer_idx + # Preserve config if it exists (used by _block_no_op_post_init to access block_configs) + if hasattr(original_layer, "config"): + dummy_block.config = original_layer.config + return dummy_block + + @staticmethod + def init_rotary_embedding(model, runtime): + """ + NemotronH has no positional embeddings + """ + + @staticmethod + def input_embedding_name(): + return "backbone.embeddings" + + @staticmethod + def output_embedding_name(): + return "lm_head" + + @staticmethod + def final_norm_name(): + return "backbone.norm_f" + + @staticmethod + def layer_block_name(index: int): + return f"backbone.layers.{index}" + + @classmethod + def get_weight_groups( + cls, layer_names: Iterable[str], num_hidden_layers: int + ) -> Dict[str, List[str]]: + """ + Problem with NemotronH is that `norm.weight` can be in both block_{i}_ffn and block_{i}_attention. duplicate groups with `norm.weight` should be removed. + """ + weight_groups = defaultdict(list) + for name in layer_names: + is_matched = False + for group, pattern in cls.layer_name_predicates(num_hidden_layers).items(): + if pattern.match(name): + weight_groups[group].append(name) + is_matched = True + if not is_matched: + raise ValueError(f"Couldn't find a match for {name}") + + valid_weight_groups = {} + for group, names in weight_groups.items(): + if len(names) == 1: + only_name = names[0] + if only_name.endswith("norm.weight") and "layers" in only_name: + # Skip and don't append this group to valid_weight_groups + continue + valid_weight_groups[group] = names + + return valid_weight_groups + + @staticmethod + def layer_name_predicates(num_layers: int) -> Dict[str, re.Pattern]: + layer_name_patterns = { + "embeddings": re.compile( + r"^(model\.embed_tokens\.weight|backbone\.embeddings\.weight)$" + ), + "lm_head": re.compile(r"^(lm_head\.weight|backbone\.norm_f\.weight)$"), + } + + def build_ffn_predicates() -> Dict[str, re.Pattern]: + return { + f"block_{layer_idx}_ffn": re.compile( + rf"^backbone\.layers\.{layer_idx}\." + r"(norm\.weight|" # ← INCLUDED IN FFN + r"mixer\.(gate\.e_score_correction_bias" + r"|gate\.weight" + r"|experts\.\d+\.up_proj\.weight" + r"|experts\.\d+\.down_proj\.weight" + r"|shared_experts\.up_proj\.weight" + r"|shared_experts\.down_proj\.weight" + r"|up_proj\.weight" # Simple MLP (non-MoE) + r"|down_proj\.weight))$" # Simple MLP (non-MoE) + ) + for layer_idx in range(num_layers) + } + + def build_attention_predicates() -> Dict[str, re.Pattern]: + return { + f"block_{layer_idx}_attention": re.compile( + rf"^backbone\.layers\.{layer_idx}\." + r"(norm\.weight|" # ← INCLUDED IN ATTENTION + r"mixer\.(norm\.weight" + r"|A_log" + r"|D" + r"|conv1d\.weight" + r"|conv1d\.bias" + r"|dt_bias" + r"|in_proj\.weight" + r"|out_proj\.weight" + r"|q_proj\.weight" + r"|k_proj\.weight" + r"|v_proj\.weight" + r"|o_proj\.weight))$" + ) + for layer_idx in range(num_layers) + } + + layer_name_patterns.update( + **build_ffn_predicates(), + **build_attention_predicates(), + ) + + return layer_name_patterns + + @staticmethod + def pruning_mixins() -> Dict[str, PruningMixIn]: + return { + "ffn_intermediate": FFNIntermediatePruningMixIn( + NemotronHV2FFNIntermediateLayerDescriptor() + ), + # TODO: Add expert removal support when ExpertRemovalPruningMixIn is migrated + } diff --git a/modelopt/torch/puzzletron/anymodel/models/qwen2/__init__.py b/modelopt/torch/puzzletron/anymodel/models/qwen2/__init__.py new file mode 100644 index 0000000000..c193fc0d6d --- /dev/null +++ b/modelopt/torch/puzzletron/anymodel/models/qwen2/__init__.py @@ -0,0 +1,19 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from modelopt.torch.puzzletron.anymodel.models.qwen2.qwen2_converter import Qwen2Converter +from modelopt.torch.puzzletron.anymodel.models.qwen2.qwen2_model_descriptor import ( + Qwen2ModelDescriptor, +) diff --git a/modelopt/torch/puzzletron/anymodel/models/qwen2/qwen2_converter.py b/modelopt/torch/puzzletron/anymodel/models/qwen2/qwen2_converter.py new file mode 100644 index 0000000000..878cfd64dc --- /dev/null +++ b/modelopt/torch/puzzletron/anymodel/models/qwen2/qwen2_converter.py @@ -0,0 +1,50 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# mypy: ignore-errors + +"""Qwen2 converter for AnyModel compression.""" + +from typing import List + +from transformers import Qwen2Config + +from modelopt.torch.puzzletron.anymodel.converter import Converter, ConverterFactory +from modelopt.torch.puzzletron.decilm.deci_lm_hf_code.block_config import ( + AttentionConfig, + BlockConfig, + FFNConfig, +) + + +@ConverterFactory.register_decorator("qwen2") +class Qwen2Converter(Converter): + """Converter for Qwen2 models to AnyModel format.""" + + @staticmethod + def create_block_configs_from_main_config(config: Qwen2Config) -> List[BlockConfig]: + """Create uniform block configs for all Qwen2 layers. + + Qwen2 models have uniform architecture across all layers, so we create + the same BlockConfig for each layer. + """ + num_hidden_layers = config.num_hidden_layers + + block_config = BlockConfig( + attention=AttentionConfig(no_op=False, num_key_value_heads=config.num_key_value_heads), + ffn=FFNConfig(no_op=False, intermediate_size=config.intermediate_size), + ).to_dict() + + block_configs = [block_config] * num_hidden_layers + return block_configs diff --git a/modelopt/torch/puzzletron/anymodel/models/qwen2/qwen2_model_descriptor.py b/modelopt/torch/puzzletron/anymodel/models/qwen2/qwen2_model_descriptor.py new file mode 100644 index 0000000000..c2bbeed7a9 --- /dev/null +++ b/modelopt/torch/puzzletron/anymodel/models/qwen2/qwen2_model_descriptor.py @@ -0,0 +1,146 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# mypy: ignore-errors + +"""Qwen2 model descriptor for AnyModel compression.""" + +import re +from dataclasses import dataclass +from typing import Dict + +from torch import nn +from transformers.models.qwen2.modeling_qwen2 import ( + Qwen2DecoderLayer, + Qwen2ForCausalLM, + Qwen2RotaryEmbedding, +) + +from modelopt.torch.puzzletron.anymodel.model_descriptor import ( + ModelDescriptor, + ModelDescriptorFactory, +) +from modelopt.torch.puzzletron.anymodel.models.llama.llama_model_descriptor import ( + LlamaFFNIntermediateLayerDescriptor, +) +from modelopt.torch.puzzletron.anymodel.puzzformer.no_op import ( + MatchingZeros, + Same, + return_tuple_of_size, +) +from modelopt.torch.puzzletron.decilm.deci_lm_hf_code.block_config import BlockConfig +from modelopt.torch.puzzletron.utils.dummy_modules import DummyBlock + + +@ModelDescriptorFactory.register_decorator("qwen2") +class Qwen2ModelDescriptor(ModelDescriptor): + """Model descriptor for Qwen2 models.""" + + @staticmethod + def decoder_layer_cls(): + return Qwen2DecoderLayer + + @classmethod + def create_dummy_block(cls, original_layer: nn.Module, block_index: int) -> nn.Module: + """Create a dummy block that preserves Qwen2-specific attributes like attention_type. + + Qwen2's forward pass accesses decoder_layer.attention_type for attention mask selection. + """ + dummy = DummyBlock(block_index=block_index) + # Copy attention_type from original layer (required by Qwen2's forward pass) + if hasattr(original_layer, "attention_type"): + dummy.attention_type = original_layer.attention_type + return dummy + + @staticmethod + def block_config_to_layer_overrides(block_config: BlockConfig): + return { + "intermediate_size": block_config.ffn.intermediate_size, + "num_key_value_heads": block_config.attention.num_key_value_heads, + } + + @staticmethod + def attn_no_op_post_init(decoder_layer: Qwen2DecoderLayer): + decoder_layer.input_layernorm = Same() + decoder_layer.self_attn = return_tuple_of_size(MatchingZeros, size=2)() + + @staticmethod + def mlp_no_op_post_init(decoder_layer: Qwen2DecoderLayer): + decoder_layer.post_attention_layernorm = Same() + decoder_layer.mlp = MatchingZeros() + + @staticmethod + def init_rotary_embedding(model: Qwen2ForCausalLM, runtime): + model.model.rotary_emb = Qwen2RotaryEmbedding(config=model.config, device=runtime.device) + + @staticmethod + def input_embedding_name(): + return "model.embed_tokens" + + @staticmethod + def output_embedding_name(): + return "lm_head" + + @staticmethod + def final_norm_name(): + return "model.norm" + + @staticmethod + def layer_block_name(index: int): + return f"model.layers.{index}" + + @staticmethod + def layer_name_predicates(num_layers: int) -> Dict[str, re.Pattern]: + layer_name_patterns = { + "embeddings": re.compile(r"^model\.embed_tokens\.weight$"), + "lm_head": re.compile(r"^(model\.norm\.weight|lm_head\.weight)$"), + } + + def build_ffn_predicates() -> Dict[str, re.Pattern]: + return { + f"block_{layer_idx}_ffn": re.compile( + rf"^model\.layers\.{layer_idx}\.(post_attention_layernorm\.weight" + r"|mlp\.up_proj\.weight" + r"|mlp\.gate_proj\.weight" + r"|mlp\.down_proj\.weight)$" + ) + for layer_idx in range(num_layers) + } + + def build_attention_predicates() -> Dict[str, re.Pattern]: + # Qwen2 has biases on attention projections + return { + f"block_{layer_idx}_attention": re.compile( + rf"^model\.layers\.{layer_idx}\.(input_layernorm\.weight" + r"|self_attn\.q_proj\.weight" + r"|self_attn\.q_proj\.bias" + r"|self_attn\.k_proj\.weight" + r"|self_attn\.k_proj\.bias" + r"|self_attn\.v_proj\.weight" + r"|self_attn\.v_proj\.bias" + r"|self_attn\.o_proj\.weight)$" + ) + for layer_idx in range(num_layers) + } + + layer_name_patterns.update(**build_ffn_predicates(), **build_attention_predicates()) + return layer_name_patterns + + +@dataclass +class Qwen2FFNIntermediateLayerDescriptor(LlamaFFNIntermediateLayerDescriptor): + """Layer descriptor for Qwen2 FFN intermediate pruning. + + Qwen2 uses the same FFN structure as Llama (gate_proj, up_proj, down_proj). + """ diff --git a/modelopt/torch/puzzletron/anymodel/models/qwen3/__init__.py b/modelopt/torch/puzzletron/anymodel/models/qwen3/__init__.py new file mode 100644 index 0000000000..cf28475718 --- /dev/null +++ b/modelopt/torch/puzzletron/anymodel/models/qwen3/__init__.py @@ -0,0 +1,19 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from modelopt.torch.puzzletron.anymodel.models.qwen3.qwen3_converter import Qwen3Converter +from modelopt.torch.puzzletron.anymodel.models.qwen3.qwen3_model_descriptor import ( + Qwen3ModelDescriptor, +) diff --git a/modelopt/torch/puzzletron/anymodel/models/qwen3/qwen3_converter.py b/modelopt/torch/puzzletron/anymodel/models/qwen3/qwen3_converter.py new file mode 100644 index 0000000000..bad9bb47d6 --- /dev/null +++ b/modelopt/torch/puzzletron/anymodel/models/qwen3/qwen3_converter.py @@ -0,0 +1,45 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# mypy: ignore-errors + +from typing import List + +from transformers import Qwen3Config + +from modelopt.torch.puzzletron.anymodel.converter import Converter, ConverterFactory +from modelopt.torch.puzzletron.decilm.deci_lm_hf_code.block_config import ( + AttentionConfig, + BlockConfig, + FFNConfig, +) + + +@ConverterFactory.register_decorator("qwen3") +class Qwen3Converter(Converter): + @staticmethod + def create_block_configs_from_main_config(config: Qwen3Config) -> List[BlockConfig]: + num_hidden_layers = config.num_hidden_layers + + block_configs = [ + BlockConfig( + attention=AttentionConfig( + no_op=False, num_key_value_heads=config.num_key_value_heads + ), + ffn=FFNConfig(no_op=False, intermediate_size=config.intermediate_size), + ).to_dict() + for _ in range(num_hidden_layers) + ] + return block_configs diff --git a/modelopt/torch/puzzletron/anymodel/models/qwen3/qwen3_model_descriptor.py b/modelopt/torch/puzzletron/anymodel/models/qwen3/qwen3_model_descriptor.py new file mode 100644 index 0000000000..ae70d96617 --- /dev/null +++ b/modelopt/torch/puzzletron/anymodel/models/qwen3/qwen3_model_descriptor.py @@ -0,0 +1,152 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# mypy: ignore-errors + +import re +from dataclasses import dataclass, field +from typing import Dict, List + +from torch import nn +from transformers.models.qwen3.modeling_qwen3 import ( + Qwen3DecoderLayer, + Qwen3ForCausalLM, + Qwen3RotaryEmbedding, +) + +from modelopt.torch.puzzletron.anymodel.model_descriptor import ( + ModelDescriptor, + ModelDescriptorFactory, +) +from modelopt.torch.puzzletron.anymodel.puzzformer.no_op import ( + MatchingZeros, + Same, + return_tuple_of_size, +) +from modelopt.torch.puzzletron.decilm.deci_lm_hf_code.block_config import BlockConfig +from modelopt.torch.puzzletron.pruning.ffn_intermediate_pruning_mixin import ( + FFNIntermediateLayerDescriptor, +) +from modelopt.torch.puzzletron.pruning.kv_heads_pruning_mixin import KVHeadsLayerDescriptor +from modelopt.torch.puzzletron.utils.dummy_modules import DummyBlock + + +@ModelDescriptorFactory.register_decorator("qwen3") +class Qwen3ModelDescriptor(ModelDescriptor): + @staticmethod + def decoder_layer_cls(): + return Qwen3DecoderLayer + + @classmethod + def create_dummy_block(cls, original_layer: nn.Module, block_index: int) -> nn.Module: + """Create a dummy block that preserves Qwen3-specific attributes like attention_type. + + Qwen3's forward pass accesses decoder_layer.attention_type for attention mask selection. + """ + dummy = DummyBlock(block_index=block_index) + # Copy attention_type from original layer (required by Qwen3's forward pass) + if hasattr(original_layer, "attention_type"): + dummy.attention_type = original_layer.attention_type + return dummy + + @staticmethod + def block_config_to_layer_overrides(block_config: BlockConfig): + return { + "intermediate_size": block_config.ffn.intermediate_size, + "num_key_value_heads": block_config.attention.num_key_value_heads, + } + + @staticmethod + def attn_no_op_post_init(decoder_layer: Qwen3DecoderLayer): + decoder_layer.input_layernorm = Same() + decoder_layer.self_attn = return_tuple_of_size(MatchingZeros, size=2)() + + @staticmethod + def mlp_no_op_post_init(decoder_layer: Qwen3DecoderLayer): + decoder_layer.post_attention_layernorm = Same() + decoder_layer.mlp = MatchingZeros() + + @staticmethod + def init_rotary_embedding(model: Qwen3ForCausalLM, runtime): + model.model.rotary_emb = Qwen3RotaryEmbedding(model.config, runtime.device) + + @staticmethod + def input_embedding_name(): + return "model.embed_tokens" + + @staticmethod + def output_embedding_name(): + return "lm_head" + + @staticmethod + def final_norm_name(): + return "model.norm" + + @staticmethod + def layer_block_name(index: int): + return f"model.layers.{index}" + + @staticmethod + def layer_name_predicates(num_layers: int) -> Dict[str, re.Pattern]: + layer_name_patterns = { + "embeddings": re.compile(r"^model\.embed_tokens\.weight$"), + "lm_head": re.compile(r"^(model\.norm\.weight|lm_head\.weight)$"), + } + + def build_ffn_predicates() -> Dict[str, re.Pattern]: + return { + f"block_{layer_idx}_ffn": re.compile( + rf"^model\.layers\.{layer_idx}\.(post_attention_layernorm\.weight" + r"|mlp\.up_proj\.weight" + r"|mlp\.gate_proj\.weight" + r"|mlp\.down_proj\.weight)$" + ) + for layer_idx in range(num_layers) + } + + def build_attention_predicates() -> Dict[str, re.Pattern]: + return { + f"block_{layer_idx}_attention": re.compile( + rf"^model\.layers\.{layer_idx}\.(input_layernorm\.weight" + r"|self_attn\.q_proj\.weight" + r"|self_attn\.k_proj\.weight" + r"|self_attn\.v_proj\.weight" + r"|self_attn\.o_proj\.weight" + r"|self_attn\.q_norm\.weight" + r"|self_attn\.k_norm\.weight)$" + ) + for layer_idx in range(num_layers) + } + + layer_name_patterns.update(**build_ffn_predicates(), **build_attention_predicates()) + return layer_name_patterns + + +@dataclass +class Qwen3FFNIntermediateLayerDescriptor(FFNIntermediateLayerDescriptor): + down_proj_name: str = "mlp.down_proj" + ffn_prefix_name: str = "model.layers.{layer_idx}.mlp" + linear_weight_names: List[str] = field( + default_factory=lambda: ["down_proj", "gate_proj", "up_proj"] + ) + + +@dataclass +class Qwen3KVHeadsLayerDescriptor(KVHeadsLayerDescriptor): + o_proj_name: str = "self_attn.o_proj" + attn_prefix_name: str = "model.layers.{layer_idx}.self_attn" + qkvo_weight_names: List[str] = field( + default_factory=lambda: ["q_proj", "k_proj", "v_proj", "o_proj"] + ) diff --git a/modelopt/torch/puzzletron/anymodel/models/qwen3_vl/__init__.py b/modelopt/torch/puzzletron/anymodel/models/qwen3_vl/__init__.py new file mode 100644 index 0000000000..48dbd2de8a --- /dev/null +++ b/modelopt/torch/puzzletron/anymodel/models/qwen3_vl/__init__.py @@ -0,0 +1,19 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from modelopt.torch.puzzletron.anymodel.models.qwen3_vl.qwen3_vl_converter import Qwen3VLConverter +from modelopt.torch.puzzletron.anymodel.models.qwen3_vl.qwen3_vl_model_descriptor import ( + Qwen3VLModelDescriptor, +) diff --git a/modelopt/torch/puzzletron/anymodel/models/qwen3_vl/qwen3_vl_converter.py b/modelopt/torch/puzzletron/anymodel/models/qwen3_vl/qwen3_vl_converter.py new file mode 100644 index 0000000000..82e51b7b80 --- /dev/null +++ b/modelopt/torch/puzzletron/anymodel/models/qwen3_vl/qwen3_vl_converter.py @@ -0,0 +1,77 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# mypy: ignore-errors + +from typing import List + +from transformers import Qwen3VLMoeConfig + +from modelopt.torch.puzzletron.anymodel.converter import Converter, ConverterFactory +from modelopt.torch.puzzletron.decilm.deci_lm_hf_code.block_config import ( + AttentionConfig, + BlockConfig, + FFNConfig, + MoEConfig, +) + + +@ConverterFactory.register_decorator("qwen3_vl") +class Qwen3VLConverter(Converter): + @staticmethod + def create_block_configs_from_main_config(config: Qwen3VLMoeConfig) -> List[BlockConfig]: + # Qwen3-VL MoE has nested text_config + text_config = config.text_config if hasattr(config, "text_config") else config + + num_hidden_layers = text_config.num_hidden_layers + decoder_sparse_step = getattr(text_config, "decoder_sparse_step", 1) + mlp_only_layers = getattr(text_config, "mlp_only_layers", []) + + block_configs = [] + for layer_idx in range(num_hidden_layers): + # Check if this layer is MoE or dense + is_moe_layer = (layer_idx % decoder_sparse_step == 0) and ( + layer_idx not in mlp_only_layers + ) + + if is_moe_layer: + # MoE layer + block_config = BlockConfig( + attention=AttentionConfig( + no_op=False, num_key_value_heads=text_config.num_key_value_heads + ), + ffn=FFNConfig( + moe=MoEConfig( + num_local_experts=text_config.num_experts, + expert_intermediate_dim=text_config.moe_intermediate_size, + num_experts_per_tok=text_config.num_experts_per_tok, + ) + ), + ) + else: + # Dense layer + block_config = BlockConfig( + attention=AttentionConfig( + no_op=False, num_key_value_heads=text_config.num_key_value_heads + ), + ffn=FFNConfig(no_op=False, intermediate_size=text_config.intermediate_size), + ) + + block_configs.append(block_config) + + print( + f"Created {len(block_configs)} block configs for Qwen3-VL MoE (decoder_sparse_step={decoder_sparse_step})" + ) + return block_configs diff --git a/modelopt/torch/puzzletron/anymodel/models/qwen3_vl/qwen3_vl_model_descriptor.py b/modelopt/torch/puzzletron/anymodel/models/qwen3_vl/qwen3_vl_model_descriptor.py new file mode 100644 index 0000000000..8c182c8968 --- /dev/null +++ b/modelopt/torch/puzzletron/anymodel/models/qwen3_vl/qwen3_vl_model_descriptor.py @@ -0,0 +1,211 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# mypy: ignore-errors + +import re +from dataclasses import dataclass, field +from typing import Dict, List + +from transformers.models.qwen3_vl_moe.modeling_qwen3_vl_moe import ( + Qwen3VLMoeTextDecoderLayer, + Qwen3VLMoeTextRotaryEmbedding, + Qwen3VLMoeVisionRotaryEmbedding, +) + +from modelopt.torch.puzzletron.anymodel.model_descriptor import ( + ModelDescriptor, + ModelDescriptorFactory, +) +from modelopt.torch.puzzletron.anymodel.puzzformer.no_op import ( + MatchingZeros, + Same, + return_tuple_of_size, +) +from modelopt.torch.puzzletron.decilm.deci_lm_hf_code.block_config import BlockConfig +from modelopt.torch.puzzletron.pruning.expert_removal_pruning_mixin import ( + ExpertRemovalLayerDescriptor, +) +from modelopt.torch.puzzletron.pruning.ffn_intermediate_pruning_mixin import ( + FFNIntermediateLayerDescriptor, +) +from modelopt.torch.puzzletron.pruning.kv_heads_pruning_mixin import KVHeadsLayerDescriptor + + +@ModelDescriptorFactory.register_decorator("qwen3_vl") +class Qwen3VLModelDescriptor(ModelDescriptor): + @staticmethod + def uses_autocast() -> bool: + """ + Qwen3-VL MoE has a dtype bug in HuggingFace transformers under torch.autocast: + scatter() in MoE routing fails with dtype mismatch. Use native bfloat16 instead. + See: https://huggingface.co/Qwen/Qwen3-VL-30B-A3B-Instruct (recommended approach) + """ + return False + + @staticmethod + def get_language_model_config(config): + """Qwen3-VL has nested text_config for language model parameters.""" + return config.text_config if hasattr(config, "text_config") else config + + @staticmethod + def decoder_layer_cls(): + return Qwen3VLMoeTextDecoderLayer + + @staticmethod + def block_config_to_layer_overrides(block_config: BlockConfig): + override_kwargs = {"num_key_value_heads": block_config.attention.num_key_value_heads} + + if block_config.ffn.moe: + override_kwargs["moe_intermediate_size"] = block_config.ffn.moe.expert_intermediate_dim + override_kwargs["num_experts"] = block_config.ffn.moe.num_local_experts + else: + override_kwargs["intermediate_size"] = block_config.ffn.intermediate_size + + return override_kwargs + + @staticmethod + def attn_no_op_post_init(decoder_layer: Qwen3VLMoeTextDecoderLayer): + decoder_layer.input_layernorm = Same() + decoder_layer.self_attn = return_tuple_of_size(MatchingZeros, size=2)() + + @staticmethod + def mlp_no_op_post_init(decoder_layer: Qwen3VLMoeTextDecoderLayer): + decoder_layer.post_attention_layernorm = Same() + decoder_layer.mlp = MatchingZeros() + + @staticmethod + def init_rotary_embedding(model, runtime): + # Re-initialize text rotary embedding on correct device and dtype + text_config = Qwen3VLModelDescriptor.get_language_model_config(model.config) + model.model.language_model.rotary_emb = Qwen3VLMoeTextRotaryEmbedding( + config=text_config + ).to(device=runtime.device, dtype=runtime.dtype) + # Re-initialize vision rotary embedding on correct device and dtype + vision_config = ( + model.config.vision_config if hasattr(model.config, "vision_config") else None + ) + if vision_config is not None: + head_dim = vision_config.hidden_size // vision_config.num_heads + model.model.visual.rotary_pos_emb = Qwen3VLMoeVisionRotaryEmbedding(head_dim // 2).to( + device=runtime.device, dtype=runtime.dtype + ) + + @staticmethod + def input_embedding_name(): + return "model.language_model.embed_tokens" + + @staticmethod + def output_embedding_name(): + return "lm_head" + + @staticmethod + def final_norm_name(): + return "model.language_model.norm" + + @staticmethod + def layer_block_name(index: int): + return f"model.language_model.layers.{index}" + + @staticmethod + def layer_name_predicates(num_layers: int) -> Dict[str, re.Pattern]: + # Qwen3-VL has text model under model.language_model.* prefix + layer_name_patterns = { + "embeddings": re.compile(r"^model\.language_model\.embed_tokens\.weight$"), + "lm_head": re.compile(r"^(model\.language_model\.norm\.weight|lm_head\.weight)$"), + # Vision encoder (includes merger under model.visual.deepstack_merger_list.*) + "vision_encoding": re.compile(r"^model\.visual\..*"), + } + + def build_ffn_predicates() -> Dict[str, re.Pattern]: + return { + f"block_{layer_idx}_ffn": re.compile( + rf"^model\.language_model\.layers\.{layer_idx}\.(post_attention_layernorm\.weight" + # MoE router + r"|mlp\.gate\.weight" + # MoE experts - fused format (gate_up_proj, down_proj without .weight suffix) + r"|mlp\.experts\.gate_up_proj" + r"|mlp\.experts\.down_proj" + # Shared expert (if present) + r"|mlp\.shared_expert\.up_proj\.weight" + r"|mlp\.shared_expert\.gate_proj\.weight" + r"|mlp\.shared_expert\.down_proj\.weight" + r"|mlp\.shared_expert_gate\.weight" + # Dense MLP fallback (for non-MoE layers) + r"|mlp\.up_proj\.weight" + r"|mlp\.gate_proj\.weight" + r"|mlp\.down_proj\.weight)$" + ) + for layer_idx in range(num_layers) + } + + def build_attention_predicates() -> Dict[str, re.Pattern]: + return { + f"block_{layer_idx}_attention": re.compile( + rf"^model\.language_model\.layers\.{layer_idx}\.(input_layernorm\.weight" + r"|self_attn\.q_proj\.weight" + r"|self_attn\.k_proj\.weight" + r"|self_attn\.v_proj\.weight" + r"|self_attn\.o_proj\.weight" + r"|self_attn\.q_norm\.weight" + r"|self_attn\.k_norm\.weight)$" + ) + for layer_idx in range(num_layers) + } + + layer_name_patterns.update(**build_ffn_predicates(), **build_attention_predicates()) + return layer_name_patterns + + +@dataclass +class Qwen3VLFFNIntermediateLayerDescriptor(FFNIntermediateLayerDescriptor): + down_proj_name: str = "mlp.down_proj" + ffn_prefix_name: str = "model.language_model.layers.{layer_idx}.mlp" + linear_weight_names: List[str] = field( + default_factory=lambda: ["down_proj", "gate_proj", "up_proj"] + ) + + +@dataclass +class Qwen3VLKVHeadsLayerDescriptor(KVHeadsLayerDescriptor): + o_proj_name: str = "self_attn.o_proj" + attn_prefix_name: str = "model.language_model.layers.{layer_idx}.self_attn" + qkvo_weight_names: List[str] = field( + default_factory=lambda: ["q_proj", "k_proj", "v_proj", "o_proj"] + ) + + +@dataclass +class Qwen3VLExpertRemovalLayerDescriptor(ExpertRemovalLayerDescriptor): + """ + Qwen3-VL MoE layer descriptor. + + Reference: https://github.com/huggingface/transformers/blob/main/src/transformers/models/qwen3_vl_moe/modeling_qwen3_vl_moe.py + - Qwen3VLMoeTextSparseMoeBlock: MoE block with .gate (router) and .experts + - Qwen3VLMoeTextTopKRouter: Router with .weight (no bias) + - Qwen3VLMoeTextExperts: Fused experts with .gate_up_proj and .down_proj tensors + """ + + target_name: str = "mlp" + moe_prefix_name: str = "model.language_model.layers.{layer_idx}.mlp" + # Router: Qwen3VLMoeTextTopKRouter has self.weight, no bias + router_weights: List[str] = field(default_factory=lambda: ["gate.weight"]) + router_biases: List[str] = field(default_factory=list) + # Fused expert format: Qwen3VLMoeTextExperts stores all experts in single tensors + # with shape [num_experts, ...] instead of separate tensors per expert. + is_fused_experts: bool = True + fused_expert_weights: List[str] = field( + default_factory=lambda: ["experts.gate_up_proj", "experts.down_proj"] + ) diff --git a/modelopt/torch/puzzletron/anymodel/puzzformer/__init__.py b/modelopt/torch/puzzletron/anymodel/puzzformer/__init__.py new file mode 100644 index 0000000000..3af98d57fe --- /dev/null +++ b/modelopt/torch/puzzletron/anymodel/puzzformer/__init__.py @@ -0,0 +1,30 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Utilities for patching and transforming HuggingFace models to work with AnyModel. + +Provides no-op modules for layer replacement and patching utilities for heterogeneous +per-layer configurations. +""" + +from modelopt.torch.puzzletron.anymodel.puzzformer.no_op import ( + MatchingZeros, + Same, + return_tuple_of_size, +) +from modelopt.torch.puzzletron.anymodel.puzzformer.utils import ( + deci_x_patcher, + override_config_with_block_configs, +) diff --git a/modelopt/torch/puzzletron/anymodel/puzzformer/no_op.py b/modelopt/torch/puzzletron/anymodel/puzzformer/no_op.py new file mode 100644 index 0000000000..9b3a9a2190 --- /dev/null +++ b/modelopt/torch/puzzletron/anymodel/puzzformer/no_op.py @@ -0,0 +1,79 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""No-op modules for replacing layers during pruning.""" + +from functools import cache + +import torch +import torch.nn as nn + + +@cache +def return_tuple_of_size(cls: type[nn.Module], size: int) -> type[nn.Module]: + """Create a wrapper class that returns a tuple of the given size. + + Useful for replacing modules that return multiple outputs (e.g., attention layers + that return (hidden_states, attn_weights)). + + Args: + cls: The base module class to wrap. + size: The size of the tuple to return. + + Returns: + A new class that wraps the base class and returns a tuple of the given size. + + Example: + >>> decoder_layer.self_attn = return_tuple_of_size(MatchingZeros, size=2)() + """ + + class Wrapped(cls): + def forward(self, *args, **kwargs): + result = super().forward(*args, **kwargs) + outputs = [None] * size + outputs[0] = result if isinstance(result, torch.Tensor) else result[0] + return tuple(outputs) + + def extra_repr(self) -> str: + return f"[{cls.__name__}]" + + return Wrapped + + +class MatchingZeros(nn.Module): + """Module that returns zeros matching the input shape. + + Used to replace MLP or attention layers with no-ops. Returns zeros because + the hidden_states are added to the residuals, so a no-op implementation + should leave the residual unchanged. + """ + + def forward(self, hidden_states, *args, **kwargs): + return torch.zeros_like(hidden_states) + + +class Same(nn.Module): + """Module that returns the input unchanged. + + Used to replace normalization layers with identity operations. + """ + + def forward(self, hidden_states, *args, **kwargs): + return hidden_states + + @property + def weight(self): + """Support NemotronH with scoring_activations, when lm_head is called `self.lm_head.weight.dtype`.""" + return torch.empty(0) diff --git a/modelopt/torch/puzzletron/anymodel/puzzformer/utils.py b/modelopt/torch/puzzletron/anymodel/puzzformer/utils.py new file mode 100644 index 0000000000..93913b8e2b --- /dev/null +++ b/modelopt/torch/puzzletron/anymodel/puzzformer/utils.py @@ -0,0 +1,122 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# mypy: ignore-errors + +import copy +import inspect +from contextlib import ExitStack, contextmanager +from functools import wraps +from typing import Any, Dict, List + +from transformers import PretrainedConfig + +from modelopt.torch.puzzletron.anymodel.model_descriptor.model_descriptor import ModelDescriptor +from modelopt.torch.puzzletron.decilm.deci_lm_hf_code.block_config import ( + BlockConfig, + maybe_cast_block_configs, +) + + +def _get_variable_from_stack(names: list[str]) -> Any: + """Search the call stack for a variable with one of the given names.""" + f = inspect.currentframe().f_back + while f: + for name in names: + if name in f.f_locals: + return f.f_locals[name] + f = f.f_back + raise RuntimeError(f"{names} not found in caller stack") + + +@contextmanager +def deci_x_patcher( + model_descriptor: ModelDescriptor, + block_configs: List[BlockConfig | dict] | None = None, +): + """Context manager that patches decoder layer __init__ for heterogeneous per-layer configs. + + This is the core mechanism that enables AnyModel to work with any HuggingFace model. + It patches the decoder layer class(es) to read per-layer block_configs and apply + layer-specific overrides (e.g., different intermediate_size per layer). + + Args: + model_descriptor: The model descriptor that defines which classes to patch + and how to map block_configs to layer overrides. + block_configs: Optional list of BlockConfig (one per layer). If not provided, + will try to read from config.block_configs during model initialization. + + Example: + >>> with deci_x_patcher(LlamaModelDescriptor, block_configs): + ... model = AutoModelForCausalLM.from_config(config) + """ + decoder_layer_classes = model_descriptor.decoder_layer_cls() # Now a list of classes + if not isinstance(decoder_layer_classes, list): + decoder_layer_classes = [decoder_layer_classes] + + orig_inits = [] + for cls in decoder_layer_classes: + orig_inits.append(cls.__init__) + + block_configs = maybe_cast_block_configs(block_configs) + + @wraps(orig_inits[0]) + def _patched_decoder_layer_init(self, config, *args, **kwargs): + _block_configs = block_configs or getattr(config, "block_configs", None) + if _block_configs is None: + return orig_inits[decoder_layer_classes.index(self.__class__)]( + self, config, *args, **kwargs + ) + + _block_configs = maybe_cast_block_configs(_block_configs) + layer_idx = _get_variable_from_stack(["layer_idx", "idx"]) + _block_config = _block_configs[layer_idx] + override_block_config = model_descriptor.block_config_to_layer_overrides(_block_config) + _config = override_config_with_block_configs(config, override_block_config) + orig_inits[decoder_layer_classes.index(self.__class__)](self, _config, *args, **kwargs) + + # Apply no-op post-init + if _block_config.attention.no_op: + if not model_descriptor.attn_no_op_supported(): + raise NotImplementedError( + f"attn no-op not supported for `{model_descriptor.__class__.__name__}`, " + "please implement the method: `attn_no_op_post_init()`" + ) + model_descriptor.attn_no_op_post_init(decoder_layer=self) + + if _block_config.ffn.no_op: + if not model_descriptor.mlp_no_op_supported(): + raise NotImplementedError( + f"mlp no-op not supported for `{model_descriptor.__class__.__name__}`, " + "please implement the method: `mlp_no_op_post_init()`" + ) + model_descriptor.mlp_no_op_post_init(decoder_layer=self) + + with ExitStack() as stack: + # Patch every decoder layer class + for orig_init, cls in zip(orig_inits, decoder_layer_classes): + stack.callback(setattr, cls, "__init__", orig_init) # Restore on exit + cls.__init__ = _patched_decoder_layer_init + yield + + +def override_config_with_block_configs( + config: PretrainedConfig, block_configs: Dict[str, Any] +) -> PretrainedConfig: + """Create a copy of config with block_config overrides applied.""" + _config = copy.deepcopy(config) + # Model initialization requires fails with None in case of no-ops + _config_overrides = {k: v for k, v in block_configs.items() if v is not None} + _config.update(_config_overrides) + return _config diff --git a/modelopt/torch/puzzletron/build_library_and_stats.py b/modelopt/torch/puzzletron/build_library_and_stats.py new file mode 100644 index 0000000000..31cebdf6be --- /dev/null +++ b/modelopt/torch/puzzletron/build_library_and_stats.py @@ -0,0 +1,90 @@ +#!/usr/bin/env python3 +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Unified command that runs build_replacement_library followed by calc_subblock_stats. + +This script combines the functionality of both commands into a single workflow: +1. First, it builds the replacement library for the puzzle +2. Then, it calculates subblock statistics + +Usage: + + python modelopt.torch.puzzletron.build_library_and_stats.py --config-dir configs --config-name Llama-3_1-8B puzzle_dir=/path/to/puzzle/dir dataset_path=/path/to/dataset + +The script uses the same Hydra configuration as the individual commands and supports +all the same configuration parameters for both build_replacement_library and calc_subblock_stats. +""" + +import hydra +from omegaconf import DictConfig + +from modelopt.torch.puzzletron.replacement_library.build_replacement_library import ( + launch_build_replacement_library, +) +from modelopt.torch.puzzletron.subblock_stats.calc_subblock_stats import launch_calc_subblock_stats +from modelopt.torch.puzzletron.tools.hydra_utils import register_hydra_resolvers +from modelopt.torch.puzzletron.tools.logger import mprint +from modelopt.torch.puzzletron.utils.parsing import format_global_config + + +def launch_build_library_and_stats(cfg: DictConfig) -> None: + """ + Launch both build_replacement_library and calc_subblock_stats in sequence. + + Args: + cfg: Hydra configuration containing settings for both commands + """ + mprint("=" * 80) + mprint("STARTING UNIFIED BUILD LIBRARY AND STATS WORKFLOW") + mprint("=" * 80) + + # Step 1: Build replacement library + mprint("=" * 50) + mprint("STEP 1: Building Replacement Library") + mprint("=" * 50) + + try: + launch_build_replacement_library(cfg) + mprint("✅ Replacement library built successfully!") + except Exception as e: + mprint(f"❌ Failed to build replacement library: {e}") + raise + + # Step 2: Calculate subblock statistics + mprint("=" * 50) + mprint("STEP 2: Calculating Subblock Statistics") + mprint("=" * 50) + + try: + launch_calc_subblock_stats(cfg) + mprint("✅ Subblock statistics calculated successfully!") + except Exception as e: + mprint(f"❌ Failed to calculate subblock statistics: {e}") + raise + + mprint("=" * 80) + mprint("UNIFIED WORKFLOW COMPLETED SUCCESSFULLY! 🎉") + mprint("=" * 80) + + mprint("Generated files:") + mprint(f" - {cfg.puzzle_dir}/block_library.json") + mprint(f" - {cfg.puzzle_dir}/subblock_library.json") + mprint(f" - {cfg.puzzle_dir}/replacement_library.json") + mprint(f" - {cfg.puzzle_dir}/single_sequence_replacement_solutions.json") + mprint(f" - {cfg.puzzle_dir}/{cfg.calc_subblock_stats.subblock_stats_filename}") + if hasattr(cfg.calc_subblock_stats, "moe_stats_filename"): + mprint(f" - {cfg.puzzle_dir}/{cfg.calc_subblock_stats.moe_stats_filename}") diff --git a/modelopt/torch/puzzletron/dataset/__init__.py b/modelopt/torch/puzzletron/dataset/__init__.py new file mode 100644 index 0000000000..47f1c65a15 --- /dev/null +++ b/modelopt/torch/puzzletron/dataset/__init__.py @@ -0,0 +1,15 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + diff --git a/modelopt/torch/puzzletron/dataset/prepare_dataset.py b/modelopt/torch/puzzletron/dataset/prepare_dataset.py new file mode 100644 index 0000000000..6f1749697c --- /dev/null +++ b/modelopt/torch/puzzletron/dataset/prepare_dataset.py @@ -0,0 +1,65 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os + +import datasets +import fire +import numpy as np + +from modelopt.torch.puzzletron.tools.logger import mprint + + +def process_and_save_dataset( + dataset_name: str, + output_dir: str, + split: tuple = ("code", "math", "stem", "chat"), + overwrite: bool = False, +): + # Check if output_dir contains an existing dataset + dataset_dict_path = os.path.join(output_dir, "dataset_dict.json") + if os.path.exists(output_dir) and os.path.exists(dataset_dict_path): + if not overwrite: + mprint( + f"Output directory '{output_dir}' already contains a dataset. " + "Use '--overwrite True' to overwrite existing data." + ) + return + + ds = datasets.load_dataset(dataset_name, split=split) + ds = datasets.concatenate_datasets(ds) + # Filter out samples with reasoning = on + ds = ds.filter(lambda x: x["reasoning"] == "off") + # Hardcoded for dynamically create a deterministic train-val split + seed = 408 + generator = np.random.RandomState(seed=seed) + ds_split = ds.train_test_split(test_size=0.05, shuffle=True, generator=generator) + # Rename dataset names to follow previous conventions + ds_dict = datasets.DatasetDict( + { + "train": ds_split["train"], + "valid": ds_split["test"], + } + ) + # Save locally + os.makedirs(output_dir, exist_ok=True) + ds_dict.save_to_disk(output_dir) + + mprint(f"Dataset splits:\n{ds_dict}") + mprint(f"Saved processed datasets to {output_dir}") + + +if __name__ == "__main__": + fire.Fire(process_and_save_dataset) diff --git a/modelopt/torch/puzzletron/decilm/deci_lm_hf_code/__init__.py b/modelopt/torch/puzzletron/decilm/deci_lm_hf_code/__init__.py new file mode 100644 index 0000000000..47f1c65a15 --- /dev/null +++ b/modelopt/torch/puzzletron/decilm/deci_lm_hf_code/__init__.py @@ -0,0 +1,15 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + diff --git a/modelopt/torch/puzzletron/decilm/deci_lm_hf_code/block_config.py b/modelopt/torch/puzzletron/decilm/deci_lm_hf_code/block_config.py new file mode 100644 index 0000000000..fb630335c6 --- /dev/null +++ b/modelopt/torch/puzzletron/decilm/deci_lm_hf_code/block_config.py @@ -0,0 +1,277 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# mypy: ignore-errors +import dataclasses +import inspect +import warnings +from abc import abstractmethod +from dataclasses import dataclass +from typing import Any, List, Optional, Type, Union, get_args, get_origin + + +@dataclass(frozen=True, kw_only=True) +class BaseDataclass: + """ + A dataclass base class with several utilities: + 1. Comparison via string representation. + 2. Initialization of dataclasses fields from dicts. + 3. Setting attributes even though it's frozen (but only inside __post_init__!) + """ + + def __eq__(self, other: "BaseDataclass") -> bool: + return str(self) == str(other) + + def __hash__(self) -> int: + return hash(str(self)) + + def __lt__(self, other: "BaseDataclass") -> bool: + return str(self) < str(other) + + def _force_setattr(self, name: str, value: Any) -> None: + """ + Set an attribute even in frozen dataclasses. + Use only inside __post_init__! + """ + assert _is_called_from_post_init(), ( + "_force_setattr should only be called from __post_init__, " + "if you need to change an attribute use dataclasses.replace " + "or create a new instance :)" + ) + object.__setattr__(self, name, value) + + def __post_init__(self): + """ + Init dataclass fields from dicts + """ + for field in dataclasses.fields(self): + field_dict = getattr(self, field.name) + if isinstance(field_dict, dict) and _is_dataclass_type(field.type): + dataclass_cls = _get_dataclass_type(field.type) + sub_fields = [field.name for field in dataclasses.fields(dataclass_cls)] + unsupported_fields = [ + field_name for field_name in field_dict.keys() if field_name not in sub_fields + ] + if len(unsupported_fields) > 0: + warnings.warn( + f"Removed unsupported fields {unsupported_fields} from {dataclass_cls}" + ) + + field_dict = {k: v for k, v in field_dict.items() if k not in unsupported_fields} + self._force_setattr(field.name, dataclass_cls(**field_dict)) + + +def _is_called_from_post_init() -> bool: + frame = inspect.currentframe() + while frame: + if frame.f_code.co_name == "__post_init__": + return True + frame = frame.f_back + return False + + +def _is_dataclass_type(tp: Type) -> bool: + """ + Like dataclasses.is_dataclass but also works for Optional[] and Union[] of a dataclass type + """ + try: + _get_dataclass_type(tp) + return True + except: + return False + + +def _get_dataclass_type(tp: Type) -> dataclass: + """ + If the given type is a dataclass, the function returns it. + If it is a Union[] or Optional[], the function extracts the first dataclass type. + If no dataclass type is found, the function raises a ValueError. + """ + origin = get_origin(tp) + if origin is Union: + for type_in_union in get_args(tp): + if dataclasses.is_dataclass(type_in_union): + return type_in_union + if dataclasses.is_dataclass(tp): + return tp + raise ValueError("Not a dataclass") + + +@dataclass(frozen=True, kw_only=True) +class SubblockConfig(BaseDataclass): + no_op: bool = False + replace_with_linear: bool = False + sparsify: Optional[list[str]] = None + weights_precision: Optional[str] = "bf16" + + def __post_init__(self): + super().__post_init__() + assert not (self.no_op and self.replace_with_linear) + if self.no_op: + self._force_setattr("sparsify", None) + + @abstractmethod + def to_blockconfig(self) -> "BlockConfig": + """ " + Convert to a block including this subblock only. + """ + ... + + +@dataclass(frozen=True, kw_only=True) +class MoEConfig(BaseDataclass): + """ + Configuration class for Mixture of Experts parameters. + """ + + num_local_experts: int = 8 + num_experts_per_tok: int = 1 + expert_intermediate_dim: int = 8192 + shared_expert_intermediate_dim: int = 8192 + # router_aux_loss_coef: float = 0.01 + # router_z_loss_coef: float = 0.0 # Optional z-loss coefficient + + def __post_init__(self): + # Validate the configuration + if self.num_local_experts <= 0: + raise ValueError(f"num_local_experts must be positive, got {self.num_local_experts}") + if self.num_experts_per_tok <= 0: + raise ValueError(f"top_k must be positive, got {self.top_k}") + if self.num_experts_per_tok > self.num_local_experts: + raise ValueError( + f"top_k ({self.top_k}) cannot be greater than num_local_experts ({self.num_local_experts})" + ) + # if self.router_aux_loss_coef < 0: + # raise ValueError(f"router_aux_loss_coef must be non-negative, got {self.router_aux_loss_coef}") + + +@dataclass(frozen=True, kw_only=True) +class MambaConfig(BaseDataclass): + state_dim: int + num_heads: int + head_dim: int + num_groups: int + + +@dataclass(frozen=True, kw_only=True) +class Llama4AttentionConfig(BaseDataclass): + attention_chunk_size: Optional[int] = None + use_rope: Optional[bool] = None + use_qk_norm: Optional[bool] = None + attn_scale: Optional[float] = None + floor_scale: Optional[float] = None + attn_temperature_tuning: Optional[bool] = None + attention_dropout: Optional[float] = None + + +@dataclass(frozen=True, kw_only=True) +class AttentionConfig(SubblockConfig): + num_key_value_heads: Optional[int] = None + llama4: Optional[Llama4AttentionConfig] = None + mamba: Optional[MambaConfig] = None + + def __post_init__(self): + super().__post_init__() + + if self.no_op: + assert not self.is_mamba + assert not self.is_llama4 + + if self.no_op or self.is_mamba: + for irrelevant_att in [ + "num_key_value_heads", + ]: + self._force_setattr(irrelevant_att, None) + else: + assert self.num_key_value_heads is not None + + def to_blockconfig(self) -> "BlockConfig": + return BlockConfig(attention=self, ffn=FFNConfig(no_op=True)) + + @property + def is_llama4(self) -> bool: + return self.llama4 is not None + + @property + def is_mamba(self) -> bool: + return self.mamba is not None + + +@dataclass(frozen=True, kw_only=True) +class FFNConfig(SubblockConfig): + moe: Optional[MoEConfig] = None + intermediate_size: Optional[int] = None + + def __post_init__(self): + super().__post_init__() + if self.no_op: + self._force_setattr("moe", None) + self._force_setattr("intermediate_size", None) + elif self.is_moe: + self._force_setattr("intermediate_size", None) + else: + assert self.intermediate_size is not None, ( + "Intermediate size must be provided for an FFN block" + ) + + def to_blockconfig(self) -> "BlockConfig": + return BlockConfig(attention=AttentionConfig(no_op=True), ffn=self) + + @property + def is_moe(self) -> bool: + return self.moe is not None + + +SUBBLOCK_CLS_DICT = { + "attention": AttentionConfig, + "ffn": FFNConfig, +} + + +@dataclass(frozen=True, kw_only=True) +class BlockConfig(BaseDataclass): + attention: Optional[AttentionConfig] = None + ffn: Optional[FFNConfig] = None + parallel_blocks: Optional[list["BlockConfig"]] = None + + def __post_init__(self): + super().__post_init__() + if (self.parallel_blocks is not None) and isinstance(self.parallel_blocks[0], dict): + initialized_block_configs = [ + BlockConfig(**block_config) for block_config in self.parallel_blocks + ] + self._force_setattr("parallel_blocks", initialized_block_configs) + + def to_dict(self) -> dict: + """Convert BlockConfig to a dictionary.""" + return dataclasses.asdict(self) + + +def maybe_cast_block_configs( + block_configs: List[BlockConfig | dict] | None, +) -> List[BlockConfig] | None: + """Cast a list of dicts to BlockConfig objects if needed. + + Args: + block_configs: List of BlockConfig or dict objects, or None. + + Returns: + List of BlockConfig objects, or None if input is None/empty. + """ + if not block_configs: + return block_configs + if isinstance(block_configs[0], dict): + return [BlockConfig(**conf) for conf in block_configs] + return block_configs diff --git a/modelopt/torch/puzzletron/decilm/deci_lm_hf_code/modeling_decilm.py b/modelopt/torch/puzzletron/decilm/deci_lm_hf_code/modeling_decilm.py new file mode 100644 index 0000000000..0a0f8ab1ef --- /dev/null +++ b/modelopt/torch/puzzletron/decilm/deci_lm_hf_code/modeling_decilm.py @@ -0,0 +1,29 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Copyright 2024 Nvidia Corporation, Google Inc, HuggingFace Inc, EleutherAI. All rights reserved. +# +# Small nn helpers for puzzletron pipeline code. Model configs come from HuggingFace ``AutoConfig`` (AnyModel). +# ``LMHead`` is a distinct ``nn.Linear`` subclass so pipeline / FSDP code can target it explicitly +# (see ``validate_runtime_pipeline``). +# mypy: ignore-errors + +from torch import nn + + +class LMHead(nn.Linear): + """ + Special class to allow FSDP wrapping without affecting other Linear layers in the model. + """ diff --git a/modelopt/torch/puzzletron/export/mbridge/__init__.py b/modelopt/torch/puzzletron/export/mbridge/__init__.py new file mode 100644 index 0000000000..471e68984b --- /dev/null +++ b/modelopt/torch/puzzletron/export/mbridge/__init__.py @@ -0,0 +1,35 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Megatron-Bridge adapters for Puzzletron AnyModel checkpoints. + +This module provides bridges for converting Puzzletron AnyModel checkpoints +(heterogeneous layer architectures) to Megatron-Core format via Megatron-Bridge. +""" + +# Import to register bridges (side effect) +from modelopt.torch.puzzletron.export.mbridge.base import HeterogeneousBridgeMixin +from modelopt.torch.puzzletron.export.mbridge.llama import ( # noqa: F401 + PuzzletronLlamaAnyModelBridge, +) +from modelopt.torch.puzzletron.export.mbridge.qwen3 import ( # noqa: F401 + PuzzletronQwen3AnyModelBridge, +) + +__all__ = [ + "HeterogeneousBridgeMixin", + "PuzzletronLlamaAnyModelBridge", + "PuzzletronQwen3AnyModelBridge", +] diff --git a/modelopt/torch/puzzletron/export/mbridge/base.py b/modelopt/torch/puzzletron/export/mbridge/base.py new file mode 100644 index 0000000000..13ea6612af --- /dev/null +++ b/modelopt/torch/puzzletron/export/mbridge/base.py @@ -0,0 +1,142 @@ +#!/usr/bin/env python3 +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Mixin class for bridges that support heterogeneous layer architectures. + +This module provides a mixin class for converting models with block_configs +(heterogeneous layer configurations) to Megatron-Core format via Megatron-Bridge. +""" + +import dataclasses +import json +from collections.abc import Callable +from dataclasses import dataclass, fields + +from megatron.bridge.models.gpt_provider import GPTModelProvider +from megatron.bridge.models.hf_pretrained.causal_lm import PreTrainedCausalLM +from megatron.bridge.models.transformer_config import HeterogeneousTransformerConfig +from megatron.core.models.gpt.heterogeneous.heterogeneous_layer_specs import ( + get_gpt_heterogeneous_layer_spec, +) +from megatron.core.transformer.spec_utils import ModuleSpec + + +def heterogeneous_layer_spec(config) -> ModuleSpec: + """Get GPT heterogeneous layer spec using Transformer Engine.""" + return get_gpt_heterogeneous_layer_spec(config, use_te=True) + + +@dataclass +class GenericHeterogeneousProvider(GPTModelProvider, HeterogeneousTransformerConfig): + """Generic provider for AnyModel checkpoints with block_configs.""" + + # Heterogeneous configuration fields + heterogeneous_layers_config_path: str | None = None + heterogeneous_layers_config_encoded_json: str = "" + transformer_layer_spec: ModuleSpec | Callable = heterogeneous_layer_spec + + def __getattr__(self, name: str): + """Handle missing attributes for OmegaConf compatibility. + + Returns empty list for per_block_parameters if not yet initialized (before finalize()). + This allows OmegaConf to serialize/deserialize configs without errors. Actual usage + should call finalize() first to set per_block_parameters as a real attribute. + """ + if name == "per_block_parameters": + # Return existing attribute if set, otherwise [] for OmegaConf compatibility + try: + return object.__getattribute__(self, name) + except AttributeError: + return [] + raise AttributeError(f"'{self.__class__.__name__}' object has no attribute '{name}'") + + +class HeterogeneousBridgeMixin: + """Mixin for bridges supporting heterogeneous layer architectures (block_configs). + + Must be used with multiple inheritance alongside a model-specific bridge. + Example: class PuzzletronLlamaAnyModelBridge(HeterogeneousBridgeMixin, LlamaBridge) + """ + + def provider_bridge(self, hf_pretrained: PreTrainedCausalLM) -> GPTModelProvider: + """Convert HF AnyModel config to Megatron GPTModelProvider. + + This method: + 1. Calls the parent bridge's provider_bridge() to get a GPTModelProvider with all + model-specific settings (e.g., LlamaBridge sets normalization="RMSNorm", etc.) + 2. Converts the provider to a dict and filters to only fields accepted by + GenericHeterogeneousProvider (which inherits from GPTModelProvider, so all valid + GPTModelProvider fields are preserved) + 3. Adds heterogeneous configuration and returns GenericHeterogeneousProvider + + All parameters from the parent bridge (e.g., LlamaBridge) are maintained because + GenericHeterogeneousProvider inherits from GPTModelProvider, which includes all + the fields that the parent bridge sets. + """ + + parent_provider = super().provider_bridge(hf_pretrained) # type: ignore[misc] + + provider_kwargs = dataclasses.asdict(parent_provider) + + # Filter to only fields that GenericHeterogeneousProvider accepts. + # GenericHeterogeneousProvider inherits from GPTModelProvider, so it includes all + # GPTModelProvider fields. Model-specific fields from subclasses (e.g., MistralModelProvider, + # GPTOSSModelProvider) are filtered out because GenericHeterogeneousProvider only inherits + # from GPTModelProvider, not from model-specific subclasses. + # + # Note: This logic may not work for bridges like MistralBridge or GPTOSSBridge if they + # use model-specific parameters not supported by GenericHeterogeneousProvider (e.g., + # scale_factor, yarn_rotary_scaling_factor, moe_* parameters). In such cases, create a + # model-specific heterogeneous provider that inherits from the model-specific provider. + valid_fields = {f.name for f in fields(GenericHeterogeneousProvider)} + + # Only keep kwargs that are valid fields + provider_kwargs = {k: v for k, v in provider_kwargs.items() if k in valid_fields} + + provider_kwargs["heterogeneous_layers_config_encoded_json"] = ( + self._build_heterogeneous_config_json(hf_pretrained.config) + ) + return GenericHeterogeneousProvider(**provider_kwargs) + + def _build_heterogeneous_config_json(self, hf_config) -> str: + """Build heterogeneous layers config JSON from HF config.""" + + hf_config_dict = json.loads(hf_config.to_json_string()) + + mcore_block_configs = [ + self._convert_block_config(block) for block in hf_config_dict["block_configs"] + ] + return json.dumps({"block_configs": mcore_block_configs}, ensure_ascii=False) + + def _convert_block_config(self, block: dict) -> dict: + """Convert a single block config from HF format to MCore format.""" + return { + "attention": self._convert_attention_config(block["attention"]), + "ffn": self._convert_ffn_config(block["ffn"]), + } + + def _convert_attention_config(self, attention_config: dict) -> dict: + """Convert attention config from HF format to MCore format.""" + attention_config = attention_config.copy() + attention_config["num_query_groups"] = attention_config.pop("num_key_value_heads") + return attention_config + + def _convert_ffn_config(self, ffn_config: dict) -> dict: + """Convert FFN/MLP config from HF format to MCore format.""" + ffn_config = ffn_config.copy() + ffn_config["ffn_hidden_size"] = ffn_config.pop("intermediate_size") + return ffn_config diff --git a/modelopt/torch/puzzletron/export/mbridge/distillation_provider.py b/modelopt/torch/puzzletron/export/mbridge/distillation_provider.py new file mode 100644 index 0000000000..fa49dc29c5 --- /dev/null +++ b/modelopt/torch/puzzletron/export/mbridge/distillation_provider.py @@ -0,0 +1,190 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# TODO: Upstream this fix to Megatron-Bridge and remove this local copy. + +import logging +from dataclasses import dataclass, fields +from typing import TYPE_CHECKING, Any, Optional + +from megatron.bridge.models.gpt_provider import GPTModelProvider +from megatron.bridge.models.mamba.mamba_provider import MambaModelProvider +from megatron.bridge.models.transformer_config import TransformerConfig +from megatron.core.models.gpt import GPTModel as MCoreGPTModel + +import modelopt.torch.distill as mtd +import modelopt.torch.distill.plugins.megatron as mtd_mcore + +if TYPE_CHECKING: + from megatron.bridge.training.post_training.distillation import ModelOptDistillConfig + + +logger = logging.getLogger(__name__) + + +@dataclass +class DistillationProvider(TransformerConfig): + """Provider for Megatron Core GPT models in distillation mode. + + Please use `convert_to_distillation_provider()` to create an instance of this class. + """ + + teacher: Optional[GPTModelProvider | MambaModelProvider] = None + kd_config: Optional["ModelOptDistillConfig"] = None + + def __init__(self, *args, **kwargs): + raise NotImplementedError( + "Use `convert_to_distillation_provider()` to create an instance of this class." + ) + + def __post_init__(self): + assert getattr(self, "teacher", None) is not None, "Teacher model must be provided." + + shared_attrs = [ + "tensor_model_parallel_size", + "pipeline_model_parallel_size", + "context_parallel_size", + "seq_length", + "pipeline_dtype", + ] + for attr in shared_attrs: + if getattr(self, attr) != getattr(self.teacher, attr): + raise ValueError(f"Student and teacher providers must have the same {attr}.") + + # Logits are overwritten in-place when TE cross-entropy loss is enabled, so switch it back to native version. + self.cross_entropy_fusion_impl = "native" + + # Hack to dynamically subclass other providers and still use their methods + self._super_class = self.__class__.__bases__[0] + + def provide(self, pre_process=None, post_process=None, vp_stage=None) -> MCoreGPTModel: + """Configure and instantiate a ModelOpt DistillationModel based on this configuration. + + Args: + pre_process: Whether to include pre-processing in the model, defaults to first pipeline stage + post_process: Whether to include post-processing in the model, defaults to last pipeline stage + vp_stage: Virtual pipeline stage + + Returns: + MCoreGPTModel: Configured ModelOpt DistillationModel instance + """ + if vp_stage is not None: + raise ValueError("ModelOpt KD currently does not support virtual-pipeline parallel.") + + assert self.teacher is not None, "Teacher model must be provided." + student_model = self._super_class.provide(self, pre_process, post_process, vp_stage) # type: ignore[attr-defined] + + # Finalize teacher provider before creating model (required for heterogeneous models). + # + # per_block_parameters is an attribute of HeterogeneousTransformerConfig (defined in + # MCoreHeterogeneousTransformerConfig, heterogeneous_config.py:197). It's created during + # provider creation (bridge.to_megatron_provider()), but finalize() ensures they're consistent + # with current parallelism settings and distributed context. Student model creation (above) + # initializes parallel_state (process groups, TP/PP config), which weight loading/scatter + # requires. During teacher model creation, get_config_for_layer() is called (transformer_block.py:341) + # for each layer, which uses per_block_parameters and current tensor_model_parallel_size to + # determine layer architecture. Without finalize() in this context, architecture expectations + # don't match checkpoint weights, causing: + # ValueError: ProcessGroupNCCL::scatter: invalid tensor size at index 0 + # (expected (2880, 4096), got (3584, 4096)) + # + # Note: This explanation needs to be confirmed yet. + self.teacher.finalize() + + # Hack to get teacher's pre-wrap hooks called to potentially load HF weights + teacher_model = self.teacher.provide_distributed_model( + wrap_with_ddp=False, mixed_precision_wrapper=None + )[0] + + kd_cfg = mtd_mcore.setup_distillation_config( + self.kd_config, student_model.config, teacher_model.config + ) + modelopt_cfg = { + "teacher_model": teacher_model, + "criterion": kd_cfg.criterion, + "loss_balancer": kd_cfg.loss_balancer, + } + kd_model = mtd.convert(student_model, mode=[("kd_loss", modelopt_cfg)]) + mtd_mcore.adjust_distillation_model_for_mcore(kd_model, kd_cfg) + + return kd_model + + def to_cfg_dict(self) -> dict[str, Any]: + """Custom method to save equivalent to the original provider class. + + Used by `_ConfigContainerBase` to serialize the main `ConfigContainer` to YAML. + There is no need to restore a `DistillationProvider` from the run config file, as + it can always be re-converted using the original student provider. + + Returns: + Dictionary representation of this provider class + """ + from megatron.bridge.training.utils.config_utils import _ConfigContainerBase + + result = {"_target_": f"{self._super_class.__module__}.{self._super_class.__qualname__}"} + + # Include all fields from the original provider class (self._super_class), not just DistillationProvider + # This ensures fields like heterogeneous_layers_config_encoded_json are preserved + excluded_fields = {"teacher", "kd_config"} + for field in fields(self._super_class): + if field.name.startswith("_") or field.name in excluded_fields: + continue + # Only include if the field exists on this instance (it should, since we converted from the original provider) + if hasattr(self, field.name): + result[field.name] = _ConfigContainerBase._convert_value_to_dict( + getattr(self, field.name) + ) + + # Also include any additional fields from DistillationProvider itself (if any) + for field in fields(self): + if field.name.startswith("_") or field.name in excluded_fields: + continue + # Skip if already included from _super_class + if field.name not in result: + result[field.name] = _ConfigContainerBase._convert_value_to_dict( + getattr(self, field.name) + ) + + return result + + def __setattr__(self, name, value): + super().__setattr__(name, value) + # Mirror to teacher if it has that attribute + if hasattr(self.teacher, name): + setattr(self.teacher, name, value) + + +def convert_to_distillation_provider( + student_provider: GPTModelProvider | MambaModelProvider, + teacher_provider: GPTModelProvider | MambaModelProvider, + kd_config: Optional["ModelOptDistillConfig"] = None, +) -> "DistillationProvider": + """Convert a given model provider to a DistillationProvider.""" + + assert isinstance(student_provider, (GPTModelProvider, MambaModelProvider)), ( + "Student provider must be a subclass of GPTModelProvider or MambaModelProvider." + ) + assert isinstance(teacher_provider, (GPTModelProvider, MambaModelProvider)), ( + "Teacher provider must be a subclass of GPTModelProvider or MambaModelProvider." + ) + + DistillationProvider.__bases__ = (type(student_provider),) + student_provider.__class__ = DistillationProvider + + student_provider.teacher = teacher_provider + student_provider.kd_config = kd_config + student_provider.__post_init__() + + return student_provider diff --git a/modelopt/torch/puzzletron/export/mbridge/export_mbridge_to_hf.py b/modelopt/torch/puzzletron/export/mbridge/export_mbridge_to_hf.py new file mode 100644 index 0000000000..0ab6083f77 --- /dev/null +++ b/modelopt/torch/puzzletron/export/mbridge/export_mbridge_to_hf.py @@ -0,0 +1,81 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Export utilities for Megatron-Bridge checkpoints.""" + +import shutil +from pathlib import Path + +from megatron.bridge import AutoBridge + +from modelopt.torch.utils import print_rank_0 + + +def export_to_hf_and_copy_config( + student_hf_path: str, + checkpoint_dir: str, + train_iters: int, + hf_export_path: str, + hf_model: str, + trust_remote_code: bool = False, +) -> None: + """ + Export Megatron checkpoint to HuggingFace format and copy config.json from student model. + + TODO: This script should not be needed (manually copying config.json from + student model to exported model). Remove it once export_to_hf() in AutoBridge + supports copying/preserving config.json from student model. + + + Args: + student_hf_path: Path to the original student HuggingFace model (source of config.json) + checkpoint_dir: Base directory where Megatron checkpoints are stored + train_iters: Number of training iterations (used to construct final checkpoint path) + hf_export_path: Directory path where the HuggingFace model will be saved + hf_model: HuggingFace model ID to use as template for export (e.g., meta-llama/Llama-3.1-8B-Instruct) + trust_remote_code: Whether to trust remote modeling code when loading the HF template model + """ + print_rank_0(f"\n{'=' * 80}") + print_rank_0("Exporting to HuggingFace format...") + print_rank_0(f"{'=' * 80}\n") + + # Construct path to final checkpoint iteration (format: iter_0000100 for 100 iterations) + final_iter_dir = Path(checkpoint_dir) / f"iter_{train_iters:07d}" + print_rank_0(f"📂 Using final checkpoint: {final_iter_dir}") + + # Use the final iteration directory for export (export_ckpt will validate it exists) + megatron_path = str(final_iter_dir) + + # Create bridge using standard model ID (not AnyModel checkpoint) to avoid sharding structure issues + print_rank_0("🌉 Creating bridge...") + print_rank_0(f" Using model ID: {hf_model}") + bridge = AutoBridge.from_hf_pretrained(hf_model, trust_remote_code=trust_remote_code) + + print_rank_0("📤 Exporting to HuggingFace format...") + bridge.export_ckpt( + megatron_path=megatron_path, hf_path=hf_export_path, show_progress=True, strict=True + ) + + print_rank_0(f"✅ Successfully exported model to: {hf_export_path}") + + # Copy config.json from student model to exported model (preserves block_configs) + student_config_path = Path(student_hf_path) / "config.json" + exported_config_path = Path(hf_export_path) / "config.json" + + print_rank_0(f"📋 Copying config.json from student model: {student_config_path}") + shutil.copy(student_config_path, exported_config_path) + print_rank_0(f"✅ Copied config.json to: {exported_config_path}") + + print_rank_0(f"\n{'=' * 80}") + print_rank_0("Export complete!") diff --git a/modelopt/torch/puzzletron/export/mbridge/llama.py b/modelopt/torch/puzzletron/export/mbridge/llama.py new file mode 100644 index 0000000000..b802215298 --- /dev/null +++ b/modelopt/torch/puzzletron/export/mbridge/llama.py @@ -0,0 +1,38 @@ +#!/usr/bin/env python3 +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Megatron Bridge for Puzzletron Llama-based AnyModel heterogeneous checkpoints.""" + +from megatron.bridge.models.conversion.model_bridge import MegatronModelBridge +from megatron.bridge.models.llama.llama_bridge import LlamaBridge +from megatron.core.models.gpt.gpt_model import GPTModel +from transformers import LlamaForCausalLM + +from modelopt.torch.puzzletron.export.mbridge.base import HeterogeneousBridgeMixin + + +@MegatronModelBridge.register_bridge(source=LlamaForCausalLM, target=GPTModel) +class PuzzletronLlamaAnyModelBridge(HeterogeneousBridgeMixin, LlamaBridge): + """ + Megatron Bridge for Puzzletron Llama-based AnyModel checkpoints. + + Extends LlamaBridge with support for heterogeneous layer architectures (block_configs). + All Llama-specific settings are inherited from LlamaBridge. + """ + + # provider_bridge() is inherited from HeterogeneousBridgeMixin + # It automatically reuses LlamaBridge.provider_bridge() and adds heterogeneous config + # mapping_registry() is inherited from LlamaBridge diff --git a/modelopt/torch/puzzletron/export/mbridge/qwen3.py b/modelopt/torch/puzzletron/export/mbridge/qwen3.py new file mode 100644 index 0000000000..ace20fbf89 --- /dev/null +++ b/modelopt/torch/puzzletron/export/mbridge/qwen3.py @@ -0,0 +1,38 @@ +#!/usr/bin/env python3 +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Megatron Bridge for Puzzletron Qwen3-based AnyModel heterogeneous checkpoints.""" + +from megatron.bridge.models.conversion.model_bridge import MegatronModelBridge +from megatron.bridge.models.qwen.qwen3_bridge import Qwen3Bridge +from megatron.core.models.gpt.gpt_model import GPTModel +from transformers import Qwen3ForCausalLM + +from modelopt.torch.puzzletron.export.mbridge.base import HeterogeneousBridgeMixin + + +@MegatronModelBridge.register_bridge(source=Qwen3ForCausalLM, target=GPTModel) +class PuzzletronQwen3AnyModelBridge(HeterogeneousBridgeMixin, Qwen3Bridge): + """ + Megatron Bridge for Puzzletron Qwen3-based AnyModel checkpoints. + + Extends Qwen3Bridge with support for heterogeneous layer architectures (block_configs). + All Qwen3-specific settings are inherited from Qwen3Bridge. + """ + + # provider_bridge() is inherited from HeterogeneousBridgeMixin + # It automatically reuses Qwen3Bridge.provider_bridge() and adds heterogeneous config + # mapping_registry() is inherited from Qwen3Bridge diff --git a/modelopt/torch/puzzletron/mip/mip_and_realize_models.py b/modelopt/torch/puzzletron/mip/mip_and_realize_models.py new file mode 100644 index 0000000000..17d8e4a2db --- /dev/null +++ b/modelopt/torch/puzzletron/mip/mip_and_realize_models.py @@ -0,0 +1,73 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Runs MIP (Mixed Integer Programming) optimization and realizes the resulting model solutions.""" + +# mypy: ignore-errors +from pathlib import Path + +import torch +from omegaconf import DictConfig + +import modelopt.torch.utils.distributed as dist +from modelopt.torch.puzzletron.mip.run_puzzle import run_puzzle +from modelopt.torch.puzzletron.tools.logger import mprint +from modelopt.torch.puzzletron.tools.validate_puzzle_with_multi_replacements import ( + validate_puzzle_solutions, +) + + +def launch_mip(cfg: DictConfig) -> list[str]: + solution_paths = run_puzzle(args=cfg.mip) + return solution_paths + + +def launch_realize_model(cfg: DictConfig): + validate_puzzle_solutions(args=cfg.realize_model) + + +def launch_mip_and_realize_model(cfg: DictConfig) -> list[str]: + # Determine device for distributed operations (NCCL requires CUDA tensors) + device = "cpu" + if dist.size() > 1: + if torch.distributed.get_backend() == "nccl": + device = torch.cuda.current_device() + + if dist.is_master(): + solution_paths = launch_mip(cfg) + length_tensor = torch.tensor([len(solution_paths)], dtype=torch.long, device=device) + else: + solution_paths = None + length_tensor = torch.tensor([0], dtype=torch.long, device=device) + + if not cfg.skip_realize_model: + if dist.size() > 1: + torch.distributed.broadcast(length_tensor, src=0) + + list_length = length_tensor.item() + + if not dist.is_master(): + solution_paths = [None] * list_length + + if dist.size() > 1: + torch.distributed.broadcast_object_list(solution_paths, src=0) + + for solution_path in solution_paths: + mprint(f"Realize model for the solution: {solution_path}") + cfg.realize_model.solutions_path = Path(solution_path) + launch_realize_model(cfg) + dist.barrier() + + return solution_paths diff --git a/modelopt/torch/puzzletron/mip/mip_with_multi_layer_replacements.py b/modelopt/torch/puzzletron/mip/mip_with_multi_layer_replacements.py new file mode 100644 index 0000000000..5b4eccbc15 --- /dev/null +++ b/modelopt/torch/puzzletron/mip/mip_with_multi_layer_replacements.py @@ -0,0 +1,203 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Solves multi-layer replacement optimization using Mixed Integer Programming.""" + +# mypy: ignore-errors +import math +import warnings +from collections import defaultdict +from collections.abc import Hashable, Iterable +from copy import deepcopy +from random import random +from typing import Any, TypeAlias + +from mip import BINARY, Model, maximize, minimize, xsum + +from modelopt.torch.puzzletron.mip.utils import ( + consecutive_ngrams, + get_nested_key, + sort_replacements, +) + +ReplacementID: TypeAlias = Hashable +Replacement: TypeAlias = dict[str, Any] +ChosenReplacements: TypeAlias = list[Replacement] + + +def run_mip( + replacements: dict[ReplacementID, Replacement], + objective: str, + constraints: dict[str, float], + bigger_is_better: bool, + max_seconds_per_solution: float | None = None, +) -> tuple[ChosenReplacements, float, dict[str, float]]: + orig_num_replacements = len(replacements) + replacements = { + replacement_id: deepcopy(replacement) + for replacement_id, replacement in replacements.items() + if math.isfinite(get_nested_key(replacement, objective)) + } + if len(replacements) < orig_num_replacements: + print("\n\n\n") + warnings.warn( + f"mip: removed {orig_num_replacements - len(replacements)} replacements with NaN/inf objective value" + ) + print("\n\n\n") + + mip_model = Model() + + objective_vars = [] + constraint_vars = {constraint_key: [] for constraint_key in constraints} + choice_indicators_by_layer = defaultdict(list) + for replacement_id, replacement in replacements.items(): + is_chosen = mip_model.add_var(var_type=BINARY) + replacement["is_chosen"] = is_chosen + + for parent_layer_idx in replacement["parent_layer_indices"]: + choice_indicators_by_layer[parent_layer_idx].append(is_chosen) + + objective_vars.append(is_chosen * get_nested_key(replacement, objective)) + + for constraint_key in constraints: + constraint_vars[constraint_key].append( + is_chosen * get_nested_key(replacement, constraint_key) + ) + + # MIP constraints: each parent layer must come from exactly one chosen replacement + for parent_layer_idx, curr_choice_indicators in choice_indicators_by_layer.items(): + mip_model += xsum(curr_choice_indicators) == 1 + + # MIP constraints: the sum of chosen replacement costs must be lower than the max cost + for constraint_key, max_cost in constraints.items(): + min_cost = None + if isinstance(max_cost, Iterable): + min_cost, max_cost = max_cost + + if max_cost is not None: + mip_model += xsum(constraint_vars[constraint_key]) <= max_cost + if min_cost is not None: + mip_model += xsum(constraint_vars[constraint_key]) >= min_cost + + # MIP objective + mip_model.objective = ( + maximize(xsum(objective_vars)) if bigger_is_better else minimize(xsum(objective_vars)) + ) + + if max_seconds_per_solution is not None: + mip_model.max_seconds = max_seconds_per_solution + + mip_model.optimize() + + if is_chosen.x is None: + return [] + # raise InfeasibleError() + + # Trust But Verify: calculate total value and costs, and check that all the constraints are filled + total_value = 0.0 + total_costs = dict.fromkeys(constraints.keys(), 0) + chosen_replacements: ChosenReplacements = [] + chosen_layers = [] + for replacement_id, replacement in replacements.items(): + is_chosen = replacement["is_chosen"].x >= 0.99 + if is_chosen: + assert replacement not in chosen_replacements + chosen_replacements.append(replacement) + total_value += get_nested_key(replacement, objective) + for constraint_key in constraints: + total_costs[constraint_key] += get_nested_key(replacement, constraint_key) + for parent_layer_idx in replacement["parent_layer_indices"]: + assert parent_layer_idx not in chosen_layers + chosen_layers.append(parent_layer_idx) + + missing_layers = set(choice_indicators_by_layer.keys()) - set(chosen_layers) + assert len(missing_layers) == 0, ( + f"The following layers were not chosen by any replacement:\n{missing_layers=}\n{chosen_replacements}" + ) + + for constraint_key, max_cost in constraints.items(): + min_cost = None + if isinstance(max_cost, Iterable): + min_cost, max_cost = max_cost + + if max_cost is not None: + assert total_costs[constraint_key] < max_cost or math.isclose( + total_costs[constraint_key], max_cost, rel_tol=1e-9 + ), ( + f"This max_cost was violated {constraint_key} in the solution, sol val={total_costs[constraint_key]} > {max_cost=}" + ) + if min_cost is not None: + assert total_costs[constraint_key] > min_cost or math.isclose( + total_costs[constraint_key], min_cost, rel_tol=1e-9 + ), ( + f"This min_cost was violated {constraint_key} in the solution, sol val={total_costs[constraint_key]} < {min_cost=}" + ) + + chosen_replacements = sort_replacements(chosen_replacements) + for cr in chosen_replacements: + del cr["is_chosen"] # not copyable, will cause errors in deep copy + if "block_config" in cr: + cr["child_block_configs"] = cr["block_config"] + # del cr['block_config'] for now the dump includes both keys (duplicated values) # we might wanna either delete one of them or keep both + # I prefer keeping block_config and deleting 'child_block_configs' from previous puzzle steps + + return [ + { + "chosen_replacements": chosen_replacements, + "total_value": total_value, + "total_costs": total_costs, + } + ] + + +def usage_example(): + num_layers = 32 + num_options_per_parent_replacement = 5 + + replacements = dict() + for num_layers_in_replacement in (1, 2, 3): + for i_option in range(num_options_per_parent_replacement): + for parent_layer_indices in consecutive_ngrams(num_layers, num_layers_in_replacement): + replacement_id = f"parent layers {parent_layer_indices} child config {i_option}" + replacement = { + "parent_layer_indices": parent_layer_indices, + "metrics": {"loss": random()}, + "stats": {"memory_mib": random() * 100, "runtime_ms": random() * 10}, + "replacement_id": replacement_id, + } + replacements[replacement_id] = replacement + + constraints = {"stats.memory_mib": num_layers * 15.0, "stats.runtime_ms": num_layers * 1.5} + (result,) = run_mip( + replacements, + objective="metrics.loss", + constraints=constraints, + bigger_is_better=False, + ) + chosen_replacements = result["chosen_replacements"] + total_value = result["total_value"] + total_costs = result["total_costs"] + + print() + print() + print(f"{total_value=}") + print(f"{total_costs=}") + print(f"{constraints=}") + print("chosen_replacements=") + print("\n".join([rep["replacement_id"] for rep in chosen_replacements])) + + +if __name__ == "__main__": + usage_example() diff --git a/modelopt/torch/puzzletron/mip/run_puzzle.py b/modelopt/torch/puzzletron/mip/run_puzzle.py new file mode 100644 index 0000000000..803fd83db3 --- /dev/null +++ b/modelopt/torch/puzzletron/mip/run_puzzle.py @@ -0,0 +1,760 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Main entry point for running the puzzle optimization to find optimal layer configurations.""" + +# mypy: ignore-errors +import argparse +import dataclasses +import enum +import json +import sys +from collections.abc import Hashable, Iterable +from copy import deepcopy +from pathlib import Path +from typing import Any, Literal, TypeAlias + +import numpy as np +import yaml +from omegaconf import DictConfig, ListConfig, OmegaConf + +from modelopt.torch.puzzletron.anymodel.model_descriptor import ( + ModelDescriptor, + ModelDescriptorFactory, +) +from modelopt.torch.puzzletron.decilm.deci_lm_hf_code.block_config import ( + AttentionConfig, + BlockConfig, + FFNConfig, +) +from modelopt.torch.puzzletron.mip.mip_with_multi_layer_replacements import ( + run_mip as run_multi_layer_replacement_mip, +) +from modelopt.torch.puzzletron.replacement_library.replacement_utils import ( + extract_block_configs_and_locations, + parse_layer_replacement, + replacement_is_teacher, +) +from modelopt.torch.puzzletron.tools.checkpoint_utils import load_model_config +from modelopt.torch.puzzletron.tools.logger import mprint +from modelopt.torch.puzzletron.tools.robust_json import json_dump +from modelopt.torch.puzzletron.utils.parsing import get_nested_key, parse_json, parse_path +from modelopt.torch.puzzletron.utils.utils import block_config_to_str, solution_to_str + +""" +Usage: +Must specify either --single_block_replacement_validation_dir and --subblock_stats_path (in which case the metrics will +be gathered from the validation output files) or --gathered_metrics_path (in which case the metrics will be read from +this json file). + +Constraints can be specified either as 'mip_constraints' (the actual constraints that go into the MIP, e.g. 'stats.memory_mib', 'stats.runtime_ms'), +or as "human constraints" (e.g. 'target_memory', 'target_throughput', for the full list see PuzzleConstraints._ALLOWED_HUMAN_CONSTRAINTS). + +""" + +PuzzleMetrics: TypeAlias = dict[Hashable, dict[Hashable, dict[str, float]]] +MultiLayerPuzzleMetrics: TypeAlias = dict[str, dict[str, Hashable]] + + +@dataclasses.dataclass +class PuzzleConstraints: + """A set of puzzle constraints can be expressed either directly as the mip constraints (e.g. 'stats.memory_mib') or as human constraints (e.g. 'target_throughput')""" + + class Type(enum.Enum): + MIP = "mip" + HUMAN = "human" + + _ALLOWED_HUMAN_CONSTRAINTS = { + "target_memory", + "target_throughput", + "target_latency", + "target_time_to_first_token", + "num_params", + "stats.has_attention", + } + type: Type + name: str = dataclasses.field(init=False) + constraints: dict[str, Any] + + @staticmethod + def sizeof_fmt(num, suffix=""): + for unit in ("", "K", "M", "G", "T"): + if abs(num) < 1000.0: + return f"{num:g}{unit}{suffix}" + num /= 1000.0 + return f"{num:.1f}P{suffix}" + + def _validate_human_constraints(self): + illegal_constraints = set(self.constraints.keys()) - self._ALLOWED_HUMAN_CONSTRAINTS + if illegal_constraints: + raise ValueError( + f"The following human_constraints are illegal: {','.join(illegal_constraints)}" + ) + + def format_num_params_to_float(self, num_params): + if isinstance(num_params, list): + return [self.format_num_params_to_float(x) for x in num_params] + if isinstance(num_params, str): + # we only deal with Billions of params scale + return float(num_params.replace("B", "")) * 1e9 + return num_params + + def format_num_params_to_str(self, num_params): + if isinstance(num_params, list): + return [self.format_num_params_to_str(x) for x in num_params] + if isinstance(num_params, float) or isinstance(num_params, int): + return f"{num_params / 1e9}B" + return num_params + + def __post_init__(self): + if self.type == PuzzleConstraints.Type.HUMAN: + self._validate_human_constraints() + + if "stats.active_params" in self.constraints: + self.constraints["stats.active_params"] = self.format_num_params_to_float( + self.constraints["stats.active_params"] + ) + + # Set self.name + constraints = deepcopy(self.constraints) # going to override with "human readable" versions + if "stats.active_params" in constraints: + constraints["stats.active_params"] = self.format_num_params_to_str( + constraints["stats.active_params"] + ) + + if self.type == PuzzleConstraints.Type.HUMAN: + # change values to a more human string form + if "target_memory" in constraints: + constraints["target_memory"] = str(constraints["target_memory"]) + "MiB" + if "num_params" in constraints: + constraints["num_params"] = self.sizeof_fmt(constraints["num_params"]) + + def build_constraint_name(constraint_name, constraint_value): + if isinstance(constraint_value, Iterable) and not isinstance(constraint_value, str): + return "-".join(f"{constraint_name}_{x}" for x in constraint_value) + else: + return f"{constraint_name}_{constraint_value}" + + self.name = "-".join(build_constraint_name(k, v) for k, v in constraints.items()).replace( + ".", "_" + ) + + def to_mip_constraints(self, subblock_stats_args) -> dict[str, Any]: + if self.type == PuzzleConstraints.Type.MIP: + return self.constraints + + assert all(key in subblock_stats_args for key in ("batch_size", "generation_seq_len")), ( + "Can't realize human constraints without 'batch_size' and 'generation_seq_len' in subblock_stats_args." + ) + batch_size = subblock_stats_args["batch_size"] + generation_seq_len = subblock_stats_args["generation_seq_len"] + + mip_constraints = {} + + # Memory constraints + if "target_memory" in self.constraints: + mip_constraints["stats.memory_mib"] = self.constraints["target_memory"] + + # Throughput constraints + throughput_constraints = [] + if "target_throughput" in self.constraints: + throughput_constraints.append( + batch_size * generation_seq_len / self.constraints["target_throughput"] + ) + if "target_latency" in self.constraints: + throughput_constraints.append(self.constraints["target_latency"]) + if throughput_constraints: + mip_constraints["stats.runtime_ms"] = 1000 * min(throughput_constraints) + + # Prefill runtime constraint + if "target_time_to_first_token" in self.constraints: + mip_constraints["stats.prefill_runtime_ms"] = ( + 1000 * self.constraints["target_time_to_first_token"] + ) + + # Num params + if "num_params" in self.constraints: + mip_constraints["stats.num_params"] = self.constraints["num_params"] + if "stats.has_attention" in self.constraints: + mip_constraints["stats.has_attention"] = self.constraints["stats.has_attention"] + return mip_constraints + + +def parse_args() -> DictConfig: + parser = argparse.ArgumentParser() + + parser.add_argument("--puzzle_profile", type=parse_path) + + parser.add_argument("--single_block_replacement_validation_dir", type=parse_path, default=None) + parser.add_argument( + "--gathered_metrics_path", + type=parse_path, + default=None, + help="Can be given explicitly instead of --single_block_replacement_validation_dir", + ) + + parser.add_argument("--subblock_stats_path", type=parse_path) + parser.add_argument("--subblock_stats_args", type=parse_json) + + parser.add_argument("--objective", type=str) + parser.add_argument("--mip_constraints", type=parse_json) + parser.add_argument("--human_constraints", type=parse_json) + parser.add_argument("--report_additional_costs", type=str, action="append", default=[]) + + parser.add_argument( + "--output_path", + type=parse_path, + help="The main folder under which all results will be stored.", + ) + + parser.add_argument("--max_seconds_per_solution", type=float, default=60.0) + parser.add_argument("--metric_overrides", type=parse_json, default=None) + parser.add_argument( + "--bigger_is_better", + action="store_true", + help="Set this if using accuracy objective, don't set if using loss objective", + ) + + args = parser.parse_args() + return DictConfig(vars(args)) + + +def run_single_puzzle_config( + args: DictConfig, + gathered_metrics: dict, + subblock_stats: dict, + subblock_stats_args: dict, + constraints: PuzzleConstraints, + output_folder, +) -> None: + # we override the constraints and subblock_stats_args for this run to keep reporting out the same way. + args = deepcopy(args) + + subblock_stats = filter_subblock_stats_by_args(subblock_stats, subblock_stats_args) + _add_block_stats_to_gathered_metrics(gathered_metrics, subblock_stats) + + output_folder.mkdir(parents=True, exist_ok=True) + _dump_gathered_metrics(gathered_metrics, output_folder) + + non_block_stats = {"stats": _get_block_stats(subblock_stats, "non_block")} + batch_size = subblock_stats["args"]["batch_size"] + generation_seq_len = subblock_stats["args"]["generation_seq_len"] + + mip_constraints = constraints.to_mip_constraints(subblock_stats["args"]) + orig_mip_constraints = deepcopy(mip_constraints) + mprint(f"Solving for the following MIP constraints: {mip_constraints}") + args.mip_constraints = orig_mip_constraints + args.human_constraints = ( + constraints.constraints if constraints.type == PuzzleConstraints.Type.HUMAN else None + ) + args.subblock_stats_args = subblock_stats_args + + for stat_name, max_cost in mip_constraints.items(): + try: + non_block_cost = get_nested_key(non_block_stats, stat_name) + except KeyError: + non_block_cost = 0 + + is_min_max = isinstance(max_cost, Iterable) + min_cost = None + if is_min_max: + min_cost, max_cost = max_cost + + min_cost = min_cost - non_block_cost if (min_cost is not None) else None + max_cost = max_cost - non_block_cost if (max_cost is not None) else None + + if is_min_max: + mip_constraints[stat_name] = (min_cost, max_cost) + else: + mip_constraints[stat_name] = max_cost + + # If there's an additional cost that is not a constraint - set it to "inf" so MIP report the actual value of it. + for cost in set(args.report_additional_costs) - set(orig_mip_constraints.keys()): + mip_constraints[cost] = np.inf + + mprint(f"After non-block adjustments: {mip_constraints=}") + + solutions = run_multi_layer_replacement_mip( + replacements=gathered_metrics, + objective=args.objective, + constraints=mip_constraints, + bigger_is_better=args.bigger_is_better, + max_seconds_per_solution=args.max_seconds_per_solution, + ) + + for solution in solutions: + for stat_name in set([*orig_mip_constraints.keys(), *args.report_additional_costs]): + try: + non_block_cost = get_nested_key(non_block_stats, stat_name) + except KeyError: + non_block_cost = 0 + solution["total_costs"][stat_name] += non_block_cost + + # Calculate throughput from runtime_ms + if "stats.runtime_ms" in solution["total_costs"]: + total_runtime = solution["total_costs"]["stats.runtime_ms"] + solution["total_costs"]["throughput"] = ( + 1000 * batch_size * generation_seq_len / total_runtime + ) + + solution["total_value"] = {args.objective: solution["total_value"]} + solution["puzzle_args"] = ( + OmegaConf.to_container(args, resolve=True) + if isinstance(args, DictConfig) + else vars(args) + ) + solution["subblock_stats"] = subblock_stats["args"] + chosen_block_configs, _ = extract_block_configs_and_locations( + solution["chosen_replacements"] + ) + solution["chosen_block_configs"] = chosen_block_configs + solution["solution_repr"] = solution_to_str(chosen_block_configs) + + if len(solutions) > 0: + solution_repr_0 = solutions[0]["solution_repr"] + mprint(f"\n{solution_repr_0}") + mprint(f"Total costs: {solutions[0]['total_costs']}") + (output_folder / "solution_repr_0.txt").write_text(solution_repr_0) + + solutions_file = output_folder / "solutions.json" + json_dump(solutions, solutions_file) + mprint(solutions_file) + return solutions_file + + +def _dump_gathered_metrics(gathered_metrics: PuzzleMetrics, output_folder: Path) -> None: + for replacement_id, replacement_info in gathered_metrics.items(): + replacement_info["block_repr"] = block_config_to_str(replacement_info["block_config"]) + gathered_metrics_for_dump = gathered_metrics + + json_dump(gathered_metrics_for_dump, output_folder / "replacement_metrics_and_stats.json") + + +def _load_all_constraints(args, puzzle_profile): + def parse_constraints(constraints, constraints_type: PuzzleConstraints.Type): + if isinstance(constraints, (list, ListConfig)): + return [PuzzleConstraints(type=constraints_type, constraints=c) for c in constraints] + elif isinstance(constraints, (dict, DictConfig)): + return [PuzzleConstraints(type=constraints_type, constraints=constraints)] + raise TypeError(f"Invalid constraints type: {constraints_type}") + + # Constraints can be given explicitely + if args.mip_constraints is not None: + return parse_constraints(args.mip_constraints, PuzzleConstraints.Type.MIP) + + if args.human_constraints is not None: + return parse_constraints(args.human_constraints, PuzzleConstraints.Type.HUMAN) + + # Or through the puzzle_profile + if "mip_constraints" in puzzle_profile: + return parse_constraints(puzzle_profile["mip_constraints"], PuzzleConstraints.Type.MIP) + + if "human_constraints" in puzzle_profile: + return parse_constraints(puzzle_profile["human_constraints"], PuzzleConstraints.Type.HUMAN) + + raise ValueError( + "Constraints must be given either explicitely by --mip_constraints or --human_constraints arguments, or through the puzzle_profile." + ) + + +def _load_all_subblock_stats_args(args, puzzle_profile): + # If given explicitely in args + if args.subblock_stats_args is not None: + if isinstance(args.subblock_stats_args, dict): + return [args.subblock_stats_args] + else: + return args.subblock_stats_args + + # Or can be given inside puzzle_profile + if "subblock_stats_args" in puzzle_profile: + return puzzle_profile["subblock_stats_args"] + + raise ValueError( + "subblock_stats_args must be given either explicitely by the --subblock_stats_args argument, or through the puzzle_profile." + ) + + +def _override_args_from_profile(args, puzzle_profile): + for arg_name in vars(args): + if arg_name in puzzle_profile: + if arg_name not in ("mip_constraints", "human_constraints", "subblock_stats_args"): + setattr(args, arg_name, puzzle_profile[arg_name]) + + +def _assert_valid_config(args, puzzle_profile): + required_args = ( + "subblock_stats_path", + "objective", + "output_path", + ) + missing_args = [arg for arg in required_args if arg not in args or getattr(args, arg) is None] + if missing_args: + mprint(f"error: The following arguments are required: {', '.join(missing_args)}") + sys.exit(1) + + # Make sure we have specified subblock_stats_args + if "subblock_stats_args" not in args and "subblock_stats_args" not in puzzle_profile: + mprint( + "error: Must specify `subblock_stats_args` in either puzzle_profile or as a commandline arg." + ) + sys.exit(1) + + # Make sure we have specified constraints + if ( + "mip_constraints" not in args + and "human_constraints" not in args + and "mip_constraints" not in puzzle_profile + and "human_constraints" not in puzzle_profile + ): + mprint( + "error: Must specify either `mip_constraints` or `human_constraints` in one of puzzle_profile or as a commandline argument." + ) + sys.exit(1) + + +def _get_minimal_unique_names(dicts: list[dict]) -> list[str]: + all_keys = set(k for d in dicts for k in d.keys()) + all_values = {k: set(d[k] for d in dicts if k in d) for k in all_keys} + non_common_keys = [k for k, values in all_values.items() if len(values) > 1] + + return ["-".join(f"{k}_{d[k]}".replace(".", "_") for k in non_common_keys) for d in dicts] + + +def run_puzzle(args: DictConfig) -> list[str]: + # Loads config from args/puzzle_profile + if args.puzzle_profile is not None: + with open(args.puzzle_profile) as f: + puzzle_profile = yaml.safe_load(f) + _override_args_from_profile(args, puzzle_profile) + mprint(f"Loaded Puzzle profile from {args.puzzle_profile}") + else: + puzzle_profile = {} + _assert_valid_config(args, puzzle_profile) + + # Read Metrics and Stats + if args.gathered_metrics_path is not None: + gathered_metrics = json.loads(args.gathered_metrics_path.read_text()) + else: + gathered_metrics = gather_multi_layer_puzle_metrics( + args.single_block_replacement_validation_dir + ) + + if args.metric_overrides is not None: + gathered_metrics = {**gathered_metrics, **args.metric_overrides} + + subblock_stats = json.loads(args.subblock_stats_path.read_text()) + + all_subblock_args = _load_all_subblock_stats_args(args, puzzle_profile) + all_subblock_output_folders = [ + args.output_path / unique_name + for unique_name in _get_minimal_unique_names(all_subblock_args) + ] + all_constraints = _load_all_constraints(args, puzzle_profile) + + # Running all puzzles + solution_paths = [] + for subblock_stats_args, subblock_stats_output_folder in zip( + all_subblock_args, all_subblock_output_folders + ): + for constraints in all_constraints: + output_folder = subblock_stats_output_folder / constraints.name + _solution_path = run_single_puzzle_config( + args, + gathered_metrics, + subblock_stats, + subblock_stats_args, + constraints, + output_folder, + ) + solution_paths.append(_solution_path) + return solution_paths + + +def gather_puzzle_metrics( + single_block_replacement_validation_dir: Path, +) -> PuzzleMetrics: + single_block_metrics = [ + _parse_single_block_replacement_metrics(metrics_path) + for metrics_path in single_block_replacement_validation_dir.glob("*solution*.json") + ] + all_metric_names = tuple(single_block_metrics[0]["metrics"].keys()) + teacher_metrics = _parse_teacher_block_metrics( + single_block_replacement_validation_dir, all_metric_names + ) + + n_layer = len(teacher_metrics) + gathered_metrics = {f"block_{block_idx}": dict() for block_idx in range(n_layer)} + for variant_metrics in single_block_metrics + teacher_metrics: + block_config = variant_metrics["block_config"] + block_name = f"block_{variant_metrics['block_idx']}" + # if we explicitly measure teacher's blocks don't override them + gathered_metrics[block_name][block_config] = variant_metrics + # if not gathered_metrics[block_name].get(block_config): + # gathered_metrics[block_name][block_config] = variant_metrics + return gathered_metrics + + +def gather_multi_layer_puzle_metrics( + single_replacement_validation_dir: Path, +) -> MultiLayerPuzzleMetrics: + single_sequence_metrics = [ + _parse_single_sequence_replacement_metrics(metrics_path) + for metrics_path in single_replacement_validation_dir.glob("*solution*.json") + ] + all_metric_names = tuple(single_sequence_metrics[0]["metrics"].keys()) + teacher_metrics = _parse_teacher_block_metrics( + single_replacement_validation_dir, all_metric_names + ) + + gathered_metrics = { + f"replacement_{replacement_id}": replacement_metrics + for replacement_id, replacement_metrics in enumerate( + single_sequence_metrics + teacher_metrics + ) + } + + return gathered_metrics + + +def _parse_single_block_replacement_metrics(metrics_path: Path) -> dict: + raw_metrics = json.loads(metrics_path.read_text()) + single_block_replacement = raw_metrics["puzzle_solution"]["single_block_replacement"] + variant_metrics = { + "block_config": BlockConfig(**single_block_replacement["block_config"]), + "block_idx": single_block_replacement["block_idx"], + "metrics": _extract_average_metrics(raw_metrics), + } + return variant_metrics + + +def _parse_single_sequence_replacement_metrics(metrics_path: Path) -> dict: + raw_metrics = json.loads(metrics_path.read_text()) + single_sequence_replacement = raw_metrics["puzzle_solution"]["single_sequence_replacement"] + if len(single_sequence_replacement["child_block_configs"]) > 1: + raise NotImplementedError( + "Currently we only support many-to-1 replacements, but we can support many-to-many! " + ) + variant_metrics = { + "block_config": BlockConfig(**single_sequence_replacement["child_block_configs"][0]), + # is there cases where child_block_configs has more than one entry? + "parent_layer_indices": single_sequence_replacement["parent_layer_indices"], + "metrics": _extract_average_metrics(raw_metrics), + "layer_replacement": parse_layer_replacement(single_sequence_replacement), + "is_teacher": False, + } + return variant_metrics + + +def _parse_teacher_block_metrics( + single_block_replacement_validation_dir: Path, + all_metric_names: Iterable[str] = ("kl_div_loss",), +) -> list[dict]: + raw_metrics = json.loads((single_block_replacement_validation_dir / "teacher.json").read_text()) + teacher_checkpoint_dir = Path(raw_metrics["args"]["teacher_dir"]).resolve() + descriptor_name = raw_metrics["args"]["descriptor"] + descriptor = ModelDescriptorFactory.get(descriptor_name) + trust_remote_code = descriptor.requires_trust_remote_code() + teacher_model_config = load_model_config( + teacher_checkpoint_dir, trust_remote_code=trust_remote_code + ) + + teacher_replacements = None + replacement_library_path = raw_metrics["args"].get("replacement_library_path") + if replacement_library_path is not None: + teacher_replacements = dict() + all_layer_replacements = json.loads(Path(replacement_library_path).read_text()) + for layer_replacement in all_layer_replacements: + layer_replacement = parse_layer_replacement(layer_replacement) + if replacement_is_teacher( + layer_replacement, teacher_model_config, teacher_checkpoint_dir + ): + block_idx = layer_replacement["parent_layer_indices"][0] + teacher_replacements[block_idx] = layer_replacement + + teacher_metrics = [ + { + "block_config": block_config, + "block_idx": block_idx, + "parent_layer_indices": [block_idx], + "metrics": { + **dict.fromkeys(all_metric_names, 0.0), # default value 0. for teacher + **_extract_average_metrics(raw_metrics), # override with real value if exists + }, + **( + {"layer_replacement": teacher_replacements[block_idx]} + if teacher_replacements is not None + else {} + ), + "is_teacher": True, + } + for block_idx, block_config in enumerate(teacher_model_config.block_configs) + ] + return teacher_metrics + + +def _extract_average_metrics(raw_metrics: dict[str, dict]) -> dict[str, float]: + average_metrics = dict() + for metric_name in raw_metrics: + metric_dict = raw_metrics[metric_name] + if isinstance(metric_dict, dict) and ("avg" in metric_dict.keys()): + metric_value = raw_metrics[metric_name]["avg"] + average_metrics[metric_name] = metric_value + average_metrics[f"one_minus_{metric_name}"] = 1 - metric_value + return average_metrics + + +def filter_subblock_stats_by_args( + all_subblock_stats: list[dict], + subblock_stats_args: dict[str, Any], + convert_dicts_to_dataclasses: bool = True, +) -> dict[str, dict]: + matching_subblock_stats = [ + subblock_stats + for subblock_stats in all_subblock_stats + if _dict_is_subset(subblock_stats_args, subblock_stats["args"]) + ] + assert len(matching_subblock_stats) == 1, ( + "The provided subblock_stats_args should match exactly one measurement " + f"scenario, instead matched {len(matching_subblock_stats)}:\n" + f"{[m['args'] for m in matching_subblock_stats]}" + ) + subblock_stats = deepcopy(matching_subblock_stats[0]) + + if convert_dicts_to_dataclasses: + class_name_to_class = {klass.__name__: klass for klass in [AttentionConfig, FFNConfig]} + subblocks_dict = dict() + for substats in subblock_stats["subblocks"]: + subblock_config_class = class_name_to_class[substats.pop("subblock_config_class")] + subblock_config = subblock_config_class(**substats.pop("subblock_config")) + dict_key = (subblock_config, None) + if "parent_layer_index" in substats: + dict_key = (subblock_config, substats["parent_layer_index"]) + subblocks_dict[dict_key] = substats + subblock_stats["subblocks"] = subblocks_dict + return subblock_stats + + +def _dict_is_subset(dict1: dict, dict2: dict) -> bool: + return all(item in dict2.items() for item in dict1.items()) + + +def _add_block_stats_to_gathered_metrics( + gathered_metrics: PuzzleMetrics, subblock_stats: dict +) -> None: + for block_name, block_variants in gathered_metrics.items(): + parent_layer_index = None + if "parent_layer_indices" in block_variants: + parent_layer_index = block_variants["parent_layer_indices"][0] + + if "metrics" in block_variants: + # this is a sequence stats object for multi-layer puzzle + block_variants["stats"] = _get_block_stats( + subblock_stats, block_variants["block_config"], parent_layer_index + ) + else: + for block_config, variant_metrics in block_variants.items(): + variant_metrics["stats"] = _get_block_stats(subblock_stats, block_config) + + +def _get_block_stats( + subblock_stats: dict, + block_config: BlockConfig | Literal["non_block"], + parent_layer_index: int = None, +) -> dict[str, float]: + if block_config == "non_block": + return subblock_stats["non_block"] + + if block_config.parallel_blocks is None: + attention_key = (block_config.attention, parent_layer_index) + ffn_key = (block_config.ffn, parent_layer_index) + attention_stats = subblock_stats["subblocks"][attention_key] + ffn_stats = subblock_stats["subblocks"][ffn_key] + assert set(attention_stats.keys()) == set(ffn_stats.keys()) + + block_stats = dict() + for k in attention_stats.keys(): + block_stats[k] = _none_add(attention_stats[k], ffn_stats[k]) + block_stats[f"attention_{k}"] = attention_stats[k] + block_stats[f"ffn_{k}"] = ffn_stats[k] + + block_stats["has_attention"] = int( + not block_config.attention.no_op and block_config.attention.mamba is None + ) + block_stats["has_ffn"] = int(not block_config.ffn.no_op) + block_stats["has_moe"] = int(block_config.ffn.moe is not None) + block_stats["not_no_op"] = int( + not (block_config.attention.no_op and block_config.ffn.no_op) + ) + block_stats["num_kv_heads"] = ( + block_config.attention.num_key_value_heads if block_stats["has_attention"] else 0 + ) + block_stats["num_local_experts"] = ( + block_config.ffn.moe.num_local_experts if block_stats["has_moe"] else 0 + ) + + return block_stats + + # this is a parallel block + ADDITIVE_METRICS = ("memory_mib", "num_params", "kv_cache_memory_mib") + ADDITIVE_METRICS = [ + f"{prefix}{metric}" for prefix in ("", "attention_", "ffn_") for metric in ADDITIVE_METRICS + ] + block_stats = [ + _get_block_stats(subblock_stats, sub_parallel) + for sub_parallel in block_config.parallel_blocks + ] + block_stats = { + k: _none_add_list([sub_parallel_stat[k] for sub_parallel_stat in block_stats]) + if k in ADDITIVE_METRICS + else _none_max_list([sub_parallel_stat[k] for sub_parallel_stat in block_stats]) + for k in block_stats[0].keys() + } + + return block_stats + + +def _none_add(a: float | int | None, b: float | int | None) -> float | int | None: + if a is None or b is None: + return None + return a + b + + +def _none_max(a: float | int | None, b: float | int | None) -> float | int | None: + if a is None or b is None: + return None + return max(a, b) + + +def _none_add_list(l) -> float | int | None: + r = l[0] + if len(l) == 1: + return r + for e in l[1:]: + r = _none_add(r, e) + return r + + +def _none_max_list(l) -> float | int | None: + r = l[0] + if len(l) == 1: + return r + for e in l[1:]: + r = _none_max(r, e) + return r + + +if __name__ == "__main__": + args = parse_args() + run_puzzle(args) diff --git a/modelopt/torch/puzzletron/mip/sweep.py b/modelopt/torch/puzzletron/mip/sweep.py new file mode 100644 index 0000000000..82d9b11e12 --- /dev/null +++ b/modelopt/torch/puzzletron/mip/sweep.py @@ -0,0 +1,286 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""MIP sweep functionality for exploring multiple memory compression rates.""" + +import json +from pathlib import Path +from typing import Any + +from omegaconf import DictConfig, OmegaConf +from transformers import PretrainedConfig + +import modelopt.torch.puzzletron.anymodel.models # noqa: F401 — register ModelDescriptorFactory entries +import modelopt.torch.puzzletron.mip.mip_and_realize_models as mip_and_realize_models +import modelopt.torch.utils.distributed as dist +from modelopt.torch.puzzletron.anymodel.model_descriptor.model_descriptor_factory import ( + ModelDescriptorFactory, +) +from modelopt.torch.puzzletron.mip.run_puzzle import _get_block_stats, filter_subblock_stats_by_args +from modelopt.torch.puzzletron.tools.checkpoint_utils_hf import load_model_config +from modelopt.torch.puzzletron.tools.logger import mprint + + +def _load_teacher_subblock_stats(hydra_cfg: DictConfig) -> tuple[dict[str, Any], PretrainedConfig]: + """Load filtered subblock_stats and teacher ``model_config`` for the current MIP scenario.""" + puzzle_dir = Path(hydra_cfg.puzzle_dir) + teacher_dir = Path(hydra_cfg.teacher_dir) + + descriptor = ModelDescriptorFactory.get(hydra_cfg.descriptor) + trust_remote_code = descriptor.requires_trust_remote_code() + model_config = load_model_config(teacher_dir, trust_remote_code=trust_remote_code) + lm_config = descriptor.get_language_model_config(model_config) + hidden_size = lm_config.hidden_size + + mip_subblock_args = hydra_cfg.mip.subblock_stats_args[0] + subblock_stats_args = OmegaConf.to_container(mip_subblock_args, resolve=True) + # Subblock_stats.json can list multiple runs that share batch/dtypes but differ by hidden size; + # filter_subblock_stats_by_args needs n_embd so exactly one row matches the teacher. + subblock_stats_args = {**subblock_stats_args, "n_embd": hidden_size} + + batch_size = subblock_stats_args["batch_size"] + weights_dtype = str(subblock_stats_args["weights_dtype"]) + activations_dtype = str(subblock_stats_args["activations_dtype"]) + kv_cache_dtype = str(subblock_stats_args["kv_cache_dtype"]) + + subblock_stats_path = puzzle_dir / "subblock_stats.json" + if not subblock_stats_path.exists(): + raise FileNotFoundError( + f"subblock_stats.json not found at {subblock_stats_path}. " + "Please run the full pipeline first without --mip-only flag." + ) + + with open(subblock_stats_path) as f: + subblock_stats_list = json.load(f) + + try: + subblock_stats = filter_subblock_stats_by_args(subblock_stats_list, subblock_stats_args) + except AssertionError as e: + raise ValueError( + f"No unique subblock_stats entry for batch_size={batch_size}, " + f"dtypes=({weights_dtype}, {activations_dtype}, {kv_cache_dtype}), " + f"n_embd={hidden_size}" + ) from e + + return subblock_stats, model_config + + +def get_teacher_memory_from_subblock_stats(hydra_cfg: DictConfig) -> float: + """Calculate teacher model memory from subblock_stats.json. + + Sums ``non_block`` and per-layer ``_get_block_stats(subblock_stats, block_config, layer_index)`` + over ``model_config.block_configs``, matching :func:`run_puzzle._get_block_stats`. + + Args: + hydra_cfg: Hydra configuration object + + Returns: + Total teacher memory in MiB + """ + subblock_stats, model_config = _load_teacher_subblock_stats(hydra_cfg) + + total_memory = subblock_stats.get("non_block", {}).get("memory_mib", 0.0) + + for layer_idx, block_config in enumerate(model_config.block_configs): + block_stats = _get_block_stats(subblock_stats, block_config, layer_idx) + total_memory += block_stats["memory_mib"] + + return total_memory + + +def get_teacher_num_params_from_subblock_stats(hydra_cfg: DictConfig) -> int: + """Calculate total teacher parameter count from subblock_stats.json. + + Sums ``non_block`` and per-layer ``_get_block_stats(...)["num_params"]`` over + ``model_config.block_configs``, matching :func:`run_puzzle._get_block_stats`. + + Args: + hydra_cfg: Hydra configuration object + + Returns: + Total teacher parameter count (same units as subblock_stats JSON). + """ + subblock_stats, model_config = _load_teacher_subblock_stats(hydra_cfg) + + total_params = subblock_stats.get("non_block", {}).get("num_params", 0) + + for layer_idx, block_config in enumerate(model_config.block_configs): + block_stats = _get_block_stats(subblock_stats, block_config, layer_idx) + total_params += block_stats["num_params"] + + return int(total_params) + + +def extract_solution_results( + solution_path: Path, + target_memory_mib: float, + teacher_memory_mib: float, + compression_rate: float, +) -> dict: + """Extract results from a completed MIP solution. + + Args: + solution_path: Path to the solutions.json file (not the directory!) + target_memory_mib: Target memory constraint used for MIP + teacher_memory_mib: Teacher model memory in MiB + compression_rate: Compression rate applied + + Returns: + Dictionary containing extracted metrics + """ + result = { + "compression_rate": compression_rate, + "target_memory_mib": target_memory_mib, + "teacher_memory_mib": teacher_memory_mib, + } + + # solution_path is the path to solutions.json file, get parent directory + solution_dir = solution_path.parent + + # Load solutions.json for actual memory and parameters + solutions_file = solution_dir / "solutions.json" + with open(solutions_file) as f: + solutions_data = json.load(f) + solution = solutions_data[0] # First solution + total_costs = solution.get("total_costs", {}) + result["actual_memory_mib"] = total_costs.get("stats.memory_mib", None) + result["num_params"] = total_costs.get("stats.num_params", None) + + # Load solution_0.json for accuracy metrics + validation_dir = solution_dir / "solutions--validation" + # TODO: There could be multiple solutions, but we only need the first one. Is it the best solution? + solution_0_file = validation_dir / "solution_0.json" + + with open(solution_0_file) as f: + validation_data = json.load(f) + result["lm_loss"] = validation_data.get("lm_loss", {}).get("avg", None) + result["token_accuracy_top_1"] = validation_data.get("token_accuracy_top_1", {}).get( + "avg", None + ) + result["token_accuracy_top_5"] = validation_data.get("token_accuracy_top_5", {}).get( + "avg", None + ) + result["token_accuracy_top_10"] = validation_data.get("token_accuracy_top_10", {}).get( + "avg", None + ) + + return result + + +def write_results_to_csv(results: list, output_csv: str): + """Write sweep results to CSV file. + + Args: + results: List of result dictionaries + output_csv: Path to output CSV file + """ + import csv + + # Define CSV columns in desired order + columns = [ + "compression_rate", + "target_memory_mib", + "actual_memory_mib", + "teacher_memory_mib", + "num_params", + "lm_loss", + "token_accuracy_top_1", + "token_accuracy_top_5", + "token_accuracy_top_10", + ] + + # Write CSV + output_path = Path(output_csv) + output_path.parent.mkdir(parents=True, exist_ok=True) + + with open(output_path, "w", newline="") as f: + writer = csv.DictWriter(f, fieldnames=columns) + writer.writeheader() + writer.writerows(results) + + mprint(f"Results written to: {output_path}") + + +def run_mip_sweep(hydra_cfg): + """Run MIP for multiple memory compression rates and generate CSV with results. + + This function is called when mip.sweep.enabled is True in the config. + + Args: + hydra_cfg: Hydra configuration object with mip.sweep settings + """ + mprint("=" * 80) + mprint("MIP Sweep Mode Enabled") + mprint("=" * 80) + + # Get sweep configuration + sweep_cfg = hydra_cfg.mip.sweep + compression_rates = sweep_cfg.memory_compression_rates + output_csv = sweep_cfg.output_csv + puzzle_dir = Path(hydra_cfg.puzzle_dir) + + mprint(f"Compression rates: {compression_rates}") + mprint(f"Output CSV: {output_csv}") + mprint(f"Puzzle directory: {puzzle_dir}") + + # Calculate teacher memory from subblock_stats + teacher_memory = get_teacher_memory_from_subblock_stats(hydra_cfg) + mprint( + f"Teacher memory (from subblock_stats): {teacher_memory:.1f} MiB ({teacher_memory / 1024:.1f} GiB)" + ) + + # Collect results + all_results = [] + + # Run MIP for each compression rate + for compression_rate in compression_rates: + target_memory_mib = teacher_memory * compression_rate + mprint("\n" + "=" * 80) + mprint( + f"Running MIP for compression_rate={compression_rate:.2f} " + f"(target={target_memory_mib:.1f} MiB = {target_memory_mib / 1024:.1f} GiB)" + ) + mprint("=" * 80) + + # Modify config dynamically + hydra_cfg.mip.human_constraints.target_memory = target_memory_mib + + # Run MIP and realize models (reuse existing distributed logic!) + solution_paths = mip_and_realize_models.launch_mip_and_realize_model(hydra_cfg) + + # Extract results (only on master rank) + if dist.is_master(): + for solution_path in solution_paths: + result = extract_solution_results( + solution_path=Path(solution_path), + target_memory_mib=target_memory_mib, + teacher_memory_mib=teacher_memory, + compression_rate=compression_rate, + ) + all_results.append(result) + + mprint( + f"✓ Results: actual_memory={result['actual_memory_mib']:.1f} MiB, " + f"lm_loss={result['lm_loss']:.4f}" + ) + + # Write results to CSV (only on master rank) + if dist.is_master(): + mprint("\n" + "=" * 80) + mprint("MIP Sweep Complete - Writing Results") + mprint("=" * 80) + write_results_to_csv(all_results, output_csv) + mprint(f"Completed {len(all_results)} sweep runs") + mprint("=" * 80) diff --git a/modelopt/torch/puzzletron/mip/utils.py b/modelopt/torch/puzzletron/mip/utils.py new file mode 100644 index 0000000000..b276ff33b1 --- /dev/null +++ b/modelopt/torch/puzzletron/mip/utils.py @@ -0,0 +1,73 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Utility functions for MIP optimization.""" + +from typing import Any + + +class InfeasibleError(Exception): + """Exception raised when optimization problem is infeasible.""" + + +def sort_replacements(layer_replacements: list[dict]) -> list[dict]: + """Sort layer replacements by parent layer indices. + + Args: + layer_replacements: List of replacement dictionaries containing "parent_layer_indices" + + Returns: + Sorted list of replacements + """ + return sorted(layer_replacements, key=lambda replacement: replacement["parent_layer_indices"]) + + +def get_nested_key(dictionary: dict[str, Any], nested_key: str) -> Any: + """Access nested dictionary values using dot notation. + + If nested_key is "a.b.c" returns dictionary["a"]["b"]["c"] + + Args: + dictionary: Dictionary to access + nested_key: Dot-separated key path (e.g., "a.b.c") + + Returns: + Value at the nested key location + """ + value = dictionary + for key in nested_key.split("."): + value = value[key] + return value + + +def consecutive_ngrams(l: int, n: int) -> list[list[int]]: + """Generate all consecutive n-grams from range(l). + + Splits range(l) into all consecutive n-grams. + + Args: + l: Length of the range + n: Size of each n-gram + + Returns: + List of consecutive n-grams + + Example: + consecutive_ngrams(7, 2) = [[0, 1], [1, 2], [2, 3], [3, 4], [4, 5], [5, 6]] + """ + ngrams = [] + for i in range(l - n + 1): + ngrams.append(list(range(i, i + n))) + return ngrams diff --git a/modelopt/torch/puzzletron/nas/plugins/puzzletron_nas_plugin.py b/modelopt/torch/puzzletron/nas/plugins/puzzletron_nas_plugin.py new file mode 100644 index 0000000000..e5025dea7d --- /dev/null +++ b/modelopt/torch/puzzletron/nas/plugins/puzzletron_nas_plugin.py @@ -0,0 +1,235 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Puzzletron NAS plugin for the Modelopt framework (based on Puzzle algorithm: https://arxiv.org/abs/2411.19146). + +It is used by mtn.convert() to convert a model from HF format to Puzzletron heterogeneous format + do pruning scoring +and save pruned checkpoints, and by mtn.search() to perform the MIP-based NAS search. +""" + +import datetime +from pathlib import Path + +import hydra +import torch +from torch import nn + +import modelopt.torch.puzzletron.mip.mip_and_realize_models as mip_and_realize_models +import modelopt.torch.puzzletron.pruning.pruning_ckpts as pruning_ckpts +import modelopt.torch.puzzletron.scoring.scoring as scoring +import modelopt.torch.utils.distributed as dist +from modelopt.torch.nas.conversion import NASModeRegistry +from modelopt.torch.opt.config import ModeloptBaseConfig, ModeloptField +from modelopt.torch.opt.mode import ( + ConvertEntrypoint, + ConvertReturnType, + MetadataDict, + ModeDescriptor, + RestoreEntrypoint, +) +from modelopt.torch.opt.searcher import BaseSearcher, SearchStateDict +from modelopt.torch.puzzletron import build_library_and_stats +from modelopt.torch.puzzletron.activation_scoring import score_pruning_activations +from modelopt.torch.puzzletron.anymodel.converter import ConverterFactory +from modelopt.torch.puzzletron.anymodel.model_descriptor import ModelDescriptorFactory +from modelopt.torch.puzzletron.tools.hydra_utils import initialize_hydra_config_for_dir +from modelopt.torch.puzzletron.tools.logger import mprint + + +class PuzzletronModel(nn.Module): + pass # No model implementation is needed for the puzzletron mode + + +class PuzzletronConfig(ModeloptBaseConfig): + """Configuration for Puzzletron NAS algorithm.""" + + # Input model path to compress in the HF format + input_model_path: str = ModeloptField( + default="", + title="", + description="", + ) + + # Hydra config directory containing the search space definition + hydra_config_dir: str = ModeloptField( + default="", + title="", + description="", + ) + + # Hydra config name containing the search space definition + hydra_config_name: str = ModeloptField( + default="", + title="", + description="", + ) + + # Directory to save the compressed model and intermediate results + puzzle_dir: str = ModeloptField( + default="", + title="", + description="", + ) + + # Dataset path to use for scoring in prunining and NAS search + dataset_path: str = ModeloptField( + default="", + title="", + description="", + ) + + +def convert_puzzletron_model(model: nn.Module, config: PuzzletronConfig) -> ConvertReturnType: + """1. Convert the model from HF format to AnyModel format. + 2. Score the pruning activations. + 3. Prune the model and save pruned checkpoints + + The output of this step will be used by mnt.search() to perform the NAS search. + """ + # Required for mtn.search() to read NAS configuration + model.hydra_config_dir = config.hydra_config_dir + model.hydra_config_name = config.hydra_config_name + model.puzzle_dir = config.puzzle_dir + model.dataset_path = config.dataset_path + + # Load hydra config + hydra_cfg = initialize_hydra_config_for_dir( + config_dir=config.hydra_config_dir, + config_name=config.hydra_config_name, + overrides=[ + f"puzzle_dir={config.puzzle_dir}", + f"dataset_path={config.dataset_path}", + ], + ) + # Instantiate nested Hydra configs (e.g., pruning_mixin, hook_class) + hydra_cfg = hydra.utils.instantiate(hydra_cfg) + + # Convert HuggingFace model to Puzzletron heterogeneous format (generic, uses descriptor from config) + if dist.is_master(): + mprint( + "Puzzletron Progress 2/8: converting model to Puzzletron heterogeneous format (single-gpu)" + ) + hf_ckpt_teacher_dir = "ckpts/teacher" # TODO: make it configurable + + # Get descriptor and converter from the hydra config + descriptor_name = hydra_cfg.descriptor + descriptor = ModelDescriptorFactory.get(descriptor_name) + converter = ConverterFactory.get(descriptor_name) + + converter.convert( + descriptor=descriptor, + input_dir=Path(config.input_model_path), + output_dir=Path(config.puzzle_dir) / hf_ckpt_teacher_dir, + ) + dist.barrier() + + # Score_pruning_activations (distributed processing) + mprint("Puzzletron Progress 3/8: scoring pruning activations (multi-gpu)") + score_pruning_activations.launch_score_activations(hydra_cfg) + + # Prune the model and save pruned checkpoints + if dist.is_master(): + mprint( + "Puzzletron Progress 4/8: pruning the model and saving pruned checkpoints (single-gpu)" + ) + pruning_ckpts.launch_prune_ckpt(hydra_cfg) + dist.barrier() + + return model, {} + + +def restore_puzzletron_model( + model: nn.Module, config: PuzzletronConfig, metadata: MetadataDict +) -> nn.Module: + """Restore is not needed for the puzzletron mode as we are not saving any model state""" + return model + + +@NASModeRegistry.register_mode +class PuzzletronDescriptor(ModeDescriptor): + """Descriptor for the Puzzletron mode.""" + + @property + def name(self) -> str: + """String identifier for this mode.""" + return "puzzletron" + + @property + def config_class(self) -> type[ModeloptBaseConfig]: + """Configuration class for this mode.""" + return PuzzletronConfig + + @property + def search_algorithm(self) -> type[BaseSearcher]: + """Return the associated searcher implementation.""" + + return PuzzletronSearcher + + @property + def convert(self) -> ConvertEntrypoint: + """Entrypoint to convert a model.""" + return convert_puzzletron_model + + @property + def restore(self) -> RestoreEntrypoint: + """Entrypoint to restore a model.""" + return restore_puzzletron_model + + @property + def export_mode(self) -> str | None: + """The mode that corresponds to the export mode. + For now, this will be a no-op as there is no modelopt's concept of search space defined + for the puzzletron algorithm. + """ + return "export_nas" + + +class PuzzletronSearcher(BaseSearcher): + """Runs NAS search for the Puzzletron mode.""" + + @property + def default_state_dict(self) -> SearchStateDict: + """Not needed for the puzzletron mode as we are not saving any model state""" + return {} + + def run_search(self) -> None: + # Load hydra config + hydra_cfg = initialize_hydra_config_for_dir( + config_dir=self.model.hydra_config_dir, + config_name=self.model.hydra_config_name, + overrides=[ + f"puzzle_dir={self.model.puzzle_dir}", + f"dataset_path={self.model.dataset_path}", + ], + ) + # Instantiate nested Hydra configs (e.g., pruning_mixin, hook_class) + hydra_cfg = hydra.utils.instantiate(hydra_cfg) + + # Build_library_and_stats (single process) + if dist.is_master(): + mprint( + "Puzzletron Progress 5/8: building replacement library and subblock statistics (single-gpu)" + ) + build_library_and_stats.launch_build_library_and_stats(hydra_cfg) + dist.barrier() + + # Calc_one_block_scores (distributed processing) + mprint("Puzzletron Progress 6/8: calculating one block scores (multi-gpu)") + scoring.launch_scoring(hydra_cfg) + + # mip_and_realize_models (distributed processing) + mprint("Puzzletron Progress 7/8: running MIP and realizing models (multi-gpu)") + mip_and_realize_models.launch_mip_and_realize_model(hydra_cfg) diff --git a/modelopt/torch/puzzletron/pruning/expert_removal_pruning_mixin.py b/modelopt/torch/puzzletron/pruning/expert_removal_pruning_mixin.py new file mode 100644 index 0000000000..42c4ad8f51 --- /dev/null +++ b/modelopt/torch/puzzletron/pruning/expert_removal_pruning_mixin.py @@ -0,0 +1,237 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from dataclasses import dataclass, field +from typing import Any, Dict, List, Optional, Tuple, Type + +import torch +from transformers import PretrainedConfig + +from modelopt.torch.prune.importance_hooks.base_hooks import ForwardHook +from modelopt.torch.prune.importance_hooks.expert_removal_hooks import ( + NemotronHRemoveExpertsIndependentHook, + Qwen3VLRemoveExpertsIndependentHook, + RankedChoiceVotingHook, + RankedChoiceVotingHookNemotronH, +) +from modelopt.torch.puzzletron.pruning.pruning_mixin import LayerDescriptor, PruningMixIn +from modelopt.torch.puzzletron.pruning.pruning_utils import MlpInitMode, _init_moe_module + + +@dataclass +class ExpertRemovalLayerDescriptor(LayerDescriptor): + """ + TODO - Add Shared expert weights in case it's prunable. + TODO - consider removing the segmentation between weight and bias, doesn't seem to affect the pruning algo. + Attributes: + target_name: module name required to register hooks for scoring_activations, can be a regex if start with the prefix `regex:` + moe_prefix_name: moe prefix layer name, should include a placeholder for `layer_idx` to be repeated for all layers. i.e: `model.layers.{layer_idx}.moe` + expert_prefix_name: expert prefix layer name relative to moe_prefix, should include a placeholder for `expert_idx` to be repeated for all experts. i.e: `experts.{expert_idx}` + router_weights: List of the router weight names relative to moe_prefix. + router_biases: List of the router bias names relative to moe_prefix. + expert_weights: List of the expert weight names relative to expert_prefix (for per-expert format). + expert_biases: List of the expert bias names relative to expert_prefix (for per-expert format). + is_fused_experts: If True, experts are stored as single fused tensors with shape [num_experts, ...]. + If False (default), experts are stored as separate tensors per expert. + fused_expert_weights: List of fused expert weight names relative to moe_prefix (for fused format). + e.g., ["experts.gate_up_proj", "experts.down_proj"] + """ + + target_name: str + moe_prefix_name: str + expert_prefix_name: str = "" + router_weights: List[str] = field(default_factory=list) + router_biases: List[str] = field(default_factory=list) + expert_weights: List[str] = field(default_factory=list) + expert_biases: List[str] = field(default_factory=list) + is_fused_experts: bool = False + fused_expert_weights: List[str] = field(default_factory=list) + + def module_name_regex(self) -> str: + return self.target_name + + def moe_prefix(self, layer_idx: int) -> str: + return self.moe_prefix_name.format(layer_idx=layer_idx) + + def expert_prefix(self, layer_idx: int, expert_idx: int) -> str: + _expert_prefix = self.moe_prefix_name + "." + self.expert_prefix_name + return _expert_prefix.format(layer_idx=layer_idx, expert_idx=expert_idx) + + +class ExpertRemovalPruningMixIn(PruningMixIn): + def __init__(self, layer_descriptor: ExpertRemovalLayerDescriptor): + assert isinstance(layer_descriptor, ExpertRemovalLayerDescriptor) + super().__init__(layer_descriptor) + + def supported_hooks(self) -> List[Type[ForwardHook]]: + return [ + RankedChoiceVotingHook, + RankedChoiceVotingHookNemotronH, + NemotronHRemoveExpertsIndependentHook, + Qwen3VLRemoveExpertsIndependentHook, + ] + + def prune_single_layer( + self, + layer_idx: int, + parent_state_dict: dict, + new_state_dict: dict, + original_config: PretrainedConfig, + new_config: PretrainedConfig, + mlp_init_mode: MlpInitMode, + mlp_init_config: Optional[dict[str, Any]], + keys: dict, + **kwargs, + ) -> Dict[str, torch.Tensor]: + layer_out_state_dict = {} + + child_block_config = new_config.block_configs[layer_idx] + parent_block_config = original_config.block_configs[layer_idx] + + if not parent_block_config.ffn.is_moe: + return layer_out_state_dict + + new_num_experts = child_block_config.ffn.moe.num_local_experts + orig_num_experts = parent_block_config.ffn.moe.num_local_experts + + child_router_keys, new_experts_keys = self._generate_moe_keys(layer_idx, new_num_experts) + parent_router_keys, orig_experts_keys = self._generate_moe_keys(layer_idx, orig_num_experts) + + # Pop parent's router keys from copy list; child-only router keys will be initialized below + for rk in sum(parent_router_keys.values(), []): + if rk in keys: + keys.pop(rk) + for key in sum(orig_experts_keys.values(), []): + if key in keys: + keys.pop(key) + + if self.layer_descriptor.is_fused_experts: + # Fused format: unbundle single tensor [num_experts, ...] into list of per-expert tensors + orig_experts_weights = {} + for name, fused_keys in orig_experts_keys.items(): + fused_tensor = parent_state_dict[fused_keys[0]] # Single fused tensor + orig_experts_weights[name] = [fused_tensor[i] for i in range(orig_num_experts)] + + new_experts_weights = {} + for name, fused_keys in new_experts_keys.items(): + fused_tensor = new_state_dict[fused_keys[0]] # Single fused tensor + new_experts_weights[name] = [fused_tensor[i] for i in range(new_num_experts)] + else: + # Per-expert format: load each expert tensor separately + orig_experts_weights = { + name: [parent_state_dict[key] for key in orig_experts_module_keys] + for name, orig_experts_module_keys in orig_experts_keys.items() + } + new_experts_weights = { + name: [new_state_dict[key] for key in new_experts_module_keys] + for name, new_experts_module_keys in new_experts_keys.items() + } + + orig_router_weights = { + name: [parent_state_dict[key] for key in _module_router_keys] + for name, _module_router_keys in parent_router_keys.items() + } + new_router_weights = { + name: [new_state_dict[key] for key in _module_router_keys] + for name, _module_router_keys in child_router_keys.items() + } + + out_router_weights, out_experts_weights = _init_moe_module( + layer_idx=layer_idx, + mlp_init_mode=mlp_init_mode, + mlp_init_config=mlp_init_config, + orig_router_weights=orig_router_weights, + orig_experts_weights=orig_experts_weights, + new_router_weights=new_router_weights, + new_experts_weights=new_experts_weights, + orig_num_experts=orig_num_experts, + new_num_experts=new_num_experts, + ) + assert new_experts_keys.keys() == out_experts_weights.keys(), ( + "new_experts_keys and out_experts_weights must have the same keys" + ) + assert child_router_keys.keys() == out_router_weights.keys(), ( + "child_router_keys and out_router_weights must have the same keys" + ) + + for name in child_router_keys.keys(): + layer_out_state_dict.update(zip(child_router_keys[name], out_router_weights[name])) + + if self.layer_descriptor.is_fused_experts: + # Fused format: rebundle list of per-expert tensors into single fused tensor + for name in new_experts_keys.keys(): + fused_key = new_experts_keys[name][0] # Single key for fused tensor + fused_tensor = torch.stack(out_experts_weights[name], dim=0) # [num_experts, ...] + layer_out_state_dict[fused_key] = fused_tensor + else: + # Per-expert format: each expert has its own key + for name in new_experts_keys.keys(): + layer_out_state_dict.update(zip(new_experts_keys[name], out_experts_weights[name])) + + return layer_out_state_dict + + def _generate_moe_keys( + self, layer_idx: int, num_experts: int + ) -> Tuple[Dict[str, List[str]], dict[str, list[str]]]: + """ + Generate MoE weight keys for router and experts. + TODO simplify or better define the data structure of the moe keys returned. + + :return: tuple of router_keys and expert_keys, all are absolute names relative to the model root: + * router_keys structure: + {"weight: [], bias: []"} + * expert_keys structure (per-expert format): + {": []} + i.e: + { + "down_proj.weight": ["model...experts.0.down_proj.weight", ..., "model...experts.N.down_proj.weight"], + ... + } + * expert_keys structure (fused format): + {": []} + i.e: + { + "experts.gate_up_proj": ["model...experts.gate_up_proj"], + "experts.down_proj": ["model...experts.down_proj"], + } + """ + self.layer_descriptor: ExpertRemovalLayerDescriptor + moe_prefix = self.layer_descriptor.moe_prefix(layer_idx) + + router_keys = { + "weight": [ + f"{moe_prefix}.{_weight}" for _weight in self.layer_descriptor.router_weights + ], + "bias": [f"{moe_prefix}.{_bias}" for _bias in self.layer_descriptor.router_biases], + } + + if self.layer_descriptor.is_fused_experts: + # Fused format: single tensor per weight type with shape [num_experts, ...] + experts_module_names = {} + for fused_weight in self.layer_descriptor.fused_expert_weights: + experts_module_names[fused_weight] = [f"{moe_prefix}.{fused_weight}"] + else: + # Per-expert format: separate tensor for each expert + expert_key_names = ( + self.layer_descriptor.expert_weights + self.layer_descriptor.expert_biases + ) + experts_module_names = {} + for key_name in expert_key_names: + experts_module_names[key_name] = [ + f"{self.layer_descriptor.expert_prefix(layer_idx, expert_idx)}.{key_name}" + for expert_idx in range(num_experts) + ] + + return router_keys, experts_module_names diff --git a/modelopt/torch/puzzletron/pruning/ffn_intermediate_pruning_mixin.py b/modelopt/torch/puzzletron/pruning/ffn_intermediate_pruning_mixin.py new file mode 100644 index 0000000000..9b7993de1e --- /dev/null +++ b/modelopt/torch/puzzletron/pruning/ffn_intermediate_pruning_mixin.py @@ -0,0 +1,102 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# mypy: ignore-errors + +from dataclasses import dataclass, field +from typing import Any, Dict, List, Optional, Type + +import torch +from transformers import PretrainedConfig + +from modelopt.torch.prune.importance_hooks.base_hooks import ( + ForwardHook, + IndependentChannelContributionHook, + IterativeChannelContributionHook, +) +from modelopt.torch.puzzletron.pruning.pruning_mixin import LayerDescriptor, PruningMixIn +from modelopt.torch.puzzletron.tools.bypassed_training.child_init import ( + MlpInitMode, + _init_mlp_module, +) + + +@dataclass +class FFNIntermediateLayerDescriptor(LayerDescriptor): + down_proj_name: str + ffn_prefix_name: str + linear_weight_names: List[str] = field(default_factory=list) + + def module_name_regex(self) -> str: + return self.down_proj_name + + def ffn_prefix(self, layer_idx: int) -> str: + return self.ffn_prefix_name.format(layer_idx=layer_idx) + + +class FFNIntermediatePruningMixIn(PruningMixIn): + def __init__(self, layer_descriptor: FFNIntermediateLayerDescriptor): + assert isinstance(layer_descriptor, FFNIntermediateLayerDescriptor) + super().__init__(layer_descriptor) + + def supported_hooks(self) -> List[Type[ForwardHook]]: + return [IndependentChannelContributionHook, IterativeChannelContributionHook] + + def prune_single_layer( + self, + layer_idx: int, + parent_state_dict: dict, + new_state_dict: dict, + original_config: PretrainedConfig, + new_config: PretrainedConfig, + mlp_init_mode: MlpInitMode, + mlp_init_config: Optional[dict[str, Any]], + keys: dict, + keys_to_remove: dict, + **kwargs, + ) -> Dict[str, torch.Tensor]: + layer_out_state_dict = {} + # Hardcoded strings + mlp_prefix = self.layer_descriptor.ffn_prefix(layer_idx) + mlp_key_names = [ + f"{mlp_prefix}.{name}.weight" for name in self.layer_descriptor.linear_weight_names + ] + mlp_keys = [keys.get(module_name) for module_name in mlp_key_names] + mlp_keys = [k for k in mlp_keys if k is not None] + + for key in mlp_keys: + keys_to_remove[f"{mlp_prefix}.{key.split('.')[-2]}.weight"] = key + + pruned_filters = None + projection_matrix = None + + for mlp_key in mlp_keys: + expanded_dim = 1 if self.layer_descriptor.down_proj_name in mlp_key else 0 + if mlp_key in new_state_dict.keys(): + mlp_module_weight, pruned_filters, projection_matrix = _init_mlp_module( + mlp_init_mode, + mlp_prefix, + expanded_dim, + layer_idx, + new_state_dict[mlp_key], + new_config, + parent_state_dict[mlp_key], + original_config, + mlp_init_config, + pruned_filters, + projection_matrix, + ) + layer_out_state_dict[mlp_key] = mlp_module_weight + + return layer_out_state_dict diff --git a/modelopt/torch/puzzletron/pruning/kv_heads_pruning_mixin.py b/modelopt/torch/puzzletron/pruning/kv_heads_pruning_mixin.py new file mode 100644 index 0000000000..4a6fe53a34 --- /dev/null +++ b/modelopt/torch/puzzletron/pruning/kv_heads_pruning_mixin.py @@ -0,0 +1,127 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# mypy: ignore-errors +from dataclasses import dataclass, field +from typing import Any, List, Optional, Type + +from transformers import PretrainedConfig + +from modelopt.torch.prune.importance_hooks.base_hooks import ( + ForwardHook, + IndependentKvHeadContributionHook, +) +from modelopt.torch.puzzletron.pruning.pruning_mixin import LayerDescriptor, PruningMixIn +from modelopt.torch.puzzletron.pruning.pruning_utils import ( + GQAInitMode, + _init_attention_biases, + _init_attention_weights, +) + + +@dataclass +class KVHeadsLayerDescriptor(LayerDescriptor): + o_proj_name: str + attn_prefix_name: str + qkvo_weight_names: List[str] = field(default_factory=list) + + def module_name_regex(self) -> str: + return self.o_proj_name + + def attn_prefix(self, layer_idx: int) -> str: + return self.attn_prefix_name.format(layer_idx=layer_idx) + + +class KVHeadsPruningMixIn(PruningMixIn): + def __init__(self, layer_descriptor: KVHeadsLayerDescriptor): + assert isinstance(layer_descriptor, KVHeadsLayerDescriptor) + super().__init__(layer_descriptor) + + def supported_hooks(self) -> List[Type[ForwardHook]]: + return [IndependentKvHeadContributionHook] + + def prune_single_layer( + self, + layer_idx: int, + parent_state_dict: dict, + new_state_dict: dict, + original_config: PretrainedConfig, + new_config: PretrainedConfig, + gqa_init_mode: GQAInitMode, + mlp_init_config: Optional[dict[str, Any]], + is_original_mha: bool, + keys: dict, + keys_to_remove: dict, + **kwargs, + ): + layer_out_state_dict = {} + + attn_prefix = self.layer_descriptor.attn_prefix(layer_idx) + q_name, k_name, v_name, o_name = [ + f"{attn_prefix}.{proj_name}" for proj_name in self.layer_descriptor.qkvo_weight_names + ] + + head_size = new_config.head_dim + for part in ["weight", "bias"]: + attn_keys = [f"{name}.{part}" for name in [q_name, k_name, v_name, o_name]] + q_key, k_key, v_key, o_key = attn_keys + + # Drop attn keys that don't exist and required to be in the new state_dict + attn_keys = [key for key in attn_keys if key in new_state_dict.keys()] + if len(attn_keys) > 0 and all(key in keys for key in attn_keys): + for key in attn_keys: + keys_to_remove[key] = keys[key] + is_student_and_teacher_have_same_attention_implementation = all( + key in new_state_dict.keys() for key in attn_keys + ) + if is_student_and_teacher_have_same_attention_implementation: + if part == "weight": + wq, wk, wv, wo = _init_attention_weights( + gqa_init_mode=gqa_init_mode, + layer_idx=layer_idx, + new_state_dict=new_state_dict, + new_config=new_config, + original_state_dict=parent_state_dict, + original_config=original_config, + q_key=q_key, + k_key=k_key, + v_key=v_key, + o_key=o_key, + is_original_mha=is_original_mha, + head_size=head_size, + mlp_init_config=mlp_init_config, + ) + layer_out_state_dict[q_key], layer_out_state_dict[k_key] = wq, wk + layer_out_state_dict[v_key], layer_out_state_dict[o_key] = wv, wo + else: + bias_sd = _init_attention_biases( + gqa_init_mode=gqa_init_mode, + layer_idx=layer_idx, + new_state_dict=new_state_dict, + new_config=new_config, + original_state_dict=parent_state_dict, + original_config=original_config, + q_key=q_key, + k_key=k_key, + v_key=v_key, + o_key=o_key, + is_original_mha=is_original_mha, + head_size=head_size, + mlp_init_config=mlp_init_config, + ) + for bias_key, sd_key in zip("qkvo", [q_key, k_key, v_key, o_key]): + if bias_key in bias_sd.keys(): + layer_out_state_dict[sd_key] = bias_sd[bias_key] + + return layer_out_state_dict diff --git a/modelopt/torch/puzzletron/pruning/pruning_ckpts.py b/modelopt/torch/puzzletron/pruning/pruning_ckpts.py new file mode 100644 index 0000000000..b9cfd75faf --- /dev/null +++ b/modelopt/torch/puzzletron/pruning/pruning_ckpts.py @@ -0,0 +1,354 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Utilities for creating pruned model checkpoints. + +This module provides functions to generate pruned checkpoints by modifying model architectures +(FFN intermediate sizes, attention head groups, hidden dimensions) and initializing child pruned models +from parent checkpoints. +""" + +# mypy: ignore-errors +import json +import os +import time +from typing import Optional + +from omegaconf import DictConfig + +from modelopt.torch.puzzletron.anymodel.model_descriptor import ModelDescriptorFactory +from modelopt.torch.puzzletron.pruning.expert_removal_pruning_mixin import ExpertRemovalPruningMixIn +from modelopt.torch.puzzletron.pruning.ffn_intermediate_pruning_mixin import ( + FFNIntermediatePruningMixIn, +) +from modelopt.torch.puzzletron.pruning.kv_heads_pruning_mixin import KVHeadsPruningMixIn +from modelopt.torch.puzzletron.pruning.pruning_utils import ( + GQAInitMode, + HiddenSizeInitMode, + LinearInitMode, + MlpInitMode, + resolve_pruning_mixin, +) +from modelopt.torch.puzzletron.tools.bypassed_training.init_child_from_parent import ( + init_child_from_parent, +) +from modelopt.torch.puzzletron.tools.checkpoint_utils import load_model_config +from modelopt.torch.puzzletron.tools.logger import mprint + + +def launch_ffn_intermediates_prune_ckpt( + cfg: DictConfig, max_save_workers: Optional[int] = None, max_layer_workers: Optional[int] = None +): + for intermediate_size in cfg.pruning.intermediate_size_list: + dirname = f"ffn_{intermediate_size}_attn_no_op" + + if os.path.exists(os.path.join(cfg.puzzle_dir, "ckpts", dirname)): + mprint(f"Process intermediate_size {intermediate_size} has already been pruned & saved") + continue + + mprint("Process intermediate_size {}".format(intermediate_size)) + + model_config_overrides_json = {"ffn": [{"intermediate_size": intermediate_size}]} + mlp_init_config_yaml = cfg.pruning.mlp_init_config_yaml + + output_dir = os.path.join(cfg.pruning.pruned_ckpts_output_dir, dirname) + + # Profile the overall init_child_from_parent call with optimizations + mprint("Starting init_child_from_parent...") + start_time = time.time() + init_child_from_parent( + descriptor=cfg.descriptor, + pruning_mixin=cfg.pruning.pruning_mixin, + parent_checkpoint_dir=cfg.teacher_dir, + model_config_overrides_dict=model_config_overrides_json, + output_checkpoint_dir=output_dir, + gqa_init_mode=GQAInitMode(cfg.pruning.gqa_init_mode), + mlp_init_mode=MlpInitMode(cfg.pruning.mlp_init_mode), + mlp_init_config_yaml=mlp_init_config_yaml, + linear_init_mode=LinearInitMode.FromTeacher, # dummy default value + max_workers=max_save_workers, # Will auto-calculate if None + max_layer_workers=max_layer_workers, # Will auto-calculate if None + ) + init_child_from_parent_time = time.time() - start_time + mprint(f"init_child_from_parent completed in {init_child_from_parent_time:.2f} seconds") + + # Create symlink in puzzle_dir/ckpts + ckpt_path = os.path.join(cfg.puzzle_dir, "ckpts") + os.makedirs(ckpt_path, exist_ok=True) + os.symlink(output_dir, os.path.join(ckpt_path, dirname)) + + mprint(f"=== COMPLETED FFN PRUNING FOR FFN INTERMEDIATE SIZE={intermediate_size} ===") + mprint(f"Total processing time: {init_child_from_parent_time:.2f} seconds\n") + + +def launch_attn_groups_prune_ckpt( + cfg: DictConfig, max_save_workers: Optional[int] = None, max_layer_workers: Optional[int] = None +): + descriptor = cfg.descriptor + parent_model_config = load_model_config( + cfg.teacher_dir, trust_remote_code=descriptor.requires_trust_remote_code() + ) + num_attention_heads = parent_model_config.num_attention_heads + + for n_heads_in_group in cfg.pruning.n_heads_in_group_list: + dirname = f"n_heads_in_group{n_heads_in_group}" + + if os.path.exists(os.path.join(cfg.puzzle_dir, "ckpts", dirname)): + mprint(f"Process n_heads_in_group {n_heads_in_group} has already been pruned & saved") + continue + + mprint("Process n_heads_in_group {}".format(n_heads_in_group)) + mprint(f"=== STARTING ATTENTION PRUNING FOR n_heads_in_group={n_heads_in_group} ===") + + num_key_value_heads = num_attention_heads // n_heads_in_group + model_config_overrides_json = {"attention": [{"num_key_value_heads": num_key_value_heads}]} + mlp_init_config_yaml = cfg.pruning.mlp_init_config_yaml + + output_dir = os.path.join(cfg.pruning.pruned_ckpts_output_dir, dirname) + + # Profile the overall init_child_from_parent call with optimizations + mprint("Starting init_child_from_parent...") + start_time = time.time() + init_child_from_parent( + descriptor=cfg.descriptor, + pruning_mixin=cfg.pruning.pruning_mixin, + parent_checkpoint_dir=cfg.teacher_dir, + model_config_overrides_dict=model_config_overrides_json, + output_checkpoint_dir=output_dir, + gqa_init_mode=GQAInitMode(cfg.pruning.gqa_init_mode), + mlp_init_mode=MlpInitMode(cfg.pruning.mlp_init_mode), + mlp_init_config_yaml=mlp_init_config_yaml, + linear_init_mode=LinearInitMode.FromTeacher, # dummy default value + max_workers=max_save_workers, # Will auto-calculate if None + max_layer_workers=max_layer_workers, # Will auto-calculate if None + ) + init_child_from_parent_time = time.time() - start_time + mprint(f"init_child_from_parent completed in {init_child_from_parent_time:.2f} seconds") + + # Create symlink in puzzle_dir/ckpts + ckpt_path = os.path.join(cfg.puzzle_dir, "ckpts") + os.makedirs(ckpt_path, exist_ok=True) + os.symlink(output_dir, os.path.join(ckpt_path, dirname)) + + mprint(f"=== COMPLETED ATTENTION PRUNING FOR n_heads_in_group={n_heads_in_group} ===") + mprint(f"Total processing time: {init_child_from_parent_time:.2f} seconds\n") + + +def launch_hidden_dim_prune_ckpt(cfg: DictConfig): + """Launch hidden dimension pruning using channel importance ranking.""" + # Get channel importance results from the activations log directory + activations_log_dir = cfg.pruning.activations_log_dir + channel_importance_path = os.path.join(activations_log_dir, "channel_importance_results.json") + + if not os.path.exists(channel_importance_path): + raise FileNotFoundError( + f"Channel importance results not found at {channel_importance_path}. " + f"Make sure to run the activation collection step first." + ) + + # Load parent model config to get FFN configuration + descriptor = ModelDescriptorFactory.get(cfg.descriptor) + trust_remote_code = descriptor.requires_trust_remote_code() + parent_model_config = load_model_config( + cfg.pruning.model_name_or_path, trust_remote_code=trust_remote_code + ) + parent_hidden_size = parent_model_config.hidden_size + + # Get teacher's FFN configuration + intermediate_sizes = [] + for block_config in parent_model_config.block_configs: + if block_config.ffn.intermediate_size is not None: + intermediate_sizes.append(block_config.ffn.intermediate_size) + else: + intermediate_sizes.append(None) + + mprint(f"Teacher config:") + mprint(f" - hidden_size: {parent_hidden_size}") + mprint(f" - intermediate_sizes: {intermediate_sizes}") + os.makedirs(os.path.join(cfg.puzzle_dir, "ckpts"), exist_ok=True) + + for hidden_size in cfg.pruning.hidden_size_list: + mprint(f"\n######################################################################") + mprint(f"Hidden Size = {hidden_size}") + mprint(f"######################################################################\n") + + mprint(f"Child config:") + mprint(f" - hidden_size: {hidden_size}") + + # Create model config overrides with proper FFN configuration + model_config_overrides_json = json.dumps( + { + "hidden_size": hidden_size, + "ffn": [ + { + "intermediate_size": intermediate_size, + } + for intermediate_size in intermediate_sizes + ], + } + ) + + mlp_init_config_yaml = cfg.pruning.mlp_init_config_yaml + dirname = f"hidden_size_{hidden_size}" + output_dir = os.path.join(cfg.pruning.pruned_ckpts_output_dir, dirname) + + mprint(f"Creating checkpoint with hidden_size={hidden_size}") + mprint(f"Model config overrides: {model_config_overrides_json}") + + init_child_from_parent( + descriptor=cfg.descriptor, + pruning_mixin=cfg.pruning.pruning_mixin, + parent_checkpoint_dir=cfg.pruning.model_name_or_path, + model_config_overrides_dict=model_config_overrides_json, + output_checkpoint_dir=output_dir, + gqa_init_mode=GQAInitMode(cfg.pruning.gqa_init_mode), + mlp_init_mode=MlpInitMode(cfg.pruning.mlp_init_mode), + mlp_init_config_yaml=mlp_init_config_yaml, + linear_init_mode=LinearInitMode(cfg.pruning.linear_init_mode), + hidden_size_init_mode=HiddenSizeInitMode(cfg.pruning.hidden_size_init_mode), + channel_importance_path=channel_importance_path, + ) + + # Create symlink in puzzle_dir/ckpts + ckpt_path = os.path.join(cfg.puzzle_dir, "ckpts") + os.makedirs(ckpt_path, exist_ok=True) + os.symlink(output_dir, os.path.join(ckpt_path, dirname)) + mprint(f"Created pruned checkpoint at: {output_dir}") + + +def launch_experts_prune_ckpt( + cfg: DictConfig, + max_save_workers: Optional[int] = None, + max_layer_workers: Optional[int] = None, + symlink_suffix: Optional[str] = None, +): + for num_experts in cfg.pruning.num_experts_to_keep_list: + dirname = f"num_experts_{num_experts}" + # Create symlink name with optional suffix + symlink_name = f"{dirname}_{symlink_suffix}" if symlink_suffix else dirname + if os.path.exists(os.path.join(cfg.puzzle_dir, "ckpts", symlink_name)): + mprint( + f"Process num_experts {num_experts} (symlink: {symlink_name}) has already been pruned & saved" + ) + continue + mprint(f"Process num_experts {num_experts}") + mprint(f"=== STARTING EXPERT PRUNING FOR num_experts={num_experts} ===") + model_config_overrides_json = {"ffn": [{"moe": {"num_local_experts": num_experts}}]} + + mlp_init_config_yaml = cfg.pruning.mlp_init_config_yaml + + output_dir = os.path.join(cfg.pruning.pruned_ckpts_output_dir, dirname) + + # Profile the overall init_child_from_parent call with optimizations + mprint("Starting init_child_from_parent...") + start_time = time.time() + init_child_from_parent( + descriptor=cfg.descriptor, + pruning_mixin=cfg.pruning.pruning_mixin, + parent_checkpoint_dir=cfg.teacher_dir, + model_config_overrides_dict=model_config_overrides_json, + output_checkpoint_dir=output_dir, + gqa_init_mode=GQAInitMode(cfg.pruning.gqa_init_mode), + mlp_init_mode=MlpInitMode(cfg.pruning.mlp_init_mode), + mlp_init_config_yaml=mlp_init_config_yaml, + linear_init_mode=LinearInitMode.FromTeacher, # dummy default value + max_workers=max_save_workers, # Will auto-calculate if None + max_layer_workers=max_layer_workers, # Will auto-calculate if None + ) + init_child_from_parent_time = time.time() - start_time + mprint(f"init_child_from_parent completed in {init_child_from_parent_time:.2f} seconds") + + # Create symlink in puzzle_dir/ckpts + ckpt_path = os.path.join(cfg.puzzle_dir, "ckpts") + os.makedirs(ckpt_path, exist_ok=True) + os.symlink(output_dir, os.path.join(ckpt_path, symlink_name)) + + mprint(f"=== COMPLETED EXPERT PRUNING FOR num_experts={num_experts} ===") + mprint(f"Total processing time: {init_child_from_parent_time:.2f} seconds\n") + + +def launch_moe_ffn_intermediates_prune_ckpt( + cfg: DictConfig, max_save_workers: Optional[int] = None, max_layer_workers: Optional[int] = None +): + for intermediate_size in cfg.pruning.intermediate_size_list: + dirname = f"moe_ffn_{intermediate_size}_attn_no_op" + + if os.path.exists(os.path.join(cfg.puzzle_dir, "ckpts", dirname)): + mprint(f"Process intermediate_size {intermediate_size} has already been pruned & saved") + continue + + mprint("Process intermediate_size {}".format(intermediate_size)) + + model_config_overrides_json = { + "attention": [{"no_op": True, "llama4": None}], + "ffn": [{"moe": {"expert_intermediate_dim": intermediate_size}}], + } + mlp_init_config_yaml = cfg.pruning.mlp_init_config_yaml + + output_dir = os.path.join(cfg.pruning.pruned_ckpts_output_dir, dirname) + + # Profile the overall init_child_from_parent call with optimizations + mprint("Starting init_child_from_parent...") + start_time = time.time() + init_child_from_parent( + descriptor=cfg.descriptor, + pruning_mixin=cfg.pruning.pruning_mixin, + parent_checkpoint_dir=cfg.teacher_dir, + model_config_overrides_dict=model_config_overrides_json, + output_checkpoint_dir=output_dir, + gqa_init_mode=GQAInitMode(cfg.pruning.gqa_init_mode), + mlp_init_mode=MlpInitMode(cfg.pruning.mlp_init_mode), + mlp_init_config_yaml=mlp_init_config_yaml, + linear_init_mode=LinearInitMode.FromTeacher, # dummy default value + max_workers=max_save_workers, # Will auto-calculate if None + max_layer_workers=max_layer_workers, # Will auto-calculate if None + ) + init_child_from_parent_time = time.time() - start_time + mprint(f"init_child_from_parent completed in {init_child_from_parent_time:.2f} seconds") + + # Create symlink in puzzle_dir/ckpts + os.symlink(output_dir, os.path.join(cfg.puzzle_dir, "ckpts", dirname)) + + mprint(f"=== COMPLETED MOE FFN PRUNING FOR FFN INTERMEDIATE SIZE={intermediate_size} ===") + mprint(f"Total processing time: {init_child_from_parent_time:.2f} seconds\n") + + +def launch_prune_ckpt(cfg: DictConfig): + cfg.descriptor = ModelDescriptorFactory.get(cfg.descriptor) + # Resolve pruning_mixin from config (could be string, enum, or PruningMixIn) + cfg.pruning.pruning_mixin = resolve_pruning_mixin(cfg.pruning.pruning_mixin, cfg.descriptor) + pruning_mixin = cfg.pruning.pruning_mixin + + # I/O optimization settings - same as FFN pruning + max_save_workers = None # Will auto-calculate as min(CPU count, num files) + if "PRUNING_SAVE_WORKERS" in os.environ: + max_save_workers = int(os.environ["PRUNING_SAVE_WORKERS"]) + + # Layer workers now auto-calculate but can still be overridden + max_layer_workers = None # Will auto-calculate as min(CPU count, num layers) + if "PRUNING_LAYER_WORKERS" in os.environ: + max_layer_workers = int(os.environ["PRUNING_LAYER_WORKERS"]) + + if isinstance(pruning_mixin, FFNIntermediatePruningMixIn): + launch_ffn_intermediates_prune_ckpt(cfg, max_save_workers, max_layer_workers) + elif isinstance(pruning_mixin, KVHeadsPruningMixIn): + launch_attn_groups_prune_ckpt(cfg, max_save_workers, max_layer_workers) + elif isinstance(pruning_mixin, ExpertRemovalPruningMixIn): + launch_experts_prune_ckpt(cfg, max_save_workers, max_layer_workers) + # elif target_layer == "layernorm": + # launch_hidden_dim_prune_ckpt(cfg) + else: + raise NotImplementedError( + f"checkpoint pruning is not currently supported for pruning mixin: {pruning_mixin.__class__.__name__}" + ) diff --git a/modelopt/torch/puzzletron/pruning/pruning_mixin.py b/modelopt/torch/puzzletron/pruning/pruning_mixin.py new file mode 100644 index 0000000000..21685848bf --- /dev/null +++ b/modelopt/torch/puzzletron/pruning/pruning_mixin.py @@ -0,0 +1,73 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# mypy: ignore-errors + +import re +from abc import ABC, abstractmethod +from typing import List, Optional, Tuple, Type + +from modelopt.torch.prune.importance_hooks.base_hooks import ForwardHook + + +class LayerDescriptor: + def module_name_regex(self) -> str: + return "" + + def block_idx_from_module_name(self, module_name: str) -> Optional[int]: + block_idx_match = re.search(r"\.(\d+)\.", module_name) + if block_idx_match: + return int(block_idx_match.group(1)) + return None + + def get_modules_names_to_hook(self, model) -> List[Tuple[int, str]]: + target_layer = self.module_name_regex() + if target_layer.startswith("regex:"): + target_layer_regex = target_layer[len("regex:") :] + pattern = re.compile(target_layer_regex) + match_predicate = lambda module_name: pattern.search(module_name) + else: + match_predicate = lambda module_name: module_name.endswith(target_layer) + + module_names_to_hook = [] + for module_name, module in model.named_modules(): + if match_predicate(module_name): + module_names_to_hook.append( + (self.block_idx_from_module_name(module_name), module_name) + ) + return module_names_to_hook + + +class PruningMixIn(ABC): + def __init__(self, layer_descriptor: LayerDescriptor): + self.layer_descriptor = layer_descriptor + + def get_module_names_to_hook(self, model) -> List[Tuple[int, str]]: + return self.layer_descriptor.get_modules_names_to_hook(model) + + @abstractmethod + def supported_hooks(self) -> List[Type[ForwardHook]]: + raise NotImplementedError + + # @abstractmethod + # def prune_single_layer( + # self, + # layer_idx: int, + # parent_state_dict: dict, + # new_state_dict: dict, + # original_config: PretrainedConfig, + # new_config: PretrainedConfig, + # **kwargs + # ): + # raise NotImplementedError diff --git a/modelopt/torch/puzzletron/pruning/pruning_utils.py b/modelopt/torch/puzzletron/pruning/pruning_utils.py new file mode 100644 index 0000000000..82ba675c94 --- /dev/null +++ b/modelopt/torch/puzzletron/pruning/pruning_utils.py @@ -0,0 +1,652 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# mypy: ignore-errors + +import json +import math +from enum import Enum +from pathlib import Path +from typing import Any, Dict, List, Optional, Tuple, Type, Union + +import torch +from transformers import PretrainedConfig + +from modelopt.torch.puzzletron.anymodel.model_descriptor import ModelDescriptor +from modelopt.torch.puzzletron.pruning.pruning_mixin import PruningMixIn + + +class GQAInitMode(Enum): + RandomKV = "RandomKV" + AverageKV = "AverageKV" + FirstKV = "FirstKV" + RandomBlock = "RandomBlock" + CopyAsIs = "CopyAsIs" + Degrouping = "Degrouping" + PruneKVHeads = "PruneKVHeads" + + +class MlpInitMode(Enum): + Random = "Random" + Truncate = "Truncate" + CopyAsIs = "CopyAsIs" + PruneByActivationsLog = "PruneByActivationsLog" + ExpertRemoval = "ExpertRemoval" + ConcatExpertsIntoDenseFFN = "ConcatExpertsIntoDenseFFN" + + +class LinearInitMode(Enum): + Random = "Random" + FromTeacher = "FromTeacher" + + +class HiddenSizeInitMode(Enum): + Random = "Random" + Truncate = "Truncate" + PruneByChannelRanking = "PruneByChannelRanking" + CopyAsIs = "CopyAsIs" + + +def resolve_pruning_mixin( + pruning_mixin, descriptor: Type[ModelDescriptor] +) -> PruningMixIn | List[PruningMixIn]: + """ + Convert pruning_mixin argument to PruningMixIn instance(s). + + Args: + pruning_mixin: Can be a string identifier, PruningMixIn instance, + or a list of any of those types. + descriptor: ModelDescriptor class that provides the pruning_mixins() mapping. + + Returns: + PruningMixIn or List[PruningMixIn] depending on input type. + """ + # Handle list of values recursively + if isinstance(pruning_mixin, list): + return [resolve_pruning_mixin(item, descriptor) for item in pruning_mixin] + + # Handle single value + # If it's already a PruningMixIn, return as is + if isinstance(pruning_mixin, PruningMixIn): + return pruning_mixin + + # Get the pruning mixins mapping from the descriptor + mixins_dict = descriptor.pruning_mixins() + + if isinstance(pruning_mixin, str): + if pruning_mixin not in mixins_dict: + available_methods = list(mixins_dict.keys()) + raise ValueError( + f"Pruning method '{pruning_mixin}' is not supported by {descriptor.__name__}. " + f"Available methods: {available_methods}" + ) + return mixins_dict[pruning_mixin] + + raise ValueError(f"Unsupported pruning_mixin type: {type(pruning_mixin)}") + + +def _init_mlp_module( + mlp_init_mode: Union[MlpInitMode, str], + mlp_prefix: str, + expanded_dim: int, + layer_idx: int, + new_item: torch.Tensor, + new_config: PretrainedConfig, + orig_item: torch.Tensor, + original_config: PretrainedConfig, + mlp_init_config: Optional[dict[str, Any]], + pruned_filters: Optional[torch.Tensor] = None, + projection_matrix: Optional[dict[str, torch.Tensor]] = None, +) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[dict[str, torch.Tensor]]]: + if isinstance(mlp_init_mode, str): + mlp_init_mode = MlpInitMode(mlp_init_mode) + assert orig_item.ndim == 2, f"{orig_item.ndim=}" + assert new_item.ndim == 2, f"{new_item.ndim=}" + + assert new_config.num_hidden_layers == original_config.num_hidden_layers, ( + f"({new_config.num_hidden_layers=}) != ({original_config.num_hidden_layers=})" + ) + + new_intermediate_size = new_config.block_configs[layer_idx].ffn.intermediate_size + original_intermediate_size = original_config.block_configs[layer_idx].ffn.intermediate_size + + if mlp_init_mode == MlpInitMode.CopyAsIs: + assert new_intermediate_size == original_intermediate_size, ( + f"({new_intermediate_size=}) != ({original_intermediate_size=}), can't be copied as is." + ) + mlp_module_weight = orig_item + + elif mlp_init_mode == MlpInitMode.Random: + mlp_module_weight = new_item + + elif new_intermediate_size == original_intermediate_size: + mlp_module_weight = orig_item + + elif mlp_init_mode in ( + MlpInitMode.Truncate, + MlpInitMode.PruneByActivationsLog, + ): + assert original_intermediate_size >= new_intermediate_size, ( + f"({original_intermediate_size=}) < ({new_intermediate_size=}), can't be truncated." + ) + orig_ffn_size = orig_item.shape[expanded_dim] + new_ffn_size = new_item.shape[expanded_dim] + + if mlp_init_mode == MlpInitMode.Truncate: + truncated_weight = torch.narrow( + orig_item, dim=expanded_dim, start=0, length=new_ffn_size + ) + mlp_module_weight = truncated_weight + + elif mlp_init_mode == MlpInitMode.PruneByActivationsLog: + if pruned_filters is None: + filter_importance = _load_activations_log( + mlp_init_config, module_name=f"{mlp_prefix}.down_proj" + ) + filters_sorted_by_importance = torch.argsort(filter_importance, descending=True) + pruned_filters = filters_sorted_by_importance[:new_ffn_size].to(orig_item.device) + + pruned_weight = torch.index_select(orig_item, dim=expanded_dim, index=pruned_filters) + if mlp_init_config.get("scale_pruned_weights", False) and expanded_dim == 1: + pruned_weight = pruned_weight * (orig_ffn_size / new_ffn_size) + mlp_module_weight = pruned_weight + + elif ( + mlp_init_mode == MlpInitMode.ExpertRemoval + ): # the case of mlp layers of maverick. for now we only support copy as is + assert new_intermediate_size == original_intermediate_size, ( + f"({new_intermediate_size=}) != ({original_intermediate_size=}), can't be copied as is." + ) + mlp_module_weight = orig_item + + else: + raise ValueError(f"Unsupported {mlp_init_mode=}") + + return mlp_module_weight, pruned_filters, projection_matrix + + +def _load_activations_log(mlp_init_config: dict[str, Any], module_name: str) -> torch.Tensor: + _cache_activations_log(mlp_init_config) + module_log = ACTIVATIONS_LOG[module_name] + filter_importance = module_log["score"] + return filter_importance + + +ACTIVATIONS_LOG = dict() + + +def _cache_activations_log(mlp_init_config: dict[str, Any]) -> None: + if len(ACTIVATIONS_LOG) == 0: + assert "activations_log_dir" in mlp_init_config + activations_log_dir = mlp_init_config["activations_log_dir"] + print(f"Loading activations_log from {activations_log_dir}") + # Only load rank_*.pth files to avoid loading hook_states_*.pth checkpoint files + ACTIVATIONS_LOG.update( + { + module_name: module_log + for p in Path(activations_log_dir).glob("rank_*.pth") + for module_name, module_log in torch.load(p).items() + } + ) + + +def _init_attention_weights( + gqa_init_mode, + layer_idx, + new_state_dict, + new_config, + original_state_dict, + q_key, + k_key, + v_key, + o_key, + original_config, + is_original_mha, + head_size, + mlp_init_config, +): + assert new_config.num_attention_heads == original_config.num_attention_heads, ( + f"({new_config.num_attention_heads=}) != ({original_config.num_attention_heads=})" + ) + num_q_heads = new_config.num_attention_heads + num_kv_heads = new_config.block_configs[layer_idx].attention.num_key_value_heads + orig_num_kv_heads = original_config.block_configs[layer_idx].attention.num_key_value_heads + + # new_w* are typically randomly initialized + new_wq = new_state_dict[q_key] + new_wk = new_state_dict[k_key] + new_wv = new_state_dict[v_key] + new_wo = new_state_dict[o_key] + + # w* are from the parent model + wq = original_state_dict[q_key] + wk = original_state_dict[k_key] + wv = original_state_dict[v_key] + wo = original_state_dict[o_key] + + if "bias" in k_key: + for tensor in [wq, wk, wv, wo, new_wq, new_wk, new_wv, new_wo]: + assert tensor.ndim == 1 + tensor.unsqueeze_(1) + dim1 = wk.shape[1] # this is the hidden_size in case of matrix weights, and 1 in case of biases + + if gqa_init_mode in (GQAInitMode.RandomKV, GQAInitMode.RandomBlock): + wk, wv = new_wk, new_wv + elif gqa_init_mode in (GQAInitMode.AverageKV, GQAInitMode.FirstKV): + assert orig_num_kv_heads % num_kv_heads == 0, ( + f"({orig_num_kv_heads=}) % ({num_kv_heads=}) != 0" + ) + n_heads_to_aggregate = orig_num_kv_heads // num_kv_heads + + wk = wk.view(-1, n_heads_to_aggregate, head_size, dim1) + wv = wv.view(-1, n_heads_to_aggregate, head_size, dim1) + + if gqa_init_mode == GQAInitMode.AverageKV: + wk = wk.mean(dim=1) + wv = wv.mean(dim=1) + else: + wk = wk[:, 0] + wv = wv[:, 0] + elif gqa_init_mode == GQAInitMode.CopyAsIs: + assert new_wk.shape == wk.shape, f"({new_wk.shape=}) != ({wk.shape=})" + assert new_wv.shape == wv.shape, f"({new_wv.shape=}) != ({wv.shape=})" + assert new_wq.shape == wq.shape, f"({new_wq.shape=}) != ({wq.shape=})" + assert new_wo.shape == wo.shape, f"({new_wo.shape=}) != ({wo.shape=})" + + elif gqa_init_mode == GQAInitMode.Degrouping: + assert not is_original_mha, ( + "Degrouping can only be done on original models that are GQA themselves." + ) + n_groups = num_kv_heads + orig_n_groups = orig_num_kv_heads + assert n_groups % orig_n_groups == 0, f"{n_groups=} must be a divisor of {orig_n_groups=}" + n_repeats = n_groups // orig_n_groups + if n_repeats > 1: + print(f"Degrouping {orig_n_groups} into {n_groups}") + + def degroup_w(w): + w = w.view(orig_n_groups, head_size, dim1) + w = torch.repeat_interleave(w, repeats=n_repeats, dim=0) + w = w.reshape(n_groups * head_size, dim1) + return w + + wk = degroup_w(wk) + wv = degroup_w(wv) + + elif gqa_init_mode == GQAInitMode.PruneKVHeads: + wk = wk.view(orig_num_kv_heads, head_size, dim1) + wv = wv.view(orig_num_kv_heads, head_size, dim1) + wq = wq.view(orig_num_kv_heads, num_q_heads // orig_num_kv_heads, head_size, dim1) + wo = wo.view(dim1, orig_num_kv_heads, num_q_heads // orig_num_kv_heads, head_size) + + o_proj_module_name = o_key.replace(".weight", "") + kv_head_importance = _load_activations_log(mlp_init_config, module_name=o_proj_module_name) + kv_heads_sorted_by_importance = torch.argsort(kv_head_importance, descending=True) + kv_heads_to_keep = kv_heads_sorted_by_importance[:num_kv_heads] + kv_heads_to_remove = kv_heads_sorted_by_importance[num_kv_heads:] + + wk = wk[kv_heads_to_keep] + wv = wv[kv_heads_to_keep] + + reduction_factor = orig_num_kv_heads // num_kv_heads + + prune_via_duplication = False + if prune_via_duplication: + ## Wq option 1 - replicate the query groups to match the total number of attention heads. Queries work with familiar kv heads. + wq = wq[kv_heads_to_keep] + wq = torch.repeat_interleave(wq, repeats=reduction_factor, dim=0) + + ## Wo option 1 - replicate the groups of the original Wo. Multiple by the reduction factor to mimic pruning of the other groups. + ## This makes sense with Wq option 1, but it will not be more expressive than true pruning due to symmetry, unless we add noise. + wo = wo[:, kv_heads_to_keep] + wo = torch.repeat_interleave(wo, repeats=reduction_factor, dim=1) + wo = wo / reduction_factor + + else: # prune via zeroing out + ## Wq option 2 - keep the original queries. At init they will not be used (see the Wo zeroing), during training they can adapt to new kv heads like in variable GQA. + ## We need to interleave them to keep the matching between queries and kv heads. + kv_heads_to_keep = kv_heads_to_keep.tolist() + kv_heads_to_remove = kv_heads_to_remove.tolist() + kv_head_ordering = [] + zero_out_mask = [] + for i_head in range(orig_num_kv_heads): + if i_head % reduction_factor == 0: + kv_head_ordering.append(kv_heads_to_keep.pop(0)) + zero_out_mask.append(False) + else: + kv_head_ordering.append(kv_heads_to_remove.pop(0)) + zero_out_mask.append(True) + + wq = wq[kv_head_ordering] + + ## Wo option 2 - zero-out the contribution of queries that do not belong to chosen kv heads. + ## At initialization it's exactly like pruning, but the extra weights will have the chance to adapt to new kv heads if we train the model. + ## Even though the weight is 0 it can still train, like initializing biases to 0 does not prevent them from training. + ## Matmul backprop: if Y = AB and dY is the gradient of Y, then dA = dY @ B.T and dB = A.T @ dY, so the gradient of the zeroed-out weights depends on the gradient of what multiplies them. + wo = wo[:, kv_head_ordering] + wo[:, zero_out_mask] = 0.0 + + else: + raise ValueError(f"{gqa_init_mode=} not supported") + + wk = wk.reshape(-1, dim1) + wv = wv.reshape(-1, dim1) + wq = wq.reshape(-1, dim1) + wo = wo.reshape(dim1, -1) + return wq, wk, wv, wo + + +def _init_attention_biases( + gqa_init_mode, + layer_idx, + new_state_dict, + new_config, + original_state_dict, + q_key, + k_key, + v_key, + o_key, + original_config, + is_original_mha, + head_size, + mlp_init_config, +): + assert new_config.num_attention_heads == original_config.num_attention_heads, ( + f"({new_config.num_attention_heads=}) != ({original_config.num_attention_heads=})" + ) + num_q_heads = new_config.num_attention_heads + num_kv_heads = new_config.block_configs[layer_idx].attention.num_key_value_heads + orig_num_kv_heads = original_config.block_configs[layer_idx].attention.num_key_value_heads + n_heads_in_group = num_q_heads // num_kv_heads + orig_n_heads_in_group = num_q_heads // orig_num_kv_heads + + o_proj_bias = new_config.o_proj_bias + attention_bias = new_config.attention_bias + + # If no biases + if not (o_proj_bias or attention_bias): + return {} + + new_bias_sd = {} + bias_sd = {} + # new_w* are typically randomly initialized + if o_proj_bias: + new_bias_sd["o"] = new_state_dict[o_key] + bias_sd["o"] = original_state_dict[o_key] + if attention_bias: + for bias_key, key in zip("qkv", [q_key, k_key, v_key]): + new_bias_sd[bias_key] = new_state_dict[key] + bias_sd[bias_key] = original_state_dict[key] + + # maybe unsqueeze all tensors + for tensor in list(new_bias_sd.values()) + list(bias_sd.values()): + assert tensor.ndim == 1 + tensor.unsqueeze_(1) + + dim1 = 1 # this is the hidden_size in case of matrix weights, and 1 in case of biases + if gqa_init_mode in (GQAInitMode.RandomKV, GQAInitMode.RandomBlock) and attention_bias: + bias_sd["k"] = torch.zeros( + new_bias_sd["k"].shape, dtype=bias_sd["k"].dtype, device=bias_sd["k"].device + ) + bias_sd["v"] = torch.zeros( + new_bias_sd["v"].shape, dtype=bias_sd["v"].dtype, device=bias_sd["v"].device + ) + elif gqa_init_mode in (GQAInitMode.AverageKV, GQAInitMode.FirstKV) and attention_bias: + assert n_heads_in_group % orig_n_heads_in_group == 0, ( + f"({n_heads_in_group=}) % ({orig_n_heads_in_group=}) != 0" + ) + n_heads_to_aggregate = n_heads_in_group // orig_n_heads_in_group + + bias_sd["k"] = bias_sd["k"].view(-1, n_heads_to_aggregate, head_size, dim1) + bias_sd["v"] = bias_sd["v"].view(-1, n_heads_to_aggregate, head_size, dim1) + + if gqa_init_mode == GQAInitMode.AverageKV: + bias_sd["k"] = bias_sd["k"].mean(dim=1) + bias_sd["v"] = bias_sd["v"].mean(dim=1) + else: + bias_sd["k"] = bias_sd["k"][:, 0] + bias_sd["v"] = bias_sd["v"][:, 0] + elif gqa_init_mode == GQAInitMode.CopyAsIs: + for key in bias_sd.keys(): + assert new_bias_sd[key].shape == bias_sd[key].shape, ( + f"({new_bias_sd[key].shape=}) != ({bias_sd[key].shape=})" + ) + + elif gqa_init_mode == GQAInitMode.Degrouping and attention_bias: + assert not is_original_mha, ( + "Degrouping can only be done on original models that are GQA themselves." + ) + n_groups = new_config.num_attention_heads // n_heads_in_group + orig_n_groups = original_config.num_attention_heads // orig_n_heads_in_group + assert n_groups % orig_n_groups == 0, f"{n_groups=} must be a divisor of {orig_n_groups=}" + n_repeats = n_groups // orig_n_groups + if n_repeats > 1: + print(f"Degrouping {orig_n_groups} into {n_groups}") + + def degroup_w(w): + w = w.view(orig_n_groups, head_size, dim1) + w = torch.repeat_interleave(w, repeats=n_repeats, dim=0) + w = w.reshape(n_groups * head_size, dim1) + return w + + bias_sd["k"] = degroup_w(bias_sd["k"]) + bias_sd["v"] = degroup_w(bias_sd["v"]) + + elif gqa_init_mode == GQAInitMode.PruneKVHeads: + if o_proj_bias: + o_proj_module_name = o_key.rsplit(".", 1)[0] + else: + # Here we assume that the o_proj layer is called "o_proj" + o_proj_module_name = k_key.rsplit(".", 2)[0] + ".o_proj" + + kv_head_importance = _load_activations_log(mlp_init_config, module_name=o_proj_module_name) + kv_heads_sorted_by_importance = torch.argsort(kv_head_importance, descending=True) + kv_heads_to_keep = kv_heads_sorted_by_importance[:num_kv_heads] + kv_heads_to_remove = kv_heads_sorted_by_importance[num_kv_heads:] + + # view as KV groups + if attention_bias: + bias_sd["k"] = bias_sd["k"].view(orig_num_kv_heads, head_size, dim1) + bias_sd["v"] = bias_sd["v"].view(orig_num_kv_heads, head_size, dim1) + bias_sd["q"] = bias_sd["q"].view( + orig_num_kv_heads, orig_n_heads_in_group, head_size, dim1 + ) + # Keep important KV heads and prune the others + bias_sd["k"] = bias_sd["k"][kv_heads_to_keep] + bias_sd["v"] = bias_sd["v"][kv_heads_to_keep] + if o_proj_bias: + bias_sd["o"] = bias_sd["o"].view( + dim1, orig_num_kv_heads, orig_n_heads_in_group, head_size + ) + + reduction_factor = orig_num_kv_heads // num_kv_heads + + prune_via_duplication = False + if prune_via_duplication: + if attention_bias: + ## Wq option 1 - replicate the query groups to match the total number of attention heads. Queries work with familiar kv heads. + bias_sd["q"] = bias_sd["q"][kv_heads_to_keep] + bias_sd["q"] = torch.repeat_interleave( + bias_sd["q"], repeats=reduction_factor, dim=0 + ) + + if o_proj_bias: + ## Wo option 1 - replicate the groups of the original Wo. Multiple by the reduction factor to mimic pruning of the other groups. + ## This makes sense with Wq option 1, but it will not be more expressive than true pruning due to symmetry, unless we add noise. + bias_sd["o"] = bias_sd["o"][:, kv_heads_to_keep] + bias_sd["o"] = torch.repeat_interleave( + bias_sd["o"], repeats=reduction_factor, dim=1 + ) + bias_sd["o"] = bias_sd["o"] / reduction_factor + + else: # prune via zeroing out + ## Wq option 2 - keep the original queries. At init they will not be used (see the Wo zeroing), during training they can adapt to new kv heads like in variable GQA. + ## We need to interleave them to keep the matching between queries and kv heads. + kv_heads_to_keep = kv_heads_to_keep.tolist() + kv_heads_to_remove = kv_heads_to_remove.tolist() + kv_head_ordering = [] + zero_out_mask = [] + for i_head in range(orig_num_kv_heads): + if i_head % reduction_factor == 0: + kv_head_ordering.append(kv_heads_to_keep.pop(0)) + zero_out_mask.append(False) + else: + kv_head_ordering.append(kv_heads_to_remove.pop(0)) + zero_out_mask.append(True) + + if attention_bias: + bias_sd["q"] = bias_sd["q"][kv_head_ordering] + + if o_proj_bias: + ## Wo option 2 - zero-out the contribution of queries that do not belong to chosen kv heads. + ## At initialization it's exactly like pruning, but the extra weights will have the chance to adapt to new kv heads if we train the model. + ## Even though the weight is 0 it can still train, like initializing biases to 0 does not prevent them from training. + ## Matmul backprop: if Y = AB and dY is the gradient of Y, then dA = dY @ B.T and dB = A.T @ dY, so the gradient of the zeroed-out weights depends on the gradient of what multiplies them. + bias_sd["o"] = bias_sd["o"][:, kv_head_ordering] + bias_sd["o"][:, zero_out_mask] = 0.0 + + else: + raise ValueError(f"{gqa_init_mode=} not supported") + + if attention_bias: + for bias_key in "qkv": + bias_sd[bias_key] = bias_sd[bias_key].reshape(-1) + if o_proj_bias: + bias_sd["o"] = bias_sd["o"].reshape(-1) + return bias_sd + + +def _init_moe_module( + mlp_init_mode: Union[MlpInitMode, str], + mlp_init_config: Optional[Dict[str, Any]], + layer_idx: int, + orig_router_weights: Dict[str, List[torch.Tensor]], + orig_experts_weights: Dict[str, List[torch.Tensor]], + new_router_weights: Dict[str, List[torch.Tensor]], + new_experts_weights: Dict[str, List[torch.Tensor]], + orig_num_experts: int, + new_num_experts: int, +) -> Tuple[Dict[str, List[torch.Tensor]], Dict[str, List[torch.Tensor]]]: + if isinstance(mlp_init_mode, str): + mlp_init_mode = MlpInitMode(mlp_init_mode) + + if mlp_init_mode != MlpInitMode.ExpertRemoval: + raise ValueError(f"Unsupported {mlp_init_mode=}") + + selected_experts = _select_expert_indices( + mlp_init_config=mlp_init_config, + layer_idx=layer_idx, + orig_num_experts=orig_num_experts, + new_num_experts=new_num_experts, + ) + + # Router: prefer parent tensors when available; if child has bias only, slice from child + result_router_weights: dict[str, list[torch.Tensor]] = {} + for name, new_list in new_router_weights.items(): + result_router_weights[name] = [ + tensor_to_slice[selected_experts] for tensor_to_slice in orig_router_weights[name] + ] + + # Experts: for each name present in the child, take from parent if available, else from child + result_experts_weights: dict[str, list[torch.Tensor]] = {} + for name, new_list in new_experts_weights.items(): + if name in orig_experts_weights: + src_list = orig_experts_weights[name] + else: + src_list = new_list + result_experts_weights[name] = [src_list[i] for i in selected_experts] + + # Validate shapes + assert result_router_weights.keys() == new_router_weights.keys(), ( + "result_router_weights and new_router_weights must have the same keys" + ) + for name in new_router_weights.keys(): + assert len(new_router_weights[name]) == len(result_router_weights[name]) + for new_router_weight, result_router_weight in zip( + new_router_weights[name], result_router_weights[name] + ): + assert new_router_weight.shape == result_router_weight.shape + + assert result_experts_weights.keys() == new_experts_weights.keys(), ( + "result_experts_weights and new_experts_weights must have the same keys" + ) + for name in result_experts_weights.keys(): + assert len(new_experts_weights[name]) == len(result_experts_weights[name]) + for new_expert_weight, result_expert_weight in zip( + new_experts_weights[name], result_experts_weights[name] + ): + assert new_expert_weight.shape == result_expert_weight.shape + + return result_router_weights, result_experts_weights + + +def _select_expert_indices( + *, mlp_init_config: dict[str, Any], layer_idx: int, orig_num_experts: int, new_num_experts: int +) -> list[int]: + expert_scores = _load_expert_scores(mlp_init_config, layer_idx) + assert len(expert_scores) == orig_num_experts + higher_is_better = mlp_init_config.get("higher_is_better", True) + selected_experts = sorted( + range(orig_num_experts), + key=lambda i: ( + expert_scores[i] + if not math.isnan(expert_scores[i]) + else (float("-inf") if higher_is_better else float("inf")) + ), + reverse=higher_is_better, + )[:new_num_experts] + return selected_experts + + +def _load_expert_scores( + mlp_init_config: Optional[dict[str, Any]], layer_idx: int +) -> list[list[int | float]]: + assert mlp_init_config is not None + if "expert_scores_file" in mlp_init_config: + expert_scores_file = mlp_init_config["expert_scores_file"] + with open(expert_scores_file, "r") as f: + expert_scores = json.load(f) + elif "activations_log_dir" in mlp_init_config: + _cache_activations_log(mlp_init_config) + # Use layer_prefix_template from pruning config, or fall back to legacy nemotron_h format + # TODO - get from descriptors + layer_prefix_template = mlp_init_config.get( + "layer_prefix_template", "backbone.layers.{layer_idx}." + ) + layer_prefix = layer_prefix_template.format(layer_idx=layer_idx) + candidate_layer_keys = [ + key for key in ACTIVATIONS_LOG.keys() if key.startswith(layer_prefix) + ] + if len(candidate_layer_keys) == 0: + raise ValueError(f"No layer keys found for {layer_prefix=}. {ACTIVATIONS_LOG.keys()=}") + elif len(candidate_layer_keys) > 1: + if "layer_suffix" not in mlp_init_config: + raise ValueError( + f"Multiple candidate layer keys found for {layer_prefix=}, you must specify a layer_suffix in the mlp_init_config. {candidate_layer_keys=}" + ) + layer_suffix = mlp_init_config["layer_suffix"] + layer_key = f"{layer_prefix}{layer_suffix}" + else: + layer_key = candidate_layer_keys[0] + layer_log = ACTIVATIONS_LOG[layer_key] + + expert_scores_key = mlp_init_config.get("expert_scores_key", "expert_ranks") + if expert_scores_key not in layer_log: + raise ValueError( + f"Expert scores key {expert_scores_key=} not found in {layer_log.keys()=}" + ) + expert_scores = layer_log[expert_scores_key] + else: + raise ValueError(f"Unsupported {mlp_init_config=}") + return expert_scores diff --git a/modelopt/torch/puzzletron/puzzletron.py b/modelopt/torch/puzzletron/puzzletron.py new file mode 100644 index 0000000000..5a1484e07a --- /dev/null +++ b/modelopt/torch/puzzletron/puzzletron.py @@ -0,0 +1,76 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""This module provides the main compression function for a model using MIP-based NAS search algorithm.""" + +import hydra +from omegaconf import DictConfig + +import modelopt.torch.puzzletron.activation_scoring.score_pruning_activations as score_pruning_activations +import modelopt.torch.puzzletron.build_library_and_stats as build_library_and_stats +import modelopt.torch.puzzletron.mip.mip_and_realize_models as mip_and_realize_models +import modelopt.torch.puzzletron.pruning.pruning_ckpts as pruning_ckpts +import modelopt.torch.puzzletron.scoring.scoring as scoring +import modelopt.torch.utils.distributed as dist +from modelopt.torch.puzzletron.tools.hydra_utils import initialize_hydra_config_for_dir + + +def puzzletron( + hydra_config_dir: str, hydra_config: str, puzzle_dir: str, dataset_path: str +) -> DictConfig: + """Compress a model using the MIP-based NAS search algorithm from Puzzletron. + + Args: + hydra_config_dir (str): path to a hydra_config_dir that defines the search space + hydra_config (str): the corresponding hydra config file + puzzle_dir (str): directory with a puzzletron model to compress + dataset_path (str): dataset used for scoring and distillation + + Returns: + Hydra config object after compressing the model. + The same hydra configuration object is used across all compression steps. + TODO: Investigate if this config object is immutable across steps and clarify + """ + # Step 0: Load puzzletron hydra config + hydra_cfg = initialize_hydra_config_for_dir( + config_dir=hydra_config_dir, + config_name=hydra_config, + overrides=[ + f"puzzle_dir={puzzle_dir}", + f"dataset_path={dataset_path}", + ], + ) + hydra_cfg = hydra.utils.instantiate(hydra_cfg) + + # Step 1: score_pruning_activations (distributed processing) + score_pruning_activations.launch_score_activations(hydra_cfg) + + # Step 2: pruning_ckpts (single process) + if dist.is_master(): + pruning_ckpts.launch_prune_ckpt(hydra_cfg) + dist.barrier() + + # Step 4: build_library_and_stats (single process) + if dist.is_master(): + build_library_and_stats.launch_build_library_and_stats(hydra_cfg) + dist.barrier() + + # Step 5: calc_one_block_scores (distributed processing) + scoring.launch_scoring(hydra_cfg) + + # Step 6: mip_and_realize_models (distributed processing) + mip_and_realize_models.launch_mip_and_realize_model(hydra_cfg) + + return hydra_cfg diff --git a/modelopt/torch/puzzletron/replacement_library/build_replacement_library.py b/modelopt/torch/puzzletron/replacement_library/build_replacement_library.py new file mode 100644 index 0000000000..a7ed3f7d37 --- /dev/null +++ b/modelopt/torch/puzzletron/replacement_library/build_replacement_library.py @@ -0,0 +1,616 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +This module constructs the replacement library JSON files from a puzzle directory containing +multiple trained model checkpoints. It analyzes checkpoints to extract unique block and subblock +configurations, builds a library of available replacements, and generates solutions for layer +replacement in compressed models. The resulting replacement library can then be used by +ReplacementLibrary to efficiently load models with mixed teacher/student layers. + +Standard Puzzle Usage: +====================== +python -m modelopt.torch.puzzletron.replacement_library.build_replacement_library PUZZLE_DIR + +Teacher checkpoint dir is assumed to be inside PUZZLE_DIR/ckpts/teacher (symlink is recommended) +though you can supply an explicit --teacher_checkpoint_dir. + +--add_ffn_no_ops and --add_attention_no_ops are optional (default True), + + +""" +# mypy: ignore-errors + +import json +from pathlib import Path +from typing import Any, Type + +import pandas as pd +from omegaconf import DictConfig + +from modelopt.torch.puzzletron.anymodel.model_descriptor import ( + ModelDescriptor, + ModelDescriptorFactory, +) +from modelopt.torch.puzzletron.decilm.deci_lm_hf_code.block_config import ( + AttentionConfig, + BlockConfig, + FFNConfig, +) +from modelopt.torch.puzzletron.replacement_library.replacement_utils import ( + is_replacement_identical_to_teacher, + replacement_is_teacher, + sort_replacements, +) +from modelopt.torch.puzzletron.tools.checkpoint_utils import ( + SAFETENSORS_SUBBLOCKS_DIR_NAME, + is_valid_decilm_checkpoint, + load_model_config, +) +from modelopt.torch.puzzletron.tools.logger import mprint +from modelopt.torch.puzzletron.tools.robust_json import json_dump +from modelopt.torch.puzzletron.utils.parsing import format_global_config +from modelopt.torch.puzzletron.utils.utils import block_config_to_str, subblock_config_to_str + +UNIQUE_SUBBLOCK_IDENTIFIER = ["block_config", "attention_config", "ffn_config", "block_idx"] +CHECKPOINTS_DIR_NAME = "ckpts" + + +def build_replacement_library( + master_puzzle_dir: Path | str, + descriptor: ModelDescriptor, + teacher_checkpoint_dir: Path | str | None = None, + add_ffn_no_ops: bool = True, + add_attention_no_ops: bool = True, +) -> None: + """ + For normal puzzle runs, use default values. + For advanced use cases, see the Usage section. + """ + master_puzzle_dir = Path(master_puzzle_dir) + (master_puzzle_dir / "ckpts").mkdir(exist_ok=True) + teacher_checkpoint_dir = infer_teacher_dir(master_puzzle_dir, teacher_checkpoint_dir) + trust_remote_code = descriptor.requires_trust_remote_code() + subblocks_df = _build_subblocks_df( + master_puzzle_dir, + teacher_checkpoint_dir, + add_ffn_no_ops, + add_attention_no_ops, + trust_remote_code=trust_remote_code, + ) + block_library_df = _build_block_library_from_subblocks(subblocks_df) + + layer_replacements = _build_layer_replacements( + block_library_df, master_puzzle_dir, teacher_checkpoint_dir, trust_remote_code + ) + + single_sequence_replacement_solutions = _build_single_sequence_replacement_solutions( + layer_replacements, teacher_checkpoint_dir, trust_remote_code + ) + + json_dump(block_library_df.to_dict(orient="records"), master_puzzle_dir / "block_library.json") + json_dump(subblocks_df.to_dict(orient="records"), master_puzzle_dir / "subblock_library.json") + json_dump(layer_replacements, master_puzzle_dir / "replacement_library.json") + json_dump( + single_sequence_replacement_solutions, + master_puzzle_dir / "single_sequence_replacement_solutions.json", + ) + mprint("done") + + +def launch_build_replacement_library(cfg: DictConfig) -> None: + """ + Launch the build replacement library function with Hydra configuration. + """ + mprint(f"Building replacement library for puzzle directory: {cfg.puzzle_dir}") + mprint(f"Teacher directory: {cfg.teacher_dir}") + mprint( + f"Build replacement library config: {format_global_config(cfg.build_replacement_library, title='Build replacement library')}" + ) + + descriptor = ModelDescriptorFactory.get(cfg.descriptor) + build_replacement_library( + master_puzzle_dir=cfg.puzzle_dir, + teacher_checkpoint_dir=cfg.teacher_dir, + add_ffn_no_ops=cfg.build_replacement_library.add_ffn_no_ops, + add_attention_no_ops=cfg.build_replacement_library.add_attention_no_ops, + descriptor=descriptor, + ) + + +def infer_teacher_dir( + master_puzzle_dir: Path | str, + teacher_checkpoint_dir: Path | str | None = None, +) -> Path: + if teacher_checkpoint_dir is None: + teacher_checkpoint_dir = Path(master_puzzle_dir) / CHECKPOINTS_DIR_NAME / "teacher" + if not teacher_checkpoint_dir.exists(): + raise ValueError( + f"You must either provide the --teacher_checkpoint_dir argument, or create a link to the " + f"teacher dir under '{{PUZZLE_DIR}}/ckpts'." + ) + teacher_checkpoint_dir = Path(teacher_checkpoint_dir).resolve().absolute() + return teacher_checkpoint_dir + + +def _build_block_library_from_subblocks(subblocks_df: pd.DataFrame) -> pd.DataFrame: + joint_blocks_df = subblocks_df.dropna(subset=["block_config"]).copy() + constructed_blocks_df = _construct_blocks_from_subblocks(subblocks_df) + + is_constructed_block_has_joint_variant = pd.Series( + map(tuple, constructed_blocks_df[["block_config", "block_idx"]].values) + ).isin(pd.Series(map(tuple, joint_blocks_df[["block_config", "block_idx"]].values))) + constructed_blocks_df = constructed_blocks_df[~is_constructed_block_has_joint_variant] + + block_library_df = pd.concat([joint_blocks_df, constructed_blocks_df]) + block_library_df["block_repr"] = block_library_df["block_config"].apply(block_config_to_str) + + dups = block_library_df.loc[ + block_library_df[["block_config", "block_idx"]].duplicated() + ].sort_values(by=["block_config", "block_idx"]) + if len(dups) > 0: + mprint(f"Found {len(dups)} duplicate blocks in the block library. Here are some examples:") + dup_block_idx = dups["block_idx"].iloc[0] + dups_with_same_block_idx = dups[dups["block_idx"] == dup_block_idx] + for _, row in dups_with_same_block_idx.head(10).iterrows(): + mprint(row.to_dict()) + json_dump(block_library_df.to_dict(orient="records"), "ERROR_block_library.json") + json_dump(subblocks_df.to_dict(orient="records"), "ERROR_subblock_library.json") + raise ValueError( + f"Found {len(dups)} duplicate blocks in the block library. See ERROR_block_library.json and ERROR_subblock_library.json for more details." + ) + + return block_library_df + + +def _construct_blocks_from_subblocks(subblocks_df: pd.DataFrame) -> pd.DataFrame: + columns = subblocks_df.columns + decomp_blocks_df = subblocks_df[subblocks_df["block_config"].isna()].drop( + columns=columns[columns.str.contains("block_config|joint|block_repr")] + ) + + attention_df = decomp_blocks_df.dropna(subset="attention_config").drop( + columns=columns[columns.str.contains("ffn")] + ) + ffn_df = decomp_blocks_df.dropna(subset="ffn_config").drop( + columns=columns[columns.str.contains("attention")] + ) + constructed_blocks_df = pd.merge(attention_df, ffn_df, on="block_idx") + + constructed_blocks_df["block_config"] = constructed_blocks_df.apply( + lambda row: BlockConfig(ffn=row["ffn_config"], attention=row["attention_config"]), axis=1 + ) + + return constructed_blocks_df + + +def _build_subblocks_df( + master_puzzle_dir: Path | str, + teacher_checkpoint_dir: Path | str, + add_ffn_no_ops: bool, + add_attention_no_ops: bool, + trust_remote_code: bool = False, +) -> pd.DataFrame: + teacher_checkpoint_dir = Path(teacher_checkpoint_dir) + checkpoint_dirs = _get_last_checkpoint_from_each_experiment( + master_puzzle_dir, trust_remote_code=trust_remote_code + ) + checkpoint_dirs = [teacher_checkpoint_dir] + list(checkpoint_dirs - {teacher_checkpoint_dir}) + checkpoints_to_split = [teacher_checkpoint_dir] + + subblock_rows = [] + for checkpoint_dir in checkpoint_dirs: + subblocks_to_extract = _infer_subblocks_to_extract(checkpoint_dir, checkpoints_to_split) + if len(subblocks_to_extract) > 0: + subblock_rows_from_current_checkpoint = ( + _construct_subblock_rows_from_current_checkpoint( + checkpoint_dir, subblocks_to_extract, trust_remote_code=trust_remote_code + ) + ) + subblock_rows.extend(subblock_rows_from_current_checkpoint) + + subblocks_df = pd.DataFrame(subblock_rows) + + subblocks_df = _drop_duplicates_of_decomp_no_op(subblocks_df) + assert subblocks_df.duplicated().sum() == 0 + + if add_ffn_no_ops or add_attention_no_ops: + subblocks_df = _add_no_op_subblock_rows(subblocks_df, add_ffn_no_ops, add_attention_no_ops) + + subblocks_df = _drop_duplicates_of_teacher(subblocks_df, teacher_checkpoint_dir) + + subblocks_that_have_multiple_sources = list( + subblocks_df[subblocks_df.duplicated(UNIQUE_SUBBLOCK_IDENTIFIER, keep=False)].groupby( + UNIQUE_SUBBLOCK_IDENTIFIER, dropna=False + ) + ) + if len(subblocks_that_have_multiple_sources) > 0: + mprint( + f"Found {len(subblocks_that_have_multiple_sources)} subblock types with multiple sources. Dropping duplicates..." + ) + for subblock_identifier, duplicates_df in subblocks_that_have_multiple_sources: + mprint("\n================================") + mprint(dict(zip(UNIQUE_SUBBLOCK_IDENTIFIER, subblock_identifier))) + for _, row in duplicates_df.iterrows(): + mprint(row.to_dict()) + + # Drop duplicates, keeping the first occurrence (which should be from teacher) + mprint(f"Dropping duplicates. Original count: {len(subblocks_df)}") + subblocks_df = subblocks_df.drop_duplicates(subset=UNIQUE_SUBBLOCK_IDENTIFIER, keep="first") + mprint(f"After dropping duplicates: {len(subblocks_df)}") + + subblocks_df["ffn_repr"] = subblocks_df["ffn_config"].apply(subblock_config_to_str) + subblocks_df["attention_repr"] = subblocks_df["attention_config"].apply(subblock_config_to_str) + subblocks_df["block_repr"] = subblocks_df["block_config"].apply(block_config_to_str) + + return subblocks_df + + +def _drop_duplicates_of_teacher( + subblocks_df: pd.DataFrame, + teacher_checkpoint_dir: Path | str, +) -> pd.DataFrame: + orig_subblocks_df = subblocks_df.copy() + + attention_is_teacher = subblocks_df["attention_checkpoint_dir"] == str(teacher_checkpoint_dir) + ffn_is_teacher = subblocks_df["ffn_checkpoint_dir"] == str(teacher_checkpoint_dir) + is_joint_teacher = attention_is_teacher & ffn_is_teacher + + is_decomp_attention = subblocks_df["ffn_config"].isna() + is_decomp_ffn = subblocks_df["attention_config"].isna() + is_joint_block = ~is_decomp_attention & ~is_decomp_ffn + + student_indices_that_have_teacher_dups = [] + + for current_subset, is_teacher in [ + (is_decomp_attention, attention_is_teacher), + (is_decomp_ffn, ffn_is_teacher), + (is_joint_block, is_joint_teacher), + ]: + subblocks_df = orig_subblocks_df.copy().loc[current_subset] + + subblocks_df["is_student"] = ~is_teacher.loc[current_subset] + + def get_student_indices_that_have_teacher_dups(grouped_is_student: pd.Series) -> list: + if grouped_is_student.all(): + return [] + return grouped_is_student.index[grouped_is_student].tolist() + + current_student_indices_that_have_teacher_dups = [ + dup_index + for dup_list in subblocks_df.groupby(UNIQUE_SUBBLOCK_IDENTIFIER, dropna=False)[ + "is_student" + ].apply(get_student_indices_that_have_teacher_dups) + for dup_index in dup_list + ] + student_indices_that_have_teacher_dups.extend( + current_student_indices_that_have_teacher_dups + ) + + dedup_subblocks_df = orig_subblocks_df.drop(index=student_indices_that_have_teacher_dups) + return dedup_subblocks_df + + +def _drop_duplicates_of_decomp_no_op(subblocks_df: pd.DataFrame) -> pd.DataFrame: + is_decomp = subblocks_df["block_config"].isna() + is_ffn_no_op = subblocks_df["ffn_config"].apply(lambda conf: conf is not None and conf.no_op) + is_attention_no_op = subblocks_df["attention_config"].apply( + lambda conf: conf is not None and conf.no_op + ) + is_duplicated = subblocks_df.duplicated(subset=UNIQUE_SUBBLOCK_IDENTIFIER, keep="first") + is_dup_of_decomp_no_op = is_duplicated & is_decomp & (is_ffn_no_op | is_attention_no_op) + subblocks_df = subblocks_df[~is_dup_of_decomp_no_op] + return subblocks_df + + +def _construct_subblock_rows_from_current_checkpoint( + checkpoint_dir: Path, subblocks_to_extract: list[str], trust_remote_code: bool = False +) -> list[dict[str, Any]]: + subblock_rows_from_current_checkpoint = [] + model_config = load_model_config(checkpoint_dir, trust_remote_code=trust_remote_code) + for block_idx, block_config in enumerate(model_config.block_configs): + for subblock_to_extract in subblocks_to_extract: + subblock_row = _init_empty_subblock_row(block_idx) + + if subblock_to_extract == "block": + subblock_row["block_config"] = block_config + subblock_row["attention_config"] = block_config.attention + subblock_row["attention_checkpoint_dir"] = ( + str(checkpoint_dir) if not block_config.attention.no_op else None + ) + subblock_row["ffn_config"] = block_config.ffn + subblock_row["ffn_checkpoint_dir"] = ( + str(checkpoint_dir) if not block_config.ffn.no_op else None + ) + elif subblock_to_extract == "ffn": + subblock_row["ffn_config"] = block_config.ffn + subblock_row["ffn_checkpoint_dir"] = ( + str(checkpoint_dir) if not block_config.ffn.no_op else None + ) + elif subblock_to_extract == "attention": + subblock_row["attention_config"] = block_config.attention + subblock_row["attention_checkpoint_dir"] = ( + str(checkpoint_dir) if not block_config.attention.no_op else None + ) + else: + raise ValueError() + + subblock_rows_from_current_checkpoint.append(subblock_row) + return subblock_rows_from_current_checkpoint + + +def _add_no_op_subblock_rows( + subblocks_df: pd.DataFrame, + add_ffn_no_op: bool, + add_attention_no_op: bool, +) -> pd.DataFrame: + n_layer = subblocks_df["block_idx"].max() + 1 + + no_op_subblocks = [] + if add_ffn_no_op: + no_op_subblocks.append("ffn") + if add_attention_no_op: + no_op_subblocks.append("attention") + + additional_no_op_rows = [] + for no_op_subblock in no_op_subblocks: + rows_with_no_op_subblock, subblock_cls = _get_rows_with_no_op_subblock( + subblocks_df, no_op_subblock + ) + existing_no_op_indices = rows_with_no_op_subblock["block_idx"].values + missing_no_op_indices = list(set(range(n_layer)) - set(existing_no_op_indices)) + for block_idx in missing_no_op_indices: + no_op_subblock_row = { + **_init_empty_subblock_row(block_idx), + f"{no_op_subblock}_config": subblock_cls(no_op=True), + } + additional_no_op_rows.append(no_op_subblock_row) + + subblocks_df = pd.concat([subblocks_df, pd.DataFrame(additional_no_op_rows)]) + + for no_op_subblock in no_op_subblocks: + rows_with_no_op_subblock, _ = _get_rows_with_no_op_subblock(subblocks_df, no_op_subblock) + assert len(rows_with_no_op_subblock) == n_layer, ( + f"Got {len(rows_with_no_op_subblock)} rows with {no_op_subblock}=no_op, but we have {n_layer} layers" + ) + return subblocks_df + + +def _get_rows_with_no_op_subblock( + subblocks_df: pd.DataFrame, no_op_subblock: str +) -> tuple[pd.DataFrame, Type[AttentionConfig] | Type[FFNConfig]]: + other_subblock = "ffn" if no_op_subblock == "attention" else "attention" + subblock_cls = AttentionConfig if no_op_subblock == "attention" else FFNConfig + no_op_subblock_config = subblock_cls(no_op=True) + rows_with_no_op_subblock = subblocks_df[ + (subblocks_df[f"{no_op_subblock}_config"] == no_op_subblock_config) + & subblocks_df[f"{other_subblock}_config"].isna() + ] + return rows_with_no_op_subblock, subblock_cls + + +def _get_last_checkpoint_from_each_experiment( + master_puzzle_dir: Path | str, trust_remote_code: bool = False +) -> set[Path]: + master_puzzle_dir = Path(master_puzzle_dir) + master_checkpoints_dir = master_puzzle_dir / CHECKPOINTS_DIR_NAME + subdirs_of_master_checkpoints_dir = [ + p.resolve() for p in master_checkpoints_dir.iterdir() if p.is_dir() + ] + checkpoint_dirs = [ + p.parent + for subdir in subdirs_of_master_checkpoints_dir + for p in subdir.rglob("config.json") + ] + + for checkpoint_dir in checkpoint_dirs: + if checkpoint_dir == master_checkpoints_dir: + raise ValueError( + f"We need at least 1 hierarchy level under the '{CHECKPOINTS_DIR_NAME}' dir. " + "Name your checkpoints, preferably with meaningful names. " + "If you are Ido Galil, tell Tomer that you got this exception ;) " + ) + + # Filter out checkpoints without block_configs (e.g. unconverted raw HF layouts) + valid_checkpoint_dirs = [ + cp + for cp in checkpoint_dirs + if is_valid_decilm_checkpoint(cp, trust_remote_code=trust_remote_code) + ] + + experiment_dirs = [ + p if (p in subdirs_of_master_checkpoints_dir) else p.parent for p in valid_checkpoint_dirs + ] + + deduped_checkpoint_dirs = set( + pd.DataFrame({"checkpoint_dir": valid_checkpoint_dirs, "experiment_dir": experiment_dirs}) + .sort_values("checkpoint_dir") + .drop_duplicates(subset="experiment_dir", keep="last")["checkpoint_dir"] + .tolist() + ) + return deduped_checkpoint_dirs + + +def _infer_subblocks_to_extract( + checkpoint_dir: Path, + checkpoints_to_split: list[Path], +) -> list[str]: + if (checkpoint_dir / "replacement_library.json").exists(): + return [] + bypass_config_path = checkpoint_dir / "bypass_config.json" + if (checkpoint_dir in checkpoints_to_split) or (not bypass_config_path.exists()): + subblocks_to_extract = ["block", "attention", "ffn"] + else: + bypass_config = json.loads(bypass_config_path.read_text()) + keys_to_learn = bypass_config.get("keys_to_learn", "entire_block") + if keys_to_learn == "entire_block": + subblocks_to_extract = ["block"] + elif "mlp" in keys_to_learn and "attn" not in keys_to_learn: + subblocks_to_extract = ["ffn"] + elif "attn" in keys_to_learn and "mlp" not in keys_to_learn: + subblocks_to_extract = ["attention"] + else: + raise ValueError(f"Unrecognized {keys_to_learn=}") + return subblocks_to_extract + + +def _init_empty_subblock_row(block_idx: int) -> dict[str, Any]: + return { + "attention_checkpoint_dir": None, + "ffn_checkpoint_dir": None, + "block_config": None, + "attention_config": None, + "ffn_config": None, + "block_idx": block_idx, + "block_repr": None, + "attention_repr": None, + "ffn_repr": None, + } + + +def _build_layer_replacements( + block_library_df: pd.DataFrame, + master_puzzle_dir: Path, + teacher_checkpoint_dir: Path, + trust_remote_code: bool = False, +) -> list[dict]: + layer_replacements_from_blocks = _build_layer_replacements_from_block_library(block_library_df) + layer_replacements_from_checkpoints = _gather_layer_replacements_from_checkpoints( + master_puzzle_dir, trust_remote_code=trust_remote_code + ) + layer_replacements = layer_replacements_from_blocks + layer_replacements_from_checkpoints + layer_replacements = _filter_duplicate_teacher_replacements( + layer_replacements, teacher_checkpoint_dir, trust_remote_code + ) + return layer_replacements + + +def _build_layer_replacements_from_block_library(block_library_df: pd.DataFrame) -> list[dict]: + layer_replacements = [] + for _, row in block_library_df.iterrows(): + block_idx = row["block_idx"] + block_config = row["block_config"] + weight_paths = [] + for subblock_name in ["attention", "ffn"]: + checkpoint_dir = row[f"{subblock_name}_checkpoint_dir"] + if checkpoint_dir is not None: + subblock_path = ( + Path(checkpoint_dir) + / SAFETENSORS_SUBBLOCKS_DIR_NAME + / f"block_{block_idx}_{subblock_name}.safetensors" + ) + weight_paths.append(subblock_path) + weight_paths = sorted(set(weight_paths)) + layer_replacement = { + "parent_layer_indices": [block_idx], + "child_block_configs": [block_config], + "weight_paths": weight_paths, + } + layer_replacements.append(layer_replacement) + return layer_replacements + + +def _gather_layer_replacements_from_checkpoints( + master_puzzle_dir: str | Path, trust_remote_code: bool = False +) -> list[dict]: + gathered_layer_replacements = [] + checkpoint_dirs = _get_last_checkpoint_from_each_experiment( + master_puzzle_dir, trust_remote_code=trust_remote_code + ) + for checkpoint_dir in checkpoint_dirs: + if (layer_replacements_path := checkpoint_dir / "replacement_library.json").exists(): + layer_replacements = json.loads(layer_replacements_path.read_text()) + for layer_replacement in layer_replacements: + layer_replacement["child_block_configs"] = [ + BlockConfig(**block_config_dict) + for block_config_dict in layer_replacement["child_block_configs"] + ] + layer_replacement["weight_paths"] = sorted( + set(Path(p) for p in layer_replacement["weight_paths"]) + ) + gathered_layer_replacements.extend(layer_replacements) + return gathered_layer_replacements + + +def _filter_duplicate_teacher_replacements( + layer_replacements: list[dict], + teacher_checkpoint_dir: Path, + trust_remote_code: bool = False, +) -> list[dict]: + teacher_model_config = load_model_config( + teacher_checkpoint_dir, trust_remote_code=trust_remote_code + ) + filtered_layer_replacements = [] + for layer_replacement in layer_replacements: + if replacement_is_teacher( + layer_replacement, teacher_model_config, teacher_checkpoint_dir + ) or not is_replacement_identical_to_teacher(layer_replacement, teacher_model_config): + filtered_layer_replacements.append(layer_replacement) + return filtered_layer_replacements + + +def _build_single_sequence_replacement_solutions( + layer_replacements: list[dict], + teacher_checkpoint_dir: Path, + trust_remote_code: bool = False, +) -> list[dict]: + teacher_model_config = load_model_config( + teacher_checkpoint_dir, trust_remote_code=trust_remote_code + ) + n_layer = teacher_model_config.num_hidden_layers + + teacher_replacements = dict() + student_replacements = [] + for layer_replacement in layer_replacements: + if replacement_is_teacher(layer_replacement, teacher_model_config, teacher_checkpoint_dir): + block_idx = layer_replacement["parent_layer_indices"][0] + teacher_replacements[block_idx] = layer_replacement + else: + student_replacements.append(layer_replacement) + + teacher_indices_represented_in_replacements = sorted(teacher_replacements.keys()) + assert teacher_indices_represented_in_replacements == list(range(n_layer)), ( + f"{n_layer=}, {teacher_indices_represented_in_replacements=}" + ) + + student_replacements = sort_replacements(student_replacements) + + solutions = [] + for layer_replacement in student_replacements: + block_indices_not_represented_in_replacement = sorted( + set(range(n_layer)) - set(layer_replacement["parent_layer_indices"]) + ) + chosen_replacements = sort_replacements( + [layer_replacement] + + [ + teacher_replacements[block_idx] + for block_idx in block_indices_not_represented_in_replacement + ] + ) + + block_configs = [ + block_config + for replacement in chosen_replacements + for block_config in replacement["child_block_configs"] + ] + + solutions.append( + { + "single_sequence_replacement": layer_replacement, + "chosen_replacements": chosen_replacements, + "block_configs": block_configs, + } + ) + + return solutions diff --git a/modelopt/torch/puzzletron/replacement_library/replacement_library.py b/modelopt/torch/puzzletron/replacement_library/replacement_library.py new file mode 100644 index 0000000000..c1eb0b9b48 --- /dev/null +++ b/modelopt/torch/puzzletron/replacement_library/replacement_library.py @@ -0,0 +1,171 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Replacement library for loading models with layer replacements (AnyModel / sharded HF checkpoints). +""" +# mypy: ignore-errors + +import copy +import json +import tempfile +from pathlib import Path +from typing import List, Optional + +from immutabledict import immutabledict +from safetensors import safe_open +from transformers import PretrainedConfig, PreTrainedModel + +from modelopt.torch.puzzletron.anymodel.converter.converter import Converter +from modelopt.torch.puzzletron.replacement_library.replacement_utils import ( + extract_block_configs_and_locations, + parse_layer_replacement, + weights_path_to_checkpoint_dir, +) +from modelopt.torch.puzzletron.tools.checkpoint_utils import ( + SAFETENSORS_SUBBLOCKS_DIR_NAME, + load_model_config, +) +from modelopt.torch.puzzletron.tools.checkpoint_utils_hf import save_model_config +from modelopt.torch.puzzletron.tools.sharded_checkpoint_utils import load_and_shard_model + + +class ReplacementLibrary: + def __init__( + self, + replacement_library_path: str | Path, + descriptor, + model_config_overrides: Optional[dict] = None, + ): + self.descriptor = descriptor + self.replacement_library = self._load_replacement_library(replacement_library_path) + self._ensure_all_checkpoints_are_split() + self.model_config_overrides = ( + immutabledict(model_config_overrides) if (model_config_overrides is not None) else None + ) + + self._model_config = None + self._arbitrary_checkpoint_dir = None + + @staticmethod + def _load_replacement_library(replacement_library_path: str | Path) -> list[dict]: + replacement_library = json.loads(Path(replacement_library_path).read_text()) + replacement_library = [ + parse_layer_replacement(layer_replacement) for layer_replacement in replacement_library + ] + return replacement_library + + def _ensure_all_checkpoints_are_split(self) -> None: + checkpoint_dirs = self._get_all_checkpoint_dirs() + unsplit_checkpoints = [] + for checkpoint_dir in checkpoint_dirs: + if not (checkpoint_dir / SAFETENSORS_SUBBLOCKS_DIR_NAME).exists(): + unsplit_checkpoints.append(checkpoint_dir) + assert len(unsplit_checkpoints) == 0, f"Found unsplit checkpoints: {unsplit_checkpoints}" + + @property + def model_config(self) -> PretrainedConfig: + if self._model_config is None: + trust_remote_code = self.descriptor.requires_trust_remote_code() + self._model_config = load_model_config( + self.get_arbitrary_checkpoint_dir(), + self.model_config_overrides, + ignore_unexpected_config_keys=True, + trust_remote_code=trust_remote_code, + ) + return self._model_config + + def create_model_config(self, layer_replacements: list[dict]): + block_configs, _ = extract_block_configs_and_locations(layer_replacements) + model_config = copy.deepcopy(self.model_config) + model_config.block_configs = block_configs + model_config.num_hidden_layers = len(block_configs) + return model_config + + def _get_arbitrary_non_block_checkpoint_paths(self): + checkpoint_dir = Path(self.get_arbitrary_checkpoint_dir()) + subblocks_dir = checkpoint_dir / SAFETENSORS_SUBBLOCKS_DIR_NAME + non_block_paths = [p for p in subblocks_dir.glob("*.safetensors") if "block_" not in p.name] + return non_block_paths + + def create_index_file_from_weights(self, weight_paths: List[str]): + weight_map = {} + for weight_path in weight_paths: + weight_path = Path(weight_path) + with safe_open(str(weight_path), framework="pt", device="cpu") as f: + for tensor_name in f.keys(): + weight_map[tensor_name] = f"{SAFETENSORS_SUBBLOCKS_DIR_NAME}/{weight_path.name}" + index = {"metadata": {"format": "pt"}, "weight_map": weight_map} + return index + + def prepare_tmp_checkpoint_dir( + self, + tmpdir: Path, + model_config: PretrainedConfig, + layer_replacements: List[dict], + ): + arbitrary_checkpoint_dir = Path(self.get_arbitrary_checkpoint_dir()) + + weight_paths = self._get_arbitrary_non_block_checkpoint_paths() + for layer_replacement in layer_replacements: + weight_paths += layer_replacement["weight_paths"] + + weights_index = self.create_index_file_from_weights(weight_paths) + index_path = tmpdir / "model.safetensors.index.json" + with index_path.open("w", encoding="utf-8") as out: + json.dump(weights_index, out, indent=2, sort_keys=True) + + Converter.copy_checkpoint_files(arbitrary_checkpoint_dir, tmpdir) + save_model_config(model_config, tmpdir) + + # create symlinks inside tmpdir + subblocks_dir = tmpdir / SAFETENSORS_SUBBLOCKS_DIR_NAME + subblocks_dir.mkdir(exist_ok=True) + for weight_path in weight_paths: + link_path = subblocks_dir / weight_path.name + link_path.symlink_to(weight_path) + + def load_model( + self, + layer_replacements: list[dict], + ) -> PreTrainedModel: + """Load model using AnyModel approach with temporary checkpoint directory.""" + model_config = self.create_model_config(layer_replacements) + with tempfile.TemporaryDirectory(prefix="replacement_solution_") as tmpdir: + tmpdir = Path(tmpdir) + self.prepare_tmp_checkpoint_dir( + tmpdir, model_config=model_config, layer_replacements=layer_replacements + ) + model = load_and_shard_model(descriptor=self.descriptor, checkpoint_path=tmpdir) + return model + + def get_arbitrary_checkpoint_dir(self) -> Path: + if self._arbitrary_checkpoint_dir is None: + self._arbitrary_checkpoint_dir = self._get_arbitrary_checkpoint_dir() + return self._arbitrary_checkpoint_dir + + def _get_arbitrary_checkpoint_dir(self) -> Path: + for layer_replacement in self.replacement_library: + weight_paths = layer_replacement["weight_paths"] + if len(weight_paths) > 0: + return weights_path_to_checkpoint_dir(weight_paths[0]) + + def _get_all_checkpoint_dirs(self) -> list[Path]: + checkpoint_dirs = set() + for layer_replacement in self.replacement_library: + weight_paths = layer_replacement["weight_paths"] + for weights_path in weight_paths: + checkpoint_dir = weights_path_to_checkpoint_dir(weights_path) + checkpoint_dirs.add(checkpoint_dir) + return list(checkpoint_dirs) diff --git a/modelopt/torch/puzzletron/replacement_library/replacement_utils.py b/modelopt/torch/puzzletron/replacement_library/replacement_utils.py new file mode 100644 index 0000000000..269e5e63ea --- /dev/null +++ b/modelopt/torch/puzzletron/replacement_library/replacement_utils.py @@ -0,0 +1,122 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""This module provides helper functions for parsing, sorting, and analyzing layer replacement +configurations used in the replacement library for model compression. +""" + +# mypy: ignore-errors +import json +from copy import deepcopy +from pathlib import Path + +from transformers import PretrainedConfig + +from modelopt.torch.puzzletron.decilm.deci_lm_hf_code.block_config import BlockConfig +from modelopt.torch.puzzletron.mip.utils import sort_replacements + + +def parse_layer_replacement(layer_replacement: dict | str) -> dict: + if isinstance(layer_replacement, str): + layer_replacement = json.loads(layer_replacement) + else: + layer_replacement = deepcopy(layer_replacement) + + if "layer_replacement" in layer_replacement: # happens in puzzle solutions + layer_replacement = layer_replacement["layer_replacement"] + + layer_replacement["child_block_configs"] = [ + BlockConfig(**block_config) if isinstance(block_config, dict) else block_config + for block_config in layer_replacement["child_block_configs"] + ] + layer_replacement["weight_paths"] = [Path(p) for p in layer_replacement["weight_paths"]] + return layer_replacement + + +# sort_replacements moved to modelopt.torch.puzzletron.mip.utils and imported above + + +def extract_block_configs_and_locations( + layer_replacements: list[dict], +) -> tuple[list[BlockConfig], list[tuple[dict, int]]]: + layer_replacements = sort_replacements(layer_replacements) + block_configs = [] + block_locations = [] + for layer_replacement in layer_replacements: + child_block_configs = layer_replacement["child_block_configs"] + if not isinstance(child_block_configs, list | tuple): + child_block_configs = [child_block_configs] + for block_idx_in_replacement, block_config in enumerate(child_block_configs): + block_configs.append(block_config) + block_locations.append((layer_replacement, block_idx_in_replacement)) + return block_configs, block_locations + + +def weights_path_to_checkpoint_dir(weights_path: Path) -> Path: + checkpoint_dir: Path = weights_path + while checkpoint_dir != Path("/"): + if (checkpoint_dir / "config.json").exists(): + return checkpoint_dir + checkpoint_dir = checkpoint_dir.parent + raise FileNotFoundError(f"Couldn't find checkpoint dir for weights path {weights_path}") + + +def replacement_is_teacher( + layer_replacement: dict, + teacher_model_config: PretrainedConfig, + teacher_checkpoint_dir: Path, +) -> bool: + paths_all_teacher = all( + p.is_relative_to(teacher_checkpoint_dir) for p in layer_replacement["weight_paths"] + ) + return paths_all_teacher and is_replacement_identical_to_teacher( + layer_replacement, teacher_model_config + ) + + +def is_replacement_identical_to_teacher( + layer_replacement: dict, + teacher_model_config: PretrainedConfig, +) -> bool: + if len(layer_replacement["parent_layer_indices"]) == 1: + block_idx = layer_replacement["parent_layer_indices"][0] + teacher_block_config = teacher_model_config.block_configs[block_idx] + if len(child_block_configs := layer_replacement["child_block_configs"]) == 1: + replacement_block_config: BlockConfig = child_block_configs[0] + if replacement_block_config == teacher_block_config: + return True + else: + parallel_blocks = getattr(replacement_block_config, "parallel_blocks", None) + if ( + parallel_blocks is not None + and len(parallel_blocks) == 1 + and parallel_blocks[0].attention == teacher_block_config.attention + and parallel_blocks[0].ffn == teacher_block_config.ffn + ): + return True + return False + + +def split_replacements_to_teacher_and_student( + replacements: list[dict], + teacher_model_config: PretrainedConfig, + teacher_checkpoint_dir: Path, +) -> tuple[list[dict], list[dict]]: + teacher_replacements, student_replacements = [], [] + for replacement in replacements: + if replacement_is_teacher(replacement, teacher_model_config, teacher_checkpoint_dir): + teacher_replacements.append(replacement) + else: + student_replacements.append(replacement) + return teacher_replacements, student_replacements diff --git a/modelopt/torch/puzzletron/scoring/scoring.py b/modelopt/torch/puzzletron/scoring/scoring.py new file mode 100644 index 0000000000..8f1871de89 --- /dev/null +++ b/modelopt/torch/puzzletron/scoring/scoring.py @@ -0,0 +1,90 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Validates and scores model compression solutions by evaluating puzzle solution candidates.""" + +# mypy: ignore-errors +import os +import re +from glob import glob + +import hydra +import numpy as np +import pandas as pd +from omegaconf import DictConfig + +import modelopt.torch.utils.distributed as dist +from modelopt.torch.puzzletron.tools.hydra_utils import register_hydra_resolvers +from modelopt.torch.puzzletron.tools.logger import mprint +from modelopt.torch.puzzletron.tools.validate_puzzle_with_multi_replacements import ( + validate_puzzle_solutions, +) + + +def extract_solution_id(filename): + pattern = r"solution_(\d+)\.json" + match = re.search(pattern, filename) + + if match: + solution_id = match.group(1) + return int(solution_id) + else: + mprint(f"Couldn't extract solutions_id from file {filename}") + + +def find_missing_solutions(solutions_df, validation_dir): + all_solutions = np.arange(solutions_df.shape[0]) + + benchmarked_solutions = list(glob(f"{validation_dir}/solution*.json")) + benchmarked_solutions = [ + extract_solution_id(os.path.basename(s)) for s in benchmarked_solutions + ] + benchmarked_solutions = [s for s in benchmarked_solutions if s is not None] + + unbenchmarked_solutions = np.setdiff1d(all_solutions, benchmarked_solutions) + return unbenchmarked_solutions.tolist() + + +def get_solutions_to_validate(cfg: DictConfig): + _solutions_to_validate = cfg.scoring.solutions_to_validate + if _solutions_to_validate is None: + single_block_replacement_solutions = pd.read_json(cfg.scoring.solutions_path) + if cfg.scoring.skip_existing_solutions: + _solutions_to_validate = find_missing_solutions( + single_block_replacement_solutions, cfg.scoring.output_dir + ) + else: + _solutions_to_validate = np.arange(single_block_replacement_solutions.shape[0]).tolist() + return _solutions_to_validate + + +def launch_scoring(cfg: DictConfig): + cfg.scoring.solutions_to_validate = get_solutions_to_validate(cfg) + mprint(f"Solutions to validate: {cfg.scoring.solutions_to_validate}") + validate_puzzle_solutions(args=cfg.scoring) + + +@hydra.main("", version_base="1.3") +def main(cfg: DictConfig) -> None: + cfg = hydra.utils.instantiate(cfg) + mprint(cfg) + dist.setup(timeout=cfg.nccl_timeout_minutes) + launch_scoring(cfg) + dist.cleanup() + + +if __name__ == "__main__": + register_hydra_resolvers() + main() diff --git a/modelopt/torch/puzzletron/sewing_kit/__init__.py b/modelopt/torch/puzzletron/sewing_kit/__init__.py new file mode 100644 index 0000000000..c8f7ffa013 --- /dev/null +++ b/modelopt/torch/puzzletron/sewing_kit/__init__.py @@ -0,0 +1,35 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# mypy: ignore-errors + +from .core import ( + CantResolveNodeDependenciesException, + ConstantTarget, + ExternalTarget, + FunctionTarget, + InputsLoopFoundException, + KnotException, + LoopFoundException, + ModuleTarget, + MultipleExternalNodesException, + Needle, + OnlyInternalNodesException, + OutputsLoopFoundException, + RemoteTarget, + StitchedModule, + StitchedModuleException, + StitchedModuleOutput, +) +from .passage import InputArgs, always_false_predicate, always_true_predicate diff --git a/modelopt/torch/puzzletron/sewing_kit/core.py b/modelopt/torch/puzzletron/sewing_kit/core.py new file mode 100644 index 0000000000..fb9055c3ed --- /dev/null +++ b/modelopt/torch/puzzletron/sewing_kit/core.py @@ -0,0 +1,883 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# mypy: ignore-errors + +from __future__ import annotations + +from abc import ABC +from collections import defaultdict +from dataclasses import dataclass, field +from typing import Any, Callable, Iterable, Literal, Optional, Sequence, Union + +from typing_extensions import override + +try: + from typing import Self +except ImportError: + from typing_extensions import Self + +import torch +import torch.distributed +import torch.nn as nn + +from .passage import ( + InputArgs, + OutputValue, + Passage, + PassageInputAdapter, + PassageInputOverrides, + PassageOutputAdapter, + PassageOutputOverrides, + Predicate, + always_false_predicate, +) +from .utils import distributed_isend_obj, distributed_recv_obj, dynamo_skip + +InputAdapter = Callable[[InputArgs], InputArgs] +OutputAdapter = Callable[..., OutputValue] + + +def default_input_adapter_fn(input_values: InputArgs) -> InputArgs: + return input_values + + +def default_output_adapter_fn(v: OutputValue) -> OutputValue: + return v + + +@dataclass +class IOReducer: + pass + + +def default_input_reducer_fn(acc: InputArgs, input_override: InputArgs, *args): + return acc + input_override + + +@dataclass +class InputReducer(IOReducer): + reducer_fn: Callable[[InputArgs, InputArgs, InputArgs, int, list[InputArgs]], InputArgs] = ( + default_input_reducer_fn + ) + + def __call__( + self, + acc: InputArgs, + input_override: InputArgs, + original_input: InputArgs, + index: int, + all_input_overrides: list[InputArgs], + ) -> InputArgs: + result = self.reducer_fn(acc, input_override, original_input, index, all_input_overrides) + return result + + @classmethod + def default(cls) -> InputReducer: + return InputReducer() + + +def default_output_reducer_fn(acc: OutputValue, input_override: OutputValue, *args): + return input_override + + +@dataclass +class OutputReducer(IOReducer): + reducer_fn: Callable[ + [OutputValue, OutputValue, Optional[OutputValue], int, list[OutputValue]], OutputValue + ] = default_output_reducer_fn + requires_original_output: bool = False + + def __call__( + self, + acc: OutputValue, + output_override: OutputValue, + original_output: Optional[OutputValue], + index: int, + all_output_overrides: list[OutputValue], + ) -> InputArgs: + result = self.reducer_fn(acc, output_override, original_output, index, all_output_overrides) + return result + + @classmethod + def default(cls) -> OutputReducer: + return OutputReducer() + + +class Singleton(type): + _instances = {} + + @override + def __call__(cls, *args, **kwargs): + if cls not in cls._instances: + cls._instances[cls] = super(Singleton, cls).__call__(*args, **kwargs) + return cls._instances[cls] + + +@dataclass +class Target: + @override + def __hash__(self) -> int: + return id(self) + + +@dataclass +class TargetWithInput(Target): + @override + def __hash__(self) -> int: + return super().__hash__() + + def input( + self, + adapter: InputAdapter = default_input_adapter_fn, + reducer: InputReducer = InputReducer.default(), + ) -> InputDescriptor: + result = InputDescriptor(self, input_name="", input_adapter=adapter, reducer=reducer) + return result + + +@dataclass +class TargetWithNamedInputs(Target): + @override + def __hash__(self) -> int: + return super().__hash__() + + def input( + self, + name: str, + adapter: InputAdapter = default_input_adapter_fn, + reducer: InputReducer = InputReducer.default(), + ) -> InputDescriptor: + result = InputDescriptor(self, input_name=name, input_adapter=adapter, reducer=reducer) + return result + + +@dataclass +class TargetWithOutput(Target): + @override + def __hash__(self) -> int: + return super().__hash__() + + def output( + self, + adapter: OutputAdapter = default_output_adapter_fn, + reducer: OutputReducer = OutputReducer.default(), + ) -> OutputDescriptor: + result = OutputDescriptor(self, output_name="", output_adapter=adapter, reducer=reducer) + return result + + +@dataclass +class TargetWithNamedOutputs(Target): + @override + def __hash__(self) -> int: + return super().__hash__() + + def output( + self, + name: str, + adapter: OutputAdapter = default_output_adapter_fn, + reducer: OutputReducer = OutputReducer.default(), + ) -> OutputDescriptor: + result = OutputDescriptor(self, output_name=name, output_adapter=adapter, reducer=reducer) + return result + + +@dataclass +class ExternalTarget(TargetWithNamedInputs, TargetWithNamedOutputs, metaclass=Singleton): + """External target for stitched modules.""" + + @override + def __hash__(self) -> int: + return super().__hash__() + + +@dataclass +class ConstantTarget(TargetWithOutput): + name: str + value: Any + + @override + def __hash__(self) -> int: + return super().__hash__() + + +@dataclass +class FunctionTarget(TargetWithInput, TargetWithOutput): + name: str + function: Callable[..., Any] + + @override + def __hash__(self) -> int: + return super().__hash__() + + +@dataclass +class ModuleTarget(TargetWithNamedInputs, TargetWithNamedOutputs): + name: str + module: nn.Module + + @override + def __str__(self) -> str: + return f"ModuleTarget({self.name})" + + @override + def __repr__(self) -> str: + return str(self) + + @override + def __hash__(self) -> int: + return super().__hash__() + + +@dataclass +class RemoteTarget(Target): + peer_rank: Union[int, Sequence[int]] + process_group: Optional[torch.distributed.ProcessGroup] = None + blocking: bool = True + + @override + def __hash__(self) -> int: + return super().__hash__() + + def value( + self, + name: str, + adapter: OutputAdapter = default_output_adapter_fn, + reducer: OutputReducer = OutputReducer.default(), + ) -> OutputDescriptor: + result = OutputDescriptor(self, output_name=name, output_adapter=adapter, reducer=reducer) + return result + + +@dataclass(frozen=True, eq=True) +class RemoteDataDescriptor(ABC): + key: str + + +@dataclass(frozen=True, eq=True) +class RemoteTensorDataDescriptor(RemoteDataDescriptor): + device: Literal["cuda", "cpu"] + dtype: torch.dtype + shape: torch.Size + + +@dataclass(frozen=True, eq=True) +class RemotePythonDataDescriptor(RemoteDataDescriptor): + value: Any + + +@dataclass +class Node: + target: Target + stitches_to: list[StitchDescriptor] = field(default_factory=list) + stitches_from: list[StitchDescriptor] = field(default_factory=list) + + @override + def __hash__(self) -> int: + return id(self) + + +@dataclass +class InputDescriptor: + target: Target + input_name: str = "" + input_adapter: InputAdapter = field(default=default_input_adapter_fn) + reducer: InputReducer = field(default_factory=InputReducer.default) + + @override + def __hash__(self) -> int: + return id(self) + + +@dataclass +class OutputDescriptor: + target: Target + output_name: str = "" + output_adapter: OutputAdapter = field(default=default_output_adapter_fn) + reducer: OutputReducer = field(default_factory=OutputReducer.default) + + @override + def __hash__(self) -> int: + return id(self) + + +IODescriptor = Union[InputDescriptor, OutputDescriptor] + + +@dataclass +class StitchDescriptor: + source_descriptor: IODescriptor + destination_descriptor: IODescriptor + + @override + def __hash__(self) -> int: + return id(self) + + +@dataclass +class StitchedModuleOutput: + captured_inputs: dict[str, InputArgs] + captured_outputs: dict[str, Any] + + +class StitchedModuleException(Exception): + pass + + +class CantResolveNodeDependenciesException(StitchedModuleException): + pass + + +class StitchedModule(nn.Module): + def __init__( + self, + nodes: dict[Target, Node], + capture_cache_outputs_predicate: Predicate = always_false_predicate, + early_exit=True, + ignore_extra_overrides=False, + ) -> None: + super().__init__() + self.nodes = nodes + self.ignore_extra_overrides = ignore_extra_overrides + external_nodes = [n for n in nodes.values() if isinstance(n.target, ExternalTarget)] + remote_nodes = [n for n in nodes.values() if isinstance(n.target, RemoteTarget)] + assert len(external_nodes) <= 1 + assert len(remote_nodes) + len(external_nodes) > 0 + self.external_node = external_nodes[0] if len(external_nodes) > 0 else None + self.internal_nodes = [ + n for n in nodes.values() if not isinstance(n.target, ExternalTarget) + ] + self.values_from_node: dict[Node, dict[IODescriptor, Any]] = defaultdict(dict) + self.values_to_node: dict[Node, dict[IODescriptor, Any]] = defaultdict(dict) + + self.node_passages: dict[Node, Passage] = { + node: Passage.create( + module=node.target.module, + inputs_to_capture=set( + s.source_descriptor.input_name + for s in node.stitches_from + if isinstance(s.source_descriptor, InputDescriptor) + ), + outputs_to_capture=set( + s.source_descriptor.output_name + for s in node.stitches_from + if isinstance(s.source_descriptor, OutputDescriptor) + ), + capture_cache_outputs_predicate=capture_cache_outputs_predicate, + early_exit=early_exit, + name=getattr(node.target, "name", None), + ) + for node in self.internal_nodes + if isinstance(node.target, ModuleTarget) + } + + self.passage_modules = nn.ModuleDict( + { + f"node_{node_index}": self.node_passages[node] + for node_index, node in enumerate(nodes.values()) + if node in self.node_passages + } + ) + self.adapter_modules = nn.ModuleDict( + { + f"node_{node_index}__stitch_{stitch_index}__{descriptor_name}": adapter + for node_index, node in enumerate(nodes.values()) + for stitch_index, stitch in enumerate(node.stitches_from + node.stitches_to) + for descriptor_name, descriptor in ( + ("source", stitch.source_descriptor), + ("destination", stitch.destination_descriptor), + ) + for adapter in [ + descriptor.input_adapter + if isinstance(descriptor, InputDescriptor) + else descriptor.output_adapter + ] + if isinstance(adapter, nn.Module) + } + ) + + def create_input_overrides( + self, values_to_node: dict[IODescriptor, Any] + ) -> PassageInputOverrides: + input_descriptors_by_group = defaultdict[str, list[InputDescriptor]](list) + for io_descriptor in values_to_node.keys(): + if isinstance(io_descriptor, InputDescriptor): + input_descriptors_by_group[io_descriptor.input_name].append(io_descriptor) + + input_overrides = PassageInputOverrides() + for group, input_descriptors in input_descriptors_by_group.items(): + reducers = [d.reducer for d in input_descriptors] + + def create_reducer(input_descriptors=input_descriptors, reducers=reducers): + inputs = [values_to_node[d] for d in input_descriptors] + + def reducer_fn( + original_input: InputArgs, + module_name: Optional[str], + module: Optional[nn.Module], + ) -> InputArgs: + acc = InputArgs() + for i, (input_, reducer) in enumerate(zip(inputs, reducers)): + acc = reducer(acc, input_, original_input, i, inputs) + return acc + + return reducer_fn + + input_override = PassageInputAdapter(create_reducer()) + input_overrides[group] = input_override + + return input_overrides + + def create_output_overrides( + self, values_to_node: dict[IODescriptor, Any] + ) -> PassageOutputOverrides: + output_descriptors_by_group = defaultdict[str, list[OutputDescriptor]](list) + for io_descriptor in values_to_node.keys(): + if isinstance(io_descriptor, OutputDescriptor): + output_descriptors_by_group[io_descriptor.output_name].append(io_descriptor) + + output_overrides = PassageOutputOverrides() + for group, output_descriptors in output_descriptors_by_group.items(): + reducers = [d.reducer for d in output_descriptors] + requires_original_output = any(r.requires_original_output for r in reducers) + + def create_reducer(reducers=reducers): + outputs = [values_to_node[d] for d in output_descriptors] + + def reducer_fn( + original_output: Optional[OutputValue], + module_name: Optional[str], + module: Optional[nn.Module], + ) -> OutputValue: + acc = None + for i, (output, reducer) in enumerate(zip(outputs, reducers)): + acc = reducer(acc, output, original_output, i, outputs) + return acc + + return reducer_fn + + reducer_fn = create_reducer() + if requires_original_output: + output_override = PassageOutputAdapter(reducer_fn) + else: + output_override = reducer_fn(None, None, None) + + output_overrides[group] = output_override + + return output_overrides + + @override + def __call__( + self, + input_overrides: dict[str, Any], + output_overrides: dict[str, Any], + *args, + **kwargs, + ) -> StitchedModuleOutput: + return super().__call__(input_overrides, output_overrides, *args, **kwargs) + + @override + @dynamo_skip + def forward( + self, + input_overrides: dict[str, Any], + output_overrides: dict[str, Any], + *args, + **kwargs, + ) -> StitchedModuleOutput: + input_overrides = {k: InputArgs.from_value(v) for k, v in input_overrides.items()} + + self.values_from_node.clear() + self.values_to_node.clear() + + unresolved_count: int = 0 + nodes_stack: list[Node] = ( + [] if self.external_node is None else [self.external_node] + ) + self.internal_nodes + while len(nodes_stack) > 0: + node = nodes_stack.pop(0) + values_from_node = self.values_from_node[node] + values_to_node = self.values_to_node[node] + + if isinstance(node.target, ExternalTarget): + assert self.external_node is not None + + if not self.ignore_extra_overrides: + input_override_names = set(input_overrides.keys()) + external_node_input_names = set( + s.source_descriptor.input_name + for s in self.external_node.stitches_from + if isinstance(s.source_descriptor, InputDescriptor) + ) + assert input_override_names == external_node_input_names + output_override_names = set(output_overrides.keys()) + external_node_output_names = set( + s.source_descriptor.output_name + for s in self.external_node.stitches_from + if isinstance(s.source_descriptor, OutputDescriptor) + ) + assert output_override_names == external_node_output_names + + for stitch in self.external_node.stitches_from: + if isinstance(stitch.source_descriptor, InputDescriptor): + orig_input_override = input_overrides[stitch.source_descriptor.input_name] + input_override = stitch.source_descriptor.input_adapter(orig_input_override) + values_from_node[stitch.source_descriptor] = input_override + elif isinstance(stitch.source_descriptor, OutputDescriptor): + orig_output_override = output_overrides[ + stitch.source_descriptor.output_name + ] + output_override = stitch.source_descriptor.output_adapter( + orig_output_override + ) + values_from_node[stitch.source_descriptor] = output_override + else: + raise RuntimeError("Shouldn't happen") + + else: + if len(values_to_node) < len(node.stitches_to): + nodes_stack.append(node) + unresolved_count += 1 + if unresolved_count >= len(nodes_stack): + raise CantResolveNodeDependenciesException( + "Can't resolve nodes dependencies" + ) + continue + + if isinstance(node.target, ConstantTarget): + assert len(values_to_node) == 0 + + output_value = node.target.value + + for stitch in node.stitches_from: + assert isinstance(stitch.source_descriptor, OutputDescriptor) + assert stitch.source_descriptor.output_name == "" + value = stitch.source_descriptor.output_adapter(output_value) + values_from_node[stitch.source_descriptor] = value + + elif isinstance(node.target, FunctionTarget): + assert all( + isinstance(v, InputDescriptor) and v.input_name == "" + for v in values_to_node + ) + + function_input_overrides = self.create_input_overrides(values_to_node)[""] + + if isinstance(function_input_overrides, InputArgs): + input_args = function_input_overrides + else: + input_args = function_input_overrides(InputArgs(), None, None) + + function_output = node.target.function(*input_args.args, **input_args.kwargs) + + for stitch in node.stitches_from: + assert isinstance(stitch.source_descriptor, OutputDescriptor) + assert stitch.source_descriptor.output_name == "" + value = stitch.source_descriptor.output_adapter(function_output) + values_from_node[stitch.source_descriptor] = value + + elif isinstance(node.target, ModuleTarget): + passage = self.node_passages[node] + passage.input_overrides = self.create_input_overrides(values_to_node) + passage.output_overrides = self.create_output_overrides(values_to_node) + passage_output = passage(*args, **kwargs) + + for stitch in node.stitches_from: + if isinstance(stitch.source_descriptor, InputDescriptor): + captured_input = passage_output.captured_inputs[ + stitch.source_descriptor.input_name + ] + value = stitch.source_descriptor.input_adapter(captured_input) + values_from_node[stitch.source_descriptor] = value + elif isinstance(stitch.source_descriptor, OutputDescriptor): + captured_output = passage_output.captured_outputs[ + stitch.source_descriptor.output_name + ] + value = stitch.source_descriptor.output_adapter(captured_output) + values_from_node[stitch.source_descriptor] = value + else: + raise RuntimeError("Shouldn't happen") + + elif isinstance(node.target, RemoteTarget): + assert all( + isinstance(v, OutputDescriptor) and v.output_name != "" + for v in values_from_node + ) + assert all( + isinstance(v, OutputDescriptor) and v.output_name != "" + for v in values_to_node + ) + + process_group = node.target.process_group + peers = node.target.peer_rank + if not isinstance(peers, Sequence): + peers = [peers] + + if len(values_to_node) > 0: + items_to_send = list(self.create_output_overrides(values_to_node).items()) + + data_descriptors: list[RemoteDataDescriptor] = [] + tensors_to_send: list[torch.Tensor] = [] + + for key, value in items_to_send: + if isinstance(value, torch.Tensor): + if value.is_cuda: + tensor_device = "cuda" + elif value.is_cpu: + tensor_device = "cpu" + else: + raise RuntimeError( + f"Invalid tensor device to send to remote target: {value.device}" + ) + + data_descriptor = RemoteTensorDataDescriptor( + key=key, + device=tensor_device, + dtype=value.dtype, + shape=value.shape, + ) + tensors_to_send.append(value) + + else: + data_descriptor = RemotePythonDataDescriptor( + key=key, + value=value, + ) + + data_descriptors.append(data_descriptor) + + works: list[Optional[torch.distributed.Work]] = [] + for peer in peers: + if process_group is not None: + peer = torch.distributed.get_global_rank(process_group, peer) + + peer_works = distributed_isend_obj(data_descriptors, dst=peer) + works.extend(peer_works) + + for tensor in tensors_to_send: + work = torch.distributed.isend(tensor, dst=peer) + works.append(work) + + if node.target.blocking: + for work in works: + if work is not None: + work.wait() + + if len(node.stitches_from) > 0: + assert len(peers) == 1, ( + f"Cannot use multiple peers when using RemoteTarget as a source ({peers=})" + ) + (peer,) = peers + + if process_group is not None: + peer = torch.distributed.get_global_rank(process_group, peer) + + data_descriptors = distributed_recv_obj(src=peer) + assert isinstance(data_descriptors, list) + + tensors_to_recv: list[torch.Tensor] = [] + received_values: dict[str, Any] = {} + for data_descriptor in data_descriptors: + if isinstance(data_descriptor, RemoteTensorDataDescriptor): + tensor = torch.empty( + data_descriptor.shape, + dtype=data_descriptor.dtype, + device=data_descriptor.device, + ) + tensors_to_recv.append(tensor) + received_values[data_descriptor.key] = tensor + elif isinstance(data_descriptor, RemotePythonDataDescriptor): + received_values[data_descriptor.key] = data_descriptor.value + else: + raise RuntimeError( + f"Received invalid data descriptor from remote peer: {data_descriptor}" + ) + + works: list[Optional[torch.distributed.Work]] = [] + for tensor in tensors_to_recv: + work = torch.distributed.irecv(tensor, src=peer) + works.append(work) + + for work in works: + if work is not None: + work.wait() + + for stitch in node.stitches_from: + if isinstance(stitch.source_descriptor, OutputDescriptor): + remote_output = received_values[ + stitch.source_descriptor.output_name + ] + value = stitch.source_descriptor.output_adapter(remote_output) + values_from_node[stitch.source_descriptor] = value + else: + raise RuntimeError("Shouldn't happen") + else: + raise RuntimeError("Shouldn't happen") + + for stitch in node.stitches_from: + dst_node = self.nodes[stitch.destination_descriptor.target] + value = values_from_node[stitch.source_descriptor] + + if isinstance(stitch.destination_descriptor, InputDescriptor): + value = stitch.destination_descriptor.input_adapter(value) + elif isinstance(stitch.destination_descriptor, OutputDescriptor): + value = stitch.destination_descriptor.output_adapter(value) + else: + raise RuntimeError("Shouldn't happen") + + self.values_to_node[dst_node][stitch.destination_descriptor] = value + + unresolved_count = 0 + + values_to_external_node = ( + {} if self.external_node is None else self.values_to_node[self.external_node] + ) + output = StitchedModuleOutput( + captured_inputs={ + k.input_name: v + for k, v in values_to_external_node.items() + if isinstance(k, InputDescriptor) + }, + captured_outputs={ + k.output_name: v + for k, v in values_to_external_node.items() + if isinstance(k, OutputDescriptor) + }, + ) + + self.values_from_node.clear() + self.values_to_node.clear() + + return output + + +class KnotException(Exception): + pass + + +class LoopFoundException(KnotException): + pass + + +class InputsLoopFoundException(LoopFoundException): + pass + + +class OutputsLoopFoundException(LoopFoundException): + pass + + +class MultipleExternalNodesException(KnotException): + pass + + +class OnlyInternalNodesException(KnotException): + pass + + +class Needle: + def __init__(self) -> None: + self.nodes = dict[Target, Node]() + + def get_node_for_target(self, target: Target) -> Node: + if target not in self.nodes: + node = Node(target=target) + self.nodes[target] = node + else: + node = self.nodes[target] + + return node + + def stitch(self, src: IODescriptor, dst: IODescriptor) -> Self: + descriptor = StitchDescriptor(source_descriptor=src, destination_descriptor=dst) + + src_node = self.get_node_for_target(descriptor.source_descriptor.target) + dst_node = self.get_node_for_target(descriptor.destination_descriptor.target) + + if descriptor not in src_node.stitches_from: + src_node.stitches_from.append(descriptor) + + if descriptor not in dst_node.stitches_to: + dst_node.stitches_to.append(descriptor) + + return self + + def _search_loops( + self, + node: Node, + expand_fn: Callable[[Node], Iterable[IODescriptor]], + traversed_nodes: Optional[set[Node]] = None, + ) -> bool: + if isinstance(node.target, ExternalTarget): + return False + + if traversed_nodes is None: + traversed_nodes = set() + + if node in traversed_nodes: + found_loop = True + else: + traversed_nodes = traversed_nodes | {node} + found_loop = False + descriptors = expand_fn(node) + for descriptor in descriptors: + stitch_node = self.get_node_for_target(descriptor.target) + found_loop |= self._search_loops(stitch_node, expand_fn, traversed_nodes) + + return found_loop + + def _validate_nodes(self): + # internal_nodes = [n for n in self.nodes.values() if not isinstance(n.target, (ExternalTarget, RemoteTarget))] + external_nodes = [n for n in self.nodes.values() if isinstance(n.target, ExternalTarget)] + remote_nodes = [n for n in self.nodes.values() if isinstance(n.target, RemoteTarget)] + + if len(external_nodes) + len(remote_nodes) == 0: + raise OnlyInternalNodesException(f"Has only internal nodes") + + if len(external_nodes) > 1: + raise MultipleExternalNodesException( + f"Expected no more than 1 external node, found {len(external_nodes)}" + ) + + for i, node in enumerate(self.nodes.values()): + found_inputs_loop = self._search_loops( + node, lambda n: [s.source_descriptor for s in n.stitches_to] + ) + if found_inputs_loop: + raise InputsLoopFoundException(f"Found a loop in inputs of node {i}: {node}") + + found_outputs_loop = self._search_loops( + node, lambda n: [s.destination_descriptor for s in n.stitches_from] + ) + if found_outputs_loop: + raise OutputsLoopFoundException(f"Found a loop in outputs of node {i}: {node}") + + def knot( + self, + capture_cache_outputs_predicate=always_false_predicate, + early_exit=True, + ignore_extra_overrides=False, + ) -> StitchedModule: + self._validate_nodes() + + module = StitchedModule( + nodes=self.nodes, + capture_cache_outputs_predicate=capture_cache_outputs_predicate, + early_exit=early_exit, + ignore_extra_overrides=ignore_extra_overrides, + ) + + return module diff --git a/modelopt/torch/puzzletron/sewing_kit/passage/__init__.py b/modelopt/torch/puzzletron/sewing_kit/passage/__init__.py new file mode 100644 index 0000000000..02f4c19eac --- /dev/null +++ b/modelopt/torch/puzzletron/sewing_kit/passage/__init__.py @@ -0,0 +1,29 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from .core import ( + InputArgs, + OutputValue, + Passage, + PassageInputAdapter, + PassageInputOverrides, + PassageOutput, + PassageOutputAdapter, + PassageOutputOverrides, + Predicate, + always_false_predicate, + always_true_predicate, +) diff --git a/modelopt/torch/puzzletron/sewing_kit/passage/core.py b/modelopt/torch/puzzletron/sewing_kit/passage/core.py new file mode 100644 index 0000000000..c0fcb4b123 --- /dev/null +++ b/modelopt/torch/puzzletron/sewing_kit/passage/core.py @@ -0,0 +1,453 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# mypy: ignore-errors +from __future__ import annotations + +from collections.abc import Callable, Sequence +from dataclasses import dataclass +from typing import Any, ContextManager, Iterable, Mapping, Optional, Union + +import torch.nn as nn +from typing_extensions import override + +from ..utils import ( + ActivityContext, + dynamo_skip, + fake_tensors, + has_fake_tensor, + is_submodule_of, + is_submodule_or_same, + real_tensors, +) + + +@dataclass +class InputArgs: + """Container for input arguments to modules.""" + + args: list[Any] + kwargs: dict[str, Any] + + def __init__(self, *args, **kwargs): + self.args = list(args) + self.kwargs = dict(kwargs) + + def __add__(self, other: Any) -> InputArgs: + assert isinstance(other, InputArgs) + result = InputArgs(*self.args, *other.args, **{**self.kwargs, **other.kwargs}) + return result + + def drop_args(self, index: int | slice | None = None) -> InputArgs: + new_args = InputArgs(*self.args, **self.kwargs) + if index is None: + new_args.args.clear() + else: + del new_args.args[index] + + return new_args + + def drop_kwargs(self, keys: Sequence[str] | None = None) -> InputArgs: + new_args = InputArgs(*self.args, **self.kwargs) + if keys is None: + new_args.kwargs.clear() + else: + for key in keys: + new_args.kwargs.pop(key, None) + + return new_args + + @classmethod + def from_value(cls, v): + if isinstance(v, cls): + return v + elif isinstance(v, InputArgs): + return cls(*v.args, **v.kwargs) + elif isinstance(v, Sequence): + return cls(*v) + else: + return cls(v) + + +OutputValue = Any + + +@dataclass +class PassageInputAdapter: + adapter_fn: Callable[[InputArgs, Optional[str], Optional[nn.Module]], InputArgs] + + def __call__( + self, original_input: InputArgs, module_name: Optional[str], module: Optional[nn.Module] + ) -> InputArgs: + result = self.adapter_fn(original_input, module_name, module) + return result + + +@dataclass +class PassageOutputAdapter: + adapter_fn: Callable[[Any, Optional[str], Optional[nn.Module]], Any] + + def __call__( + self, original_output: Any, module_name: Optional[str], module: Optional[nn.Module] + ) -> Any: + result = self.adapter_fn(original_output, module_name, module) + return result + + +class PassageInputOverrides(dict[str, Union[PassageInputAdapter, InputArgs]]): + def __init__(self, input_overrides: Mapping[str, PassageInputAdapter | InputArgs] = {}): + for k, v in input_overrides.items(): + self[k] = v + + # def __setitem__(self, key: str, value: InputAdapter | InputArgs) -> None: + # if isinstance(key, InputArgs): + # def adapter_fn(original_input: InputArgs) -> InputArgs: + # assert isinstance(value, InputArgs) + # return value + # self[key] = InputAdapter(adapter_fn) + # else: + # self[key] = value + + +class PassageOutputOverrides(dict[str, Union[PassageOutputAdapter, Any]]): + def __init__(self, output_overrides: Mapping[str, PassageOutputAdapter | Any] = {}): + for k, v in output_overrides.items(): + self[k] = v + + +class NoActivePassageContextError(RuntimeError): + pass + + +class RequiredPassageOutputsCapturedSignal(Exception): + pass + + +@dataclass +class PassageOutput: + captured_inputs: dict[str, InputArgs] + captured_outputs: dict[str, Any] + captured_fake_outputs: dict[str, Any] + module_output: Any + + +Predicate = Callable[[str, nn.Module], bool] + + +def always_false_predicate(module_name: str, module: nn.Module) -> bool: + return False + + +def always_true_predicate(module_name: str, module: nn.Module) -> bool: + return True + + +class Passage(nn.Module): + create_fn_context = ActivityContext[None](max_depth=1) + active_passages_context = ActivityContext["Passage"](no_duplicates=True, reversed=True) + + def __init__( + self, + module: nn.Module, + *, + inputs_to_capture: Iterable[str] = [], + outputs_to_capture: Iterable[str] = [], + input_overrides: Mapping[str, PassageInputAdapter | InputArgs] = {}, + output_overrides: Mapping[str, PassageOutputAdapter | Any] = {}, + outputs_cache: dict[str, Any] = {}, + capture_fake_outputs_predicate: Predicate = always_false_predicate, + capture_cache_outputs_predicate: Predicate = always_false_predicate, + early_exit: bool = False, + name: Optional[str] = None, + ): + super().__init__() + + if not self.create_fn_context.is_active(): + raise RuntimeError("Please use Passage.create(...) in order to create a new Passage") + + self.active_context_manager: Optional[ContextManager] = None + + self.name = name + self.module = module + self.module_to_name_mapping = {id(v): k for k, v in module.named_modules()} + self.inputs_to_capture = set(inputs_to_capture) + self.outputs_to_capture = set(outputs_to_capture) + self.input_overrides = input_overrides + self.output_overrides = output_overrides + self.outputs_cache = outputs_cache + self.capture_fake_outputs_predicate = capture_fake_outputs_predicate + self.capture_cache_outputs_predicate = capture_cache_outputs_predicate + self.early_exit = early_exit + + self.reset() + + @property + def input_overrides(self) -> PassageInputOverrides: + return self._input_overrides + + @input_overrides.setter + def input_overrides(self, value: Mapping[str, PassageInputAdapter | InputArgs]): + self._input_overrides = PassageInputOverrides(value) + + @property + def output_overrides(self) -> PassageOutputOverrides: + return self._output_overrides + + @output_overrides.setter + def output_overrides(self, value: Mapping[str, PassageOutputAdapter | Any]): + self._output_overrides = PassageOutputOverrides(value) + + def reset(self): + self.required_capture_count = ( + (len(self.inputs_to_capture) + len(self.outputs_to_capture)) + if self.early_exit + else None + ) + self.captured_outputs: dict[str, Any] = {} + self.captured_inputs: dict[str, InputArgs] = {} + self.captured_fake_outputs: dict[str, Any] = {} + + @classmethod + def module_name_relative_to_active_passage(cls, module: PatchedModule) -> str: + root_passage = Passage.active_passages_context.get_active() + assert root_passage is not None + module_name = root_passage.module_to_name_mapping[id(module)] + return module_name + + @classmethod + def create( + cls, + module: nn.Module, + *, + inputs_to_capture: Iterable[str] = [], + outputs_to_capture: Iterable[str] = [], + input_overrides: Mapping[str, PassageInputAdapter | InputArgs] = {}, + output_overrides: Mapping[str, PassageOutputAdapter | Any] = {}, + outputs_cache: dict[str, Any] = {}, + capture_fake_outputs_predicate: Predicate = always_false_predicate, + capture_cache_outputs_predicate: Predicate = always_false_predicate, + early_exit: bool = False, + name: Optional[str] = None, + ) -> Passage: + with cls.create_fn_context(None): + passage = cls( + module=module, + inputs_to_capture=inputs_to_capture, + outputs_to_capture=outputs_to_capture, + input_overrides=input_overrides, + output_overrides=output_overrides, + outputs_cache=outputs_cache, + capture_fake_outputs_predicate=capture_fake_outputs_predicate, + capture_cache_outputs_predicate=capture_cache_outputs_predicate, + early_exit=early_exit, + name=name, + ) + + for submodule_name, submodule in module.named_modules(remove_duplicate=False): + patch_module(submodule_name, submodule) + + # register_passage_hooks(module, descriptor) + + return passage + + def is_active(self) -> bool: + result = self.active_context_manager is not None + return result + + def __enter__(self): + assert self.active_context_manager is None + self.active_context_manager = Passage.active_passages_context(self) + self.active_context_manager.__enter__() + self.module_to_name_mapping = {id(v): k for k, v in self.named_modules()} + + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + assert self.active_context_manager is not None + self.active_context_manager.__exit__(exc_type, exc_val, exc_tb) + + def freeze(self): + self.eval() + self.requires_grad_(False) + + def unfreeze(self): + self.train() + self.requires_grad_(True) + + def run(self, *args, **kwargs) -> PassageOutput: + return self(*args, **kwargs) + + @override + def __call__(self, *args, **kwargs) -> PassageOutput: + return super().__call__(*args, **kwargs) + + @dynamo_skip + @override + def forward(self, *args, **kwargs) -> PassageOutput: + self.reset() + + with Passage.active_passages_context(self): + try: + module_output = self.module(*args, **kwargs) + except RequiredPassageOutputsCapturedSignal: + module_output = None + + output = PassageOutput( + captured_inputs=self.captured_inputs, + captured_outputs=self.captured_outputs, + captured_fake_outputs=self.captured_fake_outputs, + module_output=module_output, + ) + + self.reset() + + return output + + +class PatchedModule: ... + + +def patch_module(module_name_: str, module: nn.Module): + # orig_forward = module.forward + + if isinstance(module, PatchedModule): + # if module_name != Passage.module_name_relative_to_active_passage(module): + # logger.warn(f'Module "{module_name}" already patched for module "{Passage.module_name_relative_to_active_passage(module)}". Could lead to bugs.') + return + + orig_class = module.__class__ + + class PassageModuleWrapper(orig_class, PatchedModule): + # Defined as a static method to avoid potential collision with original class methods + @staticmethod + @dynamo_skip + def can_be_skipped(_self: PassageModuleWrapper, depth: int) -> bool: + passages_beyond_depth = Passage.active_passages_context[depth:] + module_name = Passage.module_name_relative_to_active_passage(_self) + + results = [ + ( + module_name in passage.outputs_cache + and not any( + is_submodule_or_same(k, module_name) for k in passage.outputs_to_capture + ) + and not any( + is_submodule_of(k, module_name) + for k, v in passage.input_overrides.items() + if v is not None + ) + and not any( + is_submodule_of(k, module_name) + for k, v in passage.output_overrides.items() + if v is not None + ) + ) + for passage in passages_beyond_depth + ] + + result = all(results) + + return result + + # Defined as a static method to avoid potential collision with original class methods + @staticmethod + @dynamo_skip + def run_passage(_self: PassageModuleWrapper, depth: int, args, kwargs): + if depth + 1 > len(Passage.active_passages_context): + output = super(PassageModuleWrapper, _self).__call__(*args, **kwargs) + return output + + module_name = Passage.module_name_relative_to_active_passage(_self) + passage = Passage.active_passages_context[depth] + + has_output_override = module_name in passage.output_overrides + output_override = passage.output_overrides.get(module_name) + + if has_output_override and not isinstance(output_override, PassageOutputAdapter): + output = output_override + else: + input_override = passage.input_overrides.get(module_name) + if input_override is not None: + original_input_args = InputArgs(*args, **kwargs) + + if isinstance(input_override, PassageInputAdapter): + new_input_args = input_override(original_input_args, module_name, module) + else: + new_input_args = input_override + + args, kwargs = new_input_args.args, new_input_args.kwargs + + if ( + output_override is None + and PassageModuleWrapper.can_be_skipped(_self, depth) + and (has_fake_tensor(args) or has_fake_tensor(kwargs)) + ): + cached_output = passage.outputs_cache[module_name] + return cached_output + + output = PassageModuleWrapper.run_passage( + _self=_self, + depth=depth + 1, + args=args, + kwargs=kwargs, + ) + + if isinstance(output_override, PassageOutputAdapter): + output = output_override(output, module_name, module) + + if passage.capture_fake_outputs_predicate(module_name, module): + fake_output = fake_tensors(output) + passage.captured_fake_outputs[module_name] = fake_output + + if not module_name in passage.outputs_cache and passage.capture_cache_outputs_predicate( + module_name, module + ): + fake_output = fake_tensors(output) + passage.outputs_cache[module_name] = fake_output + + if module_name in passage.inputs_to_capture: + real_args, real_kwargs = real_tensors(args), real_tensors(kwargs) + passage.captured_inputs[module_name] = InputArgs(*real_args, **real_kwargs) + + if passage.required_capture_count is not None: + passage.required_capture_count -= 1 + + if module_name in passage.outputs_to_capture: + real_output = real_tensors(output) + output_value = real_output + passage.captured_outputs[module_name] = output_value + + if passage.required_capture_count is not None: + passage.required_capture_count -= 1 + + if passage.required_capture_count == 0: + raise RequiredPassageOutputsCapturedSignal() + + return output + + @dynamo_skip + @override + def __call__(self, *args, **kwargs): + output = self.run_passage( + _self=self, + depth=0, + args=args, + kwargs=kwargs, + ) + return output + + # module.forward = forward + PassageModuleWrapper.__name__ = f"ModuleWrapper({module.__class__.__name__})" + module.__class__ = PassageModuleWrapper diff --git a/modelopt/torch/puzzletron/sewing_kit/utils.py b/modelopt/torch/puzzletron/sewing_kit/utils.py new file mode 100644 index 0000000000..068abef99e --- /dev/null +++ b/modelopt/torch/puzzletron/sewing_kit/utils.py @@ -0,0 +1,434 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# mypy: ignore-errors +from __future__ import annotations + +import inspect +from contextlib import contextmanager +from typing import ( + TYPE_CHECKING, + Any, + Callable, + ContextManager, + Generic, + Iterable, + Literal, + Optional, + Protocol, + TypeVar, + cast, + overload, +) + +import torch +import torch._C +import torch._dynamo +import torch.distributed +import torch.nn as nn +import torch.nn.functional as F +import torch.utils._pytree as pytree +from torch import Tensor +from torch._subclasses import FakeTensor, FakeTensorMode +from typing_extensions import override + +if TYPE_CHECKING: + from collections.abc import Sequence + +Fn = TypeVar("Fn", bound=Callable) + + +class DynamoSkip(Protocol): + @overload + def __call__(self, fn: None = None) -> Callable[[Fn], Fn]: ... + @overload + def __call__(self, fn: Fn) -> Fn: ... + + +class DynamoDisable(Protocol): + @overload + def __call__(self, fn: None = None, disable: bool = False) -> Callable[[Fn], Fn]: ... + @overload + def __call__(self, fn: Fn, disable: bool = False) -> Fn: ... + + +try: + dynamo_skip: DynamoSkip = cast("Any", torch._dynamo.decorators).skip + dynamo_disable: DynamoDisable = cast("Any", torch._dynamo.decorators).disable +except: + dynamo_skip: DynamoSkip = cast("Any", torch._dynamo.eval_frame).skip + dynamo_disable: DynamoDisable = cast("Any", torch._dynamo.eval_frame).disable + + +TModule = TypeVar("TModule", bound=nn.Module) + + +class ModuleRef(Generic[TModule]): + def __init__(self, module: TModule): + self.module = module + + +class ActivityContextMaxDepthException(Exception): + pass + + +class ActivityContextDuplicateException(Exception): + pass + + +T = TypeVar("T") + + +class ActivityContext(Generic[T]): + def __init__(self, max_depth: Optional[int] = None, no_duplicates=False, reversed=False): + self.activity_stack: list[T] = [] + self.max_depth = max_depth + self.no_duplicates = no_duplicates + self.reversed = reversed + self.head_index = 0 if self.reversed else -1 + + def __contains__(self, value: T) -> bool: + result = value in self.activity_stack + return result + + def __call__(self, value: T) -> ContextManager: + @contextmanager + def fn(): + try: + if self.no_duplicates and value in self.activity_stack: + raise ActivityContextDuplicateException( + f"Activity stack cannot have a duplicate of item {value}" + ) + + self.activity_stack.insert(self.head_index, value) + + if self.max_depth is not None and len(self) > self.max_depth: + raise ActivityContextMaxDepthException( + f"Activity stack exceeds max depth of {self.max_depth}" + ) + + yield + finally: + assert self.is_active() + self.activity_stack.pop(self.head_index) + + return fn() + + def __len__(self) -> int: + result = len(self.activity_stack) + return result + + @overload + def __getitem__(self, key: int) -> T: ... + @overload + def __getitem__(self, key: slice) -> Sequence[T]: ... + def __getitem__(self, key: int | slice) -> T | Sequence[T]: + result = self.activity_stack[key] + return result + + def is_active(self) -> bool: + result = len(self) > 0 + return result + + def get_active(self) -> Optional[T]: + if self.is_active: + return self.activity_stack[-1] + else: + return None + + +def is_submodule_of(module_name: str, other_module_name: str) -> bool: + result = module_name.startswith(f"{other_module_name}.") or ( + module_name != "" and other_module_name == "" + ) + return result + + +def is_submodule_or_same(module_name: str, other_module_name: str) -> bool: + result = module_name == other_module_name or is_submodule_of(module_name, other_module_name) + return result + + +fake_mode = FakeTensorMode( + allow_non_fake_inputs=True, + # allow_fallback_kernels=False, +) + + +@overload +def fake_tensor(t: Tensor, *, dtype: Optional[torch.dtype] = None, use_meta=False) -> Tensor: ... + + +@overload +def fake_tensor( + size: Sequence[int] | torch.Size, *, dtype: Optional[torch.dtype] = None, use_meta=False +) -> Tensor: ... + + +@overload +def fake_tensor(*args: int, dtype: Optional[torch.dtype] = None, use_meta=False) -> Tensor: ... + + +class MyFakeTensor(Tensor): + @dynamo_disable + def __init__(self, *args, **kwargs): + super().__init__() + self._t: FakeTensor + + @override + @dynamo_disable + def __repr__(self, *, tensor_contents=None): + return f"MyFakeTensor(shape={list(self._t.shape)}, dtype={self._t.dtype}, device={self._t.device})" + + @classmethod + @override + @dynamo_disable + def __torch_dispatch__(cls, func, types, args=(), kwargs=None): + if kwargs is None: + kwargs = {} + + args, kwargs = pytree.tree_map_only(MyFakeTensor, lambda t: t._t, (args, kwargs)) + + types = pytree.tree_map_only(type(MyFakeTensor), lambda t: FakeTensor, types) + + out = func(*args, **kwargs) + + out = pytree.tree_map_only(Tensor, lambda t: MyFakeTensor.create(t), out) + + return out + + __torch_function__ = torch._C._disabled_torch_function_impl + + # @dynamo_disable + # def __getattribute__(self, attr: str): + # if attr in {'_t', 'device', '__repr__', '__torch_function__', '__class__'}: + # return object.__getattribute__(self, attr) + + # result = getattr(self._t, attr) + + # result = pytree.tree_map_only( + # Tensor, lambda t: MyFakeTensor.create(t), result + # ) + # print('__getattribute__', 'attr', attr, 'ret', result) + + # return result + + @property + @dynamo_disable + def device(self): + return self._t.device + + # @property + # @dynamo_disable + # def shape(self): + # return self._t.shape + + # @dynamo_disable + # def size(self): + # return self._t.size() + + # @classmethod + # @dynamo_disable + # def __torch_function__(cls, func, types, args=(), kwargs=None): + # if kwargs is None: + # kwargs = {} + + # args, kwargs = pytree.tree_map_only( + # MyFakeTensor, lambda t: t._t, (args, kwargs) + # ) + + # ret = func(*args, **kwargs) + + # ret = pytree.tree_map_only( + # Tensor, lambda t: MyFakeTensor.create(t), ret + # ) + # print('__torch_function__', 'func', func, 'ret', ret) + + # return ret + + @staticmethod + @dynamo_disable + def __new__(cls, elem, device) -> MyFakeTensor: + self = torch.Tensor._make_subclass( + cls, + elem, + elem.requires_grad, + dispatch_device=True, + device_for_backend_keys=device, + ) + return cast("MyFakeTensor", self) + + @classmethod + @dynamo_disable + def create(cls, data: Tensor) -> MyFakeTensor: + if isinstance(data, MyFakeTensor): + return data + + if isinstance(data, FakeTensor): + t = data + else: + t = FakeTensor.from_tensor(data, fake_mode=fake_mode) + + # my_fake_tensor = MyFakeTensor(torch.empty(t.shape, dtype=t.dtype, device='meta')) + my_fake_tensor = MyFakeTensor( + torch.empty(t.shape, dtype=t.dtype, device="meta"), + t.device, + ) + my_fake_tensor._t = t + + return my_fake_tensor + + +@dynamo_disable +def fake_tensor(*args, **kwargs) -> Tensor: + dtype: Optional[torch.dtype] = kwargs.get("dtype") + use_meta = kwargs.get("use_meta", False) + device = kwargs.get("device", "meta") + + if len(args) == 1 and isinstance(args[0], Tensor): + if use_meta: + fake_tensor = torch.empty(args[0].size(), dtype=dtype or args[0].dtype, device="meta") + else: + fake_tensor = MyFakeTensor.create(args[0]) + else: + fake_tensor = torch.empty(*args, dtype=dtype, device=device) + if not use_meta: + fake_tensor = MyFakeTensor.create(fake_tensor) + + return fake_tensor + + +@dynamo_skip +def fake_tensor_like(t: Tensor, use_meta=False) -> Tensor: + return fake_tensor(t, use_meta=use_meta) + + +T = TypeVar("T") + + +@dynamo_skip +def fake_tensors(value: T, use_meta=False) -> T: + result = pytree.tree_map_only(Tensor, lambda t: fake_tensor_like(t, use_meta), value) + return result + # if isinstance(value, Mapping): + # return cast(Any, value.__class__)({k: fake_tensors(v, use_meta) for k, v in value.items()}) + # if isinstance(value, Sequence): + # return cast(Any, value.__class__)([fake_tensors(v, use_meta) for v in value]) + # if isinstance(value, Tensor): + # return fake_tensor_like(value, use_meta) + # return value + + +@dynamo_skip +def real_tensors(value: Any) -> Any: + result = pytree.tree_map_only(Tensor, lambda t: None if is_fake_tensor(t) else t, value) + return result + # if isinstance(value, Mapping): + # return cast(Any, value.__class__)({k: real_tensors(v) for k, v in value.items()}) + # if isinstance(value, Sequence): + # return cast(Any, value.__class__)([real_tensors(v) for v in value]) + # if is_fake_tensor(value): + # return None + # return value + + +@dynamo_skip +def is_fake_tensor(t: Any) -> bool: + return isinstance(t, (MyFakeTensor, FakeTensor)) or (isinstance(t, Tensor) and t.is_meta) + + +@dynamo_skip +def has_fake_tensor(v: Any) -> bool: + result = pytree.tree_any(is_fake_tensor, v) + return result + + +def _get_device_for_distributed( + group: Optional[torch.distributed.ProcessGroup] = None, +) -> str: + """ + Determine the appropriate device for distributed communication based on the backend. + NCCL backend requires CUDA tensors, while Gloo supports both CPU and CUDA. + """ + if not torch.distributed.is_initialized(): + return "cpu" + + backend = torch.distributed.get_backend(group) + if backend == "nccl": + # NCCL requires CUDA tensors + return torch.cuda.current_device() + else: + # Gloo and other backends support CPU tensors + return "cpu" + + +def distributed_isend_obj( + obj: Any, + dst: int = 0, + group: Optional[torch.distributed.ProcessGroup] = None, +) -> list[Optional[torch.distributed.Work]]: + device = _get_device_for_distributed(group) + obj_tensor, obj_size_tensor = torch.distributed.distributed_c10d._object_to_tensor( + obj, device=device, **_get_group_kwarg_if_necessary() + ) + works: list[Optional[torch.distributed.Work]] = [ + torch.distributed.isend(obj_size_tensor, dst, group), + torch.distributed.isend(obj_tensor, dst, group), + ] + # p2p_ops = [ + # torch.distributed.P2POp(torch.distributed.isend, obj_size_tensor, dst, group), + # torch.distributed.P2POp(torch.distributed.isend, obj_tensor, dst, group), + # ] + + # works = torch.distributed.batch_isend_irecv(p2p_ops) + + return works + + +def distributed_send_obj( + obj: Any, + dst: int = 0, + group: Optional[torch.distributed.ProcessGroup] = None, +): + works = distributed_isend_obj(obj=obj, dst=dst, group=group) + for work in works: + if work is not None: + work.wait() + + +def distributed_recv_obj( + src: Optional[int] = None, + group: Optional[torch.distributed.ProcessGroup] = None, +) -> Any: + device = _get_device_for_distributed(group) + obj_size_tensor = torch.LongTensor(1).to(device) + torch.distributed.recv(obj_size_tensor, src=src, group=group) + obj_size = int(obj_size_tensor.item()) + + obj_tensor = torch.ByteTensor(obj_size).to(device) + torch.distributed.recv(obj_tensor, src=src, group=group) + + obj = torch.distributed.distributed_c10d._tensor_to_object( + obj_tensor, obj_size, **_get_group_kwarg_if_necessary() + ) + + return obj + + +def _get_group_kwarg_if_necessary() -> dict: + """For newer versions of torch""" + arg_names = inspect.signature( + torch.distributed.distributed_c10d._object_to_tensor + ).parameters.keys() + return dict(group=None) if "group" in arg_names else dict() diff --git a/modelopt/torch/puzzletron/subblock_stats/calc_subblock_params_and_memory.py b/modelopt/torch/puzzletron/subblock_stats/calc_subblock_params_and_memory.py new file mode 100644 index 0000000000..3ea57bd7a7 --- /dev/null +++ b/modelopt/torch/puzzletron/subblock_stats/calc_subblock_params_and_memory.py @@ -0,0 +1,339 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# mypy: ignore-errors + +"""Calculate memory usage and parameter counts for neural network subblocks. + +This module provides utilities to compute memory footprints and parameter counts +for different subblock types (FFN, Attention, Mamba, MoE) in large language models, +considering various data types, batch sizes, and sequence lengths. +""" + +import copy +import json +import math +from pathlib import Path +from typing import Type + +import numpy as np +import torch +from transformers import PretrainedConfig + +from modelopt.torch.puzzletron.anymodel.model_descriptor import ModelDescriptor +from modelopt.torch.puzzletron.decilm.deci_lm_hf_code.block_config import ( + AttentionConfig, + BlockConfig, + FFNConfig, + MambaConfig, + maybe_cast_block_configs, +) +from modelopt.torch.puzzletron.tools.checkpoint_utils_hf import init_model_from_config +from modelopt.torch.puzzletron.utils.utils import ( + EmptyInitOnDevice, + calculate_kv_dim, + raise_unknown_subblock_config_error, + sizeof_dtype, +) + + +def calculate_subblock_memory( + subblock_config: FFNConfig | AttentionConfig, + batch_size: int, + prefill_seq_len: int, + generation_seq_len: int, + prefill_queue_size: int, + n_embd: int, + n_head: int, + weights_dtype: torch.dtype, + kv_cache_dtype: torch.dtype, + allocate_prefill_query: bool, + model_config: PretrainedConfig, + descriptor: Type[ModelDescriptor], +) -> float | dict[str, float]: + """``model_config`` / ``descriptor`` are required (puzzletron-style); FFN uses them for meta init.""" + if subblock_config.no_op: + return 0 + if isinstance(subblock_config, FFNConfig): + return calculate_ffn_memory( + subblock_config, + model_config, + descriptor, + weights_dtype, + ) + if isinstance(subblock_config, AttentionConfig): + if subblock_config.is_mamba: + return calculate_mamba_memory( + subblock_config, + model_config, + descriptor, + batch_size, + weights_dtype, + kv_cache_dtype, + ) + else: + return calculate_attention_memory( + subblock_config, + model_config, + descriptor, + batch_size, + prefill_seq_len, + generation_seq_len, + prefill_queue_size, + n_embd, + n_head, + weights_dtype, + kv_cache_dtype, + allocate_prefill_query, + ) + raise_unknown_subblock_config_error(subblock_config) + + +def calculate_subblock_params( + config: PretrainedConfig, + layer_config: BlockConfig | FFNConfig | AttentionConfig, + descriptor: Type[ModelDescriptor], +) -> int: + """Count parameters on one meta decoder layer.""" + if isinstance(layer_config, FFNConfig): + block_config = layer_config.to_blockconfig() + elif isinstance(layer_config, AttentionConfig): + block_config = layer_config.to_blockconfig() + else: + block_config = layer_config + + ffn = block_config.ffn + attn = block_config.attention + ffn_no_op = ffn is None or ffn.no_op + attn_no_op = attn is None or attn.no_op + if not (ffn_no_op or attn_no_op): + raise AssertionError( + "One of ffn or attention must be no-op for sublayer param calculation " + "(single subblock at a time)." + ) + if ffn_no_op and attn_no_op: + return 0 + + _config = copy.deepcopy(config) + lm_config = descriptor.get_language_model_config(_config) + lm_config.num_hidden_layers = 1 + + block_configs = maybe_cast_block_configs([block_config]) + _config.block_configs = block_configs + if lm_config is not _config: + lm_config.block_configs = block_configs + + # Replaced earlier pattern: + # with EmptyInitOnDevice("meta"), deci_x_patcher(..., block_configs=block_configs): + # model = init_model_from_config(_config, ...) + # + # That fails on GPT-OSS with recent Transformers: ``deci_x_patcher`` runs + # ``attn_no_op_post_init`` / ``mlp_no_op_post_init`` inside ``DecoderLayer.__init__``, so norms + # / attn / mlp are swapped for placeholders before ``GptOssModel.__init__`` finishes. At the end + # of ``GptOssModel.__init__`` the stack calls ``self.post_init()`` — inherited from + # ``PreTrainedModel`` — which then raises + # ``ValueError`` (e.g. ``post_attention_layernorm`` in ``_keep_in_fp32_modules`` no longer matches + # the tree). Below we merge per-layer fields manually, init without the patcher, then call the + # same descriptor no-op hooks on the built layer (equivalent param count for + # ``num_hidden_layers == 1``). + + # ``block_config_to_layer_overrides`` may include keys with value ``None``; we omit those so + # ``lm_config.update`` does not overwrite existing fields with ``None`` (same rule as + # ``override_config_with_block_configs`` inside ``deci_x_patcher``). + layer_overrides = descriptor.block_config_to_layer_overrides(block_configs[0]) + lm_config.update({k: v for k, v in layer_overrides.items() if v is not None}) + + with EmptyInitOnDevice("meta"): + model = init_model_from_config( + _config, + trust_remote_code=descriptor.requires_trust_remote_code(), + ) + + decoder_layer = model.get_submodule(descriptor.layer_block_name(index=0)) + if attn_no_op: + descriptor.attn_no_op_post_init(decoder_layer) + if ffn_no_op: + descriptor.mlp_no_op_post_init(decoder_layer) + return sum(p.numel() for p in decoder_layer.parameters()) + + +def calc_subblock_active_params( + sublayer_config: FFNConfig | AttentionConfig, + model_config: PretrainedConfig, + descriptor: Type[ModelDescriptor], + n_embd: int, + moe_stats_file: str, + batch_size: int, + block_idx: int, +) -> int: + if not (isinstance(sublayer_config, FFNConfig) and sublayer_config.is_moe): + return calculate_subblock_params(model_config, sublayer_config, descriptor) + return estimate_moe_active_params( + sublayer_config, n_embd, moe_stats_file, batch_size, block_idx + ) + + +def load_moe_stats(stats_file: str) -> dict: + with open(stats_file) as f: + stats = json.load(f) + return [np.array(l) / np.sum(l) if len(l) > 0 else 0 for l in stats] + + +def estimate_num_active_experts( + dist_over_experts: np.ndarray, batch_size: int, num_experts: int +) -> int: + # cut the tail and renormalize + dist_over_experts = np.sort(dist_over_experts)[::-1][:num_experts] + dist_over_experts = dist_over_experts / (dist_over_experts.sum()) + # calculate the probability of at least one expert being active + # (expectation on indicators is the expected number of active experts) + return (1 - (1 - dist_over_experts) ** batch_size).sum() + + +def estimate_moe_active_params( + subblock_config: FFNConfig, + n_embd: int, + moe_stats_file: Path | str, + batch_size: int, + block_idx: int, +) -> int: + assert Path(moe_stats_file).exists() + # if not Path(moe_stats_file).exists(): # if path is not provided, should we assume uniform distribution? + # return calculate_subblock_params(subblock_config, n_embd, n_head=None) + moe_stats = load_moe_stats(moe_stats_file) + dist_over_experts = moe_stats[block_idx] + num_experts = subblock_config.moe.num_local_experts + + expected_num_active_experts = estimate_num_active_experts( + dist_over_experts, batch_size, num_experts + ) + expert_dim = subblock_config.moe.expert_intermediate_dim + shared_expert_dim = subblock_config.moe.shared_expert_intermediate_dim + num_linear_layers = 3 # all moe experts have 3 linear layers + + router_num_params = n_embd * num_experts + expected_num_active_experts_params = ( + num_linear_layers * expert_dim * n_embd * expected_num_active_experts + ) + shared_expert_num_params = num_linear_layers * shared_expert_dim * n_embd + + expected_total_params = ( + router_num_params + expected_num_active_experts_params + shared_expert_num_params + ) + return expected_total_params + + +def calculate_attention_memory( + attention_config: AttentionConfig, + model_config: PretrainedConfig, + descriptor: Type[ModelDescriptor], + batch_size: int, + prefill_seq_len: int, + generation_seq_len: int, + prefill_queue_size: int, + n_embd: int, + n_head: int, + weights_dtype: torch.dtype, + kv_cache_dtype: torch.dtype, + allocate_prefill_query: bool, +) -> dict[str, float]: + """allocate_prefill_query: infery-llm style. + Infery used a unified Wqkv matrix, so before extracting the kv-cache, + the query also had to be kept in-memory, once per layer. + """ + seq_len = prefill_seq_len + generation_seq_len + if ( + attention_config.is_llama4 + and (attention_chunk_size := attention_config.llama4.attention_chunk_size) is not None + ): + seq_len = min(seq_len, attention_chunk_size) + + kv_dim = calculate_kv_dim(attention_config.num_key_value_heads, n_head, n_embd) + total_num_tokens = seq_len * (batch_size + prefill_queue_size) + kv_cache_size = total_num_tokens * kv_dim + query_prefill_size = seq_len * n_embd if allocate_prefill_query else 0 + num_params = calculate_subblock_params(model_config, attention_config, descriptor) + total_memory = ( + kv_cache_size * sizeof_dtype(kv_cache_dtype) + + query_prefill_size * sizeof_dtype(weights_dtype) + + num_params * sizeof_dtype(weights_dtype) + ) / 2**20 + kv_cache_memory = kv_cache_size * sizeof_dtype(kv_cache_dtype) / 2**20 + return {"memory_mib": total_memory, "kv_cache_memory_mib": kv_cache_memory} + + +def calculate_mamba_memory( + attention_config: AttentionConfig, + model_config: PretrainedConfig, + descriptor: Type[ModelDescriptor], + batch_size: int, + weights_dtype: torch.dtype, + kv_cache_dtype: torch.dtype, +) -> int: + assert attention_config.mamba is not None + mamba_config = attention_config.mamba + num_params = calculate_subblock_params(model_config, attention_config, descriptor) + return ( + num_params * sizeof_dtype(weights_dtype) + + calculate_mamba_state_size(mamba_config, batch_size) * sizeof_dtype(kv_cache_dtype) + ) / 2**20 + + +def calculate_mamba_state_size( + mamba_config: MambaConfig, + batch_size: int, +) -> int: + d_inner, in_proj_dim, conv_dim, kernel_size = _calculate_mamba_intermediates(mamba_config) + conv_state_size = math.prod((batch_size, conv_dim, kernel_size)) + ssm_state_size = math.prod( + (batch_size, mamba_config.num_heads, mamba_config.head_dim, mamba_config.state_dim) + ) + return conv_state_size + ssm_state_size + + +def _calculate_mamba_intermediates(mamba_config: MambaConfig) -> tuple[int, ...]: + d_inner = mamba_config.num_heads * mamba_config.head_dim + in_proj_dim = ( + d_inner * 2 + 2 * mamba_config.num_groups * mamba_config.state_dim + mamba_config.num_heads + ) + conv_dim = d_inner + 2 * mamba_config.num_groups * mamba_config.state_dim + kernel_size = 4 + return d_inner, in_proj_dim, conv_dim, kernel_size + + +def calculate_ffn_memory( + ffn_config: FFNConfig, + model_config: PretrainedConfig, + descriptor: Type[ModelDescriptor], + weights_dtype: torch.dtype | str, + experts_dtype: torch.dtype | str | None = None, +) -> float: + # TODO: How to separate between expert weights and the rest for any model (same as puzzletron). + num_params = calculate_subblock_params(model_config, ffn_config, descriptor) + return num_params * sizeof_dtype(weights_dtype) / 2**20 + + +def calculate_non_block_memory( + n_embd: int, + vocab_size: int, + weight_dtype: torch.dtype, +) -> float: + return calculate_non_block_params(n_embd, vocab_size) * sizeof_dtype(weight_dtype) / 2**20 + + +def calculate_non_block_params( + n_embd: int, + vocab_size: int, +) -> int: + return vocab_size * n_embd * 2 + n_embd diff --git a/modelopt/torch/puzzletron/subblock_stats/calc_subblock_stats.py b/modelopt/torch/puzzletron/subblock_stats/calc_subblock_stats.py new file mode 100644 index 0000000000..549d994f07 --- /dev/null +++ b/modelopt/torch/puzzletron/subblock_stats/calc_subblock_stats.py @@ -0,0 +1,559 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# mypy: ignore-errors + +"""Calc subblock stats to compute memory and runtime statistics for subblocks.""" + +import dataclasses +import json +import os +import warnings +from functools import partial +from itertools import product +from pathlib import Path +from typing import Iterable, Optional, Type, TypeVar + +os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True" + +import pandas as pd +import torch +from immutabledict import immutabledict +from omegaconf import DictConfig, ListConfig, OmegaConf +from tqdm import tqdm +from transformers import PretrainedConfig + +from modelopt.torch.puzzletron.anymodel.model_descriptor import ( + ModelDescriptor, + ModelDescriptorFactory, +) +from modelopt.torch.puzzletron.decilm.deci_lm_hf_code.block_config import ( + AttentionConfig, + BlockConfig, + FFNConfig, + SubblockConfig, +) +from modelopt.torch.puzzletron.replacement_library.replacement_utils import parse_layer_replacement +from modelopt.torch.puzzletron.subblock_stats.calc_subblock_params_and_memory import ( + calc_subblock_active_params, + calculate_non_block_memory, + calculate_non_block_params, + calculate_subblock_memory, + calculate_subblock_params, +) +from modelopt.torch.puzzletron.tools.checkpoint_utils import load_model_config +from modelopt.torch.puzzletron.tools.logger import mprint +from modelopt.torch.puzzletron.tools.robust_json import json_dump +from modelopt.torch.puzzletron.utils.parsing import format_global_config + +# Type variable for dataclasses +T_DataClass = TypeVar("T_DataClass") + +""" +Usage: +python -m modelopt.torch.puzzletron.subblock_stats.calc_subblock_stats PUZZLE_DIR [ --benchmark_iterations 1000 ] + +--benchmark_iterations=None (the default) means that the code won't use infery to benchmark runtime, + only memory stats will be calculated. If you want to benchmark runtime, run inside an infery-llm docker. + +""" + + +def calculate_subblock_stats( + calc_subblock_stats_config: DictConfig, + teacher_dir: Path, + model_config: PretrainedConfig, + descriptor: Type[ModelDescriptor], + master_puzzle_dir: Path, + subblock_configs: list[immutabledict[str, AttentionConfig | FFNConfig]], + batch_size: int, + prefill_seq_len: int, + generation_seq_len: int, + prefill_queue_size: int, + n_embd: int, + n_head: int, + vocab_size: int, + benchmark_iterations: Optional[int], + use_cuda_graph: bool, + weights_dtype: torch.dtype, + activations_dtype: torch.dtype, + kv_cache_dtype: torch.dtype, + allocate_prefill_query: bool, + moe_stats_file: str | Path | None = None, +) -> dict: + is_calc_runtime = benchmark_iterations is not None + if is_calc_runtime: + raise NotImplementedError("Runtime stats calculation is not implemented yet") + + gpu = None if not torch.cuda.is_available() else torch.cuda.get_device_name() + subblock_stats = { + "args": dict( + is_calc_runtime=is_calc_runtime, + gpu=gpu, + batch_size=batch_size, + prefill_seq_len=prefill_seq_len, + generation_seq_len=generation_seq_len, + prefill_queue_size=prefill_queue_size, + n_embd=n_embd, + n_head=n_head, + vocab_size=vocab_size, + benchmark_iterations=benchmark_iterations, + use_cuda_graph=use_cuda_graph, + weights_dtype=str(weights_dtype), + activations_dtype=str(activations_dtype), + kv_cache_dtype=str(kv_cache_dtype), + ), + "non_block": dict(), + "subblocks": list(), + } + # Compute runtime stats for unique subblocks only + if is_calc_runtime: + raise NotImplementedError("Runtime stats calculation is not implemented yet") + subblock_configs_nolayerindex = set( + [subblock_config["subblock_config"] for subblock_config in subblock_configs] + ) + + # dict[SubblockConfig, float], float + # TODO: Manage default values for calc_subblock_stats_config in one place, e.g. within a dataclass for hydra config. + synth_dataset_num_requests = calc_subblock_stats_config.get("runtime_stats", {}).get( + "synth_dataset_num_requests", 200 + ) + backend = calc_subblock_stats_config.get("runtime_stats", {}).get("backend", "trt_torch") + runtime_by_subblock_dict, non_block_runtime_ms = calc_runtime_ms_for_subblocks( + subblock_configs_nolayerindex, + vocab_size, + n_embd, + n_head, + master_puzzle_dir, + teacher_dir, + synth_dataset_num_requests, + backend, + ) + + sorted_subblock_config = sorted( + subblock_configs, key=lambda subblock_config: subblock_config["subblock_config"] + ) + it = ( + tqdm(sorted_subblock_config, desc="Measuring subblock runtimes") + if is_calc_runtime + else sorted_subblock_config + ) + for subblock_config_indexed in it: + subblock_config = subblock_config_indexed["subblock_config"] + parent_layer_indices = subblock_config_indexed["parent_layer_indices"] + + if is_calc_runtime: + total_runtime_ms = runtime_by_subblock_dict[subblock_config] + prefill_runtime_ms = None + decode_runtime_ms = None + else: + total_runtime_ms, prefill_runtime_ms, decode_runtime_ms = None, None, None + + subblock_memory = calculate_subblock_memory( + subblock_config, + batch_size, + prefill_seq_len, + generation_seq_len, + prefill_queue_size, + n_embd, + n_head, + weights_dtype, + kv_cache_dtype, + allocate_prefill_query, + model_config=model_config, + descriptor=descriptor, + ) + if not isinstance(subblock_memory, dict): + subblock_memory = {"memory_mib": subblock_memory, "kv_cache_memory_mib": 0.0} + + subblock_params = calculate_subblock_params(model_config, subblock_config, descriptor) + if moe_stats_file is not None: + subblock_active_params = calc_subblock_active_params( + subblock_config, + model_config, + descriptor, + n_embd, + moe_stats_file, + batch_size, + parent_layer_indices[0], + ) + else: + subblock_active_params = subblock_params + subblock_stats["subblocks"].append( + { + "subblock_config": subblock_config, + "subblock_config_class": type(subblock_config).__name__, + "runtime_ms": total_runtime_ms, + "prefill_runtime_ms": prefill_runtime_ms, + "decode_runtime_ms": decode_runtime_ms, + "num_params": subblock_params, + "active_params": subblock_active_params, + "parent_layer_index": parent_layer_indices[0], + **subblock_memory, + } + ) + + if is_calc_runtime: + # TODO: fix + # from puzzle_tools.calc_subblock_runtime import measure_non_block_runtime_ms + # non_block_runtime_ms, embedding_runtime_ms, lm_head_runtime_ms = \ + # measure_non_block_runtime_ms(batch_size, prefill_seq_len, generation_seq_len, n_embd, vocab_size, + # benchmark_iterations, use_cuda_graph) + embedding_runtime_ms, lm_head_runtime_ms = None, None + else: + non_block_runtime_ms, embedding_runtime_ms, lm_head_runtime_ms = None, None, None + non_block_memory = calculate_non_block_memory(n_embd, vocab_size, weights_dtype) + non_block_params = calculate_non_block_params(n_embd, vocab_size) + + # TODO + # the semantics here is wrong why do we refer, prefill_runtime_ms as embedding_runtime_ms and lm_head_runtime_ms as decode_runtime_ms ? + # Prefill is the first the user prompt inference, and Decode refer to the next generation process. both processes use all the model layers. + subblock_stats["non_block"] = { + "runtime_ms": non_block_runtime_ms, + "prefill_runtime_ms": embedding_runtime_ms, + "decode_runtime_ms": lm_head_runtime_ms, + "memory_mib": non_block_memory, + "num_params": non_block_params, + } + return subblock_stats + + +def launch_calc_subblock_stats(cfg: DictConfig) -> None: + """ + Launch the calc subblock stats function with Hydra configuration. + """ + mprint(f"Calculating subblock stats for puzzle directory: {cfg.puzzle_dir}") + mprint(f"Teacher directory: {cfg.teacher_dir}") + mprint( + f"Calc subblock stats config: {format_global_config(cfg.calc_subblock_stats, title='Calc subblock stats')}" + ) + + descriptor = ModelDescriptorFactory.get(cfg.descriptor) + calculate_subblock_stats_for_puzzle_dir( + cfg.calc_subblock_stats, + master_puzzle_dir=cfg.puzzle_dir, + teacher_dir=cfg.teacher_dir, + descriptor=descriptor, + model_hidden_sizes=cfg.calc_subblock_stats.get("model_hidden_sizes", OmegaConf.create([])), + ffn_hidden_sizes=cfg.calc_subblock_stats.get("ffn_hidden_sizes", OmegaConf.create([])), + batch_sizes=cfg.calc_subblock_stats.batch_sizes, + prefill_seq_len=cfg.calc_subblock_stats.prefill_seq_len, + generation_seq_len=cfg.calc_subblock_stats.generation_seq_len, + num_active_tokens_override=cfg.calc_subblock_stats.get("num_active_tokens_override", None), + prefill_queue_size=cfg.calc_subblock_stats.prefill_queue_size, + allocate_prefill_query=cfg.calc_subblock_stats.get("allocate_prefill_query", False), + benchmark_iterations=cfg.calc_subblock_stats.get("benchmark_iterations", None), + merge_with_existing_stats=cfg.calc_subblock_stats.merge_with_existing_stats, + subblock_stats_filename=cfg.calc_subblock_stats.subblock_stats_filename, + moe_stats_filename=cfg.calc_subblock_stats.moe_stats_filename, + ) + + +def calculate_subblock_stats_for_puzzle_dir( + calc_subblock_stats_config: DictConfig, + master_puzzle_dir: Path | str, + teacher_dir: Path | str, + descriptor: Type[ModelDescriptor], + model_hidden_sizes: ListConfig, + ffn_hidden_sizes: ListConfig, + batch_sizes: Iterable[int] = (1, 8, 16, 32, 64, 128, 256), + prefill_seq_len: int = 2048, + generation_seq_len: int = 2048, + num_active_tokens_override: int | None = None, + prefill_queue_size: int = 0, # it's an infery-llm thing + allocate_prefill_query: bool = False, + benchmark_iterations: ( + int | None + ) = None, # If set then compute runtime performance statistics. TODO: recommend default value, is 1000 good? + merge_with_existing_stats: bool = False, + subblock_stats_filename: str = "subblock_stats.json", + moe_stats_filename: str = "moe_stats.json", +) -> None: + # ==== START === Setup for attach-helper ==== + # import sys + # import os + # sys.path.insert(0, os.environ["ATTACH_HELPER_INSTALLATION_PATH"]) + # from attach_helper import debugging_setup + # debugging_setup() # You can optionally pass a name to identify the job (e.g. `debugging_setup(name="my_script")`) + # ==== END === Setup for attach-helper ==== + if isinstance(batch_sizes, str): + batch_sizes = [ + int(batch_size) for batch_size in batch_sizes.strip("[]").replace(" ", "").split(",") + ] + + master_puzzle_dir = Path(master_puzzle_dir) + teacher_dir = ( + Path(teacher_dir) if teacher_dir is not None else master_puzzle_dir / "ckpts" / "teacher" + ) + trust_remote_code = descriptor.requires_trust_remote_code() + model_config = load_model_config(teacher_dir, trust_remote_code=trust_remote_code) + # Get language model config for LM-specific attributes (VL models have nested config) + lm_config = descriptor.get_language_model_config(model_config) + subblock_configs = _load_subblock_configs(master_puzzle_dir, ffn_hidden_sizes) + + subblock_stats_file = master_puzzle_dir / subblock_stats_filename + if subblock_stats_file.exists() and not merge_with_existing_stats: + raise ValueError( + f"Subblock stats file {subblock_stats_file} already exists and `merge_with_existing_stats` was set to False." + ) + + if subblock_stats_file.exists(): + with open(subblock_stats_file) as f: + subblock_stats = json.load(f) + else: + subblock_stats = [] + + moe_stats_file = master_puzzle_dir / moe_stats_filename + if not moe_stats_file.exists(): + warnings.warn( + f"MOE stats file {moe_stats_file} does not exist, can't calculate num active params" + ) + moe_stats_file = None + + subblock_stats_args = {immutabledict(x["args"]) for x in subblock_stats} + + data_types = [ + ("nvfp4", "nvfp4", "nvfp4"), + (torch.int8, torch.int8, torch.int8), + (torch.int8, torch.int8, torch.bfloat16), + (torch.bfloat16, torch.bfloat16, torch.bfloat16), + ] + + model_hidden_sizes = model_hidden_sizes + [ + lm_config.hidden_size + ] # add a teacher model hidden size + for batch_size, ( + weights_dtype, + activations_dtype, + kv_cache_dtype, + ), model_hidden_size in product(batch_sizes, data_types, model_hidden_sizes): + if num_active_tokens_override is not None: + prefill_seq_len = generation_seq_len = int(num_active_tokens_override / batch_size / 2) + + curr_benchmark_iterations = ( + benchmark_iterations if weights_dtype == torch.bfloat16 else None + ) + + curr_subblock_stats = calculate_subblock_stats( + calc_subblock_stats_config, + teacher_dir=teacher_dir, + model_config=model_config, + descriptor=descriptor, + master_puzzle_dir=master_puzzle_dir, + subblock_configs=subblock_configs, + batch_size=batch_size, + prefill_seq_len=prefill_seq_len, + generation_seq_len=generation_seq_len, + prefill_queue_size=prefill_queue_size, + n_embd=model_hidden_size, + n_head=lm_config.num_attention_heads, + vocab_size=lm_config.vocab_size, + benchmark_iterations=curr_benchmark_iterations, + use_cuda_graph=True, + weights_dtype=weights_dtype, + activations_dtype=activations_dtype, + kv_cache_dtype=kv_cache_dtype, + allocate_prefill_query=allocate_prefill_query, + moe_stats_file=moe_stats_file, + ) + + if immutabledict(curr_subblock_stats["args"]) in subblock_stats_args: + raise ValueError( + f"Failed merging subblock_stats. The following arguments already existed in the file: {curr_subblock_stats['args']}" + ) + + subblock_stats.append(curr_subblock_stats) + + # TODO fix: add_int8_runtime_estimates(subblock_stats) + + json_dump(subblock_stats, subblock_stats_file) + + mprint(subblock_stats_file) + + +def _load_subblock_configs( + master_puzzle_dir: Path, ffn_hidden_sizes: ListConfig +) -> list[SubblockConfig]: + try: + subblock_configs = _load_subblock_configs_from_replacement_library(master_puzzle_dir) + except FileNotFoundError: + subblock_configs = _load_subblock_configs_from_subblock_library(master_puzzle_dir) + + # Extend subblock stats calculation space with ffn_hidden_sizes defined in the calc_subblock_stats section of the model config yaml file. + extra_ffn_subblock_configs = [] + for ffn_hidden_size in ffn_hidden_sizes: + # Use FFNConfig defaults (hidden_act will use its default value) + ffn_config = FFNConfig(intermediate_size=ffn_hidden_size) + extra_ffn_subblock_configs.append( + immutabledict({"subblock_config": ffn_config, "parent_layer_indices": tuple([-1])}) + ) # -1 to indicate that this sublock has no parent layer + subblock_configs.extend(extra_ffn_subblock_configs) + + return subblock_configs + + +def _load_subblock_configs_from_subblock_library(master_puzzle_dir: Path) -> list[SubblockConfig]: + subblocks_df = pd.read_json(master_puzzle_dir / "subblock_library.json") + subblocks_df["attention_config"] = subblocks_df["attention_config"].apply( + partial(_dataclass_from_dict, cls=AttentionConfig) + ) + subblocks_df["ffn_config"] = subblocks_df["ffn_config"].apply( + partial(_dataclass_from_dict, cls=FFNConfig) + ) + attention_configs = subblocks_df["attention_config"].dropna().drop_duplicates().tolist() + ffn_configs = subblocks_df["ffn_config"].dropna().drop_duplicates().tolist() + subblock_configs = attention_configs + ffn_configs + return subblock_configs + + +def _load_subblock_configs_from_replacement_library( + master_puzzle_dir: Path, +) -> list[SubblockConfig]: + """Load unique subblocks from replacement_library.json, e.g., + 256 = 32*8 unique sublocks will be returned for a model with 32 layers and the search space of + 4 intermediate_size + teacher_intermediate_size + ffn_noop + att_op (teacher) + att_noop. + + Args: + master_puzzle_dir (Path): Directory with "replacement_library.json" file + + Returns: + list[SubblockConfig]: + """ + replacement_library = json.loads((master_puzzle_dir / "replacement_library.json").read_text()) + subblock_configs = set() + for layer_replacement in replacement_library: + layer_replacement = parse_layer_replacement(layer_replacement) + + for block_config in layer_replacement["child_block_configs"]: + block_config: BlockConfig + attention_frozen_dict = immutabledict( + { + "subblock_config": block_config.attention, + "parent_layer_indices": tuple(layer_replacement["parent_layer_indices"]), + } + ) + ffn_frozen_dict = immutabledict( + { + "subblock_config": block_config.ffn, + "parent_layer_indices": tuple(layer_replacement["parent_layer_indices"]), + } + ) + subblock_configs.add(attention_frozen_dict) + subblock_configs.add(ffn_frozen_dict) + + if block_config.parallel_blocks is not None: + for block_idx, internal_block_config in enumerate(block_config.parallel_blocks): + attention_frozen_dict = immutabledict( + { + "subblock_config": internal_block_config.attention, + "parent_layer_indices": tuple( + layer_replacement["parent_layer_indices"] + ), + "inner_block_idx": block_idx, + } + ) + ffn_frozen_dict = immutabledict( + { + "subblock_config": internal_block_config.ffn, + "parent_layer_indices": tuple( + layer_replacement["parent_layer_indices"] + ), + "inner_block_idx": block_idx, + } + ) + subblock_configs.add(attention_frozen_dict) + subblock_configs.add(ffn_frozen_dict) + + subblock_configs = list(subblock_configs) + return subblock_configs + + +T_DataClass: TypeVar = Type[dataclasses.dataclass] + + +def _dataclass_from_dict( + d: dict | T_DataClass | None, + cls: T_DataClass, +) -> T_DataClass | None: + if isinstance(d, cls): + return d + if isinstance(d, dict): + return cls(**d) + if pd.isna(d): + return None + raise ValueError(f"_dataclass_from_dict: unrecognized {type(d)=} {d=}") + + +def add_int8_runtime_estimates(subblock_stats: list[dict]) -> None: + for curr_subblock_stats in subblock_stats: + args = curr_subblock_stats["args"] + if args["weights_dtype"] == "torch.int8": + assert args["activations_dtype"] == "torch.int8" + ffn_factor = 0.5 + attention_factor = 0.5 if args["kv_cache_dtype"] == "torch.int8" else 0.8 + + bf16_stats = _find_corresponding_bf16_stats(args, subblock_stats) + if bf16_stats is not None: + curr_subblocks = curr_subblock_stats["subblocks"] + [ + curr_subblock_stats["non_block"] + ] + bf16_subblocks = bf16_stats["subblocks"] + [bf16_stats["non_block"]] + for curr_subblock, bf16_subblock in zip(curr_subblocks, bf16_subblocks): + assert curr_subblock.get("subblock_config", None) == bf16_subblock.get( + "subblock_config", None + ) + is_attention = False + if (subblock_config := curr_subblock.get("subblock_config")) is not None: + if hasattr(subblock_config, "__dataclass_fields__"): + subblock_config = dataclasses.asdict(subblock_config) + is_attention = subblock_config.get("num_key_value_heads", None) is not None + runtime_factor = attention_factor if is_attention else ffn_factor + for stat_name, stat_value in bf16_subblock.items(): + if "runtime" in stat_name: + curr_subblock[stat_name] = stat_value * runtime_factor + + +def _find_corresponding_bf16_stats(args: dict, subblock_stats: list[dict]) -> dict | None: + scenario_keys = [ + "batch_size", + "prefill_seq_len", + "generation_seq_len", + "prefill_queue_size", + "gpu", + "n_embd", + "n_head", + "vocab_size", + ] + corresponding_bf16_args = { + **{k: v for k, v in args.items() if k in scenario_keys}, + "is_calc_runtime": True, + "weights_dtype": "torch.bfloat16", + "activations_dtype": "torch.bfloat16", + "kv_cache_dtype": "torch.bfloat16", + } + matching_bf16_stats = [ + stats + for stats in subblock_stats + if all( + [ + stats["args"][key] == corresponding_bf16_args[key] + for key in corresponding_bf16_args.keys() + ] + ) + ] + if len(matching_bf16_stats) == 0: + return None + if len(matching_bf16_stats) == 1: + return matching_bf16_stats[0] + raise ValueError(f"Found more than 1 matching bf16 stats for {args=}") diff --git a/modelopt/torch/puzzletron/tools/__init__.py b/modelopt/torch/puzzletron/tools/__init__.py new file mode 100644 index 0000000000..47f1c65a15 --- /dev/null +++ b/modelopt/torch/puzzletron/tools/__init__.py @@ -0,0 +1,15 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + diff --git a/modelopt/torch/puzzletron/tools/bypassed_training/child_init.py b/modelopt/torch/puzzletron/tools/bypassed_training/child_init.py new file mode 100644 index 0000000000..6b98d36a0e --- /dev/null +++ b/modelopt/torch/puzzletron/tools/bypassed_training/child_init.py @@ -0,0 +1,1181 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# mypy: ignore-errors + +"""Core logic for creating pruned child model state dicts from parent models. Used by init_child_from_parent.""" + +import concurrent.futures +import dataclasses +import json +import os +import re +import time +from copy import deepcopy +from enum import Enum +from functools import partial +from pathlib import Path +from typing import Any, Callable, Dict, List, Optional, Tuple, Union + +import torch +from transformers import PretrainedConfig +from typeguard import check_type + +from modelopt.torch.puzzletron.decilm.deci_lm_hf_code.block_config import ( + SUBBLOCK_CLS_DICT, + BlockConfig, + _get_dataclass_type, + _is_dataclass_type, +) +from modelopt.torch.puzzletron.pruning.pruning_utils import ( + ACTIVATIONS_LOG, + GQAInitMode, + HiddenSizeInitMode, + LinearInitMode, + MlpInitMode, + _cache_activations_log, + _init_attention_biases, + _init_attention_weights, + _init_mlp_module, + _init_moe_module, + _load_activations_log, + _load_expert_scores, + _select_expert_indices, +) +from modelopt.torch.puzzletron.tools.logger import aprint, mprint + +IgnoreFn = Callable[[str], bool] + +default_ignore_fn: IgnoreFn = lambda _: False + + +class Printer: + @staticmethod + def print(s: str) -> None: + print(s) + + +def _process_single_layer( + layer_idx: int, + pruning_mixin, + descriptor, + parent_state_dict: dict, + new_state_dict: dict, + original_config: PretrainedConfig, + new_config: PretrainedConfig, + gqa_init_mode: GQAInitMode, + mlp_init_mode: MlpInitMode, + mlp_init_config: Optional[dict[str, Any]], + linear_init_mode: LinearInitMode, + ignored_keys: set, + keys: dict, + is_original_mha: bool, + head_size: int, + hidden_size: int, +) -> Tuple[Dict[str, torch.Tensor], Dict[str, str]]: + """ + Process a single layer in parallel. Returns (layer_state_dict, keys_to_remove). + Thread-safe function for parallel layer processing. + """ + keys_to_remove = {} + layer_out_state_dict = {} + + # Delegate to pruning_mixin if available + if pruning_mixin is not None: + _layer_out = pruning_mixin.prune_single_layer( + layer_idx=layer_idx, + parent_state_dict=parent_state_dict, + new_state_dict=new_state_dict, + original_config=original_config, + new_config=new_config, + gqa_init_mode=gqa_init_mode, + mlp_init_mode=mlp_init_mode, + mlp_init_config=mlp_init_config, + linear_init_mode=linear_init_mode, + ignored_keys=ignored_keys, + keys=keys, + is_original_mha=is_original_mha, + head_size=head_size, + hidden_size=hidden_size, + keys_to_remove=keys_to_remove, + ) + layer_out_state_dict.update(_layer_out) + return layer_out_state_dict, keys_to_remove + + # Legacy inline processing (fallback when no pruning_mixin) + + parent_block_config = original_config.block_configs[layer_idx] + child_block_config = new_config.block_configs[layer_idx] + + # Attention processing + for part in ["weight", "bias"]: + attn_prefix = f"model.layers.{layer_idx}.self_attn" + q_key = f"{attn_prefix}.q_proj.{part}" + k_key = f"{attn_prefix}.k_proj.{part}" + v_key = f"{attn_prefix}.v_proj.{part}" + o_key = f"{attn_prefix}.o_proj.{part}" + attn_keys = [q_key, k_key, v_key, o_key] + # Drop attn keys that don't exist and required to be in the new state_dict + attn_keys = [key for key in attn_keys if key in new_state_dict.keys()] + if len(attn_keys) > 0 and all(key in keys for key in attn_keys): + for key in attn_keys: + keys_to_remove[key] = keys[key] + if all(key not in ignored_keys for key in attn_keys): + is_student_and_teacher_have_same_attention_implementation = all( + key in new_state_dict.keys() for key in attn_keys + ) + if is_student_and_teacher_have_same_attention_implementation: + if part == "weight": + wq, wk, wv, wo = _init_attention_weights( + gqa_init_mode=gqa_init_mode, + layer_idx=layer_idx, + new_state_dict=new_state_dict, + new_config=new_config, + original_state_dict=parent_state_dict, + original_config=original_config, + q_key=q_key, + k_key=k_key, + v_key=v_key, + o_key=o_key, + is_original_mha=is_original_mha, + head_size=head_size, + mlp_init_config=mlp_init_config, + ) + layer_out_state_dict[q_key], layer_out_state_dict[k_key] = wq, wk + layer_out_state_dict[v_key], layer_out_state_dict[o_key] = wv, wo + else: + bias_sd = _init_attention_biases( + gqa_init_mode=gqa_init_mode, + layer_idx=layer_idx, + new_state_dict=new_state_dict, + new_config=new_config, + original_state_dict=parent_state_dict, + original_config=original_config, + q_key=q_key, + k_key=k_key, + v_key=v_key, + o_key=o_key, + is_original_mha=is_original_mha, + head_size=head_size, + mlp_init_config=mlp_init_config, + ) + for bias_key, sd_key in zip("qkvo", [q_key, k_key, v_key, o_key]): + if bias_key in bias_sd.keys(): + layer_out_state_dict[sd_key] = bias_sd[bias_key] + + else: + linear_attn_key = f"{attn_prefix}.linear_attn.weight" + is_student_attn_replaced_with_linear = linear_attn_key in new_state_dict.keys() + if is_student_attn_replaced_with_linear: + if linear_init_mode == LinearInitMode.Random: + layer_out_state_dict[linear_attn_key] = new_state_dict[linear_attn_key] + elif linear_init_mode == LinearInitMode.FromTeacher: + layer_out_state_dict[linear_attn_key] = _init_linear_attn( + parent_state_dict, original_config, layer_idx, v_key, o_key + ) + else: + raise ValueError(f"Unknown {linear_init_mode=}") + else: + # student attn random init + for new_key in new_state_dict.keys(): + if attn_prefix in new_key: + layer_out_state_dict[new_key] = new_state_dict[new_key] + + # MLP/MoE processing + is_parent_moe = parent_block_config.ffn.is_moe + if not is_parent_moe: # not MoE, init the MLP + mlp_prefix = f"model.layers.{layer_idx}.mlp" + linear_mlp_key = f"{mlp_prefix}.linear_mlp.weight" + + is_student_mlp_replaced_with_linear = linear_mlp_key in new_state_dict.keys() + if is_student_mlp_replaced_with_linear: + if linear_init_mode == LinearInitMode.Random: + layer_out_state_dict[linear_mlp_key] = new_state_dict[linear_mlp_key] + elif linear_init_mode == LinearInitMode.FromTeacher: + teacher_mlp_state_dict = { + k.split(mlp_prefix + ".")[1]: v + for k, v in parent_state_dict.items() + if mlp_prefix in k + } + layer_out_state_dict[linear_mlp_key] = _init_linear_mlp(teacher_mlp_state_dict) + else: + raise ValueError(f"Unknown {linear_init_mode=}") + else: + layer_out_state_dict.update( + _init_mlp( + mlp_init_mode=mlp_init_mode, + layer_idx=layer_idx, + original_config=original_config, + mlp_init_config=mlp_init_config, + original_state_dict=parent_state_dict, + new_state_dict=new_state_dict, + new_config=new_config, + keys=keys, + ignored_keys=ignored_keys, + ) + ) + else: + is_child_moe = child_block_config.ffn.is_moe + if is_child_moe: + parent_moe_config = original_config.block_configs[layer_idx].ffn.moe + child_moe_config = new_config.block_configs[layer_idx].ffn.moe + if parent_moe_config == child_moe_config: + pass # copy the MoE as is + elif mlp_init_mode == MlpInitMode.MoEChannelPruning: + for expert_idx in range(parent_moe_config.num_local_experts): + layer_out_state_dict.update( + _init_mlp( + mlp_init_mode=mlp_init_mode, + layer_idx=layer_idx, + original_config=original_config, + mlp_init_config=mlp_init_config, + original_state_dict=parent_state_dict, + new_state_dict=new_state_dict, + new_config=new_config, + keys=keys, + ignored_keys=ignored_keys, + expert_idx=expert_idx, + ) + ) + + elif mlp_init_mode == MlpInitMode.ExpertRemoval: # remove some of the routed experts + router_key, new_experts_keys = _generate_moe_keys( + layer_idx, child_block_config.ffn.moe.num_local_experts + ) + _, orig_experts_keys = _generate_moe_keys( + layer_idx, parent_block_config.ffn.moe.num_local_experts + ) + keys_to_remove[router_key] = keys.get(router_key) + for key in sum(orig_experts_keys.values(), []): + keys_to_remove[key] = keys.get(key) + + orig_experts_weights = { + name: [parent_state_dict[key] for key in orig_experts_module_keys] + for name, orig_experts_module_keys in orig_experts_keys.items() + } + new_experts_weights = { + name: [new_state_dict[key] for key in new_experts_module_keys] + for name, new_experts_module_keys in new_experts_keys.items() + } + out_router_weights, out_experts_weights = _init_moe_module( + layer_idx=layer_idx, + mlp_init_mode=mlp_init_mode, + mlp_init_config=mlp_init_config, + orig_router_weight=parent_state_dict[router_key], + orig_experts_weights=orig_experts_weights, + new_router_weight=new_state_dict[router_key], + new_experts_weights=new_experts_weights, + ) + layer_out_state_dict[router_key] = out_router_weights + for name in new_experts_keys.keys(): + layer_out_state_dict.update( + zip(new_experts_keys[name], out_experts_weights[name]) + ) + elif child_block_config.ffn.no_op: # no-op, drop this layer + parent_mlp_prefix = f"model.layers.{layer_idx}.mlp" + for key in list(keys.keys()): + if key.startswith(parent_mlp_prefix): + keys_to_remove[key] = keys[key] + else: + assert mlp_init_mode == MlpInitMode.ConcatExpertsIntoDenseFFN, ( + "The parent layer is MoE and the child layer is a normal FFN. The only supported mode is ConcatExpertsAsMLP." + ) + + child_ffn_state_dict = _concatenate_experts_into_dense_ffn( + parent_state_dict, + mlp_init_config, + hidden_size, + layer_idx, + child_block_config, + parent_block_config, + ) + layer_out_state_dict.update(child_ffn_state_dict) + + for key in list(keys.keys()): + if key.startswith(f"model.layers.{layer_idx}.mlp"): + keys_to_remove[key] = keys[key] + + # Handle missing keys + for key_possibly_missing_in_student in [ + "self_attn.q_proj", + "self_attn.k_proj", + "self_attn.v_proj", + "self_attn.o_proj", + "mlp.gate_proj", + "mlp.up_proj", + "mlp.down_proj", + "input_layernorm", + "post_attention_layernorm", + ]: + key_possibly_missing_in_student = f".{layer_idx}.{key_possibly_missing_in_student}" + is_key_missing_from_student = ( + len([k for k in new_state_dict.keys() if key_possibly_missing_in_student in k]) == 0 + ) + if is_key_missing_from_student: + for k in list(keys.keys()): + if key_possibly_missing_in_student in k: + keys_to_remove[k] = keys[k] + + return layer_out_state_dict, keys_to_remove + + +@torch.no_grad() +def create_child_state_dict( + pruning_mixin, + descriptor, + original_state_dict: dict, + new_state_dict: dict, + original_config: PretrainedConfig, + new_config: PretrainedConfig, + gqa_init_mode: GQAInitMode, + ignore_fn: IgnoreFn = default_ignore_fn, + mlp_init_mode: MlpInitMode = MlpInitMode.CopyAsIs, + mlp_init_config: Optional[dict[str, Any]] = None, + owned_block_indexes: Optional[set[int]] = None, + linear_init_mode: LinearInitMode = LinearInitMode.Random, + hidden_size_init_mode: HiddenSizeInitMode = HiddenSizeInitMode.CopyAsIs, + channel_importance_path: Optional[str] = None, + max_layer_workers: Optional[int] = None, # Now optional - will auto-calculate if None +): + mprint("=== Starting create_child_state_dict with optimizations ===") + total_start_time = time.time() + + # Phase 1: Initial setup and validation + setup_start_time = time.time() + if owned_block_indexes is None: + owned_block_indexes = set(range(new_config.num_hidden_layers)) + + # Auto-calculate optimal layer workers: min(cpu_count, num_layers) + if max_layer_workers is None: + cpu_count = os.cpu_count() or 1 + num_layers = len(owned_block_indexes) + max_layer_workers = min(cpu_count, num_layers) + mprint( + f"Auto-calculated layer workers: min({cpu_count} CPUs, {num_layers} layers) = {max_layer_workers}" + ) + else: + mprint(f"Using specified layer workers: {max_layer_workers}") + + # Memory optimization: Pre-allocate output state dict with known shapes + expected_keys_and_shapes = {k: v.shape for k, v in new_state_dict.items()} + out_state_dict = {} + + # Pre-allocate tensors where possible to reduce memory fragmentation + for key, shape in expected_keys_and_shapes.items(): + if key in new_state_dict: + tensor = new_state_dict[key] + # Only make contiguous if necessary (memory optimization) + if not tensor.is_contiguous(): + out_state_dict[key] = tensor.contiguous() + else: + out_state_dict[key] = tensor + + # Get language model config for LM-specific attributes (VL models have nested config) + original_lm_config = descriptor.get_language_model_config(original_config) + new_lm_config = descriptor.get_language_model_config(new_config) + + # Check if original model is MHA (all layers have num_key_value_heads == num_attention_heads) + original_num_kv_heads_per_layer = [ + b.attention.num_key_value_heads for b in original_config.block_configs + ] + num_attention_heads = original_lm_config.num_attention_heads + is_original_mha = all(kv == num_attention_heads for kv in original_num_kv_heads_per_layer) + is_same_hidden_size = original_lm_config.hidden_size == new_lm_config.hidden_size + head_size = _get_head_dim(new_lm_config) + orig_head_size = _get_head_dim(original_lm_config) + assert head_size == orig_head_size, f"head_size {head_size} != orig_head_size {orig_head_size}" + + # Allow different hidden sizes for pruning + if not is_same_hidden_size: + assert new_lm_config.hidden_size <= original_lm_config.hidden_size, ( + f"New hidden size ({new_lm_config.hidden_size}) must be <= original ({original_lm_config.hidden_size})" + ) + assert hidden_size_init_mode != HiddenSizeInitMode.CopyAsIs, ( + "Cannot copy as is when hidden sizes differ" + ) + + hidden_size = original_lm_config.hidden_size + + ignored_keys = set([key for key in original_state_dict.keys() if ignore_fn(key)]) + for key in ignored_keys: + aprint(f"Ignoring key {key} and taking its init from new_state_dict") + out_state_dict[key] = new_state_dict[key] + + keys = { + match.group(1) if (match := re.search(r"(h\.\d+\..*)", key)) is not None else key: key + for key in original_state_dict.keys() + } + setup_time = time.time() - setup_start_time + mprint(f"Phase 1 - Setup and memory pre-allocation: {setup_time:.2f}s") + + # Phase 2: Parallel layer processing + layer_processing_start_time = time.time() + + # Prepare arguments for parallel processing + process_layer_partial = partial( + _process_single_layer, + pruning_mixin=pruning_mixin, + descriptor=descriptor, + parent_state_dict=original_state_dict, + new_state_dict=new_state_dict, + original_config=original_config, + new_config=new_config, + gqa_init_mode=gqa_init_mode, + mlp_init_mode=mlp_init_mode, + mlp_init_config=mlp_init_config, + linear_init_mode=linear_init_mode, + ignored_keys=ignored_keys, + keys=keys, + is_original_mha=is_original_mha, + head_size=head_size, + hidden_size=hidden_size, + ) + + # Process layers in parallel with optimal worker count + mprint( + f"Processing {len(owned_block_indexes)} layers in parallel with {max_layer_workers} workers..." + ) + layer_results = [] + all_keys_to_remove = {} + + with concurrent.futures.ThreadPoolExecutor(max_workers=max_layer_workers) as executor: + future_to_layer = { + executor.submit(process_layer_partial, layer_idx): layer_idx + for layer_idx in owned_block_indexes + } + + completed = 0 + for future in concurrent.futures.as_completed(future_to_layer): + layer_idx = future_to_layer[future] + try: + layer_state_dict, keys_to_remove = future.result() + layer_results.append((layer_idx, layer_state_dict)) + all_keys_to_remove.update(keys_to_remove) + + completed += 1 + if completed % 20 == 0 or completed == len( + owned_block_indexes + ): # More frequent progress updates + mprint(f"Completed {completed}/{len(owned_block_indexes)} layers") + except Exception as exc: + mprint(f"Layer {layer_idx} generated an exception: {exc}") + raise exc + + # Merge layer results into main state dict (memory efficient) + for layer_idx, layer_state_dict in layer_results: + out_state_dict.update(layer_state_dict) + + # Remove processed keys from the keys dict + for key_to_remove in all_keys_to_remove: + keys.pop(key_to_remove, None) + + layer_processing_time = time.time() - layer_processing_start_time + mprint( + f"Phase 2 - Parallel layer processing: {layer_processing_time:.2f}s ({max_layer_workers} workers)" + ) + + # Phase 3: Copy remaining keys from original model + copy_start_time = time.time() + keys_to_copy_from_orig_model = set(keys.values()) - ignored_keys + for key in keys_to_copy_from_orig_model: + # Memory optimization: avoid unnecessary copies + tensor = original_state_dict[key] + if not tensor.is_contiguous(): + out_state_dict[key] = tensor.contiguous() + else: + out_state_dict[key] = tensor + copy_time = time.time() - copy_start_time + mprint( + f"Phase 3 - Copy remaining keys: {copy_time:.2f}s ({len(keys_to_copy_from_orig_model)} keys)" + ) + + # Handle hidden size pruning for remaining keys + if not is_same_hidden_size: + out_state_dict = _apply_hidden_size_pruning( + out_state_dict, + original_state_dict, + new_config, + original_config, + descriptor, + hidden_size_init_mode, + channel_importance_path, + owned_block_indexes, + ) + + # Phase 4: Verification + verify_start_time = time.time() + _verify_state_dicts_match(out_state_dict, expected_keys_and_shapes) + verify_time = time.time() - verify_start_time + mprint(f"Phase 4 - Verification: {verify_time:.2f}s") + + total_time = time.time() - total_start_time + mprint(f"=== create_child_state_dict completed in {total_time:.2f}s ===") + mprint( + f"Breakdown: Setup {setup_time:.1f}s + ParallelProcessing {layer_processing_time:.1f}s + Copy {copy_time:.1f}s + Verify {verify_time:.1f}s" + ) + mprint( + f"Speedup: Used {max_layer_workers} workers for {len(owned_block_indexes)} layers (CPU utilization: {max_layer_workers}/{os.cpu_count() or 1})" + ) + + return out_state_dict + + +def _generate_moe_keys(layer_idx: int, num_experts: int) -> tuple[str, dict[str, list[str]]]: + mlp_prefix = f"model.layers.{layer_idx}.mlp" + router_key = f"{mlp_prefix}.router.weight" + names = ["gate_proj", "up_proj", "down_proj"] + experts_module_names = { + name: f"{mlp_prefix}.experts.{{expert_idx}}.{name}.weight" for name in names + } + return router_key, { + name: [module_name.format(expert_idx=expert_idx) for expert_idx in range(num_experts)] + for name, module_name in experts_module_names.items() + } + + +def _concatenate_experts_into_dense_ffn( + original_state_dict: dict[str, torch.Tensor], + mlp_init_config: Optional[dict], + hidden_size: int, + layer_idx: int, + child_block_config: BlockConfig, + parent_block_config: BlockConfig, +) -> dict[str, torch.Tensor]: + assert child_block_config.ffn.gated and child_block_config.ffn.hidden_act == "silu", ( + "Llama4 experts use SwiGLU." + ) + + # verify sizes + child_intermediate_size = child_block_config.ffn.intermediate_size + parent_moe_config = parent_block_config.ffn.moe + shared_expert_intermediate_dim = parent_moe_config.shared_expert_intermediate_dim + routed_expert_intermediate_dim = parent_moe_config.expert_intermediate_dim + total_concatenated_routed_experts_size = ( + child_intermediate_size - shared_expert_intermediate_dim + ) + assert total_concatenated_routed_experts_size % routed_expert_intermediate_dim == 0, ( + f"{child_intermediate_size=} " + f"{shared_expert_intermediate_dim=} " + f"{routed_expert_intermediate_dim=} " + f"{total_concatenated_routed_experts_size=} " + f"{total_concatenated_routed_experts_size % routed_expert_intermediate_dim=} != 0" + ) + num_concatenated_routed_experts = ( + total_concatenated_routed_experts_size // routed_expert_intermediate_dim + ) + + # if needed, concatenate some of the routed experts + if num_concatenated_routed_experts == 0: + print( + f"Removing all routed experts from layer {layer_idx}, turning the shared expert into a dense FFN." + ) + concat_routed_state_dict = dict() + else: + print( + f"Concatenating {num_concatenated_routed_experts} routed experts to the shared expert in layer {layer_idx}" + ) + router_key, orig_experts_keys = _generate_moe_keys( + layer_idx, parent_moe_config.num_local_experts + ) + orig_experts_weights = { + name: [original_state_dict[key] for key in orig_experts_module_keys] + for name, orig_experts_module_keys in orig_experts_keys.items() + } + _, experts_weights = _prune_experts_by_score( + mlp_init_config=mlp_init_config, + layer_idx=layer_idx, + orig_router_weight=original_state_dict[router_key], + orig_experts_weights=orig_experts_weights, + new_num_experts=num_concatenated_routed_experts, + ) + concat_dims = {"gate_proj": 0, "up_proj": 0, "down_proj": 1} + assert list(concat_dims) == list(experts_weights), ( + "concat_dims and experts_weights must have the same keys" + ) + concat_routed_state_dict = { + name: torch.cat(experts_weights[name], dim=concat_dims[name]) + for name in concat_dims.keys() + } + + # turn the shared expert into a normal FFN. concatenate the pruned routed experts if needed. + parent_shared_expert_prefix = f"model.layers.{layer_idx}.mlp.shared_expert" + child_ffn_prefix = f"model.layers.{layer_idx}.mlp" + child_ffn_state_dict = dict() + + for module_name in [ + "gate_proj", + "up_proj", + "down_proj", + ]: + shared_expert_key = f"{parent_shared_expert_prefix}.{module_name}.weight" + child_ffn_key = f"{child_ffn_prefix}.{module_name}.weight" + shared_expert_weight = original_state_dict[shared_expert_key] + concat_routed_weight = concat_routed_state_dict.get(module_name) + + if concat_routed_weight is None: + child_weight = shared_expert_weight + else: + child_weight = torch.cat( + [shared_expert_weight, concat_routed_weight], + dim=1 if module_name == "down_proj" else 0, + ) + child_ffn_state_dict[child_ffn_key] = child_weight + + return child_ffn_state_dict + + +def _verify_state_dicts_match( + state_dict: dict[str, torch.Tensor], + expected_keys_and_shapes: dict[str, torch.Size], +) -> None: + # Verify keys match + expected_keys = expected_keys_and_shapes.keys() + missing_keys = set(expected_keys) - set(state_dict.keys()) + unexpected_keys = set(state_dict.keys()) - set(expected_keys) + assert len(missing_keys) == 0 and len(unexpected_keys) == 0, ( + f"Missing keys: {missing_keys}\nUnexpected keys: {unexpected_keys}" + ) + + # Verify shapes match + shape_mismatches = [] + for key in expected_keys: + expected_shape = expected_keys_and_shapes[key] + actual_shape = state_dict[key].shape + if expected_shape != actual_shape: + shape_mismatches.append(f"{key}: expected {expected_shape}, got {actual_shape}") + + assert len(shape_mismatches) == 0, "Shape mismatches found:\n" + "\n".join(shape_mismatches) + print(""" +############################ +create_child_state_dict: all keys and shapes matched successfully. +############################ +""") + + +def _init_mlp( + *, + mlp_init_mode: Union[MlpInitMode, str], + layer_idx: int, + original_config: PretrainedConfig, + mlp_init_config: Optional[dict[str, Any]], + original_state_dict: dict, + new_state_dict: dict, + new_config: PretrainedConfig, + keys: dict[str, str], + ignored_keys: set[str], + expert_idx: Optional[int] = None, +) -> dict[str, torch.Tensor]: + out_state_dict = {} + + if mlp_init_mode == MlpInitMode.MoEChannelPruning: + if expert_idx is None: + return {} + mlp_prefix = f"model.layers.{layer_idx}.mlp.experts.{expert_idx}" + else: + mlp_prefix = f"model.layers.{layer_idx}.mlp" + + key = f"{mlp_prefix}.down_proj.weight" + if key not in keys: + return {} + + mlp_c_proj_key = keys[key] + if mlp_c_proj_key not in ignored_keys: + mlp_keys = [ + keys.pop(f"{mlp_prefix}.{module_name}.weight") + for module_name in ["down_proj", "gate_proj", "up_proj"] + ] + pruned_filters = None + projection_matrix = None + for mlp_key in mlp_keys: + expanded_dim = 1 if "down_proj" in mlp_key else 0 + if mlp_key in new_state_dict.keys(): + mlp_module_weight, pruned_filters, projection_matrix = _init_mlp_module( + mlp_init_mode, + mlp_prefix, + expanded_dim, + layer_idx, + new_state_dict[mlp_key], + new_config, + original_state_dict[mlp_key], + original_config, + mlp_init_config, + pruned_filters, + projection_matrix, + ) + out_state_dict[mlp_key] = mlp_module_weight + else: + mprint(f"mlp_key {mlp_key} not in new_state_dict") + return out_state_dict + + +def _prune_experts_by_score( + *, + mlp_init_config: dict[str, Any], + layer_idx: int, + orig_router_weight: torch.Tensor, + orig_experts_weights: dict[str, list[torch.Tensor]], + new_num_experts: int, +) -> tuple[torch.Tensor, dict[str, list[torch.Tensor]]]: + orig_num_experts = orig_router_weight.shape[0] + assert all( + len(orig_experts_module_weights) == orig_num_experts + for orig_experts_module_weights in orig_experts_weights.values() + ) + expert_scores = _load_expert_scores(mlp_init_config)[layer_idx] + assert len(expert_scores) == orig_num_experts + selected_experts = sorted( + range(orig_num_experts), + key=lambda i: expert_scores[i], + reverse=mlp_init_config.get("higher_is_better", True), + )[:new_num_experts] + result_router_weight = orig_router_weight[selected_experts] + result_experts_weights = { + name: [orig_experts_module_weights[i] for i in selected_experts] + for name, orig_experts_module_weights in orig_experts_weights.items() + } + return result_router_weight, result_experts_weights + + +def _init_linear_attn( + parent_state_dict: dict[str, torch.Tensor], + parent_config: PretrainedConfig, + layer_idx: int, + v_key: str, + o_key: str, +) -> torch.Tensor: + """ + Init a linear layer that operates like an attention layer that assigns score 1 to the current token + and score 0 to all others: out = (Wo @ Wv) @ x + """ + n_embd = parent_config.hidden_size + head_size = _get_head_dim(parent_config) + # Get num_kv_heads from config, compute n_heads_in_group + n_kv_heads = parent_config.block_configs[layer_idx].attention.num_key_value_heads + n_heads_in_group = parent_config.num_attention_heads // n_kv_heads + + wv = parent_state_dict[v_key] + wv = wv.view(n_kv_heads, head_size, n_embd) + wv_expanded = torch.repeat_interleave(wv, n_heads_in_group, dim=0).reshape(n_embd, n_embd) + + wo = parent_state_dict[o_key] + + w_linear = wo @ wv_expanded + return w_linear + + +def _init_linear_mlp(teacher_mlp_state_dict: dict[str, torch.Tensor]) -> torch.Tensor: + """ + A linear layer that does (W_down @ W_up) @ x, ignoring W_gate. + """ + if "linear_mlp.weight" in teacher_mlp_state_dict: # if the teacher itself is a linear layer + return teacher_mlp_state_dict["linear_mlp.weight"] + + w_up = teacher_mlp_state_dict["up_proj.weight"] + w_down = teacher_mlp_state_dict["down_proj.weight"] + w_linear = w_down @ w_up + return w_linear + + +def update_model_config( + model_config: PretrainedConfig, + model_config_overrides: None | list[dict[str, Any]] | str | dict | Path = None, +) -> PretrainedConfig: + new_model_config = deepcopy(model_config) + if model_config_overrides is None: + return new_model_config + + model_config_overrides = _parse_model_config_overrides( + model_config_overrides, model_config.num_hidden_layers + ) + + def override(item, item_overrides): + if item_overrides is None: + return item_overrides + if dataclasses.is_dataclass(item): + assert isinstance(item_overrides, dict) + return dataclass_override(item, item_overrides) + if isinstance(item, list): + assert isinstance(item_overrides, list) + return list_override(item, item_overrides) + return item_overrides + + def list_override(ls, ls_overrides: list): + assert len(ls) == len(ls_overrides) + return [override(item, item_overrides) for item, item_overrides in zip(ls, ls_overrides)] + + def dataclass_override(dc, dc_overrides: dict): + if not set(dc_overrides.keys()).issubset(dataclasses.asdict(dc).keys()): + raise ValueError( + f"Uknown overrides for dataclass {type(dc)}: {', '.join(set(dc_overrides.keys()) - dataclasses.asdict(dc).keys())}" + ) + field_types = {field.name: field.type for field in dataclasses.fields(dc)} + dc_changes = {} + for key, item_overrides in dc_overrides.items(): + previous_value, item_type = getattr(dc, key), field_types[key] + # if original block was no_op, we should not override it + if getattr(dc, "no_op", False): + return dc + + if previous_value is None and _is_dataclass_type(item_type): + new_value = _get_dataclass_type(item_type)(**item_overrides) + else: + new_value = override(previous_value, item_overrides) + check_type(new_value, item_type) + dc_changes[key] = new_value + return dataclasses.replace(dc, **dc_changes) + + new_model_config.block_configs = list_override( + new_model_config.block_configs, model_config_overrides + ) + + return new_model_config + + +def _parse_model_config_overrides( + model_config_overrides_json: str | dict | Path | list[dict], + n_layer: int, +) -> list[dict[str, Any]]: + """ + example model_config_overrides_dict: + { + "attention": [{"num_key_value_heads": 4}], + "ffn": [{"intermediate_size": 14336}] + } + """ + if isinstance(model_config_overrides_json, list) and isinstance( + model_config_overrides_json[0], dict + ): + return model_config_overrides_json + + if isinstance(model_config_overrides_json, dict): + model_config_overrides_dict = model_config_overrides_json + else: + if os.path.exists( + model_config_overrides_json + ): # using os.path.exists, because Path.exists throws an exception on long strings + model_config_overrides_json = Path(model_config_overrides_json).read_text() + print(f"I'm json loadsing over here. {model_config_overrides_json=}") + model_config_overrides_dict = json.loads(model_config_overrides_json) + + # Sanity checks and conversion to list of dictionaries + layer_wise_overrides = [{} for _ in range(n_layer)] + for config_key, config_value in model_config_overrides_dict.items(): + assert config_key in SUBBLOCK_CLS_DICT, f"Unknown config key: {config_key}" + assert isinstance(config_value, list), ( + f"Expected a list for {config_key}, got {config_value}" + ) + assert len(config_value) == n_layer or len(config_value) == 1, ( + f"Number of elements in {config_key} must be 1 or equal to the number of layers in the model" + ) + + if len(config_value) == 1: + model_config_overrides_dict[config_key] = config_value * n_layer + + for layer_idx in range(n_layer): + layer_wise_overrides[layer_idx][config_key] = model_config_overrides_dict[config_key][ + layer_idx + ] + + return layer_wise_overrides + + +def _apply_hidden_size_pruning( + out_state_dict: dict[str, torch.Tensor], + original_state_dict: dict[str, torch.Tensor], + new_config: PretrainedConfig, + original_config: PretrainedConfig, + descriptor, + hidden_size_init_mode: HiddenSizeInitMode, + channel_importance_path: Optional[str] = None, + owned_block_indexes: Optional[list[int]] = None, +) -> dict[str, torch.Tensor]: + """ + Apply hidden size pruning to all layers that depend on hidden_size. + This includes embeddings, layer norms, and any linear layers that haven't been handled yet. + """ + if isinstance(hidden_size_init_mode, str): + hidden_size_init_mode = HiddenSizeInitMode(hidden_size_init_mode) + + # Get language model config (for VL models this extracts the nested config) + original_lm_config = descriptor.get_language_model_config(original_config) + new_lm_config = descriptor.get_language_model_config(new_config) + + original_hidden_size = original_lm_config.hidden_size + new_hidden_size = new_lm_config.hidden_size + + if hidden_size_init_mode == HiddenSizeInitMode.CopyAsIs: + return out_state_dict + + # Load channel ranking if needed + if hidden_size_init_mode == HiddenSizeInitMode.PruneByChannelRanking: + if channel_importance_path is not None: + with open(channel_importance_path, "r") as f: + channel_ranking = json.load(f)["channel_importance_ranking"] + else: + raise ValueError( + "channel_ranking_path must be provided in hidden_size_init_config for PruneByChannelRanking mode" + ) + + # Handle embedding layer + embed_key = "model.embed_tokens.weight" + if embed_key in out_state_dict and embed_key in original_state_dict: + out_state_dict[embed_key] = _prune_hidden_size_dimension( + original_state_dict[embed_key], + new_hidden_size, + hidden_size_init_mode, + channel_ranking, + dim=1, + ) + else: + raise ValueError( + f"Embed key {embed_key} not found in out_state_dict or original_state_dict" + ) + + # Handle final layer norm + norm_key = "model.norm.weight" + if norm_key in out_state_dict and norm_key in original_state_dict: + out_state_dict[norm_key] = _prune_hidden_size_dimension( + original_state_dict[norm_key], + new_hidden_size, + hidden_size_init_mode, + channel_ranking, + dim=0, + ) + + # Handle LM head + lm_head_key = "lm_head.weight" + if lm_head_key in out_state_dict and lm_head_key in original_state_dict: + if out_state_dict[lm_head_key].shape[1] != new_hidden_size: + out_state_dict[lm_head_key] = _prune_hidden_size_dimension( + original_state_dict[lm_head_key], + new_hidden_size, + hidden_size_init_mode, + channel_ranking, + dim=1, + ) + + for block_idx in owned_block_indexes: + if new_config.block_configs[block_idx].parallel_blocks is None: + key_prefix = f"model.layers.{block_idx}" + out_state_dict = _prune_hidden_size_dimension_block( + out_state_dict, + new_hidden_size, + hidden_size_init_mode, + channel_ranking, + new_config.block_configs[block_idx], + key_prefix, + ) + else: + for internal_block_idx in range( + len(new_config.block_configs[block_idx].parallel_blocks) + ): + block_config = new_config.block_configs[block_idx].parallel_blocks[ + internal_block_idx + ] + key_prefix = f"model.layers.{block_idx}.parallel_blocks.{internal_block_idx}" + out_state_dict = _prune_hidden_size_dimension_block( + out_state_dict, + new_hidden_size, + hidden_size_init_mode, + channel_ranking, + block_config, + key_prefix, + ) + return out_state_dict + + +def _prune_hidden_size_dimension_block( + out_state_dict, + new_hidden_size, + hidden_size_init_mode, + channel_ranking, + block_config, + key_prefix, +): + for layer_norm in ["input_layernorm", "post_attention_layernorm"]: + for part in ["weight", "bias"]: + key = f"{key_prefix}.{layer_norm}.{part}" + if key in out_state_dict: + out_state_dict[key] = _prune_hidden_size_dimension( + out_state_dict[key], + new_hidden_size, + hidden_size_init_mode, + channel_ranking, + dim=0, + ) + attn_prefix = f"{key_prefix}.self_attn" + if block_config.attention.replace_with_linear: + linear_attn_key = f"{attn_prefix}.linear_attn.weight" + for dim in [0, 1]: + out_state_dict[linear_attn_key] = _prune_hidden_size_dimension( + out_state_dict[linear_attn_key], + new_hidden_size, + hidden_size_init_mode, + channel_ranking, + dim=dim, + ) + elif block_config.attention.is_mamba: + for proj in ["in", "out"]: + mamba_key = f"{attn_prefix}.mamba_mixer.{proj}_proj.weight" + out_state_dict[mamba_key] = _prune_hidden_size_dimension( + out_state_dict[mamba_key], + new_hidden_size, + hidden_size_init_mode, + channel_ranking, + dim=1 if proj == "in" else 0, + ) + else: + for k in "qkvo": + for part in ["weight", "bias"]: + if k in "qkv" and part == "bias": + continue + key = f"{attn_prefix}.{k}_proj.{part}" + if key in out_state_dict: + out_state_dict[key] = _prune_hidden_size_dimension( + out_state_dict[key], + new_hidden_size, + hidden_size_init_mode, + channel_ranking, + dim=1 if part == "weight" and k in "qkv" else 0, + ) + ffn_prefix = f"{key_prefix}.mlp" + if block_config.ffn.replace_with_linear: + linear_mlp_key = f"{ffn_prefix}.linear_mlp.weight" + for dim in [0, 1]: + out_state_dict[linear_mlp_key] = _prune_hidden_size_dimension( + out_state_dict[linear_mlp_key], + new_hidden_size, + hidden_size_init_mode, + channel_ranking, + dim=dim, + ) + elif block_config.ffn.moe is not None: + router_key = f"{ffn_prefix}.router.weight" + out_state_dict[router_key] = _prune_hidden_size_dimension( + out_state_dict[router_key], + new_hidden_size, + hidden_size_init_mode, + channel_ranking, + dim=1, + ) + _prune_hidden_size_dimension_mlp( + f"{ffn_prefix}.shared_expert", + out_state_dict, + new_hidden_size, + hidden_size_init_mode, + channel_ranking, + ) + for expert_idx in range(block_config.ffn.moe.num_local_experts): + _prune_hidden_size_dimension_mlp( + f"{ffn_prefix}.experts.{expert_idx}", + out_state_dict, + new_hidden_size, + hidden_size_init_mode, + channel_ranking, + ) + else: + _prune_hidden_size_dimension_mlp( + ffn_prefix, out_state_dict, new_hidden_size, hidden_size_init_mode, channel_ranking + ) + return out_state_dict + + +def _prune_hidden_size_dimension_mlp( + name_prefix, out_state_dict, new_hidden_size, hidden_size_init_mode, channel_ranking +): + for proj in ["gate_proj", "up_proj", "down_proj"]: + for part in ["weight", "bias"]: + if proj != "down_proj" and part == "bias": + continue + key = f"{name_prefix}.{proj}.{part}" + if key in out_state_dict: + out_state_dict[key] = _prune_hidden_size_dimension( + out_state_dict[key], + new_hidden_size, + hidden_size_init_mode, + channel_ranking, + dim=1 if part == "weight" and proj != "down_proj" else 0, + ) + + +def _prune_hidden_size_dimension( + original_tensor: torch.Tensor, + new_hidden_size: int, + hidden_size_init_mode: HiddenSizeInitMode, + channel_ranking: Optional[list[int]] = None, + dim: int = -1, +) -> torch.Tensor: + """ + Prune a tensor along the specified dimension to match the new hidden size. + """ + original_size = original_tensor.shape[dim] + + if hidden_size_init_mode == HiddenSizeInitMode.Random: + # Initialize with random weights + new_shape = list(original_tensor.shape) + new_shape[dim] = new_hidden_size + return torch.randn(new_shape, dtype=original_tensor.dtype, device=original_tensor.device) + + elif hidden_size_init_mode == HiddenSizeInitMode.Truncate: + # Simple truncation - take the first new_hidden_size elements + if dim == -1: + return original_tensor[..., :new_hidden_size] + elif dim == 0: + return original_tensor[:new_hidden_size, ...] + elif dim == 1: + return original_tensor[:, :new_hidden_size, ...] + else: + # Handle other dimensions + slices = [slice(None)] * original_tensor.ndim + slices[dim] = slice(new_hidden_size) + return original_tensor[tuple(slices)] + + elif hidden_size_init_mode == HiddenSizeInitMode.PruneByChannelRanking: + if channel_ranking is None: + raise ValueError("Channel ranking must be provided for PruneByChannelRanking mode") + + # Use channel ranking to select the most important channels + if len(channel_ranking) < new_hidden_size: + raise ValueError( + f"Channel ranking has {len(channel_ranking)} channels but need {new_hidden_size}" + ) + + # Take the top new_hidden_size channels according to ranking + selected_channels = channel_ranking[:new_hidden_size] + + if dim == -1: + return original_tensor[..., selected_channels] + elif dim == 0: + return original_tensor[selected_channels, ...] + elif dim == 1: + return original_tensor[:, selected_channels, ...] + else: + # Handle other dimensions + slices = [slice(None)] * original_tensor.ndim + slices[dim] = selected_channels + return original_tensor[tuple(slices)] + + else: + raise ValueError(f"Unsupported hidden_size_init_mode: {hidden_size_init_mode}") + + +def _get_head_dim(config) -> int: + """Get head dimension from config in a model-agnostic way. + + Some models like Llama have `head_dim` as a direct attribute, while others + like Qwen2 don't. This helper computes it from hidden_size and num_attention_heads. + """ + if hasattr(config, "head_dim") and config.head_dim is not None: + return config.head_dim + return config.hidden_size // config.num_attention_heads diff --git a/modelopt/torch/puzzletron/tools/bypassed_training/init_child_from_parent.py b/modelopt/torch/puzzletron/tools/bypassed_training/init_child_from_parent.py new file mode 100644 index 0000000000..783d233c3e --- /dev/null +++ b/modelopt/torch/puzzletron/tools/bypassed_training/init_child_from_parent.py @@ -0,0 +1,195 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# mypy: ignore-errors + +"""Initialize child models from parent models using AnyModel approach with deci_x_patcher.""" + +import json +import time +from pathlib import Path +from typing import Optional + +import torch +import yaml +from transformers import AutoModelForCausalLM + +from modelopt.torch.puzzletron.anymodel.model_descriptor import ( + ModelDescriptor, + ModelDescriptorFactory, +) +from modelopt.torch.puzzletron.anymodel.puzzformer import deci_x_patcher +from modelopt.torch.puzzletron.tools.bypassed_training.child_init import ( + GQAInitMode, + HiddenSizeInitMode, + LinearInitMode, + MlpInitMode, + create_child_state_dict, + update_model_config, +) +from modelopt.torch.puzzletron.tools.checkpoint_utils import copy_tokenizer, load_state_dict +from modelopt.torch.puzzletron.tools.checkpoint_utils_hf import _save_checkpoint, load_model_config +from modelopt.torch.puzzletron.tools.logger import mprint +from modelopt.torch.puzzletron.tools.sharded_checkpoint_utils import _get_model_class_from_config + + +def init_child_from_parent( + descriptor: ModelDescriptor, + pruning_mixin, + parent_checkpoint_dir: str, + model_config_overrides_dict: dict | str, + output_checkpoint_dir: str, + gqa_init_mode: GQAInitMode, + mlp_init_mode: MlpInitMode, + mlp_init_config_yaml: Optional[str], + linear_init_mode: LinearInitMode, + hidden_size_init_mode: Optional[HiddenSizeInitMode] = None, + channel_importance_path: Optional[str] = None, + max_workers: Optional[int] = None, # Auto-calculate optimal workers if None + max_layer_workers: Optional[int] = None, # Auto-calculate optimal workers if None +) -> None: + """ + Init child models from parent models in the style of bypass training, + but without having to run the entire bypass pipeline. + + Uses AnyModel approach with deci_x_patcher for heterogeneous layer configurations. + + I/O Optimization Parameters: + - max_workers: Number of threads for parallel file I/O (default: auto-calculate min(CPU count, num files)) + - max_layer_workers: Number of threads for parallel layer processing (default: auto-calculate min(CPU count, num layers)) + """ + assert ( + gqa_init_mode not in [GQAInitMode.RandomKV, GQAInitMode.RandomBlock] + and mlp_init_mode != MlpInitMode.Random + and linear_init_mode != LinearInitMode.Random + ), ( + "We do not support random init of any subblock in this script to avoid initializing the student model" + ) + + descriptor = ModelDescriptorFactory.get(descriptor) + + copy_tokenizer( + parent_checkpoint_dir, + output_checkpoint_dir, + trust_remote_code=descriptor.requires_trust_remote_code(), + ) + + parent_model_config = load_model_config( + parent_checkpoint_dir, trust_remote_code=descriptor.requires_trust_remote_code() + ) + parent_state_dict = load_state_dict(parent_checkpoint_dir) + + # Parse JSON if string + if isinstance(model_config_overrides_dict, str): + model_config_overrides_dict = json.loads(model_config_overrides_dict) + + # Separate global config overrides from block-level overrides + global_config_overrides = {} + block_config_overrides = {} + + for key, value in model_config_overrides_dict.items(): + if key in ["hidden_size"]: + global_config_overrides[key] = value + else: + block_config_overrides[key] = value + + # Load child model config with global overrides + child_model_config = load_model_config( + parent_checkpoint_dir, + model_config_overrides=global_config_overrides, + ignore_unexpected_config_keys=True, + trust_remote_code=descriptor.requires_trust_remote_code(), + ) + + # Apply block-level overrides if any + if block_config_overrides: + child_model_config = update_model_config( + model_config=child_model_config, + model_config_overrides=block_config_overrides, + ) + + with torch.device("meta"): + # Pass block_configs explicitly so patcher works for VL models where + # decoder layers receive nested config (e.g., text_config) without block_configs + with deci_x_patcher( + model_descriptor=descriptor, block_configs=child_model_config.block_configs + ): + model_class = _get_model_class_from_config(child_model_config) + # AutoModelForCausalLM uses from_config(); concrete model classes use _from_config() + if model_class is AutoModelForCausalLM: + trust_remote_code = descriptor.requires_trust_remote_code() + child_model = model_class.from_config( + child_model_config, trust_remote_code=trust_remote_code + ) + else: + child_model = model_class._from_config(child_model_config) + + child_state_dict_with_meta_tensors = child_model.state_dict() + + mlp_init_config = ( + yaml.safe_load(mlp_init_config_yaml) + if isinstance(mlp_init_config_yaml, str) + else mlp_init_config_yaml + ) + + # Profile create_child_state_dict with automatic layer parallelization + mprint("Starting create_child_state_dict...") + start_time = time.time() + child_state_dict = create_child_state_dict( + pruning_mixin=pruning_mixin, + descriptor=descriptor, + original_state_dict=parent_state_dict, + new_state_dict=child_state_dict_with_meta_tensors, + original_config=parent_model_config, + new_config=child_model_config, + gqa_init_mode=gqa_init_mode, + mlp_init_mode=mlp_init_mode, + mlp_init_config=mlp_init_config, + linear_init_mode=linear_init_mode, + hidden_size_init_mode=hidden_size_init_mode or HiddenSizeInitMode.CopyAsIs, + channel_importance_path=channel_importance_path, + max_layer_workers=max_layer_workers, + ) + create_child_state_dict_time = time.time() - start_time + mprint(f"create_child_state_dict completed in {create_child_state_dict_time:.2f} seconds") + + # Profile _save_checkpoint with automatic I/O worker calculation + mprint("Starting _save_checkpoint...") + actual_io_workers = max_workers if max_workers else "auto" + mprint(f"I/O Settings: max_workers={actual_io_workers}") + start_time = time.time() + _save_checkpoint( + child_model_config, + child_state_dict, + output_checkpoint_dir, + descriptor, + max_workers=max_workers, + ) + save_checkpoint_time = time.time() - start_time + mprint(f"_save_checkpoint completed in {save_checkpoint_time:.2f} seconds") + + # Print profiling summary with actual worker counts used + total_core_time = create_child_state_dict_time + save_checkpoint_time + actual_layer_workers = max_layer_workers if max_layer_workers else "auto" + actual_io_workers = max_workers if max_workers else "auto" + mprint(f"\n=== PROFILING SUMMARY ===") + mprint( + f"create_child_state_dict: {create_child_state_dict_time:.2f}s ({create_child_state_dict_time / total_core_time * 100:.1f}%)" + ) + mprint( + f"_save_checkpoint: {save_checkpoint_time:.2f}s ({save_checkpoint_time / total_core_time * 100:.1f}%)" + ) + mprint(f"Total core processing: {total_core_time:.2f}s") + mprint(f"Optimizations: I/O workers={actual_io_workers}, Layer workers={actual_layer_workers}") + mprint(f"=========================\n") diff --git a/modelopt/torch/puzzletron/tools/checkpoint_utils.py b/modelopt/torch/puzzletron/tools/checkpoint_utils.py new file mode 100644 index 0000000000..4488898e33 --- /dev/null +++ b/modelopt/torch/puzzletron/tools/checkpoint_utils.py @@ -0,0 +1,192 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# mypy: ignore-errors + +"""Utilities for loading and initializing PyTorch model checkpoints (AnyModel / HF layouts).""" + +import concurrent.futures +import warnings +from functools import partial +from pathlib import Path +from typing import Literal, TypeVar + +import torch +from safetensors.torch import load_file as safe_load_file +from torch import nn +from transformers import AutoTokenizer +from transformers.utils import SAFE_WEIGHTS_INDEX_NAME, SAFE_WEIGHTS_NAME + +from modelopt.torch.puzzletron.tools.checkpoint_utils_hf import load_model_config +from modelopt.torch.puzzletron.tools.common import infer_weights_dtype + +SAFETENSORS_SUBBLOCKS_DIR_NAME = "subblocks_safetensors" +PTH_SUBBLOCKS_DIR_NAME = "subblocks" +STATE_DICT_FILE_NAME = "model.pth" + + +def load_state_dict(checkpoint_dir: Path | str) -> dict[str, torch.Tensor]: + checkpoint_dir = _normalize_checkpoint_dir(checkpoint_dir) + + if (state_dict_path := checkpoint_dir / STATE_DICT_FILE_NAME).exists(): + return torch.load(state_dict_path, map_location="cpu", weights_only=True) + + if (safetensors_subblocks_dir := checkpoint_dir / SAFETENSORS_SUBBLOCKS_DIR_NAME).exists(): + return _load_state_dict_from_subblocks(safetensors_subblocks_dir) + + if (pth_subblocks_dir := checkpoint_dir / PTH_SUBBLOCKS_DIR_NAME).exists(): + return _load_state_dict_from_subblocks(pth_subblocks_dir) + + if (checkpoint_dir / SAFE_WEIGHTS_INDEX_NAME).exists() or ( + checkpoint_dir / SAFE_WEIGHTS_NAME + ).exists(): + from modelopt.torch.puzzletron.tools.sharded_checkpoint_utils import ( + load_sharded_state_dict, # local import to avoid circular import + ) + + return load_sharded_state_dict(checkpoint_dir) + + raise FileNotFoundError( + f"Couldn't find state dict path or subblocks dir inside {checkpoint_dir}" + ) + + +def _normalize_checkpoint_dir(checkpoint_dir: Path | str) -> Path: + checkpoint_dir = Path(checkpoint_dir) + if checkpoint_dir.is_file(): + checkpoint_dir = checkpoint_dir.parent + return checkpoint_dir + + +def _load_state_dict_from_subblocks(subblocks_dir: Path) -> dict[str, torch.Tensor]: + torch_paths = list(subblocks_dir.glob("*.pth")) + safetensors_paths = list(subblocks_dir.glob("*.safetensors")) + + if len(torch_paths) != 0: + load_fn = partial(torch.load, map_location="cpu", weights_only=True) + file_paths = torch_paths + elif len(safetensors_paths) != 0: + load_fn = safe_load_file + file_paths = safetensors_paths + else: + raise ValueError(f"No tensor files found in {subblocks_dir=}") + + with concurrent.futures.ThreadPoolExecutor() as executor: + state_dict_shards = list(executor.map(load_fn, file_paths)) + + state_dict = {k: v for shard in state_dict_shards for k, v in shard.items()} + return state_dict + + +NNModule = TypeVar("NNModule", bound=nn.Module) + + +def init_module_with_state_dict( + state_dict: dict[str, torch.Tensor], + module_cls: type[NNModule], + *init_args, + **init_kwargs, +) -> NNModule: + weights_dtype = infer_weights_dtype(state_dict) + module = init_empty_module(module_cls, weights_dtype, *init_args, **init_kwargs) + module.load_state_dict(state_dict) + return module + + +def init_empty_module( + module_cls: type[NNModule], + dtype: torch.dtype, + *init_args, + **init_kwargs, +) -> NNModule: + default_dtype = torch.get_default_dtype() + current_device = torch.ones(1).device + torch.set_default_dtype(dtype) + module = skip_init(module_cls, *init_args, device=current_device, **init_kwargs) + torch.set_default_dtype(default_dtype) + return module + + +def skip_init(module_cls, *args, **kwargs) -> nn.Module: + """Heavily inspired by torch.nn.utils.skip_init but does not require the module to accept a "device" kwarg.""" + if not issubclass(module_cls, torch.nn.Module): + raise RuntimeError(f"Expected a Module; got {module_cls}") + + final_device = kwargs.pop("device", "cpu") + with torch.device("meta"): + module = module_cls(*args, **kwargs) + + module = module.to_empty(device=final_device) + return module + + +def is_valid_decilm_checkpoint(checkpoint_dir: Path | str, trust_remote_code: bool = False) -> bool: + """True if the checkpoint config loads and defines ``block_configs`` (AnyModel / puzzletron layout). + + Args: + checkpoint_dir: Path to checkpoint directory + trust_remote_code: If True, allows execution of custom code from the model repository. + This is a security risk if the model source is untrusted. Only set to True if you + trust the source of the model. Defaults to False for security. + + Returns: + True if the config has ``block_configs``, False otherwise + """ + try: + model_config = load_model_config(checkpoint_dir, trust_remote_code=trust_remote_code) + if not hasattr(model_config, "block_configs") or model_config.block_configs is None: + warnings.warn( + f"Skipping checkpoint '{checkpoint_dir}' - missing block_configs (not an AnyModel-style layout)" + ) + return False + return True + except Exception as e: + warnings.warn(f"Skipping checkpoint '{checkpoint_dir}' - failed to load config: {e}") + return False + + +def copy_tokenizer( + source_dir_or_tokenizer_name: Path | str, + target_dir: Path | str, + on_failure: Literal["raise", "warn"] = "raise", + trust_remote_code: bool = False, +) -> None: + """Prefer loading the tokenizer from huggingface hub (when tokenizer_name.txt file is available) + to avoid collision between transformers versions. + """ + source_tokenizer_name_path = Path(source_dir_or_tokenizer_name) / "tokenizer_name.txt" + if source_tokenizer_name_path.exists(): + source_dir_or_tokenizer_name = source_tokenizer_name_path.read_text().strip() + + tokenizer = None + try: + tokenizer = AutoTokenizer.from_pretrained( + source_dir_or_tokenizer_name, trust_remote_code=trust_remote_code + ) + except Exception: + message = f"Couldn't load tokenizer from '{source_dir_or_tokenizer_name}'" + if on_failure == "raise": + raise FileNotFoundError(message) + else: + warnings.warn(message) + + if tokenizer is not None: + target_dir = Path(target_dir) + target_dir.mkdir(exist_ok=True, parents=True) + tokenizer.save_pretrained(target_dir) + + target_tokenizer_name_path = target_dir / "tokenizer_name.txt" + is_given_tokenizer_name_as_argument = not Path(source_dir_or_tokenizer_name).exists() + if is_given_tokenizer_name_as_argument: + target_tokenizer_name_path.write_text(source_dir_or_tokenizer_name) diff --git a/modelopt/torch/puzzletron/tools/checkpoint_utils_hf.py b/modelopt/torch/puzzletron/tools/checkpoint_utils_hf.py new file mode 100644 index 0000000000..1c6dcb36d8 --- /dev/null +++ b/modelopt/torch/puzzletron/tools/checkpoint_utils_hf.py @@ -0,0 +1,408 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# mypy: ignore-errors + +""" +Utilities for loading and saving Hugging Face-format checkpoints (``AutoConfig`` + optional ``block_configs``). +""" + +import concurrent.futures +import contextlib +import dataclasses +import fcntl +import os +import time +import warnings +from collections import defaultdict +from collections.abc import Callable, Mapping +from pathlib import Path +from typing import Any, BinaryIO + +import torch +import transformers +from safetensors.torch import save_file as safe_save_file +from transformers import AutoConfig, AutoModelForCausalLM, PretrainedConfig, PreTrainedModel +from transformers.dynamic_module_utils import get_class_from_dynamic_module +from transformers.utils import SAFE_WEIGHTS_INDEX_NAME + +from modelopt.torch.puzzletron.decilm.deci_lm_hf_code.block_config import maybe_cast_block_configs +from modelopt.torch.puzzletron.tools.common import infer_weights_dtype +from modelopt.torch.puzzletron.tools.logger import mprint +from modelopt.torch.puzzletron.tools.post_init_sparse import SparsityMethod +from modelopt.torch.puzzletron.tools.robust_json import json_dumps + +SAFETENSORS_SUBBLOCKS_DIR_NAME = "subblocks_safetensors" +PTH_SUBBLOCKS_DIR_NAME = "subblocks" +RELATIVE_SUBBLOCKS_DIR = Path(SAFETENSORS_SUBBLOCKS_DIR_NAME) + + +# TODO: (esegal) Should ask the model for something like this +NON_LAYER_MODULE_TO_FILE_TYPE = { + "model.embed_tokens": "embeddings", + "model.norm": "lm_head", + "lm_head": "lm_head", +} +MODULE_WITHIN_LAYER_TO_FILE_TYPE = { + "input_layernorm": "attention", + "self_attn": "attention", + "post_attention_layernorm": "ffn", + "mlp": "ffn", + "parallel_blocks": "multi_block", +} +LAYERS_MODULE_NAME = "model.layers" + + +def force_cache_dynamic_modules( + config: PretrainedConfig, checkpoint_dir: Path | str, trust_remote_code: bool = False +): + has_remote_code = ( + hasattr(config, "auto_map") + and isinstance(config.auto_map, dict) + and "AutoConfig" in config.auto_map.keys() + ) + if has_remote_code and trust_remote_code: + for class_reference in config.auto_map.values(): + _ = get_class_from_dynamic_module(class_reference, checkpoint_dir) + + +def load_model_config( + checkpoint_dir: Path | str, + model_config_overrides: Mapping | None = None, + ignore_unexpected_config_keys: bool = False, + trust_remote_code: bool = False, +): + """Load model configuration from a checkpoint directory. + + Args: + checkpoint_dir: Path to the checkpoint directory (e.g. containing config.json). + model_config_overrides: Optional mapping of config overrides. + ignore_unexpected_config_keys: If True, ignore unexpected config keys. + trust_remote_code: If True, allows execution of custom code from the model repository. + This is a security risk if the model source is untrusted. Only set to True if you + trust the source of the model. Defaults to False for security. + + Returns: + Loaded model configuration (PretrainedConfig). + """ + if not isinstance(checkpoint_dir, Path): + checkpoint_dir = Path(checkpoint_dir) + + if model_config_overrides is None: + model_config_overrides = {} + + config, unused_kwargs = AutoConfig.from_pretrained( + checkpoint_dir, + trust_remote_code=trust_remote_code, + return_unused_kwargs=True, + **model_config_overrides, + ) + if hasattr(config, "block_configs"): + config.block_configs = maybe_cast_block_configs(config.block_configs) + + force_cache_dynamic_modules(config, checkpoint_dir, trust_remote_code=trust_remote_code) + + if not ignore_unexpected_config_keys: + if unused_kwargs: + raise ValueError(f"Unexpected config keys: {unused_kwargs.keys()}") + + return config + + +def _get_model_class_from_config(config: PretrainedConfig) -> type: + """Resolve HuggingFace model class from ``config.architectures`` (see puzzletron checkpoint_utils_hf).""" + if hasattr(config, "architectures") and config.architectures: + model_class_name = config.architectures[0] + if hasattr(transformers, model_class_name): + return getattr(transformers, model_class_name) + mprint( + f"Warning: {model_class_name} not found in transformers, " + "falling back to AutoModelForCausalLM" + ) + return AutoModelForCausalLM + + +def init_model_from_config( + config: PretrainedConfig, + *, + trust_remote_code: bool = False, + **kwargs, +) -> PreTrainedModel: + """Build a model from config on meta/uninitialized weights (used e.g. for subblock param counts). + + ``trust_remote_code`` defaults to False (only ``AutoModelForCausalLM.from_config`` uses it). + Pass True when loading configs that rely on custom modeling code from the checkpoint. + """ + model_class = _get_model_class_from_config(config) + if model_class is AutoModelForCausalLM: + return model_class.from_config(config, trust_remote_code=trust_remote_code, **kwargs) + # Concrete model classes (e.g. GptOssForCausalLM): _from_config forwards kwargs to __init__, + # which does not accept trust_remote_code (only AutoModel uses it when loading custom code). + return model_class._from_config(config, **kwargs) + + +def save_checkpoint( + model: PreTrainedModel, + checkpoint_dir: Path | str, + descriptor: "ModelDescriptor", +) -> None: + _save_checkpoint(model.config, model.state_dict(), checkpoint_dir, descriptor) + + +def _save_checkpoint( + model_config: PretrainedConfig, + state_dict: dict[str, torch.Tensor], + checkpoint_dir: Path | str, + descriptor: "ModelDescriptor", + max_workers: int | None = None, # Now optional - will auto-calculate if None +) -> None: + from modelopt.torch.puzzletron.anymodel.model_descriptor import ModelDescriptor + + if not isinstance(checkpoint_dir, Path): + checkpoint_dir = Path(checkpoint_dir) + + checkpoint_dir.mkdir(parents=True, exist_ok=True) + + # Phase 1: Save config + save_model_config(model_config, checkpoint_dir) + + # Phase 2: Build weight map using descriptor and write index + subblock_keys = descriptor.get_weight_groups( + layer_names=state_dict.keys(), + num_hidden_layers=model_config.num_hidden_layers, + ) + + weight_map = {} + for subblock, layer_keys in subblock_keys.items(): + weight_map_entries = { + key: f"subblocks_safetensors/{subblock}.safetensors" for key in layer_keys + } + weight_map.update(weight_map_entries) + + # Handle tie_word_embeddings - remove from state_dict and weight_map BEFORE writing index + output_emb_weight_name = f"{descriptor.output_embedding_name()}.weight" + if getattr(model_config, "tie_word_embeddings", False) and output_emb_weight_name in state_dict: + state_dict = {k: v for k, v in state_dict.items() if k != output_emb_weight_name} + weight_map = {k: v for k, v in weight_map.items() if k != output_emb_weight_name} + + # Write index (now without tied embedding) + index = {"metadata": {"format": "pt"}, "weight_map": weight_map} + index_path = checkpoint_dir / SAFE_WEIGHTS_INDEX_NAME + index_json = json_dumps(index) + _write_file_process_safe(index_json, index_path) + + # Phase 3: Save subblocks + save_subblocks( + state_dict, + checkpoint_dir, + weight_map=weight_map, + multi_threaded=True, + max_workers=max_workers, + ) + + +def save_subblocks( + state_dict: dict[str, torch.Tensor], + checkpoint_dir: Path | str, + weight_map: dict[str, str] | None = None, + multi_threaded: bool = True, + max_workers: int | None = None, # Now optional - will auto-calculate if None +) -> None: + mprint("=== Starting save_subblocks detailed profiling ===") + subblocks_start_time = time.time() + + if not isinstance(checkpoint_dir, Path): + checkpoint_dir = Path(checkpoint_dir) + + # Step 1: Build weight map (use provided or build from state_dict) + weight_map_start_time = time.time() + if weight_map is None: + weight_map = _build_safetensors_weight_map( + state_dict=state_dict, + non_layer_module_to_file_type=NON_LAYER_MODULE_TO_FILE_TYPE, + module_within_layer_to_file_type=MODULE_WITHIN_LAYER_TO_FILE_TYPE, + layers_module_name=LAYERS_MODULE_NAME, + ) + weight_name_to_filename = {k: checkpoint_dir / v for k, v in weight_map.items()} + weight_map_time = time.time() - weight_map_start_time + mprint(f" Step 1 - Build weight map: {weight_map_time:.2f}s ({len(weight_map)} mappings)") + + # Step 2: Create subblocks directory + dir_create_start_time = time.time() + subblocks_path = checkpoint_dir / SAFETENSORS_SUBBLOCKS_DIR_NAME + subblocks_path.mkdir(parents=True, exist_ok=True) + dir_create_time = time.time() - dir_create_start_time + mprint(f" Step 2 - Create directory: {dir_create_time:.2f}s") + + # Step 3: Organize tensors by file + organize_start_time = time.time() + filename_to_partial_state_dict = defaultdict(dict) + total_tensor_size = 0 + for weight_name, weight in state_dict.items(): + if weight_name in weight_map: + # Ensure tensor is contiguous and on CPU for faster I/O + tensor = ( + weight.contiguous().cpu() if weight.device.type != "cpu" else weight.contiguous() + ) + filename_to_partial_state_dict[weight_name_to_filename[weight_name]][weight_name] = ( + tensor + ) + total_tensor_size += weight.numel() * weight.element_size() + organize_time = time.time() - organize_start_time + mprint( + f" Step 3 - Organize tensors: {organize_time:.2f}s ({total_tensor_size / (1024**3):.2f}GB total)" + ) + + # Step 4: Prepare save arguments and auto-calculate optimal I/O workers + prepare_start_time = time.time() + safe_save_kwargs = [ + {"tensors": partial_state_dict, "filename": filename, "metadata": {"format": "pt"}} + for filename, partial_state_dict in filename_to_partial_state_dict.items() + ] + + # Auto-calculate optimal I/O workers: min(cpu_count, num_files) + if max_workers is None: + cpu_count = os.cpu_count() or 1 + num_files = len(safe_save_kwargs) + max_workers = min(cpu_count, num_files) + mprint( + f" Auto-calculated I/O workers: min({cpu_count} CPUs, {num_files} files) = {max_workers}" + ) + else: + mprint(f" Using specified I/O workers: {max_workers}") + + prepare_time = time.time() - prepare_start_time + mprint(f" Step 4 - Prepare save args: {prepare_time:.2f}s ({len(safe_save_kwargs)} files)") + + # Step 5: Save files with optimal worker count + save_start_time = time.time() + if multi_threaded: + mprint(f" Using multi-threaded saving with {max_workers} workers...") + + def optimized_safe_save(kwargs): + try: + safe_save_file(**kwargs) + return True + except Exception as e: + mprint(f" Error saving {kwargs['filename']}: {e}") + return False + + with concurrent.futures.ThreadPoolExecutor(max_workers=max_workers) as executor: + results = list(executor.map(optimized_safe_save, safe_save_kwargs)) + + # Check for any failures + failed_saves = sum(1 for r in results if not r) + if failed_saves > 0: + mprint(f" Warning: {failed_saves} files failed to save") + else: + mprint(" Using single-threaded saving...") + for kwargs in safe_save_kwargs: + safe_save_file(**kwargs) + + save_time = time.time() - save_start_time + mprint(f" Step 5 - Save files: {save_time:.2f}s ({max_workers} workers)") + + subblocks_total_time = time.time() - subblocks_start_time + mprint(f"=== save_subblocks completed in {subblocks_total_time:.2f}s ===") + mprint( + f" Breakdown: WeightMap {weight_map_time:.1f}s + DirCreate {dir_create_time:.1f}s + " + f"Organize {organize_time:.1f}s + Prepare {prepare_time:.1f}s + Save {save_time:.1f}s" + ) + + # Calculate effective I/O speed + io_speed_gbps = (total_tensor_size / (1024**3)) / save_time if save_time > 0 else 0 + mprint(f" Effective I/O speed: {io_speed_gbps:.2f} GB/s ({max_workers} workers)") + mprint(f" Save operation was {save_time / subblocks_total_time * 100:.1f}% of total time") + + +def _write_text(content: str, f: BinaryIO) -> None: + f.write(content.encode("utf-8")) + + +def _write_file_process_safe( + content: Any, + path: Path | str, + write_fn: Callable[[Any, BinaryIO], None] = _write_text, +) -> None: + """ + Write a file in a multi-process safe way. + If another process tries to write the same file using this method, the current process + "gives up" and assumes that the matter is being taken care of by another process. + + write_fn is a function that receives file contents and a binary file object, + and writes the content to the file. It can be _write_text (defined above), or torch.save, + or a similar function (not safetensors.torch.save_file since it expects a path). + """ + with open(path, "wb") as f: + # Try to acquire an exclusive, non-blocking lock + try: + fcntl.flock(f, fcntl.LOCK_EX | fcntl.LOCK_NB) + except BlockingIOError: + return # Exit immediately if the lock is not acquired + + write_fn(content, f) # Write the content if lock is acquired + f.flush() # Ensure data is written to disk + + # Release the lock + fcntl.flock(f, fcntl.LOCK_UN) + + +def _build_safetensors_weight_map( + *, + state_dict: dict[str, torch.Tensor], + non_layer_module_to_file_type: dict[str, str], + module_within_layer_to_file_type: dict[str, str], + layers_module_name: str, +) -> dict[str, Path]: + weight_map = {} + unmapped_weight_names = [] + for weight_name in state_dict: + found_match = False + for module_name, file_type in non_layer_module_to_file_type.items(): + if weight_name.startswith(f"{module_name}."): + weight_map[weight_name] = str(RELATIVE_SUBBLOCKS_DIR / f"{file_type}.safetensors") + found_match = True + if not found_match: + if weight_name.startswith(f"{layers_module_name}."): + name_parts = weight_name[len(layers_module_name) + 1 :].split(".") + layer_index = name_parts[0] + name_within_layer = ".".join(name_parts[1:]) + + for module_name, file_type in module_within_layer_to_file_type.items(): + if name_within_layer.startswith(f"{module_name}."): + weight_map[weight_name] = str( + RELATIVE_SUBBLOCKS_DIR / f"block_{layer_index}_{file_type}.safetensors" + ) + found_match = True + + if not found_match: + unmapped_weight_names.append(weight_name) + + if len(unmapped_weight_names) > 0: + raise ValueError( + f"Unmapped weight names: {unmapped_weight_names}\n" + f"Add them to the `non_layer_module_to_file_type` or " + f"`module_within_layer_to_file_type` dictionaries." + ) + + return weight_map + + +def save_model_config(model_config: PretrainedConfig, checkpoint_dir: Path | str) -> None: + if hasattr(model_config, "block_configs"): + model_config.block_configs = [ + dataclasses.asdict(conf) if dataclasses.is_dataclass(conf) else conf + for conf in model_config.block_configs + ] + model_config.save_pretrained(checkpoint_dir) diff --git a/modelopt/torch/puzzletron/tools/common.py b/modelopt/torch/puzzletron/tools/common.py new file mode 100644 index 0000000000..96db572802 --- /dev/null +++ b/modelopt/torch/puzzletron/tools/common.py @@ -0,0 +1,22 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch + + +def infer_weights_dtype(state_dict: dict[str, torch.Tensor]) -> torch.dtype: + weights_dtype = [p.dtype for p in state_dict.values() if torch.is_floating_point(p)] + weights_dtype = weights_dtype[0] if len(weights_dtype) > 0 else torch.get_default_dtype() + return weights_dtype diff --git a/modelopt/torch/puzzletron/tools/hydra_utils.py b/modelopt/torch/puzzletron/tools/hydra_utils.py new file mode 100644 index 0000000000..64c4035656 --- /dev/null +++ b/modelopt/torch/puzzletron/tools/hydra_utils.py @@ -0,0 +1,81 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Utilities for hydra config initialization. +""" + +import datetime +import random +from pathlib import Path + +from hydra import compose, initialize, initialize_config_dir +from hydra.utils import get_object +from omegaconf import DictConfig, OmegaConf + + +def warmup_steps(tokens: int, block: int, mbs: int, pct: float = 0.05) -> int: + """ + Calculate warmup steps based on total tokens, block size, micro batch size, and warmup percentage. + Used as a resolver in hydra configs. + """ + steps = (int(tokens) // int(block)) // int(mbs) + w = pct * steps + return max(1, round(w)) + + +def register_hydra_resolvers(): + OmegaConf.register_new_resolver("to_path", lambda x: Path(x)) + OmegaConf.register_new_resolver( + "random_int", lambda low, high: random.randint(int(low), int(high)) + ) + OmegaConf.register_new_resolver( + "timedelta_minutes", lambda x: datetime.timedelta(minutes=x) if x is not None else None + ) + OmegaConf.register_new_resolver("warmup_steps", lambda t, b, m, p: warmup_steps(t, b, m, p)) + OmegaConf.register_new_resolver("get_object", lambda x: get_object(x)) + + +def initialize_hydra_config_for_dir( + config_dir: str, config_name: str, overrides: list[str] +) -> DictConfig: + """Initialize a hydra config from an absolute path for a config directory + + Args: + config_dir (str): + config_name (str): + overrides (List[str]): + + Returns: + DictConfig: + """ + + with initialize_config_dir(version_base=None, config_dir=config_dir): + args = compose(config_name, overrides) + args._set_flag("allow_objects", True) + OmegaConf.resolve(args) # resolve object attributes + OmegaConf.set_struct(args, False) + + return args + + +def initialize_hydra_config(config_path: str, config_name: str, overrides: list[str]) -> DictConfig: + with initialize(version_base=None, config_path=config_path): + args = compose(config_name, overrides) + args._set_flag("allow_objects", True) + OmegaConf.resolve(args) # resolve object attributes + OmegaConf.set_struct(args, False) + + return args diff --git a/modelopt/torch/puzzletron/tools/kd_model.py b/modelopt/torch/puzzletron/tools/kd_model.py new file mode 100644 index 0000000000..8590c3f56c --- /dev/null +++ b/modelopt/torch/puzzletron/tools/kd_model.py @@ -0,0 +1,53 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Knowledge distillation loss functions. + +Provides normalized_mse_loss and cosine_embedding_loss_batched for comparing +model outputs. Used by validation.py. +""" +# mypy: ignore-errors + +from abc import ABCMeta, abstractmethod +from typing import Callable, List, Literal, Optional, Tuple + +import torch +import torch.nn.functional as F +from torch import Tensor, nn + + +def normalized_mse_loss( + input: Tensor, + target: Tensor, + reduction: Literal["none", "mean", "sum"] = "mean", + epsilon: float = 1e-6, +) -> Tensor: + loss = F.mse_loss(input, target, reduction=reduction) / F.mse_loss( + target, torch.zeros_like(target) + epsilon, reduction=reduction + ) + return loss + + +def cosine_embedding_loss_batched(input: Tensor, target: Tensor) -> Tensor: + # inputs are of shape (B,T,H) + batch_size = input.size(0) + input = input.view(batch_size, -1) + target = target.view(batch_size, -1) + target_tensor = input.new(input.size(0)).fill_(1) + loss = F.cosine_embedding_loss( + input1=input, input2=target, target=target_tensor, reduction="none" + ) + return loss diff --git a/modelopt/torch/puzzletron/tools/logger.py b/modelopt/torch/puzzletron/tools/logger.py new file mode 100644 index 0000000000..257e55abe3 --- /dev/null +++ b/modelopt/torch/puzzletron/tools/logger.py @@ -0,0 +1,168 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# mypy: ignore-errors +import inspect +import logging +import os +import sys + +import torch.distributed.launch # noqa: F401 + +logging.getLogger("fsspec.local").setLevel(logging.ERROR) +logging.getLogger("websockets.client").setLevel(logging.WARN) +logging.getLogger("websockets.server").setLevel(logging.WARN) +logging.getLogger("websockets.server:connection").setLevel(logging.WARN) + + +class LogColors: + BLUE = "\033[94m" + CYAN = "\033[96m" + GREEN = "\033[92m" + YELLOW = "\033[93m" + RED = "\033[91m" + + BOLD = "\033[1m" + UNDERLINE = "\033[4m" + RESET = "\033[0m" + + +class DistributedLogger(logging.Logger): + verbosity = logging.ERROR + + def __init__(self, name, level=logging.DEBUG): + super().__init__(name, level) + self.local_rank = int(os.environ.get("LOCAL_RANK", 0)) + self.global_rank = int(os.environ.get("RANK", 0)) + self.world_size = int(os.environ.get("WORLD_SIZE", 1)) + + def dist_log(self, msg: str, ranks: str = "main"): + """Log parameter msg with the given ranks. + + Args: + msg: The message to log. + ranks: The ranks to log the message to. Choices are: + "all": log with all ranks + "main": log with only rank 0 in node 0 + "last": log with only rank -1 in node 0 + "local_main": log with only rank 0 in all nodes + """ + # print(msg, ranks) + if ranks not in ["all", "main", "local_main", "last"]: + raise NotImplementedError( + f"Could not broadcast msg {msg} - " + f"ranks parameters choices are ['all', 'main', 'local_main', 'last']. Got {ranks}" + ) + # All ranks to print + if ranks == "all": + pass + + # Only main rank at node 0 to print + elif ( + (ranks == "main" and self.global_rank != 0) + or (ranks == "last" and self.global_rank != self.world_size - 1) + or (ranks == "local_main" and self.local_rank != 0) + ): + return + + message_source = self.get_caller_location() + + self.info( + f"{LogColors.GREEN}[rank-{self.global_rank}]{LogColors.RESET}[{message_source}]\t{msg}" + ) + + # def dist_warning(self, msg): + # if self.verbosity <= logging.WARNING: + # self.warning(f"[rank-{self.global_rank}] " + msg) + + @staticmethod + def get_caller_location() -> str: + # Get the caller's stack frame + frame = inspect.currentframe() + + # f_back -> class method, 2 x f_back -> utils method, 3 x f_back -> original source + caller_frame = frame.f_back.f_back.f_back + + # Get the filename and line number from the caller's stack frame + filename = os.path.basename(caller_frame.f_code.co_filename) + lineno = caller_frame.f_lineno + return f"{filename}:{lineno}" + + +# Initialize logger +logging.setLoggerClass(DistributedLogger) +logger = logging.getLogger(__name__) +logger.propagate = False + +formatter = logging.Formatter("[%(asctime)s]%(message)s") +handler = logging.StreamHandler(sys.stdout) +handler.setFormatter(formatter) +handler.setLevel(logging.DEBUG) +logger.addHandler(handler) + +# Manually edit torch logger +torch_logger = logging.getLogger("torch") +torch_logger.handlers = logger.handlers +torch_logger.propagate = False + +# Manually edit deepspeed logger + +# Show some love to Mac & Windows users who can't easily install deepspeed ;) +# This is allowing running tests on Mac & Windows and train in non-DDP +try: + from deepspeed.utils import logger as deepspeed_logger + + deepspeed_logger.handlers = logger.handlers + deepspeed_logger.propagate = False +except ImportError: + # If deepspeed is not installed - no op + pass + +# Define a custom function to redirect warnings to logger +# def custom_warning_handler(message, category, filename, lineno, file=None, line=None): +# logger.dist_warning(f'{category.__name__}: {message} (in {filename}, line {lineno})') + + +# Use the custom warning handler +# warnings.showwarning = custom_warning_handler + +logger: DistributedLogger + + +def aprint(msg: str | None): + """ + All ranks from all nodes prints + """ + return logger.dist_log(msg=msg, ranks="all") + + +def lmprint(msg: str | None): + """ + All local main ranks prints (rank 0 in each node) + """ + return logger.dist_log(msg=msg, ranks="local_main") + + +def mprint(msg: str | None): + """ + Master prints only (rank 0 in node 0) + """ + return logger.dist_log(msg=msg, ranks="main") + + +def lprint(msg: str | None): + """ + Last rank prints only (rank -1 in node 0) + """ + return logger.dist_log(msg=msg, ranks="last") diff --git a/modelopt/torch/puzzletron/tools/post_init_sparse.py b/modelopt/torch/puzzletron/tools/post_init_sparse.py new file mode 100644 index 0000000000..eb20250e68 --- /dev/null +++ b/modelopt/torch/puzzletron/tools/post_init_sparse.py @@ -0,0 +1,123 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# mypy: ignore-errors +import torch +from torch import nn +from torch.nn.utils.prune import custom_from_mask + +""" +Converts a state dictionary from PyTorch's pruning format (with _orig and _mask suffixes) +into a standard format with sparsified weights. +""" + + +class SparsityMethod: + def calculate_masks(self, state_dict: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]: + """Gets a model state_dict, returns a state_dict-like mask_dict with masks""" + + @staticmethod + def fix_state_dict_inplace(state_dict, verbose=False, change_dtype=False): + sparsity_masks = {} + for name in list(state_dict.keys()): + original_name = name.replace("_orig", "") + mask_name = original_name + "_mask" + if name[-4:] == "orig" and mask_name in state_dict: + val = state_dict[name] + mask = state_dict[name[:-4] + "mask"] + val[mask == 0] = 0 + sparsity = (val == 0).sum() / mask.numel() + sparsity_masks[original_name[:-7]] = mask + if verbose: + print(f"fix_state_dict_inplace: {name} {sparsity=}") + del state_dict[mask_name] + del state_dict[name] + state_dict[original_name] = val + if change_dtype: + for name in state_dict: + state_dict[name] = state_dict[name].to(torch.bfloat16) + return state_dict, sparsity_masks + + def filter_function(self): + pass + + def apply_masks(self, model: nn.Module, mask_dict: dict[str, torch.Tensor]) -> None: + for name, module in model.named_modules(): + if name in mask_dict: + custom_from_mask(module, "weight", mask_dict[name].to(module.weight.device)) + print(name) + print(torch.sum(mask_dict[name]) / mask_dict[name].numel()) + + def do_sparsity(self, model: nn.Module, mask_dict=None): + full_name_layers = [] + for block_idx, block_config in enumerate(model.config.block_configs): + ffn_names = block_config.ffn.sparsify # layers_to_sparsify_pattern[block_idx] + att_name = block_config.attention.sparsify + block = model.model.layers[block_idx] + if hasattr(block, "mlp"): + for name, m in block.mlp.named_modules(): + if isinstance(m, torch.nn.Linear) and self.filter_function(name, ffn_names): + full_name_layers.append( + "model.layers." + str(block_idx) + "." + "mlp." + name + ) + if hasattr(block, "self_attn"): + for name, m in block.self_attn.named_modules(): + if isinstance(m, torch.nn.Linear) and self.filter_function(name, att_name): + full_name_layers.append( + "model.layers." + str(block_idx) + "." + "self_attn." + name + ) + + if mask_dict is None: + state_dict_for_sparsifying = { + k.rstrip(".weight"): v + for k, v in model.state_dict().items() + if k.rstrip(".weight") in full_name_layers + } + mask_dict = self.calculate_masks(state_dict_for_sparsifying) + # print('Apply sparsity') + # print(full_name_layers) + # print(model.state_dict().keys()) + # print(list(mask_dict.keys())) + + self.apply_masks(model, mask_dict) + + +class SparsityMethod2o4(SparsityMethod): + def calculate_masks(self, state_dict: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]: + """Gets a model state_dict, returns a state_dict-like mask_dict with masks""" + mask_dict = {} + for key, val in state_dict.items(): + orig_size = val.shape + scores = val.flatten() ** 2 + mask = self.create_mask(scores) + mask = mask.reshape(orig_size) + mask_dict[key] = mask + return mask_dict + + def create_mask(self, score, value=0): + score = score # .cpu() + orig_size = score.shape + score = score.view(-1, 4) + mask = torch.zeros(score.shape) + values, indices = torch.topk(score, 2, dim=1) + rows = torch.arange(mask.size(0)).unsqueeze(-1) + mask[rows, indices] = 1 + mask = mask.view(orig_size) + return mask # dev = score.device, return mask.to(dev) + + @staticmethod + def filter_function(name, modules_to_sparsify_in_block): + if modules_to_sparsify_in_block is None: + return False + return name in modules_to_sparsify_in_block diff --git a/modelopt/torch/puzzletron/tools/robust_json.py b/modelopt/torch/puzzletron/tools/robust_json.py new file mode 100644 index 0000000000..3397de6393 --- /dev/null +++ b/modelopt/torch/puzzletron/tools/robust_json.py @@ -0,0 +1,77 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# mypy: ignore-errors + +""" +Provides a robust JSON encoder that can handle various types of objects, +including dataclasses, paths, enums, namespaces, and functions. +""" + +import argparse +import dataclasses +import datetime +import inspect +import json +from enum import Enum +from pathlib import Path +from typing import Any + +from omegaconf import DictConfig, ListConfig, OmegaConf + + +class RobustJSONEncoder(json.JSONEncoder): + def default(self, o): + if dataclasses.is_dataclass(o): + return dataclasses.asdict(o) + if isinstance(o, Path): + return str(o) + if isinstance(o, Enum): + return o.name + if isinstance(o, argparse.Namespace): + return vars(o) + if type(o).__name__ == "dtype": + return str(o) + if isinstance(o, (DictConfig, ListConfig)): + return OmegaConf.to_container(o, resolve=True) + if inspect.isfunction(o) or inspect.ismethod(o): + if o.__module__ == "__main__": + # User-defined function in main — fallback to just the name + return o.__name__ + return f"{o.__module__}.{o.__qualname__}" + if inspect.isclass(o): + return f"{o.__module__}.{o.__qualname__}" + if isinstance(o, datetime.timedelta): + return str(o) + # Fallback for arbitrary objects: return their class path + if hasattr(o, "__class__") and hasattr(o.__class__, "__module__"): + return f"{o.__class__.__module__}.{o.__class__.__qualname__}" + return super().default(o) + + +def json_dumps(obj: Any) -> str: + return json.dumps(obj, cls=RobustJSONEncoder, indent=2) + + +def json_dump(obj: Any, path: Path | str) -> None: + path = Path(path) + path.parent.mkdir(exist_ok=True, parents=True) + json_text = json_dumps(obj) + path.write_text(json_text) + + +def json_load(path: Path | str) -> dict: + path = Path(path) + text = path.read_text() + return json.loads(text) diff --git a/modelopt/torch/puzzletron/tools/sharded_checkpoint_utils.py b/modelopt/torch/puzzletron/tools/sharded_checkpoint_utils.py new file mode 100644 index 0000000000..c18867a576 --- /dev/null +++ b/modelopt/torch/puzzletron/tools/sharded_checkpoint_utils.py @@ -0,0 +1,404 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# mypy: ignore-errors + +""" +Provides utilities for distributed loading, saving, and manipulation of +large language model checkpoints across multiple GPUs/processes. + +Uses native HuggingFace models with deci_x_patcher for heterogeneous layer configurations. +""" + +import json +from collections.abc import Iterable, Mapping +from pathlib import Path +from types import SimpleNamespace +from typing import Literal, Type, cast + +import numpy as np +import torch +import torch.distributed +import torch.nn as nn +import transformers +from huggingface_hub import split_torch_state_dict_into_shards +from safetensors import safe_open +from safetensors.torch import load_file as safe_load_file +from safetensors.torch import save_file as safe_save_file +from tqdm import tqdm +from transformers import AutoConfig, AutoModelForCausalLM, PretrainedConfig +from transformers.utils import SAFE_WEIGHTS_INDEX_NAME, SAFE_WEIGHTS_NAME +from transformers.utils.hub import cached_file, get_checkpoint_shard_files + +import modelopt.torch.utils.distributed as dist +from modelopt.torch.puzzletron.tools.checkpoint_utils import load_model_config, load_state_dict +from modelopt.torch.puzzletron.tools.logger import mprint +from modelopt.torch.puzzletron.utils.dummy_modules import ( + DummyBlock, + DummyLMHead, + DummyModule, + DummyWTE, +) +from modelopt.torch.puzzletron.utils.utils import EmptyInitOnDevice + + +def set_submodule(model: nn.Module, module_name: str, new_submodule: nn.Module) -> None: + """Set a submodule on a model by dotted path.""" + parts = module_name.split(".") + parent_path = ".".join(parts[:-1]) + attr = parts[-1] + parent_module = model.get_submodule(parent_path) if parent_path else model + setattr(parent_module, attr, new_submodule) + + +def create_local_shard_(model, owned_block_indexes: set[int], descriptor, runtime): + # Get language model config (handles nested configs like Qwen3-VL's text_config) + lm_config = descriptor.get_language_model_config(model.config) + all_block_indexes = set(range(lm_config.num_hidden_layers)) + has_first_block = 0 in owned_block_indexes + has_last_block = max(all_block_indexes) in owned_block_indexes + + unowned_block_indexes = all_block_indexes - owned_block_indexes + for block_index in unowned_block_indexes: + decoder_layer_name = descriptor.layer_block_name(block_index) + decoder_layer = model.get_submodule(decoder_layer_name) + set_submodule( + model, + decoder_layer_name, + descriptor.create_dummy_block(decoder_layer, block_index=block_index), + ) + + # If we have the last block with tied embeddings, keep embed_tokens so lm_head works. + # load_sharded_state_dict will load embed_tokens.weight from the first shard's checkpoint file, + # and since they're tied, lm_head.weight gets populated too. + if not has_first_block and not (has_last_block and model.config.tie_word_embeddings): + set_submodule( + model, + descriptor.input_embedding_name(), + DummyWTE(lm_config.hidden_size, dtype=runtime.dtype), + ) + + if not has_last_block: + set_submodule(model, descriptor.final_norm_name(), nn.Identity()) + if not (model.config.tie_word_embeddings and has_first_block): + set_submodule(model, descriptor.output_embedding_name(), DummyLMHead(lm_config)) + + return model + + +def _get_model_class_from_config(config: PretrainedConfig): + """ + Get the model class from config.architectures field. + Works for any model registered in transformers (CausalLM, VL models, etc.). + Falls back to AutoModelForCausalLM if architectures is not available. + """ + if hasattr(config, "architectures") and config.architectures: + model_class_name = config.architectures[0] + if hasattr(transformers, model_class_name): + return getattr(transformers, model_class_name) + mprint( + f"Warning: {model_class_name} not found in transformers, falling back to AutoModelForCausalLM" + ) + return AutoModelForCausalLM + + +def load_and_shard_model( + descriptor, + checkpoint_path: str | Path, + owned_block_indexes: set[int] | Literal["auto"] = "auto", + model_config: PretrainedConfig | None = None, +): + checkpoint_path = Path(checkpoint_path) + runtime = SimpleNamespace( + device=torch.device(dist.local_rank()), + dtype=torch.bfloat16, + global_rank=dist.rank(), + world_size=dist.size(), + is_main_process=dist.is_master(), + is_last_process=dist.is_last_process(), + use_autocast=True, # Default: use autocast; descriptor can override + ) + + with runtime.device: + if model_config is None: + trust_remote_code = descriptor.requires_trust_remote_code() + model_config = load_model_config(checkpoint_path, trust_remote_code=trust_remote_code) + + num_hidden_layers = descriptor.get_language_model_config(model_config).num_hidden_layers + if owned_block_indexes == "auto": + owned_block_indexes = set( + np.array_split(np.arange(num_hidden_layers), runtime.world_size)[ + runtime.global_rank + ] + ) + + mprint("Initializing model shards") + # Pass block_configs explicitly so patcher works for VL models where + # decoder layers receive nested config (e.g., text_config) without block_configs + from modelopt.torch.puzzletron.anymodel.puzzformer import deci_x_patcher + + with deci_x_patcher( + model_descriptor=descriptor, block_configs=getattr(model_config, "block_configs", None) + ): + model_shard = create_sharded_model( + runtime=runtime, + descriptor=descriptor, + model_config=model_config, + owned_block_indexes=owned_block_indexes, + ) + + if (checkpoint_path / SAFE_WEIGHTS_NAME).exists() or ( + checkpoint_path / SAFE_WEIGHTS_INDEX_NAME + ).exists(): + mprint("Loading shard state_dict from safetensors") + shard_keys = [ + *[name for name, _ in model_shard.named_parameters()], + *[name for name, _ in model_shard.named_buffers()], + ] + shard_state_dict = load_sharded_state_dict( + model_name_or_path=str(checkpoint_path), + keys_to_load=shard_keys, + device=runtime.device, + ) + + new_names = set(shard_state_dict.keys()) + mprint(f"{new_names=}") + # strict=False: allows missing lm_head.weight when tie_word_embeddings=True (e.g., Llama 3.2 3B) + model_shard.load_state_dict(shard_state_dict, strict=False, assign=True) + + del shard_state_dict + + # Re-tie weights after load_state_dict with assign=True, which severs the tie. + # Needed on first rank (owns embed_tokens) and last rank (owns lm_head). + has_first_block = 0 in owned_block_indexes + has_last_block = (num_hidden_layers - 1) in owned_block_indexes + if model_config.tie_word_embeddings and (has_first_block or has_last_block): + model_shard.tie_weights() + + # On the last rank with tied embeddings, we kept embed_tokens in create_local_shard_() + # just to load the weight and tie it to lm_head. Now replace it with a dummy so it + # doesn't interfere with the pipeline forward pass (only rank 0 should run embed_tokens). + if model_config.tie_word_embeddings and has_last_block and not has_first_block: + set_submodule( + model_shard, + descriptor.input_embedding_name(), + DummyWTE(model_config.hidden_size, dtype=runtime.dtype), + ) + else: + mprint("Loading state_dict in main process") + state_dict = load_state_dict(checkpoint_path) if runtime.is_main_process else None + + mprint("Distributing model to shards") + load_state_dict_to_shards(model_shard=model_shard, loaded_state_dict=state_dict) + del state_dict + + descriptor.init_rotary_embedding(model_shard, runtime) + + model_shard.type(runtime.dtype) + + # Configure autocast based on model descriptor (some models like Qwen3-VL MoE + # have dtype bugs under autocast) + runtime.use_autocast = descriptor.uses_autocast() + + params_on_meta_device = [ + param_name + for param_name, param in model_shard.named_parameters() + if param.device == torch.device("meta") + ] + assert len(params_on_meta_device) == 0, ( + f"[global_rank={runtime.global_rank}] Couldn't load params {params_on_meta_device}" + ) + + return model_shard + + +def create_sharded_model( + runtime, + descriptor, + model_config: PretrainedConfig, + owned_block_indexes: set[int], + device: str | torch.device | None = "meta", + dtype: torch.dtype | None = torch.float32, +): + if isinstance(device, str): + device = torch.device(device) + + dist.barrier() + + with EmptyInitOnDevice(device="meta", dtype=dtype): + # Get model class from config.architectures (works for CausalLM, VL models, etc.) + model_class = _get_model_class_from_config(model_config) + # AutoModelForCausalLM uses from_config(); concrete model classes use _from_config() + if model_class is AutoModelForCausalLM: + trust_remote_code = descriptor.requires_trust_remote_code() + model = model_class.from_config(model_config, trust_remote_code=trust_remote_code) + else: + model = model_class._from_config(model_config) + create_local_shard_( + model=model, + owned_block_indexes=owned_block_indexes, + descriptor=descriptor, + runtime=runtime, + ) + + if device != torch.device("meta"): + local_shard_state_dict = { + k: torch.empty_like(v, device=device) for k, v in model.state_dict().items() + } + model.load_state_dict(local_shard_state_dict, assign=True) + + return model + + +def load_state_dict_to_shards( + model_shard: torch.nn.Module, loaded_state_dict: dict | None = None +) -> None: + from modelopt.torch.puzzletron.sewing_kit.utils import ( + distributed_isend_obj, + distributed_recv_obj, + ) + + model_shard.to("meta") + local_state_dict_keys = list(model_shard.state_dict().keys()) + + if dist.is_master(): + gathered_state_dict_keys = [None] * dist.size() + torch.distributed.gather_object(local_state_dict_keys, gathered_state_dict_keys) + + assert loaded_state_dict is not None + loaded_state_dict = {k.replace("_orig_mod.", ""): v for k, v in loaded_state_dict.items()} + + works: list[torch.distributed.Work] = [] + for i, shard_keys in enumerate(gathered_state_dict_keys[1:]): + process_id = i + 1 + shard_state_dict = {k: v for k, v in loaded_state_dict.items() if k in shard_keys} + process_works = distributed_isend_obj(shard_state_dict, process_id) + works.extend(process_works) + + for work in works: + work.wait() + + shard_state_dict = { + k: v for k, v in loaded_state_dict.items() if k in local_state_dict_keys + } + else: + torch.distributed.gather_object(local_state_dict_keys) + shard_state_dict = distributed_recv_obj() + + print(f"{dist.rank()} loaded state_dict shard") + + missing_keys, unexpected_keys = model_shard.load_state_dict( + shard_state_dict, strict=False, assign=True + ) + assert len(unexpected_keys) == 0 + assert all("dummy_param" in key for key in missing_keys) + + model_shard.cuda(dist.local_rank()) + + dist.barrier() + + +def save_sharded_model( + model_shard: torch.nn.Module | dict[str, torch.Tensor], out_path: str | Path +): + """ + out_path is usually output_checkpoint_path / "model.safetensors" + """ + dist.barrier() + + if isinstance(model_shard, torch.nn.Module): + shard_state_dict = model_shard.state_dict() + elif isinstance(model_shard, dict): + shard_state_dict = model_shard + else: + raise ValueError(f"Unrecognized model shard type: {type(model_shard)}") + + shard_state_dict = {k: v.cpu() for k, v in shard_state_dict.items()} + total_shard_size = sum( + weight.numel() * weight.element_size() for weight in shard_state_dict.values() + ) + + num_shards = dist.size() + idx = dist.rank() + + out_path = Path(out_path) + shard_file = out_path.with_stem(f"{out_path.stem}-{idx + 1:05d}-of-{num_shards:05d}") + + shard_metadata = { + "total_shard_size": total_shard_size, + "shard_keys": list(shard_state_dict.keys()), + "shard_file": str(shard_file), + } + + if dist.is_master(): + shard_metadatas = [{} for _ in range(dist.size())] + torch.distributed.gather_object(shard_metadata, shard_metadatas, dst=0) + total_size = sum(x["total_shard_size"] for x in shard_metadatas) + metadata = {"total_size": total_size} + weight_map: dict[str, str] = {} + for shard_metadata in shard_metadatas: + weight_map.update( + {k: Path(shard_metadata["shard_file"]).name for k in shard_metadata["shard_keys"]} + ) + + index = {"metadata": metadata, "weight_map": weight_map} + index_path = Path(str(out_path) + ".index.json") + index_path.write_text(json.dumps(index, indent=2)) + + else: + torch.distributed.gather_object(shard_metadata, dst=0) + + if out_path.suffix == ".safetensors": + safe_save_file(shard_state_dict, shard_file, metadata={"format": "pt"}) + else: + torch.save(shard_state_dict, shard_file) + + dist.barrier() + + +def load_sharded_state_dict( + model_name_or_path: str | Path, + keys_to_load: Iterable[str] | None = None, + device: torch.device | str = "cpu", +) -> dict[str, torch.Tensor]: + """ + keys_to_load: entire state_dict if None, else partial state_dict containing only these keys + """ + shard_paths = _resolve_shard_paths(model_name_or_path) + # print(f"shard_paths: {shard_paths}") + partial_state_dict = {} + for safetensors_path in shard_paths: + if keys_to_load is None: + shard = safe_load_file(safetensors_path) + partial_state_dict.update(shard) + else: + with safe_open(safetensors_path, framework="pt", device=str(device)) as f: + for key in f.keys(): # noqa: SIM118 - safe_open objects require .keys(), not directly iterable + if key in keys_to_load: + partial_state_dict[key] = f.get_tensor(key) + return partial_state_dict + + +def _resolve_shard_paths(model_name_or_path: str) -> list[str]: + try: + unsharded_path = cached_file(model_name_or_path, SAFE_WEIGHTS_NAME) + return [unsharded_path] + except OSError: + index_path = cached_file(model_name_or_path, SAFE_WEIGHTS_INDEX_NAME) + shard_paths, _ = get_checkpoint_shard_files(model_name_or_path, index_path) + return shard_paths + + +def is_in_safetensors_format(checkpoint_dir: Path) -> bool: + return len(list(checkpoint_dir.glob("*.safetensors"))) > 0 diff --git a/modelopt/torch/puzzletron/tools/validate_model.py b/modelopt/torch/puzzletron/tools/validate_model.py new file mode 100644 index 0000000000..4a300fcd0b --- /dev/null +++ b/modelopt/torch/puzzletron/tools/validate_model.py @@ -0,0 +1,262 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# mypy: ignore-errors +""" +Provides a function to validate a model. Runs a model forward pass on a dataset and calculates +the loss, and optionally registers hooks to capture the inputs and the outputs +of pytorch modules that are used for activation scoring for pruning. + +TODO: Consider moving this a separate module dedicated for scoring + +Uses native HuggingFace models with deci_x_patcher for heterogeneous layer configurations. +""" + +import textwrap +from pathlib import Path +from typing import Type + +import torch +from omegaconf import DictConfig +from torch import nn +from torch.utils.data import DataLoader +from transformers import AutoTokenizer, PreTrainedModel, PreTrainedTokenizerBase + +import modelopt.torch.utils.distributed as dist +from modelopt.torch.puzzletron.activation_scoring.activation_hooks.utils import ( + register_activation_hooks, +) +from modelopt.torch.puzzletron.anymodel.model_descriptor import ( + ModelDescriptor, + ModelDescriptorFactory, +) +from modelopt.torch.puzzletron.anymodel.puzzformer.no_op import Same +from modelopt.torch.puzzletron.tools.logger import aprint, mprint +from modelopt.torch.puzzletron.tools.sharded_checkpoint_utils import ( + load_and_shard_model, + set_submodule, +) +from modelopt.torch.puzzletron.utils.data.dataloaders import create_validation_dataloader +from modelopt.torch.puzzletron.utils.parsing import ( + simple_parse_args_string, # noqa: F401 (kept for backwards compat) +) +from modelopt.torch.puzzletron.utils.validate_runtime_pipeline import ( + HiddenStatesAndLMHead, + calculate_losses_pipeline, +) + +""" +Two goals: +1) Calculate lm loss and token accuracy for a model. +May raise lots of NCCL warnings when it finishes, don't be alarmed. +Can be used to validate a HuggingFace model. +Automatically uses pipeline parallelism via device_map="auto". + +2) Register hooks to capture the inputs and the outputs of pytorch modules. +For example, to collect activations scores for various layers (ffn, layer_norm, etc.) +that are used for pruning (ffn_hidden_size, embedding_pruning, etc). +See activations_log_dir and activation_hooks_kwargs arguments. +""" + + +@torch.no_grad() +def validate_model( + args: DictConfig, + model: PreTrainedModel | None = None, + tokenizer: PreTrainedTokenizerBase | None = None, + target_hidden_states_per_batch: list[torch.Tensor] | None = None, + return_hidden_states: bool = False, + calculate_full_score_ablations: bool = False, + val_dataloader: DataLoader | None = None, +) -> tuple[dict[str, dict], HiddenStatesAndLMHead | None] | tuple[None, None]: + """Validate a language model on a dataset by calculating loss and optionally capturing activations. + + Args: + args: Configuration object containing the following attributes: + + Model Configuration: + - model_name_or_path (str): Path to model checkpoint or HuggingFace model name. Required unless model is passed directly. + - model_dtype (str or torch.dtype): Model data type (e.g., "torch.bfloat16", torch.float16). + - autocast_dtype (str or torch.dtype): Autocast data type for mixed precision. + + Dataset Configuration: + - dataset_path (str): Path to the validation dataset. + - tokenizer_name (str, optional): Tokenizer name/path. Uses model_name_or_path if not specified. + - data_column (str): Column name in dataset containing text data. + - block_size (int): Maximum sequence length for tokenization. + - eval_samples (int, optional): Number of samples to evaluate. Uses all if None. + - val_dataset_name (str): Name of validation dataset split. + - source_datasets_to_discard (list[str], optional): List of source datasets to exclude. + - load_dataset_fn (callable, optional): Custom function to load the dataset. + + Data Processing: + - micro_batch_size (int): Batch size for evaluation. + - seed (int): Random seed for reproducibility. + - shuffle_seed (int, optional): Seed for shuffling data. Uses seed if None. + - varlen (bool): Enable variable-length sequences. + - bos_rate (float): Rate of adding BOS token. + - fim_rate (float): Fill-in-the-middle rate for code completion tasks. + - fim_spm_rate (float): SPM-based fill-in-the-middle rate. + + Activation Hooks: + - activations_log_dir (str, optional): Directory to log activation scores. If provided, hooks will be registered to capture activations. + - activation_hooks_kwargs (str or dict, optional): Arguments for activation hooks. If string, comma-separated format: "arg1=val1,arg2=val2". + + Execution Options: + - calc_losses_on_cpu (bool): Calculate losses on CPU to avoid OOM. Very slow, not recommended. + - write_results (bool): Write validation results to file. + + model: Pre-loaded model. If None, will be loaded from args.model_name_or_path. + tokenizer: Pre-loaded tokenizer. If None, will be loaded based on args. + target_hidden_states_per_batch: Target hidden states for pipeline parallel evaluation. + return_hidden_states: Whether to return hidden states from the model. + calculate_full_score_ablations: Calculate comprehensive teacher similarity scores. False calculates only a small suite for efficiency. + val_dataloader: Pre-created validation dataloader. If None, will be created from args. + + Returns: + A tuple containing: + - losses: Dictionary mapping loss names to loss statistics (avg, per_sample). + - hidden_states_per_batch: Hidden states and LM head outputs if return_hidden_states is True, else None. + + Returns (None, None) if not on master rank. + """ + descriptor = ModelDescriptorFactory.get(args.descriptor) + + if val_dataloader is None: + val_dataloader = prepare_dataloader(args, tokenizer) if dist.is_master() else None + validation_full_iters = ( + args.eval_samples // args.micro_batch_size + ) # model pipeline, single data rank + + model = prepare_model(args, descriptor=descriptor, model=model) + + just_model_forward = False + checkpoint_manager = None + activation_hooks = None + + if args.activations_log_dir is not None: + activation_hooks_kwargs = args.activation_hooks_kwargs or {} + activation_hooks_kwargs["validation_full_iters"] = validation_full_iters + hook_class = args.hook_class + + # Create activation hooks using pruning mixin + activation_hooks = register_activation_hooks( + model=model, + activation_hooks_kwargs=activation_hooks_kwargs, + hook_class=hook_class, + pruning_mixin=args.pruning_mixin, + ) + + # Create checkpoint manager with hooks + from modelopt.torch.puzzletron.utils.checkpoint_manager import ScoringCheckpointManager + + mprint( + f"Creating checkpoint manager with {len(activation_hooks)} hooks for dir: {args.activations_log_dir}" + ) + checkpoint_manager = ScoringCheckpointManager( + checkpoint_dir=args.activations_log_dir, + activation_hooks=activation_hooks, + checkpoint_interval=50, # Save every 50 batches + ) + + # Load existing checkpoint if available + mprint("Attempting to load existing checkpoint...") + checkpoint_data = checkpoint_manager.load_checkpoint() + if checkpoint_data: + mprint(f"Checkpoint loaded successfully: {checkpoint_data}") + else: + mprint("No checkpoint found, starting fresh") + just_model_forward = True + set_submodule(model, descriptor.output_embedding_name(), Same()) + + losses, hidden_states_per_batch = calculate_losses_pipeline( + stitched_model=model, + dataloader=val_dataloader, + target_hidden_states_per_batch=target_hidden_states_per_batch, + return_hidden_states=return_hidden_states, + calculate_full_score_ablations=calculate_full_score_ablations, + calc_on_cpu=args.calc_losses_on_cpu, + just_model_forward=just_model_forward, + checkpoint_manager=checkpoint_manager, + autocast_dtype=getattr( + torch, getattr(args, "autocast_dtype", "torch.bfloat16").strip("torch.") + ), + descriptor=descriptor, + use_autocast=descriptor.uses_autocast(), + ) + + if losses is not None: + avg_losses = {loss_name: loss_log["avg"] for loss_name, loss_log in losses.items()} + + results_str = f""" + validate_model: + {args.model_name_or_path=} + Average losses = {avg_losses} + Actual num samples = {len(next(iter(losses.values()))["per_sample"])} + {args=} + """ + results_str = textwrap.dedent(results_str) + aprint(results_str) + if args.write_results: + Path(f"{args.model_name_or_path}/validate_model_results.txt").write_text(results_str) + + if activation_hooks is not None: + hook_class.dump_activations_logs(activation_hooks, args.activations_log_dir, args) + + return losses, hidden_states_per_batch + + +def prepare_model( + args: DictConfig, + descriptor: Type[ModelDescriptor], + model: PreTrainedModel | None = None, +) -> nn.Module: + if model is None: + assert args.model_name_or_path is not None + model = load_and_shard_model(descriptor=descriptor, checkpoint_path=args.model_name_or_path) + + model.eval() + return model + + +def prepare_dataloader( + args: DictConfig, tokenizer: PreTrainedTokenizerBase | None = None +) -> DataLoader: + if tokenizer is None: + tokenizer_name = getattr(args, "tokenizer_name", None) + assert (tokenizer_name is not None) or (args.model_name_or_path is not None) + tokenizer = AutoTokenizer.from_pretrained( + tokenizer_name or args.model_name_or_path, trust_remote_code=True + ) + + val_dataloader = create_validation_dataloader( + accelerator=None, + seed=args.seed, + tokenizer=tokenizer, + block_size=args.block_size, + dataset=args.dataset_path, + content_field=args.data_column, + fim_rate=args.fim_rate, + fim_spm_rate=args.fim_spm_rate, + micro_batch_size=args.micro_batch_size, + eval_samples=args.eval_samples, + dataset_name=args.val_dataset_name, + source_datasets_to_discard=args.source_datasets_to_discard, + bos_rate=args.bos_rate, + varlen=args.varlen, + shuffle_seed=args.shuffle_seed, + load_dataset_fn=args.load_dataset_fn, + ) + + return val_dataloader diff --git a/modelopt/torch/puzzletron/tools/validate_puzzle_with_multi_replacements.py b/modelopt/torch/puzzletron/tools/validate_puzzle_with_multi_replacements.py new file mode 100644 index 0000000000..f647cd3f89 --- /dev/null +++ b/modelopt/torch/puzzletron/tools/validate_puzzle_with_multi_replacements.py @@ -0,0 +1,290 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Validates puzzle solutions by applying layer replacements and evaluating model performance. + +TODO: Consider moving this a separate module dedicated for scoring +""" + +# mypy: ignore-errors + +import json +import shutil +import warnings +from functools import partial +from pathlib import Path +from typing import Optional + +import torch +from omegaconf import DictConfig +from tqdm import tqdm +from transformers import AutoTokenizer, PreTrainedTokenizerBase + +import modelopt.torch.utils.distributed as dist +from modelopt.torch.puzzletron.anymodel.converter import Converter +from modelopt.torch.puzzletron.anymodel.model_descriptor import ModelDescriptorFactory +from modelopt.torch.puzzletron.replacement_library.replacement_library import ReplacementLibrary +from modelopt.torch.puzzletron.replacement_library.replacement_utils import parse_layer_replacement +from modelopt.torch.puzzletron.tools import validate_model +from modelopt.torch.puzzletron.tools.checkpoint_utils import ( + SAFETENSORS_SUBBLOCKS_DIR_NAME, + copy_tokenizer, +) +from modelopt.torch.puzzletron.tools.checkpoint_utils_hf import save_checkpoint +from modelopt.torch.puzzletron.tools.sharded_checkpoint_utils import load_and_shard_model +from modelopt.torch.puzzletron.tools.validation_utils import ( + validate_model_and_extract_hidden_states, + validate_model_with_teacher_similarity_metrics, +) +from modelopt.torch.puzzletron.utils.parsing import get_nested_key, parse_path +from modelopt.torch.puzzletron.utils.validate_runtime_pipeline import perform_pipeline_stitches + +""" +Usage Example: +============== + +Validate single_block_replacement_solutions by calling validate_puzzle_solutions() directly +with an args object containing the required attributes. See the function docstring for details. + +""" + + +@torch.no_grad() +def validate_puzzle_solutions(args: DictConfig) -> None: + """Validate puzzle solutions by applying layer replacements and evaluating model performance. + + Args: + args: Configuration object containing the following attributes: + + Puzzle Configuration (Required): + - replacement_library_path (Path): Path to the replacement library JSON file. + - solutions_path (Path): Path to puzzle solutions JSON file or directory containing solution files. + - solutions_to_validate (list[int], optional): Indices of specific solutions to validate. Validates all solutions if None. + - sort_solutions_by (str, optional): JSON field path to sort solutions by before validation. + - bigger_is_better (bool): If True, sort solutions in descending order. Used with sort_solutions_by. + - skip_validation (bool): If True, skip model validation and only save models if requested. + - save_models (bool): If True, save realized model checkpoints for each solution. + + Teacher/Tokenizer Configuration: + - teacher_dir (Path, optional): Path to teacher model directory. Auto-inferred if not provided. + - tokenizer_name (str, optional): Tokenizer name/path. Uses teacher_dir if not specified. + + Model Configuration (Required if skip_validation=False): + - model_dtype (str or torch.dtype): Model data type (e.g., "torch.bfloat16", torch.float16). + - autocast_dtype (str or torch.dtype): Autocast data type for mixed precision. + + Dataset Configuration (Required if skip_validation=False): + - dataset_path (str): Path to the validation dataset. + - data_column (str): Column name in dataset containing text data. + - block_size (int): Maximum sequence length for tokenization. + - eval_samples (int, optional): Number of samples to evaluate. + - val_dataset_name (str): Name of validation dataset split. + - source_datasets_to_discard (list[str], optional): List of source datasets to exclude. + - load_dataset_fn (callable, optional): Custom function to load the dataset. + + Data Processing (Required if skip_validation=False): + - micro_batch_size (int): Batch size for evaluation. + - seed (int): Random seed for reproducibility. + - shuffle_seed (int, optional): Seed for shuffling data. + - varlen (bool): Enable variable-length sequences. + - bos_rate (float): Rate of adding BOS token. + - fim_rate (float): Fill-in-the-middle rate for code completion tasks. + - fim_spm_rate (float): SPM-based fill-in-the-middle rate. + + Output Configuration: + - output_dir (Path, optional): Directory to save validation results. Auto-generated from solutions_path if not provided. + + Execution Options (Optional if skip_validation=False): + - calc_losses_on_cpu (bool): Calculate losses on CPU to avoid OOM. + - write_results (bool): Write validation results to file. + - activations_log_dir (str, optional): Directory to log activation scores. + - activation_hooks_kwargs (str or dict, optional): Arguments for activation hooks. + + Returns: + None. Saves validation results and optionally model checkpoints to disk. + """ + descriptor = ModelDescriptorFactory.get(args.descriptor) + + puzzle_solutions = load_puzzle_solutions( + args.solutions_path, args.sort_solutions_by, args.bigger_is_better + ) + if args.solutions_to_validate is None: + args.solutions_to_validate = list(range(len(puzzle_solutions))) + puzzle_solutions = [puzzle_solutions[i] for i in args.solutions_to_validate] + + tokenizer = _load_tokenizer(args, trust_remote_code=descriptor.requires_trust_remote_code()) + if not args.skip_validation: + val_dataloader = ( + validate_model.prepare_dataloader(args, tokenizer) if dist.is_master() else None + ) + + output_dir = ( + args.output_dir + if getattr(args, "output_dir", None) is not None + else args.solutions_path.with_name(f"{args.solutions_path.stem}--validation") + ) + + replacement_library = ReplacementLibrary( + args.replacement_library_path, + descriptor=descriptor, + model_config_overrides={"use_cache": False}, + ) + + teacher_hidden_states = None + if (args.teacher_dir is not None) and (not args.skip_validation): + teacher_model = load_and_shard_model( + checkpoint_path=args.teacher_dir, descriptor=descriptor + ) + teacher_model.cuda(dist.local_rank()) + stitched_model = perform_pipeline_stitches(teacher_model, descriptor=descriptor) + teacher_hidden_states = validate_model_and_extract_hidden_states( + args, + stitched_model, + tokenizer, + output_dir, + model_name="teacher", + val_dataloader=val_dataloader, + ) + + # Properly release CUDA memory after teacher validation + teacher_model.cpu() + stitched_model.cpu() + torch.cuda.empty_cache() + torch.cuda.synchronize() + dist.barrier() + + for i_solution, puzzle_solution in tqdm( + list(zip(args.solutions_to_validate, puzzle_solutions)), desc="Validating solutions" + ): + layer_replacements = _extract_layer_replacements_from_puzzle_solution(puzzle_solution) + realizable_as_symlinks = can_realize_as_symlinks(layer_replacements) + # realizable_as_symlinks = False + model_config = replacement_library.create_model_config(layer_replacements) + if (args.save_models and not realizable_as_symlinks) or (not args.skip_validation): + model = replacement_library.load_model(layer_replacements) + model_config = model.config + + if args.save_models: + checkpoint_dir = ( + args.solutions_path.with_name(f"{args.solutions_path.stem}--checkpoints") + / f"solution_{i_solution}" + ) + + model_config.dtype = getattr(args, "model_dtype", "torch.bfloat16") + Converter.copy_checkpoint_files(args.teacher_dir, checkpoint_dir) + if realizable_as_symlinks: + if dist.is_master(): + # TODO: Loo into internal Puzzleron code to see how to save as symlinks + # save_checkpoint_as_symlinks is currently not supported + pass + save_checkpoint(model, checkpoint_dir, descriptor) + + copy_tokenizer( + args.tokenizer_name, + checkpoint_dir, + trust_remote_code=descriptor.requires_trust_remote_code(), + ) + + dist.barrier() + + if not args.skip_validation: + model.cuda(dist.local_rank()) + stitched_model = perform_pipeline_stitches(model, descriptor=descriptor) + validate_model_with_teacher_similarity_metrics( + args, + stitched_model, + tokenizer, + teacher_hidden_states, + output_dir, + model_name=f"solution_{i_solution}", + extra_payload={"i_solution": i_solution, "puzzle_solution": puzzle_solution}, + val_dataloader=val_dataloader, + ) + + # Properly release CUDA memory after solution validation + model.cpu() + stitched_model.cpu() + torch.cuda.empty_cache() + torch.cuda.synchronize() + + dist.barrier() + + +def can_realize_as_symlinks(layer_replacements: list[dict]) -> bool: + for layer_replacement in layer_replacements: + num_parent_layers = len(layer_replacement["parent_layer_indices"]) + num_child_layers = len(layer_replacement["child_block_configs"]) + if num_parent_layers != num_child_layers or num_parent_layers != 1: + return False + return True + + +def _load_tokenizer(args: DictConfig, trust_remote_code: bool = False) -> PreTrainedTokenizerBase: + tokenizer = None + if (tokenizer_name := getattr(args, "tokenizer_name", None)) is not None: + tokenizer = AutoTokenizer.from_pretrained( + tokenizer_name, trust_remote_code=trust_remote_code + ) + elif args.teacher_dir is not None: + try: + tokenizer = AutoTokenizer.from_pretrained( + args.teacher_dir, trust_remote_code=trust_remote_code + ) + except Exception: + pass + if tokenizer is None: + warnings.warn("Couldn't find a tokenizer, trying to continue without one") + return tokenizer + + +def _extract_layer_replacements_from_puzzle_solution( + puzzle_solution: dict, +) -> list[dict]: + puzzle_solution = puzzle_solution.get("puzzle_solution", puzzle_solution) + layer_replacements = [ + parse_layer_replacement(rep) for rep in puzzle_solution["chosen_replacements"] + ] + return layer_replacements + + +def load_puzzle_solutions( + solutions_path: Path, + sort_solutions_by: Optional[str], + bigger_is_better: bool, +) -> list[dict]: + assert solutions_path.exists(), f"{solutions_path=} does not exist" + + if solutions_path.is_file(): + puzzle_solutions = json.loads(solutions_path.read_text()) + if isinstance(puzzle_solutions, dict): + puzzle_solutions = [puzzle_solutions] + else: + puzzle_solutions = [ + json.loads(p.read_text()) for p in solutions_path.glob("*solution*.json") + ] + + if len(puzzle_solutions) == 0: + raise ValueError(f"No solutions under {solutions_path=}") + + if sort_solutions_by is not None: + puzzle_solutions = sorted( + puzzle_solutions, key=partial(get_nested_key, field=sort_solutions_by) + ) + if bigger_is_better: + puzzle_solutions = puzzle_solutions[::-1] + vals = [get_nested_key(sol, sort_solutions_by) for sol in puzzle_solutions] + print(f"sorted solutions by {sort_solutions_by}. {vals[:10]=} {vals[-10:]=}") + + return puzzle_solutions diff --git a/modelopt/torch/puzzletron/tools/validation_utils.py b/modelopt/torch/puzzletron/tools/validation_utils.py new file mode 100644 index 0000000000..d7197e8abf --- /dev/null +++ b/modelopt/torch/puzzletron/tools/validation_utils.py @@ -0,0 +1,114 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Utility functions for validating models and extracting hidden states and similarity metrics. + +TODO: Consider moving this a separate module dedicated for scoring. +""" + +# mypy: ignore-errors + +from pathlib import Path +from typing import TYPE_CHECKING, Any, Optional, Union + +import torch +from omegaconf import DictConfig, OmegaConf +from torch import nn +from transformers import PreTrainedTokenizerBase + +import modelopt.torch.utils.distributed as dist +from modelopt.torch.puzzletron.tools import validate_model +from modelopt.torch.puzzletron.tools.logger import mprint +from modelopt.torch.puzzletron.tools.robust_json import json_dump +from modelopt.torch.puzzletron.utils.validation import LowMemorySparseTensor + +if TYPE_CHECKING: + from modelopt.torch.puzzletron.sewing_kit import StitchedModule + + +def validate_model_and_extract_hidden_states( + args: DictConfig, + model: "nn.Module | StitchedModule", + tokenizer: PreTrainedTokenizerBase, + output_dir: str | Path, + model_name: str, + extra_payload: Optional[dict[str, Any]] = None, + val_dataloader=None, +) -> list[torch.Tensor | LowMemorySparseTensor]: + mprint(f""" + +################################################################ +validate_model_and_extract_token_probs({model_name=}) +################################################################ + +""") + losses, hidden_states_per_batch = validate_model.validate_model( + args, + model, + tokenizer, + return_hidden_states=True, + val_dataloader=val_dataloader, + ) + if dist.is_last_process(): + output_dir = output_dir if (output_dir is not None) else args.bypass_dir + extra_payload = extra_payload if (extra_payload is not None) else dict() + write_results(output_dir, model_name, args, {**losses, **extra_payload}) + return hidden_states_per_batch + + +def validate_model_with_teacher_similarity_metrics( + args: DictConfig, + model: "nn.Module | StitchedModule", + tokenizer: PreTrainedTokenizerBase, + target_hidden_states_per_batch: list[torch.Tensor], + output_dir: str | Path, + model_name: str, + extra_payload: Optional[dict[str, Any]] = None, + calculate_full_score_ablations: bool = False, + val_dataloader=None, +) -> None: + is_calc_kl_div = target_hidden_states_per_batch is not None + mprint(f""" + +################################################################ +validate_model_with_kl_div({model_name=}, {is_calc_kl_div=}) +################################################################ + +""") + losses, _ = validate_model.validate_model( + args, + model, + tokenizer, + target_hidden_states_per_batch=target_hidden_states_per_batch, + calculate_full_score_ablations=calculate_full_score_ablations, + val_dataloader=val_dataloader, + ) + if dist.is_last_process(): + extra_payload = extra_payload if (extra_payload is not None) else dict() + write_results(output_dir, model_name, args, {**losses, **extra_payload}) + + +def write_results( + output_dir: str | Path, result_name: str, args: DictConfig, payload: dict[str, Any] +) -> None: + output_path = Path(output_dir) / f"{result_name}.json" + output_path.parent.mkdir(parents=True, exist_ok=True) + results = { + **payload, + "args": OmegaConf.to_container(args, resolve=True) + if isinstance(args, DictConfig) + else args.__dict__, + } + json_dump(results, output_path) diff --git a/modelopt/torch/puzzletron/utils/checkpoint_manager.py b/modelopt/torch/puzzletron/utils/checkpoint_manager.py new file mode 100644 index 0000000000..a1347deaea --- /dev/null +++ b/modelopt/torch/puzzletron/utils/checkpoint_manager.py @@ -0,0 +1,259 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Checkpoint manager for activation hook scoring with periodic saves and resume support.""" + +import json +import time +from pathlib import Path +from typing import Any + +import modelopt.torch.utils.distributed as dist +from modelopt.torch.puzzletron.tools.logger import aprint, mprint + + +class ScoringCheckpointManager: + """Manages checkpointing for activation hook scoring with periodic saves.""" + + def __init__(self, checkpoint_dir: str, activation_hooks=None, checkpoint_interval: int = 100): + """Initialize checkpoint manager. + + Args: + checkpoint_dir: Directory to save checkpoints + activation_hooks: Dictionary of activation hooks to manage + checkpoint_interval: Save checkpoint every N batches + """ + self.checkpoint_dir = Path(checkpoint_dir) + self.activation_hooks = activation_hooks + self.checkpoint_interval = checkpoint_interval + self.rank = dist.rank() + self.is_main_process = dist.is_master() + + # Debug: Log checkpoint manager initialization + hook_count = len(activation_hooks) if activation_hooks else 0 + aprint( + f"[Rank {self.rank}] Checkpoint manager initialized: {hook_count} hooks, dir: {checkpoint_dir}" + ) + + # Checkpoint files + self.progress_file = self.checkpoint_dir / "scoring_progress.json" + self.hook_states_file = self.checkpoint_dir / f"hook_states_rank_{self.rank}.pth" + + # Progress tracking + self.current_batch = 0 + self.total_batches = 0 + self.start_time = time.time() + + # Ensure directory exists + if self.is_main_process: + self.checkpoint_dir.mkdir(parents=True, exist_ok=True) + + def load_checkpoint(self) -> dict[str, Any] | None: + """Load existing checkpoint if available, including hook states. + + Returns: + Dict with checkpoint info or None if no checkpoint exists + """ + aprint(f"[Rank {self.rank}] Looking for checkpoint at: {self.progress_file}") + if not self.progress_file.exists(): + aprint(f"[Rank {self.rank}] No checkpoint file found at {self.progress_file}") + return None + + try: + with open(self.progress_file) as f: + checkpoint_data = json.load(f) + + # Validate checkpoint + if "current_batch" in checkpoint_data and "total_batches" in checkpoint_data: + self.current_batch = checkpoint_data["current_batch"] + self.total_batches = checkpoint_data["total_batches"] + + mprint( + f"Found checkpoint: batch {self.current_batch}/{self.total_batches} ({checkpoint_data.get('progress', 0.0):.1%})" + ) + mprint( + f"Will resume from batch {self.current_batch}, skipping batches 0-{self.current_batch - 1}" + ) + + # Load hook states if hooks are available + if self.activation_hooks is not None: + success = self.load_hook_states(self.activation_hooks) + if success: + aprint( + f"[Rank {self.rank}] Successfully loaded hook states from checkpoint" + ) + else: + aprint(f"[Rank {self.rank}] Failed to load hook states - starting fresh") + + return checkpoint_data + else: + aprint( + f"[Rank {self.rank}] Invalid checkpoint format (missing current_batch/total_batches): {checkpoint_data}" + ) + return None + + except (json.JSONDecodeError, KeyError) as e: + mprint(f"Error loading checkpoint: {e}") + + return None + + def load_hook_states(self, activation_hooks) -> bool: + """Load hook states from checkpoint files. + + Args: + activation_hooks: Hook objects to load states into + + Returns: + bool: True if hook states were successfully loaded, False otherwise + """ + import os + + # Each rank loads only its own hook states + current_rank = int(os.environ.get("RANK", 0)) + hook_states_path = self.checkpoint_dir / f"hook_states_rank_{current_rank}.pth" + + if hook_states_path.exists(): + aprint(f"[Rank {current_rank}] Loading hook states from {hook_states_path}") + try: + import torch + + hook_states = torch.load(hook_states_path, map_location="cpu") + + # Load states into corresponding hooks + loaded_count = 0 + for module_name, hook in activation_hooks.items(): + if module_name in hook_states: + hook.load_state_dict(hook_states[module_name]) + loaded_count += 1 + + # Log progress info if available (only for a few hooks to avoid spam) + if loaded_count <= 3: # Only log first few hooks + progress_info = hook.get_progress_info() + if progress_info: + aprint(f"[Rank {current_rank}] {module_name}: {progress_info}") + else: + aprint( + f"[Rank {current_rank}] Warning: No saved state found for hook: {module_name}" + ) + + aprint( + f"[Rank {current_rank}] Successfully loaded states for {loaded_count}/{len(activation_hooks)} hooks" + ) + return True + + except Exception as e: + aprint(f"[Rank {current_rank}] Error loading hook states: {e}") + return False + else: + aprint(f"[Rank {current_rank}] No hook states file found at {hook_states_path}") + return False + + def should_skip_batch(self, batch_idx: int) -> bool: + """Check if we should skip this batch (already processed in previous run).""" + should_skip = batch_idx < self.current_batch + if should_skip and batch_idx % 10 == 0: # Log every 10th skipped batch to avoid spam + mprint(f"Skipping batch {batch_idx} (resume from batch {self.current_batch})") + return should_skip + + def update_progress(self, batch_idx: int, total_batches: int): + """Update progress and potentially save checkpoint. + + Args: + batch_idx: Current batch index + total_batches: Total number of batches + """ + self.current_batch = batch_idx + self.total_batches = total_batches + + # Save checkpoint periodically or on completion + should_save = ( + (batch_idx % self.checkpoint_interval == 0) # Periodic save + or (batch_idx == total_batches - 1) # Final batch + ) + + if should_save: + # All ranks save their hook states + if self.activation_hooks is not None: + try: + from modelopt.torch.prune.importance_hooks.base_hooks import ForwardHook + + ForwardHook.save_hook_states(self.activation_hooks, self.checkpoint_dir) + except Exception as e: + mprint(f"Warning: Failed to save hook states: {e}") + + # Only main process saves progress info + if self.is_main_process: + self.save_checkpoint() + + # Synchronize all ranks after checkpointing + dist.barrier() + + def save_checkpoint(self): + """Save current checkpoint to disk (progress info only). + Hook states are saved separately in update_progress. + """ + try: + # Save progress + progress_data = { + "current_batch": self.current_batch, + "total_batches": self.total_batches, + "progress": self.current_batch / self.total_batches + if self.total_batches > 0 + else 0.0, + "timestamp": time.time(), + "elapsed_time": time.time() - self.start_time, + "rank": self.rank, + } + + # Write progress atomically + temp_file = self.progress_file.with_suffix(".tmp") + with open(temp_file, "w") as f: + json.dump(progress_data, f, indent=2) + temp_file.replace(self.progress_file) + + # Hook states are saved at a higher level to ensure all ranks participate + + if self.current_batch % (self.checkpoint_interval) == 0: + progress_pct = progress_data["progress"] * 100 + elapsed = progress_data["elapsed_time"] + mprint( + f"Checkpoint saved: batch {self.current_batch}/{self.total_batches} ({progress_pct:.1f}%), elapsed: {elapsed:.1f}s" + ) + + except Exception as e: + mprint(f"Error saving checkpoint: {e}") + + def finalize(self): + """Mark scoring as completed.""" + # All ranks save their final hook states + if self.activation_hooks is not None: + try: + from modelopt.torch.prune.importance_hooks.base_hooks import ForwardHook + + saved_path = ForwardHook.save_hook_states( + self.activation_hooks, self.checkpoint_dir + ) + mprint(f"Final hook states saved to {saved_path}") + except Exception as e: + mprint(f"Warning: Failed to save final hook states: {e}") + + # Only main process saves progress info + if self.is_main_process: + self.current_batch = self.total_batches + self.save_checkpoint() + mprint(f"Scoring completed and finalized: {self.total_batches} batches processed") + + # Synchronize all ranks after finalization + dist.barrier() diff --git a/modelopt/torch/puzzletron/utils/data/dataloaders.py b/modelopt/torch/puzzletron/utils/data/dataloaders.py new file mode 100644 index 0000000000..892d1f3c2c --- /dev/null +++ b/modelopt/torch/puzzletron/utils/data/dataloaders.py @@ -0,0 +1,203 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""DataLoader utilities for language model training and validation.""" + +from collections.abc import Callable, Mapping, Sequence +from functools import partial +from typing import Protocol, TypeVar + +import datasets +import torch +import torch.distributed +from accelerate import Accelerator +from torch.utils.data import DataLoader, Dataset, IterableDataset +from torch.utils.data._utils.collate import collate, default_collate_fn_map +from tqdm import tqdm +from transformers import PreTrainedTokenizerBase + +from modelopt.torch.puzzletron.tools.logger import mprint +from modelopt.torch.puzzletron.utils.data.dataset import ConstantLengthDataset + + +def collate_none_fn( + batch, *, collate_fn_map: dict[type | tuple[type, ...], Callable] | None = None +): + return None + + +collate_fn_map_with_none_support = {**default_collate_fn_map, type(None): collate_none_fn} +collate_fn_with_none_support = partial(collate, collate_fn_map=collate_fn_map_with_none_support) + + +class LoadDatasetFn(Protocol): + def __call__( + self, dataset_path: str, content_field: str, keep_in_memory: bool = False + ) -> Mapping[str, Dataset]: ... + + +def load_from_disk_fn( + dataset_path: str, content_field: str, keep_in_memory: bool = False +) -> Mapping[str, Dataset]: + return datasets.load_from_disk(dataset_path, keep_in_memory=keep_in_memory) + + +def load_streaming_fn( + dataset_path: str, content_field: str, keep_in_memory: bool = False +) -> Mapping[str, Dataset]: + dataset = datasets.load_dataset( + dataset_path, + streaming=True, + features=datasets.Features( + { + content_field: datasets.Value(dtype="string"), + } + ), + keep_in_memory=keep_in_memory, + ) + + return dataset + + +def create_validation_dataloader( + accelerator: Accelerator | None, + seed: int, + tokenizer: PreTrainedTokenizerBase, + block_size: int, + dataset: str | Mapping[str, Dataset], + content_field: str, + fim_rate: float, + fim_spm_rate: float, + micro_batch_size: int, + eval_samples: int | None = None, + load_dataset_fn: LoadDatasetFn = load_from_disk_fn, + dataset_name: str = "__auto__", + keep_in_memory: bool = False, + source_datasets_to_discard: Sequence[str] = (), + bos_rate: float = 1.0, + varlen: bool = True, + shuffle_seed: int | None = None, +): + if accelerator is None: + accelerator = Printer() + + if accelerator.is_main_process: + if isinstance(dataset, str): + dataset = load_dataset_fn(dataset, content_field, keep_in_memory) + + if isinstance(dataset, datasets.Dataset | torch.utils.data.Dataset): + valid_data = dataset + mprint( + "#### Path to specific dataset was given (not DatasetDict), taking it as-is ####" + ) + else: + assert isinstance(dataset, datasets.DatasetDict) + if dataset_name == "__auto__": + val_split_options = [] + for val_key_prefix in ("val", "test"): + if len(val_split_options) == 0: + val_split_options = [ + split + for split in dataset # DatasetDict is dict-like and supports direct iteration + if split.lower().startswith(val_key_prefix) + ] + assert len(val_split_options) == 1, ( + f"Expected exactly one validation split, got {val_split_options=} ({dataset.keys()=})" + ) + val_split = val_split_options[0] + mprint(f"Inferred validation split automatically: '{val_split}'") + else: + val_split = dataset_name + mprint(f"Validation split explicitly chosen: '{val_split}'") + valid_data = dataset[val_split] + + if shuffle_seed is not None: + mprint(f"Shuffling with {shuffle_seed=}") + valid_data = valid_data.shuffle(seed=shuffle_seed) + + valid_dataset = ConstantLengthDataset( + tokenizer, + valid_data, + infinite=False, + seq_length=block_size * micro_batch_size if varlen else block_size, + content_field=content_field, + fim_rate=fim_rate, + fim_spm_rate=fim_spm_rate, + seed=seed, + source_datasets_to_discard=source_datasets_to_discard, + bos_rate=bos_rate, + # return_cu_seqlens=varlen, + # seqlen_cap=block_size if varlen else None + ) + if varlen and eval_samples is not None: + eval_samples = eval_samples // micro_batch_size + val_offloaded_dataset = realize_dataset_in_memory(valid_dataset, eval_samples) + + valid_data_len = len(val_offloaded_dataset) + mprint(f"num validation examples = {valid_data_len}") + else: + val_offloaded_dataset = None + + if not isinstance(accelerator, Printer): + obj_list = [val_offloaded_dataset] + torch.distributed.broadcast_object_list(obj_list) + val_offloaded_dataset = obj_list[0] + + # let accelerate prepare to handle distributed sampling + val_dataloader = DataLoader( + val_offloaded_dataset, + batch_size=1 if varlen else micro_batch_size, + pin_memory=True, + collate_fn=collate_fn_with_none_support, + ) + + return val_dataloader + + +def realize_dataset_in_memory(dataset: IterableDataset, eval_samples: int | None) -> list[dict]: + tqdm_desc = f"realize_dataset_in_memory({eval_samples=})" + if eval_samples is None: + offloaded_dataset = list(tqdm(dataset, desc=tqdm_desc)) + else: + val_iter = iter(dataset) + offloaded_dataset = [next(val_iter) for _ in tqdm(range(eval_samples), desc=tqdm_desc)] + return offloaded_dataset + + +TensorT = TypeVar("TensorT", bound=torch.Tensor) + + +@torch.no_grad() +def create_padded_tensor( + tensor: TensorT, desired_shape: Sequence[int], padding_value: float = 0 +) -> TensorT: + if tensor.shape == torch.Size(desired_shape): + return tensor + + padded_tensor = torch.full( + desired_shape, fill_value=padding_value, dtype=tensor.dtype, device=tensor.device + ) + indices = torch.where(torch.ones_like(tensor, dtype=torch.bool)) + padded_tensor[indices] = tensor.view(-1) + return padded_tensor + + +class Printer: + is_main_process = True + process_index = None + + @staticmethod + def print(*args, **kwargs) -> None: + print(*args, **kwargs) diff --git a/modelopt/torch/puzzletron/utils/data/dataset.py b/modelopt/torch/puzzletron/utils/data/dataset.py new file mode 100644 index 0000000000..fffc2a3a1d --- /dev/null +++ b/modelopt/torch/puzzletron/utils/data/dataset.py @@ -0,0 +1,314 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# mypy: ignore-errors +import functools +from collections.abc import Sequence + +import numpy as np +import torch +from torch.utils.data import IterableDataset + +FIM_TOKEN_START = "", "middle>", "suffix>", "pad>"] +CODEGEN_FIM_TOKENS = ["", "<|endoftext|>", ""] + + +class ConstantLengthDataset(IterableDataset): + """Iterable dataset that returns constant length chunks of tokens from stream of text files. + + Args: + tokenizer (Tokenizer): The processor used for proccessing the data. + dataset (dataset.Dataset): Dataset with text files. + infinite (bool): If True the iterator is reset after dataset reaches end else stops. + seq_length (int): Length of token sequences to return. + num_of_sequences (int): Number of token sequences to keep in buffer. + chars_per_token (int): Number of characters per token used to estimate number of tokens in text buffer. + fim_rate (float): Rate (0.0 to 1.0) that sample will be permuted with FIM. + fim_spm_rate (float): Rate (0.0 to 1.0) of FIM permuations that will use SPM. + seed (int): Seed for random number generator. + label_shift (bool): Whether to shift labels by 1 or not. + """ + + def __init__( + self, + tokenizer, + dataset, + infinite=False, + seq_length=1024, + num_of_sequences=1024, + chars_per_token=3.6, + content_field="content", + fim_rate=0.5, + fim_spm_rate=0.5, + seed=0, + label_shift=True, + max_sample_length=200_000, + tokens_field="token_ids", + source_datasets_to_discard: Sequence[str] | None = tuple(), + bos_rate: float = 1.0, + return_cu_seqlens: bool = False, + seqlen_cap: int | None = None, + ): + self.tokenizer = tokenizer + self.concat_token_id = tokenizer.eos_token_id + # self.concat_token_id = tokenizer.eos_id # for lit-lamma tokenizer + self.dataset = dataset + self.is_dataset_already_tokenized = tokens_field in self.dataset.column_names + self.seq_length = seq_length + self.infinite = infinite + self.current_size = 0 + if not self.is_dataset_already_tokenized: + self.max_buffer_size = seq_length * chars_per_token * num_of_sequences + self.max_sample_length = max_sample_length + else: + self.max_buffer_size = seq_length * num_of_sequences + # self.max_sample_length = int(max_sample_length / chars_per_token) + self.max_sample_length = max_sample_length # we don't know the exact chars_per_token + self.content_field = content_field + self.tokens_field = tokens_field + self.fim_rate = fim_rate + self.fim_spm_rate = fim_spm_rate + self.seed = seed + self.max_sample_length = max_sample_length + + self.fim_token_ids = get_fim_token_ids(self.tokenizer) + if None in self.fim_token_ids.values() and self.fim_rate > 0: + self.fim_rate = 0 + self.label_shift = label_shift + self.bos_rate = bos_rate + self.source_datasets_to_discard = ( + source_datasets_to_discard if source_datasets_to_discard is not None else tuple() + ) + self.return_cu_seqlens = return_cu_seqlens + self.seqlen_cap = seqlen_cap + self.np_rng = np.random.RandomState(seed=self.seed) + + def __iter__(self) -> dict[str, torch.Tensor]: + iterator = iter(self.dataset) + more_examples = True + while more_examples: + buffer, buffer_len = [], 0 + while True: + if buffer_len >= self.max_buffer_size: + break + try: + sample = next(iterator) + if ( + len(self.source_datasets_to_discard) > 0 + and sample["dataset_name"] in self.source_datasets_to_discard + ): + continue + if not self.is_dataset_already_tokenized: + sample = sample[self.content_field] + if ( + isinstance(sample, list) + and isinstance(sample[0], dict) + and {"content", "role"}.issubset(sample[0]) + ): + if len(sample) > 1: + sample = self.tokenizer.apply_chat_template(sample, tokenize=False) + else: + sample = sample[0]["content"] + else: + sample = sample[self.tokens_field] + sample = sample[: self.max_sample_length] + buffer.append(sample) + buffer_len += len(sample) + except StopIteration: + if self.infinite: + iterator = iter(self.dataset) + else: + more_examples = False + break + + if not self.is_dataset_already_tokenized: + tokenized_inputs = self.tokenizer(buffer, truncation=False)["input_ids"] + else: + tokenized_inputs = buffer + + all_token_ids = [] + + for tokenized_input in tokenized_inputs: + if ( + self.bos_rate < 1.0 + and not self.np_rng.binomial(1, self.bos_rate) + and self.tokenizer.bos_token_id is not None + and tokenized_input[0] == self.tokenizer.bos_token_id + ): + tokenized_input = tokenized_input[1:] + # optionally do FIM permutations + if self.fim_rate > 0: + tokenized_input, np_rng = permute( + sample=tokenized_input, + np_rng=self.np_rng, + fim_token_ids=self.fim_token_ids, + fim_rate=self.fim_rate, + fim_spm_rate=self.fim_spm_rate, + truncate_or_pad=False, + ) + + all_token_ids.extend(tokenized_input + [self.concat_token_id]) + + examples = [] + # cuts code snippets in the middle to yield constant length instances + for i in range(0, len(all_token_ids), self.seq_length): + input_ids = all_token_ids[i : i + self.seq_length] + labels = all_token_ids[ + i + int(self.label_shift) : i + int(self.label_shift) + self.seq_length + ] + # ignores last short example in the buffer + if len(labels) == self.seq_length: + examples.append((input_ids, labels)) + + shuffling_indices = self.np_rng.permutation(len(examples)) + examples = [examples[i] for i in shuffling_indices] + + for input_ids, labels in examples: + self.current_size += 1 + input_ids = torch.LongTensor(input_ids) + if self.return_cu_seqlens: + cu_seqlens = self.prepare_cu_seqlens(input_ids) + yield { + "input_ids": input_ids, + "targets": torch.LongTensor(labels), + "cu_seqlens": cu_seqlens, + } + else: + yield { + "input_ids": input_ids, + "targets": torch.LongTensor(labels), + } + + def prepare_cu_seqlens(self, input_ids): + if not self.return_cu_seqlens: + return None + # seqlens is of shape (num_seqs+1,) and with the property that + # the i-th sequnce is input_ids[seqlens[i-1]:seqlens[i]] + cu_seqlens = (input_ids == self.concat_token_id).nonzero().squeeze(-1).int() + 1 + cu_seqlens = torch.cat( + ( + torch.IntTensor([0]), + cu_seqlens, + torch.IntTensor([len(input_ids)]), + ) + ) + if self.seqlen_cap is not None: + i = 1 + while i < len(cu_seqlens): + curr_seqlen = cu_seqlens[i] - cu_seqlens[i - 1] + if curr_seqlen > self.seqlen_cap: + cu_seqlens = torch.cat( + (cu_seqlens[:i], cu_seqlens[[i - 1]] + self.seqlen_cap, cu_seqlens[i:]) + ) + i += 1 + if cu_seqlens[-1] == cu_seqlens[-2]: + cu_seqlens = cu_seqlens[:-1] + return cu_seqlens + + +## Adapted from https://github.com/NVIDIA/Megatron-LM/blob/6c4bf908df8fd86b4977f54bf5b8bd4b521003d1/megatron/data/gpt_dataset.py +def permute( + sample, + np_rng, + fim_token_ids, + fim_rate=0.5, + fim_spm_rate=0.5, + truncate_or_pad=False, +): + """Take in a sample (list of tokens) and perform a FIM transformation on it with a probability of fim_rate, using two FIM modes: + PSM and SPM (with a probability of fim_spm_rate). + """ + if np_rng.binomial(1, fim_rate): + boundaries = list(np_rng.randint(low=0, high=len(sample) + 1, size=2)) + boundaries.sort() + + prefix = np.array(sample[: boundaries[0]], dtype=np.int64) + middle = np.array(sample[boundaries[0] : boundaries[1]], dtype=np.int64) + suffix = np.array(sample[boundaries[1] :], dtype=np.int64) + + if truncate_or_pad: + raise NotImplementedError + + if "" in fim_token_ids: # use codegen FIM pattern + assert fim_spm_rate == 0 + new_sample = np.concatenate( + [ + prefix, + [fim_token_ids[""]], + suffix, + [fim_token_ids["<|endoftext|>"]], + [fim_token_ids[""]], + [fim_token_ids[""]], + middle, + ] + ) + elif np_rng.binomial(1, fim_spm_rate): + # SPM (variant 2 from FIM paper) + new_sample = np.concatenate( + [ + [fim_token_ids["prefix_tok_id"], fim_token_ids["suffix_tok_id"]], + suffix, + [fim_token_ids["middle_tok_id"]], + prefix, + middle, + ] + ) + else: + # PSM + new_sample = np.concatenate( + [ + [fim_token_ids["prefix_tok_id"]], + prefix, + [fim_token_ids["suffix_tok_id"]], + suffix, + [fim_token_ids["middle_tok_id"]], + middle, + ] + ) + else: + # don't do FIM preproc + new_sample = sample + + return list(new_sample), np_rng + + +# this is expensive so we cache it +@functools.lru_cache(maxsize=None) +def get_fim_token_ids(tokenizer): + # ugly fix for Salesforce/codegen25-7b-multi tokenizer + if hasattr(tokenizer, "encoder"): + search_vocab = tokenizer.encoder._special_tokens + fim_token_ids = {tok: search_vocab.get(tok, None) for tok in CODEGEN_FIM_TOKENS} + else: + search_vocab = tokenizer.vocab + if (FIM_TOKEN_START + FIM_TOKEN_CONNECTOR_STAR + FIM_TOKEN_END_LIST[0]) in search_vocab: + prefix_tok_id, middle_tok_id, suffix_tok_id, pad_tok_id = ( + search_vocab.get(FIM_TOKEN_START + FIM_TOKEN_CONNECTOR_STAR + tok, None) + for tok in FIM_TOKEN_END_LIST + ) + else: + prefix_tok_id, middle_tok_id, suffix_tok_id, pad_tok_id = ( + search_vocab.get(FIM_TOKEN_START + FIM_TOKEN_CONNECTOR_SANTA + tok, None) + for tok in FIM_TOKEN_END_LIST + ) + fim_token_ids = { + "suffix_tok_id": suffix_tok_id, + "prefix_tok_id": prefix_tok_id, + "middle_tok_id": middle_tok_id, + "pad_tok_id": pad_tok_id, + } + return fim_token_ids diff --git a/modelopt/torch/puzzletron/utils/dummy_modules.py b/modelopt/torch/puzzletron/utils/dummy_modules.py new file mode 100644 index 0000000000..c9eaa2bc6c --- /dev/null +++ b/modelopt/torch/puzzletron/utils/dummy_modules.py @@ -0,0 +1,75 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Optional + +import torch +import torch.nn as nn +from transformers import PretrainedConfig +from typing_extensions import override + + +class DummyModule(nn.Module): + def __init__(self, *args, **kwargs) -> None: + super().__init__(*args, **kwargs) + self.register_load_state_dict_post_hook(self.load_state_dict_post_hook) + + @staticmethod + def load_state_dict_post_hook( + module: torch.nn.Module, + incompatible_keys: torch.nn.modules.module._IncompatibleKeys, + ) -> None: + incompatible_keys.missing_keys.clear() + incompatible_keys.unexpected_keys.clear() + + +class DummyBlock(DummyModule): + def __init__(self, block_index: int): + super().__init__() + self.block_index = block_index + + @override + def forward( + self, + x: torch.Tensor, + *args, + **kwargs, + ) -> torch.Tensor | tuple[torch.Tensor, None]: + return x + + +class DummyWTE(DummyModule): + def __init__(self, hidden_size: int, dtype: Optional[torch.dtype] = None): + super().__init__() + self.n_embd = hidden_size + self.dtype = dtype + + @override + def forward(self, input_ids: torch.Tensor) -> torch.Tensor: + B, T = input_ids.shape + result = torch.ones((B, T, self.n_embd), dtype=self.dtype, device=input_ids.device) + return result + + +class DummyLMHead(DummyModule): + def __init__(self, config: PretrainedConfig): + super().__init__() + self.vocab_size = config.vocab_size + + @override + def forward(self, x: torch.Tensor) -> torch.Tensor: + B, T, C = x.shape + result = torch.ones((B, T, self.vocab_size), dtype=x.dtype, device=x.device) + return result diff --git a/modelopt/torch/puzzletron/utils/parsing.py b/modelopt/torch/puzzletron/utils/parsing.py new file mode 100644 index 0000000000..ff5bb6963a --- /dev/null +++ b/modelopt/torch/puzzletron/utils/parsing.py @@ -0,0 +1,455 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Parsing and formatting utilities for configuration handling in model compression. + +This module provides utilities for: +- Parsing command-line arguments and configuration strings +- Formatting and displaying model configurations (block configs, attention, FFN) +- Formatting loss metrics for logging and visualization +""" +# mypy: ignore-errors + +import json +from pathlib import Path +from typing import Any + +import torch +from omegaconf import DictConfig + + +def handle_arg_string(arg): + if arg.lower() == "true": + return True + elif arg.lower() == "false": + return False + elif arg.isnumeric(): + return int(arg) + try: + return float(arg) + except ValueError: + return arg + + +def simple_parse_args_string(args_string): + """ + Parses something like + args1=val1,arg2=val2 + Into a dictionary + """ + if args_string is None: + return {} + args_string = args_string.strip() + if not args_string: + return {} + arg_list = [arg for arg in args_string.split(",") if arg] + args_dict = {k: handle_arg_string(v) for k, v in [arg.split("=") for arg in arg_list]} + return args_dict + + +def parse_json(s: str | None) -> Any: + if s is None: + return None + return json.loads(s) + + +def parse_path(s: str | None) -> Path | None: + if s is None or s == "": + return None + return Path(s) + + +def parse_dtype(dtype_name: str) -> torch.dtype: + dtype = { + "bf16": torch.bfloat16, + "bfloat16": torch.bfloat16, + "fp32": torch.float32, + "float32": torch.float32, + "fp16": torch.float16, + "float16": torch.float16, + }[dtype_name] + return dtype + + +def get_nested_key(dictionary: dict[str, Any], nested_key: str) -> Any: + """ + If nested_key is "a.b.c" returns dictionary["a"]["b"]["c"] + """ + value = dictionary + for key in nested_key.split("."): + value = value[key] + return value + + +def format_block_configs(config) -> str: + """ + Formats block_configs from a model configuration into a beautiful, readable string. + + Each line represents a layer with attention and FFN configuration. + + Args: + config: PretrainedConfig object containing block_configs + + Returns: + Formatted string with layer configurations + + Example output: + ╭─────────────────────── Model Architecture ────────────────────────╮ + │ Layer 1 │ Attention: no_op │ FFN: mult = 4.95 │ + │ Layer 2 │ Attention: 4 heads in group │ FFN: mult = 4.95 │ + │ Layer 3 │ Attention: 4 heads in group │ FFN: no_op │ + ╰────────────────────────────────────────────────────────────────────╯ + """ + if not hasattr(config, "block_configs") or not config.block_configs: + return "❌ No block configs found" + + lines = [] + + # Header + header = "╭─────────────────────────────────────── Model Architecture ────────────────────────────────────────╮" + lines.append(header) + + # Format each layer + for i, block in enumerate(config.block_configs, 1): + attention_info = _format_attention_config(block.attention) + ffn_info = _format_ffn_config(block.ffn) + + # Create formatted line with proper padding + layer_str = f"Layer {i:2d}" + attention_str = f"Attention: {attention_info}" + ffn_str = f"FFN: {ffn_info}" + + line = f"│ {layer_str:8s} │ {attention_str:30s} │ {ffn_str:18s} │" + lines.append(line) + + # Footer + footer = "╰────────────────────────────────────────────────────────────────────────────────────────────────────╯" + lines.append(footer) + + return "\n".join(lines) + + +def _format_attention_config(attention_config) -> str: + """Format attention configuration for display with visual indicators.""" + if not attention_config: + return "default" + + if attention_config.no_op: + return "❌ no_op" + + num_kv_heads = attention_config.num_key_value_heads + if num_kv_heads is not None: + return f"{num_kv_heads} kv heads" + + if attention_config.replace_with_linear: + return "linear replacement" + + # Check for other attention types + if attention_config.mamba: + return "🐍 mamba" + if attention_config.llama4: + return "🦙 llama4" + + window_length = attention_config.window_length + if window_length is not None: + return f"windowed ({window_length})" + + if attention_config.sparsify: + return "sparse" + + return "default" + + +def _format_ffn_config(ffn_config) -> str: + """Format FFN configuration for display with visual indicators.""" + if not ffn_config: + return "default" + + if ffn_config.no_op: + return "❌ no_op" + + if ffn_config.replace_with_linear: + return "linear" + + ffn_intermediate = ffn_config.intermediate_size + if ffn_intermediate is not None: + return f"ffn_intermediate = {ffn_intermediate}" + + # Check for MoE configuration + moe_config = ffn_config.moe + if moe_config: + return "MoE" + + if ffn_config.sparsify: + return "sparse" + + return "default" + + +def format_global_config(config: DictConfig, title: str = "Global Configuration") -> str: + """ + Pretty prints a global DictConfig with nice formatting and visual indicators. + + Args: + config: DictConfig object to format + title: Title to display at the top of the formatted output + + Returns: + Formatted string with configuration details + + Example output: + ╭─────────────────── Global Configuration ────────────────────╮ + │ Training │ + │ • learning_rate: 1e-4 │ + │ • batch_size: 32 │ + │ • epochs: 100 │ + │ Model │ + │ • hidden_dim: 512 │ + │ • num_layers: 6 │ + │ Data │ + │ • dataset_path: /path/to/data │ + │ • block_size: 2048 │ + ╰──────────────────────────────────────────────────────────────╯ + """ + if not config: + return "❌ No configuration found" + + lines = [] + + # Calculate box width based on title + box_width = max(60, len(title) + 10) + title_padding = (box_width - len(title) - 2) // 2 + + # Header + header = f"\n╭{'─' * (box_width - 2)}╮" + title_line = ( + f"│{' ' * title_padding}{title}{' ' * (box_width - 2 - title_padding - len(title))}│" + ) + lines.extend([header, title_line]) + + def _format_value(value: Any, indent: int = 0) -> str: + """Format a value with appropriate type indicators.""" + prefix = " " * indent + + if isinstance(value, (bool, int, float)): + return f"{prefix} {value}" + elif isinstance(value, str): + # Show truncated long strings + if len(value) > 50: + return f"{prefix} {value[:47]}..." + return f"{prefix} {value}" + elif isinstance(value, (list, tuple)): + if not value: + return f"{prefix} []" + elif len(value) <= 3: + return f"{prefix} {list(value)}" + else: + return f"{prefix} [{len(value)} items]" + elif value is None: + return f"{prefix} None" + else: + return f"{prefix} {value!s}" + + def _add_config_section(cfg: DictConfig, section_name: str = "", indent: int = 0): + """Recursively add configuration sections.""" + if section_name: + indent_str = " " * indent + section_line = f"│ {indent_str}{section_name}" + # Pad to box width + padding_needed = box_width - len(section_line) - 1 + section_line += " " * padding_needed + "│" + lines.append(section_line) + + for key, value in cfg.items(): + if isinstance(value, DictConfig): + # Nested configuration section + _add_config_section(value, f"{key}", indent + 1) + else: + # Regular key-value pair + indent_str = " " * (indent + 1) + value_str = _format_value(value).replace(" " * 0, "").strip() + line = f"│ {indent_str} {key}: {value_str}" + # Pad to box width + if len(line) >= box_width - 1: + # Truncate long lines + line = line[: box_width - 4] + "..." + padding_needed = box_width - len(line) - 1 + line += " " * padding_needed + "│" + lines.append(line) + + # Add configuration sections + _add_config_section(config) + + # Footer + footer = f"╰{'─' * (box_width - 2)}╯" + lines.append(footer) + + return "\n".join(lines) + + +def format_stitched_losses( + losses_dict: dict[str, float], + best_steps_dict: dict[str, int] | None = None, + best_values_dict: dict[str, float] | None = None, + step_number: int | None = None, + title: str = "Stitched Module Losses", +) -> str: + """ + Pretty prints stitched module losses with comprehensive tracking and visual indicators. + + Args: + losses_dict: Dictionary with block names as keys and current loss values as floats + best_steps_dict: Optional dictionary with block names as keys and best step numbers as values + best_values_dict: Optional dictionary with block names as keys and best loss values as floats + step_number: Optional current step number to include in summary + title: Title to display at the top of the formatted output + + Returns: + Formatted string with loss values in a comprehensive table format + + Example output: + ╭─────────────────── Stitched Module Losses ──────────────────╮ + │ Block │ Loss Value │ Best Step │ Best Value │ Change from avg │ + │───────┼────────────┼───────────┼────────────┼──────────────────│ + │ 00 │ 6.21e-03 │ Step 5 │ 5.95e-03 │ ↑ +2.6e-04 │ + │ 01 │ 5.14e-04 │ Step 12 │ 5.14e-04 │ ↓ -1.2e-04 │ + │ 02 │ 9.84e-05 │ Step 15 │ 9.84e-05 │ ↓ -3.1e-04 │ + ╰──────────────────────────────────────────────────────────────╯ + """ + if not losses_dict: + return "❌ No losses found" + + lines = [] + + # Calculate statistics + loss_values = list(losses_dict.values()) + max_loss = max(loss_values) + min_loss = min(loss_values) + avg_loss = sum(loss_values) / len(loss_values) + + # Calculate box width for new layout (removed Bar column) + box_width = 74 + title_padding = (box_width - len(title) - 2) // 2 + + # Header + header = f"╭{'─' * (box_width - 2)}╮" + title_line = ( + f"│{' ' * title_padding}{title}{' ' * (box_width - 2 - title_padding - len(title))}│" + ) + separator = ( + f"│ {'Block':<5} │ {'Loss Value':<12} │ {'Best Step':<10} │ " + f"{'Best Value':<12} │ {'Change from avg':<18} │" + ) + divider = f"│{'─' * 7}┼{'─' * 14}┼{'─' * 12}┼{'─' * 14}┼{'─' * 20}│" + + lines.extend([header, title_line, separator, divider]) + + # Format each loss + for block_name, loss_value in losses_dict.items(): + # Format current loss value + loss_str = f"{loss_value:.2e}" + + # Format best step + if best_steps_dict and block_name in best_steps_dict: + best_step_str = f"Step {best_steps_dict[block_name]}" + else: + best_step_str = " --" + + # Format best value + if best_values_dict and block_name in best_values_dict: + best_value = best_values_dict[block_name] + best_value_str = f"{best_value:.2e}" + else: + best_value = loss_value # Assume current is best if no history + best_value_str = f"{best_value:.2e}" + + # Calculate change from average + change_from_avg = loss_value - avg_loss + if abs(change_from_avg) > 1e-8: # Only show if meaningful + change_str = f"{abs(change_from_avg):.1e}" + if change_from_avg > 0: + # Current is above average (worse for loss) + change_display = f"↑ +{change_str}" + else: + # Current is below average (better for loss) + change_display = f"↓ -{change_str}" + else: + # At average value + change_display = "↔ 0.0e+00" + + # Format the line + block_display = block_name.replace("block_", "").zfill(2) + + line = ( + f"│ {block_display:<5} │ {loss_str:<12} │ {best_step_str:<10} │ " + f"{best_value_str:<12} │ {change_display:<18} │" + ) + lines.append(line) + + # Add summary statistics + lines.append(divider) + + # Build summary string with optional step number + summary_parts = [] + if step_number is not None: + summary_parts.append(f"Step {step_number}") + summary_parts.extend([f"Avg={avg_loss:.2e}", f"Max={max_loss:.2e}", f"Min={min_loss:.2e}"]) + + summary_text = ", ".join(summary_parts) + summary = f"│ Summary: {summary_text}" + + # Pad summary to box width + padding_needed = box_width - len(summary) - 1 + summary += " " * padding_needed + "│" + lines.append(summary) + + # Add best step summary if we have best step data + if best_steps_dict and best_values_dict: + # Find the most common best step (modal step) + step_counts = {} + for step in best_steps_dict.values(): + step_counts[step] = step_counts.get(step, 0) + 1 + + if step_counts: + modal_best_step = max(step_counts, key=step_counts.get) + + # Get values at the modal best step for blocks that have it as their best + best_step_values = [] + for block_name, best_step in best_steps_dict.items(): + if best_step == modal_best_step and block_name in best_values_dict: + best_step_values.append(best_values_dict[block_name]) + + if best_step_values: + best_step_avg = sum(best_step_values) / len(best_step_values) + best_step_max = max(best_step_values) + best_step_min = min(best_step_values) + + best_step_summary_text = ( + f"Best: Step {modal_best_step}, Avg={best_step_avg:.2e}, " + f"Max={best_step_max:.2e}, Min={best_step_min:.2e}" + ) + best_step_summary = f"│ {best_step_summary_text}" + + # Pad best step summary to box width + padding_needed = box_width - len(best_step_summary) - 1 + best_step_summary += " " * padding_needed + "│" + lines.append(best_step_summary) + + # Footer + footer = f"╰{'─' * (box_width - 2)}╯" + lines.append(footer) + + return "\n".join(lines) diff --git a/modelopt/torch/puzzletron/utils/utils.py b/modelopt/torch/puzzletron/utils/utils.py new file mode 100644 index 0000000000..77a13609aa --- /dev/null +++ b/modelopt/torch/puzzletron/utils/utils.py @@ -0,0 +1,253 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import dataclasses +import json +import os +from copy import deepcopy +from typing import Any + +import torch + +from modelopt.torch.puzzletron.decilm.deci_lm_hf_code.block_config import ( + AttentionConfig, + BlockConfig, + FFNConfig, +) + + +def calculate_kv_dim(num_key_value_heads: int, n_head: int, n_embd: int) -> int: + """Calculate the key-value dimension for grouped-query attention. + + Args: + num_key_value_heads: Number of key-value heads. + n_head: Total number of attention heads. + n_embd: Embedding dimension. + + Returns: + Combined dimension for key and value tensors (2 * num_key_value_heads * head_size). + """ + if num_key_value_heads is None: + return 0 + head_size = n_embd // n_head + kv_dim = 2 * num_key_value_heads * head_size + return kv_dim + + +def raise_unknown_subblock_config_error(subblock_config: Any) -> None: + """Raise an error for invalid subblock configuration types. + + TODO: Consider a better place for this function. + Args: + subblock_config: The invalid subblock configuration object. + + Raises: + ValueError: Always raised with a message indicating the expected types. + """ + raise ValueError( + f"subblock_config should be an instance of FFNConfig or AttentionConfig, instead got {type(subblock_config)}" + ) + + +def sizeof_dtype(dtype: torch.dtype) -> int | float: + """Return the size in bytes of the given data type. + + TODO: Consider a better place for this function. + Args: + dtype: PyTorch data type or custom type string (e.g., 'nvfp4'). + + Returns: + Size in bytes of the data type. Special case: 'nvfp4' returns ~0.588 bytes. + """ + if dtype == "nvfp4": + return 1 / 1.7 + return torch.tensor([], dtype=dtype).element_size() + + +def load_json(file_path: str): + """Load and parse a JSON file. + + TODO: Consider a better place for this function. + + Args: + file_path: Path to the JSON file to load. + + Returns: + Parsed JSON data as a Python object, or None if the file doesn't exist. + """ + if not os.path.exists(file_path): + print("file does not exist {file_path}") + return None + + with open(file=file_path) as f: + return json.load(f) + + +def solution_to_str(block_configs: list[dict[str, Any] | BlockConfig]) -> str: + """Convert a list of block configurations to a human-readable string representation. + + TODO: Consider a better place for this function. + Better place for this and subsequent related function would be in __repr__ function in class + BlockConfig so when we print it or do str(block_config), it automatically + prints in this custom formatted string + + Args: + block_configs: List of BlockConfig dataclasses or dicts containing layer configurations. + + Returns: + Multi-line string with each block's configuration on a separate line. + """ + block_configs = deepcopy(block_configs) + reps = [] + for block_idx, block_config in enumerate(block_configs): + rep = f"block_{block_idx}:".ljust(9) + rep += block_config_to_str(block_config) + reps.append(rep) + rep = "\n".join(reps) + "\n" + return rep + + +def block_config_to_str(block_config: BlockConfig | dict[str, Any] | None) -> str | None: + """ + Convert a BlockConfig to a human-readable string representation. + + TODO: Consider a better place for this function. + Args: + block_config: BlockConfig dataclass or dict containing attention and ffn configs. + + Returns: + Formatted string with attention and FFN information, or None if input is None. + """ + if block_config is None: + return None + rep = "" + if dataclasses.is_dataclass(block_config): + block_config = dataclasses.asdict(block_config) + for subblock_name in ["attention", "ffn"]: + subblock_config = block_config[subblock_name] + rep += subblock_config_to_str(subblock_config, subblock_name) + return rep + + +def subblock_config_to_str( + subblock_config: FFNConfig | AttentionConfig | dict[str, Any] | None, + subblock_name: None | str = None, +) -> str | None: + """Convert a subblock config (FFN, Attention, Mamba, or MoE) to string. + + TODO: Consider a better place for this function. + Args: + subblock_config: FFNConfig, AttentionConfig dataclass or dict. + subblock_name: Name of subblock ('ffn', 'attention', 'mamba', 'moe'). + Auto-detected if subblock_config is a dataclass. + + Returns: + Formatted string showing subblock type and key parameters (e.g., intermediate_size, + num_key_value_heads), or None if input is None. + """ + if subblock_config is None: + return None + subblock_name = ( + "ffn" + if isinstance(subblock_config, FFNConfig) + else "mamba" + if isinstance(subblock_config, AttentionConfig) and subblock_config.is_mamba + else "attention" + if isinstance(subblock_config, AttentionConfig) + else subblock_name + ) + assert subblock_name is not None, "Must provide subblock_name if subblock_config is a dict." + + if dataclasses.is_dataclass(subblock_config): + subblock_config = dataclasses.asdict(subblock_config) + + if subblock_name == "attention" and subblock_config.get("mamba") is not None: + subblock_name = "mamba" + + if subblock_name == "ffn" and subblock_config.get("moe") is not None: + subblock_name = "moe" + + rep = f" {subblock_name}" + if subblock_config.get("no_op"): + rep += " no_op".ljust(8) + elif subblock_config.get("replace_with_linear"): + rep += " linear".ljust(8) + elif subblock_name == "ffn": + intermediate_size = subblock_config["intermediate_size"] + rep += f" intermediate_{intermediate_size}".ljust(8) + elif subblock_name == "attention": + num_key_value_heads = subblock_config["num_key_value_heads"] + rep += f" kv_heads_{num_key_value_heads}".ljust(8) + elif subblock_name == "mamba": + mamba_num_heads = subblock_config["mamba"]["num_heads"] + mamba_head_dim = subblock_config["mamba"]["head_dim"] + rep += f" num_heads_{mamba_num_heads} head_dim_{mamba_head_dim}".ljust(8) + elif subblock_name == "moe": + moe_num_local_experts = subblock_config["moe"]["num_local_experts"] + moe_expert_intermediate_dim = subblock_config["moe"]["expert_intermediate_dim"] + shared_expert_intermediate_dim = subblock_config["moe"]["shared_expert_intermediate_dim"] + num_experts_per_tok = subblock_config["moe"]["num_experts_per_tok"] + rep += f" num_experts_{moe_num_local_experts} expert_intermediate_dim_{moe_expert_intermediate_dim} shared_expert_intermediate_dim_{shared_expert_intermediate_dim} num_experts_per_tok_{num_experts_per_tok}".ljust( + 8 + ) + else: + raise ValueError(f"subblock_config_to_str: unrecognized subblock_name: {subblock_name}.") + + return rep + + +class EmptyInitOnDevice(torch.overrides.TorchFunctionMode): + def __init__(self, device=None, dtype=None): + """ + Create tensors with given device and dtype and don't run initialization + (but instead use "empty tensors", i.e. uninitialized memory). + + device: `torch.device` to work with + dtype: `torch.dtype` to work with + + Example:: + with EmptyInitOnDevice("cuda", dtype=torch.bfloat16): + model = LLaMA(model_config) + model.load_state_dict(torch.load("llama-lit/7B/lit-llama.pth"))""" + + self.device = device + self.dtype = dtype + + def __enter__(self): + return super().__enter__() + + def __exit__(self, exc_type, exc_val, exc_tb): + return super().__exit__(exc_type, exc_val, exc_tb) + + def __torch_function__(self, func, types, args=(), kwargs=None): + kwargs = kwargs or {} + if getattr(func, "__module__", None) == "torch.nn.init": + if "tensor" in kwargs: + return kwargs["tensor"] + else: + return args[0] + if ( + self.device is not None + and func in torch.utils._device._device_constructors() + and kwargs.get("device") is None + ): + kwargs["device"] = self.device + if ( + self.dtype is not None + and func in torch.utils._device._device_constructors() + and kwargs.get("dtype") is None + ): + kwargs["dtype"] = self.dtype + return func(*args, **kwargs) diff --git a/modelopt/torch/puzzletron/utils/validate_runtime_pipeline.py b/modelopt/torch/puzzletron/utils/validate_runtime_pipeline.py new file mode 100644 index 0000000000..90fea13c56 --- /dev/null +++ b/modelopt/torch/puzzletron/utils/validate_runtime_pipeline.py @@ -0,0 +1,302 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Model evaluation utilities for models split across multiple GPUs in pipeline-parallel mode. + +Coordinates forward passes and loss computation through model shards distributed across GPUs +using sewing_kit's StitchedModule framework. Relies on validation.py for core loss computation. + +Used by validate_model.py during activation scoring for sharded models. +""" +# mypy: ignore-errors + +import traceback +from contextlib import nullcontext +from typing import Type + +import numpy as np +import torch +from torch.utils.data import DataLoader +from tqdm import tqdm + +import modelopt.torch.utils.distributed as dist +from modelopt.torch.puzzletron.anymodel.model_descriptor import ModelDescriptor +from modelopt.torch.puzzletron.decilm.deci_lm_hf_code.modeling_decilm import LMHead +from modelopt.torch.puzzletron.sewing_kit import ( + ExternalTarget, + InputArgs, + ModuleTarget, + Needle, + RemoteTarget, + StitchedModule, +) +from modelopt.torch.puzzletron.sewing_kit.core import InputReducer +from modelopt.torch.puzzletron.sewing_kit.utils import ( + distributed_recv_obj, + distributed_send_obj, + fake_tensor, +) +from modelopt.torch.puzzletron.tools.checkpoint_utils import init_module_with_state_dict +from modelopt.torch.puzzletron.tools.sharded_checkpoint_utils import DummyBlock +from modelopt.torch.puzzletron.utils.validation import _organize_outputs, calculate_batch_outputs + + +def _log_forward_error(e: Exception, rank: int, batch_idx: int, num_batches: int) -> None: + """Log detailed error info for distributed forward pass failures. + + When one rank crashes during distributed forward, others may hang waiting for communication. + This logging helps diagnose which rank failed and why. + """ + error_msg = ( + f"\n{'=' * 60}\n" + f"[Rank {rank}] ERROR in stitched_model forward (batch {batch_idx}/{num_batches})\n" + f"Error: {type(e).__name__}: {e}\n" + f"{'=' * 60}\n" + f"{traceback.format_exc()}" + f"{'=' * 60}\n" + ) + print(error_msg, flush=True) + + +class HiddenStatesAndLMHead(list): + def __init__(self, hidden_states: list[torch.Tensor], lm_head_weights: torch.Tensor): + super().__init__(hidden_states) + self.lm_head_weights = lm_head_weights + + +@torch.no_grad() +def calculate_losses_pipeline( + stitched_model: StitchedModule, + dataloader: DataLoader | None, + target_hidden_states_per_batch: HiddenStatesAndLMHead | None = None, + return_hidden_states: bool = False, + calculate_full_score_ablations: bool = False, + calc_on_cpu: bool = False, + just_model_forward: bool = False, + checkpoint_manager=None, + autocast_dtype: torch.dtype = torch.bfloat16, + descriptor: Type[ModelDescriptor] = None, + use_autocast: bool = True, +) -> tuple[dict[str, dict], HiddenStatesAndLMHead | None] | tuple[None, None]: + """ + Do model forward on each batch and calculate LM loss. + Optionally also calculate kl_div loss and other metrics from given target_hidden_states_per_batch. + Optionally return hidden states per batch. + Does not support data-parallel. + just_model_forward: skip loss calculation, just forward the model. Useful for activation hooks. + + + Returns: + losses: dict = { + "lm_loss": { + "avg": float, + "per_sample": list[float] + } + more metrics if provided with target_hidden_states_per_batch + } + target_hidden_states_per_batch: list[torch.Tensor], returned if return_hidden_states=True + + """ + if not isinstance(stitched_model, StitchedModule): + stitched_model = perform_pipeline_stitches(stitched_model, descriptor) + + params = list(stitched_model.parameters()) + model_device = params[0].device if params else "cpu" + + # Pre-populate outputs with dummy values for skipped batches + start_batch = checkpoint_manager.current_batch if checkpoint_manager else 0 + if dist.is_last_process(): + outputs = [{"lm_loss": [0.0]}] * start_batch + else: + outputs = None + + if dist.is_master(): + all_input_ids, all_targets = zip( + *[(batch["input_ids"], batch["targets"]) for batch in dataloader] + ) + if dist.size() > 1: + distributed_send_obj(all_targets, dst=dist.size() - 1) + + if dist.is_last_process(): + if dist.size() > 1: + all_targets = distributed_recv_obj(src=0) + + lm_head: LMHead = next( + module + for module_name, module in stitched_model.named_modules() + if "lm_head" in module_name + ) + + if target_hidden_states_per_batch is not None: + lm_head_weights = target_hidden_states_per_batch.lm_head_weights + with torch.device(model_device): + target_lm_head = init_module_with_state_dict( + {"weight": lm_head_weights}, LMHead, *lm_head_weights.shape[::-1], bias=False + ) + + if dist.is_master(): + num_batches = len(all_input_ids) + seq_len = all_input_ids[0].shape[1] + if dist.size() > 1: + torch.distributed.broadcast_object_list([num_batches, seq_len]) + + # Create progress bar with sliced range starting from checkpoint position + desc = ( + f"[rank {dist.rank()}] calculate_losses_pipeline(" + f"{(target_hidden_states_per_batch is None)=}, {return_hidden_states=}, {num_batches=})" + ) + progress_bar = tqdm(range(start_batch, num_batches), desc=desc) + else: + obj_list = [None, None] + if dist.size() > 1: + torch.distributed.broadcast_object_list(obj_list) + num_batches, seq_len = obj_list + progress_bar = range(start_batch, num_batches) + + stitched_model.eval() + + # Use autocast for mixed precision, or nullcontext if disabled + # (some models like Qwen3-VL MoE have dtype bugs under autocast) + autocast_ctx = ( + torch.autocast(device_type="cuda", dtype=autocast_dtype) if use_autocast else nullcontext() + ) + with autocast_ctx: + fake_input_ids = fake_tensor(1, seq_len, dtype=torch.long, device=model_device) + for i_batch in progress_bar: + if dist.is_master(): + input_ids = all_input_ids[i_batch].to(model_device) + else: + input_ids = fake_input_ids + + try: + output = stitched_model({}, {}, input_ids) + except Exception as e: + _log_forward_error(e, dist.rank(), i_batch, num_batches) + raise + + if dist.is_last_process(): + logits = output.captured_outputs.get("model_output") + logits = getattr(logits, "logits", logits) + hidden_states = output.captured_outputs.get("hidden_states") + targets = all_targets[i_batch].to(model_device) + + target_hidden_states = None + target_logits = None + if target_hidden_states_per_batch is not None: + target_hidden_states = target_hidden_states_per_batch[i_batch] + target_hidden_states = target_hidden_states.to(hidden_states.device) + target_logits = target_lm_head(target_hidden_states) + + if just_model_forward: + batch_outputs = {"lm_loss": [-1.0] * len(targets)} + else: + batch_outputs = calculate_batch_outputs( + hidden_states, + target_hidden_states, + logits, + target_logits, + targets, + return_hidden_states, + calculate_full_score_ablations, + calc_on_cpu, + ) + + outputs.append(batch_outputs) + + # Free GPU memory after processing each batch + del logits, hidden_states, targets + if target_hidden_states is not None: + del target_hidden_states + if target_logits is not None: + del target_logits + + # Free output tensor memory on all ranks + del output + + # Update checkpoint progress periodically + if checkpoint_manager: + checkpoint_manager.update_progress(i_batch + 1, num_batches) + + losses, hidden_states_per_batch = ( + _organize_outputs(outputs) if outputs is not None else (None, None) + ) + + if hidden_states_per_batch is not None: + hidden_states_per_batch = HiddenStatesAndLMHead( + hidden_states_per_batch, lm_head.weight.cpu() + ) + + dist.barrier() + return losses, hidden_states_per_batch + + +def perform_pipeline_stitches( + model, + descriptor: Type[ModelDescriptor], +) -> StitchedModule: + """Create pipeline stitches for distributed model evaluation. + + Args: + model: The model to stitch (any HuggingFace model with AnyModel descriptor). + descriptor: ModelDescriptor for layer naming. + """ + target = ModuleTarget("module", model) + stitcher = Needle() + + num_layers = model.config.num_hidden_layers + + is_real_block = np.flatnonzero( + [ + not isinstance(model.get_submodule(descriptor.layer_block_name(i)), DummyBlock) + for i in range(num_layers) + ] + ) + + first_block, last_block = is_real_block.min(), is_real_block.max() + + if dist.rank() != 0: + # receive activations from previous rank + stitcher.stitch( + RemoteTarget(peer_rank=dist.rank() - 1).value( + name="activations", adapter=lambda x: InputArgs(x) + ), + target.input( + name=descriptor.layer_block_name(first_block), + reducer=InputReducer( + lambda acc, override, orig, *args: override + orig.drop_args(0) + ), + ), + ) + + if not dist.is_last_process(): + # send activations to next rank + stitcher.stitch( + target.output(descriptor.layer_block_name(last_block)), + RemoteTarget(peer_rank=dist.rank() + 1).value(name="activations"), + ) + else: + # register model output + stitcher.stitch( + target.output(name=descriptor.output_embedding_name()), + ExternalTarget().output("model_output"), + ) + stitcher.stitch( + target.output(name=descriptor.final_norm_name()), + ExternalTarget().output("hidden_states"), + ) + + stitched_module = stitcher.knot(ignore_extra_overrides=True) + return stitched_module diff --git a/modelopt/torch/puzzletron/utils/validation.py b/modelopt/torch/puzzletron/utils/validation.py new file mode 100644 index 0000000000..0fff907549 --- /dev/null +++ b/modelopt/torch/puzzletron/utils/validation.py @@ -0,0 +1,550 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Model validation and loss calculation utilities for single-GPU and multi-GPU setups. + +Also provides helper functions for loss metrics, KL divergence, JS divergence, +and similarity losses for knowledge distillation. +""" + +# mypy: ignore-errors +import functools +import math +from enum import Enum + +import numpy as np +import torch +import torch.nn.functional as F +from torch import nn +from torch.utils.data import DataLoader +from tqdm import tqdm +from transformers.generation.logits_process import TopKLogitsWarper, TopPLogitsWarper +from typing_extensions import Self + +from modelopt.torch.puzzletron.tools import kd_model + + +class UnshardedLowMemorySparseTensor: + def __init__(self, x: torch.Tensor): + inds_dtype = self._infer_inds_dtype(x) + x_sparse = x.to_sparse_coo() + self._values = x_sparse.values() + self._indices = x_sparse.indices().to(inds_dtype) + self._size = x_sparse.size() + + @staticmethod + def _infer_inds_dtype(x: torch.Tensor) -> torch.dtype: + max_dim = max(x.shape) + for inds_dtype in [torch.int16, torch.int32, torch.int64]: + if torch.iinfo(inds_dtype).max >= max_dim: + return inds_dtype + + def to_sparse_coo(self) -> torch.Tensor: + return torch.sparse_coo_tensor(values=self._values, indices=self._indices, size=self._size) + + def to_dense(self) -> torch.Tensor: + return self.to_sparse_coo().to_dense() + + def to(self, *args) -> Self: + self._values = self._values.to(*args) + for arg in args: + if isinstance(arg, torch.device) or isinstance(arg, str): + self._indices = self._indices.to(arg) + return self + + +class LowMemorySparseTensor: + _max_sparse_size = torch.iinfo(torch.int32).max + + def __init__(self, x: torch.Tensor): + num_chunks = math.ceil(x.numel() / self._max_sparse_size) + self._chunk_dim = np.argmax(x.shape) + self._chunks = [ + UnshardedLowMemorySparseTensor(chunk) + for chunk in torch.chunk(x, num_chunks, dim=self._chunk_dim) + ] + + def to(self, *args) -> Self: + for chunk in self._chunks: + chunk.to(*args) + return self + + def to_dense(self) -> torch.Tensor: + return torch.concat([chunk.to_dense() for chunk in self._chunks], dim=self._chunk_dim) + + +@torch.no_grad() +def calculate_losses( + model: nn.Module, + dataloader: DataLoader, + target_probs: None = None, + return_probs: bool = False, + checkpoint_manager=None, +) -> tuple[dict[str, dict], None] | tuple[None, None]: + """Do model forward on each batch and calculate LM loss. + Works on lit-llama models (single gpu) and huggingface models (can be multi gpu). + Does not support data-parallel. + + ### Anything related to probs and hidden states is not supported currently! ### + calculate_losses() isn't updated according to the major refactor in + calculate_losses_pipeline() regarding hidden states. + + Returns: + outputs = { + "lm_loss": list[float], + "token_accuracy_top_1": list[float], + "token_accuracy_top_5": list[float], + "token_accuracy_top_10": list[float], + } + """ + if (target_probs is not None) or return_probs: + raise NotImplementedError( + "calculate_losses() isn't updated according to the major refactor in " + "calculate_losses_pipeline() regarding hidden states." + ) + + model_device = next(model.parameters()).device + outputs = [] + + try: + num_batches = len(dataloader) + except: + num_batches = None + + # Adjust progress bar for resume + start_batch = checkpoint_manager.current_batch if checkpoint_manager else 0 + progress_bar = tqdm( + enumerate(dataloader), + total=num_batches, + desc=f"calculate_losses({(target_probs is None)=}, {return_probs=})", + ) + if start_batch > 0: + progress_bar.update(start_batch) + + for i_batch, batch in progress_bar: + # Skip batch if resuming from checkpoint + if checkpoint_manager and checkpoint_manager.should_skip_batch(i_batch): + continue + + input_ids = batch["input_ids"].to(model_device) + logits = model(input_ids) + if hasattr(logits, "logits"): + logits = logits.logits + # logits = logits.float() + + targets = batch["targets"].to(model_device) + + batch_outputs = calculate_batch_outputs( + hidden_states=None, + target_hidden_states=None, + logits=logits, + target_logits=None, + targets=targets, + return_hidden_states=False, + calculate_full_score_ablations=False, + calc_on_cpu=False, + ) + outputs.append(batch_outputs) + + # Update checkpoint progress periodically + if checkpoint_manager: + checkpoint_manager.update_progress(i_batch + 1, num_batches) + + losses, _ = _organize_outputs(outputs) + return losses, None + + +def calculate_batch_outputs( + hidden_states: torch.Tensor | None, + target_hidden_states: torch.Tensor | None, + logits: torch.Tensor, + target_logits: torch.Tensor | None, + targets: torch.Tensor, + return_hidden_states: bool, + calculate_full_score_ablations: bool, + calc_on_cpu: bool, +) -> dict: + if calc_on_cpu: + if hidden_states is not None: + hidden_states = hidden_states.cpu() + if target_hidden_states is not None: + target_hidden_states = target_hidden_states.cpu() + if logits is not None: + logits = logits.cpu() + if target_logits is not None: + target_logits = target_logits.cpu() + if targets is not None: + targets = targets.cpu() + + batch_outputs = _calculate_ground_truth_based_scores(logits, targets) + + if (target_hidden_states is not None) or (target_logits is not None): + batch_outputs.update( + _calculate_teacher_similarity_scores( + hidden_states, + target_hidden_states, + logits, + target_logits, + calculate_full_score_ablations, + ) + ) + + if return_hidden_states: + batch_outputs["hidden_states_per_batch"] = hidden_states.cpu() + + return batch_outputs + + +def _organize_outputs( + outputs_per_batch: list[dict], +) -> tuple[dict[str, dict], list[torch.Tensor] | None]: + outputs = _concatenate_batch_outputs(outputs_per_batch) + hidden_states_per_batch = outputs.pop("hidden_states_per_batch", None) + losses = { + loss_name: { + "avg": sum(loss_per_sample) / len(loss_per_sample), + "per_sample": loss_per_sample, + } + for loss_name, loss_per_sample in outputs.items() + } + return losses, hidden_states_per_batch + + +def _concatenate_batch_outputs(outputs_per_batch: list[dict]) -> dict[str, list]: + outputs = {} + for output_name in outputs_per_batch[0]: # Regular dict is directly iterable + item_list = [] + for batch_outputs in outputs_per_batch: + batch_items = batch_outputs[output_name] + if isinstance(batch_items, list | tuple): + item_list.extend(batch_items) + else: + item_list.append(batch_items) + outputs[output_name] = item_list + return outputs + + +def _calculate_per_sample_lm_loss( + logits: torch.Tensor, + targets: torch.Tensor, +) -> list[float]: + per_sample_lm_loss = ( + torch.nn.functional.cross_entropy( + logits.transpose(1, 2), targets, ignore_index=-1, reduction="none" + ) + .mean(dim=-1) + .tolist() + ) + return per_sample_lm_loss + + +def _calculate_ground_truth_based_scores( + logits: torch.Tensor, + targets: torch.Tensor, +) -> dict[str, list[float]]: + scores = {"lm_loss": _calculate_per_sample_lm_loss(logits, targets)} + + for top_k in (1, 5, 10): + top_k_predictions = logits.topk(top_k, dim=-1).indices # [b, t, top_k] + is_target_in_predictions = (targets.unsqueeze(-1) == top_k_predictions).any( + dim=-1 + ) # [b, t] + fraction_model_predicted_target = is_target_in_predictions.float().mean(dim=-1) # [b] + scores[f"token_accuracy_top_{top_k}"] = fraction_model_predicted_target.tolist() + + return scores + + +def cosine_embedding_loss( + hidden_states: torch.Tensor, + target_hidden_states: torch.Tensor, +) -> list[float]: + return kd_model.cosine_embedding_loss_batched(hidden_states, target_hidden_states).tolist() + + +def normalized_mse_loss( + hidden_states: torch.Tensor, + target_hidden_states: torch.Tensor, +) -> list[float]: + return [ + kd_model.normalized_mse_loss(hidden_states[i_sample], target_hidden_states[i_sample]).item() + for i_sample in range(hidden_states.shape[0]) + ] + + +def mse_loss( + hidden_states: torch.Tensor, + target_hidden_states: torch.Tensor, +) -> list[float]: + return [ + F.mse_loss(hidden_states[i_sample], target_hidden_states[i_sample]).item() + for i_sample in range(hidden_states.shape[0]) + ] + + +def mae_loss( + hidden_states: torch.Tensor, + target_hidden_states: torch.Tensor, +) -> list[float]: + return [ + F.l1_loss(hidden_states[i_sample], target_hidden_states[i_sample]).item() + for i_sample in range(hidden_states.shape[0]) + ] + + +def _calculate_teacher_similarity_scores( + hidden_states: torch.Tensor, + target_hidden_states: torch.Tensor, + logits: torch.Tensor, + target_logits: torch.Tensor, + calculate_full_score_ablations: bool, +) -> dict[str, list[float]]: + """hidden_states: [batch, tokens, n_embd] + target_hidden_states: [batch, tokens, n_embd] + logits: [batch, tokens, vocab] + target_logits: [batch, tokens, vocab] + """ + + def calc_per_sample(func, logits, target_probs): + return [ + func(logits=logits[i_sample], target_probs=target_probs[i_sample]) + for i_sample in range(logits.shape[0]) + ] + + score_ablations = {} + + if (target_hidden_states is not None) and (hidden_states.shape == target_hidden_states.shape): + for func in (cosine_embedding_loss, normalized_mse_loss, mse_loss, mae_loss): + score_name = f"{func.__name__}_hidden_states" + score_ablations[score_name] = func(hidden_states, target_hidden_states) + + if target_logits is not None: + for func in (cosine_embedding_loss, normalized_mse_loss, mse_loss, mae_loss): + score_name = f"{func.__name__}_logits" + score_ablations[score_name] = func(logits, target_logits) + + for top_p in (0.99, 0.95, None) if calculate_full_score_ablations else (None,): + transformed_logits = ( + logits if (top_p is None) else top_p_top_k(logits, top_p=top_p, top_k=None) + ) + transformed_target_logits = ( + target_logits + if (top_p is None) + else top_p_top_k(target_logits, top_p=top_p, top_k=None) + ) + target_probs = transformed_target_logits.softmax(-1) + + for func in (kl_div, js_div, tv_dist): + for clip_epsilon in ( + ( + ClipEpsilon.NO_CLIP, + ClipEpsilon.CLIP_NO_RENORMALIZE, + ClipEpsilon.CLIP_RENORMALIZE, + ) + if calculate_full_score_ablations + else (ClipEpsilon.NO_CLIP,) + ): + epsilon_factors = ( + (1.0, 0.1, 0.01) if not clip_epsilon == ClipEpsilon.NO_CLIP else (None,) + ) + + for epsilon_factor in epsilon_factors: + score_name = ( + f"{func.__name__}--top_p_{top_p}--clip_epsilon_{clip_epsilon.name}" + f"--epsilon_factor_{epsilon_factor}" + ) + func_with_args = functools.partial( + func, clip_epsilon=clip_epsilon, epsilon_factor=epsilon_factor + ) + score_ablations[score_name] = calc_per_sample( + func_with_args, transformed_logits, target_probs + ) + if (top_p is None) and (clip_epsilon == ClipEpsilon.NO_CLIP): + short_score_name = func.__name__ + score_ablations[short_score_name] = score_ablations[score_name] + + for top_k in (1, 5, 10): + teacher_greedy_prediction = target_logits.argmax(dim=-1, keepdim=True) # [b,t,1] + student_top_k_predictions = logits.topk(top_k, dim=-1).indices # [b,t,k] + is_teacher_prediction_in_student_predictions = ( + teacher_greedy_prediction == student_top_k_predictions + ).any(dim=-1) # [b,t] + fraction_student_predicted_teacher = ( + is_teacher_prediction_in_student_predictions.float().mean(dim=-1) + ) # [b] + score_ablations[f"greedy_teacher_prediction_in_student_top_{top_k}"] = ( + fraction_student_predicted_teacher.tolist() + ) + + if calculate_full_score_ablations: + for top_p in (0.99, 0.95, 0.50, None): + # student + transformed_logits = logits.clone() + + # teacher + transformed_target_logits = ( + target_logits.clone() + if (top_p is None) + else top_p_top_k(target_logits, top_p=top_p, top_k=None) + ) + + target_probs = transformed_target_logits.softmax(-1) + mask = transformed_target_logits == -1000 + if torch.any(mask): + transformed_logits[mask] = 0 + transformed_target_logits[mask] = 0 + target_probs[mask] = 0 + + for func in (mse_loss, mae_loss): + score_name = f"{func.__name__}_logits_top_p_{top_p}" + score_ablations[score_name] = func( + transformed_logits, transformed_target_logits + ) + + if top_p is not None and top_p > 0.9: + func = kl_div + clip_epsilon = ClipEpsilon.NO_CLIP + score_name = ( + f"{func.__name__}--top_p_{top_p}--clip_epsilon_no_clip_student_unfiltered" + ) + func_with_args = functools.partial( + func, clip_epsilon=clip_epsilon, epsilon_factor=epsilon_factor + ) + score_ablations[score_name] = calc_per_sample( + func_with_args, logits, target_probs + ) + # score_name = f"{func.__name__}_abs--top_p_{top_p}--clip_epsilon_no_clip_student_unfiltered" + # score_ablations[score_name] = [s.abs() for s in score_ablations[score_name]] + + return score_ablations + + +class ClipEpsilon(Enum): + NO_CLIP = "NO_CLIP" + CLIP_RENORMALIZE = "CLIP_RENORMALIZE" + CLIP_NO_RENORMALIZE = "CLIP_NO_RENORMALIZE" + + +def _logits_to_logprobs( + logits: torch.Tensor, clip_epsilon: ClipEpsilon, epsilon_factor: float +) -> torch.Tensor: + """logits: [tokens, vocab]""" + logprobs = logits.log_softmax( + -1 + ) # must normalize logits before clipping otherwise log(1/voacb) means nothing + if clip_epsilon == ClipEpsilon.NO_CLIP: + return logprobs + vocab_size = logprobs.shape[-1] + epsilon = math.log(epsilon_factor * 1 / vocab_size) + logprobs = torch.clip(logprobs, min=epsilon) + if clip_epsilon == ClipEpsilon.CLIP_RENORMALIZE: + logprobs = logprobs.log_softmax( + -1 + ) # we do log_softmax again to retain legitimate distributions + return logprobs + + +def kl_div( + logits: torch.Tensor, + target_probs: torch.Tensor, + clip_epsilon: ClipEpsilon = ClipEpsilon.NO_CLIP, + epsilon_factor: float = 1.0, +) -> float: + """Kullback-Leibler Divergence for a single sample. + logits: [tokens, vocab] + target_probs: [tokens, vocab] + """ + num_tokens = logits.shape[0] + logprobs = _logits_to_logprobs(logits, clip_epsilon, epsilon_factor) + + _kl_div = ( + F.kl_div(logprobs, target_probs, reduction="sum", log_target=False).item() / num_tokens + ) + return _kl_div + + +def js_div( + logits: torch.Tensor, + target_probs: torch.Tensor, + clip_epsilon: ClipEpsilon = ClipEpsilon.NO_CLIP, + epsilon_factor: float = 1.0, +) -> float: + """Jensen-Shannon Divergence for a single sample. + logits: [tokens, vocab] + target_probs: [tokens, vocab] + """ + probs = logits.softmax(-1) + mixture_probs = (probs + target_probs) / 2 + mixture_logprobs = mixture_probs.log().clip(min=-1000) + + pred_kl_div = kl_div(mixture_logprobs, probs, clip_epsilon, epsilon_factor) + target_kl_div = kl_div(mixture_logprobs, target_probs, clip_epsilon, epsilon_factor) + _js_div = 0.5 * (pred_kl_div + target_kl_div) + return _js_div + + +def tv_dist( + logits: torch.Tensor, + target_probs: torch.Tensor, + clip_epsilon: ClipEpsilon = ClipEpsilon.NO_CLIP, + epsilon_factor: float = 1.0, +) -> float: + """Total Variation Distance (L1-loss) for a single sample. + logits: [tokens, vocab] + target_probs: [tokens, vocab] + """ + num_tokens, vocab_size = logits.shape + probs = logits.softmax(-1) + + if clip_epsilon != ClipEpsilon.NO_CLIP: + epsilon = epsilon_factor * 1 / vocab_size + probs = probs.clip(min=epsilon) + target_probs = target_probs.clip(min=epsilon) + if clip_epsilon == ClipEpsilon.CLIP_RENORMALIZE: + probs = probs / probs.sum(-1, keepdim=True) + target_probs = target_probs / target_probs.sum(-1, keepdim=True) + + _tv_dist = 0.5 * (probs - target_probs).abs().sum().item() / num_tokens + return _tv_dist + + +DEFAULT_TOP_P = 0.999 +# WestLake model: +# 700 = percentile 0.9 for top_p=0.99 +# 1700 = percentile 0.95 for top_p=0.99 and percentile 0.75 for top_p=0.999 +# For top_p=0.999 and top_k=1700 you take about 75 GB for 2048*8192 tokens +DEFAULT_TOP_K = 1000 + + +def top_p_top_k( + logits: torch.Tensor, + top_p: float | None = DEFAULT_TOP_P, + top_k: int | None = DEFAULT_TOP_K, + filter_value=-1000, +) -> torch.Tensor: + logit_warpers = [] + if top_p is not None: + logit_warpers.append(TopPLogitsWarper(top_p=top_p, filter_value=filter_value)) + if top_k is not None: + logit_warpers.append(TopKLogitsWarper(top_k=top_k, filter_value=filter_value)) + + warped_logits = [] + for sample_logits in logits: + for warper in logit_warpers: + sample_logits = warper(input_ids=None, scores=sample_logits) + warped_logits.append(sample_logits) + warped_logits = torch.stack(warped_logits) + + return warped_logits diff --git a/modelopt/torch/utils/robust_json.py b/modelopt/torch/utils/robust_json.py index c4a72fde83..23a3091637 100644 --- a/modelopt/torch/utils/robust_json.py +++ b/modelopt/torch/utils/robust_json.py @@ -55,8 +55,13 @@ def default(self, o): # User-defined function in main — fallback to just the name return o.__name__ return f"{o.__module__}.{o.__qualname__}" + if inspect.isclass(o): + return f"{o.__module__}.{o.__qualname__}" if isinstance(o, datetime.timedelta): return str(o) + # Fallback for arbitrary objects (e.g. mixins injected into Hydra configs) + if hasattr(o, "__class__") and hasattr(o.__class__, "__module__"): + return f"{o.__class__.__module__}.{o.__class__.__qualname__}" return super().default(o) diff --git a/pyproject.toml b/pyproject.toml index 96490dff0a..a03e2029dc 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -84,6 +84,16 @@ hf = [ "transformers>=4.56,<5.0", # Should match modelopt/torch/__init__.py and tox.ini "wonderwords", ] + +puzzletron = [ # Dependedencies for modelopt.torch.puzzletron subpackage + "fire", + "hydra-core==1.3.2", + "immutabledict", + "lru-dict", + "mip", + "pandas", + "typeguard", +] dev-lint = [ "bandit[toml]==1.7.9", # security/compliance checks "mypy==1.17.1", @@ -113,7 +123,7 @@ dev-test = [ "tox-current-env>=0.0.12", ] # Compound extras via self-references -all = ["nvidia-modelopt[onnx,hf]"] +all = ["nvidia-modelopt[onnx,hf,puzzletron]"] dev = ["nvidia-modelopt[all,dev-docs,dev-lint,dev-test]"] [project.urls] @@ -203,6 +213,17 @@ extend-ignore = [ "D", "E501", ] # Ignore missing docstrings or line length for Jupyter notebooks +"modelopt/torch/puzzletron/*" = [ + "C4", + "D", + "E", + "F", + "N", + "PERF", + "RUF", + "SIM", + "UP", +] # TODO: Disabled for now, will enable later, once all puzzletron code is migrated "modelopt/torch/quantization/triton/*" = ["N803", "N806", "E731"] # triton style "modelopt/torch/sparsity/attention_sparsity/kernels/*" = [ "N803", @@ -271,7 +292,7 @@ markers = [ [tool.coverage.run] branch = false include = ["modelopt/*"] -omit = ["*/plugins/*", "*/export/*"] +omit = ["*/plugins/*", "*/export/*", "modelopt/torch/puzzletron/*"] [tool.coverage.report] diff --git a/tests/_test_utils/torch/puzzletron/utils.py b/tests/_test_utils/torch/puzzletron/utils.py new file mode 100644 index 0000000000..fc6d6d5c16 --- /dev/null +++ b/tests/_test_utils/torch/puzzletron/utils.py @@ -0,0 +1,217 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +from pathlib import Path + +import torch +from _test_utils.torch.transformers_models import get_tiny_tokenizer +from datasets import Dataset, DatasetDict +from transformers import AutoConfig, AutoModelForCausalLM, PreTrainedTokenizerBase + +import modelopt.torch.utils.distributed as dist +from modelopt.torch.puzzletron.tools.hydra_utils import register_hydra_resolvers + + +def setup_test_model_and_data( + tmp_path: Path, rank: int, hf_model_name: str, hybrid_override_pattern: str | None = None +) -> tuple[Path, Path, Path]: + """ + Setup the test model and data for the puzzletron NAS search. + + Args: + tmp_path: the temporary path to use for the test + rank: the rank of the process + hf_model_name: HuggingFace model card name (e.g., "meta-llama/Llama-3.1-8B-Instruct") + hybrid_override_pattern: For NemotronH models, the layer type pattern + + Returns: + tuple[Path, Path, Path]: the puzzle_dir, hf_checkpoint_path, dataset_path + """ + # Register Hydra custom resolvers (needed for config resolution) + register_hydra_resolvers() + + puzzle_dir = tmp_path / hf_model_name + hf_checkpoint_path = puzzle_dir / f"hf_models/{hf_model_name}" + dataset_path = puzzle_dir / "dummy_dataset" + + if rank == 0: + save_dummy_dataset(dataset_path) + + # Create a small HF model + tokenizer = get_tiny_tokenizer() + create_and_save_small_hf_model( + output_path=str(hf_checkpoint_path), + tokenizer=tokenizer, + hf_model_name=hf_model_name, + hybrid_override_pattern=hybrid_override_pattern, + ) + dist.barrier() + + return puzzle_dir, hf_checkpoint_path, dataset_path + + +def create_and_save_small_hf_model( + output_path: str, + tokenizer: PreTrainedTokenizerBase, + hf_model_name: str, + hybrid_override_pattern: str | None = None, +): + """ + Create and save a small HuggingFace model for testing the conversion pipeline. + Uses real HuggingFace config to preserve model-specific settings (like tie_word_embeddings), + but shrinks size parameters for fast testing. + + Args: + output_path: Where to save the model + tokenizer: Tokenizer to save alongside the model + hf_model_name: HuggingFace model card name (e.g., "meta-llama/Llama-3.1-8B-Instruct") + hybrid_override_pattern: For NemotronH models, the layer type pattern (e.g., "*-" for Attention+MLP, + "M-" for Mamba+MLP). Must match num_hidden_layers. None for non-NemotronH models. + """ + # Load real HuggingFace config (preserves tie_word_embeddings, rope_scaling, etc.) + config = AutoConfig.from_pretrained(hf_model_name, trust_remote_code=True) + + # Override size-related params to make it small for testing + # Note: intermediate_size must be divisible by 256 per DeciLM config requirements + # Note: hidden_size must give head_dim >= 8 for Flash Attention 2 compatibility + + # VL models have nested configs (text_config, vision_config) + if hasattr(config, "text_config") and hasattr(config, "vision_config"): + config.text_config.vocab_size = tokenizer.vocab_size + config.text_config.hidden_size = 256 + config.text_config.intermediate_size = 512 + config.text_config.num_hidden_layers = 2 + config.text_config.num_attention_heads = 32 + config.text_config.num_key_value_heads = 8 + config.text_config.num_experts = 16 # Reduce from 128 + config.text_config.moe_intermediate_size = 256 + config.text_config.max_position_embeddings = 512 + config.vision_config.depth = 2 # Reduce from 27 + config.vision_config.hidden_size = 256 + config.vision_config.intermediate_size = 512 + config.vision_config.out_hidden_size = 256 + # TODO: this is hack, redesign converter to not read config.num_hidden_layers directly. + # set top-level num_hidden_layers for converter compatibility + config.num_hidden_layers = config.text_config.num_hidden_layers + else: + # Regular models have flat config + config.vocab_size = tokenizer.vocab_size + config.hidden_size = 256 + config.intermediate_size = 512 + config.num_hidden_layers = max(2, dist.size()) + config.num_attention_heads = 32 + config.num_key_value_heads = 8 + config.max_position_embeddings = 512 + + # Fix layer_types to match num_hidden_layers (newer transformers validates this) + if hasattr(config, "layer_types") and config.layer_types is not None: + config.layer_types = config.layer_types[: config.num_hidden_layers] + + # Fix rope_scaling to be consistent with max_position_embeddings + if hasattr(config, "rope_scaling") and config.rope_scaling is not None: + config.rope_scaling["original_max_position_embeddings"] = 256 + + # NemotronH requires hybrid_override_pattern to match num_hidden_layers + if hasattr(config, "hybrid_override_pattern") and hybrid_override_pattern is not None: + config.hybrid_override_pattern = hybrid_override_pattern + + # Ensure pad_token_id is within vocab_size (nn.Embedding requires padding_idx < num_embeddings) + if ( + getattr(config, "pad_token_id", None) is not None + and config.pad_token_id >= tokenizer.vocab_size + ): + config.pad_token_id = 0 + + # Set seed for reproducible weight initialization + torch.manual_seed(42) + + # Create and save the model + # Force CPU initialization for deterministic behavior (prevents NaN on RTX GPUs) + original_cuda_visible = os.environ.get("CUDA_VISIBLE_DEVICES") + os.environ["CUDA_VISIBLE_DEVICES"] = "" + # TODO: Consider using AutoModel.from_config instead. + if hasattr(config, "text_config") and hasattr(config, "vision_config"): + from transformers import Qwen3VLMoeForConditionalGeneration + + model = Qwen3VLMoeForConditionalGeneration._from_config(config) + else: + model = AutoModelForCausalLM.from_config(config, trust_remote_code=True) + + # Initialize weights to ensure all parameters are properly initialized + # This prevents NaN values in uninitialized parameters (e.g., backbone.layers.1.mixer.gate.weight + # in nemotron-3-nano-30b-a3b-base-bf16) that can occur with from_config on RTX GPU cards (not on H100) + model.initialize_weights() + + # Fix any remaining NaN/Inf values that initialize_weights() might have missed + for param in model.parameters(): + if torch.isnan(param).any() or torch.isinf(param).any(): + nan_inf_mask = torch.isnan(param) | torch.isinf(param) + param.data = torch.where(nan_inf_mask, torch.zeros_like(param), param) + + # Restore CUDA_VISIBLE_DEVICES after model creation and initialization + if original_cuda_visible is not None: + os.environ["CUDA_VISIBLE_DEVICES"] = original_cuda_visible + else: + os.environ.pop("CUDA_VISIBLE_DEVICES", None) + + model.to(dtype=torch.bfloat16).save_pretrained(output_path) + + # Save tokenizer + tokenizer.save_pretrained(output_path) + + # Save config + config.save_pretrained(output_path) + + +def save_dummy_dataset(dataset_path: Path | str): + """ + Save a dummy dataset for testing purposes. + """ + # dummy sample + sample = [ + {"role": "user", "content": "please cite Lorem Ipsum?"}, + { + "role": "assistant", + "content": ( + "Lorem ipsum dolor sit amet, consectetur adipiscing elit. Sed in blandit ante. " + "Sed tempus erat urna, ac elementum nisl facilisis quis. Aliquam consectetur mollis massa, " + "in elementum sem venenatis posuere. Fusce lorem arcu, egestas vel massa sollicitudin, " + "dictum mollis purus. Proin in ullamcorper elit. Nam tellus nisi, volutpat a mattis vel, " + "pretium in purus. Nunc at lectus facilisis risus scelerisque rhoncus eu nec ex. " + "Maecenas semper, tellus non placerat vulputate, urna felis facilisis diam, " + "sit amet vestibulum erat sapien nec libero. Praesent non massa velit. Donec faucibus mi eros. " + "Nam turpis nulla, congue sit amet mi at, porttitor scelerisque elit. Nunc id sodales lorem, " + "nec tincidunt leo. Quisque a neque nec ligula porttitor auctor. " + "Nunc accumsan nunc ac tellus congue vehicula. Praesent tellus eros, luctus non gravida dapibus, " + "faucibus eu ex. Quisque bibendum leo pharetra, tristique est vitae, hendrerit nunc. " + "Duis nec congue dolor. Donec commodo ipsum non efficitur volutpat. " + "Nulla risus nulla, efficitur et urna at, imperdiet sodales lorem. " + "Suspendisse erat est, sollicitudin at nisl tincidunt, vehicula hendrerit lectus. " + "Nam quis nisi ullamcorper, rhoncus massa vel, tempus purus. " + "Duis pulvinar eros vel nulla pellentesque, at dapibus justo laoreet. " + "Praesent tortor orci, vulputate fermentum dapibus nec, feugiat vitae tortor. " + "Donec mollis convallis massa quis iaculis." + ), + }, + ] + + # Prepare train and val splits with sample repeated, 2500 samples are for + # 128 samples with block-size 8192 and LLama3 tokenizer + data = [{"conversation": sample}] * 2500 + + # For train-val splits + data_dict = DatasetDict({"train": Dataset.from_list(data), "valid": Dataset.from_list(data)}) + data_dict.save_to_disk(str(dataset_path)) diff --git a/tests/_test_utils/torch/tokenizer/special_tokens_map.json b/tests/_test_utils/torch/tokenizer/special_tokens_map.json new file mode 100644 index 0000000000..02ee80b619 --- /dev/null +++ b/tests/_test_utils/torch/tokenizer/special_tokens_map.json @@ -0,0 +1,16 @@ +{ + "bos_token": { + "content": "<|begin_of_text|>", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false + }, + "eos_token": { + "content": "<|eot_id|>", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false + } +} diff --git a/tests/_test_utils/torch/tokenizer/tokenizer.json b/tests/_test_utils/torch/tokenizer/tokenizer.json new file mode 100644 index 0000000000..83592e2494 --- /dev/null +++ b/tests/_test_utils/torch/tokenizer/tokenizer.json @@ -0,0 +1,212 @@ +{ + "version": "1.0", + "truncation": null, + "padding": null, + "added_tokens": [], + "normalizer": null, + "pre_tokenizer": { + "type": "Sequence", + "pretokenizers": [ + { + "type": "Split", + "pattern": { + "Regex": "(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\\r\\n\\p{L}\\p{N}]?\\p{L}+|\\p{N}{1,3}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+" + }, + "behavior": "Isolated", + "invert": false + }, + { + "type": "ByteLevel", + "add_prefix_space": false, + "trim_offsets": true, + "use_regex": false + } + ] + }, + "post_processor": { + "type": "Sequence", + "processors": [ + { + "type": "ByteLevel", + "add_prefix_space": true, + "trim_offsets": false, + "use_regex": true + }, + { + "type": "TemplateProcessing", + "single": [ + { + "SpecialToken": { + "id": "<|begin_of_text|>", + "type_id": 0 + } + }, + { + "Sequence": { + "id": "A", + "type_id": 0 + } + } + ], + "pair": [ + { + "SpecialToken": { + "id": "<|begin_of_text|>", + "type_id": 0 + } + }, + { + "Sequence": { + "id": "A", + "type_id": 0 + } + }, + { + "SpecialToken": { + "id": "<|begin_of_text|>", + "type_id": 1 + } + }, + { + "Sequence": { + "id": "B", + "type_id": 1 + } + } + ], + "special_tokens": { + "<|begin_of_text|>": { + "id": "<|begin_of_text|>", + "ids": [ + 100 + ], + "tokens": [ + "<|begin_of_text|>" + ] + } + } + } + ] + }, + "decoder": { + "type": "ByteLevel", + "add_prefix_space": true, + "trim_offsets": true, + "use_regex": true + }, + "model": { + "type": "BPE", + "dropout": null, + "unk_token": null, + "continuing_subword_prefix": null, + "end_of_word_suffix": null, + "fuse_unk": false, + "byte_fallback": false, + "ignore_merges": true, + "vocab": { + "!": 0, + "\"": 1, + "#": 2, + "$": 3, + "%": 4, + "&": 5, + "'": 6, + "(": 7, + ")": 8, + "*": 9, + "+": 10, + ",": 11, + "-": 12, + ".": 13, + "/": 14, + "0": 15, + "1": 16, + "2": 17, + "3": 18, + "4": 19, + "5": 20, + "6": 21, + "7": 22, + "8": 23, + "9": 24, + ":": 25, + ";": 26, + "<": 27, + "=": 28, + ">": 29, + "?": 30, + "@": 31, + "A": 32, + "B": 33, + "C": 34, + "D": 35, + "E": 36, + "F": 37, + "G": 38, + "H": 39, + "I": 40, + "J": 41, + "K": 42, + "L": 43, + "M": 44, + "N": 45, + "O": 46, + "P": 47, + "Q": 48, + "R": 49, + "S": 50, + "T": 51, + "U": 52, + "V": 53, + "W": 54, + "X": 55, + "Y": 56, + "Z": 57, + "[": 58, + "\\": 59, + "]": 60, + "^": 61, + "_": 62, + "`": 63, + "a": 64, + "b": 65, + "c": 66, + "d": 67, + "e": 68, + "f": 69, + "g": 70, + "h": 71, + "i": 72, + "j": 73, + "k": 74, + "l": 75, + "m": 76, + "n": 77, + "o": 78, + "p": 79, + "q": 80, + "r": 81, + "s": 82, + "t": 83, + "u": 84, + "v": 85, + "w": 86, + "x": 87, + "y": 88, + "z": 89, + "{": 90, + "|": 91, + "}": 92, + "~": 93, + "¡": 94, + "¢": 95, + "£": 96, + "¤": 97, + "¥": 98, + "¦": 99, + "<|begin_of_text|>": 100, + "<|eot_id|>": 101 + }, + "merges": [] + } +} diff --git a/tests/_test_utils/torch/tokenizer/tokenizer_config.json b/tests/_test_utils/torch/tokenizer/tokenizer_config.json new file mode 100644 index 0000000000..754d9e8db5 --- /dev/null +++ b/tests/_test_utils/torch/tokenizer/tokenizer_config.json @@ -0,0 +1,13 @@ +{ + "bos_token": "<|begin_of_text|>", + "chat_template": "{{- bos_token }}\n{%- if custom_tools is defined %}\n {%- set tools = custom_tools %}\n{%- endif %}\n{%- if not tools_in_user_message is defined %}\n {%- set tools_in_user_message = true %}\n{%- endif %}\n{%- if not date_string is defined %}\n {%- set date_string = \"26 Jul 2024\" %}\n{%- endif %}\n{%- if not tools is defined %}\n {%- set tools = none %}\n{%- endif %}\n\n{#- This block extracts the system message, so we can slot it into the right place. #}\n{%- if messages[0]['role'] == 'system' %}\n {%- set system_message = messages[0]['content']|trim %}\n {%- set messages = messages[1:] %}\n{%- else %}\n {%- set system_message = \"\" %}\n{%- endif %}\n\n{#- System message + builtin tools #}\n{{- \"<|start_header_id|>system<|end_header_id|>\\n\\n\" }}\n{%- if builtin_tools is defined or tools is not none %}\n {{- \"Environment: ipython\\n\" }}\n{%- endif %}\n{%- if builtin_tools is defined %}\n {{- \"Tools: \" + builtin_tools | reject('equalto', 'code_interpreter') | join(\", \") + \"\\n\\n\"}}\n{%- endif %}\n{{- \"Cutting Knowledge Date: December 2023\\n\" }}\n{{- \"Today Date: \" + date_string + \"\\n\\n\" }}\n{%- if tools is not none and not tools_in_user_message %}\n {{- \"You have access to the following functions. To call a function, please respond with JSON for a function call.\" }}\n {{- 'Respond in the format {\"name\": function name, \"parameters\": dictionary of argument name and its value}.' }}\n {{- \"Do not use variables.\\n\\n\" }}\n {%- for t in tools %}\n {{- t | tojson(indent=4) }}\n {{- \"\\n\\n\" }}\n {%- endfor %}\n{%- endif %}\n{{- system_message }}\n{{- \"<|eot_id|>\" }}\n\n{#- Custom tools are passed in a user message with some extra guidance #}\n{%- if tools_in_user_message and not tools is none %}\n {#- Extract the first user message so we can plug it in here #}\n {%- if messages | length != 0 %}\n {%- set first_user_message = messages[0]['content']|trim %}\n {%- set messages = messages[1:] %}\n {%- else %}\n {{- raise_exception(\"Cannot put tools in the first user message when there's no first user message!\") }}\n{%- endif %}\n {{- '<|start_header_id|>user<|end_header_id|>\\n\\n' -}}\n {{- \"Given the following functions, please respond with a JSON for a function call \" }}\n {{- \"with its proper arguments that best answers the given prompt.\\n\\n\" }}\n {{- 'Respond in the format {\"name\": function name, \"parameters\": dictionary of argument name and its value}.' }}\n {{- \"Do not use variables.\\n\\n\" }}\n {%- for t in tools %}\n {{- t | tojson(indent=4) }}\n {{- \"\\n\\n\" }}\n {%- endfor %}\n {{- first_user_message + \"<|eot_id|>\"}}\n{%- endif %}\n\n{%- for message in messages %}\n {%- if not (message.role == 'ipython' or message.role == 'tool' or 'tool_calls' in message) %}\n {{- '<|start_header_id|>' + message['role'] + '<|end_header_id|>\\n\\n'+ message['content'] | trim + '<|eot_id|>' }}\n {%- elif 'tool_calls' in message %}\n {%- if not message.tool_calls|length == 1 %}\n {{- raise_exception(\"This model only supports single tool-calls at once!\") }}\n {%- endif %}\n {%- set tool_call = message.tool_calls[0].function %}\n {%- if builtin_tools is defined and tool_call.name in builtin_tools %}\n {{- '<|start_header_id|>assistant<|end_header_id|>\\n\\n' -}}\n {{- \"<|python_tag|>\" + tool_call.name + \".call(\" }}\n {%- for arg_name, arg_val in tool_call.arguments | items %}\n {{- arg_name + '=\"' + arg_val + '\"' }}\n {%- if not loop.last %}\n {{- \", \" }}\n {%- endif %}\n {%- endfor %}\n {{- \")\" }}\n {%- else %}\n {{- '<|start_header_id|>assistant<|end_header_id|>\\n\\n' -}}\n {{- '{\"name\": \"' + tool_call.name + '\", ' }}\n {{- '\"parameters\": ' }}\n {{- tool_call.arguments | tojson }}\n {{- \"}\" }}\n {%- endif %}\n {%- if builtin_tools is defined %}\n {#- This means we're in ipython mode #}\n {{- \"<|eom_id|>\" }}\n {%- else %}\n {{- \"<|eot_id|>\" }}\n {%- endif %}\n {%- elif message.role == \"tool\" or message.role == \"ipython\" %}\n {{- \"<|start_header_id|>ipython<|end_header_id|>\\n\\n\" }}\n {%- if message.content is mapping or message.content is iterable %}\n {{- message.content | tojson }}\n {%- else %}\n {{- message.content }}\n {%- endif %}\n {{- \"<|eot_id|>\" }}\n {%- endif %}\n{%- endfor %}\n{%- if add_generation_prompt %}\n {{- '<|start_header_id|>assistant<|end_header_id|>\\n\\n' }}\n{%- endif %}\n", + "clean_up_tokenization_spaces": true, + "eos_token": "<|eot_id|>", + "extra_special_tokens": {}, + "model_input_names": [ + "input_ids", + "attention_mask" + ], + "model_max_length": 131072, + "tokenizer_class": "PreTrainedTokenizer" +} diff --git a/tests/_test_utils/torch/transformers_models.py b/tests/_test_utils/torch/transformers_models.py index 01e8fa4d38..54bc10a562 100644 --- a/tests/_test_utils/torch/transformers_models.py +++ b/tests/_test_utils/torch/transformers_models.py @@ -39,6 +39,12 @@ SEED = 1234 +TINY_TOKENIZER_PATH = Path(__file__).parent / "tokenizer" + + +def get_tiny_tokenizer() -> "transformers.PreTrainedTokenizerBase": + return AutoTokenizer.from_pretrained(TINY_TOKENIZER_PATH) + ##### Qwen3 ##### def get_tiny_qwen3(**config_kwargs) -> PreTrainedModel: @@ -66,9 +72,7 @@ def create_tiny_qwen3_dir( ) -> Path | tuple[Path, PreTrainedModel]: qwen3_dir = Path(tmp_path) / "tiny_qwen3" if with_tokenizer: - tokenizer = AutoTokenizer.from_pretrained( - "hf-internal-testing/tiny-random-LlamaForCausalLM" - ) + tokenizer = get_tiny_tokenizer() tokenizer.save_pretrained(qwen3_dir) config_kwargs["vocab_size"] = tokenizer.vocab_size tiny_qwen3 = get_tiny_qwen3(**config_kwargs) @@ -149,9 +153,7 @@ def create_tiny_llama_dir( ) -> Path: llama_dir = Path(tmp_path) / "tiny_llama" if with_tokenizer: - tokenizer = AutoTokenizer.from_pretrained( - "hf-internal-testing/tiny-random-LlamaForCausalLM" - ) + tokenizer = get_tiny_tokenizer() tokenizer.save_pretrained(llama_dir) config_kwargs["vocab_size"] = tokenizer.vocab_size diff --git a/tests/conftest.py b/tests/conftest.py index 53a2330c22..a4e65ff2ae 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -115,7 +115,7 @@ def enable_hf_checkpointing(): mto.enable_huggingface_checkpointing() -@pytest.fixture +@pytest.fixture(scope="session") def project_root_path(request: pytest.FixtureRequest) -> Path: """Fixture providing the project root path for tests.""" return Path(request.config.rootpath) diff --git a/tests/examples/puzzletron/mbridge_distillation/test_distill_hf.py b/tests/examples/puzzletron/mbridge_distillation/test_distill_hf.py new file mode 100644 index 0000000000..6ca0ac0dd9 --- /dev/null +++ b/tests/examples/puzzletron/mbridge_distillation/test_distill_hf.py @@ -0,0 +1,142 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Tests for distill_hf.py script.""" + +from pathlib import Path + +import torch +from _test_utils.examples.run_command import extend_cmd_parts, run_example_command +from _test_utils.torch.distributed.utils import get_free_port +from _test_utils.torch.puzzletron.utils import create_and_save_small_hf_model +from _test_utils.torch.transformers_models import get_tiny_tokenizer +from transformers import AutoModelForCausalLM + +from modelopt.torch.puzzletron.anymodel import convert_model + + +def test_distill_hf(project_root_path: Path, tmp_path: Path): + """Integration test for distill_hf.py. + + Creates Qwen3 models programmatically, converts them to heterogeneous format (AnyModel), + and runs mbridge distillation. The models are created with reduced size for faster testing. + Models are converted to include block_configs. + """ + # Prepare student and teacher models + student_hf_dir, student_anymodel_dir, _, teacher_anymodel_dir = ( + _prepare_student_and_teacher_models(project_root_path, tmp_path) + ) + + output_dir = tmp_path / "distill_output" + hf_export_dir = tmp_path / "hf_export" + + # Build command-line arguments for distill_hf.py + nproc_per_node = torch.cuda.device_count() + tp_size = nproc_per_node + train_iters = 5 + + cmd_parts = [ + "torchrun", + f"--nproc_per_node={nproc_per_node}", + "--master-addr", + "127.0.0.1", + "--master-port", + str(get_free_port()), + "distill_hf.py", + "--use_mock_data", + ] + extend_cmd_parts( + cmd_parts, + student_hf_path=student_anymodel_dir, + teacher_hf_path=teacher_anymodel_dir, + output_dir=output_dir, + tp_size=tp_size, + pp_size=1, + seq_length=128, + split="99,1,0", + mbs=1, + gbs=4, + train_iters=train_iters, + lr=0.0001, + min_lr=1e-5, + lr_warmup_iters=2, + eval_interval=100, + eval_iters=0, + log_interval=5, + hf_export_path=hf_export_dir, + hf_model=student_hf_dir, + ) + + run_example_command(cmd_parts, example_path="puzzletron/mbridge_distillation") + + # Check that distillation checkpoint contains run_config.yaml + run_config_path = output_dir / "checkpoints" / f"iter_{train_iters:07d}" / "run_config.yaml" + assert run_config_path.exists(), f"Expected run_config.yaml to exist at: {run_config_path}" + + # Verify that the distilled model can be loaded in HuggingFace format + model = AutoModelForCausalLM.from_pretrained(hf_export_dir) + assert model is not None, "Failed to load distilled model with AutoModelForCausalLM" + + +def _prepare_student_and_teacher_models( + project_root_path: Path, tmp_path: Path +) -> tuple[Path, Path, Path, Path]: + """Prepare student and teacher models for distillation. + + Creates Qwen3 models programmatically, converts them to heterogeneous format (AnyModel), + and returns the paths to the converted checkpoints. + + """ + + # Create temporary directories for models + student_hf_dir = tmp_path / "student_hf" + teacher_hf_dir = tmp_path / "teacher_hf" + + # Create tokenizer (uses local tokenizer from test resources) + tokenizer = get_tiny_tokenizer() + + # Create student model using utility function (loads config from Hub). + # TODO: Make the student model using different ffn sizes across layers. + create_and_save_small_hf_model( + output_path=str(student_hf_dir), + tokenizer=tokenizer, + hf_model_name="Qwen/Qwen3-0.6B", + hybrid_override_pattern=None, + ) + + # Create teacher model (same as student for testing) + create_and_save_small_hf_model( + output_path=str(teacher_hf_dir), + tokenizer=tokenizer, + hf_model_name="Qwen/Qwen3-0.6B", + hybrid_override_pattern=None, + ) + + # Convert models to AnyModel format BEFORE distillation + # This is needed as converted checkpoints will be used as input for distillation later + student_anymodel_dir = tmp_path / "student_anymodel" + teacher_anymodel_dir = tmp_path / "teacher_anymodel" + + convert_model( + input_dir=str(student_hf_dir), output_dir=str(student_anymodel_dir), converter="qwen3" + ) + + convert_model( + input_dir=str(teacher_hf_dir), output_dir=str(teacher_anymodel_dir), converter="qwen3" + ) + print("Models converted to AnyModel format:") + print(f" Student AnyModel: {student_anymodel_dir}") + print(f" Teacher AnyModel: {teacher_anymodel_dir}") + + return student_hf_dir, student_anymodel_dir, teacher_hf_dir, teacher_anymodel_dir diff --git a/tests/gpu/torch/puzzletron/nas/plugins/test_nas_convert.py b/tests/gpu/torch/puzzletron/nas/plugins/test_nas_convert.py new file mode 100644 index 0000000000..25991f1c74 --- /dev/null +++ b/tests/gpu/torch/puzzletron/nas/plugins/test_nas_convert.py @@ -0,0 +1,142 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +from datetime import timedelta +from functools import partial +from pathlib import Path + +import torch +from _test_utils.torch.distributed.utils import spawn_multiprocess_job +from _test_utils.torch.puzzletron.utils import setup_test_model_and_data + +import modelopt.torch.nas as mtn +import modelopt.torch.utils.distributed as dist +from modelopt.torch.puzzletron.nas.plugins.puzzletron_nas_plugin import PuzzletronModel + + +def test_nas_convert_ffn_pruning(project_root_path: Path, tmp_path: Path): + spawn_multiprocess_job( + size=torch.cuda.device_count(), + job=partial(_test_nas_convert_ffn_pruning_multiprocess_job, project_root_path, tmp_path), + backend="nccl", + ) + + +def _test_nas_convert_ffn_pruning_multiprocess_job( + project_root_path: Path, tmp_path: Path, rank: int, size: int +): + dist.setup(timeout=timedelta(10)) + # Setup the test model and data. + puzzle_dir, llama_checkpoint_path, dataset_path = setup_test_model_and_data( + tmp_path, rank, "meta-llama/Llama-3.1-8B-Instruct" + ) + hydra_config_dir = project_root_path / "tests/gpu/torch/puzzletron/resources/configs" + hydra_config_name = "meta-llama/Llama-3.1-8B-Instruct/Llama-3.1-8B-Instruct" + + # + # Run the mnt.convert() step + # + input_model = PuzzletronModel() + mtn.convert( + input_model, + mode=[ + ( + "puzzletron", + { + "puzzle_dir": str(puzzle_dir), + "input_model_path": str(llama_checkpoint_path), + "hydra_config_dir": str(hydra_config_dir), + "hydra_config_name": hydra_config_name, + "dataset_path": str(dataset_path), + }, + ) + ], + ) + + # + # Check assertions + # + if rank == 0: + # assertions for the score_pruning_activations step + rank = int(os.environ["RANK"]) + rank_filepath = ( + f"pruning/pruning_scores/ffn_iterative/100samples_diverse_mini/rank_{rank}.pth" + ) + assert (puzzle_dir / rank_filepath).is_file() + + # assertions for the pruning_ckpts step + assert (puzzle_dir / "ckpts/ffn_256_attn_no_op").exists() + + dist.cleanup() + + +def test_nas_convert_attn_pruning(project_root_path: Path, tmp_path: Path): + spawn_multiprocess_job( + size=torch.cuda.device_count(), + job=partial(_test_nas_convert_attn_pruning_multiprocess_job, project_root_path, tmp_path), + backend="nccl", + ) + + +def _test_nas_convert_attn_pruning_multiprocess_job( + project_root_path: Path, tmp_path: Path, rank: int, size: int +): + dist.setup(timeout=timedelta(10)) + # Setup the test model and data. + puzzle_dir, llama_checkpoint_path, dataset_path = setup_test_model_and_data( + tmp_path, rank, "meta-llama/Llama-3.1-8B-Instruct" + ) + hydra_config_dir = project_root_path / "tests/gpu/torch/puzzletron/resources/configs" + hydra_config_name = "meta-llama/Llama-3.1-8B-Instruct/Llama-3.1-8B-Instruct-attn-pruning" + + # + # Run the mnt.convert() step + # + input_model = PuzzletronModel() + mtn.convert( + input_model, + mode=[ + ( + "puzzletron", + { + "puzzle_dir": str(puzzle_dir), + "input_model_path": str(llama_checkpoint_path), + "hydra_config_dir": str(hydra_config_dir), + "hydra_config_name": hydra_config_name, + "dataset_path": str(dataset_path), + }, + ) + ], + ) + + # + # Check assertions + # + if rank == 0: + # assertions for the score_pruning_activations step + rank = int(os.environ["RANK"]) + rank_filepath = ( + f"pruning/pruning_scores/attn_independent_kv_head_contribution/" + f"100samples_diverse_mini/rank_{rank}.pth" + ) + assert (puzzle_dir / rank_filepath).is_file() + + # assertions for the pruning_ckpts step + assert (puzzle_dir / "ckpts/n_heads_in_group8").exists() + assert (puzzle_dir / "ckpts/n_heads_in_group16").exists() + assert (puzzle_dir / "ckpts/n_heads_in_group32").exists() + + dist.cleanup() diff --git a/tests/gpu/torch/puzzletron/nas/plugins/test_nas_search.py b/tests/gpu/torch/puzzletron/nas/plugins/test_nas_search.py new file mode 100644 index 0000000000..aede36bded --- /dev/null +++ b/tests/gpu/torch/puzzletron/nas/plugins/test_nas_search.py @@ -0,0 +1,102 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from datetime import timedelta +from functools import partial +from pathlib import Path + +import torch +from _test_utils.torch.distributed.utils import spawn_multiprocess_job +from _test_utils.torch.puzzletron.utils import setup_test_model_and_data + +import modelopt.torch.nas as mtn +import modelopt.torch.utils.distributed as dist +from modelopt.torch.puzzletron.nas.plugins.puzzletron_nas_plugin import PuzzletronModel + + +def test_nas_search(project_root_path: Path, tmp_path: Path): + spawn_multiprocess_job( + size=torch.cuda.device_count(), + job=partial(_test_nas_search_multiprocess_job, project_root_path, tmp_path), + backend="nccl", + ) + + +def _test_nas_search_multiprocess_job( + project_root_path: Path, tmp_path: Path, rank: int, size: int +): + dist.setup(timeout=timedelta(10)) + # Setup the test model and data. + puzzle_dir, llama_checkpoint_path, dataset_path = setup_test_model_and_data( + tmp_path, rank, "meta-llama/Llama-3.1-8B-Instruct" + ) + hydra_config_dir = project_root_path / "tests/gpu/torch/puzzletron/resources/configs" + hydra_config_name = "meta-llama/Llama-3.1-8B-Instruct/Llama-3.1-8B-Instruct" + + # + # Run the mnt.convert() step + # + input_model = PuzzletronModel() + converted_model = mtn.convert( + input_model, + mode=[ + ( + "puzzletron", + { + "puzzle_dir": str(puzzle_dir), + "input_model_path": str(llama_checkpoint_path), + "hydra_config_dir": str(hydra_config_dir), + "hydra_config_name": hydra_config_name, + "dataset_path": str(dataset_path), + }, + ) + ], + ) + + # + # Run the mnt.search() step + # + mtn.search( + converted_model, + constraints={}, # this is not used as the search space is defined in the hydra config + dummy_input=None, # Not used + config={}, # this is not used as the search space is defined in the hydra config + ) + + # + # Check assertions for mtn.search() step + # + if rank == 0: + # assertions for the build_library_and_stats step + assert (puzzle_dir / "replacement_library.json").is_file() + assert (puzzle_dir / "subblock_stats.json").is_file() + + # assertions for the scoring step + solution_0_filepath = ( + puzzle_dir / "single_sequence_replacement_solutions--validation/solution_0.json" + ) + + assert solution_0_filepath.exists() + + # assertions for the mip_and_realize_models step + solution_0_ckpt_config_path = ( + puzzle_dir + / "mip/puzzle_solutions/target_memory_780000MiB/solutions--checkpoints/solution_0/config.json" + ) + + assert solution_0_ckpt_config_path.exists() + assert (puzzle_dir / "mip/puzzle_solutions/target_memory_780000MiB/solutions.json").exists() + + dist.cleanup() diff --git a/tests/gpu/torch/puzzletron/resources/configs/Qwen/Qwen2.5-7B-Instruct/Qwen2.5-7B-Instruct.yaml b/tests/gpu/torch/puzzletron/resources/configs/Qwen/Qwen2.5-7B-Instruct/Qwen2.5-7B-Instruct.yaml new file mode 100644 index 0000000000..2843f0b97a --- /dev/null +++ b/tests/gpu/torch/puzzletron/resources/configs/Qwen/Qwen2.5-7B-Instruct/Qwen2.5-7B-Instruct.yaml @@ -0,0 +1,113 @@ +# @package _global_ +defaults: + - /Qwen/Qwen2.5-7B-Instruct/pruning@pruning: ffn_pruning + - /validate_solutions_defaults@scoring + - /validate_solutions_defaults@realize_model + - _self_ + +puzzle_dir: ??? +teacher_dir: ${puzzle_dir}/ckpts/teacher/ +replacement_library_path: ${puzzle_dir}/replacement_library.json +dataset_path: ??? # path to v0.4_mini + +skip_realize_model: false + +descriptor: qwen2 + +build_replacement_library: + add_ffn_no_ops: true + add_attention_no_ops: true + +calc_subblock_stats: + batch_sizes: [64, 96, 128] + prefill_seq_len: 4096 + generation_seq_len: 4096 + num_active_tokens_override: # Optional override for sequence lengths + prefill_queue_size: 0 + allocate_prefill_query: false + benchmark_iterations: # Set to a number (e.g., 1000) to enable runtime benchmarking + merge_with_existing_stats: false + subblock_stats_filename: "subblock_stats.json" + moe_stats_filename: "moe_stats.json" + +scoring: + descriptor: ${descriptor} + + solutions_to_validate: + skip_existing_solutions: true + + replacement_library_path: ${replacement_library_path} + solutions_path: ${to_path:${puzzle_dir}/single_sequence_replacement_solutions.json} + teacher_dir: ${to_path:${teacher_dir}} + output_dir: ${puzzle_dir}/single_sequence_replacement_solutions--validation + + eval_samples: 2 + micro_batch_size: 1 + dataset_path: ${dataset_path}/valid + seed: 42 + shuffle_seed: 444 + +mip: + single_block_replacement_validation_dir: ${to_path:${scoring.output_dir}} + subblock_stats_path: ${to_path:${puzzle_dir}/${calc_subblock_stats.subblock_stats_filename}} + output_path: ${to_path:${puzzle_dir}/mip/puzzle_solutions} + gathered_metrics_path: + puzzle_profile: + + # puzzle_profile: + objective: metrics.cosine_embedding_loss_hidden_states + bigger_is_better: false + num_solutions: 1 + minimal_diversity: 2 + + subblock_stats_args: + - batch_size: 96 + weights_dtype: torch.bfloat16 + activations_dtype: torch.bfloat16 + kv_cache_dtype: torch.bfloat16 + + report_additional_costs: + - stats.memory_mib + - stats.num_params + - stats.num_kv_heads + - stats.has_attention + - stats.has_ffn + - stats.kv_cache_memory_mib + - stats.attention_memory_mib + - stats.ffn_memory_mib + - stats.ffn_num_params + - stats.attention_num_params + + human_constraints: + target_memory: 780_000 # 78_000 + + mip_constraints: + use_greedy_search: false + is_multi_layer_puzzle: true + metric_overrides: + constrain_search_func: + max_seconds_per_solution: 60 + +realize_model: + descriptor: ${descriptor} + + teacher_dir: ${to_path:${teacher_dir}} + tokenizer_name: ${to_path:${teacher_dir}} + replacement_library_path: ${replacement_library_path} + save_models: true + solutions_path: # Filled dynamically + + # Validate params + skip_validation: false # To enable validation of the model solution set `skip_validation` as False + eval_samples: 2 + micro_batch_size: 1 + dataset_path: ${dataset_path}/valid + seed: 42 + shuffle_seed: 444 + +nccl_timeout_minutes: ${timedelta_minutes:10} + +# This section redirects Hydra outputs +hydra: + run: + dir: ${puzzle_dir}/hydra_logs/${now:%Y-%m-%d}/${now:%H-%M-%S} diff --git a/tests/gpu/torch/puzzletron/resources/configs/Qwen/Qwen2.5-7B-Instruct/pruning/ffn_pruning.yaml b/tests/gpu/torch/puzzletron/resources/configs/Qwen/Qwen2.5-7B-Instruct/pruning/ffn_pruning.yaml new file mode 100644 index 0000000000..cf6201080c --- /dev/null +++ b/tests/gpu/torch/puzzletron/resources/configs/Qwen/Qwen2.5-7B-Instruct/pruning/ffn_pruning.yaml @@ -0,0 +1,7 @@ +defaults: + - /pruning/ffn_pruning_base@_here_ + - _self_ + +pruning_mixin: + layer_descriptor: + _target_: modelopt.torch.puzzletron.anymodel.models.qwen2.qwen2_model_descriptor.Qwen2FFNIntermediateLayerDescriptor diff --git a/tests/gpu/torch/puzzletron/resources/configs/Qwen/Qwen3-8B/Qwen3-8B.yaml b/tests/gpu/torch/puzzletron/resources/configs/Qwen/Qwen3-8B/Qwen3-8B.yaml new file mode 100644 index 0000000000..cd82a47271 --- /dev/null +++ b/tests/gpu/torch/puzzletron/resources/configs/Qwen/Qwen3-8B/Qwen3-8B.yaml @@ -0,0 +1,112 @@ +# @package _global_ +defaults: + - /Qwen/Qwen3-8B/pruning@pruning: ffn_pruning + - /validate_solutions_defaults@scoring + - /validate_solutions_defaults@realize_model + - _self_ + +puzzle_dir: ??? +teacher_dir: ${puzzle_dir}/ckpts/teacher/ +replacement_library_path: ${puzzle_dir}/replacement_library.json +dataset_path: ??? # path to v0.4_mini + +skip_realize_model: false + +descriptor: qwen3 + +build_replacement_library: + add_ffn_no_ops: true + add_attention_no_ops: true + +calc_subblock_stats: + batch_sizes: [64, 96, 128] + prefill_seq_len: 4096 + generation_seq_len: 4096 + num_active_tokens_override: # Optional override for sequence lengths + prefill_queue_size: 0 + benchmark_iterations: # Set to a number (e.g., 1000) to enable runtime benchmarking + merge_with_existing_stats: false + subblock_stats_filename: "subblock_stats.json" + moe_stats_filename: "moe_stats.json" + +scoring: + descriptor: ${descriptor} + + solutions_to_validate: + skip_existing_solutions: true + + replacement_library_path: ${replacement_library_path} + solutions_path: ${to_path:${puzzle_dir}/single_sequence_replacement_solutions.json} + teacher_dir: ${to_path:${teacher_dir}} + output_dir: ${puzzle_dir}/single_sequence_replacement_solutions--validation + + eval_samples: 2 + micro_batch_size: 1 + dataset_path: ${dataset_path}/valid + seed: 42 + shuffle_seed: 444 + +mip: + single_block_replacement_validation_dir: ${to_path:${scoring.output_dir}} + subblock_stats_path: ${to_path:${puzzle_dir}/${calc_subblock_stats.subblock_stats_filename}} + output_path: ${to_path:${puzzle_dir}/mip/puzzle_solutions} + gathered_metrics_path: + puzzle_profile: + + # puzzle_profile: + objective: metrics.cosine_embedding_loss_hidden_states + bigger_is_better: false + num_solutions: 1 + minimal_diversity: 2 + + subblock_stats_args: + - batch_size: 96 + weights_dtype: torch.bfloat16 + activations_dtype: torch.bfloat16 + kv_cache_dtype: torch.bfloat16 + + report_additional_costs: + - stats.memory_mib + - stats.num_params + - stats.num_kv_heads + - stats.has_attention + - stats.has_ffn + - stats.kv_cache_memory_mib + - stats.attention_memory_mib + - stats.ffn_memory_mib + - stats.ffn_num_params + - stats.attention_num_params + + human_constraints: + target_memory: 780_000 # 78_000 + + mip_constraints: + use_greedy_search: false + is_multi_layer_puzzle: true + metric_overrides: + constrain_search_func: + max_seconds_per_solution: 60 + +realize_model: + descriptor: ${descriptor} + + teacher_dir: ${to_path:${teacher_dir}} + tokenizer_name: ${to_path:${teacher_dir}} + replacement_library_path: ${replacement_library_path} + save_models: true + solutions_path: # Filled dynamically + + # Validate params + skip_validation: false # To enable validation of the model solution set `skip_validation` as False + eval_samples: 2 + micro_batch_size: 1 + dataset_path: ${dataset_path}/valid + seed: 42 + shuffle_seed: 444 + +nccl_timeout_minutes: ${timedelta_minutes:10} + +# This section redirects Hydra outputs +hydra: + run: + dir: ${puzzle_dir}/hydra_logs/${now:%Y-%m-%d}/${now:%H-%M-%S} diff --git a/tests/gpu/torch/puzzletron/resources/configs/Qwen/Qwen3-8B/pruning/ffn_pruning.yaml b/tests/gpu/torch/puzzletron/resources/configs/Qwen/Qwen3-8B/pruning/ffn_pruning.yaml new file mode 100644 index 0000000000..6bfeec715c --- /dev/null +++ b/tests/gpu/torch/puzzletron/resources/configs/Qwen/Qwen3-8B/pruning/ffn_pruning.yaml @@ -0,0 +1,7 @@ +defaults: + - /pruning/ffn_pruning_base@_here_ + - _self_ + +pruning_mixin: + layer_descriptor: + _target_: modelopt.torch.puzzletron.anymodel.models.qwen3.qwen3_model_descriptor.Qwen3FFNIntermediateLayerDescriptor diff --git a/tests/gpu/torch/puzzletron/resources/configs/Qwen/Qwen3-VL-30B-A3B-Instruct/Qwen3-VL-30B-A3B-Instruct.yaml b/tests/gpu/torch/puzzletron/resources/configs/Qwen/Qwen3-VL-30B-A3B-Instruct/Qwen3-VL-30B-A3B-Instruct.yaml new file mode 100644 index 0000000000..00b21ea979 --- /dev/null +++ b/tests/gpu/torch/puzzletron/resources/configs/Qwen/Qwen3-VL-30B-A3B-Instruct/Qwen3-VL-30B-A3B-Instruct.yaml @@ -0,0 +1,113 @@ +# @package _global_ +defaults: + - /Qwen/Qwen3-VL-30B-A3B-Instruct/pruning@pruning: expert_pruning + - /validate_solutions_defaults@scoring + - /validate_solutions_defaults@realize_model + - _self_ + +puzzle_dir: ??? +teacher_dir: ${puzzle_dir}/ckpts/teacher/ +replacement_library_path: ${puzzle_dir}/replacement_library.json +dataset_path: ??? # path to v0.4_mini + +skip_realize_model: false + +descriptor: qwen3_vl + +build_replacement_library: + add_ffn_no_ops: true + add_attention_no_ops: true + +calc_subblock_stats: + batch_sizes: [64, 96, 128] + prefill_seq_len: 4096 + generation_seq_len: 4096 + num_active_tokens_override: # Optional override for sequence lengths + prefill_queue_size: 0 + benchmark_iterations: # Set to a number (e.g., 1000) to enable runtime benchmarking + merge_with_existing_stats: false + subblock_stats_filename: "subblock_stats.json" + moe_stats_filename: "moe_stats.json" + +scoring: + descriptor: ${descriptor} + + solutions_to_validate: + skip_existing_solutions: true + + replacement_library_path: ${replacement_library_path} + solutions_path: ${to_path:${puzzle_dir}/single_sequence_replacement_solutions.json} + teacher_dir: ${to_path:${teacher_dir}} + output_dir: ${puzzle_dir}/single_sequence_replacement_solutions--validation + + eval_samples: 2 + micro_batch_size: 1 + dataset_path: ${dataset_path}/valid + seed: 42 + shuffle_seed: 444 + +mip: + single_block_replacement_validation_dir: ${to_path:${scoring.output_dir}} + subblock_stats_path: ${to_path:${puzzle_dir}/${calc_subblock_stats.subblock_stats_filename}} + output_path: ${to_path:${puzzle_dir}/mip/puzzle_solutions} + gathered_metrics_path: + puzzle_profile: + + # puzzle_profile: + objective: metrics.cosine_embedding_loss_hidden_states + bigger_is_better: false + num_solutions: 1 + minimal_diversity: 2 + + subblock_stats_args: + - batch_size: 96 + weights_dtype: torch.bfloat16 + activations_dtype: torch.bfloat16 + kv_cache_dtype: torch.bfloat16 + + report_additional_costs: + - stats.memory_mib + - stats.num_params + - stats.num_kv_heads + - stats.has_attention + - stats.has_ffn + - stats.kv_cache_memory_mib + - stats.attention_memory_mib + - stats.ffn_memory_mib + - stats.ffn_num_params + - stats.attention_num_params + - stats.num_local_experts + + human_constraints: + + mip_constraints: + - stats.num_local_experts: 1472 # same constraint as nemotron-3-nano for test consistency + use_greedy_search: false + is_multi_layer_puzzle: true + metric_overrides: + constrain_search_func: + max_seconds_per_solution: 60 + +realize_model: + descriptor: ${descriptor} + + teacher_dir: ${to_path:${teacher_dir}} + tokenizer_name: ${to_path:${teacher_dir}} + replacement_library_path: ${replacement_library_path} + save_models: true + solutions_path: # Filled dynamically + + # Validate params + skip_validation: false # To enable validation of the model solution set `skip_validation` as False + eval_samples: 2 + micro_batch_size: 1 + dataset_path: ${dataset_path}/valid + seed: 42 + shuffle_seed: 444 + +nccl_timeout_minutes: ${timedelta_minutes:10} + +# This section redirects Hydra outputs +hydra: + run: + dir: ${puzzle_dir}/hydra_logs/${now:%Y-%m-%d}/${now:%H-%M-%S} diff --git a/tests/gpu/torch/puzzletron/resources/configs/Qwen/Qwen3-VL-30B-A3B-Instruct/pruning/expert_pruning.yaml b/tests/gpu/torch/puzzletron/resources/configs/Qwen/Qwen3-VL-30B-A3B-Instruct/pruning/expert_pruning.yaml new file mode 100644 index 0000000000..4e0786dc7a --- /dev/null +++ b/tests/gpu/torch/puzzletron/resources/configs/Qwen/Qwen3-VL-30B-A3B-Instruct/pruning/expert_pruning.yaml @@ -0,0 +1,20 @@ +defaults: + - /pruning/pruning_defaults@_here_ + +eval_samples: 10 +activations_log_dir: ${puzzle_dir}/pruning/pruning_scores/expert_removal/${pruning.experiment_id} +pruning_mixin: + _target_: modelopt.torch.puzzletron.pruning.expert_removal_pruning_mixin.ExpertRemovalPruningMixIn + layer_descriptor: + _target_: modelopt.torch.puzzletron.anymodel.models.qwen3_vl.qwen3_vl_model_descriptor.Qwen3VLExpertRemovalLayerDescriptor + target_name: "mlp" + +hook_class: ${get_object:modelopt.torch.prune.importance_hooks.expert_removal_hooks.Qwen3VLRemoveExpertsIndependentHook} +activation_hooks_kwargs: + +# num_experts_to_keep must be >= num_experts_per_tok (can't route to more experts than exist) +num_experts_to_keep_list: [8] # num_experts in test model is 16, num_experts_per_tok is 8 +mlp_init_mode: "ExpertRemoval" +mlp_init_config_yaml: + expert_scores_key: "expert_ranks_mse" + layer_prefix_template: "model.language_model.layers.{layer_idx}.mlp" diff --git a/tests/gpu/torch/puzzletron/resources/configs/meta-llama/Llama-3.1-8B-Instruct/Llama-3.1-8B-Instruct-attn-pruning.yaml b/tests/gpu/torch/puzzletron/resources/configs/meta-llama/Llama-3.1-8B-Instruct/Llama-3.1-8B-Instruct-attn-pruning.yaml new file mode 100644 index 0000000000..57051431a1 --- /dev/null +++ b/tests/gpu/torch/puzzletron/resources/configs/meta-llama/Llama-3.1-8B-Instruct/Llama-3.1-8B-Instruct-attn-pruning.yaml @@ -0,0 +1,10 @@ +# @package _global_ +defaults: + - /meta-llama/Llama-3.1-8B-Instruct/pruning@pruning: attn_pruning + - _self_ + +descriptor: llama + +puzzle_dir: ??? +teacher_dir: ${puzzle_dir}/ckpts/teacher/ +dataset_path: ??? diff --git a/tests/gpu/torch/puzzletron/resources/configs/meta-llama/Llama-3.1-8B-Instruct/Llama-3.1-8B-Instruct.yaml b/tests/gpu/torch/puzzletron/resources/configs/meta-llama/Llama-3.1-8B-Instruct/Llama-3.1-8B-Instruct.yaml new file mode 100644 index 0000000000..8e2e0786b3 --- /dev/null +++ b/tests/gpu/torch/puzzletron/resources/configs/meta-llama/Llama-3.1-8B-Instruct/Llama-3.1-8B-Instruct.yaml @@ -0,0 +1,106 @@ +# @package _global_ +defaults: + - /meta-llama/Llama-3.1-8B-Instruct/pruning@pruning: ffn_pruning + - /validate_solutions_defaults@scoring + - /validate_solutions_defaults@realize_model + - _self_ + +descriptor: llama + +puzzle_dir: ??? +teacher_dir: ${puzzle_dir}/ckpts/teacher/ +replacement_library_path: ${puzzle_dir}/replacement_library.json +dataset_path: ??? # path to v0.4_mini + +skip_realize_model: false + +build_replacement_library: + add_ffn_no_ops: true + add_attention_no_ops: true + +calc_subblock_stats: + batch_sizes: [64, 96, 128] + prefill_seq_len: 4096 + generation_seq_len: 4096 + num_active_tokens_override: # Optional override for sequence lengths + prefill_queue_size: 0 + allocate_prefill_query: false + benchmark_iterations: # Set to a number (e.g., 1000) to enable runtime benchmarking + merge_with_existing_stats: false + subblock_stats_filename: "subblock_stats.json" + moe_stats_filename: "moe_stats.json" + +scoring: + descriptor: ${descriptor} + solutions_to_validate: + skip_existing_solutions: true + + replacement_library_path: ${replacement_library_path} + solutions_path: ${to_path:${puzzle_dir}/single_sequence_replacement_solutions.json} + teacher_dir: ${to_path:${teacher_dir}} + output_dir: ${puzzle_dir}/single_sequence_replacement_solutions--validation + + eval_samples: 2 + micro_batch_size: 1 + dataset_path: ${dataset_path}/valid + seed: 42 + shuffle_seed: 444 + +mip: + single_block_replacement_validation_dir: ${to_path:${scoring.output_dir}} + subblock_stats_path: ${to_path:${puzzle_dir}/${calc_subblock_stats.subblock_stats_filename}} + output_path: ${to_path:${puzzle_dir}/mip/puzzle_solutions} + gathered_metrics_path: + puzzle_profile: + + # puzzle_profile: + objective: metrics.cosine_embedding_loss_hidden_states + bigger_is_better: false + + subblock_stats_args: + - batch_size: 96 + weights_dtype: torch.bfloat16 + activations_dtype: torch.bfloat16 + kv_cache_dtype: torch.bfloat16 + + report_additional_costs: + - stats.memory_mib + - stats.num_params + - stats.num_kv_heads + - stats.has_attention + - stats.has_ffn + - stats.kv_cache_memory_mib + - stats.attention_memory_mib + - stats.ffn_memory_mib + - stats.ffn_num_params + - stats.attention_num_params + + human_constraints: + target_memory: 780_000 # 78_000 + + mip_constraints: + metric_overrides: + max_seconds_per_solution: 60 + +realize_model: + descriptor: ${descriptor} + teacher_dir: ${to_path:${teacher_dir}} + tokenizer_name: ${to_path:${teacher_dir}} + replacement_library_path: ${replacement_library_path} + save_models: true + solutions_path: # Filled dynamically + + # Validate params + skip_validation: false # To enable validation of the model solution set `skip_validation` as False + eval_samples: 2 + micro_batch_size: 1 + dataset_path: ${dataset_path}/valid + seed: 42 + shuffle_seed: 444 + +nccl_timeout_minutes: ${timedelta_minutes:10} + +# This section redirects Hydra outputs +hydra: + run: + dir: ${puzzle_dir}/hydra_logs/${now:%Y-%m-%d}/${now:%H-%M-%S} diff --git a/tests/gpu/torch/puzzletron/resources/configs/meta-llama/Llama-3.1-8B-Instruct/pruning/attn_pruning.yaml b/tests/gpu/torch/puzzletron/resources/configs/meta-llama/Llama-3.1-8B-Instruct/pruning/attn_pruning.yaml new file mode 100644 index 0000000000..6e8af1f651 --- /dev/null +++ b/tests/gpu/torch/puzzletron/resources/configs/meta-llama/Llama-3.1-8B-Instruct/pruning/attn_pruning.yaml @@ -0,0 +1,7 @@ +defaults: + - /pruning/attn_pruning@_here_ + - _self_ + +pruning_mixin: + layer_descriptor: + _target_: modelopt.torch.puzzletron.anymodel.models.llama.llama_model_descriptor.LlamaKVHeadsLayerDescriptor diff --git a/tests/gpu/torch/puzzletron/resources/configs/meta-llama/Llama-3.1-8B-Instruct/pruning/ffn_pruning.yaml b/tests/gpu/torch/puzzletron/resources/configs/meta-llama/Llama-3.1-8B-Instruct/pruning/ffn_pruning.yaml new file mode 100644 index 0000000000..b30f4a17d9 --- /dev/null +++ b/tests/gpu/torch/puzzletron/resources/configs/meta-llama/Llama-3.1-8B-Instruct/pruning/ffn_pruning.yaml @@ -0,0 +1,7 @@ +defaults: + - /pruning/ffn_pruning_base@_here_ + - _self_ + +pruning_mixin: + layer_descriptor: + _target_: modelopt.torch.puzzletron.anymodel.models.llama.llama_model_descriptor.LlamaFFNIntermediateLayerDescriptor diff --git a/tests/gpu/torch/puzzletron/resources/configs/meta-llama/Llama-3.2-3B-Instruct/Llama-3.2-3B-Instruct.yaml b/tests/gpu/torch/puzzletron/resources/configs/meta-llama/Llama-3.2-3B-Instruct/Llama-3.2-3B-Instruct.yaml new file mode 100644 index 0000000000..78cb6bd73c --- /dev/null +++ b/tests/gpu/torch/puzzletron/resources/configs/meta-llama/Llama-3.2-3B-Instruct/Llama-3.2-3B-Instruct.yaml @@ -0,0 +1,106 @@ +# @package _global_ +defaults: + - /meta-llama/Llama-3.2-3B-Instruct/pruning@pruning: ffn_pruning + - /validate_solutions_defaults@scoring + - /validate_solutions_defaults@realize_model + - _self_ + +descriptor: llama + +puzzle_dir: ??? +teacher_dir: ${puzzle_dir}/ckpts/teacher/ +replacement_library_path: ${puzzle_dir}/replacement_library.json +dataset_path: ??? # path to v0.4_mini + +skip_realize_model: false + +build_replacement_library: + add_ffn_no_ops: true + add_attention_no_ops: true + +calc_subblock_stats: + batch_sizes: [64, 96, 128] + prefill_seq_len: 4096 + generation_seq_len: 4096 + num_active_tokens_override: # Optional override for sequence lengths + prefill_queue_size: 0 + allocate_prefill_query: false + benchmark_iterations: # Set to a number (e.g., 1000) to enable runtime benchmarking + merge_with_existing_stats: false + subblock_stats_filename: "subblock_stats.json" + moe_stats_filename: "moe_stats.json" + +scoring: + descriptor: ${descriptor} + solutions_to_validate: + skip_existing_solutions: true + + replacement_library_path: ${replacement_library_path} + solutions_path: ${to_path:${puzzle_dir}/single_sequence_replacement_solutions.json} + teacher_dir: ${to_path:${teacher_dir}} + output_dir: ${puzzle_dir}/single_sequence_replacement_solutions--validation + + eval_samples: 2 + micro_batch_size: 1 + dataset_path: ${dataset_path}/valid + seed: 42 + shuffle_seed: 444 + +mip: + single_block_replacement_validation_dir: ${to_path:${scoring.output_dir}} + subblock_stats_path: ${to_path:${puzzle_dir}/${calc_subblock_stats.subblock_stats_filename}} + output_path: ${to_path:${puzzle_dir}/mip/puzzle_solutions} + gathered_metrics_path: + puzzle_profile: + + # puzzle_profile: + objective: metrics.cosine_embedding_loss_hidden_states + bigger_is_better: false + + subblock_stats_args: + - batch_size: 96 + weights_dtype: torch.bfloat16 + activations_dtype: torch.bfloat16 + kv_cache_dtype: torch.bfloat16 + + report_additional_costs: + - stats.memory_mib + - stats.num_params + - stats.num_kv_heads + - stats.has_attention + - stats.has_ffn + - stats.kv_cache_memory_mib + - stats.attention_memory_mib + - stats.ffn_memory_mib + - stats.ffn_num_params + - stats.attention_num_params + + human_constraints: + target_memory: 780_000 # 78_000 + + mip_constraints: + metric_overrides: + max_seconds_per_solution: 60 + +realize_model: + descriptor: ${descriptor} + teacher_dir: ${to_path:${teacher_dir}} + tokenizer_name: ${to_path:${teacher_dir}} + replacement_library_path: ${replacement_library_path} + save_models: true + solutions_path: # Filled dynamically + + # Validate params + skip_validation: false # To enable validation of the model solution set `skip_validation` as False + eval_samples: 2 + micro_batch_size: 1 + dataset_path: ${dataset_path}/valid + seed: 42 + shuffle_seed: 444 + +nccl_timeout_minutes: ${timedelta_minutes:10} + +# This section redirects Hydra outputs +hydra: + run: + dir: ${puzzle_dir}/hydra_logs/${now:%Y-%m-%d}/${now:%H-%M-%S} diff --git a/tests/gpu/torch/puzzletron/resources/configs/meta-llama/Llama-3.2-3B-Instruct/pruning/ffn_pruning.yaml b/tests/gpu/torch/puzzletron/resources/configs/meta-llama/Llama-3.2-3B-Instruct/pruning/ffn_pruning.yaml new file mode 100644 index 0000000000..b30f4a17d9 --- /dev/null +++ b/tests/gpu/torch/puzzletron/resources/configs/meta-llama/Llama-3.2-3B-Instruct/pruning/ffn_pruning.yaml @@ -0,0 +1,7 @@ +defaults: + - /pruning/ffn_pruning_base@_here_ + - _self_ + +pruning_mixin: + layer_descriptor: + _target_: modelopt.torch.puzzletron.anymodel.models.llama.llama_model_descriptor.LlamaFFNIntermediateLayerDescriptor diff --git a/tests/gpu/torch/puzzletron/resources/configs/mistralai/Mistral-Small-24B-Instruct-2501/Mistral-Small-24B-Instruct-2501.yaml b/tests/gpu/torch/puzzletron/resources/configs/mistralai/Mistral-Small-24B-Instruct-2501/Mistral-Small-24B-Instruct-2501.yaml new file mode 100644 index 0000000000..e042c4bb62 --- /dev/null +++ b/tests/gpu/torch/puzzletron/resources/configs/mistralai/Mistral-Small-24B-Instruct-2501/Mistral-Small-24B-Instruct-2501.yaml @@ -0,0 +1,112 @@ +# @package _global_ +defaults: + - /mistralai/Mistral-Small-24B-Instruct-2501/pruning@pruning: ffn_pruning + - /validate_solutions_defaults@scoring + - /validate_solutions_defaults@realize_model + - _self_ + +puzzle_dir: ??? +teacher_dir: ${puzzle_dir}/ckpts/teacher/ +replacement_library_path: ${puzzle_dir}/replacement_library.json +dataset_path: ??? # path to v0.4_mini + +skip_realize_model: false + +descriptor: mistral_small + +build_replacement_library: + add_ffn_no_ops: true + add_attention_no_ops: true + +calc_subblock_stats: + batch_sizes: [64, 96, 128] + prefill_seq_len: 4096 + generation_seq_len: 4096 + num_active_tokens_override: # Optional override for sequence lengths + prefill_queue_size: 0 + benchmark_iterations: # Set to a number (e.g., 1000) to enable runtime benchmarking + merge_with_existing_stats: false + subblock_stats_filename: "subblock_stats.json" + moe_stats_filename: "moe_stats.json" + +scoring: + descriptor: ${descriptor} + + solutions_to_validate: + skip_existing_solutions: true + + replacement_library_path: ${replacement_library_path} + solutions_path: ${to_path:${puzzle_dir}/single_sequence_replacement_solutions.json} + teacher_dir: ${to_path:${teacher_dir}} + output_dir: ${puzzle_dir}/single_sequence_replacement_solutions--validation + + eval_samples: 2 + micro_batch_size: 1 + dataset_path: ${dataset_path}/valid + seed: 42 + shuffle_seed: 444 + +mip: + single_block_replacement_validation_dir: ${to_path:${scoring.output_dir}} + subblock_stats_path: ${to_path:${puzzle_dir}/${calc_subblock_stats.subblock_stats_filename}} + output_path: ${to_path:${puzzle_dir}/mip/puzzle_solutions} + gathered_metrics_path: + puzzle_profile: + + # puzzle_profile: + objective: metrics.cosine_embedding_loss_hidden_states + bigger_is_better: false + num_solutions: 1 + minimal_diversity: 2 + + subblock_stats_args: + - batch_size: 96 + weights_dtype: torch.bfloat16 + activations_dtype: torch.bfloat16 + kv_cache_dtype: torch.bfloat16 + + report_additional_costs: + - stats.memory_mib + - stats.num_params + - stats.num_kv_heads + - stats.has_attention + - stats.has_ffn + - stats.kv_cache_memory_mib + - stats.attention_memory_mib + - stats.ffn_memory_mib + - stats.ffn_num_params + - stats.attention_num_params + + human_constraints: + target_memory: 780_000 # 78_000 + + mip_constraints: + use_greedy_search: false + is_multi_layer_puzzle: true + metric_overrides: + constrain_search_func: + max_seconds_per_solution: 60 + +realize_model: + descriptor: ${descriptor} + + teacher_dir: ${to_path:${teacher_dir}} + tokenizer_name: ${to_path:${teacher_dir}} + replacement_library_path: ${replacement_library_path} + save_models: true + solutions_path: # Filled dynamically + + # Validate params + skip_validation: false # To enable validation of the model solution set `skip_validation` as False + eval_samples: 2 + micro_batch_size: 1 + dataset_path: ${dataset_path}/valid + seed: 42 + shuffle_seed: 444 + +nccl_timeout_minutes: ${timedelta_minutes:10} + +# This section redirects Hydra outputs +hydra: + run: + dir: ${puzzle_dir}/hydra_logs/${now:%Y-%m-%d}/${now:%H-%M-%S} diff --git a/tests/gpu/torch/puzzletron/resources/configs/mistralai/Mistral-Small-24B-Instruct-2501/pruning/ffn_pruning.yaml b/tests/gpu/torch/puzzletron/resources/configs/mistralai/Mistral-Small-24B-Instruct-2501/pruning/ffn_pruning.yaml new file mode 100644 index 0000000000..37c21fd638 --- /dev/null +++ b/tests/gpu/torch/puzzletron/resources/configs/mistralai/Mistral-Small-24B-Instruct-2501/pruning/ffn_pruning.yaml @@ -0,0 +1,7 @@ +defaults: + - /pruning/ffn_pruning_base@_here_ + - _self_ + +pruning_mixin: + layer_descriptor: + _target_: modelopt.torch.puzzletron.anymodel.models.mistral_small.mistral_small_model_descriptor.MistralFFNIntermediateLayerDescriptor diff --git a/tests/gpu/torch/puzzletron/resources/configs/nvidia/NVIDIA-Nemotron-3-Nano-30B-A3B-Base-BF16/NVIDIA-Nemotron-3-Nano-30B-A3B-Base-BF16.yaml b/tests/gpu/torch/puzzletron/resources/configs/nvidia/NVIDIA-Nemotron-3-Nano-30B-A3B-Base-BF16/NVIDIA-Nemotron-3-Nano-30B-A3B-Base-BF16.yaml new file mode 100644 index 0000000000..ab2b09e679 --- /dev/null +++ b/tests/gpu/torch/puzzletron/resources/configs/nvidia/NVIDIA-Nemotron-3-Nano-30B-A3B-Base-BF16/NVIDIA-Nemotron-3-Nano-30B-A3B-Base-BF16.yaml @@ -0,0 +1,115 @@ +# @package _global_ +defaults: + - /nvidia/NVIDIA-Nemotron-3-Nano-30B-A3B-Base-BF16/pruning@pruning: expert_pruning + - /validate_solutions_defaults@scoring + - /validate_solutions_defaults@realize_model + - _self_ + + +puzzle_dir: ??? +teacher_dir: ${puzzle_dir}/ckpts/teacher/ +replacement_library_path: ${puzzle_dir}/replacement_library.json +dataset_path: ??? # path to v0.4_mini + +skip_realize_model: false + +descriptor: nemotron_h + +build_replacement_library: + add_ffn_no_ops: true + add_attention_no_ops: true + +calc_subblock_stats: + batch_sizes: [64, 96, 128] + prefill_seq_len: 4096 + generation_seq_len: 4096 + num_active_tokens_override: # Optional override for sequence lengths + prefill_queue_size: 0 + allocate_prefill_query: false + benchmark_iterations: # Set to a number (e.g., 1000) to enable runtime benchmarking + merge_with_existing_stats: false + subblock_stats_filename: "subblock_stats.json" + moe_stats_filename: "moe_stats.json" + runtime_stats: + backend: trt_torch + +scoring: + descriptor: ${descriptor} + + solutions_to_validate: + skip_existing_solutions: true + + replacement_library_path: ${replacement_library_path} + solutions_path: ${to_path:${puzzle_dir}/single_sequence_replacement_solutions.json} + teacher_dir: ${to_path:${teacher_dir}} + output_dir: ${puzzle_dir}/single_sequence_replacement_solutions--validation + + eval_samples: 2 + micro_batch_size: 1 + seed: 42 + shuffle_seed: 444 + dataset_path: ${dataset_path}/valid + +mip: + single_block_replacement_validation_dir: ${to_path:${scoring.output_dir}} + subblock_stats_path: ${to_path:${puzzle_dir}/${calc_subblock_stats.subblock_stats_filename}} + output_path: ${to_path:${puzzle_dir}/mip/puzzle_solutions} + gathered_metrics_path: + puzzle_profile: + + # puzzle_profile: + objective: metrics.cosine_embedding_loss_hidden_states + bigger_is_better: false + num_solutions: 1 + minimal_diversity: 2 + + subblock_stats_args: + - batch_size: 96 + weights_dtype: torch.bfloat16 + activations_dtype: torch.bfloat16 + kv_cache_dtype: torch.bfloat16 + + report_additional_costs: + - stats.memory_mib + - stats.num_params + - stats.num_kv_heads + - stats.has_attention + - stats.has_ffn + - stats.kv_cache_memory_mib + - stats.attention_memory_mib + - stats.ffn_memory_mib + - stats.ffn_num_params + - stats.attention_num_params + - stats.num_local_experts + + human_constraints: + mip_constraints: + - stats.num_local_experts: 1472 # teacher has: 23 moe-blocks * 128 experts = 2944 total experts use_greedy_search: false + is_multi_layer_puzzle: true + metric_overrides: + constrain_search_func: + max_seconds_per_solution: 60 + +realize_model: + descriptor: ${descriptor} + + teacher_dir: ${to_path:${teacher_dir}} + tokenizer_name: ${to_path:${teacher_dir}} + replacement_library_path: ${replacement_library_path} + save_models: true + solutions_path: # Filled dynamically + + # Validate params + skip_validation: false # To enable validation of the model solution set `skip_validation` as False + eval_samples: 2 + micro_batch_size: 1 + seed: 42 + shuffle_seed: 444 + dataset_path: ${dataset_path}/valid + +nccl_timeout_minutes: ${timedelta_minutes:10} + +# This section redirects Hydra outputs +hydra: + run: + dir: ${puzzle_dir}/hydra_logs/${now:%Y-%m-%d}/${now:%H-%M-%S} diff --git a/tests/gpu/torch/puzzletron/resources/configs/nvidia/NVIDIA-Nemotron-3-Nano-30B-A3B-Base-BF16/pruning/expert_pruning.yaml b/tests/gpu/torch/puzzletron/resources/configs/nvidia/NVIDIA-Nemotron-3-Nano-30B-A3B-Base-BF16/pruning/expert_pruning.yaml new file mode 100644 index 0000000000..ae20b6d7d2 --- /dev/null +++ b/tests/gpu/torch/puzzletron/resources/configs/nvidia/NVIDIA-Nemotron-3-Nano-30B-A3B-Base-BF16/pruning/expert_pruning.yaml @@ -0,0 +1,18 @@ +defaults: + - /pruning/pruning_defaults@_here_ + +eval_samples: 10 +activations_log_dir: ${puzzle_dir}/pruning/pruning_scores/expert_removal/${pruning.experiment_id} +pruning_mixin: + _target_: modelopt.torch.puzzletron.pruning.expert_removal_pruning_mixin.ExpertRemovalPruningMixIn + layer_descriptor: + _target_: modelopt.torch.puzzletron.anymodel.models.nemotron_h.nemotron_h_model_descriptor.NemotronHExpertRemovalLayerDescriptor + target_name: "mixer" + +hook_class: ${get_object:modelopt.torch.prune.importance_hooks.expert_removal_hooks.NemotronHRemoveExpertsIndependentHook} +activation_hooks_kwargs: # Additional kwargs to pass to the hook init + +num_experts_to_keep_list: [96, 64, 32, 16, 8] # num_experts in teacher is 128 +mlp_init_mode: "ExpertRemoval" +mlp_init_config_yaml: + expert_scores_key: "expert_ranks_mse" diff --git a/tests/gpu/torch/puzzletron/resources/configs/nvidia/NVIDIA-Nemotron-3-Nano-30B-A3B-Base-BF16/pruning/ffn_pruning.yaml b/tests/gpu/torch/puzzletron/resources/configs/nvidia/NVIDIA-Nemotron-3-Nano-30B-A3B-Base-BF16/pruning/ffn_pruning.yaml new file mode 100644 index 0000000000..abc501287d --- /dev/null +++ b/tests/gpu/torch/puzzletron/resources/configs/nvidia/NVIDIA-Nemotron-3-Nano-30B-A3B-Base-BF16/pruning/ffn_pruning.yaml @@ -0,0 +1,14 @@ +defaults: + - /pruning/pruning_defaults + +activations_log_dir: ${puzzle_dir}/pruning/pruning_scores/ffn/${pruning.experiment_id} +pruning_mixin: + _target_: modelopt.torch.puzzletron.pruning.ffn_intermediate_pruning_mixin.FFNIntermediatePruningMixIn + layer_descriptor: + _target_: modelopt.torch.puzzletron.anymodel.models.llama.llama_model_descriptor.LlamaFFNIntermediateLayerDescriptor + +hook_class: ${get_object:modelopt.torch.prune.importance_hooks.base_hooks.IterativeChannelContributionHook} +activation_hooks_kwargs: # Additional kwargs to pass to the hook init + +intermediate_size_list: [3072, 5888, 8704, 11520] # teacher_intermediate_size is 14336 +mlp_init_mode: "PruneByActivationsLog" diff --git a/tests/gpu/torch/puzzletron/resources/configs/nvidia/NVIDIA-Nemotron-Nano-12B-v2/NVIDIA-Nemotron-Nano-12B-v2.yaml b/tests/gpu/torch/puzzletron/resources/configs/nvidia/NVIDIA-Nemotron-Nano-12B-v2/NVIDIA-Nemotron-Nano-12B-v2.yaml new file mode 100644 index 0000000000..906b7338d8 --- /dev/null +++ b/tests/gpu/torch/puzzletron/resources/configs/nvidia/NVIDIA-Nemotron-Nano-12B-v2/NVIDIA-Nemotron-Nano-12B-v2.yaml @@ -0,0 +1,113 @@ +# @package _global_ +defaults: + - /nvidia/NVIDIA-Nemotron-Nano-12B-v2/pruning@pruning: ffn_pruning + - /validate_solutions_defaults@scoring + - /validate_solutions_defaults@realize_model + - _self_ + +puzzle_dir: ??? +teacher_dir: ${puzzle_dir}/ckpts/teacher/ +replacement_library_path: ${puzzle_dir}/replacement_library.json +dataset_path: ??? # path to v0.4_mini + +skip_realize_model: false + +descriptor: nemotron_h_v2 + +build_replacement_library: + add_ffn_no_ops: true + add_attention_no_ops: true + +calc_subblock_stats: + batch_sizes: [64, 96, 128] + prefill_seq_len: 4096 + generation_seq_len: 4096 + num_active_tokens_override: # Optional override for sequence lengths + prefill_queue_size: 0 + allocate_prefill_query: false + benchmark_iterations: # Set to a number (e.g., 1000) to enable runtime benchmarking + merge_with_existing_stats: false + subblock_stats_filename: "subblock_stats.json" + moe_stats_filename: "moe_stats.json" + +scoring: + descriptor: ${descriptor} + + solutions_to_validate: + skip_existing_solutions: true + + replacement_library_path: ${replacement_library_path} + solutions_path: ${to_path:${puzzle_dir}/single_sequence_replacement_solutions.json} + teacher_dir: ${to_path:${teacher_dir}} + output_dir: ${puzzle_dir}/single_sequence_replacement_solutions--validation + + eval_samples: 2 + micro_batch_size: 1 + dataset_path: ${dataset_path}/valid + seed: 42 + shuffle_seed: 444 + +mip: + single_block_replacement_validation_dir: ${to_path:${scoring.output_dir}} + subblock_stats_path: ${to_path:${puzzle_dir}/${calc_subblock_stats.subblock_stats_filename}} + output_path: ${to_path:${puzzle_dir}/mip/puzzle_solutions} + gathered_metrics_path: + puzzle_profile: + + # puzzle_profile: + objective: metrics.cosine_embedding_loss_hidden_states + bigger_is_better: false + num_solutions: 1 + minimal_diversity: 2 + + subblock_stats_args: + - batch_size: 96 + weights_dtype: torch.bfloat16 + activations_dtype: torch.bfloat16 + kv_cache_dtype: torch.bfloat16 + + report_additional_costs: + - stats.memory_mib + - stats.num_params + - stats.num_kv_heads + - stats.has_attention + - stats.has_ffn + - stats.kv_cache_memory_mib + - stats.attention_memory_mib + - stats.ffn_memory_mib + - stats.ffn_num_params + - stats.attention_num_params + + human_constraints: + target_memory: 780_000 # 78_000 + + mip_constraints: + use_greedy_search: false + is_multi_layer_puzzle: true + metric_overrides: + constrain_search_func: + max_seconds_per_solution: 60 + +realize_model: + descriptor: ${descriptor} + + teacher_dir: ${to_path:${teacher_dir}} + tokenizer_name: ${to_path:${teacher_dir}} + replacement_library_path: ${replacement_library_path} + save_models: true + solutions_path: # Filled dynamically + + # Validate params + skip_validation: false # To enable validation of the model solution set `skip_validation` as False + eval_samples: 2 + micro_batch_size: 1 + dataset_path: ${dataset_path}/valid + seed: 42 + shuffle_seed: 444 + +nccl_timeout_minutes: ${timedelta_minutes:10} + +# This section redirects Hydra outputs +hydra: + run: + dir: ${puzzle_dir}/hydra_logs/${now:%Y-%m-%d}/${now:%H-%M-%S} diff --git a/tests/gpu/torch/puzzletron/resources/configs/nvidia/NVIDIA-Nemotron-Nano-12B-v2/pruning/ffn_pruning.yaml b/tests/gpu/torch/puzzletron/resources/configs/nvidia/NVIDIA-Nemotron-Nano-12B-v2/pruning/ffn_pruning.yaml new file mode 100644 index 0000000000..f68068c3ac --- /dev/null +++ b/tests/gpu/torch/puzzletron/resources/configs/nvidia/NVIDIA-Nemotron-Nano-12B-v2/pruning/ffn_pruning.yaml @@ -0,0 +1,12 @@ +defaults: + - /pruning/ffn_pruning_base@_here_ + - _self_ + +pruning_mixin: + layer_descriptor: + _target_: modelopt.torch.puzzletron.anymodel.models.nemotron_h_v2.nemotron_h_v2_model_descriptor.NemotronHV2FFNIntermediateLayerDescriptor + +activation_hooks_kwargs: + method: iterative + target_layer: "mixer.down_proj" + layer_input_descriptors_path: diff --git a/tests/gpu/torch/puzzletron/resources/configs/openai/gpt-oss-20b/gpt-oss-20b.yaml b/tests/gpu/torch/puzzletron/resources/configs/openai/gpt-oss-20b/gpt-oss-20b.yaml new file mode 100644 index 0000000000..2b77516174 --- /dev/null +++ b/tests/gpu/torch/puzzletron/resources/configs/openai/gpt-oss-20b/gpt-oss-20b.yaml @@ -0,0 +1,109 @@ +# @package _global_ +defaults: + - /openai/gpt-oss-20b/pruning@pruning: expert_removal # TODO: Note: Works for unquantized test models, not MXFP4 quantized production models + - /validate_solutions_defaults@scoring + - /validate_solutions_defaults@realize_model + - bypass: + - override /hydra/hydra_logging: disabled + - _self_ + +descriptor: gpt_oss + +puzzle_dir: ??? +teacher_dir: ${puzzle_dir}/ckpts/teacher/ +replacement_library_path: ${puzzle_dir}/replacement_library.json +dataset_path: ??? # path to v0.4_mini + +skip_realize_model: false + +build_replacement_library: + add_ffn_no_ops: true # TODO: Works for unquantized test models + add_attention_no_ops: true + +calc_subblock_stats: + batch_sizes: [64, 96, 128] + prefill_seq_len: 4096 + generation_seq_len: 4096 + num_active_tokens_override: # Optional override for sequence lengths + prefill_queue_size: 0 + allocate_prefill_query: false + benchmark_iterations: # Set to a number (e.g., 1000) to enable runtime benchmarking + merge_with_existing_stats: false + subblock_stats_filename: "subblock_stats.json" + moe_stats_filename: "moe_stats.json" + +scoring: + descriptor: ${descriptor} + solutions_to_validate: + skip_existing_solutions: true + + replacement_library_path: ${replacement_library_path} + solutions_path: ${to_path:${puzzle_dir}/single_sequence_replacement_solutions.json} + teacher_dir: ${to_path:${teacher_dir}} + output_dir: ${puzzle_dir}/single_sequence_replacement_solutions--validation + + eval_samples: 2 + micro_batch_size: 1 + dataset_path: ${dataset_path}/valid + seed: 42 + shuffle_seed: 444 + +mip: + single_block_replacement_validation_dir: ${to_path:${scoring.output_dir}} + subblock_stats_path: ${to_path:${puzzle_dir}/${calc_subblock_stats.subblock_stats_filename}} + output_path: ${to_path:${puzzle_dir}/mip/puzzle_solutions} + gathered_metrics_path: + puzzle_profile: + + # puzzle_profile: + objective: metrics.cosine_embedding_loss_hidden_states + bigger_is_better: false + + subblock_stats_args: + - batch_size: 96 + weights_dtype: torch.bfloat16 + activations_dtype: torch.bfloat16 + kv_cache_dtype: torch.bfloat16 + + report_additional_costs: + - stats.memory_mib + - stats.num_params + - stats.num_kv_heads + - stats.has_attention + - stats.has_ffn + - stats.kv_cache_memory_mib + - stats.attention_memory_mib + - stats.ffn_memory_mib + - stats.ffn_num_params + - stats.attention_num_params + + human_constraints: + target_memory: 780_000 # 78_000 + + mip_constraints: + - stats.num_local_experts: 48 # teacher has: 2 layers * 32 experts = 64 total experts + metric_overrides: + max_seconds_per_solution: 60 + +realize_model: + descriptor: ${descriptor} + teacher_dir: ${to_path:${teacher_dir}} + tokenizer_name: ${to_path:${teacher_dir}} + replacement_library_path: ${replacement_library_path} + save_models: true + solutions_path: # Filled dynamically + + # Validate params + skip_validation: false # To enable validation of the model solution set `skip_validation` as False + eval_samples: 2 + micro_batch_size: 1 + dataset_path: ${dataset_path}/valid + seed: 42 + shuffle_seed: 444 + +nccl_timeout_minutes: ${timedelta_minutes:10} + +# This section redirects Hydra outputs +hydra: + run: + dir: ${puzzle_dir}/hydra_logs/${now:%Y-%m-%d}/${now:%H-%M-%S} diff --git a/tests/gpu/torch/puzzletron/resources/configs/openai/gpt-oss-20b/pruning/expert_removal.yaml b/tests/gpu/torch/puzzletron/resources/configs/openai/gpt-oss-20b/pruning/expert_removal.yaml new file mode 100644 index 0000000000..5a4761886f --- /dev/null +++ b/tests/gpu/torch/puzzletron/resources/configs/openai/gpt-oss-20b/pruning/expert_removal.yaml @@ -0,0 +1,19 @@ +defaults: + - /pruning/pruning_defaults@_here_ + +eval_samples: 10 +activations_log_dir: ${puzzle_dir}/pruning/pruning_scores/expert_removal/${pruning.experiment_id} + +pruning_mixin: + _target_: modelopt.torch.puzzletron.pruning.expert_removal_pruning_mixin.ExpertRemovalPruningMixIn + layer_descriptor: + _target_: modelopt.torch.puzzletron.anymodel.models.gpt_oss.gpt_oss_model_descriptor.GptOssExpertRemovalLayerDescriptor + target_name: "mlp.router" +hook_class: ${get_object:modelopt.torch.prune.importance_hooks.expert_removal_hooks.RankedChoiceVotingHook} +activation_hooks_kwargs: # Additional kwargs to pass to the hook init + +num_experts_to_keep_list: [24, 16, 8] # num_experts in teacher is 128 +mlp_init_mode: "ExpertRemoval" +mlp_init_config_yaml: + expert_scores_key: "expert_ranks" + layer_prefix_template: "model.layers.{layer_idx}.mlp.router" diff --git a/tests/gpu/torch/puzzletron/resources/configs/pruning/attn_pruning.yaml b/tests/gpu/torch/puzzletron/resources/configs/pruning/attn_pruning.yaml new file mode 100644 index 0000000000..0dadc20134 --- /dev/null +++ b/tests/gpu/torch/puzzletron/resources/configs/pruning/attn_pruning.yaml @@ -0,0 +1,23 @@ +defaults: + - /pruning/pruning_defaults@_here_ + - _self_ + +activations_log_dir: ${puzzle_dir}/pruning/pruning_scores/attn_${pruning.activation_hooks_kwargs.method}/${pruning.experiment_id} + +pruning_mixin: + _target_: modelopt.torch.puzzletron.pruning.kv_heads_pruning_mixin.KVHeadsPruningMixIn + layer_descriptor: + _target_: ??? + +hook_class: ${get_object:modelopt.torch.prune.importance_hooks.base_hooks.IndependentKvHeadContributionHook} +activation_hooks_kwargs: + method: independent_kv_head_contribution + optimize_for: memory # IndependentKvHeadContributionHook implementation that consumes less memory + target_layer: "self_attn.o_proj" + layer_input_descriptors_path: + +# n_heads_in_group: 4 +# num_attention_heads: 32 # num query heads +# num_kv_heads: 32 / 4 = 8 # num_query_heads // n_heads_in_group +n_heads_in_group_list: [8, 16, 32] # num_kv_heads = [4, 2, 1] +gqa_init_mode: "PruneKVHeads" diff --git a/tests/gpu/torch/puzzletron/resources/configs/pruning/ffn_pruning_base.yaml b/tests/gpu/torch/puzzletron/resources/configs/pruning/ffn_pruning_base.yaml new file mode 100644 index 0000000000..c1c951984f --- /dev/null +++ b/tests/gpu/torch/puzzletron/resources/configs/pruning/ffn_pruning_base.yaml @@ -0,0 +1,19 @@ +defaults: + - /pruning/pruning_defaults@_here_ + - _self_ + +activations_log_dir: ${puzzle_dir}/pruning/pruning_scores/ffn_${pruning.activation_hooks_kwargs.method}/${pruning.experiment_id} + +pruning_mixin: + _target_: modelopt.torch.puzzletron.pruning.ffn_intermediate_pruning_mixin.FFNIntermediatePruningMixIn + layer_descriptor: + _target_: ??? + +hook_class: ${get_object:modelopt.torch.prune.importance_hooks.base_hooks.IterativeChannelContributionHook} +activation_hooks_kwargs: + method: iterative + target_layer: "mlp.down_proj" + layer_input_descriptors_path: + +intermediate_size_list: [256] +mlp_init_mode: "PruneByActivationsLog" diff --git a/tests/gpu/torch/puzzletron/resources/configs/pruning/hidden_dim_pruning.yaml b/tests/gpu/torch/puzzletron/resources/configs/pruning/hidden_dim_pruning.yaml new file mode 100644 index 0000000000..4033fedf3a --- /dev/null +++ b/tests/gpu/torch/puzzletron/resources/configs/pruning/hidden_dim_pruning.yaml @@ -0,0 +1,15 @@ +defaults: + - /pruning/pruning_defaults@_here_ + +activations_log_dir: ${puzzle_dir}/pruning/pruning_scores/hidden_dim_${pruning.activation_hooks_kwargs.method}/${pruning.experiment_id} + +activation_hooks_kwargs: + method: layer_norm_contribution + target_layer: "layernorm" + +# Hidden dimension pruning specific settings +hidden_size_list: [3072, 2048] # Target hidden sizes to prune to +hidden_size_init_mode: "PruneByChannelRanking" +mlp_init_mode: "Truncate" # TODO, make it work with CopyAsIs/FromTeacher +gqa_init_mode: "AverageKV" # TODO, make it work with CopyAsIs/FromTeacher +linear_init_mode: "FromTeacher" diff --git a/tests/gpu/torch/puzzletron/resources/configs/pruning/pruning_defaults.yaml b/tests/gpu/torch/puzzletron/resources/configs/pruning/pruning_defaults.yaml new file mode 100644 index 0000000000..f00a86da66 --- /dev/null +++ b/tests/gpu/torch/puzzletron/resources/configs/pruning/pruning_defaults.yaml @@ -0,0 +1,34 @@ +defaults: + - /validate_model_defaults@_here_ + +model_name_or_path: ${teacher_dir} +experiment_id: ${pruning.eval_samples}samples_diverse_mini +activations_log_dir: ??? +activation_hooks_kwargs: ??? + +descriptor: ${descriptor} + +# Data: +eval_samples: 100 +micro_batch_size: 4 +dataset_path: ${dataset_path} +val_dataset_name: train + +# Prune ckpts +pruned_ckpts_output_dir: ${puzzle_dir}/pruning/${pruning.experiment_id} + +## FFN pruning +ffn_list: +mlp_init_mode: "Truncate" + +## KV-heads pruning +n_heads_in_group_list: +gqa_init_mode: "AverageKV" + +## Hidden dimension pruning +hidden_size_list: +hidden_size_init_mode: "PruneByChannelRanking" +linear_init_mode: "FromTeacher" + +mlp_init_config_yaml: + activations_log_dir: ${pruning.activations_log_dir} diff --git a/tests/gpu/torch/puzzletron/resources/configs/validate_model_defaults.yaml b/tests/gpu/torch/puzzletron/resources/configs/validate_model_defaults.yaml new file mode 100644 index 0000000000..9dabef7413 --- /dev/null +++ b/tests/gpu/torch/puzzletron/resources/configs/validate_model_defaults.yaml @@ -0,0 +1,15 @@ +block_size: 8192 +bos_rate: 0.5 +data_column: conversation +val_dataset_name: train +shuffle_seed: 81436 +seed: 42 +fim_rate: 0 +fim_spm_rate: 0 +source_datasets_to_discard: +varlen: false +write_results: false +calc_losses_on_cpu: false +activations_log_dir: +model_name_or_path: +load_dataset_fn: ${get_object:modelopt.torch.puzzletron.utils.data.dataloaders.load_from_disk_fn} diff --git a/tests/gpu/torch/puzzletron/resources/configs/validate_solutions_defaults.yaml b/tests/gpu/torch/puzzletron/resources/configs/validate_solutions_defaults.yaml new file mode 100644 index 0000000000..ec13902379 --- /dev/null +++ b/tests/gpu/torch/puzzletron/resources/configs/validate_solutions_defaults.yaml @@ -0,0 +1,10 @@ +defaults: + - /validate_model_defaults + - _self_ + +solutions_to_validate: +skip_validation: false +save_models: false +bigger_is_better: false +sort_solutions_by: +calculate_full_score_ablations: false diff --git a/tests/gpu/torch/puzzletron/test_puzzletron.py b/tests/gpu/torch/puzzletron/test_puzzletron.py new file mode 100644 index 0000000000..45c438ec0d --- /dev/null +++ b/tests/gpu/torch/puzzletron/test_puzzletron.py @@ -0,0 +1,345 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import json +import warnings +from datetime import timedelta +from functools import partial +from pathlib import Path + +import pytest +import torch +from _test_utils.torch.distributed.utils import spawn_multiprocess_job +from _test_utils.torch.misc import set_seed +from _test_utils.torch.puzzletron.utils import setup_test_model_and_data + +import modelopt.torch.utils.distributed as dist +from modelopt.torch.puzzletron import puzzletron +from modelopt.torch.puzzletron.anymodel import convert_model +from modelopt.torch.puzzletron.mip.sweep import ( + get_teacher_memory_from_subblock_stats, + get_teacher_num_params_from_subblock_stats, +) + +# The e2e test to compress a model based on Local Neural Architecture Search (Mixed Integer Programing NAS search) +# using a one-click command. +# +# Note: Bypass is disabled now in the test. +# + +SEED = 1234 + + +@pytest.mark.parametrize( + ("hf_model_name", "converter", "hybrid_override_pattern", "has_moe_layers"), + [ + ("meta-llama/Llama-3.1-8B-Instruct", "llama", None, False), + ("meta-llama/Llama-3.2-3B-Instruct", "llama", None, False), + ("mistralai/Mistral-Small-24B-Instruct-2501", "mistral_small", None, False), + ("nvidia/NVIDIA-Nemotron-3-Nano-30B-A3B-Base-BF16", "nemotron_h", "*E", True), + ("nvidia/NVIDIA-Nemotron-Nano-12B-v2", "nemotron_h_v2", "*-", False), + ("openai/gpt-oss-20b", "gpt_oss", None, True), + ("Qwen/Qwen2.5-7B-Instruct", "qwen2", None, False), + ("Qwen/Qwen3-8B", "qwen3", None, False), + ("Qwen/Qwen3-VL-30B-A3B-Instruct", "qwen3_vl", None, True), + ], +) +def test_puzzletron( + project_root_path: Path, + tmp_path: Path, + hf_model_name: str, + converter: str, + hybrid_override_pattern: str, + has_moe_layers: bool, +): + spawn_multiprocess_job( + size=torch.cuda.device_count(), + job=partial( + _test_puzzletron_multiprocess_job, + project_root_path, + tmp_path, + hf_model_name, + converter, + hybrid_override_pattern, + has_moe_layers, + ), + backend="nccl", + ) + + +def _test_puzzletron_multiprocess_job( + project_root_path: Path, + tmp_path: Path, + hf_model_name: str, + converter: str, + hybrid_override_pattern: str, + has_moe_layers: bool, + rank: int, + size: int, +): + # Set seed BEFORE dist.setup() to ensure reproducibility across all processes + set_seed(SEED) + dist.setup(timeout=timedelta(10)) + + # Setup the test model and data. + puzzle_dir, hf_checkpoint_path, dataset_path = setup_test_model_and_data( + tmp_path, rank, hf_model_name, hybrid_override_pattern + ) + hydra_config_dir = project_root_path / "tests/gpu/torch/puzzletron/resources/configs" + model_basename = hf_model_name.split("/")[1] + hydra_config_name = f"{hf_model_name}/{model_basename}" + + # Convert the model using AnyModel converter. + if rank == 0: + convert_model( + input_dir=str(hf_checkpoint_path), + output_dir=str(puzzle_dir / "ckpts/teacher"), + converter=converter, + ) + dist.barrier() + + # Compress the model using a one-click approach + hydra_cfg = puzzletron.puzzletron( + str(hydra_config_dir), hydra_config_name, str(puzzle_dir), str(dataset_path) + ) + + # + # Check assertions + # + if rank == 0: + if has_moe_layers: + # assertions for the score_pruning_activations step 1 (MoE models only) + rank_filepath = ( + f"pruning/pruning_scores/expert_removal/10samples_diverse_mini/rank_{rank}.pth" + ) + assert (puzzle_dir / rank_filepath).is_file(), f"Expected {rank_filepath} to exist" + + # assertions for the pruning_ckpts step 2 + assert (puzzle_dir / "ckpts/num_experts_8").exists() + + # assertions for the mip_and_realize_models step 6 + # Find the MIP solution directory dynamically (e.g., stats_num_local_experts_*) + mip_solutions_dir = puzzle_dir / "mip/puzzle_solutions" + solution_dirs = [ + d + for d in mip_solutions_dir.iterdir() + if d.is_dir() and d.name.startswith("stats_num_local_experts_") + ] + assert len(solution_dirs) == 1, ( + f"Expected exactly one stats_num_local_experts_* directory, found: {[d.name for d in solution_dirs]}" + ) + solution_dir = solution_dirs[0] + + solution_0_ckpt_config_path = ( + solution_dir / "solutions--checkpoints/solution_0/config.json" + ) + assert solution_0_ckpt_config_path.exists() + assert (solution_dir / "solutions.json").exists() + + # Validate lm_loss + _assert_lm_loss(puzzle_dir, hf_model_name, tolerance=0.01) + else: + # assertions for the score_pruning_activations step 1 (FFN pruning) + _assert_score_pruning_activations(puzzle_dir, hf_model_name) + + # assertions for the pruning_ckpts step 2 + assert (puzzle_dir / "ckpts/ffn_256_attn_no_op").exists() + + # assertions for the mip_and_realize_models step 6 + _assert_mip_solutions(puzzle_dir, hf_model_name) + + # assertions for the build_library_and_stats step 4 + assert (puzzle_dir / "replacement_library.json").is_file() + _assert_subblock_stats_anymodel(hf_model_name, hydra_cfg) + + # assertions for the scoring step 5 + solution_0_filepath = ( + puzzle_dir / "single_sequence_replacement_solutions--validation/solution_0.json" + ) + assert solution_0_filepath.exists() + + dist.cleanup() + + +def _assert_subblock_stats_anymodel(hf_model_name: str, hydra_cfg) -> None: + """Minimal subblock_stats checks and teacher memory / param regression values.""" + assert (Path(hydra_cfg.puzzle_dir) / "subblock_stats.json").is_file() + teacher_mem_mib = get_teacher_memory_from_subblock_stats(hydra_cfg) + teacher_num_params = get_teacher_num_params_from_subblock_stats(hydra_cfg) + + assert abs(teacher_mem_mib - EXPECTED_TEACHER_MEMORY_MIB[hf_model_name]) < 1e-6, ( + f"Teacher memory mismatch for {hf_model_name}: " + f"expected {EXPECTED_TEACHER_MEMORY_MIB[hf_model_name]}, got {teacher_mem_mib}" + ) + assert abs(teacher_num_params - EXPECTED_TEACHER_NUM_PARAMS[hf_model_name]) < 1e-6, ( + f"Teacher num_params mismatch for {hf_model_name}: " + f"expected {EXPECTED_TEACHER_NUM_PARAMS[hf_model_name]}, got {teacher_num_params}" + ) + + +def _assert_score_pruning_activations(puzzle_dir: Path, hf_model_name: str): + """Assertions for the score_pruning_activations step 1.""" + rank = dist.rank() + rank_filepath = f"pruning/pruning_scores/ffn_iterative/100samples_diverse_mini/rank_{rank}.pth" + assert (puzzle_dir / rank_filepath).is_file() + + pruning_scores = torch.load(puzzle_dir / rank_filepath) + + layer_names = list(pruning_scores.keys()) + expected = EXPECTED_PRUNING_VALUES[hf_model_name] + size = dist.size() + + if hf_model_name == "mistralai/Mistral-Small-24B-Instruct-2501" and size == 1: + warnings.warn("Mistral-Small score assertions only work for 2 GPUs") + return + + if expected is not None: + # In multi-GPU: layers are distributed across ranks + # Each rank processes len(expected) // size layers + expected_layers_per_rank = len(expected) // size + assert len(layer_names) == expected_layers_per_rank, ( + f"Expected {expected_layers_per_rank} FFN layers on rank {rank}/{size}, got {len(layer_names)}" + ) + # Check each layer's values + for i, layer_name in enumerate(layer_names): + layer_data = pruning_scores[layer_name] + # Calculate global layer index from rank and local index + global_idx = rank * expected_layers_per_rank + i + assert layer_data["score"][0].item() == expected[global_idx]["score"], ( + layer_name, + layer_data["score"][0].item(), + expected[global_idx]["score"], + global_idx, + ) + assert ( + layer_data["channels_importance_ascending"][0].item() + == expected[global_idx]["channels"] + ) + else: + print(f"\n=== PRUNING VALUES for {hf_model_name} (num_layers={len(layer_names)}) ===") + print(f'"{hf_model_name}": [') + for layer_name in layer_names: + layer_data = pruning_scores[layer_name] + score = layer_data["score"][0].item() + channels = layer_data["channels_importance_ascending"][0].item() + print(f' {{"score": {score}, "channels": {channels}}},') + print("],") + print("===") + pytest.fail(f"Expected pruning values not found for {hf_model_name}") + + +def _assert_lm_loss(puzzle_dir: Path, hf_model_name: str, tolerance: float = 0.01): + """Validate lm_loss for a model solution.""" + solution_0_path = ( + puzzle_dir / "single_sequence_replacement_solutions--validation/solution_0.json" + ) + with open(solution_0_path) as f: + validation = json.load(f) + + actual_lm_loss = validation["lm_loss"]["avg"] + expected_lm_loss = EXPECTED_LM_LOSS.get(hf_model_name) + if expected_lm_loss is not None: + assert abs(actual_lm_loss - expected_lm_loss) < tolerance, ( + f"lm_loss mismatch: expected {expected_lm_loss}, got {actual_lm_loss}" + ) + else: + # Print value for new models - update EXPECTED_LM_LOSS with this + print(f"\n=== LM_LOSS for {hf_model_name} ===") + print(f'"{hf_model_name}": {actual_lm_loss},') + print("===") + + +def _assert_mip_solutions(puzzle_dir: Path, hf_model_name: str): + """Assertions for the mip_and_realize_models step.""" + mip_dir = puzzle_dir / "mip/puzzle_solutions/target_memory_780000MiB" + + assert (mip_dir / "solutions.json").exists() + assert (mip_dir / "solutions--checkpoints/solution_0/config.json").exists() + + # Validate lm_loss + _assert_lm_loss(puzzle_dir, hf_model_name) + + +# Expected pruning activation values per model +# Each model has a list of (score, channels) tuples for each FFN layer +EXPECTED_PRUNING_VALUES = { + "meta-llama/Llama-3.1-8B-Instruct": [ + {"score": 73, "channels": 95}, + {"score": 440, "channels": 174}, + ], + "meta-llama/Llama-3.2-3B-Instruct": [ + {"score": 79, "channels": 95}, + {"score": 428, "channels": 174}, + ], + "mistralai/Mistral-Small-24B-Instruct-2501": [ + {"score": 73, "channels": 95}, + {"score": 431, "channels": 174}, + ], + # NemotronH with pattern "*-" has only 1 FFN layer (the "-" layer) + "nvidia/NVIDIA-Nemotron-Nano-12B-v2": [ + {"score": 70, "channels": 509}, + ], + # nvidia/NVIDIA-Nemotron-3-Nano-30B-A3B-Base-BF16 uses MoE expert pruning, not FFN pruning + "Qwen/Qwen2.5-7B-Instruct": [ + {"score": 96, "channels": 433}, + {"score": 485, "channels": 105}, + ], + "Qwen/Qwen3-8B": [ + {"score": 208, "channels": 51}, + {"score": 475, "channels": 266}, + ], +} + + +# Expected lm_loss values per model +EXPECTED_LM_LOSS = { + "meta-llama/Llama-3.1-8B-Instruct": 4.706878662109375, + "meta-llama/Llama-3.2-3B-Instruct": 4.816886901855469, + "mistralai/Mistral-Small-24B-Instruct-2501": 4.709150314331055, + # TODO: not reproducible in CI, skipping for now + # "nvidia/NVIDIA-Nemotron-3-Nano-30B-A3B-Base-BF16": 4.733944892883301, + "nvidia/NVIDIA-Nemotron-Nano-12B-v2": 4.79390811920166, + "openai/gpt-oss-20b": 4.689250946044922, + "Qwen/Qwen2.5-7B-Instruct": 4.778186798095703, + "Qwen/Qwen3-8B": 4.733874320983887, + "Qwen/Qwen3-VL-30B-A3B-Instruct": 4.65625, +} + + +# Expected teacher memory from subblock_stats (MiB) +EXPECTED_TEACHER_MEMORY_MIB = { + "meta-llama/Llama-3.1-8B-Instruct": 395.60205078125, + "meta-llama/Llama-3.2-3B-Instruct": 395.60205078125, + "mistralai/Mistral-Small-24B-Instruct-2501": 395.60205078125, + "nvidia/NVIDIA-Nemotron-3-Nano-30B-A3B-Base-BF16": 202.10107421875, + "nvidia/NVIDIA-Nemotron-Nano-12B-v2": 202.10107421875, + "openai/gpt-oss-20b": 437.302490234375, + "Qwen/Qwen2.5-7B-Instruct": 386.228515625, + "Qwen/Qwen3-8B": 395.60302734375, + "Qwen/Qwen3-VL-30B-A3B-Instruct": 406.11865234375, +} + + +# Expected total teacher params from subblock_stats +EXPECTED_TEACHER_NUM_PARAMS = { + "meta-llama/Llama-3.1-8B-Instruct": 6082816.0, + "meta-llama/Llama-3.2-3B-Instruct": 6082816.0, + "mistralai/Mistral-Small-24B-Instruct-2501": 6082816.0, + "nvidia/NVIDIA-Nemotron-3-Nano-30B-A3B-Base-BF16": 5295872.0, + "nvidia/NVIDIA-Nemotron-Nano-12B-v2": 5295872.0, + "openai/gpt-oss-20b": 27945856.0, + "Qwen/Qwen2.5-7B-Instruct": 1168384.0, + "Qwen/Qwen3-8B": 6083328.0, + "Qwen/Qwen3-VL-30B-A3B-Instruct": 11596544.0, +} diff --git a/tests/unit/torch/puzzletron/test_convert_anymodel.py b/tests/unit/torch/puzzletron/test_convert_anymodel.py new file mode 100644 index 0000000000..f27cb9c9b9 --- /dev/null +++ b/tests/unit/torch/puzzletron/test_convert_anymodel.py @@ -0,0 +1,35 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import pytest + +pytest.importorskip("transformers") + +from _test_utils.torch.transformers_models import create_tiny_qwen3_dir +from transformers import AutoModelForCausalLM + +from modelopt.torch.puzzletron.anymodel import convert_model +from modelopt.torch.puzzletron.anymodel.model_descriptor import ModelDescriptorFactory +from modelopt.torch.puzzletron.anymodel.puzzformer import deci_x_patcher + + +def test_convert_anymodel(tmp_path): + input_dir = create_tiny_qwen3_dir(tmp_path, with_tokenizer=True) + output_dir = tmp_path / "qwen3-0.6b-anymodel" + convert_model(input_dir, output_dir, converter="qwen3") + + descriptor = ModelDescriptorFactory.get("qwen3") + with deci_x_patcher(descriptor): + _ = AutoModelForCausalLM.from_pretrained(output_dir) diff --git a/tox.ini b/tox.ini index 80299d814d..7948a19f5c 100644 --- a/tox.ini +++ b/tox.ini @@ -66,6 +66,10 @@ commands_pre = # Install cupy-cuda13x for INT4 ONNX quantization (default is cupy-cuda12x) pip uninstall -y cupy-cuda12x pip install cupy-cuda13x + + # Install mamba and causal-conv1d for Nemotron tests + pip install --no-build-isolation git+https://github.com/state-spaces/mamba.git + pip install --no-build-isolation git+https://github.com/Dao-AILab/causal-conv1d.git commands = # Coverage fails with "Can't combine line data with arc data" error so not using "--cov" python -m pytest tests/gpu