Skip to content

[JAX] Fix BF16 tolerance for CGEMM + RS + BF16 test#2860

Open
phu0ngng wants to merge 2 commits intoNVIDIA:mainfrom
phu0ngng:cgemm_bf16_fix_tols
Open

[JAX] Fix BF16 tolerance for CGEMM + RS + BF16 test#2860
phu0ngng wants to merge 2 commits intoNVIDIA:mainfrom
phu0ngng:cgemm_bf16_fix_tols

Conversation

@phu0ngng
Copy link
Copy Markdown
Collaborator

@phu0ngng phu0ngng commented Apr 9, 2026

Description

atol=1e-5 was too strict for BF16 comparisons between the NONE collective GEMM and Collective GEMM with RS collective paths. Both paths split K across TP ranks and produce identical BF16 partial GEMMs, but reduce them in different orders:

  • NONE (NCCL all-reduce): ((p0+p1)+(p2+p3)) — binary tree in FP32 → BF16
  • RS (reduce_bf16 kernel): ((p0+p1)+p2)+p3 — sequential in FP32 → BF16

Different reduction associativity causes rounding differences of up to 1 BF16 ULP of the partial GEMM magnitude. The combined tolerance atol + rtol*|ref| covers this across all output scales:

  • Large outputs (|ref| > atol/rtol = 12.5): rtol=1e-2 dominates and provides sufficient coverage.
  • Near-zero outputs: rtol provides no coverage, so atol=0.125 (2× the worst-case 1-ULP diff at O(8) scale) is needed. atol=1e-5 failed because it is far below 1 ULP at any realistic activation magnitude.

Reproducer

The mismatch is verified by a standalone test (https://gist.github.com/phu0ngng/9600caf76df6040ecc4b3f3c6ea20882) that mimics the two collective paths on a single GPU:

  1. Generate TP=4 BF16 partial GEMMs matching the per-rank GEMM size fromtest_gemm.py (M=8192, K_tp=1024, N=16384, seed=PRNGKey(0)).
  2. Reduce via NCCL binary-tree order (C_none) and TE sequential order (C_rs).
  3. Compare element-wise in BF16 ULPs.
$ CUDA_VISIBLE_DEVICES=1 python test_gemm_reduction.py
M=8192 K_tp=1024 N=16384 TP=4: 2 diffs, max=0.5000, max_ulps=1.00 PASS

2 elements differ by exactly 1 BF16 ULP.

Type of change

  • Documentation change (change only to the documentation, either a fix or a new content)
  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • Infra/Build change
  • Code refactoring

Changes

Increase atol from 1e-5 to 0.125 to cover the near-zero regime where rtol provides no coverage. Large-magnitude diffs (the common case) are already handled by rtol=1e-2.

Checklist:

  • I have read and followed the contributing guidelines
  • The functionality is complete
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com>
@phu0ngng phu0ngng marked this pull request as ready for review April 9, 2026 03:13
@phu0ngng phu0ngng requested a review from ptrendx April 9, 2026 03:13
@phu0ngng
Copy link
Copy Markdown
Collaborator Author

phu0ngng commented Apr 9, 2026

/te-ci JAX L0

@greptile-apps
Copy link
Copy Markdown
Contributor

greptile-apps bot commented Apr 9, 2026

Greptile Summary

This PR fixes a false test failure in the CGEMM + RS + BF16 path by raising atol from 1e-5 to 0.125. The tighter tolerance was incorrect because NCCL (used by the NONE/reference path) reduces via binary-tree order while TE's reduce_bf16 kernel uses sequential left-to-right order; for near-zero outputs this difference in associativity produces absolute errors of up to 1 BF16 ULP (~0.0625 at O(8) scale), which the new atol=0.125 (2×) correctly covers.

Confidence Score: 5/5

Safe to merge — single-line test-tolerance fix with solid mathematical justification and no impact on production code.

The change is a minimal, well-reasoned relaxation of a test tolerance. The tolerance derivation (2× 1 BF16 ULP at O(8) scale = 0.125) is correct and verified by a standalone reproducer. The assert_allclose helper correctly handles the None fallback for all other paths, so no regressions are introduced. No P0 or P1 findings.

No files require special attention.

Vulnerabilities

No security concerns identified.

Important Files Changed

Filename Overview
examples/jax/collective_gemm/test_gemm.py Adds a targeted tolerance override (rtol=1e-2, atol=0.125) for the CGEMM+RS+BF16 path and passes both values to assert_allclose, which correctly falls back to dtype defaults when both are None (all other paths). Logic is sound.

Flowchart

%%{init: {'theme': 'neutral'}}%%
flowchart TD
    A[run_gemm_tests] --> B{enable_result_check\n&& process_id == 0?}
    B -- No --> Z[End]
    B -- Yes --> C{collective_op ==\nREDUCE_SCATTER\n&& !use_quantization?}
    C -- Yes\nis_cgemm_rs_bf16=True --> D["rtol=1e-2\natol=0.125\n(covers 1 BF16 ULP at O(8) scale)"]
    C -- No\nis_cgemm_rs_bf16=False --> E["rtol=None\natol=None\n→ dtype_tols fallback"]
    D --> F[assert_allclose\ngathered_ref_output vs gathered_output]
    E --> F
Loading

Reviews (1): Last reviewed commit: "Merge branch 'main' into cgemm_bf16_fix_..." | Re-trigger Greptile

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant