From e04cf57c58132cfb85b9864796d18df6fd9505a8 Mon Sep 17 00:00:00 2001 From: puigde Date: Thu, 9 Apr 2026 13:46:25 +0200 Subject: [PATCH 1/3] Initialize weights before parallelization Models with custom position-dependent initialization (e.g. RWKV-7) assign values via .data=, which is incompatible with DTensors created by FSDP. Move to_empty() and post_init() before parallelize_fn() so that weight initialization operates on plain tensors. The temporary full-model allocation is freed once FSDP shards the parameters. Closes #22 --- flame/train.py | 13 +++++++++---- 1 file changed, 9 insertions(+), 4 deletions(-) diff --git a/flame/train.py b/flame/train.py index 461fcc09..97fcb4d9 100644 --- a/flame/train.py +++ b/flame/train.py @@ -264,21 +264,26 @@ def main(job_config: JobConfig): # We need to iterate through model_parts to apply SPMD parallelisms, compilation, # optimizer, and checkpointing for m in model_parts: - # apply SPMD-style PT-D techniques - train_spec.parallelize_fn(m, world_mesh, parallel_dims, job_config) + # Initialize weights before parallelization so that models with + # custom .data= init (e.g. RWKV-7) operate on plain tensors + # instead of DTensors. The temporary full-size allocation is + # freed once FSDP shards the parameters. m.to_empty(device=init_device) with torch.no_grad(): m.post_init() + # apply SPMD-style PT-D techniques + train_spec.parallelize_fn(m, world_mesh, parallel_dims, job_config) m.train() # confirm that user will be able to view loss metrics on the console ensure_pp_loss_visible(parallel_dims, job_config, color) else: - # apply PT-D Tensor Parallel, activation checkpointing, torch.compile, Data Parallel - train_spec.parallelize_fn(model, world_mesh, parallel_dims, job_config) + # Initialize weights before parallelization (see PP path above). model.to_empty(device=init_device) with torch.no_grad(): model.post_init() + # apply PT-D Tensor Parallel, activation checkpointing, torch.compile, Data Parallel + train_spec.parallelize_fn(model, world_mesh, parallel_dims, job_config) model.train() model_parts = [model] From 365ef8bf9ee222da762284825c94e758759b01ba Mon Sep 17 00:00:00 2001 From: puigde Date: Thu, 9 Apr 2026 13:48:19 +0200 Subject: [PATCH 2/3] Add RWKV-7 380M config --- configs/rwkv7_380M.json | 20 ++++++++++++++++++++ 1 file changed, 20 insertions(+) create mode 100644 configs/rwkv7_380M.json diff --git a/configs/rwkv7_380M.json b/configs/rwkv7_380M.json new file mode 100644 index 00000000..8313a84c --- /dev/null +++ b/configs/rwkv7_380M.json @@ -0,0 +1,20 @@ +{ + "model_type": "rwkv7", + "hidden_size": 1024, + "num_hidden_layers": 24, + "head_dim": 64, + "hidden_ratio": 4, + "hidden_act": "sqrelu", + "vocab_size": 32000, + "decay_low_rank_dim": 64, + "gate_low_rank_dim": 128, + "a_low_rank_dim": 64, + "v_low_rank_dim": 16, + "tie_word_embeddings": false, + "fuse_cross_entropy": true, + "fuse_norm": true, + "use_l2warp": false, + "attn_mode": "chunk", + "bos_token_id": 1, + "eos_token_id": 2 +} From 05eda4a2039d8cc424a869c594826b122c84167a Mon Sep 17 00:00:00 2001 From: puigde Date: Thu, 9 Apr 2026 14:40:31 +0200 Subject: [PATCH 3/3] Document init memory thresholds and CPU fallback Expand inline comments to explain the memory cost of init-before-parallelize (~4 bytes/param in fp32) and the OOM thresholds (~10B on 40GB, ~20B on 80GB). Point users to --training.enable_cpu_offload as a fallback for larger models. Co-Authored-By: Claude Opus 4.6 (1M context) --- flame/train.py | 15 ++++++++++----- 1 file changed, 10 insertions(+), 5 deletions(-) diff --git a/flame/train.py b/flame/train.py index 97fcb4d9..65b1f716 100644 --- a/flame/train.py +++ b/flame/train.py @@ -264,10 +264,14 @@ def main(job_config: JobConfig): # We need to iterate through model_parts to apply SPMD parallelisms, compilation, # optimizer, and checkpointing for m in model_parts: - # Initialize weights before parallelization so that models with - # custom .data= init (e.g. RWKV-7) operate on plain tensors - # instead of DTensors. The temporary full-size allocation is - # freed once FSDP shards the parameters. + # Materialize and initialize weights before applying parallelisms. + # Some models (e.g. RWKV-7) compute position-dependent init values + # as regular tensors and assign via .data, which is incompatible + # with DTensors created by FSDP. Cost: full model in fp32 + # temporarily on each rank (~4 bytes/param). This fits comfortably + # for models up to ~10B on 40GB GPUs or ~20B on 80GB GPUs. For + # larger models, set init_device="cpu" (e.g. via + # --training.enable_cpu_offload). m.to_empty(device=init_device) with torch.no_grad(): m.post_init() @@ -278,7 +282,8 @@ def main(job_config: JobConfig): # confirm that user will be able to view loss metrics on the console ensure_pp_loss_visible(parallel_dims, job_config, color) else: - # Initialize weights before parallelization (see PP path above). + # Materialize and initialize weights before parallelization + # (see PP path comment above for rationale and memory thresholds). model.to_empty(device=init_device) with torch.no_grad(): model.post_init()