-
Notifications
You must be signed in to change notification settings - Fork 691
[JAX] MXFP8 Grouped Quant+GEMM #2763
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Open
jberchtold-nvidia
wants to merge
73
commits into
NVIDIA:main
Choose a base branch
from
jberchtold-nvidia:jberchtold/gmm-mxfp8
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Open
Changes from all commits
Commits
Show all changes
73 commits
Select commit
Hold shift + click to select a range
28e5f53
Refactor to group_sizes per tensor
jberchtold-nvidia 4a57485
Support first_dims and last_dims instead of a single group_sizes per
jberchtold-nvidia 345d940
Refactor GMM FFIs to store static attrs as structs
jberchtold-nvidia ed9c8e4
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] ed0deaf
Cleanup C++ v2 FFI
jberchtold-nvidia 88bb7da
Fix int64 workspace usage
jberchtold-nvidia 60312c8
Address greptile comments
jberchtold-nvidia 025f598
Refactor wgrad-specific checks to be generic for GMM in gemm.py
jberchtold-nvidia 089e530
Refactor XLA FFI struct setup
jberchtold-nvidia 8ad2294
Fix edge case in TE v1 GMM
jberchtold-nvidia bac092d
Merge remote-tracking branch 'github-upstream/main' into jberchtold/g…
jberchtold-nvidia 4ff5d1d
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] 0cb7289
Fix issues on Hopper
jberchtold-nvidia 37d300a
Merge remote-trackint commit --amend -sg branch 'github-upstream/main…
jberchtold-nvidia cc236ad
Refactor
jberchtold-nvidia 1d1fec9
MXFP8 grouped quantize V2
jberchtold-nvidia 269a518
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] 2b84dfd
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] b2b3216
MXFP8 quantization working
jberchtold-nvidia 47218b3
Merge remote-tracking branch 'github-upstream/main' into jberchtold/g…
jberchtold-nvidia 611526f
mxfp8 grouped gemm
jberchtold-nvidia c97b0b7
te_permutation NaN issue fix
jberchtold-nvidia 0b9a763
Support GroupedDense quantization checkpointing
jberchtold-nvidia 6b64cea
Temporary commit to assert if V1 grouped quantize is used
jberchtold-nvidia 2dd69d4
Fix scale shapes for MXFP8
jberchtold-nvidia 204b326
Fix MXFP8 scale sharding when FSDP+EP on same axis
jberchtold-nvidia 5fb585f
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] 2902eb2
Merge remote-tracking branch 'github-upstream/main' into jberchtold/g…
jberchtold-nvidia bee7f3b
Address comments
jberchtold-nvidia d9b9c44
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] ef0d498
Merge branch 'main' into jberchtold/gmm-refactor
jberchtold-nvidia 9438478
Lint
jberchtold-nvidia 09dfd9c
Fixes for Hopper
jberchtold-nvidia e25538e
Address review comments
jberchtold-nvidia 78674e9
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] d5229e2
Merge branch 'main' into jberchtold/gmm-refactor
jberchtold-nvidia b78435a
Merge jberchtold/gmm-refactor into jberchtold/gmm-mxfp8
jberchtold-nvidia 06ebb44
Fixes
jberchtold-nvidia a3f8042
wip
jberchtold-nvidia 7e99314
Fix grouped colwise dequantize for transposed ragged tensors and V1 p…
jberchtold-nvidia 68bcbfc
2D shape fixes for flattened 1D shape from grouped quantization
jberchtold-nvidia 81cb189
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] 75995e4
Merge remote-tracking branch 'github-upstream/main' into jberchtold/g…
jberchtold-nvidia d7b04cc
Fix swizzling
jberchtold-nvidia 064f314
Remove pre-swizzling from non-grouped quantization
jberchtold-nvidia 5edef90
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] c55969c
Use avg m,n,k heuristics for cuBLASLt Grouped GEMM
jberchtold-nvidia 427d5b6
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] 167c343
Update transformer_engine/jax/cpp_extensions/gemm.py
jberchtold-nvidia ae97af1
Use avg m,n,k heuristics for cuBLASLt Grouped GEMM
jberchtold-nvidia f1c7582
Fix rhs transpose flag
jberchtold-nvidia b3ea76a
Fix rhs transpose flag
jberchtold-nvidia 6387b8a
Address comments
jberchtold-nvidia 7febb9b
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] 7e94996
Merge branch 'main' into jberchtold/gmm-avg-mnk
jberchtold-nvidia fbebfea
Merge branch 'jberchtold/gmm-avg-mnk' into jberchtold/gmm-mxfp8
jberchtold-nvidia 2e1a9f5
Fix merge issue
jberchtold-nvidia 7769c51
Remove unnecessary changes
jberchtold-nvidia 6fbe4ca
Cleanup tests
jberchtold-nvidia 7cafd35
Fix tests
jberchtold-nvidia 49e7a60
Use GroupedTensorWrapper in grouped quantization
jberchtold-nvidia 644520b
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] 087bd2e
Merge remote-tracking branch 'github-upstream/main' into jberchtold/g…
jberchtold-nvidia 56fce55
Fix merge conflict issue
jberchtold-nvidia 9ea2482
Address comments
jberchtold-nvidia 2af15e5
Clean up grouped_gemm function
jberchtold-nvidia 6535819
Test fixes
jberchtold-nvidia bf6377b
Fix old var names in V1 python codepath
jberchtold-nvidia 16a4bf7
Fix lint
jberchtold-nvidia 513108a
Merge remote-tracking branch 'github-upstream/main' into jberchtold/g…
jberchtold-nvidia 9ced1c5
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] 4da6b80
Fix Hopper V1 FP8 grouped GEMMs
jberchtold-nvidia 5cf90d5
Merge branch 'main' into jberchtold/gmm-mxfp8
jberchtold-nvidia File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I have an idea for this in case n_groups ever gets large: do 32 threads cumsum in blocks then warp shfl to reduce local sums to 1 sum.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sounds good! Currently the kernel runtime is pretty small relative to our other kernels and our n_groups per device is fairly small with EP, but good idea for future if n_groups per device gets bigger