Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 10 additions & 2 deletions examples/speculative_decoding/eagle_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,15 @@ def __len__(self):
return len(self.dumped_files)

def __getitem__(self, i) -> dict[str, torch.Tensor]:
offline_data = torch.load(self.dumped_files[i])
try:
offline_data = torch.load(self.dumped_files[i])
except Exception as e:
print(
f"[ERROR] Failed to load file at index={i}, "
f"path='{self.dumped_files[i]}', error={e}. "
"Reusing data from previous index (i-1)."
)
return self.__getitem__(i - 1)

labels = torch.full_like(offline_data["input_ids"], IGNORE_TOKEN_ID)
labels[..., :-1] = offline_data["input_ids"][..., 1:]
Expand Down Expand Up @@ -154,7 +162,7 @@ def make_eagle_supervised_data_module(
assert not data_args.vlm_processor, "Offline data is not supported for VLM."

offline_data_path = Path(data_args.offline_data_path)
dumped_files = [str(p) for p in offline_data_path.glob("*.pt")]
dumped_files = [str(p) for p in offline_data_path.rglob("*.pt")]
if not dumped_files:
raise ValueError(f"No .pt files found in {data_args.offline_data_path}")

Expand Down
40 changes: 32 additions & 8 deletions examples/speculative_decoding/launch_train.sh
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,18 @@ while [ $# -gt 0 ]; do
if [[ "$1" != *=* ]]; then shift; fi
DISABLE_TORCH_COMPILE="${1#*=}"
;;
--use_fake_base_for_offline*)
if [[ "$1" != *=* ]]; then shift; fi
USE_FAKE_BASE_FOR_OFFLINE="${1#*=}"
;;
--trust_remote_code*)
if [[ "$1" != *=* ]]; then shift; fi
TRUST_REMOTE_CODE="${1#*=}"
;;
--fsdp*)
if [[ "$1" != *=* ]]; then shift; fi
FSDP="${1#*=}"
;;
*)
>&2 printf "Error: Invalid argument ${1#*=}\n"
exit 1
Expand All @@ -134,9 +146,16 @@ set -x

SCRIPT_DIR="$(dirname "$(readlink -f "$0")")"
NUM_NODES=${NUM_NODES:-1}
GPU_PER_NODE=${GPU_PER_NODE:-$(nvidia-smi --query-gpu=name --format=csv,noheader | wc -l)}
TOTAL_GPU=$((NUM_NODES * GPU_PER_NODE))
echo "Total GPUs: $TOTAL_GPU (NUM_NODES: $NUM_NODES, GPU_PER_NODE: $GPU_PER_NODE)"
if [[ "$NUM_NODES" != 1 ]]; then
#Multi Node Training
GPU_PER_NODE=${GPU_PER_NODE:-$(nvidia-smi --query-gpu=name --format=csv,noheader | wc -l)}
TOTAL_GPU=$((NUM_NODES * GPU_PER_NODE))
echo "Total GPUs: $TOTAL_GPU (NUM_NODES: $NUM_NODES, GPU_PER_NODE: $GPU_PER_NODE)"
else
#Single Node Training, GPU can be specified by $CUDA_VISIBLE_DEVICES
TOTAL_GPU=$(python -c "import torch; print(torch.cuda.device_count())")
echo "Total GPUs: $TOTAL_GPU (Single Node Training)"
fi
# Calculate save_steps
DEFAULT_SAVE_STEPS=$((8192 / TOTAL_GPU))

Expand All @@ -151,6 +170,7 @@ SAVE_STEPS=${SAVE_STEPS:-$DEFAULT_SAVE_STEPS}
LR=${LR:-"1e-4"}
TRAIN_BS=${TRAIN_BS:-1}
TRAINING_SEQ_LEN=${TRAINING_SEQ_LEN:-2048}
DATA=${DATA:-""}
OFFLINE_DATA_PATH=${OFFLINE_DATA_PATH:-""}
DISABLE_TQDM=${DISABLE_TQDM:-False}
VLM_PROCESSOR=${VLM_PROCESSOR:-}
Expand All @@ -165,6 +185,9 @@ MIX_HIDDEN_STATES=${MIX_HIDDEN_STATES:-"False"}
DISABLE_TORCH_COMPILE=${DISABLE_TORCH_COMPILE:-"False"}
NUM_TTT_STEPS=${NUM_TTT_STEPS:-3}

USE_FAKE_BASE_FOR_OFFLINE=${USE_FAKE_BASE_FOR_OFFLINE:-"False"}
TRUST_REMOTE_CODE=${TRUST_REMOTE_CODE:-"False"}
FSDP=${FSDP:-"False"}

if [[ "$MODE" == "eagle3" ]]; then
if [[ -n "$EAGLE_CONFIG" ]]; then
Expand All @@ -182,10 +205,10 @@ if [[ "$OFFLINE_DATA_PATH" != "" ]]; then
echo "Offline data path $OFFLINE_DATA_PATH does not exist or is not a directory."
exit 1
else
OFFLINE_TRAINING_ARGS="--offline-data-path $OFFLINE_DATA_PATH --ar_validate_steps -1"
DATA_ARGS="--offline-data-path $OFFLINE_DATA_PATH --ar_validate_steps -1"
fi
else
OFFLINE_TRAINING_ARGS=""
DATA_ARGS="--data_path $DATA"
fi


Expand All @@ -195,7 +218,7 @@ else
VLM_ARGS=""
fi

if [[ "$TOTAL_GPU" -gt 1 ]]; then
if [[ "$TOTAL_GPU" -gt 1 && "$FSDP" == "True" ]]; then
#Use FSDP2 when multi GPU available
FSDP_ARGS="--fsdp 'full_shard' --fsdp_config ${SCRIPT_DIR}/fsdp_config.json"
else
Expand Down Expand Up @@ -245,15 +268,16 @@ CMD="accelerate launch $MULTI_NODE_ARGS --mixed_precision bf16 ${SCRIPT_DIR}/mai
--lr_scheduler_type linear \
--logging_steps $LOG_STEPS \
--tf32 True \
--data_path $DATA \
$DATA_ARGS \
--disable_tqdm $DISABLE_TQDM \
--estimate_ar $ESTIMATE_AR \
--ar_validate_steps $AR_VALIDATE_STEPS \
--mix_hidden_states $MIX_HIDDEN_STATES \
--disable_torch_compile $DISABLE_TORCH_COMPILE \
--use_fake_base_for_offline $USE_FAKE_BASE_FOR_OFFLINE \
--trust_remote_code $TRUST_REMOTE_CODE \
$DRAFT_VOCAB_CACHE_ARGS \
$VLM_ARGS \
$OFFLINE_TRAINING_ARGS \
$SPECULATIVE_ARGS \
$FSDP_ARGS \
--cp_size $CP_SIZE \
Expand Down
36 changes: 20 additions & 16 deletions examples/speculative_decoding/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,10 +47,7 @@

import modelopt.torch.opt as mto
import modelopt.torch.speculative as mtsp
from modelopt.torch.speculative.utils import (
load_vlm_or_llm_with_kwargs,
patch_transformers5_params_loading,
)
from modelopt.torch.speculative.utils import load_vlm_or_llm, patch_transformers5_params_loading
from modelopt.torch.utils import print_rank_0

torch.manual_seed(0)
Expand All @@ -60,11 +57,18 @@
@dataclass
class ModelArguments:
model_name_or_path: str | None = field(default="TinyLlama/TinyLlama-1.1B-Chat-v1.0")
use_fake_base_for_offline: bool = field(
default=False, metadata={"help": "Whether to use fake base for offline training."}
)
trust_remote_code: bool = field(
default=False, metadata={"help": "Whether to trust remote code."}
)


@dataclass
class DataArguments:
data_path: str = field(
default=None,
metadata={"help": "Path to the training data."},
)
eval_data_path: str = field(default=None, metadata={"help": "Path to the evaluation data."})
Expand Down Expand Up @@ -153,6 +157,8 @@ def train():
model_args, data_args, training_args, medusa_args, eagle_args = (
parser.parse_args_into_dataclasses()
)
if not data_args.data_path and not data_args.offline_data_path:
raise ValueError("Either data_path or offline_data_path must be provided.")
if training_args.cp_size > 1 or training_args.dp_shard_size > 1:
training_args.parallelism_config = ParallelismConfig(
cp_size=training_args.cp_size, dp_shard_size=training_args.dp_shard_size
Expand All @@ -178,29 +184,27 @@ def train():

if checkpoint:
with patch_transformers5_params_loading():
_, model = load_vlm_or_llm_with_kwargs(
checkpoint, torch_dtype="auto", trust_remote_code=True
model = load_vlm_or_llm(
checkpoint, torch_dtype="auto", trust_remote_code=model_args.trust_remote_code
)
tokenizer = transformers.AutoTokenizer.from_pretrained(checkpoint, trust_remote_code=True)
tokenizer = transformers.AutoTokenizer.from_pretrained(
checkpoint, trust_remote_code=model_args.trust_remote_code
)
else:
# To avoid OOM for large models, we load and convert model on CPU first.
# Model will be moved to GPU during HF trainer.init().
offline_kwargs = {"num_hidden_layers": 0} if use_offline_training else {}
model_config, model = load_vlm_or_llm_with_kwargs(
model = load_vlm_or_llm(
model_args.model_name_or_path,
use_fake_base=model_args.use_fake_base_for_offline,
use_offline_training=use_offline_training,
torch_dtype="auto",
device_map="cpu",
trust_remote_code=True,
**offline_kwargs,
trust_remote_code=model_args.trust_remote_code,
)
if use_offline_training:
# When doing offline training, we need to set num_hidden_layers
# since we override it when loading the model for space savings
model.config.num_orig_hidden_layers = model_config.num_hidden_layers
tokenizer = transformers.AutoTokenizer.from_pretrained(
model_args.model_name_or_path,
model_max_length=training_args.training_seq_len,
trust_remote_code=True,
trust_remote_code=model_args.trust_remote_code,
)
if training_args.mode == "medusa":
config = {
Expand Down
4 changes: 2 additions & 2 deletions examples/speculative_decoding/scripts/ar_validate.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@

import modelopt.torch.opt as mto
from modelopt.torch.speculative.plugins.transformers import HFARValidation
from modelopt.torch.speculative.utils import load_vlm_or_llm_with_kwargs
from modelopt.torch.speculative.utils import load_vlm_or_llm

mto.enable_huggingface_checkpointing()

Expand Down Expand Up @@ -72,7 +72,7 @@ def main():

accelerator = Accelerator()
# Load model and tokenizer
_, model = load_vlm_or_llm_with_kwargs(args.model_path, device_map="auto")
model = load_vlm_or_llm(args.model_path, device_map="auto")
tokenizer = AutoTokenizer.from_pretrained(args.model_path)
model.eval()
model = accelerator.prepare(model)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@

import modelopt.torch.opt as mto
from modelopt.torch.export import export_speculative_decoding
from modelopt.torch.speculative.utils import load_vlm_or_llm_with_kwargs
from modelopt.torch.speculative.utils import load_vlm_or_llm


def parse_args():
Expand All @@ -38,7 +38,7 @@ def parse_args():
mto.enable_huggingface_checkpointing()

args = parse_args()
_, model = load_vlm_or_llm_with_kwargs(args.model_path, torch_dtype="auto")
model = load_vlm_or_llm(args.model_path, torch_dtype="auto")
model.eval()
with torch.inference_mode():
export_speculative_decoding(
Expand Down
Loading
Loading