Skip to content

[PyTorch] Support scaled + clamped SwiGLU in te.ops and enable fused MXFP8 grouped MLP#2855

Merged
ksivaman merged 7 commits intoNVIDIA:mainfrom
ksivaman:fused_mxfp8_mlp_geglu
Apr 9, 2026
Merged

[PyTorch] Support scaled + clamped SwiGLU in te.ops and enable fused MXFP8 grouped MLP#2855
ksivaman merged 7 commits intoNVIDIA:mainfrom
ksivaman:fused_mxfp8_mlp_geglu

Conversation

@ksivaman
Copy link
Copy Markdown
Member

@ksivaman ksivaman commented Apr 8, 2026

Description

This PR adds support for the grouped MLP fused path via cuDNN frontend when the scaled clamped swiglu activation is used. This activation is misnamed to "geglu" in cudnn-frontend.

Type of change

  • Documentation change (change only to the documentation, either a fix or a new content)
  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • Infra/Build change
  • Code refactoring

Changes

  • Implement ScaledClampedSwiGLU in TE sequential.
  • Small refactor in swiglu.py.
  • Enable grouped MLP fused path when using this activation via cuDNN frontend.

Checklist:

  • I have read and followed the contributing guidelines
  • The functionality is complete
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

Signed-off-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com>
@ksivaman ksivaman requested a review from timmoon10 April 8, 2026 18:50
Signed-off-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com>
@greptile-apps
Copy link
Copy Markdown
Contributor

greptile-apps bot commented Apr 8, 2026

Greptile Summary

This PR introduces ScaledClampedQGeGLU — a clamped QGeGLU activation with per-row post-scaling — and enables the MXFP8 fused grouped MLP cuDNN path for it. The cuDNN kernel identifies this activation as "geglu" / "dgeglu" (a known misnaming on the cuDNN FE side, acknowledged in comments), and a minimum cuDNN FE version gate (>= 1.23.0) is checked before allowing fusion.

Confidence Score: 4/5

Mostly safe to merge; the missing limit boundary guard for cuDNN fusion eligibility (flagged in prior rounds) and the acknowledged numerics skip for nvfp4+geglu+bias remain unresolved.

The new ScaledClampedQGeGLU implementation, fused kernel dispatch strings, and cuDNN FE version gate are all correct. However, two concerns from prior review rounds persist without resolution: (1) the fusion eligibility check in fuse_grouped_mlp_ops validates alpha but not limit, so a ScaledClampedQGeGLU(limit=5.0, glu_interleave_size=32) would silently trigger the cuDNN kernel with the wrong clamp boundary; (2) the nvfp4 + geglu + bias test combination is skipped with an unlinked TODO. All findings from this review pass are P2 or lower.

transformer_engine/pytorch/ops/_common.py — fusion eligibility guard for ScaledClampedQGeGLU is incomplete.

Vulnerabilities

No security concerns identified. The changes are confined to numerical computation paths (activation functions and fused kernel dispatch) with no user-supplied data reaching system calls, SQL, or credential handling.

Important Files Changed

Filename Overview
transformer_engine/pytorch/ops/basic/swiglu.py Adds ScaledClampedQGeGLU subclassing _ScaledGLU, delegating to ClampedSwiGLU internals; _clamped is created without glu_interleave_size (intentional per PR design) but the private attribute is exposed for fusion eligibility checks in _common.py.
transformer_engine/pytorch/ops/_common.py Adds cuDNN FE version gate and alpha guard for ScaledClampedQGeGLU fusion eligibility; the matching limit guard is absent, meaning a non-default limit would silently run the cuDNN kernel with incorrect clamp boundaries.
transformer_engine/pytorch/ops/fused/forward_grouped_mlp.py Extends type annotation and stores _cudnn_act_func at construction time; passes "geglu" for ScaledClampedQGeGLU and "swiglu" for ScaledSwiGLU to the kernel call.
transformer_engine/pytorch/ops/fused/backward_grouped_mlp.py Mirrors forward changes: stores _cudnn_dact_func ("dgeglu" or "dswiglu") and substitutes it into the kernel call.
tests/pytorch/test_fusible_ops.py Adds unit test for ScaledClampedQGeGLU and parametrizes test_grouped_mlp with activation; includes a pytest.skip for the confirmed nvfp4+geglu+bias numerics defect with a TODO but no issue reference.
docs/api/pytorch.rst Adds autoapiclass entry for ScaledClampedQGeGLU, correctly placed alongside sibling GLU ops.
transformer_engine/pytorch/ops/basic/init.py Adds ScaledClampedQGeGLU to the basic package re-export, making it accessible via te_ops.ScaledClampedQGeGLU.

Flowchart

%%{init: {'theme': 'neutral'}}%%
flowchart TD
    A[fuse_grouped_mlp_ops\nscanning window] --> B{window matches\nGroupedLinear + GLU + GroupedLinear?}
    B -- No --> C[advance window, no fusion]
    B -- Yes --> D{GLU type?}
    D -- ScaledSwiGLU --> E[matches_pattern = True]
    D -- ScaledClampedQGeGLU --> F{alpha ≈ 1.702 AND\ncuDNN FE >= 1.23.0?}
    F -- No --> C
    F -- Yes --> G{limit == 7.0?\n⚠️ NOT CHECKED}
    G -- limit ≠ 7.0\nsilently proceeds --> E
    G -- limit == 7.0 --> E
    E --> H{dims / interleave\n/ bias checks pass?}
    H -- No --> C
    H -- Yes --> I[Create FusedOp\nact_func = 'geglu' or 'swiglu']
Loading

Reviews (5): Last reviewed commit: "Merge branch 'main' into fused_mxfp8_mlp..." | Re-trigger Greptile

Comment on lines +3378 to +3380
if quantization == "nvfp4" and activation == "geglu" and bias:
# TODO: ksivaman: Need to debug numerics for this case.
pytest.skip("Bias/dbias not yet supported in NVFP4 fused grouped MLP with GeGLU")
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 Known numerics failure left unresolved

The skip guard documents a confirmed numerics defect ("Need to debug numerics for this case") for the nvfp4 + geglu + bias combination. That combination is part of the newly enabled activation="geglu" parametrize added in this very PR, so it represents a regression in the test matrix introduced here rather than a pre-existing gap. Merging with an acknowledged-but-undebugged wrong-output path may silently hide the root cause; a follow-up issue or a clear deferral note (e.g., a GitHub issue reference) would make this safer to track.

class ScaledSwiGLU(BasicOperation):
r"""SwiGLU with post-scaling.
class _ScaledGLU(BasicOperation):
"""SwiGLU-family activation with per-row scales (fused grouped MLP middle op)."""
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: We're being a buit weird and nonstandard by treating SwiGLU as the canonical GLU. Sigmoid GLU was first and torch.nn.GLU is a sigmoid GLU. Also, we shouldn't imply that they are specific to Transformer MLPs, since gated activations were used for RNNs long before Transformers were invented.

Signed-off-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com>
@greptile-apps
Copy link
Copy Markdown
Contributor

greptile-apps bot commented Apr 8, 2026

Tip:

Greploop — Automatically fix all review issues by running /greploops in Claude Code. It iterates: fix, push, re-review, repeat until 5/5 confidence.

Use the Greptile plugin for Claude Code to query reviews, search comments, and manage custom context directly from your terminal.

@ksivaman
Copy link
Copy Markdown
Member Author

ksivaman commented Apr 8, 2026

/te-ci pytorch L0

timmoon10
timmoon10 previously approved these changes Apr 8, 2026
Copy link
Copy Markdown
Collaborator

@timmoon10 timmoon10 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Overall LGTM, with minor suggestions.

Signed-off-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com>
timmoon10
timmoon10 previously approved these changes Apr 8, 2026
@ksivaman
Copy link
Copy Markdown
Member Author

ksivaman commented Apr 8, 2026

/te-ci pytorch L0

Signed-off-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com>
@ksivaman ksivaman merged commit 64bb9a2 into NVIDIA:main Apr 9, 2026
10 of 12 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants