-
Notifications
You must be signed in to change notification settings - Fork 999
[WIP] add mindspeed support #6689
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
ad2e944
8b42957
97dd8a8
28f10c9
a8ed577
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||||
|---|---|---|---|---|---|---|---|---|
| @@ -0,0 +1,45 @@ | ||||||||
| PYTORCH_NPU_ALLOC_CONF='expandable_segments:True' \ | ||||||||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It is recommended to place the examples in the examples/npu/megatron directory. In the future, the npu directory can provide more NPU-related scripts. It is recommended to add a section on NPU environment installation in the Megatron-SWIFT Quick Start documentation. |
||||||||
| NPROC_PER_NODE=4 \ | ||||||||
| ASCEND_RT_VISIBLE_DEVICES=0,1,2,3 \ | ||||||||
| megatron sft \ | ||||||||
| --model Qwen/Qwen3-VL-30B-A3B-Instruct \ | ||||||||
| --load_safetensors true \ | ||||||||
| --save_safetensors true \ | ||||||||
| --dataset 'AI-ModelScope/LaTeX_OCR:human_handwrite#5000' \ | ||||||||
| --load_from_cache_file true \ | ||||||||
| --train_type full \ | ||||||||
| --freeze_llm false \ | ||||||||
| --freeze_vit true \ | ||||||||
| --freeze_aligner true \ | ||||||||
| --split_dataset_ratio 0.01 \ | ||||||||
| --tensor_model_parallel_size 2 \ | ||||||||
| --sequence_parallel true \ | ||||||||
| --expert_model_parallel_size 2 \ | ||||||||
| --pipeline_model_parallel_size 2 \ | ||||||||
| --decoder_first_pipeline_num_layers 5 \ | ||||||||
| --moe_aux_loss_coeff 1e-3 \ | ||||||||
| --micro_batch_size 1 \ | ||||||||
| --global_batch_size 4 \ | ||||||||
| --finetune true \ | ||||||||
| --lr 1e-4 \ | ||||||||
| --lr_warmup_fraction 0.05 \ | ||||||||
| --min_lr 1e-5 \ | ||||||||
| --max_epochs 1 \ | ||||||||
| --save megatron_output/Qwen3-VL-30B-A3B \ | ||||||||
| --eval_interval 200 \ | ||||||||
| --save_interval 200 \ | ||||||||
| --vit_gradient_checkpointing true \ | ||||||||
| --max_length 2048 \ | ||||||||
| --num_workers 8 \ | ||||||||
| --dataset_num_proc 8 \ | ||||||||
| --no_save_optim true \ | ||||||||
| --no_save_rng true \ | ||||||||
| --moe_grouped_gemm true \ | ||||||||
| --moe_shared_expert_overlap true \ | ||||||||
| --packing true \ | ||||||||
| --cross_entropy_loss_fusion true \ | ||||||||
| --recompute_granularity full \ | ||||||||
| --recompute_method uniform \ | ||||||||
| --recompute_num_layers 1 \ | ||||||||
| --attention_backend flash | ||||||||
| # --moe_permute_fusion true \ | ||||||||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This commented-out line includes an unnecessary trailing backslash. Additionally, the file is missing a newline character at the end. According to POSIX standards, a text file should conclude with a newline to ensure compatibility with various command-line tools.
Suggested change
|
||||||||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The logic for setting
use_flash_attncan be simplified for better clarity and to avoid redundancy. It's clearer to modifyargs.use_flash_attndirectly and then create themegatron_argsdictionary from the updatedargsobject. This avoids modifying both the dictionary and theargsobject separately for the same value.