diff --git a/examples/jax/collective_gemm/test_gemm.py b/examples/jax/collective_gemm/test_gemm.py index c2db8fc44a..8221d7bbfd 100644 --- a/examples/jax/collective_gemm/test_gemm.py +++ b/examples/jax/collective_gemm/test_gemm.py @@ -151,8 +151,20 @@ def run_gemm_tests(args, mesh=None): jax.block_until_ready(gathered_output) if args.enable_result_check and args.process_id == 0: + # CGEMM + RS + BF16 uses TE's reduce_bf16 kernel (sequential left-to-right in FP32). + # With catastrophic cancellation the output is near zero while the absolute diff can + # reach 1 ULP of the partial GEMM magnitude (~0.0625 for typical transformer + # activations at O(8) scale), which exceeds the previous atol=1e-5. The 2x + # margin (0.125) covers this worst-case 1-ULP absolute difference. + is_cgemm_rs_bf16 = collective_op == CollectiveOp.REDUCE_SCATTER and not use_quantization + rtol = 1e-2 if is_cgemm_rs_bf16 else None + atol = 0.125 if is_cgemm_rs_bf16 else None assert_allclose( - gathered_ref_output, gathered_output, dtype=get_tolerance_dtype(quantizer_set) + gathered_ref_output, + gathered_output, + dtype=get_tolerance_dtype(quantizer_set), + rtol=rtol, + atol=atol, )