Conversation
|
/te-ci L1 pytorch |
Greptile SummaryThis PR adds FSDP2 all-gather hooks ( Confidence Score: 5/5Safe to merge — the core FSDP2 hook logic is correct and all prior P0/P1 concerns are resolved; remaining findings are minor style issues. All five previously raised issues have been properly addressed: @classmethod decorator, slice arg naming, rowwise_data assertion, shard_M alignment assertion, and as_strided storage_offset validation. The two remaining findings are P2: a missing f-string prefix that produces a slightly unhelpful error message, and an inconsistent silent fallthrough for partial aten.slice calls vs the explicit NotImplementedError used by as_strided. Neither affects correctness of the FSDP2 training path. transformer_engine/pytorch/tensor/nvfp4_tensor.py (minor: f-string prefix on line 676, slice fallthrough on lines 707-718) Important Files Changed
Sequence DiagramsequenceDiagram
participant FSDP2
participant NVFP4Tensor
participant AllGather
FSDP2->>NVFP4Tensor: fsdp_pre_all_gather()
Note over NVFP4Tensor: assert shard_M % 16 == 0<br/>assert _rowwise_data is not None
NVFP4Tensor-->>NVFP4Tensor: unpad rowwise_scale_inv[:shard_M, :]
NVFP4Tensor->>FSDP2: (rowwise_data, rowwise_scale_inv), metadata
FSDP2->>AllGather: all-gather rowwise tensors
AllGather-->>FSDP2: concatenated tensors (world_size x shard)
FSDP2->>NVFP4Tensor: fsdp_post_all_gather(all_gather_outputs, metadata)
Note over NVFP4Tensor: repad rowwise_scale_inv to round_up(full_M, 128)
alt out is None (first iteration)
NVFP4Tensor-->>NVFP4Tensor: construct new NVFP4Tensor
else out is not None (subsequent iterations)
NVFP4Tensor-->>NVFP4Tensor: update _rowwise_data, _rowwise_scale_inv in-place
end
opt columnwise_usage
NVFP4Tensor-->>NVFP4Tensor: _create_columnwise() locally
end
NVFP4Tensor->>FSDP2: (NVFP4Tensor, all_gather_outputs)
FSDP2->>NVFP4Tensor: aten.as_strided (identity check)
FSDP2->>NVFP4Tensor: aten.slice (full-dim identity check)
FSDP2->>NVFP4Tensor: aten.record_stream (propagate to inner tensors)
Reviews (8): Last reviewed commit: "X" | Re-trigger Greptile |
| torch.manual_seed(42) | ||
| torch.cuda.manual_seed(42) | ||
|
|
||
| @pytest.mark.parametrize("shape", _test_shapes) |
There was a problem hiding this comment.
setup_class will TypeError at runtime
setup_class is decorated with @staticmethod, but its signature accepts a cls parameter. Pytest invokes class-level setup by calling TestNVFP4FSDP2Hooks.setup_class() with no arguments. Because this is a static method, Python does not inject cls, so the call has 0 arguments against a 1-argument signature — raising TypeError: setup_class() missing 1 required positional argument: 'cls' before any test in the class executes.
The seed initialisation would silently never run. Change the decorator to @classmethod:
| torch.manual_seed(42) | |
| torch.cuda.manual_seed(42) | |
| @pytest.mark.parametrize("shape", _test_shapes) | |
| @classmethod | |
| def setup_class(cls) -> None: | |
| torch.manual_seed(42) | |
| torch.cuda.manual_seed(42) |
There was a problem hiding this comment.
▎ Good catch — fixed. Changed to @classmethod so pytest actually invokes the seed initialization.
| if func == aten.slice.Tensor: | ||
| tensor = args[0] | ||
| dim = args[1] | ||
| start = args[2] | ||
| length = args[3] | ||
| if start == 0 and length == tensor.size(dim): | ||
| return NVFP4Tensor.make_like(tensor) |
There was a problem hiding this comment.
slice.Tensor handler uses wrong variable name for end and skips step check
The fourth positional argument of aten.slice.Tensor is end, not length. The variable name is misleading, but more importantly the handler unconditionally indexes args[3] without checking whether a fifth step argument is provided and equals 1. A non-unit step with start=0, end=size would still pass the existing check and incorrectly return an unmodified NVFP4Tensor, silently skipping the stride.
| if func == aten.slice.Tensor: | |
| tensor = args[0] | |
| dim = args[1] | |
| start = args[2] | |
| length = args[3] | |
| if start == 0 and length == tensor.size(dim): | |
| return NVFP4Tensor.make_like(tensor) | |
| if func == aten.slice.Tensor: | |
| tensor = args[0] | |
| dim = args[1] if len(args) > 1 else 0 | |
| start = args[2] if len(args) > 2 else None | |
| end = args[3] if len(args) > 3 else None | |
| step = args[4] if len(args) > 4 else 1 | |
| if step == 1 and (start is None or start == 0) and (end is None or end >= tensor.size(dim)): | |
| return NVFP4Tensor.make_like(tensor) |
There was a problem hiding this comment.
Fixed. Renamed length → end, added defensive len(args) checks for optional positional args, and added step == 1 guard. Also handles the case where start/end are None (the ATen defaults
| # Always send both orientations (GEMM needs both for fwd/bwd) | ||
| rowwise_usage = True | ||
| sharded_tensors = (rowwise_data, rowwise_scale_inv) | ||
| columnwise_usage = self._quantizer.columnwise_usage | ||
| if columnwise_usage: | ||
| sharded_tensors += (columnwise_data, columnwise_scale_inv) |
There was a problem hiding this comment.
Hardcoded
rowwise_usage = True can silently pass None to FSDP2
rowwise_usage is hardcoded to True without checking whether _rowwise_data is actually populated. If an NVFP4Tensor is created with only columnwise data (rowwise=False), rowwise_data will be None but it will still be included in sharded_tensors. FSDP2 would then try to all-gather None, which will crash or corrupt silently.
Consider guarding this with an explicit check:
rowwise_usage = self._rowwise_data is not NoneOr, if both orientations are always required for FSDP2 training, add an assertion to catch misconfigured tensors early:
rowwise_usage = True
assert self._rowwise_data is not None, (
"FSDP2 requires rowwise data, but _rowwise_data is None. "
"Ensure the NVFP4Quantizer was created with rowwise=True."
)There was a problem hiding this comment.
Added an assertion to fail early if _rowwise_data is None:
▎ assert self._rowwise_data is not None, (
▎ "FSDP2 requires rowwise data, but _rowwise_data is None. "
▎ "Ensure the NVFP4Quantizer was created with rowwise=True."
▎ )
|
|
||
| # Always send both orientations (GEMM needs both for fwd/bwd) |
There was a problem hiding this comment.
@jomitchellnv we shouldnt need both orientations. For forward pass only rowwise is needed and for backward pass only columnwise is needed.
Can you refer to this PR for the optimizations
#2789
Right now the columnwise allhgather implementation is transposing, allgathering in pre_allgather and then again transposing after post allgather which is expensive.
You can make it better one of the two ways for columnwise allgather
- Always allgather rowwise data similar to above PR. And if columnwise usage is needed then you can just transpose it in post allgather
- allgather only columnwise data if columnwise usage is set and then interleave the the columnwise stacked data.
1 is fine to have for now.
I would also add assertions for 2d Scaling similar to fp8 blockscaling implemntation.
There was a problem hiding this comment.
Thanks for the pointer to #2789 — that makes sense. I'll adopt option 1: only all-gather rowwise data +
scales, and derive columnwise locally in post_all_gather if columnwise_usage is set. Will update.
On the 2D scaling assertion — NVFP4 doesn't have _is_2D_scaled like Float8Blockwise, but I can add a guar on with_2d_quantization. Could you clarify what layout constraint you're thinking of? The NVFP4 rowwise scale has M in dim0 (round_up(M, 128), ...) regardless of that flag, so it should be compatible with dim0 all-gather either way. Want to make sure I add the right check.
There was a problem hiding this comment.
Yes, exactly with_2d_quantization should be asserted to be True. If with_2d_quantization is False then columnwise is not derivable from rowwise data/scales(Although this is not a common use-case)
There was a problem hiding this comment.
Added an assertion in fsdp_pre_all_gather that with_2d_quantization == True when columnwise data is included in the all-gather. This ensures we fail early rather than silently passing None columnwise tensors
178c7c3 to
238f2df
Compare
| if current_m_blocks < target_m_blocks: | ||
| columnwise_scale_inv = torch.nn.functional.pad( | ||
| columnwise_scale_inv, (0, target_m_blocks - current_m_blocks) | ||
| ) | ||
| elif current_m_blocks > target_m_blocks: | ||
| columnwise_scale_inv = columnwise_scale_inv[:, :target_m_blocks] |
There was a problem hiding this comment.
Silent data corruption when
shard_M is not a multiple of NVFP4_BLOCK_SCALING_SIZE
The elif current_m_blocks > target_m_blocks trim branch silently discards valid scale data rather than surfacing an alignment error.
Here is why this happens: fsdp_pre_all_gather computes m_blocks = ceil(shard_M / 16) per shard. After all-gather, the concatenated total is world_size * ceil(shard_M / 16). The target is round_up(ceil(full_M / 16), 4). When shard_M % 16 != 0, world_size * ceil(shard_M / 16) > ceil(full_M / 16), and the trim condition fires.
Concrete example: full_M = 136, world_size = 8, shard_M = 17.
- Each rank:
ceil(17/16) = 2m-blocks - All-gathered:
16m-blocks - Target:
round_up(ceil(136/16), 4) = round_up(9, 4) = 12 eliftrims from 16 → 12, discarding 4 real m-block scale columns
Beyond the trim, this represents a deeper problem: each rank's m-block 0 covers its local rows 0–15, not the same global rows. When shard_M % 16 != 0, the per-rank m-blocks do not align with global m-blocks, and the all-gathered scale tensor is fundamentally scrambled before the trim even runs.
The fix is to assert the constraint at the top of fsdp_pre_all_gather rather than silently masking the symptom:
assert shard_M % NVFP4_BLOCK_SCALING_SIZE == 0, (
f"FSDP2 requires shard_M ({shard_M}) to be a multiple of "
f"NVFP4_BLOCK_SCALING_SIZE ({NVFP4_BLOCK_SCALING_SIZE}). "
"Ensure the weight's row count is divisible by world_size * 16."
)And the elif trim should then become an unreachable assertion (or be removed), so any unexpected mismatch is caught loudly:
| if current_m_blocks < target_m_blocks: | |
| columnwise_scale_inv = torch.nn.functional.pad( | |
| columnwise_scale_inv, (0, target_m_blocks - current_m_blocks) | |
| ) | |
| elif current_m_blocks > target_m_blocks: | |
| columnwise_scale_inv = columnwise_scale_inv[:, :target_m_blocks] | |
| if current_m_blocks < target_m_blocks: | |
| columnwise_scale_inv = torch.nn.functional.pad( | |
| columnwise_scale_inv, (0, target_m_blocks - current_m_blocks) | |
| ) | |
| else: | |
| assert current_m_blocks == target_m_blocks, ( | |
| f"Unexpected m_block count after all-gather: " | |
| f"got {current_m_blocks}, expected {target_m_blocks}" | |
| ) |
There was a problem hiding this comment.
Fixed. Added an assertion in fsdp_pre_all_gather that shard_M % NVFP4_BLOCK_SCALING_SIZE == 0, which
guarantees m-blocks align across ranks after all-gather. Converted the elif trim to an unreachable assertion —
if the shard alignment holds, current_m_blocks can never exceed target_m_blocks.
| if func == aten.as_strided.default: | ||
| tensor = args[0] | ||
| shape = args[1] | ||
| strides = args[2] | ||
| if ( | ||
| len(shape) == len(strides) == 2 | ||
| and tuple(strides) == (shape[-1], 1) | ||
| and tuple(shape) == tuple(tensor.size()) | ||
| ): | ||
| return NVFP4Tensor.make_like(tensor) |
There was a problem hiding this comment.
as_strided handler silently falls through for non-standard calls
When the shape/stride check does not match (e.g., FSDP2 applies as_strided with a non-unit stride, a shape mismatch, or a non-zero storage offset), the handler does not return anything and falls through to super().__torch_dispatch__(). That super call is unlikely to know how to handle an NVFP4Tensor, so the result will be wrong or raise an obscure error.
The storage_offset argument (args[3] when present) is also never inspected — if FSDP2 ever supplies a non-zero offset, the identity check still fires and silently returns the unmodified tensor, ignoring the offset.
Since the intent is to handle only the no-op identity case and let everything else fall through deliberately, consider making that explicit with a comment or raising NotImplementedError for the non-identity case so failures are surfaced early rather than producing undefined behaviour from the super dispatch.
There was a problem hiding this comment.
Fixed — as_strided now validates storage_offset == 0 and raises NotImplementedError for any non-identity call
instead of silently falling through
96cf20f to
0222882
Compare
|
/te-ci L1 pytorch |
| @pytest.mark.parametrize("layer_type", ["LayerNormLinear", "TransformerLayer"]) | ||
| def test_distributed(recipe_name, fp8_init, sharding_dims, layer_type): | ||
| if recipe_name in ("Float8BlockScaling", "NVFP4BlockScaling") and fp8_init: | ||
| if recipe_name == "Float8BlockScaling" and fp8_init: |
There was a problem hiding this comment.
Does fp8 block scaling still fail? You had fixed this in here #2753
right?
| """Verify that GEMM-swizzled scales raise NotImplementedError.""" | ||
| shape = (512, 256) | ||
| qt = _make_nvfp4_tensor(shape) | ||
| qt._with_gemm_swizzled_scales = True |
There was a problem hiding this comment.
The right way to do this is to set quantizer.optimize_for_gemm = True and the quantized tensor automatically gets _with_gemm_swizzled_scales attribute set to true with in fact the scales being actually swizzled.
Here we are just setting the attribute to be True without actually swizzling the scales.
There was a problem hiding this comment.
I replaced the direct qt._with_gemm_swizzled_scales = True hack with the proper API: creating an
NVFP4Quantizer with quantizer.optimize_for_gemm = True and quantizing through it. Since NVFP4's C++ path doesn't wire up optimize_for_gemm yet (hardcoded false with TODO at quantizer.cpp:1731 and :1993), the test gracefully skips with pytest.skip() explaining why, and will auto-activate once the C++ TODO is addressed.
| assert torch.equal(result._amax_columnwise, orig_amax_col) | ||
|
|
||
| @pytest.mark.parametrize("shape", _test_shapes) | ||
| def test_round_trip_dequantize(self, shape: Tuple[int, int]): |
There was a problem hiding this comment.
There is a lot of code duplication with the previous round trip test without the dequantize. Could we merge this test into that one?
There was a problem hiding this comment.
Merged test_round_trip_dequantize into test_round_trip_data_integrity. The dequantize check
(orig_deq, result.dequantize(), assert_close) was added at the end of the data integrity test. The separate
test_round_trip_dequantize method was removed entirely. This eliminates the duplicated setup (tensor creation,
pre_all_gather, simulate, post_all_gather
|
/te-ci L1 pytorch |
|
/te-ci L1 pytorch |
1 similar comment
|
/te-ci L1 pytorch |
Signed-off-by: Jonathan Mitchell <jomitchell@ipp1-1429.ipp1a1.colossus.nvidia.com>
Signed-off-by: Jonathan Mitchell <jomitchell@ipp1-1429.ipp1a1.colossus.nvidia.com>
for more information, see https://pre-commit.ci
Signed-off-by: Jonathan Mitchell <jomitchell@ipp1-1429.ipp1a1.colossus.nvidia.com>
for more information, see https://pre-commit.ci
Signed-off-by: Jonathan Mitchell <jomitchell@ipp1-1429.ipp1a1.colossus.nvidia.com>
1. nvfp4_tensor.py — as_strided hardening + with_2d_quantization assertion 2. test_nvfp4_fsdp2_hooks.py — remove unused tex import 3. run_fsdp2_model.py — remove stale Float8BlockScaling xfail PR comments to leave: - greptile P1 (as_strided): Fixed — validates storage_offset == 0 and raises NotImplementedError for non-identity calls. - vthumbe1503 (2D quantization): Added assertion that with_2d_quantization == True when columnwise data is in the all-gather. - greptile P2 (unused import): Removed. - vthumbe1503 (Float8BlockScaling xfail): Correct, fixed in NVIDIA#2753 — removed the xfail Signed-off-by: Jonathan Mitchell <jomitchell@r6515-0097.ipp1a1.colossus.nvidia.com>
1. Enable with_2d_quantization in test helper to satisfy
fsdp_pre_all_gather assertion added in 339ec09
2. Merge test_round_trip_dequantize into
test_round_trip_data_integrity to reduce duplication
3. Use quantizer.optimize_for_gemm API instead of directly
setting _with_gemm_swizzled_scales in swizzled scales
rejection test (skips until C++ wires up NVFP4 support
Signed-off-by: Jonathan Mitchell <jomitchell@ipp1-1334.ipp1a1.colossus.nvidia.com>
Derive columnwise data locally via _create_columnwise() in post_all_gather instead of all-gathering both orientations. This halves the all-gather communication volume, matching the pattern used by Float8BlockwiseQTensor. Updates tests accordingly. Signed-off-by: Jonathan Mitchell <jomitchell@ipp1-1334.ipp1a1.colossus.nvidia.com>
Signed-off-by: Jonathan Mitchell <jomitchell@ipp1-1334.ipp1a1.colossus.nvidia.com>
11c5f08 to
0471224
Compare
|
/te-ci L1 pytorch |
1. _check_fp8_fsdp2_allgather: zeros_like(local_tensor) returns an
NVFP4Tensor (not a plain tensor) because NVFP4Tensor.__torch_dispatch__
intercepts empty_like/zero_. This causes a dtype mismatch in
dist.all_gather. Fix by dequantizing first, then creating plain-tensor
output buffers.
2. Re-add Float8BlockScaling + fp8_init xfail removed in d4b7337 — scale
inverse padding during FSDP2 all-gather slice ops is still unhandled
Signed-off-by: Jonathan Mitchell <jomitchell@dl325g11-0771.ipp4a1.colossus.nvidia.com>
aa64e7a to
7a81c83
Compare
|
/te-ci L1 pytorch |
Description
Summary
FSDP2 training with NVFP4BlockScaling
operations
Details
Without FSDP2 hooks, FSDP2 attempts data_ptr() on the NVFP4Tensor wrapper subclass and crashes. This PR follows
the Float8BlockwiseQTensor FSDP2 hooks pattern (the closest analog since NVFP4 also stores columnwise data
transposed), with NVFP4-specific adjustments:
round_up(ceil(M/16), 4) — both unpadded before all-gather and repadded after
since they're scalar and identical across ranks
Test plan
integrity, dequantize correctness, in-place update path, swizzled-scale rejection, and dispatch handlers
tests/pytorch/distributed/fsdp2_tests/run_fsdp2_fused_adam.py -v -k "fp8_master_weights and NVFP4" — multi-GPU
FSDP2 + NVFP4 integration
FSDP2 fused_adam regression
Fixes # (issue)
Type of change
Changes
Please list the changes introduced in this PR:
Checklist: