Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 20 additions & 0 deletions configs/rwkv7_380M.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
{
"model_type": "rwkv7",
"hidden_size": 1024,
"num_hidden_layers": 24,
"head_dim": 64,
"hidden_ratio": 4,
"hidden_act": "sqrelu",
"vocab_size": 32000,
"decay_low_rank_dim": 64,
"gate_low_rank_dim": 128,
"a_low_rank_dim": 64,
"v_low_rank_dim": 16,
"tie_word_embeddings": false,
"fuse_cross_entropy": true,
"fuse_norm": true,
"use_l2warp": false,
"attn_mode": "chunk",
"bos_token_id": 1,
"eos_token_id": 2
}
18 changes: 14 additions & 4 deletions flame/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -264,21 +264,31 @@ def main(job_config: JobConfig):
# We need to iterate through model_parts to apply SPMD parallelisms, compilation,
# optimizer, and checkpointing
for m in model_parts:
# apply SPMD-style PT-D techniques
train_spec.parallelize_fn(m, world_mesh, parallel_dims, job_config)
# Materialize and initialize weights before applying parallelisms.
# Some models (e.g. RWKV-7) compute position-dependent init values
# as regular tensors and assign via .data, which is incompatible
# with DTensors created by FSDP. Cost: full model in fp32
# temporarily on each rank (~4 bytes/param). This fits comfortably
# for models up to ~10B on 40GB GPUs or ~20B on 80GB GPUs. For
# larger models, set init_device="cpu" (e.g. via
# --training.enable_cpu_offload).
m.to_empty(device=init_device)
with torch.no_grad():
m.post_init()
# apply SPMD-style PT-D techniques
train_spec.parallelize_fn(m, world_mesh, parallel_dims, job_config)
m.train()

# confirm that user will be able to view loss metrics on the console
ensure_pp_loss_visible(parallel_dims, job_config, color)
else:
# apply PT-D Tensor Parallel, activation checkpointing, torch.compile, Data Parallel
train_spec.parallelize_fn(model, world_mesh, parallel_dims, job_config)
# Materialize and initialize weights before parallelization
# (see PP path comment above for rationale and memory thresholds).
model.to_empty(device=init_device)
with torch.no_grad():
model.post_init()
# apply PT-D Tensor Parallel, activation checkpointing, torch.compile, Data Parallel
train_spec.parallelize_fn(model, world_mesh, parallel_dims, job_config)
Comment on lines 287 to +291
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

Moving to_empty() and post_init() before parallelize_fn() introduces a significant memory risk for large models. In the non-pipeline-parallel path, this causes the entire model to be materialized on the init_device (typically GPU) for every rank before FSDP or Tensor Parallelism can shard the parameters.\n\nFor models that are larger than a single GPU's memory (e.g., 70B+ models on 80GB GPUs), this will lead to an immediate Out-Of-Memory (OOM) error during initialization. The previous order was more memory-efficient as it allowed FSDP to create sharded meta-tensors that were materialized only as shards.\n\nWhile this change fixes the .data= assignment issue for specific models like RWKV-7, it regresses the scalability of the training script. If this behavior is necessary for specific models, consider making it optional or ensuring init_device is set to 'cpu' when training very large models.

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The largest config in the repo is 7B (28 GB in fp32 at init, since FSDP's bf16 casting applies after parallelization). For FSDP-only setups the threshold where init OOMs is >10B params on 40GB GPUs or >20B on 80GB GPUs, below current flame model range. For larger scales, init_device="cpu" (already supported in train.py) is a straightforward fallback with minimal overhead (measured 5s at 7B vs. hours of training, see updated comment in code).

model.train()

model_parts = [model]
Expand Down