diff --git a/examples/puzzletron/README.md b/examples/puzzletron/README.md index a7e3aedfc..134da88e2 100644 --- a/examples/puzzletron/README.md +++ b/examples/puzzletron/README.md @@ -130,6 +130,91 @@ hf auth login 30% GPU memory reduction leads to nearly 5% regression in token_accuracy_top_10 metric (0.898 / 0.942). +## Bypass Distillation (Local Knowledge Distillation) + +Bypass distillation (also called Blockwise Local Distillation or BLD) is an **optional** pipeline stage that trains alternative transformer block configurations using per-block knowledge distillation from the teacher model. It significantly improves the quality of aggressively compressed models by producing better "puzzle pieces" for the MIP solver. + +### When to use bypass + +Bypass distillation is only necessary for **aggressive compression**. For mild pruning (e.g., reducing FFN intermediate size by less than 25%), weight-initialization-based pruning alone usually produces good results. Use bypass when: + +- **Heavy FFN pruning**: the target `intermediate_size` is ≤ 1/8 of the teacher's width. + For example, on Llama-3.1-8B (teacher `intermediate_size=14336`), run bypass for sizes ≤ 1792. + For milder reductions (e.g., to 3072 = ~21%), bypass improves quality but may not be essential. +- **KV head compression**: the number of `num_key_value_heads` is being significantly reduced + (e.g., from 8 to 2 or fewer). The AverageKV initialization provides a good starting point, + but bypass distillation recovers additional accuracy. + +### Time cost + +Bypass distillation is a full training loop — plan for several hours per configuration when +using ~1B training tokens on H100 GPUs. Total time scales with `len(bypass.configs) × training_tokens`. +This is comparable to lightweight fine-tuning. + +### Sequential execution + +Each entry in `bypass.configs` trains **sequentially** (one config at a time). There is no +parallelism across configurations — if you have 3 configs, they run one after the other within +a single pipeline invocation. Distribute across different jobs if time is a constraint. + +### Configuration + +Add a `bypass` section to your config YAML (or include `bypass/defaults.yaml` via Hydra defaults). +Key parameters: + +| Parameter | Description | Default | +|---|---|---| +| `training.learning_rate` | Initial learning rate | `1e-4` | +| `training.training_tokens` | Total training tokens per config | `1e+9` (1B) | +| `training.micro_batch_size` | Batch size per step | `2` | +| `data.block_size` | Sequence length | `512` | +| `model_factory.gqa_init_mode` | KV head init strategy (`AverageKV`, `RandomKV`) | `AverageKV` | +| `model_factory.mlp_init_mode` | FFN init strategy (`Truncate`, `PruneByActivationsLog`) | `Truncate` | +| `model_factory.keys_to_learn` | Which params to train (`subblock_ffn`, `subblock_attention`, `entire_block`) | computed | +| `configs` | List of configurations to train sequentially | — | + +### Training multiple configurations + +Use `bypass.configs` to train multiple block configurations in a single run. Each entry +overrides `model.model_config_overrides` and optionally `model_factory.keys_to_learn`: + +```yaml +bypass: + training: + training_tokens: 1e+9 # ~1B tokens per config + configs: + - model_config_overrides: + ffn: + - intermediate_size: 1792 # ~1/8 of 14336 — bypass strongly recommended + attention: + - num_key_value_heads: 8 + keys_to_learn: subblock_ffn + - model_config_overrides: + ffn: + - intermediate_size: 3584 # ~1/4 of 14336 — bypass optional but helpful + attention: + - num_key_value_heads: 8 + keys_to_learn: subblock_ffn +``` + +Trained checkpoints are automatically symlinked into `$PUZZLE_DIR/ckpts/` where the replacement +library builder picks them up in the next pipeline stage. + +### Weights & Biases logging + +Enable W&B to track per-block distillation loss and validation metrics during training: + +```yaml +bypass: + wandb_log: true + wandb: + project: my-puzzletron-project + entity: my-org +``` + +W&B logs iteration number, token count, learning rate, and per-block loss at each log interval. +If `wandb` is not installed, logging is silently disabled and training continues normally. + ## Re-run MIP Search with different constraints If you want to try different constraints without re-running the expensive pruning and scoring steps, use the `--mip-only` flag. diff --git a/examples/puzzletron/configs/llama-3_1-8B_pruneffn_memory/Llama-3_1-8B.yaml b/examples/puzzletron/configs/llama-3_1-8B_pruneffn_memory/Llama-3_1-8B.yaml index 21903db16..29174ce88 100644 --- a/examples/puzzletron/configs/llama-3_1-8B_pruneffn_memory/Llama-3_1-8B.yaml +++ b/examples/puzzletron/configs/llama-3_1-8B_pruneffn_memory/Llama-3_1-8B.yaml @@ -2,7 +2,7 @@ defaults: - pruning: ffn_pruning - scoring: ../validate_solutions_defaults - realize_model: ../validate_solutions_defaults - - bypass: + - bypass: defaults # comment out to run without bypass - override hydra/hydra_logging: disabled - _self_ diff --git a/examples/puzzletron/configs/llama-3_1-8B_pruneffn_memory/bypass/defaults.yaml b/examples/puzzletron/configs/llama-3_1-8B_pruneffn_memory/bypass/defaults.yaml new file mode 100644 index 000000000..7a0be3789 --- /dev/null +++ b/examples/puzzletron/configs/llama-3_1-8B_pruneffn_memory/bypass/defaults.yaml @@ -0,0 +1,130 @@ +# @package bypass +# Bypass Distillation Configuration +# This config defines parameters for blockwise local distillation (BLD), +# which trains alternative transformer block configurations using per-block +# knowledge distillation from a teacher model. + +# Runtime Configuration +dtype: "bf16" # Model precision: bf16 for efficiency, fp32 for stability +seed: 42 # Random seed for reproducibility + +# Experiment Tracking +experiment_id: # Unique identifier for this experiment. Will be dynamically set +experiment_dir: # Directory for this experiment. Will be dynamically set +iter_num: 1 # Current iteration number +step_num: 1 # Current step number within iteration +token_count: 0 # Token count tracker (auto-updated during training) + +# Data Configuration +data: + data_column: "messages" + block_size: 512 # Sequence length (tokens per sample) + bos_rate: 0.5 + fim_rate: 0 + fim_spm_rate: 0 + source_datasets_to_discard: [] + load_from_disk: true # Load preprocessed data from disk or from stream + keep_in_memory: false + val_dataset_name: valid + max_eval_samples: 4 + eval_samples_per_process: null # Samples per GPU during distributed eval (auto if null) + shuffle_train_data_seed: ${random_int:0,9999} # Seed for shuffling train data + +# Training Configuration +training: + learning_rate: 1e-4 # Initial learning rate (1e-4 = 0.0001) + training_tokens: 1e+4 # Total training tokens (10K tokens - sanity check) + micro_batch_size: 2 + val_micro_batch_size: 1 + warmup_ratio: 0.05 + warmup_steps: ${warmup_steps:${.training_tokens},${..data.block_size},${.micro_batch_size},${.warmup_ratio}} # Auto-calculated warmup steps + min_lr_factor: 1e-5 + grad_accumulation_steps: 1 + skip_first_batches: 0 # Use for debugging or to skip few batches which cause crashes or optimization issues. + weight_decay: 0.1 + decay_lr: true + beta1: 0.9 + beta2: 0.95 + use_grad_scaling: false + grad_clip: 1.0 + grad_clip_type: norm + clipping_count: 0 + log_interval: 5 + eval_interval: 5 + +# Model Loading Configuration +resume_checkpoint_path: null # Path to resume training from checkpoint +find_last_ckpt_for_resume: True # Auto-resume by finding last checkpoint (bool) +parameter_count: null +init_checkpoint_path: null # Path to initialize weights from + +model: + student_weights_dtype: "bf16" # Student model weight precision + + model_overrides: + delete_old_checkpoints: true # Clean up old checkpoints to save disk space + save_interval_seconds: 12900 # Save checkpoint every ~3.5 hours + save_interval: 1e+9 # Save checkpoint every 1B steps (effectively disabled) + save_checkpoint_when_done: true # Save final checkpoint when training completes + +# Architecture modifications for student model + model_config_overrides: + ffn: + - intermediate_size: + no_op: # Disable FFN entirely (true/false) + attention: + - num_key_value_heads: # Number of kv-heads (for GQA) + no_op: # Disable attention entirely (true/false) + +# Model Factory Configuration - Controls student model creation and initialization +model_factory: + factory: bypass_factory_fn # Unified factory supporting all layer types + block_loss_func: normalized_mse_loss # Loss function for comparing teacher/student blocks. vectorwise_normalized_mse_loss / batched_normalized_mse_loss / normalized_mse_loss + gqa_init_mode: AverageKV # How to initialize K/V heads in GQA. All options here: GQAInitMode + mlp_init_mode: Truncate # MLP initialization. All options here: MlpInitMode + mlp_init_config: # Configuration for MLP initialization (if needed) + activations_log_dir: null # Directory with activation statistics (required for PruneByActivationsLog) + linear_init_mode: FromTeacher # How to initialize linear layers: FromTeacher, Random, etc. + submodule_for_loss_calculation: null # Specific submodule for loss calc. + keys_to_learn: null # What parameters to train. Either "entire_block", or specific submodules. Computed dynamically. + +# Validation Configuration +disable_initial_validate: false +validate_teacher_model: true +validate_student_model: true +disable_validation: false # Enable validation to exercise all code paths +best_val_loss: 1e+9 # Track best validation loss achieved + +# Performance Optimization +compile: false # Use PyTorch compilation +disable_fa2: false # Disable Flash Attention 2 (false = use FA2 if available) +teacher_model_load_on_cpu: false + +# Checkpoint Management +save_checkpoint_before_training: false # Save initial checkpoint before training +disable_checkpoint_save: false # Disable all checkpoint saving +save_best_ckpt: true # Save checkpoint when validation improves +kill_after_first_save: false # Exit after first checkpoint save (for testing) +realize_best_or_latest: "best" + +wandb_log: false +wandb: + project: + entity: + +# Multiple bypass configurations to train sequentially. +# Each entry overrides model.model_config_overrides and optionally model_factory.keys_to_learn. +# If empty or absent, a single run uses the settings above. +configs: + - model_config_overrides: + ffn: + - intermediate_size: 3072 + attention: + - num_key_value_heads: 8 + keys_to_learn: subblock_ffn + - model_config_overrides: + ffn: + - intermediate_size: 5888 + attention: + - num_key_value_heads: 8 + keys_to_learn: subblock_ffn diff --git a/examples/puzzletron/main.py b/examples/puzzletron/main.py index 5bb04818e..4e62bfb78 100644 --- a/examples/puzzletron/main.py +++ b/examples/puzzletron/main.py @@ -41,6 +41,7 @@ import modelopt.torch.puzzletron.mip.sweep as sweep import modelopt.torch.utils.distributed as dist from modelopt.torch.puzzletron.nas.plugins.puzzletron_nas_plugin import PuzzletronModel +from modelopt.torch.puzzletron.nas.plugins.puzzletron_nas_plugin import _total_steps from modelopt.torch.puzzletron.tools.hydra_utils import ( initialize_hydra_config_for_dir, register_hydra_resolvers, @@ -74,7 +75,6 @@ def run_full_puzzletron(hydra_config_path: str): Args: config_path: Path to the YAML configuration file """ - mprint("Puzzletron Progress 1/8: starting puzzletron pipeline") dist.setup(timeout=timedelta(10)) # Register Hydra custom resolvers (needed for config resolution) @@ -84,12 +84,15 @@ def run_full_puzzletron(hydra_config_path: str): hydra_config_dir = str(hydra_config_path.parent) hydra_config_name = hydra_config_path.stem - # Load hydra config + # Load hydra config to determine total step count (bypass adds one step) hydra_cfg = initialize_hydra_config_for_dir( config_dir=hydra_config_dir, config_name=hydra_config_name, overrides=[], ) + N = _total_steps(hydra_cfg) + + mprint(f"Puzzletron Progress 1/{N}: starting puzzletron pipeline") # Convert model (convert from HF to DeciLM, score pruning activations, # prune the model and save pruned checkpoints) @@ -120,7 +123,7 @@ def run_full_puzzletron(hydra_config_path: str): ) dist.cleanup() - mprint("Puzzletron Progress 8/8: puzzletron pipeline completed (multi-gpu)") + mprint(f"Puzzletron Progress {N}/{N}: puzzletron pipeline completed (multi-gpu)") def run_mip_only(hydra_config_path: str): diff --git a/modelopt/torch/puzzletron/anymodel/model_descriptor/model_descriptor.py b/modelopt/torch/puzzletron/anymodel/model_descriptor/model_descriptor.py index 4cc4356c8..9bcddad18 100644 --- a/modelopt/torch/puzzletron/anymodel/model_descriptor/model_descriptor.py +++ b/modelopt/torch/puzzletron/anymodel/model_descriptor/model_descriptor.py @@ -160,6 +160,19 @@ def layer_name_predicates(num_layers: int) -> Dict[str, re.Pattern]: """ raise NotImplementedError + @staticmethod + def pruning_mixins() -> Dict[str, Any]: + """Return available pruning mixins for bypass distillation. + + Override in subclasses to provide model-specific pruning mixins, e.g. + ``{"kv_heads": KVHeadsPruningMixIn(...), "experts_removal": ExpertRemovalPruningMixIn(...)}``. + + Returns an empty dict by default so that descriptors that do not need + model-specific weight-slicing (e.g. Llama with standard FFN truncation) + can rely on the generic ``create_child_state_dict`` fallback path. + """ + return {} + @staticmethod def uses_autocast() -> bool: """Whether this model supports torch.autocast. diff --git a/modelopt/torch/puzzletron/anymodel/models/nemotron_h/nemotron_h_model_descriptor.py b/modelopt/torch/puzzletron/anymodel/models/nemotron_h/nemotron_h_model_descriptor.py index 55d9ef56c..50dc2db4b 100644 --- a/modelopt/torch/puzzletron/anymodel/models/nemotron_h/nemotron_h_model_descriptor.py +++ b/modelopt/torch/puzzletron/anymodel/models/nemotron_h/nemotron_h_model_descriptor.py @@ -34,6 +34,10 @@ ExpertRemovalLayerDescriptor, ExpertRemovalPruningMixIn, ) +from modelopt.torch.puzzletron.pruning.kv_heads_pruning_mixin import ( + KVHeadsLayerDescriptor, + KVHeadsPruningMixIn, +) from modelopt.torch.puzzletron.pruning.pruning_mixin import PruningMixIn @@ -52,6 +56,15 @@ def get_dynamic_modules(module_cls_str: str) -> List[Type[nn.Module]]: return matches +@dataclass +class NemotronHKVHeadsLayerDescriptor(KVHeadsLayerDescriptor): + o_proj_name: str = "mixer.o_proj" + attn_prefix_name: str = "backbone.layers.{layer_idx}.mixer" + qkvo_weight_names: List[str] = field( + default_factory=lambda: ["q_proj", "k_proj", "v_proj", "o_proj"] + ) + + @dataclass class NemotronHExpertRemovalLayerDescriptor(ExpertRemovalLayerDescriptor): target_name: str = "mixer.gate" @@ -253,4 +266,5 @@ def build_attention_predicates() -> Dict[str, re.Pattern]: def pruning_mixins() -> Dict[str, PruningMixIn]: return { "experts_removal": ExpertRemovalPruningMixIn(NemotronHExpertRemovalLayerDescriptor()), + "kv_heads": KVHeadsPruningMixIn(NemotronHKVHeadsLayerDescriptor()), } diff --git a/modelopt/torch/puzzletron/bypass_distillation/__init__.py b/modelopt/torch/puzzletron/bypass_distillation/__init__.py new file mode 100644 index 000000000..f1cea0afe --- /dev/null +++ b/modelopt/torch/puzzletron/bypass_distillation/__init__.py @@ -0,0 +1,22 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Bypass distillation (blockwise local distillation) for the PUZZLE framework. + +This module implements Stage 1 of the PUZZLE pipeline: training alternative transformer +block configurations using per-block knowledge distillation from a teacher model. +""" + +from .training_loop import launch_bypass_distillation diff --git a/modelopt/torch/puzzletron/bypass_distillation/bypass_checkpoint_utils.py b/modelopt/torch/puzzletron/bypass_distillation/bypass_checkpoint_utils.py new file mode 100644 index 000000000..52ef8e884 --- /dev/null +++ b/modelopt/torch/puzzletron/bypass_distillation/bypass_checkpoint_utils.py @@ -0,0 +1,187 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Checkpoint utilities for bypass distillation.""" + +import re +from collections import OrderedDict +from pathlib import Path +from typing import Optional, Type, Union + +import torch +from omegaconf import DictConfig +from tqdm import tqdm + +import modelopt.torch.utils.distributed as dist +from modelopt.torch.puzzletron.anymodel.model_descriptor import ModelDescriptor +from modelopt.torch.puzzletron.tools.checkpoint_utils_hf import save_checkpoint +from modelopt.torch.puzzletron.tools.logger import aprint, mprint +from modelopt.torch.puzzletron.tools.robust_json import json_dump + +from .stitched_model_factory import StitchedModuleDescriptor + + +def find_latest_run_dir(run_parent_dir: Union[str, Path]) -> str | None: + """Find the latest checkpoint directory within a run parent directory.""" + run_parent_dir = Path(run_parent_dir) + + # Check for the "latest" directory + latest_dir = run_parent_dir / "latest" + if latest_dir.exists() and (latest_dir / "saving_completed").exists(): + return str(latest_dir) + + # If "latest" doesn't exist, look explicitly into directories with `*iter-*` + candidate_dirs = [d for d in run_parent_dir.glob("*iter-*") if d.is_dir()] + + if not candidate_dirs: + return None + + def get_iter_num(dir_name): + match = re.search(r"iter-(\d+)", dir_name.name) + return int(match.group(1)) if match else 0 + + checkpoint_dirs = sorted(candidate_dirs, key=get_iter_num, reverse=True) + for latest_dir in checkpoint_dirs: + if (latest_dir / "saving_completed").exists(): + return str(latest_dir) + return None + + +def load_local_state( + stitched_module_descriptors: OrderedDict[str, StitchedModuleDescriptor], + checkpoint_path: str | Path, + verbose=True, +) -> None: + """Load local state from a checkpoint. + + Loads both optimizer and state dicts into stitched module descriptors. + Modifies stitched_module_descriptors in place. + """ + device = torch.device(f"cuda:{dist.local_rank()}") + load_dir = Path(checkpoint_path) + + if not load_dir.exists(): + raise RuntimeError(f'Can\'t load local state. "{load_dir}" does not exist.') + + for stitched_module_name, stitched_module_descriptor in stitched_module_descriptors.items(): + stitched_module = stitched_module_descriptor.stitched_module + optimizer = stitched_module_descriptor.optimizer + + state_dict_path = load_dir / "stitched" / f"{stitched_module_name}.state_dict.pth" + if verbose: + mprint(f"Loading state dict for module {stitched_module_name} from {state_dict_path}") + loaded_state_dict = torch.load(state_dict_path, map_location=device) + loaded_state_dict = {**stitched_module.state_dict(), **loaded_state_dict} + + stitched_module.load_state_dict(loaded_state_dict) + del loaded_state_dict + + if optimizer is not None: + optimizer_state_path = ( + load_dir / "stitched" / f"{stitched_module_name}.optimizer_state.pth" + ) + if verbose: + mprint( + f"Loading optimizer state for module {stitched_module_name} from {optimizer_state_path}" + ) + loaded_optimizer_state = torch.load(optimizer_state_path, map_location=device) + optimizer.load_state_dict(loaded_optimizer_state) + del loaded_optimizer_state + + +def _save_local_file(obj, save_path: Path | str, overwrite=True): + save_path = Path(save_path) + if save_path.exists(): + if not overwrite: + mprint(f'WARNING: Local save path "{save_path}" already exists. Skipping') + return + torch.save(obj, save_path) + + +def _save_local_state( + stitched_module_descriptors: OrderedDict[str, StitchedModuleDescriptor], + checkpoint_dir: Path | str, + overwrite=True, + verbose=True, +) -> None: + save_dir = Path(checkpoint_dir) / "stitched" + + if dist.is_master(): + save_dir.mkdir(parents=True, exist_ok=True) + + # Main process creates the directory, so we must wait for it to finish + dist.barrier() + + for stitched_module_name, stitched_module_descriptor in tqdm( + stitched_module_descriptors.items() + ): + optimizer = stitched_module_descriptor.optimizer + + state_dict_path = save_dir / f"{stitched_module_name}.state_dict.pth" + if verbose: + aprint(f"Saving state dict for module {stitched_module_name} to {state_dict_path}") + state_dict = { + **stitched_module_descriptor.owned_parameters, + **stitched_module_descriptor.owned_buffers, + } + _save_local_file(state_dict, state_dict_path, overwrite=overwrite) + + if optimizer is not None: + optimizer_state_path = save_dir / f"{stitched_module_name}.optimizer_state.pth" + if verbose: + mprint( + f"Saving optimizer state for module {stitched_module_name} to {optimizer_state_path}" + ) + _save_local_file(optimizer.state_dict(), optimizer_state_path, overwrite=overwrite) + + dist.barrier() + + +def save_bypass_checkpoint( + cfg: DictConfig, + descriptor: Type[ModelDescriptor], + model: torch.nn.Module, + stitched_module_descriptors: OrderedDict[str, StitchedModuleDescriptor], + checkpoint_dir: Path | str, + reference_checkpoint_dir: Optional[Path] = None, +) -> None: + """Save a bypass distillation checkpoint.""" + checkpoint_dir = Path(checkpoint_dir) + mprint("Starting checkpoint save") + mprint(f"Saving checkpoint to {checkpoint_dir}") + + # Save stitched module states + _save_local_state( + stitched_module_descriptors=stitched_module_descriptors, + checkpoint_dir=checkpoint_dir, + overwrite=cfg.bypass.model.model_overrides.delete_old_checkpoints, + verbose=dist.is_master() and False, + ) + # Save as HF checkpoint + save_checkpoint(model=model, checkpoint_dir=checkpoint_dir, descriptor=descriptor) + + if dist.is_master(): + # Create 'latest' symlink + latest_symlink = Path(cfg.bypass.experiment_dir) / "latest" + latest_symlink.unlink(missing_ok=True) + latest_symlink.symlink_to(checkpoint_dir.name) + # Save config args json + json_dump(cfg.bypass, checkpoint_dir / "args.json") + # Save completed file + completed_file = checkpoint_dir / "saving_completed" + completed_file.touch() + + dist.barrier() + mprint("Checkpoint save done") diff --git a/modelopt/torch/puzzletron/bypass_distillation/bypass_utils.py b/modelopt/torch/puzzletron/bypass_distillation/bypass_utils.py new file mode 100644 index 000000000..3715078bb --- /dev/null +++ b/modelopt/torch/puzzletron/bypass_distillation/bypass_utils.py @@ -0,0 +1,67 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Utility functions for bypass distillation.""" + +from pathlib import Path + +from omegaconf import DictConfig + +import modelopt.torch.utils.distributed as dist + + +def set_experiment_id(cfg: DictConfig) -> None: + """Set the experiment ID based on the model config overrides.""" + if cfg.bypass.experiment_id is None: + overrides = cfg.bypass.model.model_config_overrides + if "ffn" in overrides: + ffn_override = overrides.ffn[0] + if "intermediate_size" in ffn_override: + # Dense FFN model: identify by FFN size and attention heads + cfg.bypass.experiment_id = "bypass_ffn_{}_heads_{}".format( + ffn_override["intermediate_size"], + overrides.attention[0]["num_key_value_heads"], + ) + else: + # MoE model: identify by number of experts per layer + cfg.bypass.experiment_id = "bypass_experts_{}".format( + ffn_override["moe"]["num_local_experts"] + ) + elif "attention" in overrides: + # Attention-only bypass: identify by number of KV heads + cfg.bypass.experiment_id = "bypass_heads_{}".format( + overrides.attention[0]["num_key_value_heads"] + ) + + +def set_experiment_dir(cfg: DictConfig) -> None: + """Set the experiment directory for the bypass run.""" + cfg.bypass.experiment_dir = Path(cfg.puzzle_dir) / "bypass" / "bypass_runs" / cfg.bypass.experiment_id + if dist.is_master(): + cfg.bypass.experiment_dir.mkdir(parents=True, exist_ok=True) + + +def get_distributed_modules_ownership(module_count: int, world_size: int) -> list[int]: + """Map module (block) indices to GPU ranks for pipeline-parallel distribution.""" + modules_process_ownership: list[int] = [] + + for i in range(world_size): + num_modules_for_process = module_count // world_size + if i < module_count % world_size: + num_modules_for_process += 1 + + modules_process_ownership.extend([i] * num_modules_for_process) + + return modules_process_ownership diff --git a/modelopt/torch/puzzletron/bypass_distillation/data_classes.py b/modelopt/torch/puzzletron/bypass_distillation/data_classes.py new file mode 100644 index 000000000..3fb1b2835 --- /dev/null +++ b/modelopt/torch/puzzletron/bypass_distillation/data_classes.py @@ -0,0 +1,43 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Data classes for bypass distillation training.""" + +import dataclasses +from typing import TypeAlias + + +IterNum: TypeAlias = int +GlobalRank: TypeAlias = int + + +@dataclasses.dataclass(kw_only=True, frozen=True) +class IterStatistics: + step_num: int + token_count: int + iter_duration: float + lr: float + clipping_count: int + + +@dataclasses.dataclass(kw_only=True, frozen=True) +class LocalTrainingStats: + iter_num: int + stitched_module_losses: dict[str, float] + + +@dataclasses.dataclass(kw_only=True, frozen=True) +class TimeToSaveSignal: + step_num: int diff --git a/modelopt/torch/puzzletron/bypass_distillation/stitched_model_factory.py b/modelopt/torch/puzzletron/bypass_distillation/stitched_model_factory.py new file mode 100644 index 000000000..710935388 --- /dev/null +++ b/modelopt/torch/puzzletron/bypass_distillation/stitched_model_factory.py @@ -0,0 +1,619 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Factory for creating stitched teacher/student models for bypass distillation.""" + +import copy +import dataclasses +import re +from argparse import Namespace +from collections import OrderedDict +from pathlib import Path +from typing import Any, Callable, Mapping, Optional, Sequence, Type + +import torch +from omegaconf import DictConfig, OmegaConf +from torch.amp.grad_scaler import GradScaler +from torch.optim import AdamW, Optimizer +from transformers import PretrainedConfig, PreTrainedModel + +import modelopt.torch.utils.distributed as dist +from modelopt.torch.puzzletron.anymodel.model_descriptor import ModelDescriptor +from modelopt.torch.puzzletron.pruning.pruning_utils import ( + GQAInitMode, + LinearInitMode, + MlpInitMode, +) +from modelopt.torch.puzzletron.sewing_kit import ( + ExternalTarget, + FunctionTarget, + InputArgs, + ModuleTarget, + Needle, + RemoteTarget, + StitchedModule, + always_true_predicate, +) +from modelopt.torch.puzzletron.sewing_kit.core import InputReducer +from modelopt.torch.puzzletron.sewing_kit.utils import ( + batched_normalized_mse_loss, + normalized_mse_loss, + vectorwise_normalized_mse_loss, +) +from modelopt.torch.puzzletron.tools.bypassed_training.child_init import ( + create_child_state_dict, + update_model_config, +) +from modelopt.torch.puzzletron.tools.logger import mprint +from modelopt.torch.puzzletron.tools.sharded_checkpoint_utils import create_sharded_model +from modelopt.torch.puzzletron.utils.parsing import format_block_configs, parse_dtype + +StitchedModulesProcessOwnership = list[int] +SyncDistributedModelWeightsFn = Callable[[], None] +Config = Mapping[str, Any] +Args = Namespace + + +@dataclasses.dataclass +class StitchedModuleDescriptor: + stitched_module: StitchedModule + owned_parameters: dict[str, torch.nn.Parameter] + owned_buffers: dict[str, torch.Tensor] + optimizer: Optional[Optimizer] = None + grad_scaler: Optional[GradScaler] = None + + +def default_factory( + teacher_model: PreTrainedModel, + descriptor: Type[ModelDescriptor], + config: Config, + model_blocks_process_ownership: Sequence[int], + student_model: Optional[PreTrainedModel] = None, +) -> tuple[ + PreTrainedModel, + StitchedModule, + StitchedModule, + StitchedModule, + OrderedDict[str, StitchedModuleDescriptor], + PretrainedConfig, +]: + raise NotImplementedError() + + +StitchedModelFactoryFn = type(default_factory) + +_SUBBLOCK_KEYS_TO_LEARN = frozenset({"subblock_ffn", "subblock_attention", "subblock_mamba", "entire_block"}) + + +def _set_keys_to_learn( + model: PreTrainedModel, + descriptor: Type[ModelDescriptor], + keys_to_learn: str | Sequence[str], +) -> None: + """Set ``requires_grad=True`` on parameters selected by ``keys_to_learn``. + + * A **sequence of strings** (not a bare ``str``): each string is a full parameter + name; gradients are enabled only where ``named_parameters()`` names match exactly. + * A **single string**: if it is ``"subblock_ffn"``, ``"subblock_attention"``, or + ``"entire_block"``, enables gradients for the corresponding descriptor weight + groups; otherwise ``re.search`` is applied to each parameter name. + """ + # If keys_to_learn is a sequence of strings. + if isinstance(keys_to_learn, Sequence) and not isinstance(keys_to_learn, str): + param_names = set(keys_to_learn) + # If keys_to_learn is a single string. + else: + # If keys_to_learn is a single string that is a subblock key. + if keys_to_learn in _SUBBLOCK_KEYS_TO_LEARN: + lm_config = descriptor.get_language_model_config(model.config) + weight_groups = descriptor.get_weight_groups( + model.state_dict().keys(), lm_config.num_hidden_layers + ) + + attn_group_names = [ + group_name + for group_name in weight_groups.keys() + if group_name.endswith("_attention") + ] + ffn_group_names = [ + group_name + for group_name in weight_groups.keys() + if group_name.endswith("_ffn") + ] + if keys_to_learn == "subblock_attention": + group_names = attn_group_names + elif keys_to_learn == "subblock_ffn": + group_names = ffn_group_names + elif keys_to_learn == "subblock_mamba": + group_names = attn_group_names # Mamba params live in _attention groups + else: # entire_block + group_names = attn_group_names + ffn_group_names + + block_configs = getattr(lm_config, "block_configs", None) + + param_names = [] + for group_name in group_names: + # For hybrid models (e.g. NemotronH), a single "_attention" group + # name can contain either Mamba SSM params *or* GQA params depending + # on the block. Use the block config — not the keys_to_learn string + # — to decide whether each block belongs to the current subblock type. + if block_configs is not None: + m = re.match(r"block_(\d+)_attention", group_name) + if m: + block_idx = int(m.group(1)) + if block_idx < len(block_configs): + is_mamba = ( + getattr(block_configs[block_idx].attention, "mamba", None) + is not None + ) + # subblock_attention → GQA blocks only (not Mamba) + # subblock_mamba → Mamba blocks only (not GQA) + # entire_block → all blocks (no filtering) + if keys_to_learn == "subblock_attention" and is_mamba: + continue + if keys_to_learn == "subblock_mamba" and not is_mamba: + continue + param_names.extend(weight_groups[group_name]) + param_names = set(param_names) + # If keys_to_learn is a single string that is not a subblock key, treat as regex. + else: + param_names = { + param_name + for param_name, _ in model.named_parameters() + if re.search(keys_to_learn, param_name) + } + # In pipeline-parallel training a rank may own only blocks that don't match + # keys_to_learn (e.g. a rank with only Mamba blocks during subblock_attention + # bypass has no GQA params after the _mamba rename). That is a valid state: + # all its blocks will produce NaN loss and be excluded from statistics. + if not param_names: + return + + # Set requires_grad to True for the selected parameters. + for param_name, param in model.named_parameters(): + if param_name in param_names and torch.is_floating_point(param): + param.requires_grad_(True) + + +def _get_all_non_persistent_buffers_set(module: torch.nn.Module) -> set[str]: + all_non_persistent = set() + for module_name, submodule in module.named_modules(): + for buffer_name in submodule._non_persistent_buffers_set: + full_name = f"{module_name}.{buffer_name}" if module_name else buffer_name + all_non_persistent.add(full_name) + return all_non_persistent + + +def bypass_factory_fn( + teacher_model: PreTrainedModel, + descriptor: Type[ModelDescriptor], + cfg: DictConfig, + model_blocks_process_ownership: Sequence[int], + student_model: Optional[PreTrainedModel] = None, +) -> tuple[ + PreTrainedModel, + StitchedModule, + StitchedModule, + StitchedModule, + OrderedDict[str, StitchedModuleDescriptor], + PretrainedConfig, +]: + """Unified factory function for bypass (blockwise local) distillation. + + Handles all layer types — FFN, attention (GQA/MHA), MoE experts, Mamba, and whole blocks — + through a single pipeline. Behavior is driven entirely by ``model_factory`` config fields: + + - ``mlp_init_mode``: how student FFN / MoE weights are initialised + - ``"ExpertRemoval"``: select top-N experts from teacher (MoE models) + - ``"Truncate"`` / ``"PruneByActivationsLog"``: prune FFN channels (dense models) + - ``"CopyAsIs"``: copy weights unchanged (attention-only or Mamba-only runs) + - ``gqa_init_mode``: how attention KV heads are initialised (optional, default ``AverageKV``). + Irrelevant when the student has the same number of KV heads as the teacher. + - ``keys_to_learn``: which parameters to train. + Accepts ``"subblock_ffn"``, ``"subblock_attention"``, ``"entire_block"``, or a regex string. + + The stitching logic (pipeline-parallel per-block KD) is architecture-agnostic and unchanged + regardless of which layer type is being distilled. + + Args: + teacher_model: The teacher model to use for stitching. + descriptor: Model descriptor for layer naming and pruning mixin lookup. + cfg: The bypass config section. + model_blocks_process_ownership: Ownership mapping of model blocks to process ranks. + student_model: Optionally provided pre-built student model (skips initialisation). + + Returns: + Tuple of (student_model, teacher_stitched, teacher_val_stitched, + student_val_stitched, stitched_module_descriptors, student_config) + """ + device = torch.device(f"cuda:{dist.local_rank()}") + model_config_overrides = cfg.model.model_config_overrides + + block_loss_func = { + "normalized_mse_loss": normalized_mse_loss, + "vectorwise_normalized_mse_loss": vectorwise_normalized_mse_loss, + "batched_normalized_mse_loss": batched_normalized_mse_loss, + }[cfg.model_factory.block_loss_func] + mprint(f"{block_loss_func.__name__=}") + + owned_block_indexes = set( + block_index + for block_index, owner_rank in enumerate(model_blocks_process_ownership) + if owner_rank == dist.rank() + ) + + # Initialize student_model + if student_model is None: + mprint("Creating student model from teacher model") + + with torch.autocast(device_type="cuda", dtype=torch.bfloat16): + if isinstance(model_config_overrides, DictConfig): + config_to_override = OmegaConf.to_container(model_config_overrides, resolve=True) + else: + config_to_override = model_config_overrides + mprint(f"{config_to_override=}") + student_model_config = update_model_config( + model_config=teacher_model.config, + model_config_overrides=config_to_override, + ) + student_model_config.use_cache = False + + mprint(f"Student model config:\n {format_block_configs(student_model_config)}") + + from modelopt.torch.puzzletron.anymodel.puzzformer import deci_x_patcher + + runtime = Namespace( + device=device, + dtype=torch.bfloat16, + global_rank=dist.rank(), + world_size=dist.size(), + is_main_process=dist.is_master(), + is_last_process=dist.is_last_process(), + ) + + with deci_x_patcher( + model_descriptor=descriptor, + block_configs=getattr(student_model_config, "block_configs", None), + ): + student_model = create_sharded_model( + runtime=runtime, + descriptor=descriptor, + model_config=student_model_config, + owned_block_indexes=owned_block_indexes, + device=device, + ) + student_model._init_weights(student_model) + + student_weights_dtype = parse_dtype(cfg.model.student_weights_dtype) + descriptor.init_rotary_embedding(student_model, runtime) + student_model.type(student_weights_dtype) + + mlp_init_mode = MlpInitMode(cfg.model_factory.mlp_init_mode or MlpInitMode.CopyAsIs) + + # For expert removal, use the model-specific pruning mixin so that model-specific + # key paths (e.g. backbone.layers.{i}.mixer for Nemotron-H vs model.layers.{i}.mlp + # for GPT-OSS) are handled correctly. For all other init modes the legacy inline + # key logic in create_child_state_dict is sufficient. + _mixins = [] + if mlp_init_mode == MlpInitMode.ExpertRemoval: + _expert_mixin = descriptor.pruning_mixins().get("experts_removal") + if _expert_mixin is not None: + _mixins.append(_expert_mixin) + + # If any attention layer has fewer KV heads in the student than the teacher, use the + # model-specific KV heads mixin so that k_proj/v_proj weights are correctly sliced + # rather than copied verbatim from the (larger) teacher state dict. + _kv_mixin = descriptor.pruning_mixins().get("kv_heads") + if _kv_mixin is not None: + _student_kv = [ + b.attention.num_key_value_heads + for b in student_model_config.block_configs + if b.attention is not None and b.attention.num_key_value_heads is not None + ] + _teacher_kv = [ + b.attention.num_key_value_heads + for b in teacher_model.config.block_configs + if b.attention is not None and b.attention.num_key_value_heads is not None + ] + if _student_kv != _teacher_kv: + _mixins.append(_kv_mixin) + + if len(_mixins) == 0: + pruning_mixin = None + elif len(_mixins) == 1: + pruning_mixin = _mixins[0] + else: + pruning_mixin = _mixins + + # GQA init mode is optional: only relevant when the student has fewer KV heads than + # the teacher. Defaults to AverageKV and is a no-op when head counts are equal. + gqa_init_mode = GQAInitMode( + cfg.model_factory.get("gqa_init_mode", GQAInitMode.AverageKV) + ) + + student_state_dict = create_child_state_dict( + pruning_mixin=pruning_mixin, + descriptor=descriptor, + original_state_dict=teacher_model.state_dict(), + new_state_dict=student_model.state_dict(), + original_config=teacher_model.config, + new_config=student_model_config, + gqa_init_mode=gqa_init_mode, + mlp_init_mode=mlp_init_mode, + mlp_init_config=cfg.model_factory.mlp_init_config, + owned_block_indexes=owned_block_indexes, + linear_init_mode=LinearInitMode( + cfg.model_factory.linear_init_mode or LinearInitMode.Random + ), + ) + + # Load student state dict + missing_keys, unexpected_keys = student_model.load_state_dict( + student_state_dict, strict=False + ) + assert len(unexpected_keys) == 0, f"{unexpected_keys=}" + # GQA models have learnable logit parameters not present in the teacher state dict; + # allow those to be absent and assert nothing else is missing. + non_gqa_missing = [k for k in missing_keys if not re.search(r"gqa_\w+_logits", k)] + assert len(non_gqa_missing) == 0, f"Unexpected missing keys: {non_gqa_missing}" + + else: + mprint("Student model provided explicitly, not using teacher model to instantiate") + student_model_config = student_model.config + + # Set up training parameters + lm_config = descriptor.get_language_model_config(student_model_config) + all_block_indices = list(range(lm_config.num_hidden_layers)) + + student_model.requires_grad_(False) + keys_to_learn = cfg.model_factory.keys_to_learn + mprint(f"Keys to learn: {keys_to_learn}") + + _set_keys_to_learn(model=student_model, descriptor=descriptor, keys_to_learn=keys_to_learn) + + dist.barrier() + mprint(f"Global rank: {dist.rank()}, {owned_block_indexes=}") + dist.barrier() + + torch.cuda.synchronize() + torch.cuda.empty_cache() + dist.barrier() + + min_owned_index = min(owned_block_indexes) + max_owned_index = max(owned_block_indexes) + prev_rank: Optional[int] = ( + None + if min_owned_index == min(all_block_indices) + else model_blocks_process_ownership[min_owned_index - 1] + ) + next_rank: Optional[int] = ( + None + if max_owned_index == max(all_block_indices) + else model_blocks_process_ownership[max_owned_index + 1] + ) + + teacher_parameters = set(teacher_model.parameters()) + teacher_buffers = set(teacher_model.buffers()) + + # Setup the student model's submodules for knowledge distillation training + with torch.autocast(device_type="cuda", dtype=torch.bfloat16), torch.device(device): + stitched_module_descriptors = OrderedDict[str, StitchedModuleDescriptor]() + submodule_for_loss_calculation = cfg.model_factory.submodule_for_loss_calculation + + teacher_target = ModuleTarget("teacher", teacher_model) + teacher_stitcher = Needle() + teacher_val_stitcher = Needle() + + student_target = ModuleTarget("student", student_model) + student_val_stitcher = Needle() + + for local_block_index, global_block_index in enumerate(sorted(owned_block_indexes)): + module_name = descriptor.layer_block_name(global_block_index) + module = student_model.get_submodule(module_name) + + submodule_name = "" + submodule_input_descriptor = submodule_name + submodule_output_descriptor = submodule_name + + if submodule_for_loss_calculation is not None: + assert hasattr(module, submodule_for_loss_calculation) + submodule_output_descriptor = submodule_for_loss_calculation + + input_descriptor = f"{module_name}.{submodule_input_descriptor}".rstrip(".") + output_descriptor = f"{module_name}.{submodule_output_descriptor}".rstrip(".") + + # Receive activations from previous rank + if global_block_index > 0 and local_block_index == 0 and prev_rank is not None: + teacher_stitcher.stitch( + RemoteTarget(peer_rank=prev_rank).value( + name="teacher_activations", adapter=lambda x: InputArgs(x) + ), + teacher_target.input( + name=module_name, + reducer=InputReducer( + lambda acc, override, orig, *args: override + orig.drop_args(0) + ), + ), + ) + teacher_val_stitcher.stitch( + RemoteTarget(peer_rank=prev_rank).value( + name="teacher_activations", adapter=lambda x: InputArgs(x) + ), + teacher_target.input( + name=module_name, + reducer=InputReducer( + lambda acc, override, orig, *args: override + orig.drop_args(0) + ), + ), + ) + student_val_stitcher.stitch( + RemoteTarget(peer_rank=prev_rank).value( + name="student_activations", adapter=lambda x: InputArgs(x) + ), + student_target.input( + name=module_name, + reducer=InputReducer( + lambda acc, override, orig, *args: override + orig.drop_args(0) + ), + ), + ) + + # Send activations to next rank or register model output + if local_block_index + 1 == len(owned_block_indexes): + if next_rank is None: + student_val_stitcher.stitch( + student_target.output(name=""), + ExternalTarget().output("model_output"), + ) + teacher_val_stitcher.stitch( + teacher_target.output(name=""), + ExternalTarget().output("model_output"), + ) + else: + teacher_stitcher.stitch( + teacher_target.output(name=module_name), + RemoteTarget(peer_rank=next_rank).value(name="teacher_activations"), + ) + teacher_val_stitcher.stitch( + teacher_target.output(name=module_name), + RemoteTarget(peer_rank=next_rank).value(name="teacher_activations"), + ) + student_val_stitcher.stitch( + student_target.output(name=module_name), + RemoteTarget(peer_rank=next_rank).value(name="student_activations"), + ) + + # Bypass training stitches + teacher_stitcher.stitch( + teacher_target.input(name=input_descriptor), + ExternalTarget().input(name=input_descriptor), + ).stitch( + teacher_target.output(name=output_descriptor), + ExternalTarget().output(name=output_descriptor), + ) + + # Create the student block stitched module + student_stitched_module_loss_target = FunctionTarget( + "module_loss_func", block_loss_func + ) + student_stitched_module_name = f"block_{global_block_index}" + student_submodule_target = ModuleTarget("student_submodule", module) + student_stitched_module = ( + Needle() + .stitch( + ExternalTarget().input(name=input_descriptor), + student_submodule_target.input(name=submodule_input_descriptor), + ) + .stitch( + ExternalTarget().output( + name=output_descriptor, + adapter=lambda v: InputArgs(target=v) + if not isinstance(v, tuple) + else InputArgs(target=v[0]), + ), + student_stitched_module_loss_target.input(), + ) + .stitch( + student_submodule_target.output( + name=submodule_output_descriptor, + adapter=lambda v: InputArgs(input=v) + if not isinstance(v, tuple) + else InputArgs(input=v[0]), + ), + student_stitched_module_loss_target.input(), + ) + .stitch( + student_stitched_module_loss_target.output(), + ExternalTarget().output(name="loss"), + ) + .knot( + ignore_extra_overrides=True, + capture_cache_outputs_predicate=always_true_predicate, + ) + ) + + assert "learning_rate" in cfg.training + num_trainable_params = sum( + p.requires_grad and submodule_name in p_name + for p_name, p in student_stitched_module.named_parameters() + if "dummy_param" not in p_name # exclude placeholder params + ) + # Do NOT enable dummy params: blocks with no real trainable parameters + # (e.g. Mamba blocks during an attention-only bypass run) should produce + # NaN loss so they are excluded from statistics — identical to the + # optimizer=None path in the training loop. + + student_module_parameters = { + p_name: p + for p_name, p in student_stitched_module.named_parameters() + if p not in teacher_parameters and "dummy_param" not in p_name + } + student_module_buffers = { + p_name: p + for p_name, p in student_stitched_module.named_buffers() + if p not in teacher_buffers + and p_name not in _get_all_non_persistent_buffers_set(student_stitched_module) + } + + trainable_params = { + p_name: p + for p_name, p in student_module_parameters.items() + if p.requires_grad + } + + optimizer = ( + AdamW( + list(trainable_params.values()), + lr=cfg.training.learning_rate, + weight_decay=cfg.training.weight_decay, + betas=(cfg.training.beta1, cfg.training.beta2), + fused=True, + ) + if len(trainable_params) > 0 + else None + ) + + grad_scaler = ( + None + if optimizer is None + else GradScaler(device=device.type, enabled=cfg.training.use_grad_scaling) + ) + + stitched_module_descriptors[student_stitched_module_name] = StitchedModuleDescriptor( + stitched_module=student_stitched_module, + owned_parameters=student_module_parameters, + owned_buffers=student_module_buffers, + optimizer=optimizer, + grad_scaler=grad_scaler, + ) + + teacher_stitched_module = teacher_stitcher.knot(ignore_extra_overrides=True) + teacher_val_stitched_module = teacher_val_stitcher.knot(ignore_extra_overrides=True) + student_val_stitched_module = student_val_stitcher.knot(ignore_extra_overrides=True) + + return ( + student_model, + teacher_stitched_module, + teacher_val_stitched_module, + student_val_stitched_module, + stitched_module_descriptors, + student_model_config, + ) + + + +# Backward-compatible name aliases +gqa_factory_fn = bypass_factory_fn +moe_factory_fn = bypass_factory_fn diff --git a/modelopt/torch/puzzletron/bypass_distillation/training_loop.py b/modelopt/torch/puzzletron/bypass_distillation/training_loop.py new file mode 100644 index 000000000..349bb27f5 --- /dev/null +++ b/modelopt/torch/puzzletron/bypass_distillation/training_loop.py @@ -0,0 +1,951 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Bypass distillation training loop for per-block knowledge distillation. + +This module implements the blockwise local distillation (BLD) stage of the PUZZLE framework. +It trains alternative transformer block configurations using per-block knowledge distillation +from a teacher model, producing a library of "puzzle pieces" with different efficiency/performance +trade-offs. +""" + +import logging +import math +import os +import shutil +import sys +import time +import traceback +from collections import OrderedDict, defaultdict +from pathlib import Path +from statistics import mean +from typing import Optional, Type, cast + +import datasets +import torch +import torch.distributed +import transformers +from omegaconf import DictConfig +from torch.utils.data.dataloader import DataLoader +from transformers import AutoTokenizer, PreTrainedTokenizerBase, PretrainedConfig + +import modelopt.torch.utils.distributed as dist +from modelopt.torch.puzzletron.anymodel.model_descriptor import ModelDescriptor, ModelDescriptorFactory +from modelopt.torch.puzzletron.sewing_kit import InputArgs, StitchedModule +from modelopt.torch.puzzletron.sewing_kit.utils import fake_tensor +from modelopt.torch.puzzletron.tools.checkpoint_utils_hf import load_model_config +from modelopt.torch.puzzletron.tools.logger import aprint, mprint +from modelopt.torch.puzzletron.tools.robust_json import json_load +from modelopt.torch.puzzletron.tools.sharded_checkpoint_utils import load_and_shard_model +from modelopt.torch.puzzletron.utils.parsing import format_global_config, format_stitched_losses + +from .bypass_checkpoint_utils import find_latest_run_dir, load_local_state, save_bypass_checkpoint +from .bypass_utils import get_distributed_modules_ownership, set_experiment_dir, set_experiment_id +from .data_classes import GlobalRank, IterNum, IterStatistics, LocalTrainingStats, TimeToSaveSignal +from .stitched_model_factory import StitchedModuleDescriptor, StitchedModulesProcessOwnership + +import modelopt.torch.puzzletron.bypass_distillation.stitched_model_factory as stitched_model_factory_module + +time_start = time.time() + +os.environ["TOKENIZERS_PARALLELISM"] = "false" + + +def launch_bypass_distillation(hydra_cfg: DictConfig) -> None: + """Top-level entry point for bypass distillation stage. + + Supports multiple bypass configurations via ``bypass.configs`` list. + Each entry overrides ``bypass.model.model_config_overrides`` and optionally + ``bypass.model_factory.keys_to_learn``, then runs a full bypass training. + + If ``bypass.configs`` is absent or empty, runs a single bypass training + with the settings already in ``bypass``. + + Args: + hydra_cfg: The full Hydra configuration with a 'bypass' section. + """ + configs_list = hydra_cfg.bypass.get("configs", None) + + if not configs_list: + # Single config mode — run once with whatever is in bypass already + mprint("Starting bypass distillation (single config)") + run_bypassed_training(hydra_cfg) + mprint("Bypass distillation completed") + return + + mprint(f"Starting bypass distillation sweep ({len(configs_list)} configs)") + for i, override in enumerate(configs_list): + mprint(f"Bypass config {i + 1}/{len(configs_list)}: {override}") + + # Apply overrides for this run + if "model_config_overrides" in override: + hydra_cfg.bypass.model.model_config_overrides = override.model_config_overrides + if "keys_to_learn" in override: + hydra_cfg.bypass.model_factory.keys_to_learn = override.keys_to_learn + + # Reset per-run state so each config starts fresh + hydra_cfg.bypass.experiment_id = None + hydra_cfg.bypass.iter_num = 1 + hydra_cfg.bypass.step_num = 1 + hydra_cfg.bypass.token_count = 0 + hydra_cfg.bypass.best_val_loss = 1e9 + hydra_cfg.bypass.training.clipping_count = 0 + + run_bypassed_training(hydra_cfg) + mprint(f"Bypass config {i + 1}/{len(configs_list)} completed") + + mprint("Bypass distillation sweep completed") + + +def train( + cfg: DictConfig, + descriptor: Type[ModelDescriptor], + student_model: torch.nn.Module, + student_stitched_model: StitchedModule, + teacher_stitched_model: StitchedModule, + stitched_module_descriptors: OrderedDict[str, StitchedModuleDescriptor], + stitched_modules_process_ownership: StitchedModulesProcessOwnership, + train_dataloader: DataLoader, + val_dataloader: Optional[DataLoader], + student_model_config: PretrainedConfig, + skip_first_batches: int = 0, + tokenizer: Optional[PreTrainedTokenizerBase] = None, +) -> None: + """Inner training loop for bypass distillation.""" + device = torch.device(f"cuda:{dist.local_rank()}") + + dist.barrier() + + time_last_save = time_start + iter_t0 = time.time() + + resumed_iter_num = cfg.bypass.iter_num + mprint(f"resumed_iter_num: {resumed_iter_num}") + + # Number of total stitched modules + global_stitched_modules_count = len(stitched_modules_process_ownership) + # Number of stitched modules per process + num_stitched_modules_per_process = [ + sum(1 for x in stitched_modules_process_ownership if x == owner_rank) + for owner_rank in range(dist.size()) + ] + # Indices of stitched modules owned by the current process + owned_stitched_module_indices = [ + i + for i, owner in enumerate(stitched_modules_process_ownership) + if owner == dist.rank() + ] + mprint(f"{global_stitched_modules_count=}") + mprint(f"{num_stitched_modules_per_process=}") + dist.barrier() + + if dist.is_master(): + # {iter_num: {stitched_module_name: loss}} + stitched_losses_history = dict[IterNum, dict[str, float]]() + else: + stitched_losses_history = None + + # Save checkpoint before training starts + if cfg.bypass.save_checkpoint_before_training and not cfg.bypass.disable_checkpoint_save: + subdir_name = f"start-iter-{cfg.bypass.iter_num:06d}-ckpt" + save_bypass_checkpoint( + cfg=cfg, + descriptor=descriptor, + model=student_model, + stitched_module_descriptors=stitched_module_descriptors, + checkpoint_dir=cfg.bypass.experiment_dir / subdir_name, + reference_checkpoint_dir=cfg.teacher_dir, + ) + + # Track statistics for each iteration + iter_stats_history: dict[IterNum, IterStatistics] = {} + + # Create fake input ids for the teacher model + fake_input_ids = fake_tensor( + torch.ones( + size=(cfg.bypass.training.micro_batch_size, cfg.bypass.data.block_size), + dtype=torch.long, + device=device, + ) + ) + + # Get pipeline neighbor ranks + min_owned_index = min(owned_stitched_module_indices) + max_owned_index = max(owned_stitched_module_indices) + prev_rank: Optional[int] = ( + None + if min_owned_index - 1 < 0 + else stitched_modules_process_ownership[min_owned_index - 1] + ) + next_rank: Optional[int] = ( + None + if max_owned_index + 1 >= global_stitched_modules_count + else stitched_modules_process_ownership[max_owned_index + 1] + ) + + torch.cuda.synchronize() + + mprint(f'Grad scaling status: {"enabled" if cfg.bypass.training.use_grad_scaling else "disabled"}') + + train_iterator = iter(train_dataloader) + + mprint("Waiting for everyone before training starts") + dist.barrier() + + step_to_save = None + # Track best loss value for each block + best_losses_by_name = dict[str, float]() + best_steps_by_name = dict[str, int]() + # Buffer variables + input_ids = torch.zeros(1, 1, dtype=torch.int64) + + aprint( + f"previous rank: {str(prev_rank):<5} next rank: {str(next_rank):<5} {owned_stitched_module_indices=}" + ) + + # Train loop start + while True: + time_now = time.time() + # Check if we've reached the maximum number of steps + if cfg.bypass.step_num >= cfg.bypass.training.max_steps: + if ( + cfg.bypass.model.model_overrides.save_checkpoint_when_done + and not cfg.bypass.disable_checkpoint_save + ): + mprint("Saving final checkpoint before training completion") + subdir_name = f"final-iter-{cfg.bypass.iter_num:06d}-ckpt" + save_bypass_checkpoint( + cfg=cfg, + descriptor=descriptor, + model=student_model, + stitched_module_descriptors=stitched_module_descriptors, + checkpoint_dir=cfg.bypass.experiment_dir / subdir_name, + reference_checkpoint_dir=cfg.teacher_dir, + ) + + if cfg.bypass.model.model_overrides.delete_old_checkpoints and dist.is_master(): + existing_ckpt_paths = list(Path(cfg.bypass.experiment_dir).glob("iter-*")) + for old_ckpt_path in existing_ckpt_paths: + if old_ckpt_path.name != subdir_name: + shutil.rmtree(str(old_ckpt_path)) + break + + is_accumulating = cfg.bypass.iter_num % cfg.bypass.training.grad_accumulation_steps != 0 + # Determine and set the learning rate for this iteration + lr = ( + _get_lr(cfg, cfg.bypass.step_num) + if cfg.bypass.training.decay_lr + else cfg.bypass.training.learning_rate + ) + for stitched_module_descriptor in stitched_module_descriptors.values(): + optimizer = stitched_module_descriptor.optimizer + if optimizer is not None: + for param_group in optimizer.param_groups: + param_group["lr"] = lr + + if dist.is_master(): + train_data = next(train_iterator) + input_ids = train_data["input_ids"] + input_ids = input_ids.to(device) + + with torch.autocast(device_type="cuda", dtype=torch.bfloat16), torch.no_grad(): + teacher_input_ids = input_ids if prev_rank is None else fake_input_ids + teacher_output = teacher_stitched_model({}, {}, teacher_input_ids) + + input_overrides = teacher_output.captured_inputs + output_overrides = teacher_output.captured_outputs + + del teacher_output + + input_overrides["teacher_inputs"] = InputArgs(fake_input_ids) + + iter_stitched_module_losses: dict[str, float] = {} + + for local_stitched_module_index, ( + stitched_module_name, + stitched_module_descriptor, + ) in enumerate(stitched_module_descriptors.items()): + stitched_module = stitched_module_descriptor.stitched_module + optimizer = stitched_module_descriptor.optimizer + grad_scaler = stitched_module_descriptor.grad_scaler + + if optimizer is not None: + assert grad_scaler is not None + + with torch.autocast(device_type="cuda", dtype=torch.bfloat16): + stitched_module_output = stitched_module( + input_overrides=input_overrides, + output_overrides=output_overrides, + ) + stitched_module_loss = stitched_module_output.captured_outputs["loss"] + del stitched_module_output + grad_scaler.scale(stitched_module_loss).backward() + else: + stitched_module_loss = torch.full( + [1], fill_value=torch.nan, dtype=torch.float32 + ) + + iter_stitched_module_losses[stitched_module_name] = ( + stitched_module_loss.to("cpu").item() + ) + + del stitched_module_loss + + if not is_accumulating: + if optimizer is not None: + grad_clip = cfg.bypass.training.grad_clip + if grad_clip is not None: + if cfg.bypass.training.grad_clip_type == "norm": + grad_norm = torch.nn.utils.clip_grad_norm_( + parameters=stitched_module.parameters(), + max_norm=grad_clip, + ) + if grad_norm > grad_clip: + cfg.bypass.training.clipping_count += 1 + elif cfg.bypass.training.grad_clip_type == "value": + max_abs_grad_per_param = [ + p.grad.abs().max().item() + for p in stitched_module.parameters() + if p.grad is not None + ] + max_abs_grad = ( + max(max_abs_grad_per_param) + if len(max_abs_grad_per_param) > 0 + else 0.0 + ) + if max_abs_grad > grad_clip: + cfg.bypass.training.clipping_count += 1 + torch.nn.utils.clip_grad_value_( + parameters=stitched_module.parameters(), + clip_value=grad_clip, + ) + else: + raise RuntimeError( + f"Invalid {cfg.bypass.training.grad_clip_type}" + ) + + assert grad_scaler is not None + grad_scaler.step(optimizer) + grad_scaler.update() + optimizer.zero_grad(set_to_none=True) + + # Collect losses from all ranks using all_gather_object + local_training_stats = LocalTrainingStats( + iter_num=cfg.bypass.iter_num, + stitched_module_losses=iter_stitched_module_losses, + ) + all_training_stats = [None] * dist.size() + torch.distributed.all_gather_object(all_training_stats, local_training_stats) + + if dist.is_master(): + if cfg.bypass.iter_num == resumed_iter_num: + mprint(f"Starting from iter {cfg.bypass.iter_num}") + + # Merge all stats into the losses history + assert stitched_losses_history is not None + merged_losses: dict[str, float] = {} + for stats in all_training_stats: + if stats is not None: + merged_losses.update(stats.stitched_module_losses) + stitched_losses_history[cfg.bypass.iter_num] = merged_losses + + cfg.bypass.token_count += cfg.bypass.training.tokens_per_iter + iter_t1 = time.time() + iter_duration = iter_t1 - iter_t0 + iter_stats_history[cfg.bypass.iter_num] = IterStatistics( + token_count=cfg.bypass.token_count, + iter_duration=iter_duration, + step_num=cfg.bypass.step_num, + lr=lr, + clipping_count=cfg.bypass.training.clipping_count, + ) + iter_t0 = iter_t1 + + # Time-based save signal (broadcast from master) + save_signal = [step_to_save] + if dist.is_master(): + if cfg.bypass.model.model_overrides.save_interval_seconds is not None: + time_now = time.time() + if time_now - time_last_save >= cfg.bypass.model.model_overrides.save_interval_seconds: + mprint( + f"Time to save! {cfg.bypass.model.model_overrides.save_interval_seconds=}, " + f"{time_last_save=}, {time_now=}" + ) + step_to_save = cfg.bypass.step_num + 5 + save_signal = [step_to_save] + time_last_save = time_now + + torch.distributed.broadcast_object_list(save_signal, src=0) + step_to_save = save_signal[0] + + # Logging + if dist.is_master(): + assert stitched_losses_history is not None + while len(stitched_losses_history) >= cfg.bypass.training.log_interval: + lowest_iter = next(iter(stitched_losses_history.keys())) + + log_chunk = { + it: losses + for it, losses in stitched_losses_history.items() + if it - lowest_iter < cfg.bypass.training.log_interval + } + if len(log_chunk) < cfg.bypass.training.log_interval: + break + + highest_iter = list(log_chunk.keys())[-1] + highest_iter_stats = iter_stats_history[highest_iter] + + losses_by_name = defaultdict[str, list[float]](lambda: []) + for losses in log_chunk.values(): + for name, loss in losses.items(): + losses_by_name[name].append(loss) + + losses_by_name_avg = { + name: mean(losses) for name, losses in losses_by_name.items() + } + + # Update best losses tracking + for name, current_loss in losses_by_name_avg.items(): + if name not in best_losses_by_name or current_loss < best_losses_by_name[name]: + best_losses_by_name[name] = current_loss + best_steps_by_name[name] = highest_iter + + chunk_iter_durations = [ + iter_stats_history[it].iter_duration for it in log_chunk.keys() + ] + avg_chunk_iter_duration = mean(chunk_iter_durations) + avg_token_speed = cfg.bypass.training.tokens_per_iter / avg_chunk_iter_duration + mprint( + f"iter {highest_iter}/{cfg.bypass.training.max_steps:,}:" + f" avg_iter_time={avg_chunk_iter_duration * 1000:.2f}ms" + f" avg_token_speed={avg_token_speed:,.0f}[tok/s]" + ) + mprint( + format_stitched_losses( + losses_dict=losses_by_name_avg, + best_steps_dict=best_steps_by_name, + best_values_dict=best_losses_by_name, + step_number=highest_iter, + title="Stitched Module Losses", + ) + ) + + if cfg.bypass.wandb_log: + try: + import wandb + + wandb.log( + { + "iter": highest_iter, + "step": highest_iter_stats.step_num, + "token_count": highest_iter_stats.token_count, + "token_speed": avg_token_speed, + "lr": highest_iter_stats.lr, + "grad_clipping": highest_iter_stats.clipping_count, + }, + step=highest_iter, + ) + except ImportError: + pass + + for it in log_chunk.keys(): + del iter_stats_history[it] + del stitched_losses_history[it] + + # Validation + if ( + not is_accumulating + and (cfg.bypass.step_num % cfg.bypass.training.eval_interval) == 0 + and val_dataloader is not None + ): + from modelopt.torch.puzzletron.utils.validate_runtime_pipeline import ( + calculate_losses_pipeline, + ) + + losses, _ = calculate_losses_pipeline( + stitched_model=student_stitched_model, + dataloader=val_dataloader, + descriptor=descriptor, + ) + + val_loss = float("inf") + if losses is not None and "lm_loss" in losses: + val_loss = losses["lm_loss"]["avg"] + mprint(f"Validation loss at iter {cfg.bypass.iter_num}: {val_loss:.4f}") + + # Broadcast val_loss so all ranks agree on checkpoint decisions + val_loss_tensor = torch.tensor([val_loss], device=device) + torch.distributed.broadcast(val_loss_tensor, src=dist.size() - 1) + val_loss = val_loss_tensor.item() + + if val_loss < cfg.bypass.best_val_loss: + cfg.bypass.best_val_loss = val_loss + if not cfg.bypass.disable_checkpoint_save and cfg.bypass.save_best_ckpt: + subdir_name = f"best-iter-{cfg.bypass.iter_num:06d}-ckpt" + save_bypass_checkpoint( + cfg=cfg, + descriptor=descriptor, + model=student_model, + stitched_module_descriptors=stitched_module_descriptors, + checkpoint_dir=cfg.bypass.experiment_dir / subdir_name, + reference_checkpoint_dir=cfg.teacher_dir, + ) + if cfg.bypass.kill_after_first_save: + raise RuntimeError( + "Done saving checkpoint, kill_after_first_save=True" + ) + + # Checkpoint saving (step-based or time-based) + if not is_accumulating and ( + (cfg.bypass.step_num % cfg.bypass.model.model_overrides.save_interval) == 0 + or step_to_save == cfg.bypass.step_num + or ( + cfg.bypass.model.model_overrides.save_checkpoint_when_done + and cfg.bypass.step_num >= cfg.bypass.training.max_steps + ) + ): + if not cfg.bypass.disable_checkpoint_save: + if (cfg.bypass.step_num % cfg.bypass.model.model_overrides.save_interval) == 0: + mprint("Saving step-interval checkpoint") + elif step_to_save == cfg.bypass.step_num: + mprint("Saving time-based checkpoint") + elif ( + cfg.bypass.model.model_overrides.save_checkpoint_when_done + and cfg.bypass.step_num >= cfg.bypass.training.max_steps - 100 + ): + mprint("Saving final checkpoint") + + subdir_name = f"iter-{cfg.bypass.iter_num:06d}-ckpt" + save_bypass_checkpoint( + cfg=cfg, + descriptor=descriptor, + model=student_model, + stitched_module_descriptors=stitched_module_descriptors, + checkpoint_dir=cfg.bypass.experiment_dir / subdir_name, + reference_checkpoint_dir=cfg.teacher_dir, + ) + + if cfg.bypass.kill_after_first_save: + dist.barrier() + raise RuntimeError( + "Done saving checkpoint, kill_after_first_save=True" + ) + + if cfg.bypass.model.model_overrides.delete_old_checkpoints and dist.is_master(): + existing_ckpt_paths = list( + Path(cfg.bypass.experiment_dir).glob("iter-*") + ) + for old_ckpt_path in existing_ckpt_paths: + if old_ckpt_path.name != subdir_name: + shutil.rmtree(str(old_ckpt_path)) + + cfg.bypass.iter_num += 1 + if not is_accumulating: + cfg.bypass.step_num += 1 + + mprint("Finished successfully!") + + +# Learning rate decay scheduler (cosine with warmup) +def _get_lr(cfg: DictConfig, step: int) -> float: + # 1) linear warmup for warmup_steps steps + if step <= cfg.bypass.training.warmup_steps: + lr = cfg.bypass.training.learning_rate * step / cfg.bypass.training.warmup_steps + # 2) if step > lr_decay_steps, return min learning rate + elif step > cfg.bypass.training.lr_decay_steps: + lr = cfg.bypass.training.min_lr + # 3) in between, use cosine decay down to min learning rate + else: + decay_ratio = (step - cfg.bypass.training.warmup_steps - 1) / ( + cfg.bypass.training.lr_decay_steps - cfg.bypass.training.warmup_steps + ) + assert 0 <= decay_ratio <= 1 + coeff = 0.5 * (1.0 + math.cos(math.pi * decay_ratio)) # coeff ranges 0..1 + lr = cfg.bypass.training.min_lr + coeff * ( + cfg.bypass.training.learning_rate - cfg.bypass.training.min_lr + ) + + return lr + + +def run_bypassed_training(cfg: DictConfig): + """Setup and orchestrate bypass distillation training.""" + logging.basicConfig( + format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", level=logging.WARN + ) + + # Suppress debug messages from HuggingFace libraries + datasets.utils.logging.set_verbosity_error() + transformers.utils.logging.set_verbosity_error() + + device = torch.device(f"cuda:{dist.local_rank()}") + + descriptor = ModelDescriptorFactory.get(cfg.descriptor) + trust_remote_code = descriptor.requires_trust_remote_code() + teacher_model_config = load_model_config(cfg.teacher_dir, trust_remote_code=trust_remote_code) + + try: + mprint("Waiting for distributed setup...") + dist.barrier() + + if cfg.bypass.disable_initial_validate: + cfg.bypass.validate_teacher_model = False + cfg.bypass.validate_student_model = False + + if cfg.bypass.teacher_model_load_on_cpu: + assert not cfg.bypass.validate_teacher_model, ( + "Teacher model validation is too slow on CPU" + ) + + num_hidden_layers = descriptor.get_language_model_config( + teacher_model_config + ).num_hidden_layers + + model_blocks_process_ownership = get_distributed_modules_ownership( + module_count=num_hidden_layers, + world_size=dist.size(), + ) + + owned_block_indexes = set( + block_index + for block_index, owner_rank in enumerate(model_blocks_process_ownership) + if owner_rank == dist.rank() + ) + + cfg.teacher_dir = str(Path(cfg.teacher_dir).expanduser()) + teacher_model_config = load_model_config( + cfg.teacher_dir, + model_config_overrides={"use_cache": False}, + trust_remote_code=trust_remote_code, + ) + + student_model = None + if cfg.bypass.init_checkpoint_path is not None: + mprint(f"Loading student model from {cfg.bypass.init_checkpoint_path}") + student_model = load_and_shard_model( + descriptor=descriptor, + checkpoint_path=cfg.bypass.init_checkpoint_path, + owned_block_indexes=owned_block_indexes, + ) + + cfg.bypass.training.min_lr = ( + cfg.bypass.training.learning_rate * cfg.bypass.training.min_lr_factor + ) + cfg.bypass.training.batch_size_per_iter = cfg.bypass.training.micro_batch_size + cfg.bypass.training.tokens_per_iter = ( + cfg.bypass.data.block_size * cfg.bypass.training.batch_size_per_iter + ) + cfg.bypass.training.max_steps = math.ceil( + cfg.bypass.training.training_tokens / cfg.bypass.training.tokens_per_iter + ) + cfg.bypass.training.max_iters = ( + cfg.bypass.training.max_steps * cfg.bypass.training.grad_accumulation_steps + ) + cfg.bypass.training.max_token_count = ( + cfg.bypass.training.max_iters * cfg.bypass.training.tokens_per_iter + ) + cfg.bypass.training.lr_decay_steps = cfg.bypass.training.max_steps + + if cfg.bypass.training.val_micro_batch_size is None: + cfg.bypass.training.val_micro_batch_size = cfg.bypass.training.micro_batch_size + + if cfg.bypass.training.warmup_steps is None: + cfg.bypass.training.warmup_steps = 0 + + mprint(f'\n{format_global_config(cfg.bypass, "Bypass Configurations")}') + mprint(f"Max token count: {cfg.bypass.training.max_token_count:,}") + + seed = cfg.bypass.seed + torch.manual_seed(seed) + + tokenizer = AutoTokenizer.from_pretrained( + cfg.teacher_dir, + trust_remote_code=True, + token=True, + ) + + assert teacher_model_config is not None + + mprint( + f"Load and shard model with: {owned_block_indexes=}, {cfg.teacher_dir=}" + ) + teacher_model = load_and_shard_model( + descriptor=descriptor, + checkpoint_path=cfg.teacher_dir, + owned_block_indexes=owned_block_indexes, + model_config=teacher_model_config, + ) + + teacher_model.requires_grad_(False) + + # Create dataloaders + from modelopt.torch.puzzletron.utils.data.dataloaders import ( + create_train_dataloader, + create_validation_dataloader, + load_from_disk_fn, + load_streaming_fn, + ) + + if cfg.bypass.data.eval_samples_per_process is not None: + max_eval_samples = cfg.bypass.data.eval_samples_per_process * dist.size() + else: + max_eval_samples = cfg.bypass.data.max_eval_samples + + load_dataset_fn = load_streaming_fn if not cfg.bypass.data.load_from_disk else load_from_disk_fn + + train_dataloader = create_train_dataloader( + seed=seed, + tokenizer=tokenizer, + block_size=cfg.bypass.data.block_size, + dataset_path=cfg.dataset_path, + content_field=cfg.bypass.data.data_column, + fim_rate=cfg.bypass.data.fim_rate, + fim_spm_rate=cfg.bypass.data.fim_spm_rate, + micro_batch_size=cfg.bypass.training.micro_batch_size, + load_dataset_fn=load_dataset_fn, + keep_in_memory=cfg.bypass.data.keep_in_memory, + source_datasets_to_discard=cfg.bypass.get("source_datasets_to_discard", tuple()), + bos_rate=cfg.bypass.data.bos_rate, + shuffle_seed=cfg.bypass.data.shuffle_train_data_seed, + ) + + val_dataloader = None + if not cfg.bypass.disable_validation: + val_dataloader = create_validation_dataloader( + accelerator=None, + seed=seed, + tokenizer=tokenizer, + block_size=cfg.bypass.data.block_size, + dataset=cfg.dataset_path, + content_field=cfg.bypass.data.data_column, + fim_rate=cfg.bypass.data.fim_rate, + fim_spm_rate=cfg.bypass.data.fim_spm_rate, + micro_batch_size=cfg.bypass.training.val_micro_batch_size, + eval_samples=max_eval_samples, + load_dataset_fn=load_dataset_fn, + dataset_name=cfg.bypass.data.val_dataset_name, + keep_in_memory=cfg.bypass.data.keep_in_memory, + source_datasets_to_discard=cfg.bypass.get( + "source_datasets_to_discard", tuple() + ), + bos_rate=cfg.bypass.data.bos_rate, + ) + + # Set ID from experiment configuration + set_experiment_id(cfg) + # Set directory for experiment ID + set_experiment_dir(cfg) + + dist.barrier() + + with torch.device(device): + stitched_model_factory_fn = cast( + stitched_model_factory_module.StitchedModelFactoryFn, + getattr(stitched_model_factory_module, cfg.bypass.model_factory.factory), + ) + ( + student_model, + teacher_stitched_model, + teacher_val_stitched_module, + student_val_stitched_model, + stitched_module_descriptors, + student_model_config, + ) = stitched_model_factory_fn( + teacher_model=teacher_model, + descriptor=descriptor, + cfg=cfg.bypass, + model_blocks_process_ownership=model_blocks_process_ownership, + student_model=student_model, + ) + + # Check whether to resume from checkpoint + resume_checkpoint_path = None + if cfg.bypass.resume_checkpoint_path is not None: + resume_checkpoint_path = cfg.bypass.resume_checkpoint_path + elif cfg.bypass.find_last_ckpt_for_resume: + _ckpt_dir = find_latest_run_dir(run_parent_dir=cfg.bypass.experiment_dir) + if _ckpt_dir is None: + mprint( + "Couldn't find any run dir for resume, assuming this is the first job" + ) + else: + mprint( + f"`cfg.bypass.find_last_ckpt_for_resume` is True. " + f"Auto-found a checkpoint to resume: `{_ckpt_dir}`" + ) + resume_checkpoint_path = _ckpt_dir + + if resume_checkpoint_path: + load_local_state( + stitched_module_descriptors=stitched_module_descriptors, + checkpoint_path=resume_checkpoint_path, + ) + + # Load resume ckpt bypass configs and extract resume iter_num + resume_cfg = DictConfig(json_load(Path(resume_checkpoint_path) / "args.json")) + + # Resume stats + cfg.bypass.iter_num = resume_cfg.iter_num + cfg.bypass.token_count = resume_cfg.token_count + cfg.bypass.step_num = resume_cfg.step_num + cfg.bypass.best_val_loss = resume_cfg.best_val_loss + cfg.bypass.training.clipping_count = resume_cfg.training.clipping_count + mprint(f"Resume from iter_num: {cfg.bypass.iter_num}") + + # Only copy wandb.run_id if it exists in resume config + if hasattr(resume_cfg, "wandb") and hasattr(resume_cfg.wandb, "run_id"): + cfg.bypass.wandb.run_id = resume_cfg.wandb.run_id + + cfg.bypass.save_checkpoint_before_training = False + cfg.bypass.validate_teacher_model = False + cfg.bypass.validate_student_model = False + + cfg.bypass.resume_checkpoint_path = resume_checkpoint_path + + # Initialize Weights and Biases + if cfg.bypass.wandb_log: + try: + import wandb + + wandb.init( + project=cfg.bypass.wandb.project, + entity=cfg.bypass.wandb.entity, + config=dict(cfg.bypass), + ) + except ImportError: + mprint("wandb not installed, disabling wandb logging") + cfg.bypass.wandb_log = False + else: + mprint("Weights & Biases logging disabled (wandb_log=False)") + + if cfg.bypass.validate_teacher_model and val_dataloader is not None: + from modelopt.torch.puzzletron.utils.validate_runtime_pipeline import ( + calculate_losses_pipeline, + ) + + mprint("Evaluating teacher model:") + losses, _ = calculate_losses_pipeline( + stitched_model=teacher_val_stitched_module, + dataloader=val_dataloader, + descriptor=descriptor, + ) + if losses is not None: + mprint(f"Teacher validation losses: {losses}") + mprint("Evaluated teacher model") + + torch.cuda.empty_cache() + dist.barrier() + + parameter_count = sum(p.numel() for p in student_model.parameters()) + aprint(f"Model parameter count: {parameter_count:,}") + cfg.bypass.parameter_count = parameter_count + + dist.barrier() + mprint("Performing dummy runs on stitched modules:") + torch.cuda.synchronize() + with torch.no_grad(), torch.autocast( + device_type="cuda", dtype=torch.bfloat16 + ), torch.device(device): + input_ids = torch.ones( + (cfg.bypass.training.micro_batch_size, cfg.bypass.data.block_size), + dtype=torch.long, + ) + dummy_fake_input_ids = fake_tensor(input_ids) + mprint(f"Dummy runs on stitched modules with shape: {dummy_fake_input_ids.shape=}") + teacher_output = teacher_stitched_model({}, {}, input_ids) + for stitched_module_descriptor in stitched_module_descriptors.values(): + stitched_module = stitched_module_descriptor.stitched_module + stitched_module( + input_overrides={ + **teacher_output.captured_inputs, + "teacher_inputs": InputArgs(dummy_fake_input_ids), + }, + output_overrides=teacher_output.captured_outputs, + ) + for name, param in stitched_module.named_parameters(recurse=True): + if "iter_num" in name: + param.data = torch.zeros_like(param.data) + del name, param + del input_ids, dummy_fake_input_ids, teacher_output + torch.cuda.synchronize() + dist.barrier() + + del teacher_model + + if cfg.bypass.validate_student_model and val_dataloader is not None: + from modelopt.torch.puzzletron.utils.validate_runtime_pipeline import ( + calculate_losses_pipeline, + ) + + mprint("Validating model before training:") + losses, _ = calculate_losses_pipeline( + stitched_model=student_val_stitched_model, + dataloader=val_dataloader, + descriptor=descriptor, + ) + if losses is not None: + mprint(f"Student validation losses: {losses}") + + dist.barrier() + torch.cuda.empty_cache() + dist.barrier() + + train( + cfg=cfg, + descriptor=descriptor, + student_model=student_model, + student_stitched_model=student_val_stitched_model, + teacher_stitched_model=teacher_stitched_model, + stitched_module_descriptors=stitched_module_descriptors, + stitched_modules_process_ownership=model_blocks_process_ownership, + train_dataloader=train_dataloader, + val_dataloader=val_dataloader, + student_model_config=student_model_config, + skip_first_batches=cfg.bypass.training.skip_first_batches, + tokenizer=tokenizer, + ) + + aprint("Finished training successfully!") + dist.barrier() + + except Exception as e: + print(traceback.format_exc(), file=sys.stderr) + if isinstance(e, SystemExit): + raise e + else: + sys.exit(1) + + dist.barrier() + if dist.is_master(): + mprint("Realizing bypass checkpoints") + realize_bypass_checkpoints(cfg) + + +def realize_bypass_checkpoints(cfg: DictConfig): + """Create symlinks from bypass checkpoint directories to the ckpts directory.""" + checkpoint_dir = Path(cfg.bypass.experiment_dir) / "latest" + if not checkpoint_dir.exists(): + mprint(f"Could not find checkpoint directory: {checkpoint_dir}") + return + + ckpts_dir = Path(cfg.puzzle_dir) / "ckpts" + ckpts_dir.mkdir(parents=True, exist_ok=True) + + symlink_name = ckpts_dir / cfg.bypass.experiment_id + if symlink_name.exists() or symlink_name.is_symlink(): + symlink_name.unlink() + + symlink_name.symlink_to(checkpoint_dir, target_is_directory=True) + mprint(f"Created symlink: {symlink_name} -> {checkpoint_dir}") diff --git a/modelopt/torch/puzzletron/nas/plugins/puzzletron_nas_plugin.py b/modelopt/torch/puzzletron/nas/plugins/puzzletron_nas_plugin.py index e5025dea7..042b2adce 100644 --- a/modelopt/torch/puzzletron/nas/plugins/puzzletron_nas_plugin.py +++ b/modelopt/torch/puzzletron/nas/plugins/puzzletron_nas_plugin.py @@ -27,6 +27,7 @@ import torch from torch import nn +import modelopt.torch.puzzletron.bypass_distillation as bypass_distillation import modelopt.torch.puzzletron.mip.mip_and_realize_models as mip_and_realize_models import modelopt.torch.puzzletron.pruning.pruning_ckpts as pruning_ckpts import modelopt.torch.puzzletron.scoring.scoring as scoring @@ -92,10 +93,28 @@ class PuzzletronConfig(ModeloptBaseConfig): ) +def _total_steps(hydra_cfg) -> int: + """Return total pipeline step count: 9 with bypass, 8 without. + + Steps: + 1 starting (main.py) + 2 convert model + 3 score pruning activations + 4 prune checkpoints + [5 bypass distillation — only when bypass is configured] + 5/6 build replacement library & subblock stats + 6/7 calculate one block scores + 7/8 MIP and realize models + 8/9 completed (main.py) + """ + return 9 if hydra_cfg.get("bypass", None) is not None else 8 + + def convert_puzzletron_model(model: nn.Module, config: PuzzletronConfig) -> ConvertReturnType: """1. Convert the model from HF format to AnyModel format. 2. Score the pruning activations. - 3. Prune the model and save pruned checkpoints + 3. Prune the model and save pruned checkpoints. + 4. (Optional) Run bypass distillation. The output of this step will be used by mnt.search() to perform the NAS search. """ @@ -117,37 +136,70 @@ def convert_puzzletron_model(model: nn.Module, config: PuzzletronConfig) -> Conv # Instantiate nested Hydra configs (e.g., pruning_mixin, hook_class) hydra_cfg = hydra.utils.instantiate(hydra_cfg) - # Convert HuggingFace model to Puzzletron heterogeneous format (generic, uses descriptor from config) - if dist.is_master(): - mprint( - "Puzzletron Progress 2/8: converting model to Puzzletron heterogeneous format (single-gpu)" - ) - hf_ckpt_teacher_dir = "ckpts/teacher" # TODO: make it configurable - - # Get descriptor and converter from the hydra config - descriptor_name = hydra_cfg.descriptor - descriptor = ModelDescriptorFactory.get(descriptor_name) - converter = ConverterFactory.get(descriptor_name) + has_bypass = hydra_cfg.get("bypass", None) is not None + N = _total_steps(hydra_cfg) - converter.convert( - descriptor=descriptor, - input_dir=Path(config.input_model_path), - output_dir=Path(config.puzzle_dir) / hf_ckpt_teacher_dir, - ) + # Step 2: Convert HuggingFace model to Puzzletron heterogeneous format + hf_ckpt_teacher_dir = "ckpts/teacher" # TODO: make it configurable + teacher_dir = Path(config.puzzle_dir) / hf_ckpt_teacher_dir + if dist.is_master(): + if (teacher_dir / "config.json").exists(): + mprint(f"Puzzletron Progress 2/{N}: teacher checkpoint already exists, skipping conversion") + else: + mprint(f"Puzzletron Progress 2/{N}: converting model to Puzzletron heterogeneous format (single-gpu)") + + # Get descriptor and converter from the hydra config + descriptor_name = hydra_cfg.descriptor + descriptor = ModelDescriptorFactory.get(descriptor_name) + converter = ConverterFactory.get(descriptor_name) + + # Auto-download from HuggingFace if path doesn't exist locally + input_model_path = config.input_model_path + if not Path(input_model_path).exists(): + from huggingface_hub import snapshot_download + + if input_model_path.startswith("https://huggingface.co/"): + model_id = "/".join(input_model_path.rstrip("/").split("/")[-2:]) + else: + model_id = input_model_path # assume HF model ID like "org/model-name" + mprint( + f"Downloading HuggingFace model '{model_id}' — this may take several minutes " + f"for large models. Other ranks are waiting at a barrier." + ) + input_model_path = snapshot_download(repo_id=model_id) + mprint(f"Downloaded to: {input_model_path}") + + converter.convert( + descriptor=descriptor, + input_dir=Path(input_model_path), + output_dir=teacher_dir, + ) dist.barrier() - # Score_pruning_activations (distributed processing) - mprint("Puzzletron Progress 3/8: scoring pruning activations (multi-gpu)") - score_pruning_activations.launch_score_activations(hydra_cfg) + # Step 3: Score pruning activations (distributed processing) + activations_log_dir = Path(hydra_cfg.pruning.activations_log_dir) + if activations_log_dir.exists() and any(activations_log_dir.glob("rank_*.pth")): + mprint(f"Puzzletron Progress 3/{N}: pruning activation scores already exist, skipping scoring") + dist.barrier() + else: + mprint(f"Puzzletron Progress 3/{N}: scoring pruning activations (multi-gpu)") + score_pruning_activations.launch_score_activations(hydra_cfg) - # Prune the model and save pruned checkpoints + # Step 4: Prune the model and save pruned checkpoints (single process) + pruned_ckpts_dir = Path(hydra_cfg.pruning.pruned_ckpts_output_dir) if dist.is_master(): - mprint( - "Puzzletron Progress 4/8: pruning the model and saving pruned checkpoints (single-gpu)" - ) - pruning_ckpts.launch_prune_ckpt(hydra_cfg) + if pruned_ckpts_dir.exists() and any(pruned_ckpts_dir.iterdir()): + mprint(f"Puzzletron Progress 4/{N}: pruned checkpoints already exist, skipping pruning") + else: + mprint(f"Puzzletron Progress 4/{N}: pruning the model and saving pruned checkpoints (single-gpu)") + pruning_ckpts.launch_prune_ckpt(hydra_cfg) dist.barrier() + # Step 5: Bypass distillation (optional, distributed processing) + if has_bypass: + mprint(f"Puzzletron Progress 5/{N}: running bypass distillation (multi-gpu)") + bypass_distillation.launch_bypass_distillation(hydra_cfg) + return model, {} @@ -218,18 +270,34 @@ def run_search(self) -> None: # Instantiate nested Hydra configs (e.g., pruning_mixin, hook_class) hydra_cfg = hydra.utils.instantiate(hydra_cfg) - # Build_library_and_stats (single process) + has_bypass = hydra_cfg.get("bypass", None) is not None + N = _total_steps(hydra_cfg) + # With bypass: library=6, scoring=7, mip=8 (out of 9) + # Without bypass: library=5, scoring=6, mip=7 (out of 8) + library_step = 6 if has_bypass else 5 + scoring_step = 7 if has_bypass else 6 + mip_step = 8 if has_bypass else 7 + + # Build replacement library and subblock statistics (single process) + puzzle_dir = Path(self.model.puzzle_dir) + replacement_library_path = puzzle_dir / "replacement_library.json" + subblock_stats_path = puzzle_dir / hydra_cfg.calc_subblock_stats.subblock_stats_filename if dist.is_master(): - mprint( - "Puzzletron Progress 5/8: building replacement library and subblock statistics (single-gpu)" - ) - build_library_and_stats.launch_build_library_and_stats(hydra_cfg) + if replacement_library_path.exists() and subblock_stats_path.exists(): + mprint( + f"Puzzletron Progress {library_step}/{N}: replacement library and subblock stats already exist, skipping" + ) + else: + mprint( + f"Puzzletron Progress {library_step}/{N}: building replacement library and subblock statistics (single-gpu)" + ) + build_library_and_stats.launch_build_library_and_stats(hydra_cfg) dist.barrier() - # Calc_one_block_scores (distributed processing) - mprint("Puzzletron Progress 6/8: calculating one block scores (multi-gpu)") + # Calculate one block scores (distributed processing) + mprint(f"Puzzletron Progress {scoring_step}/{N}: calculating one block scores (multi-gpu)") scoring.launch_scoring(hydra_cfg) - # mip_and_realize_models (distributed processing) - mprint("Puzzletron Progress 7/8: running MIP and realizing models (multi-gpu)") + # MIP search and realize models (distributed processing) + mprint(f"Puzzletron Progress {mip_step}/{N}: running MIP and realizing models (multi-gpu)") mip_and_realize_models.launch_mip_and_realize_model(hydra_cfg) diff --git a/modelopt/torch/puzzletron/pruning/pruning_utils.py b/modelopt/torch/puzzletron/pruning/pruning_utils.py index 82ba675c9..dbc40f082 100644 --- a/modelopt/torch/puzzletron/pruning/pruning_utils.py +++ b/modelopt/torch/puzzletron/pruning/pruning_utils.py @@ -44,6 +44,7 @@ class MlpInitMode(Enum): PruneByActivationsLog = "PruneByActivationsLog" ExpertRemoval = "ExpertRemoval" ConcatExpertsIntoDenseFFN = "ConcatExpertsIntoDenseFFN" + MoEChannelPruning = "MoEChannelPruning" class LinearInitMode(Enum): diff --git a/modelopt/torch/puzzletron/puzzletron.py b/modelopt/torch/puzzletron/puzzletron.py index 5a1484e07..457fef6df 100644 --- a/modelopt/torch/puzzletron/puzzletron.py +++ b/modelopt/torch/puzzletron/puzzletron.py @@ -20,6 +20,7 @@ import modelopt.torch.puzzletron.activation_scoring.score_pruning_activations as score_pruning_activations import modelopt.torch.puzzletron.build_library_and_stats as build_library_and_stats +import modelopt.torch.puzzletron.bypass_distillation as bypass_distillation import modelopt.torch.puzzletron.mip.mip_and_realize_models as mip_and_realize_models import modelopt.torch.puzzletron.pruning.pruning_ckpts as pruning_ckpts import modelopt.torch.puzzletron.scoring.scoring as scoring @@ -62,6 +63,10 @@ def puzzletron( pruning_ckpts.launch_prune_ckpt(hydra_cfg) dist.barrier() + # Step 3: bypass distillation (optional, distributed processing) + if hydra_cfg.get("bypass", None) is not None: + bypass_distillation.launch_bypass_distillation(hydra_cfg) + # Step 4: build_library_and_stats (single process) if dist.is_master(): build_library_and_stats.launch_build_library_and_stats(hydra_cfg) diff --git a/modelopt/torch/puzzletron/sewing_kit/utils.py b/modelopt/torch/puzzletron/sewing_kit/utils.py index 19c1bd6c8..6926ba1d9 100644 --- a/modelopt/torch/puzzletron/sewing_kit/utils.py +++ b/modelopt/torch/puzzletron/sewing_kit/utils.py @@ -429,3 +429,55 @@ def _get_group_kwarg_if_necessary() -> dict: torch.distributed.distributed_c10d._object_to_tensor ).parameters.keys() return dict(group=None) if "group" in arg_names else dict() + + +# ────────────────────────────────────────────────────────────────────────────── +# Loss functions for bypass distillation (blockwise local knowledge distillation) +# ────────────────────────────────────────────────────────────────────────────── + +Reduction = Literal["none", "mean", "sum"] + + +def normalized_mse_loss( + input: torch.Tensor, + target: torch.Tensor, + reduction: Reduction = "mean", + epsilon: float = 1e-6, +) -> torch.Tensor: + """MSE loss normalized by the variance of the target. + + Dividing by the target's self-MSE makes the loss scale-invariant, so that + blocks whose activations have large magnitude do not dominate training. + """ + loss = F.mse_loss(input, target, reduction=reduction) / F.mse_loss( + target, torch.zeros_like(target) + epsilon, reduction=reduction + ) + return loss + + +def vectorwise_normalized_mse_loss( + input: torch.Tensor, + target: torch.Tensor, + epsilon: float = 1e-6, +) -> torch.Tensor: + """Like normalized_mse_loss, but normalization is done per-vector (last dim), then averaged.""" + return batched_normalized_mse_loss(input, target, epsilon, batch_dims=range(input.ndim - 1)) + + +def batched_normalized_mse_loss( + input: torch.Tensor, + target: torch.Tensor, + epsilon: float = 1e-6, + batch_dims: Sequence[int] = (0,), +) -> torch.Tensor: + """Like normalized_mse_loss, but normalization is done on non-batch dims, then averaged. + + Useful when activations within a batch item should be normalized independently + rather than normalizing across the full batch. + """ + norm_dims = list(set(range(input.ndim)) - set(batch_dims)) + norm_of_target_vectors = F.mse_loss( + target, torch.zeros_like(target) + epsilon, reduction="none" + ).mean(norm_dims) + loss = F.mse_loss(input, target, reduction="none").mean(norm_dims) / norm_of_target_vectors + return loss.mean() diff --git a/modelopt/torch/puzzletron/tools/bypassed_training/child_init.py b/modelopt/torch/puzzletron/tools/bypassed_training/child_init.py index b30e7eefa..e7e6753d6 100644 --- a/modelopt/torch/puzzletron/tools/bypassed_training/child_init.py +++ b/modelopt/torch/puzzletron/tools/bypassed_training/child_init.py @@ -91,26 +91,28 @@ def _process_single_layer( keys_to_remove = {} layer_out_state_dict = {} - # Delegate to pruning_mixin if available + # Delegate to pruning_mixin if available (supports a single mixin or a list of mixins) if pruning_mixin is not None: - _layer_out = pruning_mixin.prune_single_layer( - layer_idx=layer_idx, - parent_state_dict=parent_state_dict, - new_state_dict=new_state_dict, - original_config=original_config, - new_config=new_config, - gqa_init_mode=gqa_init_mode, - mlp_init_mode=mlp_init_mode, - mlp_init_config=mlp_init_config, - linear_init_mode=linear_init_mode, - ignored_keys=ignored_keys, - keys=keys, - is_original_mha=is_original_mha, - head_size=head_size, - hidden_size=hidden_size, - keys_to_remove=keys_to_remove, - ) - layer_out_state_dict.update(_layer_out) + _mixins = pruning_mixin if isinstance(pruning_mixin, list) else [pruning_mixin] + for _mixin in _mixins: + _layer_out = _mixin.prune_single_layer( + layer_idx=layer_idx, + parent_state_dict=parent_state_dict, + new_state_dict=new_state_dict, + original_config=original_config, + new_config=new_config, + gqa_init_mode=gqa_init_mode, + mlp_init_mode=mlp_init_mode, + mlp_init_config=mlp_init_config, + linear_init_mode=linear_init_mode, + ignored_keys=ignored_keys, + keys=keys, + is_original_mha=is_original_mha, + head_size=head_size, + hidden_size=hidden_size, + keys_to_remove=keys_to_remove, + ) + layer_out_state_dict.update(_layer_out) return layer_out_state_dict, keys_to_remove # Legacy inline processing (fallback when no pruning_mixin) @@ -801,7 +803,7 @@ def update_model_config( def override(item, item_overrides): if item_overrides is None: - return item_overrides + return item # None override means "keep original value" if dataclasses.is_dataclass(item): assert isinstance(item_overrides, dict) return dataclass_override(item, item_overrides) diff --git a/modelopt/torch/puzzletron/tools/checkpoint_utils_hf.py b/modelopt/torch/puzzletron/tools/checkpoint_utils_hf.py index 020afdfad..0afd5d5b6 100644 --- a/modelopt/torch/puzzletron/tools/checkpoint_utils_hf.py +++ b/modelopt/torch/puzzletron/tools/checkpoint_utils_hf.py @@ -22,7 +22,9 @@ import concurrent.futures import dataclasses import fcntl +import inspect import os +import shutil import time import warnings from collections import defaultdict @@ -368,6 +370,32 @@ def _build_safetensors_weight_map( return weight_map +def _copy_auto_map_code_files(model_config: PretrainedConfig, checkpoint_dir: Path) -> None: + """Copy custom modeling Python files referenced in auto_map to the checkpoint directory. + + PretrainedConfig.save_pretrained() only copies the config class's own source file. + This copies any additional files (e.g., modeling_*.py) also referenced in auto_map, + which are required when loading the checkpoint with trust_remote_code=True. + """ + if not hasattr(model_config, "auto_map"): + return + + # The config class's source file lives in the HF cache together with all other + # custom code files for this model. Walk the auto_map values to find every + # module file that needs to be present alongside config.json. + source_dir = Path(inspect.getfile(type(model_config))).parent + + module_files = { + f"{class_ref.split('.')[0]}.py" for class_ref in model_config.auto_map.values() + } + + for filename in module_files: + src = source_dir / filename + dst = Path(checkpoint_dir) / filename + if src.exists() and not dst.exists(): + shutil.copy(src, dst) + + def save_model_config(model_config: PretrainedConfig, checkpoint_dir: Path | str) -> None: if hasattr(model_config, "block_configs"): model_config.block_configs = [ @@ -375,3 +403,4 @@ def save_model_config(model_config: PretrainedConfig, checkpoint_dir: Path | str for conf in model_config.block_configs ] model_config.save_pretrained(checkpoint_dir) + _copy_auto_map_code_files(model_config, Path(checkpoint_dir)) diff --git a/modelopt/torch/puzzletron/utils/data/dataloaders.py b/modelopt/torch/puzzletron/utils/data/dataloaders.py index 892d1f3c2..ce1ff033f 100644 --- a/modelopt/torch/puzzletron/utils/data/dataloaders.py +++ b/modelopt/torch/puzzletron/utils/data/dataloaders.py @@ -71,6 +71,54 @@ def load_streaming_fn( return dataset +def create_train_dataloader( + seed: int, + tokenizer: PreTrainedTokenizerBase, + block_size: int, + dataset_path: str | Mapping[str, Dataset], + content_field: str, + fim_rate: float, + fim_spm_rate: float, + micro_batch_size: int, + load_dataset_fn: LoadDatasetFn = load_from_disk_fn, + dataset_name: str = "train", + keep_in_memory: bool = False, + shuffle_seed: int | None = None, + source_datasets_to_discard: Sequence[str] = (), + bos_rate: float = 1.0, + num_workers: int = 0, +) -> DataLoader: + """Create an infinite training DataLoader over ConstantLengthDataset.""" + if isinstance(dataset_path, str): + dataset = load_dataset_fn(dataset_path, content_field, keep_in_memory) + else: + dataset = dataset_path + + train_data = dataset[dataset_name] + if shuffle_seed is not None: + train_data = train_data.shuffle(seed=shuffle_seed, keep_in_memory=True) + + train_dataset = ConstantLengthDataset( + tokenizer, + train_data, + infinite=True, + seq_length=block_size, + content_field=content_field, + fim_rate=fim_rate, + fim_spm_rate=fim_spm_rate, + seed=seed, + source_datasets_to_discard=source_datasets_to_discard, + bos_rate=bos_rate, + ) + + return DataLoader( + train_dataset, + batch_size=micro_batch_size, + pin_memory=True, + num_workers=num_workers, + ) + + def create_validation_dataloader( accelerator: Accelerator | None, seed: int, diff --git a/modelopt/torch/puzzletron/utils/data/dataset.py b/modelopt/torch/puzzletron/utils/data/dataset.py index fffc2a3a1..c1278f054 100644 --- a/modelopt/torch/puzzletron/utils/data/dataset.py +++ b/modelopt/torch/puzzletron/utils/data/dataset.py @@ -120,7 +120,14 @@ def __iter__(self) -> dict[str, torch.Tensor]: and {"content", "role"}.issubset(sample[0]) ): if len(sample) > 1: - sample = self.tokenizer.apply_chat_template(sample, tokenize=False) + if getattr(self.tokenizer, "chat_template", None) is not None: + sample = self.tokenizer.apply_chat_template( + sample, tokenize=False + ) + else: + # Base models have no chat template — concatenate message + # contents separated by newlines as plain text. + sample = "\n".join(m["content"] for m in sample) else: sample = sample[0]["content"] else: diff --git a/modelopt/torch/puzzletron/utils/parsing.py b/modelopt/torch/puzzletron/utils/parsing.py index ff5bb6963..6a36886b0 100644 --- a/modelopt/torch/puzzletron/utils/parsing.py +++ b/modelopt/torch/puzzletron/utils/parsing.py @@ -332,6 +332,18 @@ def format_stitched_losses( if not losses_dict: return "❌ No losses found" + import math + + # Filter out nan entries — these are no-op blocks (e.g. Mamba) with no trainable parameters + losses_dict = {k: v for k, v in losses_dict.items() if not math.isnan(v)} + if best_steps_dict: + best_steps_dict = {k: v for k, v in best_steps_dict.items() if k in losses_dict} + if best_values_dict: + best_values_dict = {k: v for k, v in best_values_dict.items() if k in losses_dict} + + if not losses_dict: + return "❌ No trainable blocks found" + lines = [] # Calculate statistics diff --git a/tests/gpu/torch/puzzletron/resources/configs/meta-llama/Llama-3.1-8B-Instruct/bypass/test_bypass.yaml b/tests/gpu/torch/puzzletron/resources/configs/meta-llama/Llama-3.1-8B-Instruct/bypass/test_bypass.yaml new file mode 100644 index 000000000..0d78205c1 --- /dev/null +++ b/tests/gpu/torch/puzzletron/resources/configs/meta-llama/Llama-3.1-8B-Instruct/bypass/test_bypass.yaml @@ -0,0 +1,99 @@ +# @package bypass +# Minimal bypass config for GPU integration tests. +# Uses tiny training budget (128 tokens) and tiny model (hidden_size=256, +# intermediate_size=512, 2 layers) to run fast on CI. + +dtype: "bf16" +seed: 42 +experiment_id: +experiment_dir: +iter_num: 1 +step_num: 1 +token_count: 0 + +data: + data_column: "conversation" + block_size: 64 + bos_rate: 0.5 + fim_rate: 0 + fim_spm_rate: 0 + source_datasets_to_discard: [] + load_from_disk: true + keep_in_memory: false + val_dataset_name: valid + max_eval_samples: 1 + eval_samples_per_process: + shuffle_train_data_seed: 42 + +training: + learning_rate: 1e-4 + training_tokens: 128 + micro_batch_size: 1 + val_micro_batch_size: 1 + warmup_ratio: 0.05 + warmup_steps: + min_lr_factor: 1e-5 + grad_accumulation_steps: 1 + skip_first_batches: 0 + weight_decay: 0.1 + decay_lr: true + beta1: 0.9 + beta2: 0.95 + use_grad_scaling: false + grad_clip: 1.0 + grad_clip_type: norm + clipping_count: 0 + log_interval: 5 + eval_interval: 100 + +resume_checkpoint_path: +find_last_ckpt_for_resume: false +parameter_count: +init_checkpoint_path: + +model: + student_weights_dtype: "bf16" + model_overrides: + delete_old_checkpoints: true + save_interval_seconds: + save_interval: 1000000000 + save_checkpoint_when_done: true + model_config_overrides: + ffn: + - intermediate_size: + no_op: + attention: + - num_key_value_heads: + no_op: + +model_factory: + factory: bypass_factory_fn + block_loss_func: normalized_mse_loss + gqa_init_mode: AverageKV + mlp_init_mode: Truncate + mlp_init_config: + activations_log_dir: + linear_init_mode: FromTeacher + submodule_for_loss_calculation: + keys_to_learn: entire_block + +disable_initial_validate: true +validate_teacher_model: false +validate_student_model: false +disable_validation: true +best_val_loss: 1.0e+9 + +compile: false +disable_fa2: false +teacher_model_load_on_cpu: false + +save_checkpoint_before_training: false +disable_checkpoint_save: false +save_best_ckpt: true +kill_after_first_save: false +realize_best_or_latest: "best" + +wandb_log: false +wandb: + project: + entity: diff --git a/tests/gpu/torch/puzzletron/test_bypass.py b/tests/gpu/torch/puzzletron/test_bypass.py new file mode 100644 index 000000000..54673b415 --- /dev/null +++ b/tests/gpu/torch/puzzletron/test_bypass.py @@ -0,0 +1,526 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""GPU integration tests for bypass distillation (blockwise local distillation). + +These tests verify that: +- Bypass distillation runs end-to-end with a tiny Llama model (hidden_size=256, + intermediate_size=512, num_layers=max(2, world_size)). +- FFN pruning, KV-head compression, and multi-config sweep all produce the expected + checkpoint symlinks in puzzle_dir/ckpts/. +- The bypass config injection pattern via OmegaConf works correctly for tests that + do not load a full bypass Hydra config file. + +Model parameters used throughout: + - teacher intermediate_size: 512 -> pruned to 256 (half) for FFN tests + - teacher num_key_value_heads: 8 -> pruned to 4 for KV-head tests + - training_tokens: 128, block_size: 64, micro_batch_size: 1 -> max_steps = 2 +""" + +from datetime import timedelta +from functools import partial +from pathlib import Path + +import hydra +import torch +from _test_utils.torch.distributed.utils import spawn_multiprocess_job +from _test_utils.torch.misc import set_seed +from _test_utils.torch.puzzletron.utils import setup_test_model_and_data +from omegaconf import OmegaConf + +import modelopt.torch.puzzletron.activation_scoring.score_pruning_activations as score_pruning_activations +import modelopt.torch.puzzletron.bypass_distillation as bypass_distillation +import modelopt.torch.puzzletron.pruning.pruning_ckpts as pruning_ckpts +import modelopt.torch.utils.distributed as dist +from modelopt.torch.puzzletron.anymodel import convert_model +from modelopt.torch.puzzletron.tools.hydra_utils import initialize_hydra_config_for_dir + +# --------------------------------------------------------------------------- +# Constants +# --------------------------------------------------------------------------- + +SEED = 1234 +HF_MODEL_NAME = "meta-llama/Llama-3.1-8B-Instruct" +CONVERTER = "llama" +HYDRA_CONFIG_NAME = "meta-llama/Llama-3.1-8B-Instruct/Llama-3.1-8B-Instruct" + +# Teacher model dimensions (set by setup_test_model_and_data for Llama) +TEACHER_INTERMEDIATE_SIZE = 512 +TEACHER_NUM_KV_HEADS = 8 + +# Pruned sizes used in tests +PRUNED_INTERMEDIATE_SIZE = 256 # half of teacher +PRUNED_NUM_KV_HEADS = 4 # half of teacher + +# Training budget: 128 tokens / (64 block * 1 mbs) = 2 steps — completes fast +TRAINING_TOKENS = 128 +BLOCK_SIZE = 64 + + +# --------------------------------------------------------------------------- +# Helper: build the bypass config dict for injection into hydra_cfg +# --------------------------------------------------------------------------- + +def _make_bypass_cfg_dict( + intermediate_size: int = PRUNED_INTERMEDIATE_SIZE, + num_key_value_heads: int = PRUNED_NUM_KV_HEADS, + configs_list: list | None = None, +) -> dict: + """Return a plain-dict bypass config suitable for OmegaConf.update injection. + + Args: + intermediate_size: FFN intermediate size for the student model. + num_key_value_heads: Number of KV heads for the student model. + configs_list: If provided, populates bypass.configs for a multi-config sweep. + Each entry is a dict with ``model_config_overrides`` and optionally + ``keys_to_learn``. + """ + cfg = { + "dtype": "bf16", + "seed": 42, + "experiment_id": None, + "experiment_dir": None, + "iter_num": 1, + "step_num": 1, + "token_count": 0, + "data": { + # The dummy test dataset stores conversations under the "conversation" column. + "data_column": "conversation", + "block_size": BLOCK_SIZE, + "bos_rate": 0.5, + "fim_rate": 0, + "fim_spm_rate": 0, + "source_datasets_to_discard": [], + "load_from_disk": True, + "keep_in_memory": False, + "val_dataset_name": "valid", + "max_eval_samples": 1, + "eval_samples_per_process": None, + "shuffle_train_data_seed": 42, + }, + "training": { + "learning_rate": 1e-4, + "training_tokens": TRAINING_TOKENS, + "micro_batch_size": 1, + "val_micro_batch_size": 1, + "warmup_ratio": 0.05, + "warmup_steps": None, + "min_lr_factor": 1e-5, + "grad_accumulation_steps": 1, + "skip_first_batches": 0, + "weight_decay": 0.1, + "decay_lr": True, + "beta1": 0.9, + "beta2": 0.95, + "use_grad_scaling": False, + "grad_clip": 1.0, + "grad_clip_type": "norm", + "clipping_count": 0, + "log_interval": 5, + # Large eval_interval so validation is skipped during this short run. + # Validation is fully disabled anyway (disable_validation=True below). + "eval_interval": 100, + }, + "resume_checkpoint_path": None, + "find_last_ckpt_for_resume": False, + "parameter_count": None, + "init_checkpoint_path": None, + "model": { + "student_weights_dtype": "bf16", + "model_overrides": { + "delete_old_checkpoints": True, + "save_interval_seconds": None, + # Effectively disable step-interval saving; rely on save_checkpoint_when_done. + "save_interval": 1_000_000_000, + "save_checkpoint_when_done": True, + }, + "model_config_overrides": { + "ffn": [{"intermediate_size": intermediate_size, "no_op": None}], + "attention": [{"num_key_value_heads": num_key_value_heads, "no_op": None}], + }, + }, + "model_factory": { + "factory": "bypass_factory_fn", + "block_loss_func": "normalized_mse_loss", + "gqa_init_mode": "AverageKV", + "mlp_init_mode": "Truncate", + "mlp_init_config": {"activations_log_dir": None}, + "linear_init_mode": "FromTeacher", + "submodule_for_loss_calculation": None, + "keys_to_learn": "entire_block", + }, + # Disable all validation to keep tests fast. + "disable_initial_validate": True, + "validate_teacher_model": False, + "validate_student_model": False, + "disable_validation": True, + "best_val_loss": 1e9, + "compile": False, + "disable_fa2": False, + "teacher_model_load_on_cpu": False, + "save_checkpoint_before_training": False, + "disable_checkpoint_save": False, + "save_best_ckpt": True, + # Do NOT use kill_after_first_save — it raises RuntimeError which becomes sys.exit(1). + # Instead let the short training run (2 steps) complete naturally. + "kill_after_first_save": False, + "realize_best_or_latest": "best", + "wandb_log": False, + "wandb": {"project": None, "entity": None}, + } + + if configs_list is not None: + cfg["configs"] = configs_list + + return cfg + + +# --------------------------------------------------------------------------- +# Helper: load hydra config and run pruning prerequisites +# --------------------------------------------------------------------------- + +def _setup_hydra_cfg_and_pruning( + project_root_path: Path, + tmp_path: Path, + rank: int, + size: int, +) -> tuple: + """Set up the tiny model, convert it, score activations, and create pruning ckpts. + + This is the shared preamble for all bypass tests. Returns + ``(puzzle_dir, dataset_path, hydra_cfg)``. + + Steps performed: + 1. Create a small HF model and dummy dataset via ``setup_test_model_and_data``. + 2. Convert the HF checkpoint to AnyModel/DeciLM format (rank 0 only). + 3. Load the Hydra config with ``puzzle_dir`` and ``dataset_path`` overrides. + 4. Run ``score_pruning_activations`` (distributed). + 5. Run ``pruning_ckpts`` (rank 0 only) then barrier. + """ + set_seed(SEED) + dist.setup(timeout=timedelta(10)) + + puzzle_dir, hf_checkpoint_path, dataset_path = setup_test_model_and_data( + project_root_path, tmp_path, rank, HF_MODEL_NAME + ) + + hydra_config_dir = str( + project_root_path / "tests/gpu/torch/puzzletron/resources/configs" + ) + + # Step 0: Convert HF checkpoint to AnyModel/DeciLM format. + if rank == 0: + convert_model( + input_dir=str(hf_checkpoint_path), + output_dir=str(puzzle_dir / "ckpts/teacher"), + converter=CONVERTER, + ) + dist.barrier() + + # Step 1: Load Hydra config. + hydra_cfg = initialize_hydra_config_for_dir( + config_dir=hydra_config_dir, + config_name=HYDRA_CONFIG_NAME, + overrides=[ + f"puzzle_dir={puzzle_dir}", + f"dataset_path={dataset_path}", + ], + ) + hydra_cfg = hydra.utils.instantiate(hydra_cfg) + + # Step 2: Score pruning activations (distributed). + score_pruning_activations.launch_score_activations(hydra_cfg) + + # Step 3: Create pruning checkpoints (rank 0 only). + if rank == 0: + pruning_ckpts.launch_prune_ckpt(hydra_cfg) + dist.barrier() + + return puzzle_dir, dataset_path, hydra_cfg + + +# --------------------------------------------------------------------------- +# Tests +# --------------------------------------------------------------------------- + + +def test_bypass_ffn_pruning(project_root_path: Path, tmp_path: Path): + """Bypass distillation with FFN pruned to intermediate_size=256. + + Verifies that after training: + - The experiment directory ``bypass/bypass_runs/bypass_ffn_256_heads_4`` exists. + - A symlink ``ckpts/bypass_ffn_256_heads_4`` pointing into the experiment dir + is created by ``realize_bypass_checkpoints``. + """ + spawn_multiprocess_job( + size=torch.cuda.device_count(), + job=partial( + _test_bypass_ffn_pruning_job, + project_root_path, + tmp_path, + ), + backend="nccl", + ) + + +def _test_bypass_ffn_pruning_job( + project_root_path: Path, + tmp_path: Path, + rank: int, + size: int, +): + puzzle_dir, dataset_path, hydra_cfg = _setup_hydra_cfg_and_pruning( + project_root_path, tmp_path, rank, size + ) + + # Inject bypass config: prune FFN to 256, keep num_key_value_heads=4. + # experiment_id will be set dynamically to "bypass_ffn_256_heads_4". + bypass_cfg_dict = _make_bypass_cfg_dict( + intermediate_size=PRUNED_INTERMEDIATE_SIZE, + num_key_value_heads=PRUNED_NUM_KV_HEADS, + ) + OmegaConf.update(hydra_cfg, "bypass", bypass_cfg_dict, merge=True) + + bypass_distillation.launch_bypass_distillation(hydra_cfg) + dist.barrier() + + if rank == 0: + expected_experiment_id = ( + f"bypass_ffn_{PRUNED_INTERMEDIATE_SIZE}_heads_{PRUNED_NUM_KV_HEADS}" + ) + experiment_dir = puzzle_dir / "bypass/bypass_runs" / expected_experiment_id + ckpt_symlink = puzzle_dir / "ckpts" / expected_experiment_id + + assert experiment_dir.exists(), ( + f"Expected bypass experiment directory to exist: {experiment_dir}" + ) + assert ckpt_symlink.exists() or ckpt_symlink.is_symlink(), ( + f"Expected bypass checkpoint symlink to exist: {ckpt_symlink}" + ) + + dist.cleanup() + + print( + f"PYTEST SUMMARY: test_bypass_ffn_pruning completed successfully. " + f"Puzzle directory: {puzzle_dir}" + ) + + +def test_bypass_kv_head_compression(project_root_path: Path, tmp_path: Path): + """Bypass distillation with KV heads reduced from 8 to 4, FFN kept at 512. + + The experiment_id is ``bypass_ffn_512_heads_4`` because both FFN and attention + overrides are specified (FFN is kept at teacher size, attention is halved). + """ + spawn_multiprocess_job( + size=torch.cuda.device_count(), + job=partial( + _test_bypass_kv_head_compression_job, + project_root_path, + tmp_path, + ), + backend="nccl", + ) + + +def _test_bypass_kv_head_compression_job( + project_root_path: Path, + tmp_path: Path, + rank: int, + size: int, +): + puzzle_dir, dataset_path, hydra_cfg = _setup_hydra_cfg_and_pruning( + project_root_path, tmp_path, rank, size + ) + + # Keep FFN at teacher size (512) but halve KV heads (8 -> 4). + # experiment_id will be "bypass_ffn_512_heads_4". + bypass_cfg_dict = _make_bypass_cfg_dict( + intermediate_size=TEACHER_INTERMEDIATE_SIZE, + num_key_value_heads=PRUNED_NUM_KV_HEADS, + ) + OmegaConf.update(hydra_cfg, "bypass", bypass_cfg_dict, merge=True) + + bypass_distillation.launch_bypass_distillation(hydra_cfg) + dist.barrier() + + if rank == 0: + expected_experiment_id = ( + f"bypass_ffn_{TEACHER_INTERMEDIATE_SIZE}_heads_{PRUNED_NUM_KV_HEADS}" + ) + experiment_dir = puzzle_dir / "bypass/bypass_runs" / expected_experiment_id + ckpt_symlink = puzzle_dir / "ckpts" / expected_experiment_id + + assert experiment_dir.exists(), ( + f"Expected bypass experiment directory to exist: {experiment_dir}" + ) + assert ckpt_symlink.exists() or ckpt_symlink.is_symlink(), ( + f"Expected bypass checkpoint symlink to exist: {ckpt_symlink}" + ) + + dist.cleanup() + + print( + f"PYTEST SUMMARY: test_bypass_kv_head_compression completed successfully. " + f"Puzzle directory: {puzzle_dir}" + ) + + +def test_bypass_multi_config_sequential(project_root_path: Path, tmp_path: Path): + """Bypass distillation sweep: two configs run sequentially via bypass.configs list. + + Config 0: FFN=256, heads=4 -> experiment_id ``bypass_ffn_256_heads_4`` + Config 1: FFN=512, heads=4 -> experiment_id ``bypass_ffn_512_heads_4`` + + Both symlinks must exist after the sweep completes. + """ + spawn_multiprocess_job( + size=torch.cuda.device_count(), + job=partial( + _test_bypass_multi_config_sequential_job, + project_root_path, + tmp_path, + ), + backend="nccl", + ) + + +def _test_bypass_multi_config_sequential_job( + project_root_path: Path, + tmp_path: Path, + rank: int, + size: int, +): + puzzle_dir, dataset_path, hydra_cfg = _setup_hydra_cfg_and_pruning( + project_root_path, tmp_path, rank, size + ) + + # Build base bypass config (model_config_overrides will be overwritten by configs list). + configs_list = [ + { + "model_config_overrides": { + "ffn": [{"intermediate_size": PRUNED_INTERMEDIATE_SIZE, "no_op": None}], + "attention": [{"num_key_value_heads": PRUNED_NUM_KV_HEADS, "no_op": None}], + }, + "keys_to_learn": "entire_block", + }, + { + "model_config_overrides": { + "ffn": [{"intermediate_size": TEACHER_INTERMEDIATE_SIZE, "no_op": None}], + "attention": [{"num_key_value_heads": PRUNED_NUM_KV_HEADS, "no_op": None}], + }, + "keys_to_learn": "entire_block", + }, + ] + bypass_cfg_dict = _make_bypass_cfg_dict( + intermediate_size=PRUNED_INTERMEDIATE_SIZE, + num_key_value_heads=PRUNED_NUM_KV_HEADS, + configs_list=configs_list, + ) + OmegaConf.update(hydra_cfg, "bypass", bypass_cfg_dict, merge=True) + + bypass_distillation.launch_bypass_distillation(hydra_cfg) + dist.barrier() + + if rank == 0: + expected_ids = [ + f"bypass_ffn_{PRUNED_INTERMEDIATE_SIZE}_heads_{PRUNED_NUM_KV_HEADS}", + f"bypass_ffn_{TEACHER_INTERMEDIATE_SIZE}_heads_{PRUNED_NUM_KV_HEADS}", + ] + for experiment_id in expected_ids: + experiment_dir = puzzle_dir / "bypass/bypass_runs" / experiment_id + ckpt_symlink = puzzle_dir / "ckpts" / experiment_id + + assert experiment_dir.exists(), ( + f"Expected bypass experiment directory to exist: {experiment_dir}" + ) + assert ckpt_symlink.exists() or ckpt_symlink.is_symlink(), ( + f"Expected bypass checkpoint symlink to exist: {ckpt_symlink}" + ) + + dist.cleanup() + + print( + f"PYTEST SUMMARY: test_bypass_multi_config_sequential completed successfully. " + f"Puzzle directory: {puzzle_dir}" + ) + + +def test_bypass_checkpoint_contents(project_root_path: Path, tmp_path: Path): + """Verify that a bypass checkpoint contains expected HuggingFace model files. + + After bypass completes, the checkpoint directory (reachable via the symlink at + ``ckpts/{experiment_id}``) must contain a ``config.json`` (saved by + ``save_checkpoint`` / ``save_bypass_checkpoint``). + """ + spawn_multiprocess_job( + size=torch.cuda.device_count(), + job=partial( + _test_bypass_checkpoint_contents_job, + project_root_path, + tmp_path, + ), + backend="nccl", + ) + + +def _test_bypass_checkpoint_contents_job( + project_root_path: Path, + tmp_path: Path, + rank: int, + size: int, +): + puzzle_dir, dataset_path, hydra_cfg = _setup_hydra_cfg_and_pruning( + project_root_path, tmp_path, rank, size + ) + + bypass_cfg_dict = _make_bypass_cfg_dict( + intermediate_size=PRUNED_INTERMEDIATE_SIZE, + num_key_value_heads=PRUNED_NUM_KV_HEADS, + ) + OmegaConf.update(hydra_cfg, "bypass", bypass_cfg_dict, merge=True) + + bypass_distillation.launch_bypass_distillation(hydra_cfg) + dist.barrier() + + if rank == 0: + expected_experiment_id = ( + f"bypass_ffn_{PRUNED_INTERMEDIATE_SIZE}_heads_{PRUNED_NUM_KV_HEADS}" + ) + ckpt_symlink = puzzle_dir / "ckpts" / expected_experiment_id + + assert ckpt_symlink.exists() or ckpt_symlink.is_symlink(), ( + f"Expected bypass checkpoint symlink: {ckpt_symlink}" + ) + + # The symlink resolves to the latest checkpoint dir; verify HF config exists. + resolved = ckpt_symlink.resolve() + config_json = resolved / "config.json" + assert config_json.exists(), ( + f"Expected HuggingFace config.json inside checkpoint: {config_json}" + ) + + # The saving_completed marker must be present (set by save_bypass_checkpoint). + saving_completed = resolved / "saving_completed" + assert saving_completed.exists(), ( + f"Expected saving_completed marker inside checkpoint: {saving_completed}" + ) + + dist.cleanup() + + print( + f"PYTEST SUMMARY: test_bypass_checkpoint_contents completed successfully. " + f"Puzzle directory: {puzzle_dir}" + ) diff --git a/tests/gpu/torch/puzzletron/test_puzzletron.py b/tests/gpu/torch/puzzletron/test_puzzletron.py index f3f49bed2..2ce97ef61 100644 --- a/tests/gpu/torch/puzzletron/test_puzzletron.py +++ b/tests/gpu/torch/puzzletron/test_puzzletron.py @@ -206,25 +206,36 @@ def _assert_score_pruning_activations(puzzle_dir: Path, hf_model_name: str): size = dist.size() if expected is not None: - # In multi-GPU: layers are distributed across ranks - # Each rank processes len(expected) // size layers - expected_layers_per_rank = len(expected) // size - assert len(layer_names) == expected_layers_per_rank, ( - f"Expected {expected_layers_per_rank} FFN layers on rank {rank}/{size}, got {len(layer_names)}" + # The test model has num_hidden_layers = max(2, size), so every rank owns at least + # one layer. Compute the actual expected count for *this* rank. + total_layers = max(2, size) + layers_this_rank = total_layers // size + (1 if rank < total_layers % size else 0) + assert len(layer_names) == layers_this_rank, ( + f"Expected {layers_this_rank} FFN layers on rank {rank}/{size}, got {len(layer_names)}" ) - # Check each layer's values - for i, layer_name in enumerate(layer_names): - layer_data = pruning_scores[layer_name] - # Calculate global layer index from rank and local index - global_idx = rank * expected_layers_per_rank + i - assert layer_data["score"][0].item() == expected[global_idx]["score"] - assert ( - layer_data["channels_importance_ascending"][0].item() - == expected[global_idx]["channels"] + + # Numerical score checks are only meaningful when the expected table was + # collected with the same GPU count (same total_layers). When running on + # more GPUs than the table covers, skip the per-value assertions rather than + # failing: the layer-count check above already confirms the distribution is right. + if len(expected) == total_layers: + global_start = sum( + max(2, size) // size + (1 if r < max(2, size) % size else 0) + for r in range(rank) ) + for i, layer_name in enumerate(layer_names): + layer_data = pruning_scores[layer_name] + global_idx = global_start + i + assert layer_data["score"][0].item() == expected[global_idx]["score"] + assert ( + layer_data["channels_importance_ascending"][0].item() + == expected[global_idx]["channels"] + ) else: # Print values for new models - update EXPECTED_PRUNING_VALUES with these - print(f"\n=== PRUNING VALUES for {hf_model_name} (num_layers={len(layer_names)}) ===") + # Note: values depend on GPU count (num_hidden_layers = max(2, size)). + total_layers = max(2, size) + print(f"\n=== PRUNING VALUES for {hf_model_name} (num_layers={total_layers}) ===") print(f'"{hf_model_name}": [') for layer_name in layer_names: layer_data = pruning_scores[layer_name] diff --git a/tests/unit/torch/puzzletron/__init__.py b/tests/unit/torch/puzzletron/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/unit/torch/puzzletron/test_bypass_losses.py b/tests/unit/torch/puzzletron/test_bypass_losses.py new file mode 100644 index 000000000..759fb5fa3 --- /dev/null +++ b/tests/unit/torch/puzzletron/test_bypass_losses.py @@ -0,0 +1,117 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Unit tests for normalized MSE loss functions in sewing_kit/utils.py.""" + +import pytest +import torch + +from modelopt.torch.puzzletron.sewing_kit.utils import ( + batched_normalized_mse_loss, + normalized_mse_loss, + vectorwise_normalized_mse_loss, +) + + +# --------------------------------------------------------------------------- +# normalized_mse_loss +# --------------------------------------------------------------------------- + + +def test_normalized_mse_loss_identical_tensors(): + """Identical input and target should produce a loss of approximately 0.""" + torch.manual_seed(42) + x = torch.randn(4, 8) + loss = normalized_mse_loss(x, x) + assert torch.allclose(loss, torch.zeros_like(loss), atol=1e-6) + + +def test_normalized_mse_loss_basic(): + """Loss should be positive and finite for random, non-identical tensors.""" + torch.manual_seed(42) + input_ = torch.randn(4, 8) + target = torch.randn(4, 8) + loss = normalized_mse_loss(input_, target) + assert loss.item() > 0.0 + assert torch.isfinite(loss) + + +def test_normalized_mse_loss_reduction_none(): + """With reduction='none' the output shape should match the input shape.""" + torch.manual_seed(42) + input_ = torch.randn(4, 8) + target = torch.randn(4, 8) + loss = normalized_mse_loss(input_, target, reduction="none") + assert loss.shape == input_.shape + + +def test_normalized_mse_loss_reduction_sum(): + """With reduction='sum' the output should be a scalar tensor.""" + torch.manual_seed(42) + input_ = torch.randn(4, 8) + target = torch.randn(4, 8) + loss = normalized_mse_loss(input_, target, reduction="sum") + assert loss.ndim == 0 # scalar + assert torch.isfinite(loss) + + +# --------------------------------------------------------------------------- +# vectorwise_normalized_mse_loss +# --------------------------------------------------------------------------- + + +def test_vectorwise_normalized_mse_loss_shape(): + """vectorwise_normalized_mse_loss should return a scalar for any 2-D input.""" + torch.manual_seed(42) + input_ = torch.randn(4, 16) + target = torch.randn(4, 16) + loss = vectorwise_normalized_mse_loss(input_, target) + assert loss.ndim == 0 # scalar + assert torch.isfinite(loss) + + +def test_vectorwise_normalized_mse_loss_identical(): + """Identical input and target should give a loss of approximately 0.""" + torch.manual_seed(42) + x = torch.randn(4, 16) + loss = vectorwise_normalized_mse_loss(x, x) + assert torch.allclose(loss, torch.zeros_like(loss), atol=1e-6) + + +# --------------------------------------------------------------------------- +# batched_normalized_mse_loss +# --------------------------------------------------------------------------- + + +def test_batched_normalized_mse_loss_basic(): + """Should return a scalar with a positive, finite value for random tensors.""" + torch.manual_seed(42) + input_ = torch.randn(4, 8) + target = torch.randn(4, 8) + loss = batched_normalized_mse_loss(input_, target) + assert loss.ndim == 0 # scalar + assert loss.item() > 0.0 + assert torch.isfinite(loss) + + +def test_batched_normalized_mse_loss_custom_dims(): + """Custom batch_dims=(0, 1) on a 3-D tensor should still return a scalar.""" + torch.manual_seed(42) + input_ = torch.randn(2, 3, 8) + target = torch.randn(2, 3, 8) + loss = batched_normalized_mse_loss(input_, target, batch_dims=(0, 1)) + assert loss.ndim == 0 # scalar + assert torch.isfinite(loss) + assert loss.item() > 0.0 diff --git a/tests/unit/torch/puzzletron/test_bypass_utils.py b/tests/unit/torch/puzzletron/test_bypass_utils.py new file mode 100644 index 000000000..c34bd017d --- /dev/null +++ b/tests/unit/torch/puzzletron/test_bypass_utils.py @@ -0,0 +1,87 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Unit tests for get_distributed_modules_ownership in bypass_utils.py.""" + +import pytest + +from modelopt.torch.puzzletron.bypass_distillation.bypass_utils import ( + get_distributed_modules_ownership, +) + + +def test_single_gpu_all_to_rank_0(): + """With world_size=1, all 4 modules should be assigned to rank 0.""" + ownership = get_distributed_modules_ownership(module_count=4, world_size=1) + assert ownership == [0, 0, 0, 0] + + +def test_even_distribution(): + """With world_size=2 and 4 modules, each rank should own exactly 2 modules.""" + ownership = get_distributed_modules_ownership(module_count=4, world_size=2) + assert ownership.count(0) == 2 + assert ownership.count(1) == 2 + assert len(ownership) == 4 + + +def test_uneven_distribution(): + """With world_size=2 and 3 modules, rank 0 should own 2 and rank 1 should own 1.""" + ownership = get_distributed_modules_ownership(module_count=3, world_size=2) + assert ownership.count(0) == 2 + assert ownership.count(1) == 1 + assert len(ownership) == 3 + + +@pytest.mark.parametrize( + "module_count, world_size", + [ + (1, 1), + (4, 1), + (4, 2), + (4, 4), + (7, 3), + (10, 4), + (1, 2), + ], +) +def test_total_equals_module_count(module_count, world_size): + """The length of the ownership list must always equal module_count.""" + ownership = get_distributed_modules_ownership( + module_count=module_count, world_size=world_size + ) + assert len(ownership) == module_count + + +def test_consecutive_ownership(): + """Each rank should own a contiguous block of indices (no interleaving).""" + ownership = get_distributed_modules_ownership(module_count=7, world_size=3) + # Verify that once we see a new rank, we never see the previous rank again. + seen_ranks = set() + prev_rank = ownership[0] + seen_ranks.add(prev_rank) + for rank in ownership[1:]: + if rank != prev_rank: + assert rank not in seen_ranks, ( + f"Rank {rank} appears non-consecutively in ownership list: {ownership}" + ) + seen_ranks.add(rank) + prev_rank = rank + + +def test_single_module(): + """With world_size=2 and only 1 module, rank 0 should be the sole owner.""" + ownership = get_distributed_modules_ownership(module_count=1, world_size=2) + assert ownership == [0] + assert len(ownership) == 1