Skip to content

Commit aa64e7a

Browse files
authored
Merge branch 'main' into jm/nvfp4-block-fused-adam
2 parents 0471224 + 53fefa4 commit aa64e7a

File tree

4 files changed

+220
-94
lines changed

4 files changed

+220
-94
lines changed

examples/jax/collective_gemm/test_gemm.py

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -151,8 +151,20 @@ def run_gemm_tests(args, mesh=None):
151151
jax.block_until_ready(gathered_output)
152152

153153
if args.enable_result_check and args.process_id == 0:
154+
# CGEMM + RS + BF16 uses TE's reduce_bf16 kernel (sequential left-to-right in FP32).
155+
# With catastrophic cancellation the output is near zero while the absolute diff can
156+
# reach 1 ULP of the partial GEMM magnitude (~0.0625 for typical transformer
157+
# activations at O(8) scale), which exceeds the previous atol=1e-5. The 2x
158+
# margin (0.125) covers this worst-case 1-ULP absolute difference.
159+
is_cgemm_rs_bf16 = collective_op == CollectiveOp.REDUCE_SCATTER and not use_quantization
160+
rtol = 1e-2 if is_cgemm_rs_bf16 else None
161+
atol = 0.125 if is_cgemm_rs_bf16 else None
154162
assert_allclose(
155-
gathered_ref_output, gathered_output, dtype=get_tolerance_dtype(quantizer_set)
163+
gathered_ref_output,
164+
gathered_output,
165+
dtype=get_tolerance_dtype(quantizer_set),
166+
rtol=rtol,
167+
atol=atol,
156168
)
157169

158170

examples/pytorch/quantized_model_init/fully_shard.py

Lines changed: 58 additions & 85 deletions
Original file line numberDiff line numberDiff line change
@@ -13,8 +13,11 @@
1313
local shards on each rank's GPU.
1414
2. ``quantized_model_init`` -- Flags the model for FP8 weight initialization
1515
(actual quantization happens in ``reset_parameters`` after sharding).
16-
3. ``fully_shard`` -- PyTorch FSDP2 sharding of each TransformerLayer.
17-
4. ``FusedAdam`` with FP32 master weights for full-precision training updates.
16+
3. ``preserve_high_precision_init_val`` -- Keeps the original BF16 weight
17+
values on CPU so they can seed the optimizer's FP32 master weights,
18+
avoiding the precision loss of round-tripping through FP8.
19+
4. ``fully_shard`` -- PyTorch FSDP2 sharding of each TransformerLayer.
20+
5. ``FusedAdam`` with FP32 master weights for full-precision training updates.
1821
1922
.. note::
2023
``fuse_wgrad_accumulation`` is **not** used here. That feature writes
@@ -38,18 +41,23 @@
3841
from torch.distributed.tensor import DTensor
3942

4043
import transformer_engine.pytorch as te
41-
from transformer_engine.pytorch import QuantizedTensor
4244
from transformer_engine.pytorch.module.base import TransformerEngineBaseModule
45+
from transformer_engine.pytorch.quantized_tensor import QuantizedTensor
4346

44-
# ── Configuration (matches main.py) ──────────────────────────────────
47+
# ── Configuration ────────────────────────────────────────────────────
4548
HIDDEN_SIZE = 256
4649
FFN_HIDDEN_SIZE = 1024
4750
NUM_ATTENTION_HEADS = 8
4851
NUM_LAYERS = 3
4952
SEQ_LEN = 32
5053
BATCH_PER_RANK = 2
5154
NUM_STEPS = 5
52-
DTYPE = torch.bfloat16
55+
# DTYPE is used for both params_dtype and activation tensors in this example.
56+
# float32 is chosen for params_dtype so that the high-precision init values
57+
# (which seed the optimizer's FP32 master weights) avoid a lossy BF16→FP8→FP32
58+
# round-trip. Using float32 for activations as well keeps the example simple;
59+
# in production you would typically use BF16 activations inside te.autocast().
60+
DTYPE = torch.float32
5361

5462

5563
def dist_print(msg):
@@ -60,10 +68,6 @@ def dist_print(msg):
6068

6169
def main():
6270
# ── 1. Distributed setup ─────────────────────────────────────────
63-
assert "TORCHELASTIC_RUN_ID" in os.environ, (
64-
"This script must be launched with torchrun, e.g.:\n"
65-
" torchrun --nproc-per-node 2 fully_shard.py"
66-
)
6771
world_size = int(os.environ["WORLD_SIZE"])
6872
local_rank = int(os.environ["LOCAL_RANK"])
6973

@@ -74,10 +78,14 @@ def main():
7478
torch.manual_seed(42)
7579
torch.cuda.manual_seed(42)
7680

77-
# ── 2. Create model on meta device (zero memory) ────────────────
78-
# quantized_model_init sets the flag for FP8 weight initialization,
79-
# but with device="meta" no actual memory is allocated yet.
80-
with te.quantized_model_init(enabled=True):
81+
# ── 2. Create model on meta device (zero memory) ─────────────────
82+
# quantized_model_init flags parameters for FP8 quantization.
83+
# preserve_high_precision_init_val=True saves the original BF16
84+
# values on CPU so they can seed optimizer master weights later,
85+
# avoiding the precision loss of dequantizing from FP8.
86+
# We set DTYPE to float32 since these weights will actually be initialized as FP8,
87+
# but we want to seed the optimizer states (which will be in FP32) with the FP32 values.
88+
with te.quantized_model_init(enabled=True, preserve_high_precision_init_val=True):
8189
model = torch.nn.Sequential(
8290
*[
8391
te.TransformerLayer(
@@ -93,52 +101,53 @@ def main():
93101
for _ in range(NUM_LAYERS)
94102
]
95103
)
96-
97-
# Verify all parameters are on meta device (no GPU memory used).
98-
for name, param in model.named_parameters():
99-
assert param.device == torch.device("meta"), f"{name} is not on meta device"
100104
dist_print("Model created on meta device (zero GPU memory).")
101105

102-
# ── 3. FSDP2 sharding ───────────────────────────────────────────
103-
# Apply sharding to the meta-device model. FSDP2 wraps parameters
106+
# ── 3. FSDP2 sharding ───────────────────────────────────────────
107+
# Apply sharding to the meta-device model. FSDP2 wraps parameters
104108
# as DTensors but no GPU memory is allocated yet.
105109
mesh = DeviceMesh("cuda", list(range(world_size)))
106110
for child in model.children():
107111
fully_shard(child, mesh=mesh)
108112
fully_shard(model, mesh=mesh)
109113
dist_print("FSDP2 sharding applied to meta-device model.")
110114

111-
# ── 4. Materialize parameters on GPU ─────────────────────────────
115+
# ── 4. Materialize parameters on GPU ─────────────────────────────
112116
# reset_parameters() on each TE module materializes the local shard
113117
# on CUDA, applies weight initialization, and quantizes to FP8.
118+
# Because preserve_high_precision_init_val=True, the pre-quantization
119+
# BF16 values are saved on CPU for each local shard.
114120
for module in model.modules():
115121
if isinstance(module, TransformerEngineBaseModule):
116122
module.reset_parameters()
123+
dist_print("Parameters materialized on GPU.")
117124

118-
# Post-materialization verification.
119-
for name, param in model.named_parameters():
120-
assert isinstance(param, DTensor), f"{name} is not a DTensor after sharding"
121-
qt_count = sum(
122-
1
123-
for _, p in model.named_parameters()
124-
if isinstance(p, DTensor) and isinstance(p._local_tensor, QuantizedTensor)
125-
)
126-
assert qt_count > 0, "No QuantizedTensor local tensors after materialization"
127-
dist_print(
128-
f"Parameters materialized: {qt_count} FP8 (QuantizedTensor) weight params "
129-
"wrapped in DTensors."
130-
)
131-
132-
# ── 5. Optimizer ─────────────────────────────────────────────────
125+
# ── 5. Optimizer with FP32 master weights ────────────────────────
133126
optimizer = te.optimizers.FusedAdam(
134127
model.parameters(),
135128
lr=1e-3,
136129
master_weights=True,
137130
master_weight_dtype=torch.float32,
138131
)
139-
dist_print("Using FusedAdam with master_weights=True.")
140132

141-
# ── 6. Training loop ─────────────────────────────────────────────
133+
# ── 6. Seed master weights from high-precision init values ───────
134+
# By default, FusedAdam initializes master weights by dequantizing
135+
# the FP8 parameters, which introduces quantization noise. Instead,
136+
# we seed them from the original BF16 init values preserved in step 2.
137+
for name, param in model.named_parameters():
138+
optimizer.initialize_state(param, store_param_remainders=False)
139+
local = param._local_tensor if isinstance(param, DTensor) else param
140+
if isinstance(local, QuantizedTensor):
141+
hp_val = local.get_high_precision_init_val()
142+
assert hp_val.dtype == DTYPE, f"HP val dtype {hp_val.dtype}, expected {DTYPE}"
143+
optimizer.set_scaled_state(
144+
param, "master_param", hp_val.to(device=device, dtype=torch.float32)
145+
)
146+
local.clear_high_precision_init_val()
147+
148+
dist_print("Optimizer master weights seeded from high-precision init values.")
149+
150+
# ── 7. Training loop ─────────────────────────────────────────────
142151
x = torch.randn(SEQ_LEN, BATCH_PER_RANK, HIDDEN_SIZE, dtype=DTYPE, device=device)
143152
target = torch.randn(SEQ_LEN, BATCH_PER_RANK, HIDDEN_SIZE, dtype=DTYPE, device=device)
144153

@@ -153,56 +162,22 @@ def main():
153162
optimizer.step()
154163
dist_print(f" Step {step}: loss = {loss.item():.6f}")
155164

156-
# ── 7. Post-training assertions ──────────────────────────────────
157-
dist_print("\nVerifying invariants ...")
158-
159-
qt_after = 0
160-
for name, param in model.named_parameters():
161-
assert isinstance(param, DTensor), f"{name} lost DTensor wrapping"
162-
if isinstance(param._local_tensor, QuantizedTensor):
163-
qt_after += 1
164-
assert qt_after > 0, "No QuantizedTensor local tensors after training"
165-
dist_print(f" {qt_after} params still have QuantizedTensor local tensors.")
166-
167-
# Optimizer states: master weights and moments should be float32.
168-
for param in model.parameters():
169-
state = optimizer.state[param]
170-
if "master_param" in state:
171-
assert (
172-
state["master_param"].dtype == torch.float32
173-
), f"Master weight dtype {state['master_param'].dtype}, expected float32"
174-
assert state["exp_avg"].dtype == torch.float32, "exp_avg should be float32"
175-
assert state["exp_avg_sq"].dtype == torch.float32, "exp_avg_sq should be float32"
176-
177-
dist_print("All assertions passed!")
178-
dist_print(" - Linear weight parameters: QuantizedTensor (FP8) wrapped in DTensor")
179-
dist_print(" - Optimizer master weights: float32")
180-
dist_print(" - Optimizer states (exp_avg, exp_avg_sq): float32")
181-
182165
# ── 8. Distributed checkpoint: save and load ─────────────────────
183166
# torch.distributed.checkpoint (DCP) saves sharded state — each rank
184-
# writes only its local shard. This preserves FP8 compute weights
185-
# and the full optimizer state (master weights, moments, step count).
167+
# writes only its local shard, preserving FP8 compute weights and
168+
# the full optimizer state (master weights, moments, step count).
186169
import torch.distributed.checkpoint as dcp
187-
from torch.distributed.checkpoint.state_dict import (
188-
StateDictOptions,
189-
get_model_state_dict,
190-
get_optimizer_state_dict,
191-
)
192170

193-
# Use a fixed path so all ranks agree on the checkpoint location.
194171
checkpoint_dir = "/tmp/te_fsdp2_example_checkpoint"
195172
dist_print(f"\nSaving distributed checkpoint to {checkpoint_dir} ...")
196173

197-
# Save sharded checkpoint. DCP handles DTensor shards natively —
198-
# each rank writes only its local shard to the filesystem.
199174
dcp.save(
200175
{"model": model.state_dict(), "optimizer": optimizer.state_dict()},
201176
checkpoint_id=checkpoint_dir,
202177
)
203178
dist_print(" Checkpoint saved (FP8 weights + optimizer state).")
204179

205-
# Load checkpoint back. Provide empty state dict containers with the
180+
# Load checkpoint back. Provide empty state dict containers with the
206181
# same structure; DCP fills them from the saved files.
207182
state_to_load = {"model": model.state_dict(), "optimizer": optimizer.state_dict()}
208183
dcp.load(state_to_load, checkpoint_id=checkpoint_dir)
@@ -225,6 +200,11 @@ def main():
225200
# authoritative FP32 values (more precise than dequantizing FP8).
226201
# All ranks must participate in gathering; only rank 0 saves.
227202
from safetensors.torch import save_file
203+
from torch.distributed.checkpoint.state_dict import (
204+
StateDictOptions,
205+
get_model_state_dict,
206+
get_optimizer_state_dict,
207+
)
228208

229209
full_opts = StateDictOptions(full_state_dict=True, cpu_offload=True)
230210

@@ -238,10 +218,10 @@ def main():
238218

239219
for key, value in full_model_state.items():
240220
if key in opt_param_states and "master_param" in opt_param_states[key]:
241-
# Prefer optimizer's FP32 master weight (maintained throughout training).
221+
# Prefer optimizer's FP32 master weight.
242222
fp32_state[key] = opt_param_states[key]["master_param"].float()
243-
elif isinstance(value, QuantizedTensor):
244-
# Fallback: dequantize FP8 → FP32 (e.g. if master_weights was off).
223+
elif isinstance(value, te.QuantizedTensor):
224+
# Fallback: dequantize FP8 → FP32.
245225
fp32_state[key] = value.dequantize().float()
246226
else:
247227
# Non-FP8 params (e.g. LayerNorm weights): cast to FP32.
@@ -251,14 +231,7 @@ def main():
251231
save_file(fp32_state, save_path)
252232
dist_print(f"\nSaved FP32 model ({len(fp32_state)} params) to {save_path}")
253233

254-
# Quick verification: all saved tensors are float32.
255-
from safetensors.torch import load_file
256-
257-
loaded = load_file(save_path)
258-
for k, v in loaded.items():
259-
assert v.dtype == torch.float32, f"{k}: expected float32, got {v.dtype}"
260-
dist_print(f" Verified: all {len(loaded)} tensors are float32.")
261-
234+
dist.barrier() # wait for rank 0 to finish file I/O
262235
dist.destroy_process_group()
263236

264237

0 commit comments

Comments
 (0)