1313 local shards on each rank's GPU.
14142. ``quantized_model_init`` -- Flags the model for FP8 weight initialization
1515 (actual quantization happens in ``reset_parameters`` after sharding).
16- 3. ``fully_shard`` -- PyTorch FSDP2 sharding of each TransformerLayer.
17- 4. ``FusedAdam`` with FP32 master weights for full-precision training updates.
16+ 3. ``preserve_high_precision_init_val`` -- Keeps the original BF16 weight
17+ values on CPU so they can seed the optimizer's FP32 master weights,
18+ avoiding the precision loss of round-tripping through FP8.
19+ 4. ``fully_shard`` -- PyTorch FSDP2 sharding of each TransformerLayer.
20+ 5. ``FusedAdam`` with FP32 master weights for full-precision training updates.
1821
1922.. note::
2023 ``fuse_wgrad_accumulation`` is **not** used here. That feature writes
3841from torch .distributed .tensor import DTensor
3942
4043import transformer_engine .pytorch as te
41- from transformer_engine .pytorch import QuantizedTensor
4244from transformer_engine .pytorch .module .base import TransformerEngineBaseModule
45+ from transformer_engine .pytorch .quantized_tensor import QuantizedTensor
4346
44- # ── Configuration (matches main.py) ──────────────────────────────────
47+ # ── Configuration ────────────────── ──────────────────────────────────
4548HIDDEN_SIZE = 256
4649FFN_HIDDEN_SIZE = 1024
4750NUM_ATTENTION_HEADS = 8
4851NUM_LAYERS = 3
4952SEQ_LEN = 32
5053BATCH_PER_RANK = 2
5154NUM_STEPS = 5
52- DTYPE = torch .bfloat16
55+ # DTYPE is used for both params_dtype and activation tensors in this example.
56+ # float32 is chosen for params_dtype so that the high-precision init values
57+ # (which seed the optimizer's FP32 master weights) avoid a lossy BF16→FP8→FP32
58+ # round-trip. Using float32 for activations as well keeps the example simple;
59+ # in production you would typically use BF16 activations inside te.autocast().
60+ DTYPE = torch .float32
5361
5462
5563def dist_print (msg ):
@@ -60,10 +68,6 @@ def dist_print(msg):
6068
6169def main ():
6270 # ── 1. Distributed setup ─────────────────────────────────────────
63- assert "TORCHELASTIC_RUN_ID" in os .environ , (
64- "This script must be launched with torchrun, e.g.:\n "
65- " torchrun --nproc-per-node 2 fully_shard.py"
66- )
6771 world_size = int (os .environ ["WORLD_SIZE" ])
6872 local_rank = int (os .environ ["LOCAL_RANK" ])
6973
@@ -74,10 +78,14 @@ def main():
7478 torch .manual_seed (42 )
7579 torch .cuda .manual_seed (42 )
7680
77- # ── 2. Create model on meta device (zero memory) ────────────────
78- # quantized_model_init sets the flag for FP8 weight initialization,
79- # but with device="meta" no actual memory is allocated yet.
80- with te .quantized_model_init (enabled = True ):
81+ # ── 2. Create model on meta device (zero memory) ─────────────────
82+ # quantized_model_init flags parameters for FP8 quantization.
83+ # preserve_high_precision_init_val=True saves the original BF16
84+ # values on CPU so they can seed optimizer master weights later,
85+ # avoiding the precision loss of dequantizing from FP8.
86+ # We set DTYPE to float32 since these weights will actually be initialized as FP8,
87+ # but we want to seed the optimizer states (which will be in FP32) with the FP32 values.
88+ with te .quantized_model_init (enabled = True , preserve_high_precision_init_val = True ):
8189 model = torch .nn .Sequential (
8290 * [
8391 te .TransformerLayer (
@@ -93,52 +101,53 @@ def main():
93101 for _ in range (NUM_LAYERS )
94102 ]
95103 )
96-
97- # Verify all parameters are on meta device (no GPU memory used).
98- for name , param in model .named_parameters ():
99- assert param .device == torch .device ("meta" ), f"{ name } is not on meta device"
100104 dist_print ("Model created on meta device (zero GPU memory)." )
101105
102- # ── 3. FSDP2 sharding ────────────────────────────────────────────
103- # Apply sharding to the meta-device model. FSDP2 wraps parameters
106+ # ── 3. FSDP2 sharding ───────────────────────────────────────────
107+ # Apply sharding to the meta-device model. FSDP2 wraps parameters
104108 # as DTensors but no GPU memory is allocated yet.
105109 mesh = DeviceMesh ("cuda" , list (range (world_size )))
106110 for child in model .children ():
107111 fully_shard (child , mesh = mesh )
108112 fully_shard (model , mesh = mesh )
109113 dist_print ("FSDP2 sharding applied to meta-device model." )
110114
111- # ── 4. Materialize parameters on GPU ──────────────────────────────
115+ # ── 4. Materialize parameters on GPU ─────────────────────────────
112116 # reset_parameters() on each TE module materializes the local shard
113117 # on CUDA, applies weight initialization, and quantizes to FP8.
118+ # Because preserve_high_precision_init_val=True, the pre-quantization
119+ # BF16 values are saved on CPU for each local shard.
114120 for module in model .modules ():
115121 if isinstance (module , TransformerEngineBaseModule ):
116122 module .reset_parameters ()
123+ dist_print ("Parameters materialized on GPU." )
117124
118- # Post-materialization verification.
119- for name , param in model .named_parameters ():
120- assert isinstance (param , DTensor ), f"{ name } is not a DTensor after sharding"
121- qt_count = sum (
122- 1
123- for _ , p in model .named_parameters ()
124- if isinstance (p , DTensor ) and isinstance (p ._local_tensor , QuantizedTensor )
125- )
126- assert qt_count > 0 , "No QuantizedTensor local tensors after materialization"
127- dist_print (
128- f"Parameters materialized: { qt_count } FP8 (QuantizedTensor) weight params "
129- "wrapped in DTensors."
130- )
131-
132- # ── 5. Optimizer ─────────────────────────────────────────────────
125+ # ── 5. Optimizer with FP32 master weights ────────────────────────
133126 optimizer = te .optimizers .FusedAdam (
134127 model .parameters (),
135128 lr = 1e-3 ,
136129 master_weights = True ,
137130 master_weight_dtype = torch .float32 ,
138131 )
139- dist_print ("Using FusedAdam with master_weights=True." )
140132
141- # ── 6. Training loop ─────────────────────────────────────────────
133+ # ── 6. Seed master weights from high-precision init values ───────
134+ # By default, FusedAdam initializes master weights by dequantizing
135+ # the FP8 parameters, which introduces quantization noise. Instead,
136+ # we seed them from the original BF16 init values preserved in step 2.
137+ for name , param in model .named_parameters ():
138+ optimizer .initialize_state (param , store_param_remainders = False )
139+ local = param ._local_tensor if isinstance (param , DTensor ) else param
140+ if isinstance (local , QuantizedTensor ):
141+ hp_val = local .get_high_precision_init_val ()
142+ assert hp_val .dtype == DTYPE , f"HP val dtype { hp_val .dtype } , expected { DTYPE } "
143+ optimizer .set_scaled_state (
144+ param , "master_param" , hp_val .to (device = device , dtype = torch .float32 )
145+ )
146+ local .clear_high_precision_init_val ()
147+
148+ dist_print ("Optimizer master weights seeded from high-precision init values." )
149+
150+ # ── 7. Training loop ─────────────────────────────────────────────
142151 x = torch .randn (SEQ_LEN , BATCH_PER_RANK , HIDDEN_SIZE , dtype = DTYPE , device = device )
143152 target = torch .randn (SEQ_LEN , BATCH_PER_RANK , HIDDEN_SIZE , dtype = DTYPE , device = device )
144153
@@ -153,56 +162,22 @@ def main():
153162 optimizer .step ()
154163 dist_print (f" Step { step } : loss = { loss .item ():.6f} " )
155164
156- # ── 7. Post-training assertions ──────────────────────────────────
157- dist_print ("\n Verifying invariants ..." )
158-
159- qt_after = 0
160- for name , param in model .named_parameters ():
161- assert isinstance (param , DTensor ), f"{ name } lost DTensor wrapping"
162- if isinstance (param ._local_tensor , QuantizedTensor ):
163- qt_after += 1
164- assert qt_after > 0 , "No QuantizedTensor local tensors after training"
165- dist_print (f" { qt_after } params still have QuantizedTensor local tensors." )
166-
167- # Optimizer states: master weights and moments should be float32.
168- for param in model .parameters ():
169- state = optimizer .state [param ]
170- if "master_param" in state :
171- assert (
172- state ["master_param" ].dtype == torch .float32
173- ), f"Master weight dtype { state ['master_param' ].dtype } , expected float32"
174- assert state ["exp_avg" ].dtype == torch .float32 , "exp_avg should be float32"
175- assert state ["exp_avg_sq" ].dtype == torch .float32 , "exp_avg_sq should be float32"
176-
177- dist_print ("All assertions passed!" )
178- dist_print (" - Linear weight parameters: QuantizedTensor (FP8) wrapped in DTensor" )
179- dist_print (" - Optimizer master weights: float32" )
180- dist_print (" - Optimizer states (exp_avg, exp_avg_sq): float32" )
181-
182165 # ── 8. Distributed checkpoint: save and load ─────────────────────
183166 # torch.distributed.checkpoint (DCP) saves sharded state — each rank
184- # writes only its local shard. This preserves FP8 compute weights
185- # and the full optimizer state (master weights, moments, step count).
167+ # writes only its local shard, preserving FP8 compute weights and
168+ # the full optimizer state (master weights, moments, step count).
186169 import torch .distributed .checkpoint as dcp
187- from torch .distributed .checkpoint .state_dict import (
188- StateDictOptions ,
189- get_model_state_dict ,
190- get_optimizer_state_dict ,
191- )
192170
193- # Use a fixed path so all ranks agree on the checkpoint location.
194171 checkpoint_dir = "/tmp/te_fsdp2_example_checkpoint"
195172 dist_print (f"\n Saving distributed checkpoint to { checkpoint_dir } ..." )
196173
197- # Save sharded checkpoint. DCP handles DTensor shards natively —
198- # each rank writes only its local shard to the filesystem.
199174 dcp .save (
200175 {"model" : model .state_dict (), "optimizer" : optimizer .state_dict ()},
201176 checkpoint_id = checkpoint_dir ,
202177 )
203178 dist_print (" Checkpoint saved (FP8 weights + optimizer state)." )
204179
205- # Load checkpoint back. Provide empty state dict containers with the
180+ # Load checkpoint back. Provide empty state dict containers with the
206181 # same structure; DCP fills them from the saved files.
207182 state_to_load = {"model" : model .state_dict (), "optimizer" : optimizer .state_dict ()}
208183 dcp .load (state_to_load , checkpoint_id = checkpoint_dir )
@@ -225,6 +200,11 @@ def main():
225200 # authoritative FP32 values (more precise than dequantizing FP8).
226201 # All ranks must participate in gathering; only rank 0 saves.
227202 from safetensors .torch import save_file
203+ from torch .distributed .checkpoint .state_dict import (
204+ StateDictOptions ,
205+ get_model_state_dict ,
206+ get_optimizer_state_dict ,
207+ )
228208
229209 full_opts = StateDictOptions (full_state_dict = True , cpu_offload = True )
230210
@@ -238,10 +218,10 @@ def main():
238218
239219 for key , value in full_model_state .items ():
240220 if key in opt_param_states and "master_param" in opt_param_states [key ]:
241- # Prefer optimizer's FP32 master weight (maintained throughout training) .
221+ # Prefer optimizer's FP32 master weight.
242222 fp32_state [key ] = opt_param_states [key ]["master_param" ].float ()
243- elif isinstance (value , QuantizedTensor ):
244- # Fallback: dequantize FP8 → FP32 (e.g. if master_weights was off) .
223+ elif isinstance (value , te . QuantizedTensor ):
224+ # Fallback: dequantize FP8 → FP32.
245225 fp32_state [key ] = value .dequantize ().float ()
246226 else :
247227 # Non-FP8 params (e.g. LayerNorm weights): cast to FP32.
@@ -251,14 +231,7 @@ def main():
251231 save_file (fp32_state , save_path )
252232 dist_print (f"\n Saved FP32 model ({ len (fp32_state )} params) to { save_path } " )
253233
254- # Quick verification: all saved tensors are float32.
255- from safetensors .torch import load_file
256-
257- loaded = load_file (save_path )
258- for k , v in loaded .items ():
259- assert v .dtype == torch .float32 , f"{ k } : expected float32, got { v .dtype } "
260- dist_print (f" Verified: all { len (loaded )} tensors are float32." )
261-
234+ dist .barrier () # wait for rank 0 to finish file I/O
262235 dist .destroy_process_group ()
263236
264237
0 commit comments