[JAX] Fix BF16 tolerance for CGEMM + RS + BF16 test#2860
[JAX] Fix BF16 tolerance for CGEMM + RS + BF16 test#2860phu0ngng wants to merge 2 commits intoNVIDIA:mainfrom
Conversation
Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com>
|
/te-ci JAX L0 |
Greptile SummaryThis PR fixes a false test failure in the CGEMM + RS + BF16 path by raising Confidence Score: 5/5Safe to merge — single-line test-tolerance fix with solid mathematical justification and no impact on production code. The change is a minimal, well-reasoned relaxation of a test tolerance. The tolerance derivation (2× 1 BF16 ULP at O(8) scale = 0.125) is correct and verified by a standalone reproducer. The assert_allclose helper correctly handles the None fallback for all other paths, so no regressions are introduced. No P0 or P1 findings. No files require special attention.
|
| Filename | Overview |
|---|---|
| examples/jax/collective_gemm/test_gemm.py | Adds a targeted tolerance override (rtol=1e-2, atol=0.125) for the CGEMM+RS+BF16 path and passes both values to assert_allclose, which correctly falls back to dtype defaults when both are None (all other paths). Logic is sound. |
Flowchart
%%{init: {'theme': 'neutral'}}%%
flowchart TD
A[run_gemm_tests] --> B{enable_result_check\n&& process_id == 0?}
B -- No --> Z[End]
B -- Yes --> C{collective_op ==\nREDUCE_SCATTER\n&& !use_quantization?}
C -- Yes\nis_cgemm_rs_bf16=True --> D["rtol=1e-2\natol=0.125\n(covers 1 BF16 ULP at O(8) scale)"]
C -- No\nis_cgemm_rs_bf16=False --> E["rtol=None\natol=None\n→ dtype_tols fallback"]
D --> F[assert_allclose\ngathered_ref_output vs gathered_output]
E --> F
Reviews (1): Last reviewed commit: "Merge branch 'main' into cgemm_bf16_fix_..." | Re-trigger Greptile
Description
atol=1e-5was too strict for BF16 comparisons between the NONE collective GEMM and Collective GEMM with RS collective paths. Both paths split K across TP ranks and produce identical BF16 partial GEMMs, but reduce them in different orders:((p0+p1)+(p2+p3))— binary tree in FP32 → BF16((p0+p1)+p2)+p3— sequential in FP32 → BF16Different reduction associativity causes rounding differences of up to 1 BF16 ULP of the partial GEMM magnitude. The combined tolerance
atol + rtol*|ref|covers this across all output scales:|ref| > atol/rtol = 12.5):rtol=1e-2dominates and provides sufficient coverage.rtolprovides no coverage, soatol=0.125(2× the worst-case 1-ULP diff at O(8) scale) is needed.atol=1e-5failed because it is far below 1 ULP at any realistic activation magnitude.Reproducer
The mismatch is verified by a standalone test (https://gist.github.com/phu0ngng/9600caf76df6040ecc4b3f3c6ea20882) that mimics the two collective paths on a single GPU:
test_gemm.py(M=8192, K_tp=1024, N=16384, seed=PRNGKey(0)).C_none) and TE sequential order (C_rs).2 elements differ by exactly 1 BF16 ULP.
Type of change
Changes
Increase
atolfrom1e-5to0.125to cover the near-zero regime wherertolprovides no coverage. Large-magnitude diffs (the common case) are already handled byrtol=1e-2.Checklist: