diff --git a/configs/rwkv7_380M.json b/configs/rwkv7_380M.json new file mode 100644 index 00000000..8313a84c --- /dev/null +++ b/configs/rwkv7_380M.json @@ -0,0 +1,20 @@ +{ + "model_type": "rwkv7", + "hidden_size": 1024, + "num_hidden_layers": 24, + "head_dim": 64, + "hidden_ratio": 4, + "hidden_act": "sqrelu", + "vocab_size": 32000, + "decay_low_rank_dim": 64, + "gate_low_rank_dim": 128, + "a_low_rank_dim": 64, + "v_low_rank_dim": 16, + "tie_word_embeddings": false, + "fuse_cross_entropy": true, + "fuse_norm": true, + "use_l2warp": false, + "attn_mode": "chunk", + "bos_token_id": 1, + "eos_token_id": 2 +} diff --git a/flame/train.py b/flame/train.py index 461fcc09..65b1f716 100644 --- a/flame/train.py +++ b/flame/train.py @@ -264,21 +264,31 @@ def main(job_config: JobConfig): # We need to iterate through model_parts to apply SPMD parallelisms, compilation, # optimizer, and checkpointing for m in model_parts: - # apply SPMD-style PT-D techniques - train_spec.parallelize_fn(m, world_mesh, parallel_dims, job_config) + # Materialize and initialize weights before applying parallelisms. + # Some models (e.g. RWKV-7) compute position-dependent init values + # as regular tensors and assign via .data, which is incompatible + # with DTensors created by FSDP. Cost: full model in fp32 + # temporarily on each rank (~4 bytes/param). This fits comfortably + # for models up to ~10B on 40GB GPUs or ~20B on 80GB GPUs. For + # larger models, set init_device="cpu" (e.g. via + # --training.enable_cpu_offload). m.to_empty(device=init_device) with torch.no_grad(): m.post_init() + # apply SPMD-style PT-D techniques + train_spec.parallelize_fn(m, world_mesh, parallel_dims, job_config) m.train() # confirm that user will be able to view loss metrics on the console ensure_pp_loss_visible(parallel_dims, job_config, color) else: - # apply PT-D Tensor Parallel, activation checkpointing, torch.compile, Data Parallel - train_spec.parallelize_fn(model, world_mesh, parallel_dims, job_config) + # Materialize and initialize weights before parallelization + # (see PP path comment above for rationale and memory thresholds). model.to_empty(device=init_device) with torch.no_grad(): model.post_init() + # apply PT-D Tensor Parallel, activation checkpointing, torch.compile, Data Parallel + train_spec.parallelize_fn(model, world_mesh, parallel_dims, job_config) model.train() model_parts = [model]