[PyTorch] Support scaled + clamped SwiGLU in te.ops and enable fused MXFP8 grouped MLP#2855
Conversation
Signed-off-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com>
Signed-off-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com>
Greptile SummaryThis PR introduces Confidence Score: 4/5Mostly safe to merge; the missing The new
|
| Filename | Overview |
|---|---|
| transformer_engine/pytorch/ops/basic/swiglu.py | Adds ScaledClampedQGeGLU subclassing _ScaledGLU, delegating to ClampedSwiGLU internals; _clamped is created without glu_interleave_size (intentional per PR design) but the private attribute is exposed for fusion eligibility checks in _common.py. |
| transformer_engine/pytorch/ops/_common.py | Adds cuDNN FE version gate and alpha guard for ScaledClampedQGeGLU fusion eligibility; the matching limit guard is absent, meaning a non-default limit would silently run the cuDNN kernel with incorrect clamp boundaries. |
| transformer_engine/pytorch/ops/fused/forward_grouped_mlp.py | Extends type annotation and stores _cudnn_act_func at construction time; passes "geglu" for ScaledClampedQGeGLU and "swiglu" for ScaledSwiGLU to the kernel call. |
| transformer_engine/pytorch/ops/fused/backward_grouped_mlp.py | Mirrors forward changes: stores _cudnn_dact_func ("dgeglu" or "dswiglu") and substitutes it into the kernel call. |
| tests/pytorch/test_fusible_ops.py | Adds unit test for ScaledClampedQGeGLU and parametrizes test_grouped_mlp with activation; includes a pytest.skip for the confirmed nvfp4+geglu+bias numerics defect with a TODO but no issue reference. |
| docs/api/pytorch.rst | Adds autoapiclass entry for ScaledClampedQGeGLU, correctly placed alongside sibling GLU ops. |
| transformer_engine/pytorch/ops/basic/init.py | Adds ScaledClampedQGeGLU to the basic package re-export, making it accessible via te_ops.ScaledClampedQGeGLU. |
Flowchart
%%{init: {'theme': 'neutral'}}%%
flowchart TD
A[fuse_grouped_mlp_ops\nscanning window] --> B{window matches\nGroupedLinear + GLU + GroupedLinear?}
B -- No --> C[advance window, no fusion]
B -- Yes --> D{GLU type?}
D -- ScaledSwiGLU --> E[matches_pattern = True]
D -- ScaledClampedQGeGLU --> F{alpha ≈ 1.702 AND\ncuDNN FE >= 1.23.0?}
F -- No --> C
F -- Yes --> G{limit == 7.0?\n⚠️ NOT CHECKED}
G -- limit ≠ 7.0\nsilently proceeds --> E
G -- limit == 7.0 --> E
E --> H{dims / interleave\n/ bias checks pass?}
H -- No --> C
H -- Yes --> I[Create FusedOp\nact_func = 'geglu' or 'swiglu']
Reviews (5): Last reviewed commit: "Merge branch 'main' into fused_mxfp8_mlp..." | Re-trigger Greptile
tests/pytorch/test_fusible_ops.py
Outdated
| if quantization == "nvfp4" and activation == "geglu" and bias: | ||
| # TODO: ksivaman: Need to debug numerics for this case. | ||
| pytest.skip("Bias/dbias not yet supported in NVFP4 fused grouped MLP with GeGLU") |
There was a problem hiding this comment.
Known numerics failure left unresolved
The skip guard documents a confirmed numerics defect ("Need to debug numerics for this case") for the nvfp4 + geglu + bias combination. That combination is part of the newly enabled activation="geglu" parametrize added in this very PR, so it represents a regression in the test matrix introduced here rather than a pre-existing gap. Merging with an acknowledged-but-undebugged wrong-output path may silently hide the root cause; a follow-up issue or a clear deferral note (e.g., a GitHub issue reference) would make this safer to track.
| class ScaledSwiGLU(BasicOperation): | ||
| r"""SwiGLU with post-scaling. | ||
| class _ScaledGLU(BasicOperation): | ||
| """SwiGLU-family activation with per-row scales (fused grouped MLP middle op).""" |
There was a problem hiding this comment.
Nit: We're being a buit weird and nonstandard by treating SwiGLU as the canonical GLU. Sigmoid GLU was first and torch.nn.GLU is a sigmoid GLU. Also, we shouldn't imply that they are specific to Transformer MLPs, since gated activations were used for RNNs long before Transformers were invented.
Signed-off-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com>
|
Tip: Greploop — Automatically fix all review issues by running Use the Greptile plugin for Claude Code to query reviews, search comments, and manage custom context directly from your terminal. |
|
/te-ci pytorch L0 |
timmoon10
left a comment
There was a problem hiding this comment.
Overall LGTM, with minor suggestions.
Signed-off-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com>
|
/te-ci pytorch L0 |
Signed-off-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com>
Description
This PR adds support for the grouped MLP fused path via cuDNN frontend when the scaled clamped swiglu activation is used. This activation is misnamed to "geglu" in
cudnn-frontend.Type of change
Changes
ScaledClampedSwiGLUin TE sequential.swiglu.py.Checklist: