Skip to content

Commit 6d45b3a

Browse files
Simplify megatron to hf export logic
Signed-off-by: Keval Morabia <28916987+kevalmorabia97@users.noreply.github.com>
1 parent c750237 commit 6d45b3a

4 files changed

Lines changed: 12 additions & 108 deletions

File tree

examples/puzzletron/mbridge_distillation/README.md

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -71,7 +71,6 @@ torchrun --nproc_per_node=8 examples/puzzletron/mbridge_distillation/distill_hf.
7171
--data_paths 1.0 /path/to/hf_datasets/wikitext-103-v1/Salesforce--wikitext_wikitext-103-v1_train_text_document \
7272
--output_dir /path/to/distilled/checkpoint \
7373
--hf_export_path /path/to/exported/hf/model \
74-
--hf_model meta-llama/Llama-3.1-8B-Instruct \
7574
--seq_length 4096 \
7675
--tp_size 8 \
7776
--pp_size 1 \

examples/puzzletron/mbridge_distillation/distill_hf.py

Lines changed: 11 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -44,10 +44,9 @@
4444
from megatron.core.datasets.utils import get_blend_from_list
4545
from megatron.core.distributed import DistributedDataParallelConfig
4646

47+
# Import to register heterogeneous bridges (side effect)
48+
import modelopt.torch.puzzletron.export.mbridge # noqa: F401
4749
import modelopt.torch.utils.distributed as dist
48-
from modelopt.torch.puzzletron.export.mbridge.export_mbridge_to_hf import (
49-
export_to_hf_and_copy_config,
50-
)
5150
from modelopt.torch.utils import print_rank_0
5251

5352
SEED = 1234
@@ -129,13 +128,6 @@ def get_args():
129128
"If provided, exports last iteration checkpoint to HF format after distillation."
130129
),
131130
)
132-
parser.add_argument(
133-
"--hf_model",
134-
type=str,
135-
required=True,
136-
help="HuggingFace model ID to use as template for export (e.g., meta-llama/Llama-3.1-8B-Instruct). "
137-
"Should match the base architecture of the student model.",
138-
)
139131
args = parser.parse_args()
140132

141133
# Sanity checks
@@ -272,13 +264,15 @@ def _build_model_provider(hf_path):
272264

273265
# Only rank 0 exports
274266
if is_rank_0:
275-
export_to_hf_and_copy_config(
276-
student_hf_path=args.student_hf_path,
277-
checkpoint_dir=checkpoint_dir,
278-
train_iters=args.train_iters,
279-
hf_export_path=args.hf_export_path,
280-
hf_model=args.hf_model,
281-
trust_remote_code=args.trust_remote_code,
267+
bridge = AutoBridge.from_hf_pretrained(
268+
args.student_hf_path, trust_remote_code=args.trust_remote_code
269+
)
270+
os.makedirs(os.path.join(args.hf_export_path, "subblocks_safetensors"), exist_ok=True)
271+
bridge.export_ckpt(
272+
megatron_path=f"{checkpoint_dir}/iter_{args.train_iters:07d}",
273+
hf_path=args.hf_export_path,
274+
show_progress=True,
275+
strict=True,
282276
)
283277

284278

modelopt/torch/puzzletron/export/mbridge/export_mbridge_to_hf.py

Lines changed: 0 additions & 89 deletions
This file was deleted.

tests/examples/puzzletron/mbridge_distillation/test_distill_hf.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@ def test_distill_hf(project_root_path: Path, tmp_path: Path):
3232
and runs mbridge distillation. The models are created with reduced size for faster testing.
3333
Models are converted to include block_configs.
3434
"""
35+
tmp_path = Path("/tmp/test_distill_hf")
3536
# Prepare student and teacher models
3637
student_hf_path, teacher_hf_path = _prepare_student_and_teacher_models(
3738
project_root_path, tmp_path
@@ -74,7 +75,6 @@ def test_distill_hf(project_root_path: Path, tmp_path: Path):
7475
eval_iters=0,
7576
log_interval=5,
7677
hf_export_path=hf_export_dir,
77-
hf_model="Qwen/Qwen3-0.6B",
7878
)
7979

8080
run_example_command(cmd_parts, example_path="puzzletron/mbridge_distillation")

0 commit comments

Comments
 (0)