Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 13 additions & 1 deletion examples/jax/collective_gemm/test_gemm.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,8 +151,20 @@ def run_gemm_tests(args, mesh=None):
jax.block_until_ready(gathered_output)

if args.enable_result_check and args.process_id == 0:
# CGEMM + RS + BF16 uses TE's reduce_bf16 kernel (sequential left-to-right in FP32).
# With catastrophic cancellation the output is near zero while the absolute diff can
# reach 1 ULP of the partial GEMM magnitude (~0.0625 for typical transformer
# activations at O(8) scale), which exceeds the previous atol=1e-5. The 2x
# margin (0.125) covers this worst-case 1-ULP absolute difference.
is_cgemm_rs_bf16 = collective_op == CollectiveOp.REDUCE_SCATTER and not use_quantization
rtol = 1e-2 if is_cgemm_rs_bf16 else None
atol = 0.125 if is_cgemm_rs_bf16 else None
assert_allclose(
gathered_ref_output, gathered_output, dtype=get_tolerance_dtype(quantizer_set)
gathered_ref_output,
gathered_output,
dtype=get_tolerance_dtype(quantizer_set),
rtol=rtol,
atol=atol,
)


Expand Down
Loading