Skip to content

adds NVFP4 Fused Adam support#2797

Open
jomitchellnv wants to merge 12 commits intoNVIDIA:mainfrom
jomitchellnv:jm/nvfp4-block-fused-adam
Open

adds NVFP4 Fused Adam support#2797
jomitchellnv wants to merge 12 commits intoNVIDIA:mainfrom
jomitchellnv:jm/nvfp4-block-fused-adam

Conversation

@jomitchellnv
Copy link
Copy Markdown
Contributor

Description

Summary

  • Add FSDP2 all-gather hooks (fsdp_pre_all_gather / fsdp_post_all_gather) to NVFP4Tensor, enabling end-to-end
    FSDP2 training with NVFP4BlockScaling
  • Add aten.as_strided, aten.slice, and aten.record_stream dispatch handlers required by FSDP2's internal tensor
    operations
  • Remove NVFP4BlockScaling xfails from FSDP2 integration tests now that the hooks are in place

Details

Without FSDP2 hooks, FSDP2 attempts data_ptr() on the NVFP4Tensor wrapper subclass and crashes. This PR follows
the Float8BlockwiseQTensor FSDP2 hooks pattern (the closest analog since NVFP4 also stores columnwise data
transposed), with NVFP4-specific adjustments:

  • FP4 packing: Data last dim is K//2 (two FP4 values packed per uint8 byte)
  • Scale dtype: uint8 (vs float32 for Float8Blockwise)
  • Block size: 16 (NVFP4_BLOCK_SCALING_SIZE) vs 128 for Float8Blockwise
  • Scale padding: Rowwise scale dim0 padded to round_up(M, 128), columnwise scale dim1 padded to
    round_up(ceil(M/16), 4) — both unpadded before all-gather and repadded after
  • Amax tensors: _amax_rowwise and _amax_columnwise (shape (1,)) passed via metadata rather than all-gathered,
    since they're scalar and identical across ranks

Test plan

  • pytest tests/pytorch/test_nvfp4_fsdp2_hooks.py -v — 19 single-GPU unit tests covering round-trip shapes, data
    integrity, dequantize correctness, in-place update path, swizzled-scale rejection, and dispatch handlers
  • PYTHONPATH=... torchrun --nproc_per_node=2 -m pytest
    tests/pytorch/distributed/fsdp2_tests/run_fsdp2_fused_adam.py -v -k "fp8_master_weights and NVFP4" — multi-GPU
    FSDP2 + NVFP4 integration
  • PYTHONPATH=... pytest tests/pytorch/distributed/test_torch_fsdp2.py::test_fsdp2_fused_adam_tests -v — full
    FSDP2 fused_adam regression

Fixes # (issue)

Type of change

  • Documentation change (change only to the documentation, either a fix or a new content)
  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • Infra/Build change
  • Code refactoring

Changes

Please list the changes introduced in this PR:

  • Change A
  • Change B

Checklist:

  • I have read and followed the contributing guidelines
  • The functionality is complete
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

@jomitchellnv
Copy link
Copy Markdown
Contributor Author

/te-ci L1 pytorch

@greptile-apps
Copy link
Copy Markdown
Contributor

greptile-apps bot commented Mar 24, 2026

Greptile Summary

This PR adds FSDP2 all-gather hooks (fsdp_pre_all_gather / fsdp_post_all_gather) to NVFP4Tensor, along with the required aten.as_strided, aten.slice, and aten.record_stream dispatch handlers, enabling end-to-end FSDP2 training with NVFP4BlockScaling. The implementation follows the Float8BlockwiseQTensor pattern with NVFP4-specific adjustments for FP4 packing, uint8 scale dtype, and block-size-16 alignment. All five issues raised in the prior review round (setup_class decorator, slice arg names, rowwise_data None guard, shard_M alignment assertion, and as_strided storage_offset validation) are properly addressed.

Confidence Score: 5/5

Safe to merge — the core FSDP2 hook logic is correct and all prior P0/P1 concerns are resolved; remaining findings are minor style issues.

All five previously raised issues have been properly addressed: @classmethod decorator, slice arg naming, rowwise_data assertion, shard_M alignment assertion, and as_strided storage_offset validation. The two remaining findings are P2: a missing f-string prefix that produces a slightly unhelpful error message, and an inconsistent silent fallthrough for partial aten.slice calls vs the explicit NotImplementedError used by as_strided. Neither affects correctness of the FSDP2 training path.

transformer_engine/pytorch/tensor/nvfp4_tensor.py (minor: f-string prefix on line 676, slice fallthrough on lines 707-718)

Important Files Changed

Filename Overview
transformer_engine/pytorch/tensor/nvfp4_tensor.py Adds fsdp_pre_all_gather / fsdp_post_all_gather hooks and aten dispatch handlers (as_strided, slice, record_stream); one missing f-string prefix and an inconsistent silent fallthrough on partial aten.slice.Tensor calls.
tests/pytorch/test_nvfp4_fsdp2_hooks.py 19 single-GPU unit tests covering round-trip shapes, data integrity, dequantize correctness, in-place update, swizzled-scale rejection, and all three dispatch handlers; previously flagged @staticmethod/cls issue resolved to @classmethod.
tests/pytorch/distributed/fsdp2_tests/run_fsdp2_fused_adam.py Multi-GPU FSDP2 + FusedAdam integration tests extended to include NVFP4BlockScaling; xfails removed where appropriate and retained for known non-supported paths (no-master-weights, no-meta-device).
tests/pytorch/distributed/fsdp2_tests/run_fsdp2_model.py Model sharding test helper updated to include NVFP4BlockScaling as a valid recipe choice; change is additive and straightforward.

Sequence Diagram

sequenceDiagram
    participant FSDP2
    participant NVFP4Tensor
    participant AllGather

    FSDP2->>NVFP4Tensor: fsdp_pre_all_gather()
    Note over NVFP4Tensor: assert shard_M % 16 == 0<br/>assert _rowwise_data is not None
    NVFP4Tensor-->>NVFP4Tensor: unpad rowwise_scale_inv[:shard_M, :]
    NVFP4Tensor->>FSDP2: (rowwise_data, rowwise_scale_inv), metadata

    FSDP2->>AllGather: all-gather rowwise tensors
    AllGather-->>FSDP2: concatenated tensors (world_size x shard)

    FSDP2->>NVFP4Tensor: fsdp_post_all_gather(all_gather_outputs, metadata)
    Note over NVFP4Tensor: repad rowwise_scale_inv to round_up(full_M, 128)
    alt out is None (first iteration)
        NVFP4Tensor-->>NVFP4Tensor: construct new NVFP4Tensor
    else out is not None (subsequent iterations)
        NVFP4Tensor-->>NVFP4Tensor: update _rowwise_data, _rowwise_scale_inv in-place
    end
    opt columnwise_usage
        NVFP4Tensor-->>NVFP4Tensor: _create_columnwise() locally
    end
    NVFP4Tensor->>FSDP2: (NVFP4Tensor, all_gather_outputs)

    FSDP2->>NVFP4Tensor: aten.as_strided (identity check)
    FSDP2->>NVFP4Tensor: aten.slice (full-dim identity check)
    FSDP2->>NVFP4Tensor: aten.record_stream (propagate to inner tensors)
Loading

Reviews (8): Last reviewed commit: "X" | Re-trigger Greptile

Comment on lines +68 to +71
torch.manual_seed(42)
torch.cuda.manual_seed(42)

@pytest.mark.parametrize("shape", _test_shapes)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 setup_class will TypeError at runtime

setup_class is decorated with @staticmethod, but its signature accepts a cls parameter. Pytest invokes class-level setup by calling TestNVFP4FSDP2Hooks.setup_class() with no arguments. Because this is a static method, Python does not inject cls, so the call has 0 arguments against a 1-argument signature — raising TypeError: setup_class() missing 1 required positional argument: 'cls' before any test in the class executes.

The seed initialisation would silently never run. Change the decorator to @classmethod:

Suggested change
torch.manual_seed(42)
torch.cuda.manual_seed(42)
@pytest.mark.parametrize("shape", _test_shapes)
@classmethod
def setup_class(cls) -> None:
torch.manual_seed(42)
torch.cuda.manual_seed(42)

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

▎ Good catch — fixed. Changed to @classmethod so pytest actually invokes the seed initialization.

Comment on lines +726 to +732
if func == aten.slice.Tensor:
tensor = args[0]
dim = args[1]
start = args[2]
length = args[3]
if start == 0 and length == tensor.size(dim):
return NVFP4Tensor.make_like(tensor)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 slice.Tensor handler uses wrong variable name for end and skips step check

The fourth positional argument of aten.slice.Tensor is end, not length. The variable name is misleading, but more importantly the handler unconditionally indexes args[3] without checking whether a fifth step argument is provided and equals 1. A non-unit step with start=0, end=size would still pass the existing check and incorrectly return an unmodified NVFP4Tensor, silently skipping the stride.

Suggested change
if func == aten.slice.Tensor:
tensor = args[0]
dim = args[1]
start = args[2]
length = args[3]
if start == 0 and length == tensor.size(dim):
return NVFP4Tensor.make_like(tensor)
if func == aten.slice.Tensor:
tensor = args[0]
dim = args[1] if len(args) > 1 else 0
start = args[2] if len(args) > 2 else None
end = args[3] if len(args) > 3 else None
step = args[4] if len(args) > 4 else 1
if step == 1 and (start is None or start == 0) and (end is None or end >= tensor.size(dim)):
return NVFP4Tensor.make_like(tensor)

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed. Renamed length → end, added defensive len(args) checks for optional positional args, and added step == 1 guard. Also handles the case where start/end are None (the ATen defaults

Comment on lines +592 to +597
# Always send both orientations (GEMM needs both for fwd/bwd)
rowwise_usage = True
sharded_tensors = (rowwise_data, rowwise_scale_inv)
columnwise_usage = self._quantizer.columnwise_usage
if columnwise_usage:
sharded_tensors += (columnwise_data, columnwise_scale_inv)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 Hardcoded rowwise_usage = True can silently pass None to FSDP2

rowwise_usage is hardcoded to True without checking whether _rowwise_data is actually populated. If an NVFP4Tensor is created with only columnwise data (rowwise=False), rowwise_data will be None but it will still be included in sharded_tensors. FSDP2 would then try to all-gather None, which will crash or corrupt silently.

Consider guarding this with an explicit check:

rowwise_usage = self._rowwise_data is not None

Or, if both orientations are always required for FSDP2 training, add an assertion to catch misconfigured tensors early:

rowwise_usage = True
assert self._rowwise_data is not None, (
    "FSDP2 requires rowwise data, but _rowwise_data is None. "
    "Ensure the NVFP4Quantizer was created with rowwise=True."
)

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added an assertion to fail early if _rowwise_data is None:
▎ assert self._rowwise_data is not None, (
▎ "FSDP2 requires rowwise data, but _rowwise_data is None. "
▎ "Ensure the NVFP4Quantizer was created with rowwise=True."
▎ )

Comment on lines +591 to +592

# Always send both orientations (GEMM needs both for fwd/bwd)
Copy link
Copy Markdown
Collaborator

@vthumbe1503 vthumbe1503 Mar 24, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@jomitchellnv we shouldnt need both orientations. For forward pass only rowwise is needed and for backward pass only columnwise is needed.

Can you refer to this PR for the optimizations
#2789
Right now the columnwise allhgather implementation is transposing, allgathering in pre_allgather and then again transposing after post allgather which is expensive.
You can make it better one of the two ways for columnwise allgather

  1. Always allgather rowwise data similar to above PR. And if columnwise usage is needed then you can just transpose it in post allgather
  2. allgather only columnwise data if columnwise usage is set and then interleave the the columnwise stacked data.
    1 is fine to have for now.
    I would also add assertions for 2d Scaling similar to fp8 blockscaling implemntation.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the pointer to #2789 — that makes sense. I'll adopt option 1: only all-gather rowwise data +
scales, and derive columnwise locally in post_all_gather if columnwise_usage is set. Will update.

On the 2D scaling assertion — NVFP4 doesn't have _is_2D_scaled like Float8Blockwise, but I can add a guar on with_2d_quantization. Could you clarify what layout constraint you're thinking of? The NVFP4 rowwise scale has M in dim0 (round_up(M, 128), ...) regardless of that flag, so it should be compatible with dim0 all-gather either way. Want to make sure I add the right check.

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, exactly with_2d_quantization should be asserted to be True. If with_2d_quantization is False then columnwise is not derivable from rowwise data/scales(Although this is not a common use-case)

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added an assertion in fsdp_pre_all_gather that with_2d_quantization == True when columnwise data is included in the all-gather. This ensures we fail early rather than silently passing None columnwise tensors

@jomitchellnv jomitchellnv force-pushed the jm/nvfp4-block-fused-adam branch 2 times, most recently from 178c7c3 to 238f2df Compare March 24, 2026 21:50
Comment on lines +666 to +671
if current_m_blocks < target_m_blocks:
columnwise_scale_inv = torch.nn.functional.pad(
columnwise_scale_inv, (0, target_m_blocks - current_m_blocks)
)
elif current_m_blocks > target_m_blocks:
columnwise_scale_inv = columnwise_scale_inv[:, :target_m_blocks]
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 Silent data corruption when shard_M is not a multiple of NVFP4_BLOCK_SCALING_SIZE

The elif current_m_blocks > target_m_blocks trim branch silently discards valid scale data rather than surfacing an alignment error.

Here is why this happens: fsdp_pre_all_gather computes m_blocks = ceil(shard_M / 16) per shard. After all-gather, the concatenated total is world_size * ceil(shard_M / 16). The target is round_up(ceil(full_M / 16), 4). When shard_M % 16 != 0, world_size * ceil(shard_M / 16) > ceil(full_M / 16), and the trim condition fires.

Concrete example: full_M = 136, world_size = 8, shard_M = 17.

  • Each rank: ceil(17/16) = 2 m-blocks
  • All-gathered: 16 m-blocks
  • Target: round_up(ceil(136/16), 4) = round_up(9, 4) = 12
  • elif trims from 16 → 12, discarding 4 real m-block scale columns

Beyond the trim, this represents a deeper problem: each rank's m-block 0 covers its local rows 0–15, not the same global rows. When shard_M % 16 != 0, the per-rank m-blocks do not align with global m-blocks, and the all-gathered scale tensor is fundamentally scrambled before the trim even runs.

The fix is to assert the constraint at the top of fsdp_pre_all_gather rather than silently masking the symptom:

assert shard_M % NVFP4_BLOCK_SCALING_SIZE == 0, (
    f"FSDP2 requires shard_M ({shard_M}) to be a multiple of "
    f"NVFP4_BLOCK_SCALING_SIZE ({NVFP4_BLOCK_SCALING_SIZE}). "
    "Ensure the weight's row count is divisible by world_size * 16."
)

And the elif trim should then become an unreachable assertion (or be removed), so any unexpected mismatch is caught loudly:

Suggested change
if current_m_blocks < target_m_blocks:
columnwise_scale_inv = torch.nn.functional.pad(
columnwise_scale_inv, (0, target_m_blocks - current_m_blocks)
)
elif current_m_blocks > target_m_blocks:
columnwise_scale_inv = columnwise_scale_inv[:, :target_m_blocks]
if current_m_blocks < target_m_blocks:
columnwise_scale_inv = torch.nn.functional.pad(
columnwise_scale_inv, (0, target_m_blocks - current_m_blocks)
)
else:
assert current_m_blocks == target_m_blocks, (
f"Unexpected m_block count after all-gather: "
f"got {current_m_blocks}, expected {target_m_blocks}"
)

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed. Added an assertion in fsdp_pre_all_gather that shard_M % NVFP4_BLOCK_SCALING_SIZE == 0, which
guarantees m-blocks align across ranks after all-gather. Converted the elif trim to an unreachable assertion —
if the shard alignment holds, current_m_blocks can never exceed target_m_blocks.

Comment on lines +727 to +736
if func == aten.as_strided.default:
tensor = args[0]
shape = args[1]
strides = args[2]
if (
len(shape) == len(strides) == 2
and tuple(strides) == (shape[-1], 1)
and tuple(shape) == tuple(tensor.size())
):
return NVFP4Tensor.make_like(tensor)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 as_strided handler silently falls through for non-standard calls

When the shape/stride check does not match (e.g., FSDP2 applies as_strided with a non-unit stride, a shape mismatch, or a non-zero storage offset), the handler does not return anything and falls through to super().__torch_dispatch__(). That super call is unlikely to know how to handle an NVFP4Tensor, so the result will be wrong or raise an obscure error.

The storage_offset argument (args[3] when present) is also never inspected — if FSDP2 ever supplies a non-zero offset, the identity check still fires and silently returns the unmodified tensor, ignoring the offset.

Since the intent is to handle only the no-op identity case and let everything else fall through deliberately, consider making that explicit with a comment or raising NotImplementedError for the non-identity case so failures are surfaced early rather than producing undefined behaviour from the super dispatch.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed — as_strided now validates storage_offset == 0 and raises NotImplementedError for any non-identity call
instead of silently falling through

@jomitchellnv jomitchellnv force-pushed the jm/nvfp4-block-fused-adam branch from 96cf20f to 0222882 Compare April 6, 2026 18:28
@jomitchellnv
Copy link
Copy Markdown
Contributor Author

/te-ci L1 pytorch

@pytest.mark.parametrize("layer_type", ["LayerNormLinear", "TransformerLayer"])
def test_distributed(recipe_name, fp8_init, sharding_dims, layer_type):
if recipe_name in ("Float8BlockScaling", "NVFP4BlockScaling") and fp8_init:
if recipe_name == "Float8BlockScaling" and fp8_init:
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does fp8 block scaling still fail? You had fixed this in here #2753
right?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

removed it!

"""Verify that GEMM-swizzled scales raise NotImplementedError."""
shape = (512, 256)
qt = _make_nvfp4_tensor(shape)
qt._with_gemm_swizzled_scales = True
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The right way to do this is to set quantizer.optimize_for_gemm = True and the quantized tensor automatically gets _with_gemm_swizzled_scales attribute set to true with in fact the scales being actually swizzled.

Here we are just setting the attribute to be True without actually swizzling the scales.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I replaced the direct qt._with_gemm_swizzled_scales = True hack with the proper API: creating an
NVFP4Quantizer with quantizer.optimize_for_gemm = True and quantizing through it. Since NVFP4's C++ path doesn't wire up optimize_for_gemm yet (hardcoded false with TODO at quantizer.cpp:1731 and :1993), the test gracefully skips with pytest.skip() explaining why, and will auto-activate once the C++ TODO is addressed.

assert torch.equal(result._amax_columnwise, orig_amax_col)

@pytest.mark.parametrize("shape", _test_shapes)
def test_round_trip_dequantize(self, shape: Tuple[int, int]):
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There is a lot of code duplication with the previous round trip test without the dequantize. Could we merge this test into that one?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Merged test_round_trip_dequantize into test_round_trip_data_integrity. The dequantize check
(orig_deq, result.dequantize(), assert_close) was added at the end of the data integrity test. The separate
test_round_trip_dequantize method was removed entirely. This eliminates the duplicated setup (tensor creation,
pre_all_gather, simulate, post_all_gather

@jomitchellnv
Copy link
Copy Markdown
Contributor Author

/te-ci L1 pytorch

@jomitchellnv
Copy link
Copy Markdown
Contributor Author

/te-ci L1 pytorch

1 similar comment
@vthumbe1503
Copy link
Copy Markdown
Collaborator

/te-ci L1 pytorch

Jonathan Mitchell and others added 10 commits April 9, 2026 15:23
Signed-off-by: Jonathan Mitchell <jomitchell@ipp1-1429.ipp1a1.colossus.nvidia.com>
Signed-off-by: Jonathan Mitchell <jomitchell@ipp1-1429.ipp1a1.colossus.nvidia.com>
Signed-off-by: Jonathan Mitchell <jomitchell@ipp1-1429.ipp1a1.colossus.nvidia.com>
Signed-off-by: Jonathan Mitchell <jomitchell@ipp1-1429.ipp1a1.colossus.nvidia.com>
  1. nvfp4_tensor.py — as_strided hardening + with_2d_quantization assertion
  2. test_nvfp4_fsdp2_hooks.py — remove unused tex import
  3. run_fsdp2_model.py — remove stale Float8BlockScaling xfail

  PR comments to leave:
  - greptile P1 (as_strided): Fixed — validates storage_offset == 0 and raises NotImplementedError for
  non-identity calls.
  - vthumbe1503 (2D quantization): Added assertion that with_2d_quantization == True when columnwise data is in
  the all-gather.
  - greptile P2 (unused import): Removed.
  - vthumbe1503 (Float8BlockScaling xfail): Correct, fixed in NVIDIA#2753 — removed the xfail

Signed-off-by: Jonathan Mitchell <jomitchell@r6515-0097.ipp1a1.colossus.nvidia.com>
  1. Enable with_2d_quantization in test helper to satisfy
     fsdp_pre_all_gather assertion added in 339ec09
  2. Merge test_round_trip_dequantize into
     test_round_trip_data_integrity to reduce duplication
  3. Use quantizer.optimize_for_gemm API instead of directly
     setting _with_gemm_swizzled_scales in swizzled scales
     rejection test (skips until C++ wires up NVFP4 support

Signed-off-by: Jonathan Mitchell <jomitchell@ipp1-1334.ipp1a1.colossus.nvidia.com>
  Derive columnwise data locally via _create_columnwise()
  in post_all_gather instead of all-gathering both
  orientations. This halves the all-gather communication
  volume, matching the pattern used by
  Float8BlockwiseQTensor. Updates tests accordingly.

Signed-off-by: Jonathan Mitchell <jomitchell@ipp1-1334.ipp1a1.colossus.nvidia.com>
Signed-off-by: Jonathan Mitchell <jomitchell@ipp1-1334.ipp1a1.colossus.nvidia.com>
@jomitchellnv jomitchellnv force-pushed the jm/nvfp4-block-fused-adam branch from 11c5f08 to 0471224 Compare April 9, 2026 22:23
@vthumbe1503
Copy link
Copy Markdown
Collaborator

/te-ci L1 pytorch

  1. _check_fp8_fsdp2_allgather: zeros_like(local_tensor) returns an
     NVFP4Tensor (not a plain tensor) because NVFP4Tensor.__torch_dispatch__
     intercepts empty_like/zero_. This causes a dtype mismatch in
     dist.all_gather. Fix by dequantizing first, then creating plain-tensor
     output buffers.
  2. Re-add Float8BlockScaling + fp8_init xfail removed in d4b7337 — scale
     inverse padding during FSDP2 all-gather slice ops is still unhandled

Signed-off-by: Jonathan Mitchell <jomitchell@dl325g11-0771.ipp4a1.colossus.nvidia.com>
X
Signed-off-by: Jonathan Mitchell <jomitchell@dl325g11-0771.ipp4a1.colossus.nvidia.com>
@jomitchellnv jomitchellnv force-pushed the jm/nvfp4-block-fused-adam branch from aa64e7a to 7a81c83 Compare April 10, 2026 23:15
@jomitchellnv
Copy link
Copy Markdown
Contributor Author

/te-ci L1 pytorch

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants