From 28e5f5377c3059c40c73b7c30b0ea584c9ea6943 Mon Sep 17 00:00:00 2001 From: Jeremy Berchtold Date: Mon, 9 Mar 2026 15:42:48 -0700 Subject: [PATCH 01/60] Refactor to group_sizes per tensor Signed-off-by: Jeremy Berchtold --- tests/jax/test_custom_call_compute.py | 16 +- transformer_engine/jax/cpp_extensions/gemm.py | 202 +++++++++++------- .../jax/csrc/extensions/gemm.cpp | 194 +++++++++++------ transformer_engine/jax/dense.py | 28 ++- 4 files changed, 284 insertions(+), 156 deletions(-) diff --git a/tests/jax/test_custom_call_compute.py b/tests/jax/test_custom_call_compute.py index 613aefc178..02cc05649a 100644 --- a/tests/jax/test_custom_call_compute.py +++ b/tests/jax/test_custom_call_compute.py @@ -1787,13 +1787,16 @@ def test_grouped_gemm_fp16(self, dtype, input_shape, layout): ref_out = self._ref_grouped_dense(lhs, rhs, None, group_sizes, contracting_dims) # jitting grouped_gemm + empty_gs = jnp.empty((0,), jnp.int32) prim_out = jax.jit( tex.grouped_gemm, static_argnames=("contracting_dims", "use_async_d2h_group_sizes") )( lhs, rhs, - group_sizes, - contracting_dims, + lhs_group_sizes=group_sizes, + rhs_group_sizes=empty_gs, + out_group_sizes=group_sizes, + contracting_dims=contracting_dims, use_async_d2h_group_sizes=True, ) @@ -1825,8 +1828,15 @@ def test_grouped_gemm_fp8(self, fwd_bwd_dtype, scaling_mode, input_shape, layout ) ref_out = self._ref_grouped_dense(lhs, rhs, None, group_sizes, contracting_dims) + empty_gs = jnp.empty((0,), jnp.int32) prim_out = jax.jit(tex.grouped_gemm, static_argnames=("contracting_dims",))( - lhs, rhs, group_sizes, contracting_dims, quantizer_set=quantizer_set + lhs, + rhs, + lhs_group_sizes=group_sizes, + rhs_group_sizes=empty_gs, + out_group_sizes=group_sizes, + contracting_dims=contracting_dims, + quantizer_set=quantizer_set, ) allclose_dtype = jnp.float8_e4m3fn diff --git a/transformer_engine/jax/cpp_extensions/gemm.py b/transformer_engine/jax/cpp_extensions/gemm.py index ab2be7f799..c298e19bf0 100644 --- a/transformer_engine/jax/cpp_extensions/gemm.py +++ b/transformer_engine/jax/cpp_extensions/gemm.py @@ -1446,12 +1446,12 @@ class GroupedGemmPrimitive(BasePrimitive): Primitive for grouped GEMM using nvte_multi_tensor_gemm (supports all scaling modes) or nvte_grouped_gemm (supporting BF16). """ - # args = lhs_data, lhs_scale_inv, rhs_data, rhs_scale_inv, bias, group_sizes, group_offset, unused_placeholder + # args = lhs_data, lhs_scale_inv, rhs_data, rhs_scale_inv, bias, lhs_group_sizes, rhs_group_sizes, out_group_sizes, group_offset, unused_placeholder name = "te_grouped_gemm_ffi" - # args = lhs_data, lhs_scale_inv, rhs_data, rhs_scale_inv, bias, group_sizes, alpha, beta + # args = lhs_data, lhs_scale_inv, rhs_data, rhs_scale_inv, bias, lhs_group_sizes, rhs_group_sizes, out_group_sizes, alpha, beta name_graph_safe = "te_grouped_gemm_v2_ffi" multiple_results = True - impl_static_args = (8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18) + impl_static_args = (10, 11, 12, 13, 14, 15, 16) inner_primitive = None outer_primitive = None @@ -1462,17 +1462,15 @@ def abstract( rhs_data_aval, rhs_scale_inv_aval, bias_aval, - group_sizes_aval, + lhs_group_sizes_aval, + rhs_group_sizes_aval, + out_group_sizes_aval, *additional_args, # group_offset_aval, unused_placeholder OR alpha_aval, beta_aval - M, - N, - K, lhs_is_trans, rhs_is_trans, scaling_mode, out_dtype, has_bias, - is_grouped_dense_wgrad, use_async_d2h_group_sizes, use_v2_ffi, ): @@ -1480,35 +1478,57 @@ def abstract( Grouped GEMM operation. Args: - lhs_data: Left-hand side input matrix data, 1D flattened array + lhs_data: Left-hand side input matrix data, 2D array [rows, cols] lhs_scale_inv: Left-hand side input scale_inv matrix, 1D flattened array - rhs_data: Right-hand side input matrix data, 1D flattened array + rhs_data: Right-hand side input matrix data, 2D array [rows, cols] rhs_scale_inv: Right-hand side input scale_inv matrix, 1D flattened array bias: Bias matrix of shape (G, N) - group_sizes: 1D array containing the sizes of each group + lhs_group_sizes: (G,) int32 if lhs first-dim is ragged, else empty (0,) sentinel + rhs_group_sizes: (G,) int32 if rhs first-dim is ragged (wgrad), else empty (0,) sentinel + out_group_sizes: (G,) int32 if output first-dim is ragged, else empty (0,) sentinel additional_args: Either * group_offsets: 1D array containing offsets for each group (not yet implemented) OR * alpha: 1D array of shape (G,) containing alpha values for each group * beta: 1D array of shape (G,) containing beta values for each group - M: Number of rows in the output matrix - N: Number of columns in the output matrix - K: Number of columns in the left-hand side matrix lhs_is_trans: Boolean indicating if the left-hand side matrix is transposed rhs_is_trans: Boolean indicating if the right-hand side matrix is transposed scaling_mode: Scaling mode for the GEMM operations out_dtype: Data type of the output tensors has_bias: Boolean indicating if bias tensors are provided - is_grouped_dense_wgrad: Boolean indicating if this is a grouped dense wgrad operation - where both lhs and rhs are 2D matrices and output is (G, M, N) Returns: A jnp.ndarray containing the result of the grouped GEMM operation """ - del lhs_data_aval, rhs_data_aval, bias_aval - del K, lhs_is_trans, rhs_is_trans, has_bias, use_async_d2h_group_sizes + del bias_aval + del has_bias, use_async_d2h_group_sizes + + # Determine mode from which group_sizes buffer is non-empty + is_wgrad = rhs_group_sizes_aval.size > 0 + num_groups = ( + lhs_group_sizes_aval.size + or rhs_group_sizes_aval.size + or out_group_sizes_aval.size + or additional_args[0].size # alpha (V2) has size G; group_offset (legacy) has size >= 1 + ) - num_groups = group_sizes_aval.size + # lhs_data_aval and rhs_data_aval are now 2D; derive output shape from buffer dims + if is_wgrad: + # lhs shape [K_lhs, M] (lhs_is_trans=True) or [M, K_lhs] (lhs_is_trans=False) + # M is the non-contracting (output) dim + M = lhs_data_aval.shape[1] if lhs_is_trans else lhs_data_aval.shape[0] + N = rhs_data_aval.shape[1] + out_shape = (num_groups, M, N) + else: + # lhs shape [M_total, K] (lhs_is_trans=False) or [K, M_total] (lhs_is_trans=True) + # dim[0] is always total M for fwd/dgrad + M = lhs_data_aval.shape[0] + N = ( + rhs_data_aval.shape[1] + if not rhs_is_trans + else rhs_data_aval.shape[0] // num_groups + ) + out_shape = (M, N) cublas_workspace_aval = jax.core.ShapedArray( shape=( @@ -1519,9 +1539,6 @@ def abstract( dtype=jnp.uint8, ) - out_shape = (M, N) - if is_grouped_dense_wgrad: - out_shape = (num_groups, M, N) out_aval = jax.core.ShapedArray(shape=out_shape, dtype=out_dtype) if use_v2_ffi: @@ -1597,15 +1614,11 @@ def outer_abstract(*args, **kwargs): def lowering( ctx, *args, - M, - N, - K, lhs_is_trans, rhs_is_trans, scaling_mode, out_dtype, has_bias, - is_grouped_dense_wgrad, use_async_d2h_group_sizes, use_v2_ffi, ): @@ -1615,26 +1628,18 @@ def lowering( return jax.ffi.ffi_lowering(ffi_name)( ctx, *args, - M=M, - N=N, - K=K, lhs_is_trans=lhs_is_trans, rhs_is_trans=rhs_is_trans, scaling_mode=scaling_mode.value, - is_grouped_dense_wgrad=is_grouped_dense_wgrad, ) ffi_name = GroupedGemmPrimitive.name return jax.ffi.ffi_lowering(ffi_name)( ctx, *args, - M=M, - N=N, - K=K, lhs_is_trans=lhs_is_trans, rhs_is_trans=rhs_is_trans, scaling_mode=scaling_mode.value, has_bias=has_bias, - is_grouped_dense_wgrad=is_grouped_dense_wgrad, use_async_d2h_group_sizes=use_async_d2h_group_sizes, ) @@ -1645,18 +1650,16 @@ def impl( rhs_data, rhs_scale_inv, bias, - group_sizes, + lhs_group_sizes, + rhs_group_sizes, + out_group_sizes, additional_arg_0, # group_offset (non-graph-safe) OR alpha (graph-safe) additional_arg_1, # unused placeholder (non-graph-safe) OR beta (graph-safe) - M, - N, - K, lhs_is_trans, rhs_is_trans, scaling_mode, out_dtype, has_bias, - is_grouped_dense_wgrad, use_async_d2h_group_sizes, use_v2_ffi, ): @@ -1671,17 +1674,15 @@ def impl( rhs_data, rhs_scale_inv, bias, - group_sizes, + lhs_group_sizes, + rhs_group_sizes, + out_group_sizes, *additional_args, - M=M, - N=N, - K=K, lhs_is_trans=lhs_is_trans, rhs_is_trans=rhs_is_trans, scaling_mode=scaling_mode, out_dtype=out_dtype, has_bias=has_bias, - is_grouped_dense_wgrad=is_grouped_dense_wgrad, use_async_d2h_group_sizes=use_async_d2h_group_sizes, use_v2_ffi=use_v2_ffi, ) @@ -2022,10 +2023,24 @@ def _can_use_v2_grouped_gemm( return scaling_mode == ScalingMode.NO_SCALING and dtype == jnp.bfloat16 and not has_bias +def _flatten_to_2d(data, flatten_axis): + """Reshape *data* to 2D by splitting at *flatten_axis*. + + Positive flatten_axis: split before that axis index. + Negative flatten_axis: split before (ndim + flatten_axis). + """ + if data.ndim == 2: + return data # Already 2D, no reshape needed + fa = flatten_axis if flatten_axis >= 0 else data.ndim + flatten_axis + return data.reshape(math.prod(data.shape[:fa]), math.prod(data.shape[fa:])) + + def grouped_gemm( lhs: Union[jnp.ndarray, GroupedScaledTensor1x], rhs: Union[jnp.ndarray, GroupedScaledTensor1x], - group_sizes: jnp.ndarray, + lhs_group_sizes: jnp.ndarray = None, # (G,) int32 if lhs first-dim is ragged, else None/(0,) + rhs_group_sizes: jnp.ndarray = None, # (G,) int32 if rhs first-dim is ragged (wgrad), else None/(0,) + out_group_sizes: jnp.ndarray = None, # (G,) int32 if output first-dim is ragged, else None/(0,) contracting_dims: Tuple[Sequence[int], Sequence[int]] = ((1,), (2,)), bias: jnp.ndarray = None, precision: jax.lax.Precision = jax.lax.Precision.DEFAULT, @@ -2040,7 +2055,9 @@ def grouped_gemm( Args: lhs: Left-hand side input matrix, can be a jnp.ndarray or GroupedScaledTensor1x rhs: Right-hand side input matrix, can be a jnp.ndarray or GroupedScaledTensor1x - group_sizes: 1D array containing the sizes of each group + lhs_group_sizes: (G,) int32 if lhs first-dim is ragged, else None or empty (0,) sentinel + rhs_group_sizes: (G,) int32 if rhs first-dim is ragged (wgrad mode), else None/(0,) + out_group_sizes: (G,) int32 if output first-dim is ragged, else None/(0,) contracting_dims: Tuple of two sequences representing the contracting dimensions bias: Bias tensor of shape (G, N) precision: JAX precision for the GEMM operation @@ -2060,6 +2077,15 @@ def grouped_gemm( # TODO(Phuong): implement the precision del precision + # Replace None sentinels with empty (0,) int32 arrays. + empty_gs = jnp.empty((0,), jnp.int32) + if lhs_group_sizes is None: + lhs_group_sizes = empty_gs + if rhs_group_sizes is None: + rhs_group_sizes = empty_gs + if out_group_sizes is None: + out_group_sizes = empty_gs + if isinstance(lhs, jnp.ndarray): assert isinstance(rhs, jnp.ndarray) out_dtype = lhs.dtype @@ -2074,8 +2100,14 @@ def grouped_gemm( out_dtype = lhs.dq_dtype lhs_shape = lhs.original_shape rhs_shape = rhs.original_shape - lhs_data = lhs.data - rhs_data = rhs.data + lhs_fa = lhs.flatten_axis + rhs_fa = rhs.flatten_axis + lhs_data = lhs.data.reshape( + math.prod(lhs_shape[:lhs_fa]), math.prod(lhs_shape[lhs_fa:]) + ) + rhs_data = rhs.data.reshape( + math.prod(rhs_shape[:rhs_fa]), math.prod(rhs_shape[rhs_fa:]) + ) lhs_scale_inv = lhs.scale_inv rhs_scale_inv = rhs.scale_inv assert lhs.scaling_mode == rhs.scaling_mode @@ -2094,14 +2126,9 @@ def grouped_gemm( rhs_is_trans = rhs_contract_dim[0] != 1 rhs_flatten_axis = -len(rhs_contract_dim) if rhs_is_trans else 1 + len(rhs_contract_dim) - is_grouped_dense_wgrad = False - if len(rhs_shape) == 2: - rhs_is_trans = rhs_contract_dim[0] != 0 - is_grouped_dense_wgrad = True - - # TODO(Hua): thses are for fp16 dense wgrad, any better way to handle this? + # TODO(Hua): these are for fp16 dense wgrad, any better way to handle this? if ( - is_grouped_dense_wgrad + rhs_group_sizes.size > 0 # wgrad mode: rhs first-dim is ragged and not isinstance(lhs, ScaledTensor) and not isinstance(rhs, ScaledTensor) ): @@ -2110,6 +2137,15 @@ def grouped_gemm( lhs_flatten_axis = 1 rhs_flatten_axis = 1 + # For MXFP8 block-scaling wgrad with pre-quantized inputs: rhs is colwise quantized, + # so rhs_use_colwise = (is_mxfp8 && !rhs_is_trans) must be True → rhs_is_trans=False. + if ( + rhs_group_sizes.size > 0 # wgrad mode: rhs first-dim is ragged + and isinstance(lhs, GroupedScaledTensor1x) + and scaling_mode.is_1d_block_scaling() + ): + rhs_is_trans = False + if ( not isinstance(lhs, ScaledTensor) and not isinstance(rhs, ScaledTensor) @@ -2132,16 +2168,30 @@ def grouped_gemm( quantizer_set.kernel.q_layout = ( QuantizeLayout.ROWWISE if rhs_is_rowwise else QuantizeLayout.COLWISE ) - lhs_q = grouped_quantize(lhs, quantizer_set.x, group_sizes, lhs_flatten_axis) + active_group_sizes = lhs_group_sizes if lhs_group_sizes.size > 0 else rhs_group_sizes + lhs_q = grouped_quantize(lhs, quantizer_set.x, active_group_sizes, lhs_flatten_axis) rhs_q = grouped_quantize( rhs, quantizer_set.kernel, group_sizes=None, flatten_axis=rhs_flatten_axis ) - lhs_data = lhs_q.data - rhs_data = rhs_q.data + # grouped_quantize returns a 1D flat buffer; reshape to 2D using the + # original_shape and flatten_axis stored in each quantized tensor. + lhs_fa = lhs_q.flatten_axis # positive index (adjusted in create_1x) + rhs_fa = rhs_q.flatten_axis + lhs_data = lhs_q.data.reshape( + math.prod(lhs_q.original_shape[:lhs_fa]), + math.prod(lhs_q.original_shape[lhs_fa:]), + ) + rhs_data = rhs_q.data.reshape( + math.prod(rhs_q.original_shape[:rhs_fa]), + math.prod(rhs_q.original_shape[rhs_fa:]), + ) lhs_scale_inv = lhs_q.scale_inv rhs_scale_inv = rhs_q.scale_inv lhs_shape = lhs_q.original_shape rhs_shape = rhs_q.original_shape + # Data is already 2D; reset flatten axes so _flatten_to_2d calls below are no-ops. + lhs_flatten_axis = -1 + rhs_flatten_axis = -1 assert not ( lhs_data.dtype == jnp.float8_e5m2 and rhs_data.dtype == jnp.float8_e5m2 @@ -2172,31 +2222,26 @@ def grouped_gemm( lhs_contract_dim = tuple((lhs_ndim - 1 - i) % lhs_ndim for i in lhs_contract_dim) if rhs_layout_is_T: # For rhs [G, K, N], need to exclude the G dim from contract_dim - if group_sizes.size == rhs_shape[0]: + if lhs_group_sizes.size > 0: # fwd/dgrad: rhs has G as first dim rhs_contract_dim = tuple( (rhs_ndim - 1 - i) % (rhs_ndim - 1) + 1 for i in rhs_contract_dim ) else: rhs_contract_dim = tuple((rhs_ndim - 1 - i) % rhs_ndim for i in rhs_contract_dim) - # Calling GroupedGEMM Custom Call - K_lhs = math.prod(lhs_shape[i] for i in lhs_contract_dim) - K_rhs = math.prod(rhs_shape[i] for i in rhs_contract_dim) - assert K_lhs == K_rhs - M = math.prod(_calculate_remaining_shape(lhs_shape, lhs_contract_dim)) - N = math.prod(_calculate_remaining_shape(rhs_shape, rhs_contract_dim)[1:]) # Exclude G + # Reshape inputs to 2D using the already-computed flatten_axes. + lhs_data_2d = _flatten_to_2d(lhs_data, lhs_flatten_axis) + rhs_data_2d = _flatten_to_2d(rhs_data, rhs_flatten_axis) - if is_grouped_dense_wgrad: - N = math.prod(_calculate_remaining_shape(rhs_shape, rhs_contract_dim)) - else: - assert group_sizes.size == rhs_shape[0] + num_gemms = lhs_group_sizes.size or rhs_group_sizes.size or out_group_sizes.size has_bias = bias is not None if has_bias: + N_dim = rhs_data_2d.shape[0] // num_gemms if rhs_is_trans else rhs_data_2d.shape[1] assert bias.shape == ( - group_sizes.size, - N, - ), f"bias shape {bias.shape} does not match expected shape {(group_sizes.size, N)}" + num_gemms, + N_dim, + ), f"bias shape {bias.shape} does not match expected shape {(num_gemms, N_dim)}" bias = jnp.empty((), jnp.float32) if bias is None else bias assert group_offset is None, ( @@ -2207,7 +2252,6 @@ def grouped_gemm( use_v2_ffi = _can_use_v2_grouped_gemm(scaling_mode, lhs_data.dtype, has_bias) if use_v2_ffi: - num_gemms = group_sizes.shape[0] additional_arg_0 = jnp.ones((num_gemms,), jnp.float32) # alpha additional_arg_1 = jnp.zeros((num_gemms,), jnp.float32) # beta else: @@ -2215,23 +2259,21 @@ def grouped_gemm( additional_arg_1 = jnp.zeros((0,), jnp.int32) # unused placeholder (out,) = GroupedGemmPrimitive.outer_primitive.bind( - lhs_data, + lhs_data_2d, lhs_scale_inv, - rhs_data, + rhs_data_2d, rhs_scale_inv, bias, - group_sizes, + lhs_group_sizes, + rhs_group_sizes, + out_group_sizes, additional_arg_0, additional_arg_1, - M=M, - N=N, - K=K_lhs, lhs_is_trans=lhs_is_trans, rhs_is_trans=rhs_is_trans, scaling_mode=scaling_mode.value, out_dtype=out_dtype, has_bias=has_bias, - is_grouped_dense_wgrad=is_grouped_dense_wgrad, use_async_d2h_group_sizes=use_async_d2h_group_sizes, use_v2_ffi=use_v2_ffi, ) diff --git a/transformer_engine/jax/csrc/extensions/gemm.cpp b/transformer_engine/jax/csrc/extensions/gemm.cpp index 4cbec405a4..834a7b9a5f 100644 --- a/transformer_engine/jax/csrc/extensions/gemm.cpp +++ b/transformer_engine/jax/csrc/extensions/gemm.cpp @@ -562,11 +562,12 @@ JAXX_GroupedTensorWrapper make_grouped_tensor(Buffer_Type const &data, // This FFI is EXPERIMENTAL and subject to change without deprecation, intended for use in JAX's internal implementation of grouped GEMM. Error_Type GroupedGemmV2FFI(cudaStream_t stream, Buffer_Type lhs_data, Buffer_Type lhs_sinv, Buffer_Type rhs_data, Buffer_Type rhs_sinv, Buffer_Type bias, - Buffer_Type group_sizes, Buffer_Type alpha, Buffer_Type beta, + Buffer_Type lhs_group_sizes, Buffer_Type rhs_group_sizes, + Buffer_Type out_group_sizes, Buffer_Type alpha, Buffer_Type beta, Result_Type output, Result_Type cublas_workspace, - Result_Type setup_workspace, Result_Type int64_workspace, size_t m, - size_t n, size_t k, bool lhs_is_trans, bool rhs_is_trans, - JAXX_Scaling_Mode scaling_mode, bool is_grouped_dense_wgrad) { + Result_Type setup_workspace, Result_Type int64_workspace, + bool lhs_is_trans, bool rhs_is_trans, + JAXX_Scaling_Mode scaling_mode) { // Notes on matrix layouts and transpose: // Jax uses row-major data_layout, on entering this function, each input matrix pair: // A: row-major [m, k] for N - [k, m] for T @@ -581,6 +582,40 @@ Error_Type GroupedGemmV2FFI(cudaStream_t stream, Buffer_Type lhs_data, Buffer_Ty // C: column-major with size [m, n] --> row-major with size [n, m]. // To make the output compatible with JAX, we need to swap A and B in cuBLAS GEMM call. + // out_group_sizes is the sentinel for the output tensor's ragged dimension; unused directly here + // as the output shape is inferred from lhs/rhs dims and passed to nvte_grouped_gemm implicitly. + (void)out_group_sizes; + + // Determine which group_sizes buffer is active (non-empty sentinel = ragged dimension). + bool is_lhs_ragged = lhs_group_sizes.element_count() > 0; + bool is_rhs_ragged = rhs_group_sizes.element_count() > 0; + bool any_ragged = is_lhs_ragged || is_rhs_ragged; + + size_t num_gemms; + if (is_lhs_ragged) + num_gemms = lhs_group_sizes.dimensions()[0]; + else if (is_rhs_ragged) + num_gemms = rhs_group_sizes.dimensions()[0]; + else if (out_group_sizes.element_count() > 0) + num_gemms = out_group_sizes.dimensions()[0]; + else + num_gemms = alpha.element_count(); // batched: no ragged tensor + const Buffer_Type &active_group_sizes = is_lhs_ragged ? lhs_group_sizes : rhs_group_sizes; + + // lhs_data and rhs_data are 2D; derive m, n, k from buffer dimensions. + NVTE_CHECK(lhs_data.dimensions().size() == 2, "lhs_data must be 2D."); + NVTE_CHECK(rhs_data.dimensions().size() == 2, "rhs_data must be 2D."); + size_t k = lhs_is_trans ? lhs_data.dimensions()[0] : lhs_data.dimensions()[1]; + size_t m, n; + if (is_rhs_ragged) { + // wgrad: lhs shape [K_lhs, M]: lhs_is_trans=True, contracting is dim[0]=K_lhs, output is dim[1]=M + m = lhs_is_trans ? lhs_data.dimensions()[1] : lhs_data.dimensions()[0]; + n = rhs_data.dimensions()[1]; + } else { + m = lhs_data.dimensions()[0]; // total M (sum of group sizes) + n = rhs_is_trans ? rhs_data.dimensions()[0] / num_gemms : rhs_data.dimensions()[1]; + } + // Inputs auto lhs_ptr = reinterpret_cast(lhs_data.untyped_data()); auto rhs_ptr = reinterpret_cast(rhs_data.untyped_data()); @@ -594,14 +629,15 @@ Error_Type GroupedGemmV2FFI(cudaStream_t stream, Buffer_Type lhs_data, Buffer_Ty auto bias_ptr = has_bias ? reinterpret_cast(bias.untyped_data()) : nullptr; auto bias_dtype = convert_ffi_datatype_to_te_dtype(bias.element_type()); - NVTE_CHECK(group_sizes.dimensions().size() == 1); - size_t num_gemms = group_sizes.dimensions()[0]; - - // Convert int32 group_sizes to int64 into the dedicated output buffer. - NVTE_CHECK(group_sizes.element_type() == xla::ffi::DataType::S32, "group_sizes must be int32."); + // Convert int32 group_sizes to int64 into the dedicated output buffer (ragged tensors only). auto *int64_sizes_ptr = reinterpret_cast(int64_workspace->untyped_data()); - nvte_convert_int32_to_int64(reinterpret_cast(group_sizes.untyped_data()), - int64_sizes_ptr, num_gemms, stream); + if (any_ragged) { + NVTE_CHECK(active_group_sizes.element_type() == xla::ffi::DataType::S32, + "group_sizes must be int32."); + nvte_convert_int32_to_int64( + reinterpret_cast(active_group_sizes.untyped_data()), int64_sizes_ptr, + num_gemms, stream); + } NVTE_CHECK(scaling_mode == JAXX_Scaling_Mode::NO_SCALING, "Only non-quantized grouped GEMM is supported in current implementation."); @@ -656,14 +692,14 @@ Error_Type GroupedGemmV2FFI(cudaStream_t stream, Buffer_Type lhs_data, Buffer_Ty "sizeof(lhs_sinv_dtype) != sizeof(rhs_sinv_dtype)"); size_t expected_lhs_size = m * k; - size_t expected_rhs_size = is_grouped_dense_wgrad ? (k * n) : (num_gemms * k * n); - size_t expected_out_size = is_grouped_dense_wgrad ? (num_gemms * m * n) : (m * n); + size_t expected_rhs_size = is_rhs_ragged ? (k * n) : (num_gemms * k * n); + size_t expected_out_size = is_rhs_ragged ? (num_gemms * m * n) : (m * n); size_t actual_lhs_size = product(lhs_data.dimensions()); size_t actual_rhs_size = product(rhs_data.dimensions()); size_t actual_out_size = product(output->dimensions()); NVTE_CHECK(expected_lhs_size == actual_lhs_size, "Unexpected lhs size! Expect ", expected_lhs_size, ", got ", actual_lhs_size); - if (!is_grouped_dense_wgrad) { + if (!is_rhs_ragged) { NVTE_CHECK(expected_rhs_size == actual_rhs_size, "Unexpected rhs size! Expect num_gemms * n * k = ", num_gemms, " * ", n, " * ", k, " = ", expected_rhs_size, ", got ", actual_rhs_size); @@ -703,7 +739,7 @@ Error_Type GroupedGemmV2FFI(cudaStream_t stream, Buffer_Type lhs_data, Buffer_Ty std::vector{num_gemms}, convert_ffi_datatype_to_te_dtype(beta.element_type())); - if (is_grouped_dense_wgrad) { + if (is_rhs_ragged) { NVTE_CHECK(lhs_is_trans && !rhs_is_trans, "For grouped dense wgrad, only TN GEMM is supported in TE/JAX currently."); @@ -732,7 +768,7 @@ Error_Type GroupedGemmV2FFI(cudaStream_t stream, Buffer_Type lhs_data, Buffer_Ty return ffi_with_cuda_error_check(); } - // Nominal case for FWD or DGRAD + // Nominal case for FWD, DGRAD, or batched GEMM //// RHS NVTEShape rhsShape{.data = {num_gemms * k, n}, .ndim = 2}; @@ -748,14 +784,18 @@ Error_Type GroupedGemmV2FFI(cudaStream_t stream, Buffer_Type lhs_data, Buffer_Ty std::swap(lhsShape.data[0], lhsShape.data[1]); } auto lhs_tensor = make_grouped_tensor(lhs_data, lhs_sinv, scaling_mode, num_gemms, lhsShape); - lhs_tensor.set_group_sizes_only(int64_sizes_ptr, num_gemms, - lhs_is_trans ? kNVTEGroupedLastDims : kNVTEGroupedFirstDims); + if (any_ragged) { + lhs_tensor.set_group_sizes_only(int64_sizes_ptr, num_gemms, + lhs_is_trans ? kNVTEGroupedLastDims : kNVTEGroupedFirstDims); + } //// OUTPUT NVTEShape outShape{.data = {m, n}, .ndim = 2}; auto out_tensor = make_grouped_tensor(*output, std::nullopt, JAXX_Scaling_Mode::NO_SCALING, num_gemms, outShape); - out_tensor.set_group_sizes_only(int64_sizes_ptr, num_gemms, kNVTEGroupedFirstDims); + if (any_ragged) { + out_tensor.set_group_sizes_only(int64_sizes_ptr, num_gemms, kNVTEGroupedFirstDims); + } nvte_grouped_gemm(rhs_tensor, rhs_is_trans, lhs_tensor, lhs_is_trans, nullptr, out_tensor, alpha_tensor.data(), beta_tensor.data(), workspace_setup.data(), @@ -769,33 +809,32 @@ Error_Type GroupedGemmV2FFI(cudaStream_t stream, Buffer_Type lhs_data, Buffer_Ty XLA_FFI_DEFINE_HANDLER_SYMBOL(GroupedGemmV2Handler, GroupedGemmV2FFI, FFI::Bind() .Ctx() // stream - .Arg() // lhs_data + .Arg() // lhs_data (2D) .Arg() // lhs_sinv - .Arg() // rhs_data + .Arg() // rhs_data (2D) .Arg() // rhs_sinv .Arg() // bias - .Arg() // group_sizes (int32) + .Arg() // lhs_group_sizes (G,) or empty (0,) + .Arg() // rhs_group_sizes (G,) or empty (0,) + .Arg() // out_group_sizes (G,) or empty (0,) .Arg() // alpha .Arg() // beta .Ret() // output .Ret() // cublas_workspace .Ret() // setup_workspace .Ret() // int64_workspace - .Attr("M") - .Attr("N") - .Attr("K") .Attr("lhs_is_trans") .Attr("rhs_is_trans") - .Attr("scaling_mode") - .Attr("is_grouped_dense_wgrad"), + .Attr("scaling_mode"), FFI_CudaGraph_Traits); Error_Type GroupedGemmFFI(cudaStream_t stream, Buffer_Type lhs_data, Buffer_Type lhs_sinv, Buffer_Type rhs_data, Buffer_Type rhs_sinv, Buffer_Type bias, - Buffer_Type group_sizes, Buffer_Type group_offset, Result_Type output, - Result_Type workspace, size_t m, size_t n, size_t k, bool lhs_is_trans, + Buffer_Type lhs_group_sizes, Buffer_Type rhs_group_sizes, + Buffer_Type out_group_sizes, Buffer_Type group_offset, + Result_Type output, Result_Type workspace, bool lhs_is_trans, bool rhs_is_trans, JAXX_Scaling_Mode scaling_mode, bool has_bias, - bool is_grouped_dense_wgrad, bool use_async_d2h_group_sizes) { + bool use_async_d2h_group_sizes) { // Notes on matrix layouts and transpose: // Jax uses row-major data_layout, on entering this function, each input matrix pair: // A: row-major [m, k] for N - [k, m] for T @@ -812,6 +851,37 @@ Error_Type GroupedGemmFFI(cudaStream_t stream, Buffer_Type lhs_data, Buffer_Type int num_streams = nvte_get_num_compute_streams(); + // out_group_sizes is the sentinel for the output tensor's ragged dimension; unused directly here. + (void)out_group_sizes; + + // Determine which group_sizes buffer is active (non-empty sentinel = ragged dimension). + bool is_lhs_ragged = lhs_group_sizes.element_count() > 0; + bool is_rhs_ragged = rhs_group_sizes.element_count() > 0; + bool any_ragged = is_lhs_ragged || is_rhs_ragged; + + size_t num_gemms; + if (is_lhs_ragged) + num_gemms = lhs_group_sizes.dimensions()[0]; + else if (is_rhs_ragged) + num_gemms = rhs_group_sizes.dimensions()[0]; + else + num_gemms = 1; // degenerate batched; legacy batched not a tested use case + const Buffer_Type &active_group_sizes = is_lhs_ragged ? lhs_group_sizes : rhs_group_sizes; + + // lhs_data and rhs_data are 2D; derive m, n, k from buffer dimensions. + NVTE_CHECK(lhs_data.dimensions().size() == 2, "lhs_data must be 2D."); + NVTE_CHECK(rhs_data.dimensions().size() == 2, "rhs_data must be 2D."); + size_t k = lhs_is_trans ? lhs_data.dimensions()[0] : lhs_data.dimensions()[1]; + size_t m, n; + if (is_rhs_ragged) { + // wgrad: lhs shape [K_lhs, M]: lhs_is_trans=True, contracting is dim[0]=K_lhs, output is dim[1]=M + m = lhs_is_trans ? lhs_data.dimensions()[1] : lhs_data.dimensions()[0]; + n = rhs_data.dimensions()[1]; + } else { + m = lhs_data.dimensions()[0]; // total M (sum of group sizes) + n = rhs_is_trans ? rhs_data.dimensions()[0] / num_gemms : rhs_data.dimensions()[1]; + } + // Inputs auto lhs_ptr = reinterpret_cast(lhs_data.untyped_data()); auto rhs_ptr = reinterpret_cast(rhs_data.untyped_data()); @@ -824,9 +894,6 @@ Error_Type GroupedGemmFFI(cudaStream_t stream, Buffer_Type lhs_data, Buffer_Type auto bias_ptr = has_bias ? reinterpret_cast(bias.untyped_data()) : nullptr; auto bias_dtype = convert_ffi_datatype_to_te_dtype(bias.element_type()); - NVTE_CHECK(group_sizes.dimensions().size() == 1); - size_t num_gemms = group_sizes.dimensions()[0]; - // It is weird that TE/Common GEMM only use colwise for MXFP8 const bool is_fp8_gemm = is_fp8_dtype(lhs_dtype); const bool is_tensor_scaling = scaling_mode == JAXX_Scaling_Mode::DELAYED_TENSOR_SCALING || @@ -893,14 +960,14 @@ Error_Type GroupedGemmFFI(cudaStream_t stream, Buffer_Type lhs_data, Buffer_Type "sizeof(lhs_sinv_dtype) != sizeof(rhs_sinv_dtype)"); size_t expected_lhs_size = m * k; - size_t expected_rhs_size = is_grouped_dense_wgrad ? (k * n) : (num_gemms * k * n); - size_t expected_out_size = is_grouped_dense_wgrad ? (num_gemms * m * n) : (m * n); + size_t expected_rhs_size = is_rhs_ragged ? (k * n) : (num_gemms * k * n); + size_t expected_out_size = is_rhs_ragged ? (num_gemms * m * n) : (m * n); size_t actual_lhs_size = product(lhs_data.dimensions()); size_t actual_rhs_size = product(rhs_data.dimensions()); size_t actual_out_size = product(output->dimensions()); NVTE_CHECK(expected_lhs_size == actual_lhs_size, "Unexpected lhs size! Expect ", expected_lhs_size, ", got ", actual_lhs_size); - if (!is_grouped_dense_wgrad) { + if (!is_rhs_ragged) { NVTE_CHECK(expected_rhs_size == actual_rhs_size, "Unexpected rhs size! Expect num_gemms * n * k = ", num_gemms, " * ", n, " * ", k, " = ", expected_rhs_size, ", got ", actual_rhs_size); @@ -916,25 +983,28 @@ Error_Type GroupedGemmFFI(cudaStream_t stream, Buffer_Type lhs_data, Buffer_Type size_t dim_list_bytes = sizeof(int32_t) * num_gemms; std::vector dim_list_host(num_gemms); - size_t host_num_gemms = 0; - if (use_async_d2h_group_sizes) { - host_num_gemms = GroupedGemmGetGroupSizes(stream, num_gemms, nullptr, dim_list_host.data()); - NVTE_CHECK(host_num_gemms == num_gemms, "num_gemms ", num_gemms, - " does not match the return of GroupedGemmGetGroupSizes ", host_num_gemms, "."); - } else { - auto dim_list_ptr = reinterpret_cast(group_sizes.untyped_data()); - cudaMemcpyAsync(dim_list_host.data(), dim_list_ptr, dim_list_bytes, cudaMemcpyDeviceToHost, - stream); - // Note: This may break cudaGraph. - cudaStreamSynchronize(stream); - } - size_t sum_group_sizes = std::accumulate(dim_list_host.begin(), dim_list_host.end(), 0); - if (!is_grouped_dense_wgrad) { - NVTE_CHECK(m == sum_group_sizes, "Unexpected group_sizes! M = ", m, - ", got sum(group_sizes)=", sum_group_sizes); - } else { - NVTE_CHECK(k == sum_group_sizes, "Unexpected group_sizes! K = ", k, - ", got sum(group_sizes)=", sum_group_sizes); + if (any_ragged) { + size_t host_num_gemms = 0; + if (use_async_d2h_group_sizes) { + host_num_gemms = GroupedGemmGetGroupSizes(stream, num_gemms, nullptr, dim_list_host.data()); + NVTE_CHECK(host_num_gemms == num_gemms, "num_gemms ", num_gemms, + " does not match the return of GroupedGemmGetGroupSizes ", host_num_gemms, "."); + } else { + auto active_gs_ptr = + reinterpret_cast(active_group_sizes.untyped_data()); + cudaMemcpyAsync(dim_list_host.data(), active_gs_ptr, dim_list_bytes, cudaMemcpyDeviceToHost, + stream); + // Note: This may break cudaGraph. + cudaStreamSynchronize(stream); + } + size_t sum_group_sizes = std::accumulate(dim_list_host.begin(), dim_list_host.end(), 0); + if (!is_rhs_ragged) { + NVTE_CHECK(m == sum_group_sizes, "Unexpected group_sizes! M = ", m, + ", got sum(group_sizes)=", sum_group_sizes); + } else { + NVTE_CHECK(k == sum_group_sizes, "Unexpected group_sizes! K = ", k, + ", got sum(group_sizes)=", sum_group_sizes); + } } auto num_math_sm = cuda::sm_count() - getenv("NVTE_EXT_MARGIN_SM", 0); @@ -982,7 +1052,7 @@ Error_Type GroupedGemmFFI(cudaStream_t stream, Buffer_Type lhs_data, Buffer_Type auto lhs_shape_i = std::vector{m_i, k}; auto rhs_shape_i = std::vector{rhs_is_trans ? n : k, rhs_is_trans ? k : n}; auto out_shape_i = std::vector{m_i, n}; - if (is_grouped_dense_wgrad) { + if (is_rhs_ragged) { size_t k_i = dim_list_host[i]; lhs_shape_i[0] = lhs_is_trans ? k_i : m; lhs_shape_i[1] = lhs_is_trans ? m : k_i; @@ -1172,23 +1242,21 @@ Error_Type GroupedGemmFFI(cudaStream_t stream, Buffer_Type lhs_data, Buffer_Type XLA_FFI_DEFINE_HANDLER_SYMBOL(GroupedGemmHandler, GroupedGemmFFI, FFI::Bind() .Ctx() // stream - .Arg() // lhs_data + .Arg() // lhs_data (2D) .Arg() // lhs_sinv - .Arg() // rhs_data + .Arg() // rhs_data (2D) .Arg() // rhs_sinv .Arg() // bias - .Arg() // group_sizes + .Arg() // lhs_group_sizes (G,) or empty (0,) + .Arg() // rhs_group_sizes (G,) or empty (0,) + .Arg() // out_group_sizes (G,) or empty (0,) .Arg() // group_offset .Ret() // output .Ret() // workspace - .Attr("M") - .Attr("N") - .Attr("K") .Attr("lhs_is_trans") .Attr("rhs_is_trans") .Attr("scaling_mode") .Attr("has_bias") - .Attr("is_grouped_dense_wgrad") .Attr("use_async_d2h_group_sizes")); } // namespace jax diff --git a/transformer_engine/jax/dense.py b/transformer_engine/jax/dense.py index 268995281c..e2a79fe9c7 100644 --- a/transformer_engine/jax/dense.py +++ b/transformer_engine/jax/dense.py @@ -523,15 +523,18 @@ def _grouped_dense_fwd_rule( # This is needed especially when kernel_fsdp_enabled == True AND FP8 enabled. quantizer_set.kernel.q_layout = original_quantizer_set_kernel_q_layout + empty_gs = jnp.empty((0,), jnp.int32) output = tex.grouped_gemm( grouped_gemm_x, grouped_gemm_kernel, - group_sizes, - contracting_dims, - bias, - precision, - preferred_element_type, - group_offset, + lhs_group_sizes=group_sizes, + rhs_group_sizes=empty_gs, + out_group_sizes=group_sizes, + contracting_dims=contracting_dims, + bias=bias, + precision=precision, + preferred_element_type=preferred_element_type, + group_offset=group_offset, ) ctx = ( @@ -615,11 +618,14 @@ def _grouped_dense_bwd_rule( wgrad_x_T = ctx_x wgrad_grad = casted_grad.get_tensor(usage=TensorUsage.RHS) + empty_gs = jnp.empty((0,), jnp.int32) dgrad = tex.grouped_gemm( dgrad_grad, dgrad_kernel_T, - group_sizes, - dgrad_contracting_dims, + lhs_group_sizes=group_sizes, + rhs_group_sizes=empty_gs, + out_group_sizes=group_sizes, + contracting_dims=dgrad_contracting_dims, precision=precision, preferred_element_type=preferred_element_type, group_offset=group_offset, @@ -628,8 +634,10 @@ def _grouped_dense_bwd_rule( wgrad = tex.grouped_gemm( wgrad_x_T, wgrad_grad, - group_sizes, - wgrad_contracting_dims, + lhs_group_sizes=empty_gs, + rhs_group_sizes=group_sizes, + out_group_sizes=empty_gs, + contracting_dims=wgrad_contracting_dims, precision=precision, preferred_element_type=preferred_element_type, group_offset=group_offset, From 4a57485316db5c5f1c6bdf5c991c11e4d374259e Mon Sep 17 00:00:00 2001 From: Jeremy Berchtold Date: Tue, 10 Mar 2026 09:58:07 -0700 Subject: [PATCH 02/60] Support first_dims and last_dims instead of a single group_sizes per tensor Signed-off-by: Jeremy Berchtold --- tests/jax/test_custom_call_compute.py | 18 ++- transformer_engine/jax/cpp_extensions/gemm.py | 99 ++++++++----- .../jax/csrc/extensions/gemm.cpp | 138 +++++++++++------- transformer_engine/jax/dense.py | 27 ++-- 4 files changed, 177 insertions(+), 105 deletions(-) diff --git a/tests/jax/test_custom_call_compute.py b/tests/jax/test_custom_call_compute.py index 02cc05649a..2f2d5383b2 100644 --- a/tests/jax/test_custom_call_compute.py +++ b/tests/jax/test_custom_call_compute.py @@ -1793,9 +1793,12 @@ def test_grouped_gemm_fp16(self, dtype, input_shape, layout): )( lhs, rhs, - lhs_group_sizes=group_sizes, - rhs_group_sizes=empty_gs, - out_group_sizes=group_sizes, + lhs_first_dims=group_sizes, + lhs_last_dims=empty_gs, + rhs_first_dims=empty_gs, + rhs_last_dims=empty_gs, + out_first_dims=group_sizes, + out_last_dims=empty_gs, contracting_dims=contracting_dims, use_async_d2h_group_sizes=True, ) @@ -1832,9 +1835,12 @@ def test_grouped_gemm_fp8(self, fwd_bwd_dtype, scaling_mode, input_shape, layout prim_out = jax.jit(tex.grouped_gemm, static_argnames=("contracting_dims",))( lhs, rhs, - lhs_group_sizes=group_sizes, - rhs_group_sizes=empty_gs, - out_group_sizes=group_sizes, + lhs_first_dims=group_sizes, + lhs_last_dims=empty_gs, + rhs_first_dims=empty_gs, + rhs_last_dims=empty_gs, + out_first_dims=group_sizes, + out_last_dims=empty_gs, contracting_dims=contracting_dims, quantizer_set=quantizer_set, ) diff --git a/transformer_engine/jax/cpp_extensions/gemm.py b/transformer_engine/jax/cpp_extensions/gemm.py index c298e19bf0..32500c9676 100644 --- a/transformer_engine/jax/cpp_extensions/gemm.py +++ b/transformer_engine/jax/cpp_extensions/gemm.py @@ -1446,12 +1446,13 @@ class GroupedGemmPrimitive(BasePrimitive): Primitive for grouped GEMM using nvte_multi_tensor_gemm (supports all scaling modes) or nvte_grouped_gemm (supporting BF16). """ - # args = lhs_data, lhs_scale_inv, rhs_data, rhs_scale_inv, bias, lhs_group_sizes, rhs_group_sizes, out_group_sizes, group_offset, unused_placeholder name = "te_grouped_gemm_ffi" - # args = lhs_data, lhs_scale_inv, rhs_data, rhs_scale_inv, bias, lhs_group_sizes, rhs_group_sizes, out_group_sizes, alpha, beta + # args = lhs_data, lhs_scale_inv, rhs_data, rhs_scale_inv, bias, + # lhs_first_dims, lhs_last_dims, rhs_first_dims, rhs_last_dims, + # out_first_dims, out_last_dims, alpha, beta name_graph_safe = "te_grouped_gemm_v2_ffi" multiple_results = True - impl_static_args = (10, 11, 12, 13, 14, 15, 16) + impl_static_args = (13, 14, 15, 16, 17, 18, 19) inner_primitive = None outer_primitive = None @@ -1462,9 +1463,12 @@ def abstract( rhs_data_aval, rhs_scale_inv_aval, bias_aval, - lhs_group_sizes_aval, - rhs_group_sizes_aval, - out_group_sizes_aval, + lhs_first_dims_aval, + lhs_last_dims_aval, + rhs_first_dims_aval, + rhs_last_dims_aval, + out_first_dims_aval, + out_last_dims_aval, *additional_args, # group_offset_aval, unused_placeholder OR alpha_aval, beta_aval lhs_is_trans, rhs_is_trans, @@ -1504,11 +1508,11 @@ def abstract( del has_bias, use_async_d2h_group_sizes # Determine mode from which group_sizes buffer is non-empty - is_wgrad = rhs_group_sizes_aval.size > 0 + is_wgrad = rhs_first_dims_aval.size > 0 or rhs_last_dims_aval.size > 0 num_groups = ( - lhs_group_sizes_aval.size - or rhs_group_sizes_aval.size - or out_group_sizes_aval.size + lhs_first_dims_aval.size or lhs_last_dims_aval.size + or rhs_first_dims_aval.size or rhs_last_dims_aval.size + or out_first_dims_aval.size or out_last_dims_aval.size or additional_args[0].size # alpha (V2) has size G; group_offset (legacy) has size >= 1 ) @@ -1650,9 +1654,12 @@ def impl( rhs_data, rhs_scale_inv, bias, - lhs_group_sizes, - rhs_group_sizes, - out_group_sizes, + lhs_first_dims, + lhs_last_dims, + rhs_first_dims, + rhs_last_dims, + out_first_dims, + out_last_dims, additional_arg_0, # group_offset (non-graph-safe) OR alpha (graph-safe) additional_arg_1, # unused placeholder (non-graph-safe) OR beta (graph-safe) lhs_is_trans, @@ -1674,9 +1681,12 @@ def impl( rhs_data, rhs_scale_inv, bias, - lhs_group_sizes, - rhs_group_sizes, - out_group_sizes, + lhs_first_dims, + lhs_last_dims, + rhs_first_dims, + rhs_last_dims, + out_first_dims, + out_last_dims, *additional_args, lhs_is_trans=lhs_is_trans, rhs_is_trans=rhs_is_trans, @@ -2038,9 +2048,12 @@ def _flatten_to_2d(data, flatten_axis): def grouped_gemm( lhs: Union[jnp.ndarray, GroupedScaledTensor1x], rhs: Union[jnp.ndarray, GroupedScaledTensor1x], - lhs_group_sizes: jnp.ndarray = None, # (G,) int32 if lhs first-dim is ragged, else None/(0,) - rhs_group_sizes: jnp.ndarray = None, # (G,) int32 if rhs first-dim is ragged (wgrad), else None/(0,) - out_group_sizes: jnp.ndarray = None, # (G,) int32 if output first-dim is ragged, else None/(0,) + lhs_first_dims: jnp.ndarray = None, # (G,) int32 if LHS squashed first dim varies, else None/(0,) + lhs_last_dims: jnp.ndarray = None, # (G,) int32 if LHS squashed last dim varies, else None/(0,) + rhs_first_dims: jnp.ndarray = None, # (G,) int32 if RHS squashed first dim varies, else None/(0,) + rhs_last_dims: jnp.ndarray = None, # (G,) int32 if RHS squashed last dim varies, else None/(0,) + out_first_dims: jnp.ndarray = None, # (G,) int32 if output first dim varies, else None/(0,) + out_last_dims: jnp.ndarray = None, # (G,) int32 if output last dim varies, else None/(0,) contracting_dims: Tuple[Sequence[int], Sequence[int]] = ((1,), (2,)), bias: jnp.ndarray = None, precision: jax.lax.Precision = jax.lax.Precision.DEFAULT, @@ -2055,9 +2068,12 @@ def grouped_gemm( Args: lhs: Left-hand side input matrix, can be a jnp.ndarray or GroupedScaledTensor1x rhs: Right-hand side input matrix, can be a jnp.ndarray or GroupedScaledTensor1x - lhs_group_sizes: (G,) int32 if lhs first-dim is ragged, else None or empty (0,) sentinel - rhs_group_sizes: (G,) int32 if rhs first-dim is ragged (wgrad mode), else None/(0,) - out_group_sizes: (G,) int32 if output first-dim is ragged, else None/(0,) + lhs_first_dims: (G,) int32 if LHS squashed first dim varies per group, else None/(0,) + lhs_last_dims: (G,) int32 if LHS squashed last dim varies per group, else None/(0,) + rhs_first_dims: (G,) int32 if RHS squashed first dim varies per group (wgrad), else None/(0,) + rhs_last_dims: (G,) int32 if RHS squashed last dim varies per group, else None/(0,) + out_first_dims: (G,) int32 if output first dim varies per group, else None/(0,) + out_last_dims: (G,) int32 if output last dim varies per group, else None/(0,) contracting_dims: Tuple of two sequences representing the contracting dimensions bias: Bias tensor of shape (G, N) precision: JAX precision for the GEMM operation @@ -2079,12 +2095,12 @@ def grouped_gemm( # Replace None sentinels with empty (0,) int32 arrays. empty_gs = jnp.empty((0,), jnp.int32) - if lhs_group_sizes is None: - lhs_group_sizes = empty_gs - if rhs_group_sizes is None: - rhs_group_sizes = empty_gs - if out_group_sizes is None: - out_group_sizes = empty_gs + lhs_first_dims = empty_gs if lhs_first_dims is None else lhs_first_dims + lhs_last_dims = empty_gs if lhs_last_dims is None else lhs_last_dims + rhs_first_dims = empty_gs if rhs_first_dims is None else rhs_first_dims + rhs_last_dims = empty_gs if rhs_last_dims is None else rhs_last_dims + out_first_dims = empty_gs if out_first_dims is None else out_first_dims + out_last_dims = empty_gs if out_last_dims is None else out_last_dims if isinstance(lhs, jnp.ndarray): assert isinstance(rhs, jnp.ndarray) @@ -2128,7 +2144,7 @@ def grouped_gemm( # TODO(Hua): these are for fp16 dense wgrad, any better way to handle this? if ( - rhs_group_sizes.size > 0 # wgrad mode: rhs first-dim is ragged + (rhs_first_dims.size > 0 or rhs_last_dims.size > 0) # wgrad mode: rhs dim is ragged and not isinstance(lhs, ScaledTensor) and not isinstance(rhs, ScaledTensor) ): @@ -2140,7 +2156,7 @@ def grouped_gemm( # For MXFP8 block-scaling wgrad with pre-quantized inputs: rhs is colwise quantized, # so rhs_use_colwise = (is_mxfp8 && !rhs_is_trans) must be True → rhs_is_trans=False. if ( - rhs_group_sizes.size > 0 # wgrad mode: rhs first-dim is ragged + (rhs_first_dims.size > 0 or rhs_last_dims.size > 0) # wgrad mode: rhs dim is ragged and isinstance(lhs, GroupedScaledTensor1x) and scaling_mode.is_1d_block_scaling() ): @@ -2168,7 +2184,11 @@ def grouped_gemm( quantizer_set.kernel.q_layout = ( QuantizeLayout.ROWWISE if rhs_is_rowwise else QuantizeLayout.COLWISE ) - active_group_sizes = lhs_group_sizes if lhs_group_sizes.size > 0 else rhs_group_sizes + active_group_sizes = next( + (gs for gs in [lhs_first_dims, lhs_last_dims, rhs_first_dims, rhs_last_dims] + if gs.size > 0), + empty_gs, + ) lhs_q = grouped_quantize(lhs, quantizer_set.x, active_group_sizes, lhs_flatten_axis) rhs_q = grouped_quantize( rhs, quantizer_set.kernel, group_sizes=None, flatten_axis=rhs_flatten_axis @@ -2222,7 +2242,7 @@ def grouped_gemm( lhs_contract_dim = tuple((lhs_ndim - 1 - i) % lhs_ndim for i in lhs_contract_dim) if rhs_layout_is_T: # For rhs [G, K, N], need to exclude the G dim from contract_dim - if lhs_group_sizes.size > 0: # fwd/dgrad: rhs has G as first dim + if lhs_first_dims.size > 0 or lhs_last_dims.size > 0: # fwd/dgrad: rhs has G as first dim rhs_contract_dim = tuple( (rhs_ndim - 1 - i) % (rhs_ndim - 1) + 1 for i in rhs_contract_dim ) @@ -2233,7 +2253,11 @@ def grouped_gemm( lhs_data_2d = _flatten_to_2d(lhs_data, lhs_flatten_axis) rhs_data_2d = _flatten_to_2d(rhs_data, rhs_flatten_axis) - num_gemms = lhs_group_sizes.size or rhs_group_sizes.size or out_group_sizes.size + num_gemms = ( + lhs_first_dims.size or lhs_last_dims.size + or rhs_first_dims.size or rhs_last_dims.size + or out_first_dims.size or out_last_dims.size + ) has_bias = bias is not None if has_bias: @@ -2264,9 +2288,12 @@ def grouped_gemm( rhs_data_2d, rhs_scale_inv, bias, - lhs_group_sizes, - rhs_group_sizes, - out_group_sizes, + lhs_first_dims, + lhs_last_dims, + rhs_first_dims, + rhs_last_dims, + out_first_dims, + out_last_dims, additional_arg_0, additional_arg_1, lhs_is_trans=lhs_is_trans, diff --git a/transformer_engine/jax/csrc/extensions/gemm.cpp b/transformer_engine/jax/csrc/extensions/gemm.cpp index 834a7b9a5f..4387354f2a 100644 --- a/transformer_engine/jax/csrc/extensions/gemm.cpp +++ b/transformer_engine/jax/csrc/extensions/gemm.cpp @@ -562,8 +562,10 @@ JAXX_GroupedTensorWrapper make_grouped_tensor(Buffer_Type const &data, // This FFI is EXPERIMENTAL and subject to change without deprecation, intended for use in JAX's internal implementation of grouped GEMM. Error_Type GroupedGemmV2FFI(cudaStream_t stream, Buffer_Type lhs_data, Buffer_Type lhs_sinv, Buffer_Type rhs_data, Buffer_Type rhs_sinv, Buffer_Type bias, - Buffer_Type lhs_group_sizes, Buffer_Type rhs_group_sizes, - Buffer_Type out_group_sizes, Buffer_Type alpha, Buffer_Type beta, + Buffer_Type lhs_first_dims, Buffer_Type lhs_last_dims, + Buffer_Type rhs_first_dims, Buffer_Type rhs_last_dims, + Buffer_Type out_first_dims, Buffer_Type out_last_dims, + Buffer_Type alpha, Buffer_Type beta, Result_Type output, Result_Type cublas_workspace, Result_Type setup_workspace, Result_Type int64_workspace, bool lhs_is_trans, bool rhs_is_trans, @@ -582,25 +584,31 @@ Error_Type GroupedGemmV2FFI(cudaStream_t stream, Buffer_Type lhs_data, Buffer_Ty // C: column-major with size [m, n] --> row-major with size [n, m]. // To make the output compatible with JAX, we need to swap A and B in cuBLAS GEMM call. - // out_group_sizes is the sentinel for the output tensor's ragged dimension; unused directly here - // as the output shape is inferred from lhs/rhs dims and passed to nvte_grouped_gemm implicitly. - (void)out_group_sizes; - - // Determine which group_sizes buffer is active (non-empty sentinel = ragged dimension). - bool is_lhs_ragged = lhs_group_sizes.element_count() > 0; - bool is_rhs_ragged = rhs_group_sizes.element_count() > 0; - bool any_ragged = is_lhs_ragged || is_rhs_ragged; + // Determine which group_sizes buffers are active (non-empty = ragged dimension). + bool is_lhs_first_ragged = lhs_first_dims.element_count() > 0; + bool is_lhs_last_ragged = lhs_last_dims.element_count() > 0; + bool is_rhs_first_ragged = rhs_first_dims.element_count() > 0; + bool is_rhs_last_ragged = rhs_last_dims.element_count() > 0; + bool is_out_first_ragged = out_first_dims.element_count() > 0; + bool is_out_last_ragged = out_last_dims.element_count() > 0; + bool is_lhs_ragged = is_lhs_first_ragged || is_lhs_last_ragged; + bool is_rhs_ragged = is_rhs_first_ragged || is_rhs_last_ragged; + bool any_ragged = is_lhs_ragged || is_rhs_ragged; size_t num_gemms; - if (is_lhs_ragged) - num_gemms = lhs_group_sizes.dimensions()[0]; - else if (is_rhs_ragged) - num_gemms = rhs_group_sizes.dimensions()[0]; - else if (out_group_sizes.element_count() > 0) - num_gemms = out_group_sizes.dimensions()[0]; - else - num_gemms = alpha.element_count(); // batched: no ragged tensor - const Buffer_Type &active_group_sizes = is_lhs_ragged ? lhs_group_sizes : rhs_group_sizes; + if (is_lhs_first_ragged) num_gemms = lhs_first_dims.dimensions()[0]; + else if (is_lhs_last_ragged) num_gemms = lhs_last_dims.dimensions()[0]; + else if (is_rhs_first_ragged) num_gemms = rhs_first_dims.dimensions()[0]; + else if (is_rhs_last_ragged) num_gemms = rhs_last_dims.dimensions()[0]; + else if (is_out_first_ragged) num_gemms = out_first_dims.dimensions()[0]; + else if (is_out_last_ragged) num_gemms = out_last_dims.dimensions()[0]; + else num_gemms = alpha.element_count(); // batched: no ragged tensor + + const Buffer_Type *active_gs_ptr = nullptr; + if (is_lhs_first_ragged) active_gs_ptr = &lhs_first_dims; + else if (is_lhs_last_ragged) active_gs_ptr = &lhs_last_dims; + else if (is_rhs_first_ragged) active_gs_ptr = &rhs_first_dims; + else if (is_rhs_last_ragged) active_gs_ptr = &rhs_last_dims; // lhs_data and rhs_data are 2D; derive m, n, k from buffer dimensions. NVTE_CHECK(lhs_data.dimensions().size() == 2, "lhs_data must be 2D."); @@ -632,10 +640,11 @@ Error_Type GroupedGemmV2FFI(cudaStream_t stream, Buffer_Type lhs_data, Buffer_Ty // Convert int32 group_sizes to int64 into the dedicated output buffer (ragged tensors only). auto *int64_sizes_ptr = reinterpret_cast(int64_workspace->untyped_data()); if (any_ragged) { - NVTE_CHECK(active_group_sizes.element_type() == xla::ffi::DataType::S32, + NVTE_CHECK(active_gs_ptr != nullptr, "active_gs_ptr is null but any_ragged is true."); + NVTE_CHECK(active_gs_ptr->element_type() == xla::ffi::DataType::S32, "group_sizes must be int32."); nvte_convert_int32_to_int64( - reinterpret_cast(active_group_sizes.untyped_data()), int64_sizes_ptr, + reinterpret_cast(active_gs_ptr->untyped_data()), int64_sizes_ptr, num_gemms, stream); } @@ -746,13 +755,19 @@ Error_Type GroupedGemmV2FFI(cudaStream_t stream, Buffer_Type lhs_data, Buffer_Ty //// RHS NVTEShape rhsShape{.data = {k, n}, .ndim = 2}; auto rhs_tensor = make_grouped_tensor(rhs_data, rhs_sinv, scaling_mode, num_gemms, rhsShape); - rhs_tensor.set_group_sizes_only(int64_sizes_ptr, num_gemms, kNVTEGroupedFirstDims); + if (is_rhs_first_ragged) + rhs_tensor.set_group_sizes_only(int64_sizes_ptr, num_gemms, kNVTEGroupedFirstDims); + if (is_rhs_last_ragged) + rhs_tensor.set_group_sizes_only(int64_sizes_ptr, num_gemms, kNVTEGroupedLastDims); //// LHS NVTEShape lhsShape{.data = {k, m}, .ndim = 2}; lhs_is_trans = true; auto lhs_tensor = make_grouped_tensor(lhs_data, lhs_sinv, scaling_mode, num_gemms, lhsShape); - lhs_tensor.set_group_sizes_only(int64_sizes_ptr, num_gemms, kNVTEGroupedFirstDims); + if (is_lhs_first_ragged) + lhs_tensor.set_group_sizes_only(int64_sizes_ptr, num_gemms, kNVTEGroupedFirstDims); + if (is_lhs_last_ragged) + lhs_tensor.set_group_sizes_only(int64_sizes_ptr, num_gemms, kNVTEGroupedLastDims); //// OUTPUT NVTEShape outShape{.data = {num_gemms * m, n}, .ndim = 2}; @@ -784,18 +799,19 @@ Error_Type GroupedGemmV2FFI(cudaStream_t stream, Buffer_Type lhs_data, Buffer_Ty std::swap(lhsShape.data[0], lhsShape.data[1]); } auto lhs_tensor = make_grouped_tensor(lhs_data, lhs_sinv, scaling_mode, num_gemms, lhsShape); - if (any_ragged) { - lhs_tensor.set_group_sizes_only(int64_sizes_ptr, num_gemms, - lhs_is_trans ? kNVTEGroupedLastDims : kNVTEGroupedFirstDims); - } + if (is_lhs_first_ragged) + lhs_tensor.set_group_sizes_only(int64_sizes_ptr, num_gemms, kNVTEGroupedFirstDims); + if (is_lhs_last_ragged) + lhs_tensor.set_group_sizes_only(int64_sizes_ptr, num_gemms, kNVTEGroupedLastDims); //// OUTPUT NVTEShape outShape{.data = {m, n}, .ndim = 2}; auto out_tensor = make_grouped_tensor(*output, std::nullopt, JAXX_Scaling_Mode::NO_SCALING, num_gemms, outShape); - if (any_ragged) { + if (is_out_first_ragged) out_tensor.set_group_sizes_only(int64_sizes_ptr, num_gemms, kNVTEGroupedFirstDims); - } + if (is_out_last_ragged) + out_tensor.set_group_sizes_only(int64_sizes_ptr, num_gemms, kNVTEGroupedLastDims); nvte_grouped_gemm(rhs_tensor, rhs_is_trans, lhs_tensor, lhs_is_trans, nullptr, out_tensor, alpha_tensor.data(), beta_tensor.data(), workspace_setup.data(), @@ -814,9 +830,12 @@ XLA_FFI_DEFINE_HANDLER_SYMBOL(GroupedGemmV2Handler, GroupedGemmV2FFI, .Arg() // rhs_data (2D) .Arg() // rhs_sinv .Arg() // bias - .Arg() // lhs_group_sizes (G,) or empty (0,) - .Arg() // rhs_group_sizes (G,) or empty (0,) - .Arg() // out_group_sizes (G,) or empty (0,) + .Arg() // lhs_first_dims (G,) or empty (0,) + .Arg() // lhs_last_dims (G,) or empty (0,) + .Arg() // rhs_first_dims (G,) or empty (0,) + .Arg() // rhs_last_dims (G,) or empty (0,) + .Arg() // out_first_dims (G,) or empty (0,) + .Arg() // out_last_dims (G,) or empty (0,) .Arg() // alpha .Arg() // beta .Ret() // output @@ -830,8 +849,10 @@ XLA_FFI_DEFINE_HANDLER_SYMBOL(GroupedGemmV2Handler, GroupedGemmV2FFI, Error_Type GroupedGemmFFI(cudaStream_t stream, Buffer_Type lhs_data, Buffer_Type lhs_sinv, Buffer_Type rhs_data, Buffer_Type rhs_sinv, Buffer_Type bias, - Buffer_Type lhs_group_sizes, Buffer_Type rhs_group_sizes, - Buffer_Type out_group_sizes, Buffer_Type group_offset, + Buffer_Type lhs_first_dims, Buffer_Type lhs_last_dims, + Buffer_Type rhs_first_dims, Buffer_Type rhs_last_dims, + Buffer_Type out_first_dims, Buffer_Type out_last_dims, + Buffer_Type group_offset, Result_Type output, Result_Type workspace, bool lhs_is_trans, bool rhs_is_trans, JAXX_Scaling_Mode scaling_mode, bool has_bias, bool use_async_d2h_group_sizes) { @@ -851,22 +872,27 @@ Error_Type GroupedGemmFFI(cudaStream_t stream, Buffer_Type lhs_data, Buffer_Type int num_streams = nvte_get_num_compute_streams(); - // out_group_sizes is the sentinel for the output tensor's ragged dimension; unused directly here. - (void)out_group_sizes; - - // Determine which group_sizes buffer is active (non-empty sentinel = ragged dimension). - bool is_lhs_ragged = lhs_group_sizes.element_count() > 0; - bool is_rhs_ragged = rhs_group_sizes.element_count() > 0; - bool any_ragged = is_lhs_ragged || is_rhs_ragged; + // Determine which group_sizes buffers are active (non-empty = ragged dimension). + bool is_lhs_first_ragged = lhs_first_dims.element_count() > 0; + bool is_lhs_last_ragged = lhs_last_dims.element_count() > 0; + bool is_rhs_first_ragged = rhs_first_dims.element_count() > 0; + bool is_rhs_last_ragged = rhs_last_dims.element_count() > 0; + bool is_lhs_ragged = is_lhs_first_ragged || is_lhs_last_ragged; + bool is_rhs_ragged = is_rhs_first_ragged || is_rhs_last_ragged; + bool any_ragged = is_lhs_ragged || is_rhs_ragged; size_t num_gemms; - if (is_lhs_ragged) - num_gemms = lhs_group_sizes.dimensions()[0]; - else if (is_rhs_ragged) - num_gemms = rhs_group_sizes.dimensions()[0]; - else - num_gemms = 1; // degenerate batched; legacy batched not a tested use case - const Buffer_Type &active_group_sizes = is_lhs_ragged ? lhs_group_sizes : rhs_group_sizes; + if (is_lhs_first_ragged) num_gemms = lhs_first_dims.dimensions()[0]; + else if (is_lhs_last_ragged) num_gemms = lhs_last_dims.dimensions()[0]; + else if (is_rhs_first_ragged) num_gemms = rhs_first_dims.dimensions()[0]; + else if (is_rhs_last_ragged) num_gemms = rhs_last_dims.dimensions()[0]; + else num_gemms = 1; // degenerate batched; legacy batched not a tested use case + + const Buffer_Type *active_gs_ptr = nullptr; + if (is_lhs_first_ragged) active_gs_ptr = &lhs_first_dims; + else if (is_lhs_last_ragged) active_gs_ptr = &lhs_last_dims; + else if (is_rhs_first_ragged) active_gs_ptr = &rhs_first_dims; + else if (is_rhs_last_ragged) active_gs_ptr = &rhs_last_dims; // lhs_data and rhs_data are 2D; derive m, n, k from buffer dimensions. NVTE_CHECK(lhs_data.dimensions().size() == 2, "lhs_data must be 2D."); @@ -990,9 +1016,10 @@ Error_Type GroupedGemmFFI(cudaStream_t stream, Buffer_Type lhs_data, Buffer_Type NVTE_CHECK(host_num_gemms == num_gemms, "num_gemms ", num_gemms, " does not match the return of GroupedGemmGetGroupSizes ", host_num_gemms, "."); } else { - auto active_gs_ptr = - reinterpret_cast(active_group_sizes.untyped_data()); - cudaMemcpyAsync(dim_list_host.data(), active_gs_ptr, dim_list_bytes, cudaMemcpyDeviceToHost, + NVTE_CHECK(active_gs_ptr != nullptr, "active_gs_ptr is null but any_ragged is true."); + auto gs_data_ptr = + reinterpret_cast(active_gs_ptr->untyped_data()); + cudaMemcpyAsync(dim_list_host.data(), gs_data_ptr, dim_list_bytes, cudaMemcpyDeviceToHost, stream); // Note: This may break cudaGraph. cudaStreamSynchronize(stream); @@ -1247,9 +1274,12 @@ XLA_FFI_DEFINE_HANDLER_SYMBOL(GroupedGemmHandler, GroupedGemmFFI, .Arg() // rhs_data (2D) .Arg() // rhs_sinv .Arg() // bias - .Arg() // lhs_group_sizes (G,) or empty (0,) - .Arg() // rhs_group_sizes (G,) or empty (0,) - .Arg() // out_group_sizes (G,) or empty (0,) + .Arg() // lhs_first_dims (G,) or empty (0,) + .Arg() // lhs_last_dims (G,) or empty (0,) + .Arg() // rhs_first_dims (G,) or empty (0,) + .Arg() // rhs_last_dims (G,) or empty (0,) + .Arg() // out_first_dims (G,) or empty (0,) + .Arg() // out_last_dims (G,) or empty (0,) .Arg() // group_offset .Ret() // output .Ret() // workspace diff --git a/transformer_engine/jax/dense.py b/transformer_engine/jax/dense.py index e2a79fe9c7..ed4d0aa082 100644 --- a/transformer_engine/jax/dense.py +++ b/transformer_engine/jax/dense.py @@ -527,9 +527,12 @@ def _grouped_dense_fwd_rule( output = tex.grouped_gemm( grouped_gemm_x, grouped_gemm_kernel, - lhs_group_sizes=group_sizes, - rhs_group_sizes=empty_gs, - out_group_sizes=group_sizes, + lhs_first_dims=group_sizes, + lhs_last_dims=empty_gs, + rhs_first_dims=empty_gs, + rhs_last_dims=empty_gs, + out_first_dims=group_sizes, + out_last_dims=empty_gs, contracting_dims=contracting_dims, bias=bias, precision=precision, @@ -622,9 +625,12 @@ def _grouped_dense_bwd_rule( dgrad = tex.grouped_gemm( dgrad_grad, dgrad_kernel_T, - lhs_group_sizes=group_sizes, - rhs_group_sizes=empty_gs, - out_group_sizes=group_sizes, + lhs_first_dims=group_sizes, + lhs_last_dims=empty_gs, + rhs_first_dims=empty_gs, + rhs_last_dims=empty_gs, + out_first_dims=group_sizes, + out_last_dims=empty_gs, contracting_dims=dgrad_contracting_dims, precision=precision, preferred_element_type=preferred_element_type, @@ -634,9 +640,12 @@ def _grouped_dense_bwd_rule( wgrad = tex.grouped_gemm( wgrad_x_T, wgrad_grad, - lhs_group_sizes=empty_gs, - rhs_group_sizes=group_sizes, - out_group_sizes=empty_gs, + lhs_first_dims=group_sizes, + lhs_last_dims=empty_gs, + rhs_first_dims=group_sizes, + rhs_last_dims=empty_gs, + out_first_dims=empty_gs, + out_last_dims=empty_gs, contracting_dims=wgrad_contracting_dims, precision=precision, preferred_element_type=preferred_element_type, From 345d940869181f9c2e57820cffcc69b277d78903 Mon Sep 17 00:00:00 2001 From: Jeremy Berchtold Date: Tue, 10 Mar 2026 10:22:31 -0700 Subject: [PATCH 03/60] Refactor GMM FFIs to store static attrs as structs Signed-off-by: Jeremy Berchtold --- .../jax/csrc/extensions/gemm.cpp | 131 ++++++++++++++++-- 1 file changed, 118 insertions(+), 13 deletions(-) diff --git a/transformer_engine/jax/csrc/extensions/gemm.cpp b/transformer_engine/jax/csrc/extensions/gemm.cpp index 4387354f2a..24026b4ad9 100644 --- a/transformer_engine/jax/csrc/extensions/gemm.cpp +++ b/transformer_engine/jax/csrc/extensions/gemm.cpp @@ -559,6 +559,117 @@ JAXX_GroupedTensorWrapper make_grouped_tensor(Buffer_Type const &data, return std::move(grouped_tensor_wrapper); } +// Config structs for grouped GEMM FFI static attributes. +// Consolidating all static attributes into a single dict attribute makes it easy to add new +// attributes in the future with backwards-compatible defaults: if old HLO was generated without a +// newer attribute, DecodeAttrOrDefault leaves the field at its struct default value. +struct GroupedGemmV2Config { + bool lhs_is_trans = false; + bool rhs_is_trans = false; + JAXX_Scaling_Mode scaling_mode = JAXX_Scaling_Mode::NO_SCALING; +}; + +struct GroupedGemmConfig { + bool lhs_is_trans = false; + bool rhs_is_trans = false; + JAXX_Scaling_Mode scaling_mode = JAXX_Scaling_Mode::NO_SCALING; + bool has_bias = false; + bool use_async_d2h_group_sizes = false; +}; + +} // namespace jax +} // namespace transformer_engine + +// Register AttrsBinding and AttrDecoding for grouped GEMM config structs. +// Uses a custom AttrDecoding (instead of XLA_FFI_REGISTER_STRUCT_ATTR_DECODING) that supports +// optional struct fields with default values, so old HLO without newer attributes still decodes. +namespace xla::ffi { + +namespace { + +// Finds an attribute by name. Returns its index or std::nullopt if absent. +std::optional FindAttrByName(const XLA_FFI_Attrs* attrs, std::string_view name) { + for (int64_t i = 0; i < attrs->size; ++i) { + if (std::string_view{attrs->names[i]->ptr, attrs->names[i]->len} == name) return i; + } + return std::nullopt; +} + +// Decodes a named attribute into `field` if present; leaves `field` at its default if absent. +// Returns false only when the attribute is present but fails to decode. +template +bool DecodeAttrOrDefault(const XLA_FFI_Attrs* attrs, std::string_view name, T& field, + DiagnosticEngine& diagnostic) { + auto idx = FindAttrByName(attrs, name); + if (!idx.has_value()) return true; // absent → keep default + auto decoded = AttrDecoding::Decode(attrs->types[*idx], attrs->attrs[*idx], diagnostic); + if (!decoded.has_value()) return false; + field = *decoded; + return true; +} + +} // namespace + +template <> +struct AttrsBinding { + using Attrs = transformer_engine::jax::GroupedGemmV2Config; +}; + +template <> +struct AttrDecoding { + using Type = transformer_engine::jax::GroupedGemmV2Config; + static std::optional Decode(XLA_FFI_AttrType type, void* attr, + DiagnosticEngine& diagnostic) { + if (XLA_FFI_PREDICT_FALSE(type != XLA_FFI_AttrType_DICTIONARY)) { + return diagnostic.Emit("Expected dictionary attribute for GroupedGemmV2Config"); + } + auto* attrs = reinterpret_cast(attr); + Type config; + if (!DecodeAttrOrDefault(attrs, "lhs_is_trans", config.lhs_is_trans, diagnostic)) + return std::nullopt; + if (!DecodeAttrOrDefault(attrs, "rhs_is_trans", config.rhs_is_trans, diagnostic)) + return std::nullopt; + if (!DecodeAttrOrDefault(attrs, "scaling_mode", config.scaling_mode, diagnostic)) + return std::nullopt; + return config; + } +}; + +template <> +struct AttrsBinding { + using Attrs = transformer_engine::jax::GroupedGemmConfig; +}; + +template <> +struct AttrDecoding { + using Type = transformer_engine::jax::GroupedGemmConfig; + static std::optional Decode(XLA_FFI_AttrType type, void* attr, + DiagnosticEngine& diagnostic) { + if (XLA_FFI_PREDICT_FALSE(type != XLA_FFI_AttrType_DICTIONARY)) { + return diagnostic.Emit("Expected dictionary attribute for GroupedGemmConfig"); + } + auto* attrs = reinterpret_cast(attr); + Type config; + if (!DecodeAttrOrDefault(attrs, "lhs_is_trans", config.lhs_is_trans, diagnostic)) + return std::nullopt; + if (!DecodeAttrOrDefault(attrs, "rhs_is_trans", config.rhs_is_trans, diagnostic)) + return std::nullopt; + if (!DecodeAttrOrDefault(attrs, "scaling_mode", config.scaling_mode, diagnostic)) + return std::nullopt; + if (!DecodeAttrOrDefault(attrs, "has_bias", config.has_bias, diagnostic)) + return std::nullopt; + if (!DecodeAttrOrDefault(attrs, "use_async_d2h_group_sizes", + config.use_async_d2h_group_sizes, diagnostic)) + return std::nullopt; + return config; + } +}; + +} // namespace xla::ffi + +namespace transformer_engine { +namespace jax { + // This FFI is EXPERIMENTAL and subject to change without deprecation, intended for use in JAX's internal implementation of grouped GEMM. Error_Type GroupedGemmV2FFI(cudaStream_t stream, Buffer_Type lhs_data, Buffer_Type lhs_sinv, Buffer_Type rhs_data, Buffer_Type rhs_sinv, Buffer_Type bias, @@ -568,8 +679,8 @@ Error_Type GroupedGemmV2FFI(cudaStream_t stream, Buffer_Type lhs_data, Buffer_Ty Buffer_Type alpha, Buffer_Type beta, Result_Type output, Result_Type cublas_workspace, Result_Type setup_workspace, Result_Type int64_workspace, - bool lhs_is_trans, bool rhs_is_trans, - JAXX_Scaling_Mode scaling_mode) { + GroupedGemmV2Config config) { + auto [lhs_is_trans, rhs_is_trans, scaling_mode] = config; // Notes on matrix layouts and transpose: // Jax uses row-major data_layout, on entering this function, each input matrix pair: // A: row-major [m, k] for N - [k, m] for T @@ -842,9 +953,7 @@ XLA_FFI_DEFINE_HANDLER_SYMBOL(GroupedGemmV2Handler, GroupedGemmV2FFI, .Ret() // cublas_workspace .Ret() // setup_workspace .Ret() // int64_workspace - .Attr("lhs_is_trans") - .Attr("rhs_is_trans") - .Attr("scaling_mode"), + .Attrs(), FFI_CudaGraph_Traits); Error_Type GroupedGemmFFI(cudaStream_t stream, Buffer_Type lhs_data, Buffer_Type lhs_sinv, @@ -853,9 +962,9 @@ Error_Type GroupedGemmFFI(cudaStream_t stream, Buffer_Type lhs_data, Buffer_Type Buffer_Type rhs_first_dims, Buffer_Type rhs_last_dims, Buffer_Type out_first_dims, Buffer_Type out_last_dims, Buffer_Type group_offset, - Result_Type output, Result_Type workspace, bool lhs_is_trans, - bool rhs_is_trans, JAXX_Scaling_Mode scaling_mode, bool has_bias, - bool use_async_d2h_group_sizes) { + Result_Type output, Result_Type workspace, + GroupedGemmConfig config) { + auto [lhs_is_trans, rhs_is_trans, scaling_mode, has_bias, use_async_d2h_group_sizes] = config; // Notes on matrix layouts and transpose: // Jax uses row-major data_layout, on entering this function, each input matrix pair: // A: row-major [m, k] for N - [k, m] for T @@ -1283,11 +1392,7 @@ XLA_FFI_DEFINE_HANDLER_SYMBOL(GroupedGemmHandler, GroupedGemmFFI, .Arg() // group_offset .Ret() // output .Ret() // workspace - .Attr("lhs_is_trans") - .Attr("rhs_is_trans") - .Attr("scaling_mode") - .Attr("has_bias") - .Attr("use_async_d2h_group_sizes")); + .Attrs()); } // namespace jax } // namespace transformer_engine From ed9c8e4e275b563a1a35a5eeafc5cf46052c58c7 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 10 Mar 2026 17:25:19 +0000 Subject: [PATCH 04/60] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- transformer_engine/jax/cpp_extensions/gemm.py | 59 ++++----- .../jax/csrc/extensions/gemm.cpp | 117 ++++++++++-------- 2 files changed, 97 insertions(+), 79 deletions(-) diff --git a/transformer_engine/jax/cpp_extensions/gemm.py b/transformer_engine/jax/cpp_extensions/gemm.py index 32500c9676..d79be04983 100644 --- a/transformer_engine/jax/cpp_extensions/gemm.py +++ b/transformer_engine/jax/cpp_extensions/gemm.py @@ -1510,9 +1510,12 @@ def abstract( # Determine mode from which group_sizes buffer is non-empty is_wgrad = rhs_first_dims_aval.size > 0 or rhs_last_dims_aval.size > 0 num_groups = ( - lhs_first_dims_aval.size or lhs_last_dims_aval.size - or rhs_first_dims_aval.size or rhs_last_dims_aval.size - or out_first_dims_aval.size or out_last_dims_aval.size + lhs_first_dims_aval.size + or lhs_last_dims_aval.size + or rhs_first_dims_aval.size + or rhs_last_dims_aval.size + or out_first_dims_aval.size + or out_last_dims_aval.size or additional_args[0].size # alpha (V2) has size G; group_offset (legacy) has size >= 1 ) @@ -1527,11 +1530,7 @@ def abstract( # lhs shape [M_total, K] (lhs_is_trans=False) or [K, M_total] (lhs_is_trans=True) # dim[0] is always total M for fwd/dgrad M = lhs_data_aval.shape[0] - N = ( - rhs_data_aval.shape[1] - if not rhs_is_trans - else rhs_data_aval.shape[0] // num_groups - ) + N = rhs_data_aval.shape[1] if not rhs_is_trans else rhs_data_aval.shape[0] // num_groups out_shape = (M, N) cublas_workspace_aval = jax.core.ShapedArray( @@ -2048,12 +2047,12 @@ def _flatten_to_2d(data, flatten_axis): def grouped_gemm( lhs: Union[jnp.ndarray, GroupedScaledTensor1x], rhs: Union[jnp.ndarray, GroupedScaledTensor1x], - lhs_first_dims: jnp.ndarray = None, # (G,) int32 if LHS squashed first dim varies, else None/(0,) - lhs_last_dims: jnp.ndarray = None, # (G,) int32 if LHS squashed last dim varies, else None/(0,) - rhs_first_dims: jnp.ndarray = None, # (G,) int32 if RHS squashed first dim varies, else None/(0,) - rhs_last_dims: jnp.ndarray = None, # (G,) int32 if RHS squashed last dim varies, else None/(0,) - out_first_dims: jnp.ndarray = None, # (G,) int32 if output first dim varies, else None/(0,) - out_last_dims: jnp.ndarray = None, # (G,) int32 if output last dim varies, else None/(0,) + lhs_first_dims: jnp.ndarray = None, # (G,) int32 if LHS squashed first dim varies, else None/(0,) + lhs_last_dims: jnp.ndarray = None, # (G,) int32 if LHS squashed last dim varies, else None/(0,) + rhs_first_dims: jnp.ndarray = None, # (G,) int32 if RHS squashed first dim varies, else None/(0,) + rhs_last_dims: jnp.ndarray = None, # (G,) int32 if RHS squashed last dim varies, else None/(0,) + out_first_dims: jnp.ndarray = None, # (G,) int32 if output first dim varies, else None/(0,) + out_last_dims: jnp.ndarray = None, # (G,) int32 if output last dim varies, else None/(0,) contracting_dims: Tuple[Sequence[int], Sequence[int]] = ((1,), (2,)), bias: jnp.ndarray = None, precision: jax.lax.Precision = jax.lax.Precision.DEFAULT, @@ -2118,12 +2117,8 @@ def grouped_gemm( rhs_shape = rhs.original_shape lhs_fa = lhs.flatten_axis rhs_fa = rhs.flatten_axis - lhs_data = lhs.data.reshape( - math.prod(lhs_shape[:lhs_fa]), math.prod(lhs_shape[lhs_fa:]) - ) - rhs_data = rhs.data.reshape( - math.prod(rhs_shape[:rhs_fa]), math.prod(rhs_shape[rhs_fa:]) - ) + lhs_data = lhs.data.reshape(math.prod(lhs_shape[:lhs_fa]), math.prod(lhs_shape[lhs_fa:])) + rhs_data = rhs.data.reshape(math.prod(rhs_shape[:rhs_fa]), math.prod(rhs_shape[rhs_fa:])) lhs_scale_inv = lhs.scale_inv rhs_scale_inv = rhs.scale_inv assert lhs.scaling_mode == rhs.scaling_mode @@ -2144,7 +2139,7 @@ def grouped_gemm( # TODO(Hua): these are for fp16 dense wgrad, any better way to handle this? if ( - (rhs_first_dims.size > 0 or rhs_last_dims.size > 0) # wgrad mode: rhs dim is ragged + (rhs_first_dims.size > 0 or rhs_last_dims.size > 0) # wgrad mode: rhs dim is ragged and not isinstance(lhs, ScaledTensor) and not isinstance(rhs, ScaledTensor) ): @@ -2156,7 +2151,7 @@ def grouped_gemm( # For MXFP8 block-scaling wgrad with pre-quantized inputs: rhs is colwise quantized, # so rhs_use_colwise = (is_mxfp8 && !rhs_is_trans) must be True → rhs_is_trans=False. if ( - (rhs_first_dims.size > 0 or rhs_last_dims.size > 0) # wgrad mode: rhs dim is ragged + (rhs_first_dims.size > 0 or rhs_last_dims.size > 0) # wgrad mode: rhs dim is ragged and isinstance(lhs, GroupedScaledTensor1x) and scaling_mode.is_1d_block_scaling() ): @@ -2185,8 +2180,11 @@ def grouped_gemm( QuantizeLayout.ROWWISE if rhs_is_rowwise else QuantizeLayout.COLWISE ) active_group_sizes = next( - (gs for gs in [lhs_first_dims, lhs_last_dims, rhs_first_dims, rhs_last_dims] - if gs.size > 0), + ( + gs + for gs in [lhs_first_dims, lhs_last_dims, rhs_first_dims, rhs_last_dims] + if gs.size > 0 + ), empty_gs, ) lhs_q = grouped_quantize(lhs, quantizer_set.x, active_group_sizes, lhs_flatten_axis) @@ -2242,7 +2240,9 @@ def grouped_gemm( lhs_contract_dim = tuple((lhs_ndim - 1 - i) % lhs_ndim for i in lhs_contract_dim) if rhs_layout_is_T: # For rhs [G, K, N], need to exclude the G dim from contract_dim - if lhs_first_dims.size > 0 or lhs_last_dims.size > 0: # fwd/dgrad: rhs has G as first dim + if ( + lhs_first_dims.size > 0 or lhs_last_dims.size > 0 + ): # fwd/dgrad: rhs has G as first dim rhs_contract_dim = tuple( (rhs_ndim - 1 - i) % (rhs_ndim - 1) + 1 for i in rhs_contract_dim ) @@ -2254,9 +2254,12 @@ def grouped_gemm( rhs_data_2d = _flatten_to_2d(rhs_data, rhs_flatten_axis) num_gemms = ( - lhs_first_dims.size or lhs_last_dims.size - or rhs_first_dims.size or rhs_last_dims.size - or out_first_dims.size or out_last_dims.size + lhs_first_dims.size + or lhs_last_dims.size + or rhs_first_dims.size + or rhs_last_dims.size + or out_first_dims.size + or out_last_dims.size ) has_bias = bias is not None diff --git a/transformer_engine/jax/csrc/extensions/gemm.cpp b/transformer_engine/jax/csrc/extensions/gemm.cpp index 24026b4ad9..e85c520916 100644 --- a/transformer_engine/jax/csrc/extensions/gemm.cpp +++ b/transformer_engine/jax/csrc/extensions/gemm.cpp @@ -588,7 +588,7 @@ namespace xla::ffi { namespace { // Finds an attribute by name. Returns its index or std::nullopt if absent. -std::optional FindAttrByName(const XLA_FFI_Attrs* attrs, std::string_view name) { +std::optional FindAttrByName(const XLA_FFI_Attrs *attrs, std::string_view name) { for (int64_t i = 0; i < attrs->size; ++i) { if (std::string_view{attrs->names[i]->ptr, attrs->names[i]->len} == name) return i; } @@ -598,8 +598,8 @@ std::optional FindAttrByName(const XLA_FFI_Attrs* attrs, std::string_vi // Decodes a named attribute into `field` if present; leaves `field` at its default if absent. // Returns false only when the attribute is present but fails to decode. template -bool DecodeAttrOrDefault(const XLA_FFI_Attrs* attrs, std::string_view name, T& field, - DiagnosticEngine& diagnostic) { +bool DecodeAttrOrDefault(const XLA_FFI_Attrs *attrs, std::string_view name, T &field, + DiagnosticEngine &diagnostic) { auto idx = FindAttrByName(attrs, name); if (!idx.has_value()) return true; // absent → keep default auto decoded = AttrDecoding::Decode(attrs->types[*idx], attrs->attrs[*idx], diagnostic); @@ -618,12 +618,12 @@ struct AttrsBinding { template <> struct AttrDecoding { using Type = transformer_engine::jax::GroupedGemmV2Config; - static std::optional Decode(XLA_FFI_AttrType type, void* attr, - DiagnosticEngine& diagnostic) { + static std::optional Decode(XLA_FFI_AttrType type, void *attr, + DiagnosticEngine &diagnostic) { if (XLA_FFI_PREDICT_FALSE(type != XLA_FFI_AttrType_DICTIONARY)) { return diagnostic.Emit("Expected dictionary attribute for GroupedGemmV2Config"); } - auto* attrs = reinterpret_cast(attr); + auto *attrs = reinterpret_cast(attr); Type config; if (!DecodeAttrOrDefault(attrs, "lhs_is_trans", config.lhs_is_trans, diagnostic)) return std::nullopt; @@ -643,12 +643,12 @@ struct AttrsBinding { template <> struct AttrDecoding { using Type = transformer_engine::jax::GroupedGemmConfig; - static std::optional Decode(XLA_FFI_AttrType type, void* attr, - DiagnosticEngine& diagnostic) { + static std::optional Decode(XLA_FFI_AttrType type, void *attr, + DiagnosticEngine &diagnostic) { if (XLA_FFI_PREDICT_FALSE(type != XLA_FFI_AttrType_DICTIONARY)) { return diagnostic.Emit("Expected dictionary attribute for GroupedGemmConfig"); } - auto* attrs = reinterpret_cast(attr); + auto *attrs = reinterpret_cast(attr); Type config; if (!DecodeAttrOrDefault(attrs, "lhs_is_trans", config.lhs_is_trans, diagnostic)) return std::nullopt; @@ -656,10 +656,9 @@ struct AttrDecoding { return std::nullopt; if (!DecodeAttrOrDefault(attrs, "scaling_mode", config.scaling_mode, diagnostic)) return std::nullopt; - if (!DecodeAttrOrDefault(attrs, "has_bias", config.has_bias, diagnostic)) - return std::nullopt; - if (!DecodeAttrOrDefault(attrs, "use_async_d2h_group_sizes", - config.use_async_d2h_group_sizes, diagnostic)) + if (!DecodeAttrOrDefault(attrs, "has_bias", config.has_bias, diagnostic)) return std::nullopt; + if (!DecodeAttrOrDefault(attrs, "use_async_d2h_group_sizes", config.use_async_d2h_group_sizes, + diagnostic)) return std::nullopt; return config; } @@ -676,10 +675,9 @@ Error_Type GroupedGemmV2FFI(cudaStream_t stream, Buffer_Type lhs_data, Buffer_Ty Buffer_Type lhs_first_dims, Buffer_Type lhs_last_dims, Buffer_Type rhs_first_dims, Buffer_Type rhs_last_dims, Buffer_Type out_first_dims, Buffer_Type out_last_dims, - Buffer_Type alpha, Buffer_Type beta, - Result_Type output, Result_Type cublas_workspace, - Result_Type setup_workspace, Result_Type int64_workspace, - GroupedGemmV2Config config) { + Buffer_Type alpha, Buffer_Type beta, Result_Type output, + Result_Type cublas_workspace, Result_Type setup_workspace, + Result_Type int64_workspace, GroupedGemmV2Config config) { auto [lhs_is_trans, rhs_is_trans, scaling_mode] = config; // Notes on matrix layouts and transpose: // Jax uses row-major data_layout, on entering this function, each input matrix pair: @@ -697,29 +695,40 @@ Error_Type GroupedGemmV2FFI(cudaStream_t stream, Buffer_Type lhs_data, Buffer_Ty // Determine which group_sizes buffers are active (non-empty = ragged dimension). bool is_lhs_first_ragged = lhs_first_dims.element_count() > 0; - bool is_lhs_last_ragged = lhs_last_dims.element_count() > 0; + bool is_lhs_last_ragged = lhs_last_dims.element_count() > 0; bool is_rhs_first_ragged = rhs_first_dims.element_count() > 0; - bool is_rhs_last_ragged = rhs_last_dims.element_count() > 0; + bool is_rhs_last_ragged = rhs_last_dims.element_count() > 0; bool is_out_first_ragged = out_first_dims.element_count() > 0; - bool is_out_last_ragged = out_last_dims.element_count() > 0; + bool is_out_last_ragged = out_last_dims.element_count() > 0; bool is_lhs_ragged = is_lhs_first_ragged || is_lhs_last_ragged; bool is_rhs_ragged = is_rhs_first_ragged || is_rhs_last_ragged; - bool any_ragged = is_lhs_ragged || is_rhs_ragged; + bool any_ragged = is_lhs_ragged || is_rhs_ragged; size_t num_gemms; - if (is_lhs_first_ragged) num_gemms = lhs_first_dims.dimensions()[0]; - else if (is_lhs_last_ragged) num_gemms = lhs_last_dims.dimensions()[0]; - else if (is_rhs_first_ragged) num_gemms = rhs_first_dims.dimensions()[0]; - else if (is_rhs_last_ragged) num_gemms = rhs_last_dims.dimensions()[0]; - else if (is_out_first_ragged) num_gemms = out_first_dims.dimensions()[0]; - else if (is_out_last_ragged) num_gemms = out_last_dims.dimensions()[0]; - else num_gemms = alpha.element_count(); // batched: no ragged tensor + if (is_lhs_first_ragged) + num_gemms = lhs_first_dims.dimensions()[0]; + else if (is_lhs_last_ragged) + num_gemms = lhs_last_dims.dimensions()[0]; + else if (is_rhs_first_ragged) + num_gemms = rhs_first_dims.dimensions()[0]; + else if (is_rhs_last_ragged) + num_gemms = rhs_last_dims.dimensions()[0]; + else if (is_out_first_ragged) + num_gemms = out_first_dims.dimensions()[0]; + else if (is_out_last_ragged) + num_gemms = out_last_dims.dimensions()[0]; + else + num_gemms = alpha.element_count(); // batched: no ragged tensor const Buffer_Type *active_gs_ptr = nullptr; - if (is_lhs_first_ragged) active_gs_ptr = &lhs_first_dims; - else if (is_lhs_last_ragged) active_gs_ptr = &lhs_last_dims; - else if (is_rhs_first_ragged) active_gs_ptr = &rhs_first_dims; - else if (is_rhs_last_ragged) active_gs_ptr = &rhs_last_dims; + if (is_lhs_first_ragged) + active_gs_ptr = &lhs_first_dims; + else if (is_lhs_last_ragged) + active_gs_ptr = &lhs_last_dims; + else if (is_rhs_first_ragged) + active_gs_ptr = &rhs_first_dims; + else if (is_rhs_last_ragged) + active_gs_ptr = &rhs_last_dims; // lhs_data and rhs_data are 2D; derive m, n, k from buffer dimensions. NVTE_CHECK(lhs_data.dimensions().size() == 2, "lhs_data must be 2D."); @@ -754,9 +763,8 @@ Error_Type GroupedGemmV2FFI(cudaStream_t stream, Buffer_Type lhs_data, Buffer_Ty NVTE_CHECK(active_gs_ptr != nullptr, "active_gs_ptr is null but any_ragged is true."); NVTE_CHECK(active_gs_ptr->element_type() == xla::ffi::DataType::S32, "group_sizes must be int32."); - nvte_convert_int32_to_int64( - reinterpret_cast(active_gs_ptr->untyped_data()), int64_sizes_ptr, - num_gemms, stream); + nvte_convert_int32_to_int64(reinterpret_cast(active_gs_ptr->untyped_data()), + int64_sizes_ptr, num_gemms, stream); } NVTE_CHECK(scaling_mode == JAXX_Scaling_Mode::NO_SCALING, @@ -961,8 +969,7 @@ Error_Type GroupedGemmFFI(cudaStream_t stream, Buffer_Type lhs_data, Buffer_Type Buffer_Type lhs_first_dims, Buffer_Type lhs_last_dims, Buffer_Type rhs_first_dims, Buffer_Type rhs_last_dims, Buffer_Type out_first_dims, Buffer_Type out_last_dims, - Buffer_Type group_offset, - Result_Type output, Result_Type workspace, + Buffer_Type group_offset, Result_Type output, Result_Type workspace, GroupedGemmConfig config) { auto [lhs_is_trans, rhs_is_trans, scaling_mode, has_bias, use_async_d2h_group_sizes] = config; // Notes on matrix layouts and transpose: @@ -983,25 +990,34 @@ Error_Type GroupedGemmFFI(cudaStream_t stream, Buffer_Type lhs_data, Buffer_Type // Determine which group_sizes buffers are active (non-empty = ragged dimension). bool is_lhs_first_ragged = lhs_first_dims.element_count() > 0; - bool is_lhs_last_ragged = lhs_last_dims.element_count() > 0; + bool is_lhs_last_ragged = lhs_last_dims.element_count() > 0; bool is_rhs_first_ragged = rhs_first_dims.element_count() > 0; - bool is_rhs_last_ragged = rhs_last_dims.element_count() > 0; + bool is_rhs_last_ragged = rhs_last_dims.element_count() > 0; bool is_lhs_ragged = is_lhs_first_ragged || is_lhs_last_ragged; bool is_rhs_ragged = is_rhs_first_ragged || is_rhs_last_ragged; - bool any_ragged = is_lhs_ragged || is_rhs_ragged; + bool any_ragged = is_lhs_ragged || is_rhs_ragged; size_t num_gemms; - if (is_lhs_first_ragged) num_gemms = lhs_first_dims.dimensions()[0]; - else if (is_lhs_last_ragged) num_gemms = lhs_last_dims.dimensions()[0]; - else if (is_rhs_first_ragged) num_gemms = rhs_first_dims.dimensions()[0]; - else if (is_rhs_last_ragged) num_gemms = rhs_last_dims.dimensions()[0]; - else num_gemms = 1; // degenerate batched; legacy batched not a tested use case + if (is_lhs_first_ragged) + num_gemms = lhs_first_dims.dimensions()[0]; + else if (is_lhs_last_ragged) + num_gemms = lhs_last_dims.dimensions()[0]; + else if (is_rhs_first_ragged) + num_gemms = rhs_first_dims.dimensions()[0]; + else if (is_rhs_last_ragged) + num_gemms = rhs_last_dims.dimensions()[0]; + else + num_gemms = 1; // degenerate batched; legacy batched not a tested use case const Buffer_Type *active_gs_ptr = nullptr; - if (is_lhs_first_ragged) active_gs_ptr = &lhs_first_dims; - else if (is_lhs_last_ragged) active_gs_ptr = &lhs_last_dims; - else if (is_rhs_first_ragged) active_gs_ptr = &rhs_first_dims; - else if (is_rhs_last_ragged) active_gs_ptr = &rhs_last_dims; + if (is_lhs_first_ragged) + active_gs_ptr = &lhs_first_dims; + else if (is_lhs_last_ragged) + active_gs_ptr = &lhs_last_dims; + else if (is_rhs_first_ragged) + active_gs_ptr = &rhs_first_dims; + else if (is_rhs_last_ragged) + active_gs_ptr = &rhs_last_dims; // lhs_data and rhs_data are 2D; derive m, n, k from buffer dimensions. NVTE_CHECK(lhs_data.dimensions().size() == 2, "lhs_data must be 2D."); @@ -1126,8 +1142,7 @@ Error_Type GroupedGemmFFI(cudaStream_t stream, Buffer_Type lhs_data, Buffer_Type " does not match the return of GroupedGemmGetGroupSizes ", host_num_gemms, "."); } else { NVTE_CHECK(active_gs_ptr != nullptr, "active_gs_ptr is null but any_ragged is true."); - auto gs_data_ptr = - reinterpret_cast(active_gs_ptr->untyped_data()); + auto gs_data_ptr = reinterpret_cast(active_gs_ptr->untyped_data()); cudaMemcpyAsync(dim_list_host.data(), gs_data_ptr, dim_list_bytes, cudaMemcpyDeviceToHost, stream); // Note: This may break cudaGraph. From ed0deaf08a01da3357fe09cbe92ca31fcd6beac0 Mon Sep 17 00:00:00 2001 From: Jeremy Berchtold Date: Tue, 10 Mar 2026 11:35:29 -0700 Subject: [PATCH 05/60] Cleanup C++ v2 FFI Signed-off-by: Jeremy Berchtold --- .../jax/csrc/extensions/gemm.cpp | 308 +++++------------- 1 file changed, 77 insertions(+), 231 deletions(-) diff --git a/transformer_engine/jax/csrc/extensions/gemm.cpp b/transformer_engine/jax/csrc/extensions/gemm.cpp index e85c520916..288770281a 100644 --- a/transformer_engine/jax/csrc/extensions/gemm.cpp +++ b/transformer_engine/jax/csrc/extensions/gemm.cpp @@ -559,6 +559,70 @@ JAXX_GroupedTensorWrapper make_grouped_tensor(Buffer_Type const &data, return std::move(grouped_tensor_wrapper); } +// V2 variant: derives data shape from the 2D XLA buffer directly, converts group_sizes +// int32→int64 per-tensor into int64_workspace, and wires first_dims/last_dims. +// Only NO_SCALING is supported. +JAXX_GroupedTensorWrapper make_grouped_tensor(Buffer_Type const &data, + Buffer_Type const &first_dims, + Buffer_Type const &last_dims, + Result_Type int64_workspace, + size_t num_gemms, + cudaStream_t stream) { + auto dims = data.dimensions(); + NVTE_CHECK(dims.size() >= 2, "grouped GEMM data buffer must be at least 2D."); + // Flatten all leading dimensions into the first axis to produce a 2D NVTE shape. + // Input buffers (lhs, rhs) are already 2D from the Python side. Output buffers may be ND + // (e.g. [G, K, N] for wgrad), so we collapse dims[0..N-2] → rows and keep dims[N-1] → cols. + NVTEShape dataShape{.data = {product(dims, 0, dims.size() - 1), dims[dims.size() - 1]}, + .ndim = 2}; + JAXX_GroupedTensorWrapper wrapper(JAXX_Scaling_Mode::NO_SCALING, num_gemms, dataShape); + wrapper.set_rowwise(data, std::nullopt); + auto *int64_sizes_ptr = reinterpret_cast(int64_workspace->untyped_data()); + if (first_dims.element_count() > 0) { + NVTE_CHECK(first_dims.element_type() == xla::ffi::DataType::S32, + "group_sizes must be int32."); + nvte_convert_int32_to_int64( + reinterpret_cast(first_dims.untyped_data()), + int64_sizes_ptr, num_gemms, stream); + wrapper.set_group_sizes_only(int64_sizes_ptr, num_gemms, kNVTEGroupedFirstDims); + } + if (last_dims.element_count() > 0) { + NVTE_CHECK(last_dims.element_type() == xla::ffi::DataType::S32, + "group_sizes must be int32."); + nvte_convert_int32_to_int64( + reinterpret_cast(last_dims.untyped_data()), + int64_sizes_ptr, num_gemms, stream); + wrapper.set_group_sizes_only(int64_sizes_ptr, num_gemms, kNVTEGroupedLastDims); + } + return wrapper; +} + +// Returns num_gemms from the first non-empty per-tensor group_sizes buffer, +// falling back to the element count of alpha for the uniform-batch case. +size_t grouped_gemm_num_gemms(Buffer_Type const &lhs_first_dims, + Buffer_Type const &lhs_last_dims, + Buffer_Type const &rhs_first_dims, + Buffer_Type const &rhs_last_dims, + Buffer_Type const &out_first_dims, + Buffer_Type const &out_last_dims, + Buffer_Type const &alpha) { + if (lhs_first_dims.element_count() > 0) { + return lhs_first_dims.dimensions()[0]; + } else if (lhs_last_dims.element_count() > 0) { + return lhs_last_dims.dimensions()[0]; + } else if (rhs_first_dims.element_count() > 0) { + return rhs_first_dims.dimensions()[0]; + } else if (rhs_last_dims.element_count() > 0) { + return rhs_last_dims.dimensions()[0]; + } else if (out_first_dims.element_count() > 0) { + return out_first_dims.dimensions()[0]; + } else if (out_last_dims.element_count() > 0) { + return out_last_dims.dimensions()[0]; + } else { + return alpha.element_count(); // uniform batch: no ragged tensor + } +} + // Config structs for grouped GEMM FFI static attributes. // Consolidating all static attributes into a single dict attribute makes it easy to add new // attributes in the future with backwards-compatible defaults: if old HLO was generated without a @@ -679,181 +743,19 @@ Error_Type GroupedGemmV2FFI(cudaStream_t stream, Buffer_Type lhs_data, Buffer_Ty Result_Type cublas_workspace, Result_Type setup_workspace, Result_Type int64_workspace, GroupedGemmV2Config config) { auto [lhs_is_trans, rhs_is_trans, scaling_mode] = config; - // Notes on matrix layouts and transpose: - // Jax uses row-major data_layout, on entering this function, each input matrix pair: - // A: row-major [m, k] for N - [k, m] for T - // B: row-major [k, n] for N - [n, k] for T - // on exiting this function, JAX expect: - // C: row-major with size [m, n]. - // cuBLAS uses column-major data_layout, in this view, each input matrix pair: - // A: column-major with size [k, m] for T - [m, k] for N - // B: column-major with size [n, k] for T - [k, n] for N - // - // If we call cuBLAS GEMM for A * B, the output will be: - // C: column-major with size [m, n] --> row-major with size [n, m]. - // To make the output compatible with JAX, we need to swap A and B in cuBLAS GEMM call. - - // Determine which group_sizes buffers are active (non-empty = ragged dimension). - bool is_lhs_first_ragged = lhs_first_dims.element_count() > 0; - bool is_lhs_last_ragged = lhs_last_dims.element_count() > 0; - bool is_rhs_first_ragged = rhs_first_dims.element_count() > 0; - bool is_rhs_last_ragged = rhs_last_dims.element_count() > 0; - bool is_out_first_ragged = out_first_dims.element_count() > 0; - bool is_out_last_ragged = out_last_dims.element_count() > 0; - bool is_lhs_ragged = is_lhs_first_ragged || is_lhs_last_ragged; - bool is_rhs_ragged = is_rhs_first_ragged || is_rhs_last_ragged; - bool any_ragged = is_lhs_ragged || is_rhs_ragged; - - size_t num_gemms; - if (is_lhs_first_ragged) - num_gemms = lhs_first_dims.dimensions()[0]; - else if (is_lhs_last_ragged) - num_gemms = lhs_last_dims.dimensions()[0]; - else if (is_rhs_first_ragged) - num_gemms = rhs_first_dims.dimensions()[0]; - else if (is_rhs_last_ragged) - num_gemms = rhs_last_dims.dimensions()[0]; - else if (is_out_first_ragged) - num_gemms = out_first_dims.dimensions()[0]; - else if (is_out_last_ragged) - num_gemms = out_last_dims.dimensions()[0]; - else - num_gemms = alpha.element_count(); // batched: no ragged tensor - - const Buffer_Type *active_gs_ptr = nullptr; - if (is_lhs_first_ragged) - active_gs_ptr = &lhs_first_dims; - else if (is_lhs_last_ragged) - active_gs_ptr = &lhs_last_dims; - else if (is_rhs_first_ragged) - active_gs_ptr = &rhs_first_dims; - else if (is_rhs_last_ragged) - active_gs_ptr = &rhs_last_dims; - - // lhs_data and rhs_data are 2D; derive m, n, k from buffer dimensions. - NVTE_CHECK(lhs_data.dimensions().size() == 2, "lhs_data must be 2D."); - NVTE_CHECK(rhs_data.dimensions().size() == 2, "rhs_data must be 2D."); - size_t k = lhs_is_trans ? lhs_data.dimensions()[0] : lhs_data.dimensions()[1]; - size_t m, n; - if (is_rhs_ragged) { - // wgrad: lhs shape [K_lhs, M]: lhs_is_trans=True, contracting is dim[0]=K_lhs, output is dim[1]=M - m = lhs_is_trans ? lhs_data.dimensions()[1] : lhs_data.dimensions()[0]; - n = rhs_data.dimensions()[1]; - } else { - m = lhs_data.dimensions()[0]; // total M (sum of group sizes) - n = rhs_is_trans ? rhs_data.dimensions()[0] / num_gemms : rhs_data.dimensions()[1]; - } - - // Inputs - auto lhs_ptr = reinterpret_cast(lhs_data.untyped_data()); - auto rhs_ptr = reinterpret_cast(rhs_data.untyped_data()); - auto lhs_sinv_ptr = reinterpret_cast(lhs_sinv.untyped_data()); - auto rhs_sinv_ptr = reinterpret_cast(rhs_sinv.untyped_data()); - auto lhs_dtype = convert_ffi_datatype_to_te_dtype(lhs_data.element_type()); - auto rhs_dtype = convert_ffi_datatype_to_te_dtype(rhs_data.element_type()); - auto lhs_sinv_dtype = convert_ffi_datatype_to_te_dtype(lhs_sinv.element_type()); - auto rhs_sinv_dtype = convert_ffi_datatype_to_te_dtype(rhs_sinv.element_type()); - bool has_bias = product(bias.dimensions()) > 0; - auto bias_ptr = has_bias ? reinterpret_cast(bias.untyped_data()) : nullptr; - auto bias_dtype = convert_ffi_datatype_to_te_dtype(bias.element_type()); - - // Convert int32 group_sizes to int64 into the dedicated output buffer (ragged tensors only). - auto *int64_sizes_ptr = reinterpret_cast(int64_workspace->untyped_data()); - if (any_ragged) { - NVTE_CHECK(active_gs_ptr != nullptr, "active_gs_ptr is null but any_ragged is true."); - NVTE_CHECK(active_gs_ptr->element_type() == xla::ffi::DataType::S32, - "group_sizes must be int32."); - nvte_convert_int32_to_int64(reinterpret_cast(active_gs_ptr->untyped_data()), - int64_sizes_ptr, num_gemms, stream); - } NVTE_CHECK(scaling_mode == JAXX_Scaling_Mode::NO_SCALING, "Only non-quantized grouped GEMM is supported in current implementation."); - // It is weird that TE/Common GEMM only use colwise for MXFP8 - const bool is_fp8_gemm = is_fp8_dtype(lhs_dtype); - const bool is_tensor_scaling = scaling_mode == JAXX_Scaling_Mode::DELAYED_TENSOR_SCALING || - scaling_mode == JAXX_Scaling_Mode::CURRENT_TENSOR_SCALING; - const bool is_mxfp8_scaling = scaling_mode == JAXX_Scaling_Mode::MXFP8_1D_SCALING; - const bool rhs_use_colwise = is_mxfp8_scaling && !rhs_is_trans; - const bool lhs_use_colwise = is_mxfp8_scaling && lhs_is_trans; + size_t num_gemms = grouped_gemm_num_gemms(lhs_first_dims, lhs_last_dims, + rhs_first_dims, rhs_last_dims, + out_first_dims, out_last_dims, alpha); - // Outputs - auto out_ptr = reinterpret_cast(output->untyped_data()); - auto out_dtype = convert_ffi_datatype_to_te_dtype(output->element_type()); + // Workspaces. auto setup_workspace_ptr = reinterpret_cast(setup_workspace->untyped_data()); - // Here we clear the lower 8 bits of the buffer address to ensure the buffer is 256-aligned auto cublas_workspace_ptr = reinterpret_cast(cublas_workspace->untyped_data()); cublas_workspace_ptr = move_ptr_to_next_256B_aligned(cublas_workspace_ptr); - auto workspace_total_size = product(cublas_workspace->dimensions()); - - auto lhs_sinv_size = product(lhs_sinv.dimensions()); - auto rhs_sinv_size = product(rhs_sinv.dimensions()); - const size_t workspace_alignment_padding = 256; - const size_t tensor_scaling_sinv_aligment = 16; - const size_t mxfp8_scaling_sinv_alignment_padding = 256; - auto workspace_size = workspace_total_size - workspace_alignment_padding; - if (is_mxfp8_scaling) { - // For MXFP8 swizzled scale_inv buffers, only the first pointer needs to be with 256B alignment padding. Later pointers are guaranteed to be 256-aligned as the scale_inv shapes are padded by 128x4. - workspace_size -= (lhs_sinv_size + rhs_sinv_size + 2 * mxfp8_scaling_sinv_alignment_padding); - } else if (is_tensor_scaling) { - // For tensor scaling, each matrix has a single scale value, and all scales need to be aligned - // by 16 bytes to meet the requirement of CUDA 12.9.1 and later. - workspace_size -= tensor_scaling_sinv_aligment * (lhs_sinv_size + rhs_sinv_size); - } - auto swizzled_lhs_sinv_ptr = cublas_workspace_ptr + workspace_size; - swizzled_lhs_sinv_ptr = move_ptr_to_next_256B_aligned(swizzled_lhs_sinv_ptr); - auto swizzled_rhs_sinv_ptr = swizzled_lhs_sinv_ptr + lhs_sinv_size; - swizzled_rhs_sinv_ptr = move_ptr_to_next_256B_aligned(swizzled_rhs_sinv_ptr); - auto lhs_scatter_aligned_ptr = swizzled_lhs_sinv_ptr; // Already 256B aligned - auto rhs_scatter_aligned_ptr = lhs_scatter_aligned_ptr + num_gemms * tensor_scaling_sinv_aligment; - - size_t lhs_dtype_bytes = te_dtype_bytes(lhs_dtype); - size_t rhs_dtype_bytes = te_dtype_bytes(rhs_dtype); - size_t lhs_sinv_dtype_bytes = te_dtype_bytes(lhs_sinv_dtype); - size_t rhs_sinv_dtype_bytes = te_dtype_bytes(rhs_sinv_dtype); - size_t bias_dtype_bytes = te_dtype_bytes(bias_dtype); - size_t out_dtype_bytes = te_dtype_bytes(out_dtype); - - NVTE_CHECK(lhs_dtype_bytes == rhs_dtype_bytes, "sizeof(lhs_dtype) != sizeof(rhs_dtype)"); - NVTE_CHECK(lhs_sinv_dtype_bytes == rhs_sinv_dtype_bytes, - "sizeof(lhs_sinv_dtype) != sizeof(rhs_sinv_dtype)"); - - size_t expected_lhs_size = m * k; - size_t expected_rhs_size = is_rhs_ragged ? (k * n) : (num_gemms * k * n); - size_t expected_out_size = is_rhs_ragged ? (num_gemms * m * n) : (m * n); - size_t actual_lhs_size = product(lhs_data.dimensions()); - size_t actual_rhs_size = product(rhs_data.dimensions()); - size_t actual_out_size = product(output->dimensions()); - NVTE_CHECK(expected_lhs_size == actual_lhs_size, "Unexpected lhs size! Expect ", - expected_lhs_size, ", got ", actual_lhs_size); - if (!is_rhs_ragged) { - NVTE_CHECK(expected_rhs_size == actual_rhs_size, - "Unexpected rhs size! Expect num_gemms * n * k = ", num_gemms, " * ", n, " * ", k, - " = ", expected_rhs_size, ", got ", actual_rhs_size); - NVTE_CHECK(expected_out_size == actual_out_size, "Unexpected output size! Expect m * n = ", m, - " * ", n, " = ", expected_out_size, ", got ", actual_out_size); - } else { - NVTE_CHECK(expected_rhs_size == actual_rhs_size, "Unexpected rhs size! Expect k * n = ", k, - " * ", n, " = ", expected_rhs_size, ", got ", actual_rhs_size); - NVTE_CHECK(expected_out_size == actual_out_size, - "Unexpected output size! Expect num_gemms * m * n = ", num_gemms, " * ", m, " * ", n, - " = ", expected_out_size, ", got ", actual_out_size); - } - - auto num_math_sm = cuda::sm_count() - getenv("NVTE_EXT_MARGIN_SM", 0); - bool grad = false; - bool accumulate = false; - bool use_split_accumulator = false; - auto bias_shape = std::vector{has_bias ? n : 0}; - const int arch = cuda::sm_arch(); - - if (arch < 100 && is_fp8_gemm) { - NVTE_CHECK(!lhs_is_trans && rhs_is_trans, - "For SM90 or older archs and FP8 input, only NT (row-major) GEMM is supported, ", - "got lhs_is_trans=", lhs_is_trans, ", rhs_is_trans=", rhs_is_trans); - } - + auto workspace_size = product(cublas_workspace->dimensions()) - 256; TensorWrapper workspace_setup(setup_workspace_ptr, std::vector{product(setup_workspace->dimensions())}, DType::kByte); @@ -867,70 +769,14 @@ Error_Type GroupedGemmV2FFI(cudaStream_t stream, Buffer_Type lhs_data, Buffer_Ty std::vector{num_gemms}, convert_ffi_datatype_to_te_dtype(beta.element_type())); - if (is_rhs_ragged) { - NVTE_CHECK(lhs_is_trans && !rhs_is_trans, - "For grouped dense wgrad, only TN GEMM is supported in TE/JAX currently."); - - //// RHS - NVTEShape rhsShape{.data = {k, n}, .ndim = 2}; - auto rhs_tensor = make_grouped_tensor(rhs_data, rhs_sinv, scaling_mode, num_gemms, rhsShape); - if (is_rhs_first_ragged) - rhs_tensor.set_group_sizes_only(int64_sizes_ptr, num_gemms, kNVTEGroupedFirstDims); - if (is_rhs_last_ragged) - rhs_tensor.set_group_sizes_only(int64_sizes_ptr, num_gemms, kNVTEGroupedLastDims); - - //// LHS - NVTEShape lhsShape{.data = {k, m}, .ndim = 2}; - lhs_is_trans = true; - auto lhs_tensor = make_grouped_tensor(lhs_data, lhs_sinv, scaling_mode, num_gemms, lhsShape); - if (is_lhs_first_ragged) - lhs_tensor.set_group_sizes_only(int64_sizes_ptr, num_gemms, kNVTEGroupedFirstDims); - if (is_lhs_last_ragged) - lhs_tensor.set_group_sizes_only(int64_sizes_ptr, num_gemms, kNVTEGroupedLastDims); - - //// OUTPUT - NVTEShape outShape{.data = {num_gemms * m, n}, .ndim = 2}; - auto out_tensor = make_grouped_tensor(*output, std::nullopt, JAXX_Scaling_Mode::NO_SCALING, - num_gemms, outShape); - - nvte_grouped_gemm(rhs_tensor, rhs_is_trans, lhs_tensor, lhs_is_trans, nullptr, out_tensor, - alpha_tensor.data(), beta_tensor.data(), workspace_setup.data(), - workspace_cublas.data(), - nullptr, // config (use defaults) - stream); - - return ffi_with_cuda_error_check(); - } - - // Nominal case for FWD, DGRAD, or batched GEMM - - //// RHS - NVTEShape rhsShape{.data = {num_gemms * k, n}, .ndim = 2}; - if (rhs_is_trans) { - rhsShape.data[0] = num_gemms * n; - rhsShape.data[1] = k; - } - auto rhs_tensor = make_grouped_tensor(rhs_data, rhs_sinv, scaling_mode, num_gemms, rhsShape); - - //// LHS - NVTEShape lhsShape{.data = {m, k}, .ndim = 2}; - if (lhs_is_trans) { - std::swap(lhsShape.data[0], lhsShape.data[1]); - } - auto lhs_tensor = make_grouped_tensor(lhs_data, lhs_sinv, scaling_mode, num_gemms, lhsShape); - if (is_lhs_first_ragged) - lhs_tensor.set_group_sizes_only(int64_sizes_ptr, num_gemms, kNVTEGroupedFirstDims); - if (is_lhs_last_ragged) - lhs_tensor.set_group_sizes_only(int64_sizes_ptr, num_gemms, kNVTEGroupedLastDims); - - //// OUTPUT - NVTEShape outShape{.data = {m, n}, .ndim = 2}; - auto out_tensor = make_grouped_tensor(*output, std::nullopt, JAXX_Scaling_Mode::NO_SCALING, - num_gemms, outShape); - if (is_out_first_ragged) - out_tensor.set_group_sizes_only(int64_sizes_ptr, num_gemms, kNVTEGroupedFirstDims); - if (is_out_last_ragged) - out_tensor.set_group_sizes_only(int64_sizes_ptr, num_gemms, kNVTEGroupedLastDims); + // Build grouped tensors from XLA buffer shapes and group_sizes — no m/n/k derivation needed. + // int32→int64 conversion for group_sizes is handled per-tensor inside make_grouped_tensor. + auto rhs_tensor = make_grouped_tensor(rhs_data, rhs_first_dims, rhs_last_dims, + int64_workspace, num_gemms, stream); + auto lhs_tensor = make_grouped_tensor(lhs_data, lhs_first_dims, lhs_last_dims, + int64_workspace, num_gemms, stream); + auto out_tensor = make_grouped_tensor(*output, out_first_dims, out_last_dims, + int64_workspace, num_gemms, stream); nvte_grouped_gemm(rhs_tensor, rhs_is_trans, lhs_tensor, lhs_is_trans, nullptr, out_tensor, alpha_tensor.data(), beta_tensor.data(), workspace_setup.data(), From 88bb7daaa6ff1d0877b2227c2cd29efcf29cd555 Mon Sep 17 00:00:00 2001 From: Jeremy Berchtold Date: Tue, 10 Mar 2026 11:55:59 -0700 Subject: [PATCH 06/60] Fix int64 workspace usage Signed-off-by: Jeremy Berchtold --- transformer_engine/jax/cpp_extensions/gemm.py | 19 +++++++- .../jax/csrc/extensions/gemm.cpp | 47 +++++++++++++------ 2 files changed, 50 insertions(+), 16 deletions(-) diff --git a/transformer_engine/jax/cpp_extensions/gemm.py b/transformer_engine/jax/cpp_extensions/gemm.py index d79be04983..979bed9577 100644 --- a/transformer_engine/jax/cpp_extensions/gemm.py +++ b/transformer_engine/jax/cpp_extensions/gemm.py @@ -1549,7 +1549,24 @@ def abstract( shape=(get_grouped_gemm_setup_workspace_size(num_groups),), dtype=jnp.uint8 ) # Temporary buffer for int32 -> int64 conversion of group_sizes on device. - int64_workspace_size = num_groups * jnp.dtype(jnp.int64).itemsize + # Each non-empty *_dims buffer needs its own slot of num_groups int64 elements so that + # make_grouped_tensor can write to a distinct region per ragged dimension. Allocate + # exactly as many slots as there are non-empty buffers (minimum 1 to avoid zero-size). + num_ragged_dim_buffers = sum( + 1 + for aval in [ + lhs_first_dims_aval, + lhs_last_dims_aval, + rhs_first_dims_aval, + rhs_last_dims_aval, + out_first_dims_aval, + out_last_dims_aval, + ] + if aval.size > 0 + ) + int64_workspace_size = ( + max(num_ragged_dim_buffers, 1) * num_groups * jnp.dtype(jnp.int64).itemsize + ) int64_workspace_aval = jax.core.ShapedArray( shape=(int64_workspace_size,), dtype=jnp.uint8 ) diff --git a/transformer_engine/jax/csrc/extensions/gemm.cpp b/transformer_engine/jax/csrc/extensions/gemm.cpp index 288770281a..e122e8f909 100644 --- a/transformer_engine/jax/csrc/extensions/gemm.cpp +++ b/transformer_engine/jax/csrc/extensions/gemm.cpp @@ -559,13 +559,17 @@ JAXX_GroupedTensorWrapper make_grouped_tensor(Buffer_Type const &data, return std::move(grouped_tensor_wrapper); } -// V2 variant: derives data shape from the 2D XLA buffer directly, converts group_sizes -// int32→int64 per-tensor into int64_workspace, and wires first_dims/last_dims. -// Only NO_SCALING is supported. +// V2 variant: derives data shape from the XLA buffer directly, converts group_sizes +// int32→int64 per-tensor into a dedicated slot of int64_workspace, and wires first_dims/last_dims. +// int64_offset (in int64 elements) is updated on return to the next available slot so callers can +// thread it through successive make_grouped_tensor calls without aliasing. Bounds are checked +// before each slot is used. Only NO_SCALING is supported. JAXX_GroupedTensorWrapper make_grouped_tensor(Buffer_Type const &data, Buffer_Type const &first_dims, Buffer_Type const &last_dims, - Result_Type int64_workspace, + int64_t *int64_workspace_base, + size_t int64_workspace_capacity, + size_t &int64_offset, size_t num_gemms, cudaStream_t stream) { auto dims = data.dimensions(); @@ -577,22 +581,27 @@ JAXX_GroupedTensorWrapper make_grouped_tensor(Buffer_Type const &data, .ndim = 2}; JAXX_GroupedTensorWrapper wrapper(JAXX_Scaling_Mode::NO_SCALING, num_gemms, dataShape); wrapper.set_rowwise(data, std::nullopt); - auto *int64_sizes_ptr = reinterpret_cast(int64_workspace->untyped_data()); if (first_dims.element_count() > 0) { NVTE_CHECK(first_dims.element_type() == xla::ffi::DataType::S32, "group_sizes must be int32."); + NVTE_CHECK(int64_offset + num_gemms <= int64_workspace_capacity, + "int64_workspace overflow: not enough space for first_dims conversion."); + auto *slot = int64_workspace_base + int64_offset; nvte_convert_int32_to_int64( - reinterpret_cast(first_dims.untyped_data()), - int64_sizes_ptr, num_gemms, stream); - wrapper.set_group_sizes_only(int64_sizes_ptr, num_gemms, kNVTEGroupedFirstDims); + reinterpret_cast(first_dims.untyped_data()), slot, num_gemms, stream); + wrapper.set_group_sizes_only(slot, num_gemms, kNVTEGroupedFirstDims); + int64_offset += num_gemms; } if (last_dims.element_count() > 0) { NVTE_CHECK(last_dims.element_type() == xla::ffi::DataType::S32, "group_sizes must be int32."); + NVTE_CHECK(int64_offset + num_gemms <= int64_workspace_capacity, + "int64_workspace overflow: not enough space for last_dims conversion."); + auto *slot = int64_workspace_base + int64_offset; nvte_convert_int32_to_int64( - reinterpret_cast(last_dims.untyped_data()), - int64_sizes_ptr, num_gemms, stream); - wrapper.set_group_sizes_only(int64_sizes_ptr, num_gemms, kNVTEGroupedLastDims); + reinterpret_cast(last_dims.untyped_data()), slot, num_gemms, stream); + wrapper.set_group_sizes_only(slot, num_gemms, kNVTEGroupedLastDims); + int64_offset += num_gemms; } return wrapper; } @@ -770,13 +779,21 @@ Error_Type GroupedGemmV2FFI(cudaStream_t stream, Buffer_Type lhs_data, Buffer_Ty convert_ffi_datatype_to_te_dtype(beta.element_type())); // Build grouped tensors from XLA buffer shapes and group_sizes — no m/n/k derivation needed. - // int32→int64 conversion for group_sizes is handled per-tensor inside make_grouped_tensor. + // int64_workspace is partitioned into per-ragged-buffer slots of num_gemms int64 elements each. + // int64_offset is threaded through the three make_grouped_tensor calls so each non-empty *_dims + // buffer gets its own non-aliasing slot; bounds are checked inside make_grouped_tensor. + auto *int64_base = reinterpret_cast(int64_workspace->untyped_data()); + size_t int64_capacity = int64_workspace->element_count() / sizeof(int64_t); + size_t int64_offset = 0; auto rhs_tensor = make_grouped_tensor(rhs_data, rhs_first_dims, rhs_last_dims, - int64_workspace, num_gemms, stream); + int64_base, int64_capacity, int64_offset, num_gemms, + stream); auto lhs_tensor = make_grouped_tensor(lhs_data, lhs_first_dims, lhs_last_dims, - int64_workspace, num_gemms, stream); + int64_base, int64_capacity, int64_offset, num_gemms, + stream); auto out_tensor = make_grouped_tensor(*output, out_first_dims, out_last_dims, - int64_workspace, num_gemms, stream); + int64_base, int64_capacity, int64_offset, num_gemms, + stream); nvte_grouped_gemm(rhs_tensor, rhs_is_trans, lhs_tensor, lhs_is_trans, nullptr, out_tensor, alpha_tensor.data(), beta_tensor.data(), workspace_setup.data(), From 60312c85374c2ce4cf1d51e09a0e18aa59b157ee Mon Sep 17 00:00:00 2001 From: Jeremy Berchtold Date: Tue, 10 Mar 2026 15:59:12 -0700 Subject: [PATCH 07/60] Address greptile comments Signed-off-by: Jeremy Berchtold --- transformer_engine/jax/cpp_extensions/gemm.py | 17 +++++++++++++++-- 1 file changed, 15 insertions(+), 2 deletions(-) diff --git a/transformer_engine/jax/cpp_extensions/gemm.py b/transformer_engine/jax/cpp_extensions/gemm.py index 979bed9577..de0ef1c522 100644 --- a/transformer_engine/jax/cpp_extensions/gemm.py +++ b/transformer_engine/jax/cpp_extensions/gemm.py @@ -1528,8 +1528,8 @@ def abstract( out_shape = (num_groups, M, N) else: # lhs shape [M_total, K] (lhs_is_trans=False) or [K, M_total] (lhs_is_trans=True) - # dim[0] is always total M for fwd/dgrad - M = lhs_data_aval.shape[0] + # M is the non-contracting (output) dim + M = lhs_data_aval.shape[1] if lhs_is_trans else lhs_data_aval.shape[0] N = rhs_data_aval.shape[1] if not rhs_is_trans else rhs_data_aval.shape[0] // num_groups out_shape = (M, N) @@ -2270,6 +2270,13 @@ def grouped_gemm( lhs_data_2d = _flatten_to_2d(lhs_data, lhs_flatten_axis) rhs_data_2d = _flatten_to_2d(rhs_data, rhs_flatten_axis) + # Validate contracting dim size + k_lhs = lhs_data_2d.shape[0] if lhs_is_trans else lhs_data_2d.shape[1] + k_rhs = rhs_data_2d.shape[1] if rhs_is_trans else rhs_data_2d.shape[0] // num_gemms + assert k_lhs == k_rhs, ( + f"Contracting dimension mismatch: LHS K={k_lhs}, RHS K={k_rhs}" + ) + num_gemms = ( lhs_first_dims.size or lhs_last_dims.size @@ -2278,6 +2285,12 @@ def grouped_gemm( or out_first_dims.size or out_last_dims.size ) + if num_gemms == 0: + raise ValueError( + "grouped_gemm requires at least one non-empty dimension array " + "(lhs_first_dims, lhs_last_dims, rhs_first_dims, rhs_last_dims, " + "out_first_dims, or out_last_dims)." + ) has_bias = bias is not None if has_bias: From 025f598ab65fd5c06d9d0f167504fde748514a89 Mon Sep 17 00:00:00 2001 From: Jeremy Berchtold Date: Tue, 10 Mar 2026 16:25:20 -0700 Subject: [PATCH 08/60] Refactor wgrad-specific checks to be generic for GMM in gemm.py Signed-off-by: Jeremy Berchtold --- transformer_engine/jax/cpp_extensions/gemm.py | 104 +++++++++++------- .../jax/csrc/extensions/gemm.cpp | 2 +- 2 files changed, 64 insertions(+), 42 deletions(-) diff --git a/transformer_engine/jax/cpp_extensions/gemm.py b/transformer_engine/jax/cpp_extensions/gemm.py index de0ef1c522..b97b66066f 100644 --- a/transformer_engine/jax/cpp_extensions/gemm.py +++ b/transformer_engine/jax/cpp_extensions/gemm.py @@ -1441,6 +1441,45 @@ def impl( register_primitive(GroupedGemmCopySizesPrimitive) +def _grouped_gemm_lhs_M(lhs_shape_2d: Tuple[int, int], lhs_is_trans: bool) -> int: + """Non-contracting output size M from the 2-D LHS buffer.""" + return lhs_shape_2d[1] if lhs_is_trans else lhs_shape_2d[0] + + +def _grouped_gemm_rhs_N(rhs_shape_2d: Tuple[int, int], rhs_is_trans: bool, num_groups: int) -> int: + """Non-contracting output size N from the 2-D RHS buffer.""" + return rhs_shape_2d[0] // num_groups if rhs_is_trans else rhs_shape_2d[1] + + +def _assert_grouped_gemm_dims_shapes( + lhs_first_dims_aval, + lhs_last_dims_aval, + rhs_first_dims_aval, + rhs_last_dims_aval, + out_first_dims_aval, + out_last_dims_aval, + num_groups: int, +) -> None: + """Assert that all non-empty *_dims arrays have exactly num_groups elements. + + rhs_first_dims / rhs_last_dims describe the ragged contracting K dimension. + K totals need not fill the entire buffer (padding is allowed), so only the + array length is checked, not the per-group sum. + """ + for name, aval in [ + ("lhs_first_dims", lhs_first_dims_aval), + ("lhs_last_dims", lhs_last_dims_aval), + ("out_first_dims", out_first_dims_aval), + ("out_last_dims", out_last_dims_aval), + ("rhs_first_dims", rhs_first_dims_aval), + ("rhs_last_dims", rhs_last_dims_aval), + ]: + if aval.size > 0: + assert aval.size == num_groups, ( + f"grouped GEMM {name} has size {aval.size}, expected num_groups={num_groups}" + ) + + class GroupedGemmPrimitive(BasePrimitive): """ Primitive for grouped GEMM using nvte_multi_tensor_gemm (supports all scaling modes) or nvte_grouped_gemm (supporting BF16). @@ -1507,8 +1546,6 @@ def abstract( del bias_aval del has_bias, use_async_d2h_group_sizes - # Determine mode from which group_sizes buffer is non-empty - is_wgrad = rhs_first_dims_aval.size > 0 or rhs_last_dims_aval.size > 0 num_groups = ( lhs_first_dims_aval.size or lhs_last_dims_aval.size @@ -1519,18 +1556,28 @@ def abstract( or additional_args[0].size # alpha (V2) has size G; group_offset (legacy) has size >= 1 ) - # lhs_data_aval and rhs_data_aval are now 2D; derive output shape from buffer dims - if is_wgrad: - # lhs shape [K_lhs, M] (lhs_is_trans=True) or [M, K_lhs] (lhs_is_trans=False) - # M is the non-contracting (output) dim - M = lhs_data_aval.shape[1] if lhs_is_trans else lhs_data_aval.shape[0] - N = rhs_data_aval.shape[1] + _assert_grouped_gemm_dims_shapes( + lhs_first_dims_aval, + lhs_last_dims_aval, + rhs_first_dims_aval, + rhs_last_dims_aval, + out_first_dims_aval, + out_last_dims_aval, + num_groups, + ) + + # lhs_data_aval and rhs_data_aval are 2D; derive output shape from buffer dims. + # lhs shape: [M, K] (lhs_is_trans=False) or [K, M] (lhs_is_trans=True) + # rhs shape: [G*K, N] or [K, N] (rhs_is_trans=False) or [G*N, K] (rhs_is_trans=True) + M = _grouped_gemm_lhs_M(lhs_data_aval.shape, lhs_is_trans) + N = _grouped_gemm_rhs_N(rhs_data_aval.shape, rhs_is_trans, num_groups) + # When rhs has a ragged (contracting) K dimension, M and N are fixed per group + # and the output has a leading group axis. + # K validation is intentionally skipped: per-group K values may not fill the + # entire buffer (padding is allowed), so sum(rhs_*_dims) != buffer K is acceptable. + if rhs_first_dims_aval.size > 0 or rhs_last_dims_aval.size > 0: out_shape = (num_groups, M, N) else: - # lhs shape [M_total, K] (lhs_is_trans=False) or [K, M_total] (lhs_is_trans=True) - # M is the non-contracting (output) dim - M = lhs_data_aval.shape[1] if lhs_is_trans else lhs_data_aval.shape[0] - N = rhs_data_aval.shape[1] if not rhs_is_trans else rhs_data_aval.shape[0] // num_groups out_shape = (M, N) cublas_workspace_aval = jax.core.ShapedArray( @@ -2150,30 +2197,12 @@ def grouped_gemm( lhs_is_trans = lhs_contract_dim[-1] != len(lhs_shape) - 1 lhs_flatten_axis = len(lhs_contract_dim) * (1 if lhs_is_trans else -1) - # rhs_shape [G, K, N] - rhs_is_trans = rhs_contract_dim[0] != 1 + # rhs_is_trans: K is the last dim of rhs (i.e., rhs is in "T" layout). + # This formula handles both standard rhs [G, K, N] (G-prefixed) and wgrad + # rhs [K_total, N] (no G prefix) without needing a separate wgrad override. + rhs_is_trans = rhs_contract_dim[-1] == len(rhs_shape) - 1 rhs_flatten_axis = -len(rhs_contract_dim) if rhs_is_trans else 1 + len(rhs_contract_dim) - # TODO(Hua): these are for fp16 dense wgrad, any better way to handle this? - if ( - (rhs_first_dims.size > 0 or rhs_last_dims.size > 0) # wgrad mode: rhs dim is ragged - and not isinstance(lhs, ScaledTensor) - and not isinstance(rhs, ScaledTensor) - ): - lhs_is_trans = True - rhs_is_trans = False - lhs_flatten_axis = 1 - rhs_flatten_axis = 1 - - # For MXFP8 block-scaling wgrad with pre-quantized inputs: rhs is colwise quantized, - # so rhs_use_colwise = (is_mxfp8 && !rhs_is_trans) must be True → rhs_is_trans=False. - if ( - (rhs_first_dims.size > 0 or rhs_last_dims.size > 0) # wgrad mode: rhs dim is ragged - and isinstance(lhs, GroupedScaledTensor1x) - and scaling_mode.is_1d_block_scaling() - ): - rhs_is_trans = False - if ( not isinstance(lhs, ScaledTensor) and not isinstance(rhs, ScaledTensor) @@ -2270,13 +2299,6 @@ def grouped_gemm( lhs_data_2d = _flatten_to_2d(lhs_data, lhs_flatten_axis) rhs_data_2d = _flatten_to_2d(rhs_data, rhs_flatten_axis) - # Validate contracting dim size - k_lhs = lhs_data_2d.shape[0] if lhs_is_trans else lhs_data_2d.shape[1] - k_rhs = rhs_data_2d.shape[1] if rhs_is_trans else rhs_data_2d.shape[0] // num_gemms - assert k_lhs == k_rhs, ( - f"Contracting dimension mismatch: LHS K={k_lhs}, RHS K={k_rhs}" - ) - num_gemms = ( lhs_first_dims.size or lhs_last_dims.size diff --git a/transformer_engine/jax/csrc/extensions/gemm.cpp b/transformer_engine/jax/csrc/extensions/gemm.cpp index e122e8f909..9e996c8f3a 100644 --- a/transformer_engine/jax/csrc/extensions/gemm.cpp +++ b/transformer_engine/jax/csrc/extensions/gemm.cpp @@ -892,7 +892,7 @@ Error_Type GroupedGemmFFI(cudaStream_t stream, Buffer_Type lhs_data, Buffer_Type m = lhs_is_trans ? lhs_data.dimensions()[1] : lhs_data.dimensions()[0]; n = rhs_data.dimensions()[1]; } else { - m = lhs_data.dimensions()[0]; // total M (sum of group sizes) + m = lhs_is_trans ? lhs_data.dimensions()[1] : lhs_data.dimensions()[0]; // total M (sum of group sizes) n = rhs_is_trans ? rhs_data.dimensions()[0] / num_gemms : rhs_data.dimensions()[1]; } From 089e530d2f5b9a3732f1d2bff08bd9650ff62a73 Mon Sep 17 00:00:00 2001 From: Jeremy Berchtold Date: Tue, 10 Mar 2026 16:58:35 -0700 Subject: [PATCH 09/60] Refactor XLA FFI struct setup Signed-off-by: Jeremy Berchtold --- transformer_engine/jax/csrc/extensions.h | 28 +++++ .../jax/csrc/extensions/gemm.cpp | 104 ------------------ 2 files changed, 28 insertions(+), 104 deletions(-) diff --git a/transformer_engine/jax/csrc/extensions.h b/transformer_engine/jax/csrc/extensions.h index 93c85aaacc..98a97084a1 100644 --- a/transformer_engine/jax/csrc/extensions.h +++ b/transformer_engine/jax/csrc/extensions.h @@ -45,6 +45,20 @@ struct ActivationConfig { ClampedSwigluConfig clamped_swiglu; }; +struct GroupedGemmV2Config { + bool lhs_is_trans; + bool rhs_is_trans; + JAXX_Scaling_Mode scaling_mode; +}; + +struct GroupedGemmConfig { + bool lhs_is_trans; + bool rhs_is_trans; + JAXX_Scaling_Mode scaling_mode; + bool has_bias; + bool use_async_d2h_group_sizes; +}; + inline bool use_fp8(DType type) { return type == DType::kFloat8E4M3 || type == DType::kFloat8E5M2; } // Activation @@ -170,6 +184,20 @@ XLA_FFI_REGISTER_STRUCT_ATTR_DECODING( transformer_engine::jax::ActivationConfig, ::xla::ffi::StructMember("clamped_swiglu")); +XLA_FFI_REGISTER_STRUCT_ATTR_DECODING( + transformer_engine::jax::GroupedGemmV2Config, + ::xla::ffi::StructMember("lhs_is_trans"), + ::xla::ffi::StructMember("rhs_is_trans"), + ::xla::ffi::StructMember("scaling_mode")); + +XLA_FFI_REGISTER_STRUCT_ATTR_DECODING( + transformer_engine::jax::GroupedGemmConfig, + ::xla::ffi::StructMember("lhs_is_trans"), + ::xla::ffi::StructMember("rhs_is_trans"), + ::xla::ffi::StructMember("scaling_mode"), + ::xla::ffi::StructMember("has_bias"), + ::xla::ffi::StructMember("use_async_d2h_group_sizes")); + // ENUM_ATTR and DICT_ATTR recoding need to be registered in the global namespace XLA_FFI_REGISTER_ENUM_ATTR_DECODING(transformer_engine::jax::JAXX_Scaling_Mode); XLA_FFI_REGISTER_ENUM_ATTR_DECODING(transformer_engine::jax::JAXX_Score_Function); diff --git a/transformer_engine/jax/csrc/extensions/gemm.cpp b/transformer_engine/jax/csrc/extensions/gemm.cpp index 9e996c8f3a..f28deaaea7 100644 --- a/transformer_engine/jax/csrc/extensions/gemm.cpp +++ b/transformer_engine/jax/csrc/extensions/gemm.cpp @@ -632,113 +632,9 @@ size_t grouped_gemm_num_gemms(Buffer_Type const &lhs_first_dims, } } -// Config structs for grouped GEMM FFI static attributes. -// Consolidating all static attributes into a single dict attribute makes it easy to add new -// attributes in the future with backwards-compatible defaults: if old HLO was generated without a -// newer attribute, DecodeAttrOrDefault leaves the field at its struct default value. -struct GroupedGemmV2Config { - bool lhs_is_trans = false; - bool rhs_is_trans = false; - JAXX_Scaling_Mode scaling_mode = JAXX_Scaling_Mode::NO_SCALING; -}; - -struct GroupedGemmConfig { - bool lhs_is_trans = false; - bool rhs_is_trans = false; - JAXX_Scaling_Mode scaling_mode = JAXX_Scaling_Mode::NO_SCALING; - bool has_bias = false; - bool use_async_d2h_group_sizes = false; -}; - } // namespace jax } // namespace transformer_engine -// Register AttrsBinding and AttrDecoding for grouped GEMM config structs. -// Uses a custom AttrDecoding (instead of XLA_FFI_REGISTER_STRUCT_ATTR_DECODING) that supports -// optional struct fields with default values, so old HLO without newer attributes still decodes. -namespace xla::ffi { - -namespace { - -// Finds an attribute by name. Returns its index or std::nullopt if absent. -std::optional FindAttrByName(const XLA_FFI_Attrs *attrs, std::string_view name) { - for (int64_t i = 0; i < attrs->size; ++i) { - if (std::string_view{attrs->names[i]->ptr, attrs->names[i]->len} == name) return i; - } - return std::nullopt; -} - -// Decodes a named attribute into `field` if present; leaves `field` at its default if absent. -// Returns false only when the attribute is present but fails to decode. -template -bool DecodeAttrOrDefault(const XLA_FFI_Attrs *attrs, std::string_view name, T &field, - DiagnosticEngine &diagnostic) { - auto idx = FindAttrByName(attrs, name); - if (!idx.has_value()) return true; // absent → keep default - auto decoded = AttrDecoding::Decode(attrs->types[*idx], attrs->attrs[*idx], diagnostic); - if (!decoded.has_value()) return false; - field = *decoded; - return true; -} - -} // namespace - -template <> -struct AttrsBinding { - using Attrs = transformer_engine::jax::GroupedGemmV2Config; -}; - -template <> -struct AttrDecoding { - using Type = transformer_engine::jax::GroupedGemmV2Config; - static std::optional Decode(XLA_FFI_AttrType type, void *attr, - DiagnosticEngine &diagnostic) { - if (XLA_FFI_PREDICT_FALSE(type != XLA_FFI_AttrType_DICTIONARY)) { - return diagnostic.Emit("Expected dictionary attribute for GroupedGemmV2Config"); - } - auto *attrs = reinterpret_cast(attr); - Type config; - if (!DecodeAttrOrDefault(attrs, "lhs_is_trans", config.lhs_is_trans, diagnostic)) - return std::nullopt; - if (!DecodeAttrOrDefault(attrs, "rhs_is_trans", config.rhs_is_trans, diagnostic)) - return std::nullopt; - if (!DecodeAttrOrDefault(attrs, "scaling_mode", config.scaling_mode, diagnostic)) - return std::nullopt; - return config; - } -}; - -template <> -struct AttrsBinding { - using Attrs = transformer_engine::jax::GroupedGemmConfig; -}; - -template <> -struct AttrDecoding { - using Type = transformer_engine::jax::GroupedGemmConfig; - static std::optional Decode(XLA_FFI_AttrType type, void *attr, - DiagnosticEngine &diagnostic) { - if (XLA_FFI_PREDICT_FALSE(type != XLA_FFI_AttrType_DICTIONARY)) { - return diagnostic.Emit("Expected dictionary attribute for GroupedGemmConfig"); - } - auto *attrs = reinterpret_cast(attr); - Type config; - if (!DecodeAttrOrDefault(attrs, "lhs_is_trans", config.lhs_is_trans, diagnostic)) - return std::nullopt; - if (!DecodeAttrOrDefault(attrs, "rhs_is_trans", config.rhs_is_trans, diagnostic)) - return std::nullopt; - if (!DecodeAttrOrDefault(attrs, "scaling_mode", config.scaling_mode, diagnostic)) - return std::nullopt; - if (!DecodeAttrOrDefault(attrs, "has_bias", config.has_bias, diagnostic)) return std::nullopt; - if (!DecodeAttrOrDefault(attrs, "use_async_d2h_group_sizes", config.use_async_d2h_group_sizes, - diagnostic)) - return std::nullopt; - return config; - } -}; - -} // namespace xla::ffi - namespace transformer_engine { namespace jax { From 8ad229483323db437bb3426d600a42c019aa3f67 Mon Sep 17 00:00:00 2001 From: Jeremy Berchtold Date: Tue, 10 Mar 2026 17:04:21 -0700 Subject: [PATCH 10/60] Fix edge case in TE v1 GMM Signed-off-by: Jeremy Berchtold --- transformer_engine/jax/csrc/extensions/gemm.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/transformer_engine/jax/csrc/extensions/gemm.cpp b/transformer_engine/jax/csrc/extensions/gemm.cpp index f28deaaea7..7795d4c18a 100644 --- a/transformer_engine/jax/csrc/extensions/gemm.cpp +++ b/transformer_engine/jax/csrc/extensions/gemm.cpp @@ -766,7 +766,7 @@ Error_Type GroupedGemmFFI(cudaStream_t stream, Buffer_Type lhs_data, Buffer_Type else if (is_rhs_last_ragged) num_gemms = rhs_last_dims.dimensions()[0]; else - num_gemms = 1; // degenerate batched; legacy batched not a tested use case + NVTE_CHECK(false, "GroupedGemmFFI (v1): At least one of the group size buffers must be non-empty to determine num_gemms."); const Buffer_Type *active_gs_ptr = nullptr; if (is_lhs_first_ragged) From 4ff5d1d9ffd9e7c43410de5456f42d0c1b43921f Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Wed, 11 Mar 2026 00:26:02 +0000 Subject: [PATCH 11/60] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- transformer_engine/jax/cpp_extensions/gemm.py | 6 +- transformer_engine/jax/csrc/extensions.h | 6 +- .../jax/csrc/extensions/gemm.cpp | 56 ++++++++----------- 3 files changed, 29 insertions(+), 39 deletions(-) diff --git a/transformer_engine/jax/cpp_extensions/gemm.py b/transformer_engine/jax/cpp_extensions/gemm.py index 1ef7da9cbc..1f15c27d97 100644 --- a/transformer_engine/jax/cpp_extensions/gemm.py +++ b/transformer_engine/jax/cpp_extensions/gemm.py @@ -1313,9 +1313,9 @@ def _assert_grouped_gemm_dims_shapes( ("rhs_last_dims", rhs_last_dims_aval), ]: if aval.size > 0: - assert aval.size == num_groups, ( - f"grouped GEMM {name} has size {aval.size}, expected num_groups={num_groups}" - ) + assert ( + aval.size == num_groups + ), f"grouped GEMM {name} has size {aval.size}, expected num_groups={num_groups}" class GroupedGemmPrimitive(BasePrimitive): diff --git a/transformer_engine/jax/csrc/extensions.h b/transformer_engine/jax/csrc/extensions.h index 9a5647f7c7..bd429a7db6 100644 --- a/transformer_engine/jax/csrc/extensions.h +++ b/transformer_engine/jax/csrc/extensions.h @@ -207,14 +207,12 @@ XLA_FFI_REGISTER_STRUCT_ATTR_DECODING( ::xla::ffi::StructMember("use_split_accumulator")); XLA_FFI_REGISTER_STRUCT_ATTR_DECODING( - transformer_engine::jax::GroupedGemmV2Config, - ::xla::ffi::StructMember("lhs_is_trans"), + transformer_engine::jax::GroupedGemmV2Config, ::xla::ffi::StructMember("lhs_is_trans"), ::xla::ffi::StructMember("rhs_is_trans"), ::xla::ffi::StructMember("scaling_mode")); XLA_FFI_REGISTER_STRUCT_ATTR_DECODING( - transformer_engine::jax::GroupedGemmConfig, - ::xla::ffi::StructMember("lhs_is_trans"), + transformer_engine::jax::GroupedGemmConfig, ::xla::ffi::StructMember("lhs_is_trans"), ::xla::ffi::StructMember("rhs_is_trans"), ::xla::ffi::StructMember("scaling_mode"), ::xla::ffi::StructMember("has_bias"), diff --git a/transformer_engine/jax/csrc/extensions/gemm.cpp b/transformer_engine/jax/csrc/extensions/gemm.cpp index dec72f809e..dd6c0a59f2 100644 --- a/transformer_engine/jax/csrc/extensions/gemm.cpp +++ b/transformer_engine/jax/csrc/extensions/gemm.cpp @@ -626,10 +626,8 @@ JAXX_GroupedTensorWrapper make_grouped_tensor(Buffer_Type const &data, Buffer_Type const &first_dims, Buffer_Type const &last_dims, int64_t *int64_workspace_base, - size_t int64_workspace_capacity, - size_t &int64_offset, - size_t num_gemms, - cudaStream_t stream) { + size_t int64_workspace_capacity, size_t &int64_offset, + size_t num_gemms, cudaStream_t stream) { auto dims = data.dimensions(); NVTE_CHECK(dims.size() >= 2, "grouped GEMM data buffer must be at least 2D."); // Flatten all leading dimensions into the first axis to produce a 2D NVTE shape. @@ -640,24 +638,22 @@ JAXX_GroupedTensorWrapper make_grouped_tensor(Buffer_Type const &data, JAXX_GroupedTensorWrapper wrapper(JAXX_Scaling_Mode::NO_SCALING, num_gemms, dataShape); wrapper.set_rowwise(data, std::nullopt); if (first_dims.element_count() > 0) { - NVTE_CHECK(first_dims.element_type() == xla::ffi::DataType::S32, - "group_sizes must be int32."); + NVTE_CHECK(first_dims.element_type() == xla::ffi::DataType::S32, "group_sizes must be int32."); NVTE_CHECK(int64_offset + num_gemms <= int64_workspace_capacity, "int64_workspace overflow: not enough space for first_dims conversion."); auto *slot = int64_workspace_base + int64_offset; - nvte_convert_int32_to_int64( - reinterpret_cast(first_dims.untyped_data()), slot, num_gemms, stream); + nvte_convert_int32_to_int64(reinterpret_cast(first_dims.untyped_data()), slot, + num_gemms, stream); wrapper.set_group_sizes_only(slot, num_gemms, kNVTEGroupedFirstDims); int64_offset += num_gemms; } if (last_dims.element_count() > 0) { - NVTE_CHECK(last_dims.element_type() == xla::ffi::DataType::S32, - "group_sizes must be int32."); + NVTE_CHECK(last_dims.element_type() == xla::ffi::DataType::S32, "group_sizes must be int32."); NVTE_CHECK(int64_offset + num_gemms <= int64_workspace_capacity, "int64_workspace overflow: not enough space for last_dims conversion."); auto *slot = int64_workspace_base + int64_offset; - nvte_convert_int32_to_int64( - reinterpret_cast(last_dims.untyped_data()), slot, num_gemms, stream); + nvte_convert_int32_to_int64(reinterpret_cast(last_dims.untyped_data()), slot, + num_gemms, stream); wrapper.set_group_sizes_only(slot, num_gemms, kNVTEGroupedLastDims); int64_offset += num_gemms; } @@ -666,12 +662,9 @@ JAXX_GroupedTensorWrapper make_grouped_tensor(Buffer_Type const &data, // Returns num_gemms from the first non-empty per-tensor group_sizes buffer, // falling back to the element count of alpha for the uniform-batch case. -size_t grouped_gemm_num_gemms(Buffer_Type const &lhs_first_dims, - Buffer_Type const &lhs_last_dims, - Buffer_Type const &rhs_first_dims, - Buffer_Type const &rhs_last_dims, - Buffer_Type const &out_first_dims, - Buffer_Type const &out_last_dims, +size_t grouped_gemm_num_gemms(Buffer_Type const &lhs_first_dims, Buffer_Type const &lhs_last_dims, + Buffer_Type const &rhs_first_dims, Buffer_Type const &rhs_last_dims, + Buffer_Type const &out_first_dims, Buffer_Type const &out_last_dims, Buffer_Type const &alpha) { if (lhs_first_dims.element_count() > 0) { return lhs_first_dims.dimensions()[0]; @@ -710,9 +703,8 @@ Error_Type GroupedGemmV2FFI(cudaStream_t stream, Buffer_Type lhs_data, Buffer_Ty NVTE_CHECK(scaling_mode == JAXX_Scaling_Mode::NO_SCALING, "Only non-quantized grouped GEMM is supported in current implementation."); - size_t num_gemms = grouped_gemm_num_gemms(lhs_first_dims, lhs_last_dims, - rhs_first_dims, rhs_last_dims, - out_first_dims, out_last_dims, alpha); + size_t num_gemms = grouped_gemm_num_gemms(lhs_first_dims, lhs_last_dims, rhs_first_dims, + rhs_last_dims, out_first_dims, out_last_dims, alpha); // Workspaces. auto setup_workspace_ptr = reinterpret_cast(setup_workspace->untyped_data()); @@ -739,15 +731,12 @@ Error_Type GroupedGemmV2FFI(cudaStream_t stream, Buffer_Type lhs_data, Buffer_Ty auto *int64_base = reinterpret_cast(int64_workspace->untyped_data()); size_t int64_capacity = int64_workspace->element_count() / sizeof(int64_t); size_t int64_offset = 0; - auto rhs_tensor = make_grouped_tensor(rhs_data, rhs_first_dims, rhs_last_dims, - int64_base, int64_capacity, int64_offset, num_gemms, - stream); - auto lhs_tensor = make_grouped_tensor(lhs_data, lhs_first_dims, lhs_last_dims, - int64_base, int64_capacity, int64_offset, num_gemms, - stream); - auto out_tensor = make_grouped_tensor(*output, out_first_dims, out_last_dims, - int64_base, int64_capacity, int64_offset, num_gemms, - stream); + auto rhs_tensor = make_grouped_tensor(rhs_data, rhs_first_dims, rhs_last_dims, int64_base, + int64_capacity, int64_offset, num_gemms, stream); + auto lhs_tensor = make_grouped_tensor(lhs_data, lhs_first_dims, lhs_last_dims, int64_base, + int64_capacity, int64_offset, num_gemms, stream); + auto out_tensor = make_grouped_tensor(*output, out_first_dims, out_last_dims, int64_base, + int64_capacity, int64_offset, num_gemms, stream); nvte_grouped_gemm(rhs_tensor, rhs_is_trans, lhs_tensor, lhs_is_trans, nullptr, out_tensor, alpha_tensor.data(), beta_tensor.data(), workspace_setup.data(), @@ -824,7 +813,9 @@ Error_Type GroupedGemmFFI(cudaStream_t stream, Buffer_Type lhs_data, Buffer_Type else if (is_rhs_last_ragged) num_gemms = rhs_last_dims.dimensions()[0]; else - NVTE_CHECK(false, "GroupedGemmFFI (v1): At least one of the group size buffers must be non-empty to determine num_gemms."); + NVTE_CHECK(false, + "GroupedGemmFFI (v1): At least one of the group size buffers must be non-empty to " + "determine num_gemms."); const Buffer_Type *active_gs_ptr = nullptr; if (is_lhs_first_ragged) @@ -846,7 +837,8 @@ Error_Type GroupedGemmFFI(cudaStream_t stream, Buffer_Type lhs_data, Buffer_Type m = lhs_is_trans ? lhs_data.dimensions()[1] : lhs_data.dimensions()[0]; n = rhs_data.dimensions()[1]; } else { - m = lhs_is_trans ? lhs_data.dimensions()[1] : lhs_data.dimensions()[0]; // total M (sum of group sizes) + m = lhs_is_trans ? lhs_data.dimensions()[1] + : lhs_data.dimensions()[0]; // total M (sum of group sizes) n = rhs_is_trans ? rhs_data.dimensions()[0] / num_gemms : rhs_data.dimensions()[1]; } From 0cb7289643a8349817916be637dc5f247df6696b Mon Sep 17 00:00:00 2001 From: Jeremy Berchtold Date: Wed, 11 Mar 2026 10:13:34 -0700 Subject: [PATCH 12/60] Fix issues on Hopper Signed-off-by: Jeremy Berchtold --- transformer_engine/jax/cpp_extensions/gemm.py | 13 ++++++++++--- transformer_engine/jax/csrc/extensions/gemm.cpp | 3 ++- 2 files changed, 12 insertions(+), 4 deletions(-) diff --git a/transformer_engine/jax/cpp_extensions/gemm.py b/transformer_engine/jax/cpp_extensions/gemm.py index 1f15c27d97..06af064c9f 100644 --- a/transformer_engine/jax/cpp_extensions/gemm.py +++ b/transformer_engine/jax/cpp_extensions/gemm.py @@ -1408,14 +1408,16 @@ def abstract( # lhs shape: [M, K] (lhs_is_trans=False) or [K, M] (lhs_is_trans=True) # rhs shape: [G*K, N] or [K, N] (rhs_is_trans=False) or [G*N, K] (rhs_is_trans=True) M = _grouped_gemm_lhs_M(lhs_data_aval.shape, lhs_is_trans) - N = _grouped_gemm_rhs_N(rhs_data_aval.shape, rhs_is_trans, num_groups) - # When rhs has a ragged (contracting) K dimension, M and N are fixed per group - # and the output has a leading group axis. # K validation is intentionally skipped: per-group K values may not fill the # entire buffer (padding is allowed), so sum(rhs_*_dims) != buffer K is acceptable. if rhs_first_dims_aval.size > 0 or rhs_last_dims_aval.size > 0: + # Wgrad case: rhs has ragged contracting K dimension with no G-prefix. + # T-layout rhs shape is (N, K_total); N-layout rhs shape is (K_total, N). + N = rhs_data_aval.shape[0] if rhs_is_trans else rhs_data_aval.shape[1] out_shape = (num_groups, M, N) else: + # When rhs has a leading group axis, _grouped_gemm_rhs_N divides by num_groups. + N = _grouped_gemm_rhs_N(rhs_data_aval.shape, rhs_is_trans, num_groups) out_shape = (M, N) cublas_workspace_aval = jax.core.ShapedArray( @@ -1889,6 +1891,11 @@ def _can_use_v2_grouped_gemm( if not _v2_grouped_gemm_available: return False + # nvte_grouped_gemm (the v2 kernel) requires SM100+ (Blackwell or newer). + # Fall back to the v1 path on SM90 (Hopper) and older architectures. + if get_device_compute_capability(0) < 100: + return False + return scaling_mode == ScalingMode.NO_SCALING and dtype == jnp.bfloat16 and not has_bias diff --git a/transformer_engine/jax/csrc/extensions/gemm.cpp b/transformer_engine/jax/csrc/extensions/gemm.cpp index dd6c0a59f2..2d73390d33 100644 --- a/transformer_engine/jax/csrc/extensions/gemm.cpp +++ b/transformer_engine/jax/csrc/extensions/gemm.cpp @@ -835,7 +835,8 @@ Error_Type GroupedGemmFFI(cudaStream_t stream, Buffer_Type lhs_data, Buffer_Type if (is_rhs_ragged) { // wgrad: lhs shape [K_lhs, M]: lhs_is_trans=True, contracting is dim[0]=K_lhs, output is dim[1]=M m = lhs_is_trans ? lhs_data.dimensions()[1] : lhs_data.dimensions()[0]; - n = rhs_data.dimensions()[1]; + // T-layout rhs: (N, K_total) -> n = dim[0]; N-layout rhs: (K_total, N) -> n = dim[1] + n = rhs_is_trans ? rhs_data.dimensions()[0] : rhs_data.dimensions()[1]; } else { m = lhs_is_trans ? lhs_data.dimensions()[1] : lhs_data.dimensions()[0]; // total M (sum of group sizes) From cc236ad10ccea4edc7b740cf236776c3586c3381 Mon Sep 17 00:00:00 2001 From: Jeremy Berchtold Date: Thu, 12 Mar 2026 14:50:43 -0700 Subject: [PATCH 13/60] Refactor Signed-off-by: Jeremy Berchtold --- tests/jax/test_custom_call_compute.py | 35 ++- transformer_engine/jax/cpp_extensions/gemm.py | 254 +++++++++--------- .../jax/cpp_extensions/quantization.py | 3 +- transformer_engine/jax/csrc/extensions.h | 12 +- .../jax/csrc/extensions/gemm.cpp | 51 ++-- transformer_engine/jax/dense.py | 71 +++-- .../jax/quantize/dequantizer.py | 11 +- transformer_engine/jax/quantize/quantizer.py | 2 +- transformer_engine/jax/quantize/tensor.py | 159 +++++++++-- 9 files changed, 389 insertions(+), 209 deletions(-) diff --git a/tests/jax/test_custom_call_compute.py b/tests/jax/test_custom_call_compute.py index 2f2d5383b2..9fddbc435c 100644 --- a/tests/jax/test_custom_call_compute.py +++ b/tests/jax/test_custom_call_compute.py @@ -36,6 +36,7 @@ ScaledTensor1x, ScaledTensor2x, GroupedScaledTensor1x, + GroupedNoScaleTensor, ScalingMode, QuantizerFactory, QuantizeLayout, @@ -1787,18 +1788,17 @@ def test_grouped_gemm_fp16(self, dtype, input_shape, layout): ref_out = self._ref_grouped_dense(lhs, rhs, None, group_sizes, contracting_dims) # jitting grouped_gemm - empty_gs = jnp.empty((0,), jnp.int32) + lhs_tensor = GroupedNoScaleTensor( + data=lhs, first_dims=group_sizes, last_dims=None, group_axis=0, original_shape=lhs.shape + ) + rhs_tensor = GroupedNoScaleTensor( + data=rhs, first_dims=None, last_dims=None, group_axis=0, original_shape=rhs.shape + ) prim_out = jax.jit( tex.grouped_gemm, static_argnames=("contracting_dims", "use_async_d2h_group_sizes") )( - lhs, - rhs, - lhs_first_dims=group_sizes, - lhs_last_dims=empty_gs, - rhs_first_dims=empty_gs, - rhs_last_dims=empty_gs, - out_first_dims=group_sizes, - out_last_dims=empty_gs, + lhs_tensor, + rhs_tensor, contracting_dims=contracting_dims, use_async_d2h_group_sizes=True, ) @@ -1831,16 +1831,15 @@ def test_grouped_gemm_fp8(self, fwd_bwd_dtype, scaling_mode, input_shape, layout ) ref_out = self._ref_grouped_dense(lhs, rhs, None, group_sizes, contracting_dims) - empty_gs = jnp.empty((0,), jnp.int32) + lhs_tensor = GroupedNoScaleTensor( + data=lhs, first_dims=group_sizes, last_dims=None, group_axis=0, original_shape=lhs.shape + ) + rhs_tensor = GroupedNoScaleTensor( + data=rhs, first_dims=None, last_dims=None, group_axis=0, original_shape=rhs.shape + ) prim_out = jax.jit(tex.grouped_gemm, static_argnames=("contracting_dims",))( - lhs, - rhs, - lhs_first_dims=group_sizes, - lhs_last_dims=empty_gs, - rhs_first_dims=empty_gs, - rhs_last_dims=empty_gs, - out_first_dims=group_sizes, - out_last_dims=empty_gs, + lhs_tensor, + rhs_tensor, contracting_dims=contracting_dims, quantizer_set=quantizer_set, ) diff --git a/transformer_engine/jax/cpp_extensions/gemm.py b/transformer_engine/jax/cpp_extensions/gemm.py index 6a41cfc94e..ff9194bdd9 100644 --- a/transformer_engine/jax/cpp_extensions/gemm.py +++ b/transformer_engine/jax/cpp_extensions/gemm.py @@ -37,6 +37,7 @@ ScaledTensor1x, ScaledTensor2x, GroupedScaledTensor1x, + GroupedNoScaleTensor, ScalingMode, Quantizer, GroupedQuantizer, @@ -1331,15 +1332,6 @@ def impl( register_primitive(GroupedGemmCopySizesPrimitive) -def _grouped_gemm_lhs_M(lhs_shape_2d: Tuple[int, int], lhs_is_trans: bool) -> int: - """Non-contracting output size M from the 2-D LHS buffer.""" - return lhs_shape_2d[1] if lhs_is_trans else lhs_shape_2d[0] - - -def _grouped_gemm_rhs_N(rhs_shape_2d: Tuple[int, int], rhs_is_trans: bool, num_groups: int) -> int: - """Non-contracting output size N from the 2-D RHS buffer.""" - return rhs_shape_2d[0] // num_groups if rhs_is_trans else rhs_shape_2d[1] - def _assert_grouped_gemm_dims_shapes( lhs_first_dims_aval, @@ -1381,7 +1373,7 @@ class GroupedGemmPrimitive(BasePrimitive): # out_first_dims, out_last_dims, alpha, beta name_graph_safe = "te_grouped_gemm_v2_ffi" multiple_results = True - impl_static_args = (13, 14, 15, 16, 17, 18, 19) + impl_static_args = (13, 14, 15, 16, 17, 18, 19, 20, 21, 22) inner_primitive = None outer_primitive = None @@ -1406,19 +1398,22 @@ def abstract( has_bias, use_async_d2h_group_sizes, use_v2_ffi, + lhs_axis_boundary, + rhs_axis_boundary, + rhs_group_axis, ): """ Grouped GEMM operation. Args: - lhs_data: Left-hand side input matrix data, 2D array [rows, cols] + lhs_data: Left-hand side input matrix data, N-D array lhs_scale_inv: Left-hand side input scale_inv matrix, 1D flattened array - rhs_data: Right-hand side input matrix data, 2D array [rows, cols] + rhs_data: Right-hand side input matrix data, N-D array rhs_scale_inv: Right-hand side input scale_inv matrix, 1D flattened array bias: Bias matrix of shape (G, N) - lhs_group_sizes: (G,) int32 if lhs first-dim is ragged, else empty (0,) sentinel - rhs_group_sizes: (G,) int32 if rhs first-dim is ragged (wgrad), else empty (0,) sentinel - out_group_sizes: (G,) int32 if output first-dim is ragged, else empty (0,) sentinel + lhs_first_dims: (G,) int32 if lhs first-dim is ragged, else empty (0,) sentinel + rhs_first_dims: (G,) int32 if rhs first-dim is ragged (wgrad), else empty (0,) sentinel + out_first_dims: (G,) int32 if output first-dim is ragged, else empty (0,) sentinel additional_args: Either * group_offsets: 1D array containing offsets for each group (not yet implemented) OR @@ -1429,6 +1424,9 @@ def abstract( scaling_mode: Scaling mode for the GEMM operations out_dtype: Data type of the output tensors has_bias: Boolean indicating if bias tensors are provided + lhs_axis_boundary: Axis split point for lhs N-D → 2D flattening + rhs_axis_boundary: Axis split point for rhs N-D → 2D flattening + rhs_group_axis: Batch-group axis of rhs to exclude from output non-contracting dims Returns: A jnp.ndarray containing the result of the grouped GEMM operation @@ -1456,21 +1454,33 @@ def abstract( num_groups, ) - # lhs_data_aval and rhs_data_aval are 2D; derive output shape from buffer dims. - # lhs shape: [M, K] (lhs_is_trans=False) or [K, M] (lhs_is_trans=True) - # rhs shape: [G*K, N] or [K, N] (rhs_is_trans=False) or [G*N, K] (rhs_is_trans=True) - M = _grouped_gemm_lhs_M(lhs_data_aval.shape, lhs_is_trans) + # Derive output shape from N-D buffer shapes using axis_boundary. + lhs_shape = lhs_data_aval.shape + rhs_shape = rhs_data_aval.shape + + # Non-contracting dims for lhs + if lhs_is_trans: + lhs_non_contracting = lhs_shape[lhs_axis_boundary:] + else: + lhs_non_contracting = lhs_shape[:lhs_axis_boundary] + + # Non-contracting dims for rhs (excluding batch-group axis where applicable) + if rhs_is_trans: + rhs_non_contracting = tuple( + rhs_shape[d] + for d in range(rhs_axis_boundary) + if rhs_group_axis is None or d != rhs_group_axis + ) + else: + rhs_non_contracting = rhs_shape[rhs_axis_boundary:] + # K validation is intentionally skipped: per-group K values may not fill the # entire buffer (padding is allowed), so sum(rhs_*_dims) != buffer K is acceptable. if rhs_first_dims_aval.size > 0 or rhs_last_dims_aval.size > 0: - # Wgrad case: rhs has ragged contracting K dimension with no G-prefix. - # T-layout rhs shape is (N, K_total); N-layout rhs shape is (K_total, N). - N = rhs_data_aval.shape[0] if rhs_is_trans else rhs_data_aval.shape[1] - out_shape = (num_groups, M, N) + # Wgrad case: rhs has ragged contracting K dimension → output gets G prefix. + out_shape = (num_groups, *lhs_non_contracting, *rhs_non_contracting) else: - # When rhs has a leading group axis, _grouped_gemm_rhs_N divides by num_groups. - N = _grouped_gemm_rhs_N(rhs_data_aval.shape, rhs_is_trans, num_groups) - out_shape = (M, N) + out_shape = (*lhs_non_contracting, *rhs_non_contracting) cublas_workspace_aval = jax.core.ShapedArray( shape=( @@ -1577,8 +1587,11 @@ def lowering( has_bias, use_async_d2h_group_sizes, use_v2_ffi, + lhs_axis_boundary, + rhs_axis_boundary, + rhs_group_axis, ): - del out_dtype + del out_dtype, rhs_group_axis # Python-only; not forwarded to C++ if use_v2_ffi: ffi_name = GroupedGemmPrimitive.name_graph_safe return jax.ffi.ffi_lowering(ffi_name)( @@ -1587,6 +1600,8 @@ def lowering( lhs_is_trans=lhs_is_trans, rhs_is_trans=rhs_is_trans, scaling_mode=scaling_mode.value, + lhs_axis_boundary=lhs_axis_boundary, + rhs_axis_boundary=rhs_axis_boundary, ) ffi_name = GroupedGemmPrimitive.name return jax.ffi.ffi_lowering(ffi_name)( @@ -1597,6 +1612,8 @@ def lowering( scaling_mode=scaling_mode.value, has_bias=has_bias, use_async_d2h_group_sizes=use_async_d2h_group_sizes, + lhs_axis_boundary=lhs_axis_boundary, + rhs_axis_boundary=rhs_axis_boundary, ) @staticmethod @@ -1621,6 +1638,9 @@ def impl( has_bias, use_async_d2h_group_sizes, use_v2_ffi, + lhs_axis_boundary, + rhs_axis_boundary, + rhs_group_axis, ): if GroupedGemmPrimitive.inner_primitive is None: raise RuntimeError("GroupedGemmPrimitive.inner_primitive has not been registered") @@ -1648,6 +1668,9 @@ def impl( has_bias=has_bias, use_async_d2h_group_sizes=use_async_d2h_group_sizes, use_v2_ffi=use_v2_ffi, + lhs_axis_boundary=lhs_axis_boundary, + rhs_axis_boundary=rhs_axis_boundary, + rhs_group_axis=rhs_group_axis, ) return (out,) @@ -1959,27 +1982,9 @@ def _can_use_v2_grouped_gemm( return scaling_mode == ScalingMode.NO_SCALING and dtype == jnp.bfloat16 and not has_bias -def _flatten_to_2d(data, flatten_axis): - """Reshape *data* to 2D by splitting at *flatten_axis*. - - Positive flatten_axis: split before that axis index. - Negative flatten_axis: split before (ndim + flatten_axis). - """ - if data.ndim == 2: - return data # Already 2D, no reshape needed - fa = flatten_axis if flatten_axis >= 0 else data.ndim + flatten_axis - return data.reshape(math.prod(data.shape[:fa]), math.prod(data.shape[fa:])) - - def grouped_gemm( - lhs: Union[jnp.ndarray, GroupedScaledTensor1x], - rhs: Union[jnp.ndarray, GroupedScaledTensor1x], - lhs_first_dims: jnp.ndarray = None, # (G,) int32 if LHS squashed first dim varies, else None/(0,) - lhs_last_dims: jnp.ndarray = None, # (G,) int32 if LHS squashed last dim varies, else None/(0,) - rhs_first_dims: jnp.ndarray = None, # (G,) int32 if RHS squashed first dim varies, else None/(0,) - rhs_last_dims: jnp.ndarray = None, # (G,) int32 if RHS squashed last dim varies, else None/(0,) - out_first_dims: jnp.ndarray = None, # (G,) int32 if output first dim varies, else None/(0,) - out_last_dims: jnp.ndarray = None, # (G,) int32 if output last dim varies, else None/(0,) + lhs: Union[GroupedNoScaleTensor, GroupedScaledTensor1x], + rhs: Union[GroupedNoScaleTensor, GroupedScaledTensor1x], contracting_dims: Tuple[Sequence[int], Sequence[int]] = ((1,), (2,)), bias: jnp.ndarray = None, precision: jax.lax.Precision = jax.lax.Precision.DEFAULT, @@ -1992,14 +1997,8 @@ def grouped_gemm( Grouped GEMM operation. Args: - lhs: Left-hand side input matrix, can be a jnp.ndarray or GroupedScaledTensor1x - rhs: Right-hand side input matrix, can be a jnp.ndarray or GroupedScaledTensor1x - lhs_first_dims: (G,) int32 if LHS squashed first dim varies per group, else None/(0,) - lhs_last_dims: (G,) int32 if LHS squashed last dim varies per group, else None/(0,) - rhs_first_dims: (G,) int32 if RHS squashed first dim varies per group (wgrad), else None/(0,) - rhs_last_dims: (G,) int32 if RHS squashed last dim varies per group, else None/(0,) - out_first_dims: (G,) int32 if output first dim varies per group, else None/(0,) - out_last_dims: (G,) int32 if output last dim varies per group, else None/(0,) + lhs: Left-hand side input matrix, GroupedNoScaleTensor or GroupedScaledTensor1x + rhs: Right-hand side input matrix, GroupedNoScaleTensor or GroupedScaledTensor1x contracting_dims: Tuple of two sequences representing the contracting dimensions bias: Bias tensor of shape (G, N) precision: JAX precision for the GEMM operation @@ -2009,60 +2008,76 @@ def grouped_gemm( Returns: A jnp.ndarray containing the result of the grouped GEMM operation - - Note: - Tested shapes: - lhs: [M, K] or [K, N] - rhs: [G, N, K] or [G, K, N] or [G * K, N] or [N, G * K] """ # TODO(Phuong): implement the precision del precision - # Replace None sentinels with empty (0,) int32 arrays. empty_gs = jnp.empty((0,), jnp.int32) - lhs_first_dims = empty_gs if lhs_first_dims is None else lhs_first_dims - lhs_last_dims = empty_gs if lhs_last_dims is None else lhs_last_dims - rhs_first_dims = empty_gs if rhs_first_dims is None else rhs_first_dims - rhs_last_dims = empty_gs if rhs_last_dims is None else rhs_last_dims - out_first_dims = empty_gs if out_first_dims is None else out_first_dims - out_last_dims = empty_gs if out_last_dims is None else out_last_dims - - if isinstance(lhs, jnp.ndarray): - if not isinstance(rhs, jnp.ndarray): - raise TypeError( - f"Expected rhs to be jnp.ndarray when lhs is jnp.ndarray, but got type={type(rhs)}" - ) - out_dtype = lhs.dtype - lhs_shape = lhs.shape - rhs_shape = rhs.shape - lhs_data = lhs - rhs_data = rhs - lhs_scale_inv = rhs_scale_inv = jnp.empty((0,), jnp.float32) + + # Extract data, dims, and metadata from tensor objects. + if isinstance(lhs, GroupedNoScaleTensor): + lhs_data = lhs.data + lhs_shape = lhs.original_shape + lhs_scale_inv = jnp.empty((0,), jnp.float32) scaling_mode = ScalingMode.NO_SCALING + out_dtype = lhs.data.dtype + lhs_first_dims = lhs.first_dims if lhs.first_dims is not None else empty_gs + lhs_last_dims = lhs.last_dims if lhs.last_dims is not None else empty_gs + rhs_group_axis = getattr(rhs, "group_axis", 0) elif isinstance(lhs, GroupedScaledTensor1x): - if not isinstance(rhs, GroupedScaledTensor1x): - raise TypeError( - "Expected rhs to be GroupedScaledTensor1x when lhs is GroupedScaledTensor1x, but" - f" got type={type(rhs)}" - ) - out_dtype = lhs.dq_dtype lhs_shape = lhs.original_shape - rhs_shape = rhs.original_shape - lhs_fa = lhs.flatten_axis - rhs_fa = rhs.flatten_axis - lhs_data = lhs.data.reshape(math.prod(lhs_shape[:lhs_fa]), math.prod(lhs_shape[lhs_fa:])) - rhs_data = rhs.data.reshape(math.prod(rhs_shape[:rhs_fa]), math.prod(rhs_shape[rhs_fa:])) + lhs_data = lhs.data.reshape(lhs_shape) lhs_scale_inv = lhs.scale_inv + scaling_mode = lhs.scaling_mode + out_dtype = lhs.dq_dtype + lhs_first_dims = lhs.first_dims if lhs.first_dims is not None else empty_gs + lhs_last_dims = lhs.last_dims if lhs.last_dims is not None else empty_gs + rhs_group_axis = getattr(rhs, "group_axis", 0) + else: + raise TypeError( + "lhs must be GroupedNoScaleTensor or GroupedScaledTensor1x, " + f"got type={type(lhs)}" + ) + + if isinstance(rhs, GroupedNoScaleTensor): + rhs_data = rhs.data + rhs_shape = rhs.original_shape + rhs_scale_inv = jnp.empty((0,), jnp.float32) + rhs_first_dims = rhs.first_dims if rhs.first_dims is not None else empty_gs + rhs_last_dims = rhs.last_dims if rhs.last_dims is not None else empty_gs + elif isinstance(rhs, GroupedScaledTensor1x): + rhs_shape = rhs.original_shape + rhs_data = rhs.data.reshape(rhs_shape) rhs_scale_inv = rhs.scale_inv - if lhs.scaling_mode != rhs.scaling_mode: + rhs_first_dims = rhs.first_dims if rhs.first_dims is not None else empty_gs + rhs_last_dims = rhs.last_dims if rhs.last_dims is not None else empty_gs + if isinstance(lhs, GroupedScaledTensor1x) and lhs.scaling_mode != rhs.scaling_mode: raise ValueError( f"Mismatched scaling modes: lhs.scaling_mode={lhs.scaling_mode}," f" rhs.scaling_mode={rhs.scaling_mode}" ) - scaling_mode = lhs.scaling_mode + if isinstance(lhs, GroupedScaledTensor1x): + scaling_mode = lhs.scaling_mode else: - raise TypeError("Unsupported lhs type object!") + raise TypeError( + "rhs must be GroupedNoScaleTensor or GroupedScaledTensor1x, " + f"got type={type(rhs)}" + ) + + # Infer output dims from which operand has the ragged non-contracting dim. + if rhs_first_dims.size > 0 or rhs_last_dims.size > 0: + # Wgrad: rhs contracting dim is ragged → output is uniform (G prefix from num_groups) + out_first_dims = empty_gs + out_last_dims = empty_gs + elif lhs_first_dims.size > 0: + out_first_dims = lhs_first_dims + out_last_dims = empty_gs + elif lhs_last_dims.size > 0: + out_first_dims = empty_gs + out_last_dims = lhs_last_dims + else: + out_first_dims = out_last_dims = empty_gs out_dtype = preferred_element_type or out_dtype @@ -2072,8 +2087,6 @@ def grouped_gemm( lhs_flatten_axis = len(lhs_contract_dim) * (1 if lhs_is_trans else -1) # rhs_is_trans: K is the last dim of rhs (i.e., rhs is in "T" layout). - # This formula handles both standard rhs [G, K, N] (G-prefixed) and wgrad - # rhs [K_total, N] (no G prefix) without needing a separate wgrad override. rhs_is_trans = rhs_contract_dim[-1] == len(rhs_shape) - 1 rhs_flatten_axis = -len(rhs_contract_dim) if rhs_is_trans else 1 + len(rhs_contract_dim) @@ -2115,29 +2128,18 @@ def grouped_gemm( ), empty_gs, ) - lhs_q = grouped_quantize(lhs, quantizer_set.x, active_group_sizes, lhs_flatten_axis) + lhs_input_data = lhs.data if isinstance(lhs, GroupedNoScaleTensor) else lhs_data + rhs_input_data = rhs.data if isinstance(rhs, GroupedNoScaleTensor) else rhs_data + lhs_q = grouped_quantize(lhs_input_data, quantizer_set.x, active_group_sizes, lhs_flatten_axis) rhs_q = grouped_quantize( - rhs, quantizer_set.kernel, group_sizes=None, flatten_axis=rhs_flatten_axis - ) - # grouped_quantize returns a 1D flat buffer; reshape to 2D using the - # original_shape and flatten_axis stored in each quantized tensor. - lhs_fa = lhs_q.flatten_axis # positive index (adjusted in create_1x) - rhs_fa = rhs_q.flatten_axis - lhs_data = lhs_q.data.reshape( - math.prod(lhs_q.original_shape[:lhs_fa]), - math.prod(lhs_q.original_shape[lhs_fa:]), - ) - rhs_data = rhs_q.data.reshape( - math.prod(rhs_q.original_shape[:rhs_fa]), - math.prod(rhs_q.original_shape[rhs_fa:]), + rhs_input_data, quantizer_set.kernel, group_sizes=None, flatten_axis=rhs_flatten_axis ) + lhs_data = lhs_q.data.reshape(lhs_q.original_shape) + rhs_data = rhs_q.data.reshape(rhs_q.original_shape) lhs_scale_inv = lhs_q.scale_inv rhs_scale_inv = rhs_q.scale_inv lhs_shape = lhs_q.original_shape rhs_shape = rhs_q.original_shape - # Data is already 2D; reset flatten axes so _flatten_to_2d calls below are no-ops. - lhs_flatten_axis = -1 - rhs_flatten_axis = -1 if lhs_data.dtype == jnp.float8_e5m2 and rhs_data.dtype == jnp.float8_e5m2: raise ValueError("FP8 GEMM does not support E5M2 * E5M2") @@ -2174,9 +2176,9 @@ def grouped_gemm( else: rhs_contract_dim = tuple((rhs_ndim - 1 - i) % rhs_ndim for i in rhs_contract_dim) - # Reshape inputs to 2D using the already-computed flatten_axes. - lhs_data_2d = _flatten_to_2d(lhs_data, lhs_flatten_axis) - rhs_data_2d = _flatten_to_2d(rhs_data, rhs_flatten_axis) + # Compute N-D axis boundaries from final (post-adjustment) contracting dims. + lhs_axis_boundary = get_lhs_axis_boundary(lhs_contract_dim, lhs_is_trans) + rhs_axis_boundary = get_rhs_axis_boundary(rhs_contract_dim, rhs_is_trans) num_gemms = ( lhs_first_dims.size @@ -2188,14 +2190,21 @@ def grouped_gemm( ) if num_gemms == 0: raise ValueError( - "grouped_gemm requires at least one non-empty dimension array " - "(lhs_first_dims, lhs_last_dims, rhs_first_dims, rhs_last_dims, " - "out_first_dims, or out_last_dims)." + "grouped_gemm requires at least one non-empty dimension array. " + "Ensure lhs or rhs tensor objects carry first_dims or last_dims." ) has_bias = bias is not None if has_bias: - N_dim = rhs_data_2d.shape[0] // num_gemms if rhs_is_trans else rhs_data_2d.shape[1] + # Compute N from rhs non-contracting dims. + if rhs_is_trans: + N_dim = math.prod( + rhs_data.shape[d] + for d in range(rhs_axis_boundary) + if rhs_group_axis is None or d != rhs_group_axis + ) + else: + N_dim = math.prod(rhs_data.shape[rhs_axis_boundary:]) assert bias.shape == ( num_gemms, N_dim, @@ -2218,9 +2227,9 @@ def grouped_gemm( additional_arg_1 = jnp.zeros((0,), jnp.int32) # unused placeholder (out,) = GroupedGemmPrimitive.outer_primitive.bind( - lhs_data_2d, + lhs_data, lhs_scale_inv, - rhs_data_2d, + rhs_data, rhs_scale_inv, bias, lhs_first_dims, @@ -2238,5 +2247,8 @@ def grouped_gemm( has_bias=has_bias, use_async_d2h_group_sizes=use_async_d2h_group_sizes, use_v2_ffi=use_v2_ffi, + lhs_axis_boundary=lhs_axis_boundary, + rhs_axis_boundary=rhs_axis_boundary, + rhs_group_axis=rhs_group_axis, ) return out diff --git a/transformer_engine/jax/cpp_extensions/quantization.py b/transformer_engine/jax/cpp_extensions/quantization.py index bf4e833c89..c8578d48b8 100644 --- a/transformer_engine/jax/cpp_extensions/quantization.py +++ b/transformer_engine/jax/cpp_extensions/quantization.py @@ -1203,6 +1203,7 @@ def grouped_quantize( ), f"Only flatten_axis = -1 is supported for now, got {flatten_axis}" group_axis = 0 + ragged_first_dims = group_sizes # None if no explicit group_sizes (kernel case) if group_sizes is None: group_sizes = jnp.ones(x.shape[group_axis], dtype=jnp.int32) @@ -1280,7 +1281,7 @@ def grouped_quantize( q_layout=quantizer.q_layout, data_layout=quantizer.get_data_layout(), flatten_axis=flatten_axis, - group_sizes=group_sizes, + first_dims=ragged_first_dims, original_shape=original_shape, group_axis=group_axis, ) diff --git a/transformer_engine/jax/csrc/extensions.h b/transformer_engine/jax/csrc/extensions.h index bd429a7db6..616209709b 100644 --- a/transformer_engine/jax/csrc/extensions.h +++ b/transformer_engine/jax/csrc/extensions.h @@ -59,6 +59,8 @@ struct GroupedGemmV2Config { bool lhs_is_trans; bool rhs_is_trans; JAXX_Scaling_Mode scaling_mode; + int64_t lhs_axis_boundary; + int64_t rhs_axis_boundary; }; struct GroupedGemmConfig { @@ -67,6 +69,8 @@ struct GroupedGemmConfig { JAXX_Scaling_Mode scaling_mode; bool has_bias; bool use_async_d2h_group_sizes; + int64_t lhs_axis_boundary; + int64_t rhs_axis_boundary; }; inline bool use_fp8(DType type) { return type == DType::kFloat8E4M3 || type == DType::kFloat8E5M2; } @@ -209,14 +213,18 @@ XLA_FFI_REGISTER_STRUCT_ATTR_DECODING( XLA_FFI_REGISTER_STRUCT_ATTR_DECODING( transformer_engine::jax::GroupedGemmV2Config, ::xla::ffi::StructMember("lhs_is_trans"), ::xla::ffi::StructMember("rhs_is_trans"), - ::xla::ffi::StructMember("scaling_mode")); + ::xla::ffi::StructMember("scaling_mode"), + ::xla::ffi::StructMember("lhs_axis_boundary"), + ::xla::ffi::StructMember("rhs_axis_boundary")); XLA_FFI_REGISTER_STRUCT_ATTR_DECODING( transformer_engine::jax::GroupedGemmConfig, ::xla::ffi::StructMember("lhs_is_trans"), ::xla::ffi::StructMember("rhs_is_trans"), ::xla::ffi::StructMember("scaling_mode"), ::xla::ffi::StructMember("has_bias"), - ::xla::ffi::StructMember("use_async_d2h_group_sizes")); + ::xla::ffi::StructMember("use_async_d2h_group_sizes"), + ::xla::ffi::StructMember("lhs_axis_boundary"), + ::xla::ffi::StructMember("rhs_axis_boundary")); // ENUM_ATTR and DICT_ATTR recoding need to be registered in the global namespace XLA_FFI_REGISTER_ENUM_ATTR_DECODING(transformer_engine::jax::JAXX_Scaling_Mode); diff --git a/transformer_engine/jax/csrc/extensions/gemm.cpp b/transformer_engine/jax/csrc/extensions/gemm.cpp index 2d73390d33..50bf43f349 100644 --- a/transformer_engine/jax/csrc/extensions/gemm.cpp +++ b/transformer_engine/jax/csrc/extensions/gemm.cpp @@ -627,13 +627,15 @@ JAXX_GroupedTensorWrapper make_grouped_tensor(Buffer_Type const &data, Buffer_Type const &last_dims, int64_t *int64_workspace_base, size_t int64_workspace_capacity, size_t &int64_offset, - size_t num_gemms, cudaStream_t stream) { + size_t num_gemms, cudaStream_t stream, + int64_t axis_boundary = -1) { auto dims = data.dimensions(); NVTE_CHECK(dims.size() >= 2, "grouped GEMM data buffer must be at least 2D."); - // Flatten all leading dimensions into the first axis to produce a 2D NVTE shape. - // Input buffers (lhs, rhs) are already 2D from the Python side. Output buffers may be ND - // (e.g. [G, K, N] for wgrad), so we collapse dims[0..N-2] → rows and keep dims[N-1] → cols. - NVTEShape dataShape{.data = {product(dims, 0, dims.size() - 1), dims[dims.size() - 1]}, + // Flatten dims at axis_boundary to produce a 2D NVTE shape. + // axis_boundary=-1 (default) collapses dims[0..N-2] → rows and keeps dims[N-1] → cols, + // preserving the prior behaviour for output buffers (e.g. [G, K, N] for wgrad). + size_t ab = (axis_boundary < 0) ? dims.size() - 1 : static_cast(axis_boundary); + NVTEShape dataShape{.data = {product(dims, 0, ab), product(dims, ab, dims.size())}, .ndim = 2}; JAXX_GroupedTensorWrapper wrapper(JAXX_Scaling_Mode::NO_SCALING, num_gemms, dataShape); wrapper.set_rowwise(data, std::nullopt); @@ -698,7 +700,7 @@ Error_Type GroupedGemmV2FFI(cudaStream_t stream, Buffer_Type lhs_data, Buffer_Ty Buffer_Type alpha, Buffer_Type beta, Result_Type output, Result_Type cublas_workspace, Result_Type setup_workspace, Result_Type int64_workspace, GroupedGemmV2Config config) { - auto [lhs_is_trans, rhs_is_trans, scaling_mode] = config; + auto [lhs_is_trans, rhs_is_trans, scaling_mode, lhs_axis_boundary, rhs_axis_boundary] = config; NVTE_CHECK(scaling_mode == JAXX_Scaling_Mode::NO_SCALING, "Only non-quantized grouped GEMM is supported in current implementation."); @@ -732,9 +734,11 @@ Error_Type GroupedGemmV2FFI(cudaStream_t stream, Buffer_Type lhs_data, Buffer_Ty size_t int64_capacity = int64_workspace->element_count() / sizeof(int64_t); size_t int64_offset = 0; auto rhs_tensor = make_grouped_tensor(rhs_data, rhs_first_dims, rhs_last_dims, int64_base, - int64_capacity, int64_offset, num_gemms, stream); + int64_capacity, int64_offset, num_gemms, stream, + rhs_axis_boundary); auto lhs_tensor = make_grouped_tensor(lhs_data, lhs_first_dims, lhs_last_dims, int64_base, - int64_capacity, int64_offset, num_gemms, stream); + int64_capacity, int64_offset, num_gemms, stream, + lhs_axis_boundary); auto out_tensor = make_grouped_tensor(*output, out_first_dims, out_last_dims, int64_base, int64_capacity, int64_offset, num_gemms, stream); @@ -777,7 +781,8 @@ Error_Type GroupedGemmFFI(cudaStream_t stream, Buffer_Type lhs_data, Buffer_Type Buffer_Type out_first_dims, Buffer_Type out_last_dims, Buffer_Type group_offset, Result_Type output, Result_Type workspace, GroupedGemmConfig config) { - auto [lhs_is_trans, rhs_is_trans, scaling_mode, has_bias, use_async_d2h_group_sizes] = config; + auto [lhs_is_trans, rhs_is_trans, scaling_mode, has_bias, use_async_d2h_group_sizes, + lhs_axis_boundary, rhs_axis_boundary] = config; // Notes on matrix layouts and transpose: // Jax uses row-major data_layout, on entering this function, each input matrix pair: // A: row-major [m, k] for N - [k, m] for T @@ -827,20 +832,26 @@ Error_Type GroupedGemmFFI(cudaStream_t stream, Buffer_Type lhs_data, Buffer_Type else if (is_rhs_last_ragged) active_gs_ptr = &rhs_last_dims; - // lhs_data and rhs_data are 2D; derive m, n, k from buffer dimensions. - NVTE_CHECK(lhs_data.dimensions().size() == 2, "lhs_data must be 2D."); - NVTE_CHECK(rhs_data.dimensions().size() == 2, "rhs_data must be 2D."); - size_t k = lhs_is_trans ? lhs_data.dimensions()[0] : lhs_data.dimensions()[1]; + // Derive m, n, k from N-D buffer dimensions using axis_boundary. + // axis_boundary splits contracting dims from non-contracting dims. + auto lhs_dims = lhs_data.dimensions(); + auto rhs_dims = rhs_data.dimensions(); + NVTE_CHECK(lhs_dims.size() >= 2, "lhs_data must be at least 2D."); + NVTE_CHECK(rhs_dims.size() >= 2, "rhs_data must be at least 2D."); + size_t lab = static_cast(lhs_axis_boundary); + size_t rab = static_cast(rhs_axis_boundary); + // k = product of contracting dims of lhs + size_t k = lhs_is_trans ? product(lhs_dims, 0, lab) : product(lhs_dims, lab, lhs_dims.size()); size_t m, n; if (is_rhs_ragged) { - // wgrad: lhs shape [K_lhs, M]: lhs_is_trans=True, contracting is dim[0]=K_lhs, output is dim[1]=M - m = lhs_is_trans ? lhs_data.dimensions()[1] : lhs_data.dimensions()[0]; - // T-layout rhs: (N, K_total) -> n = dim[0]; N-layout rhs: (K_total, N) -> n = dim[1] - n = rhs_is_trans ? rhs_data.dimensions()[0] : rhs_data.dimensions()[1]; + // wgrad: non-contracting lhs dims form M; non-contracting rhs dims form N + m = lhs_is_trans ? product(lhs_dims, lab, lhs_dims.size()) : product(lhs_dims, 0, lab); + n = rhs_is_trans ? product(rhs_dims, 0, rab) : product(rhs_dims, rab, rhs_dims.size()); } else { - m = lhs_is_trans ? lhs_data.dimensions()[1] - : lhs_data.dimensions()[0]; // total M (sum of group sizes) - n = rhs_is_trans ? rhs_data.dimensions()[0] / num_gemms : rhs_data.dimensions()[1]; + m = lhs_is_trans ? product(lhs_dims, lab, lhs_dims.size()) + : product(lhs_dims, 0, lab); // total M (sum of group sizes) + n = rhs_is_trans ? product(rhs_dims, 0, rab) / num_gemms + : product(rhs_dims, rab, rhs_dims.size()); } // Inputs diff --git a/transformer_engine/jax/dense.py b/transformer_engine/jax/dense.py index 8b397520f2..76c984486f 100644 --- a/transformer_engine/jax/dense.py +++ b/transformer_engine/jax/dense.py @@ -27,6 +27,7 @@ is_fp8_gemm_with_all_layouts_supported, TensorUsage, QuantizeLayout, + GroupedNoScaleTensor, ) @@ -490,7 +491,8 @@ def _grouped_dense_fwd_rule( is_colwise=False, data_layout="N", flatten_axis=ctx_kernel.flatten_axis, - group_sizes=ctx_kernel.group_sizes, + first_dims=ctx_kernel.first_dims, + last_dims=ctx_kernel.last_dims, original_shape=kernel_shape, group_axis=ctx_kernel.group_axis, ) @@ -507,7 +509,8 @@ def _grouped_dense_fwd_rule( is_colwise=True, data_layout="T", flatten_axis=ctx_kernel.flatten_axis, - group_sizes=ctx_kernel.group_sizes, + first_dims=ctx_kernel.first_dims, + last_dims=ctx_kernel.last_dims, original_shape=kernel_shape, group_axis=ctx_kernel.group_axis, ) @@ -518,16 +521,24 @@ def _grouped_dense_fwd_rule( # This is needed especially when kernel_fsdp_enabled == True AND FP8 enabled. quantizer_set.kernel.q_layout = original_quantizer_set_kernel_q_layout - empty_gs = jnp.empty((0,), jnp.int32) + if is_noop_quantizer_set: + grouped_gemm_x = GroupedNoScaleTensor( + data=grouped_gemm_x, + first_dims=group_sizes, + last_dims=None, + group_axis=0, + original_shape=grouped_gemm_x.shape, + ) + grouped_gemm_kernel = GroupedNoScaleTensor( + data=grouped_gemm_kernel, + first_dims=None, + last_dims=None, + group_axis=0, + original_shape=grouped_gemm_kernel.shape, + ) output = tex.grouped_gemm( grouped_gemm_x, grouped_gemm_kernel, - lhs_first_dims=group_sizes, - lhs_last_dims=empty_gs, - rhs_first_dims=empty_gs, - rhs_last_dims=empty_gs, - out_first_dims=group_sizes, - out_last_dims=empty_gs, contracting_dims=contracting_dims, bias=bias, precision=precision, @@ -616,16 +627,38 @@ def _grouped_dense_bwd_rule( wgrad_x_T = ctx_x wgrad_grad = casted_grad.get_tensor(usage=TensorUsage.RHS) - empty_gs = jnp.empty((0,), jnp.int32) + if is_noop_quantizer_set: + dgrad_grad = GroupedNoScaleTensor( + data=dgrad_grad, + first_dims=group_sizes, + last_dims=None, + group_axis=0, + original_shape=dgrad_grad.shape, + ) + dgrad_kernel_T = GroupedNoScaleTensor( + data=dgrad_kernel_T, + first_dims=None, + last_dims=None, + group_axis=0, + original_shape=dgrad_kernel_T.shape, + ) + wgrad_x_T = GroupedNoScaleTensor( + data=wgrad_x_T, + first_dims=group_sizes, + last_dims=None, + group_axis=0, + original_shape=wgrad_x_T.shape, + ) + wgrad_grad = GroupedNoScaleTensor( + data=wgrad_grad, + first_dims=group_sizes, + last_dims=None, + group_axis=0, + original_shape=wgrad_grad.shape, + ) dgrad = tex.grouped_gemm( dgrad_grad, dgrad_kernel_T, - lhs_first_dims=group_sizes, - lhs_last_dims=empty_gs, - rhs_first_dims=empty_gs, - rhs_last_dims=empty_gs, - out_first_dims=group_sizes, - out_last_dims=empty_gs, contracting_dims=dgrad_contracting_dims, precision=precision, preferred_element_type=preferred_element_type, @@ -635,12 +668,6 @@ def _grouped_dense_bwd_rule( wgrad = tex.grouped_gemm( wgrad_x_T, wgrad_grad, - lhs_first_dims=group_sizes, - lhs_last_dims=empty_gs, - rhs_first_dims=group_sizes, - rhs_last_dims=empty_gs, - out_first_dims=empty_gs, - out_last_dims=empty_gs, contracting_dims=wgrad_contracting_dims, precision=precision, preferred_element_type=preferred_element_type, diff --git a/transformer_engine/jax/quantize/dequantizer.py b/transformer_engine/jax/quantize/dequantizer.py index 74787b9308..8fd54a3a63 100644 --- a/transformer_engine/jax/quantize/dequantizer.py +++ b/transformer_engine/jax/quantize/dequantizer.py @@ -275,7 +275,16 @@ def _grouped_dequantize(grouped_scaled_tensor): """ data = grouped_scaled_tensor.data scale_inv = grouped_scaled_tensor.scale_inv - group_sizes = grouped_scaled_tensor.group_sizes + group_sizes = ( + grouped_scaled_tensor.first_dims + if grouped_scaled_tensor.first_dims is not None and grouped_scaled_tensor.first_dims.size > 0 + else grouped_scaled_tensor.last_dims + ) + # For non-ragged groups (kernel case), group_sizes is not stored; derive from original_shape + if group_sizes is None: + group_sizes = jnp.ones( + grouped_scaled_tensor.original_shape[grouped_scaled_tensor.group_axis], dtype=jnp.int32 + ) flatten_axis = grouped_scaled_tensor.flatten_axis scaling_mode = grouped_scaled_tensor.scaling_mode original_shape = grouped_scaled_tensor.original_shape diff --git a/transformer_engine/jax/quantize/quantizer.py b/transformer_engine/jax/quantize/quantizer.py index f5ca6aeaed..55dd7f5618 100644 --- a/transformer_engine/jax/quantize/quantizer.py +++ b/transformer_engine/jax/quantize/quantizer.py @@ -948,7 +948,7 @@ def _create_grouped_tensor_from_tensor_list( is_colwise=tensor_list[0].is_colwise, data_layout=tensor_list[0].data_layout, flatten_axis=tensor_list[0].flatten_axis, - group_sizes=group_sizes, + first_dims=group_sizes, original_shape=original_shape, group_axis=group_axis, ) diff --git a/transformer_engine/jax/quantize/tensor.py b/transformer_engine/jax/quantize/tensor.py index c26cb8a531..38433e95ae 100644 --- a/transformer_engine/jax/quantize/tensor.py +++ b/transformer_engine/jax/quantize/tensor.py @@ -9,7 +9,7 @@ rowwise and colwise quantization modes with proper scaling and dequantization. """ from dataclasses import dataclass -from typing import Callable, Tuple +from typing import Callable, Optional, Tuple from abc import ABC, abstractmethod import jax.numpy as jnp @@ -32,6 +32,7 @@ "ScaledTensor1x", "ScaledTensor2x", "GroupedScaledTensor1x", + "GroupedNoScaleTensor", "ScaledTensorFactory", "with_sharding_constraint_by_logical_axes", ] @@ -365,12 +366,14 @@ class GroupedScaledTensor1x(ScaledTensor1x): where elements are grouped along a specified axis. Attributes: - group_sizes: Array containing the size of each group + first_dims: Per-group sizes of the first (row) 2D dim, or None if not ragged + last_dims: Per-group sizes of the last (col) 2D dim, or None if not ragged original_shape: The original shape of the tensor before grouping group_axis: The axis along which grouping is performed (default: 0) """ - group_sizes: jnp.ndarray + first_dims: Optional[jnp.ndarray] + last_dims: Optional[jnp.ndarray] original_shape: Tuple group_axis: int @@ -379,7 +382,7 @@ def __init__( data, scale_inv, amax, - group_sizes, + first_dims, scaling_mode, dq_dtype, _dq_func, @@ -388,9 +391,11 @@ def __init__( flatten_axis, original_shape, group_axis=0, + last_dims=None, ): self.flatten_axis = flatten_axis - self.group_sizes = group_sizes + self.first_dims = first_dims + self.last_dims = last_dims self.original_shape = original_shape self.group_axis = group_axis # TODO(Phuong):Handle RHT for grouped quantization once grouped quantization supports NVFP4 @@ -422,9 +427,19 @@ def __post_init__(self): 0 <= self.group_axis < data_ndim ), f"group_axis {self.group_axis} is out of bounds for shape {self.original_shape}" + active_dims = ( + self.first_dims + if self.first_dims is not None and self.first_dims.size > 0 + else self.last_dims + ) + if active_dims is not None: + num_groups = active_dims.size + else: + num_groups = self.original_shape[self.group_axis] + expected_scale_shape = self.scaling_mode.get_grouped_scale_shape( self.original_shape, - self.group_sizes.size, + num_groups, self.group_axis, self.is_colwise, is_padded=True, @@ -442,7 +457,7 @@ def tree_flatten(self): Returns: A tuple containing (children, aux_data) for tree operations """ - children = (self.data, self.scale_inv, self.amax, self.group_sizes) + children = (self.data, self.scale_inv, self.amax, self.first_dims, self.last_dims) aux_data = ( self.scaling_mode, self.dq_dtype, @@ -455,6 +470,36 @@ def tree_flatten(self): ) return (children, aux_data) + @classmethod + def tree_unflatten(cls, aux_data, children): + """Reconstructs the tensor from its flattened representation.""" + data, scale_inv, amax, first_dims, last_dims = children + ( + scaling_mode, + dq_dtype, + _dq_func, + is_colwise, + data_layout, + flatten_axis, + original_shape, + group_axis, + ) = aux_data + return cls( + data=data, + scale_inv=scale_inv, + amax=amax, + first_dims=first_dims, + last_dims=last_dims, + scaling_mode=scaling_mode, + dq_dtype=dq_dtype, + _dq_func=_dq_func, + is_colwise=is_colwise, + data_layout=data_layout, + flatten_axis=flatten_axis, + original_shape=original_shape, + group_axis=group_axis, + ) + def apply_sharding_constraint_by_logical_axes(self, logical_axis_names: Tuple[str, ...]): raise NotImplementedError @@ -473,6 +518,52 @@ def checkpoint(self, quantizer): return jax_checkpoint_name(self, name=quantizer.checkpoint_name) +@register_pytree_node_class +@dataclass +class GroupedNoScaleTensor: + """Unquantized grouped tensor. + + Stores N-D data with per-group dimension sizes so that grouped_gemm() + can extract first/last dims automatically without explicit parameters. + + Attributes: + data: The raw (unquantized) tensor data in N-D layout + first_dims: Per-group sizes of the first (row) 2D dim, or None if not ragged + last_dims: Per-group sizes of the last (col) 2D dim, or None if not ragged + group_axis: Which axis of original_shape is the group batch prefix + original_shape: Shape of data (same as data.shape for N-D unquantized) + """ + + data: jnp.ndarray + first_dims: Optional[jnp.ndarray] + last_dims: Optional[jnp.ndarray] + group_axis: int + original_shape: Tuple + + def tree_flatten(self): + """Flattens the tensor for JAX tree operations.""" + children = (self.data, self.first_dims, self.last_dims) + aux_data = (self.group_axis, self.original_shape) + return (children, aux_data) + + @classmethod + def tree_unflatten(cls, aux_data, children): + """Reconstructs the tensor from its flattened representation.""" + group_axis, original_shape = aux_data + data, first_dims, last_dims = children + return cls( + data=data, + first_dims=first_dims, + last_dims=last_dims, + group_axis=group_axis, + original_shape=original_shape, + ) + + def dequantize(self): + """No-op dequantization — returns the raw data.""" + return self.data + + @register_pytree_node_class @dataclass class ScaledTensor2x(AbstractBaseTensor, ScaledTensor): @@ -570,7 +661,8 @@ def create_1x( is_colwise=False, data_layout="N", flatten_axis=-1, - group_sizes=None, + first_dims=None, + last_dims=None, original_shape=None, group_axis=0, has_rht_applied=False, @@ -586,29 +678,40 @@ def create_1x( is_colwise: Whether to use column-wise quantization (default: False) data_layout: The data_layout specification (default: "N") flatten_axis: The quantization axis for the tensor - group_sizes: Array of ints containing the size of each group (default: None) + first_dims: Per-group sizes of the first (row) 2D dim (default: None) + last_dims: Per-group sizes of the last (col) 2D dim (default: None) original_shape: The original shape of the tensor before grouping (default: None) group_axis: The axis along which grouping is performed (default: 0) has_rht_applied: Whether the tensor had the Randomized Hadamard Transform (RHT) applied during quantization (default: False) Returns: - A ScaledTensor1x or GroupedScaledTensor1x instance depending on whether group_sizes is provided + A ScaledTensor1x or GroupedScaledTensor1x instance depending on whether first_dims or last_dims is provided """ if amax is None: amax = jnp.empty((1,), dtype=jnp.float32) dequantizer = ScalingModeToDequantizerMap.get(scaling_mode) - if group_sizes is not None: - flatten_axis = (len(original_shape) + flatten_axis) % len(original_shape) + if first_dims is not None or last_dims is not None or ( + original_shape is not None and group_axis is not None + ): assert ( original_shape is not None ), "original_shape is not given for GroupedScaledTensor1x" + flatten_axis = (len(original_shape) + flatten_axis) % len(original_shape) + + # Determine num_groups from whichever dims array is provided, or from original_shape + active_dims = first_dims if first_dims is not None and first_dims.size > 0 else last_dims + if active_dims is not None: + num_groups = active_dims.size + else: + norm_group_axis = (len(original_shape) + group_axis) % len(original_shape) + num_groups = original_shape[norm_group_axis] # Handling attrs of transposed tensors group_axis = (len(original_shape) + group_axis) % len(original_shape) if data_layout == "T": - if original_shape[0] == group_sizes.size: + if original_shape[0] == num_groups: original_shape = ( original_shape[0], *original_shape[flatten_axis:], @@ -633,7 +736,8 @@ def create_1x( is_colwise=is_colwise, data_layout=data_layout, flatten_axis=flatten_axis, - group_sizes=group_sizes, + first_dims=first_dims, + last_dims=last_dims, original_shape=original_shape, group_axis=group_axis, ) @@ -668,7 +772,8 @@ def create_2x( dq_dtype=jnp.bfloat16, data_layout="NN", flatten_axis=-1, - group_sizes=None, + first_dims=None, + last_dims=None, original_shape=None, group_axis=0, rowwise_has_rht_applied=False, @@ -686,7 +791,8 @@ def create_2x( dq_dtype: The data type for dequantized values (default: bfloat16) data_layout: The data_layout specification (default: "NN") flatten_axis: The quantization axis for the tensor - group_sizes: Array containing the size of each group (default: None) + first_dims: Per-group sizes of the first (row) 2D dim (default: None) + last_dims: Per-group sizes of the last (col) 2D dim (default: None) original_shape: The original shape of the tensor before grouping (default: None) group_axis: The axis along which grouping is performed (default: 0) rowwise_has_rht_applied: Whether the row-wise tensor uses the Randomized Hadamard Transform (RHT) (default: False) @@ -710,7 +816,8 @@ def create_2x( is_colwise=False, data_layout=data_layout[0], flatten_axis=flatten_axis, - group_sizes=group_sizes, + first_dims=first_dims, + last_dims=last_dims, original_shape=original_shape, group_axis=group_axis, has_rht_applied=rowwise_has_rht_applied, @@ -724,7 +831,8 @@ def create_2x( is_colwise=True, data_layout=data_layout[1], flatten_axis=flatten_axis, - group_sizes=group_sizes, + first_dims=first_dims, + last_dims=last_dims, original_shape=original_shape, group_axis=group_axis, has_rht_applied=colwise_has_rht_applied, @@ -744,7 +852,8 @@ def create( data_layout: str = "NN", q_layout: QuantizeLayout = QuantizeLayout.ROWWISE, flatten_axis: int = -1, - group_sizes: jnp.ndarray = None, + first_dims: jnp.ndarray = None, + last_dims: jnp.ndarray = None, original_shape: Tuple[int] = None, group_axis: int = 0, rowwise_has_rht_applied: bool = False, @@ -762,7 +871,8 @@ def create( data_layout: The data_layout specification (default: "NN") q_layout: The quantization axis (default: ROWWISE) flatten_axis: The axis along which the tensor could be flattened to 2D (default: -1) - group_sizes: Array containing the size of each group (default: None) + first_dims: Per-group sizes of the first (row) 2D dim (default: None) + last_dims: Per-group sizes of the last (col) 2D dim (default: None) original_shape: The original shape of the tensor before grouping (default: None) group_axis: The axis along which grouping is performed (default: 0) rowwise_has_rht_applied: Whether the row-wise tensor uses the Randomized Hadamard Transform (RHT) (default: False) @@ -785,7 +895,8 @@ def create( dq_dtype, data_layout=data_layout, flatten_axis=flatten_axis, - group_sizes=group_sizes, + first_dims=first_dims, + last_dims=last_dims, original_shape=original_shape, group_axis=group_axis, rowwise_has_rht_applied=rowwise_has_rht_applied, @@ -802,7 +913,8 @@ def create( is_colwise=True, data_layout=data_layout[0], flatten_axis=flatten_axis, - group_sizes=group_sizes, + first_dims=first_dims, + last_dims=last_dims, original_shape=original_shape, group_axis=group_axis, has_rht_applied=colwise_has_rht_applied, @@ -817,7 +929,8 @@ def create( is_colwise=False, data_layout=data_layout[0], flatten_axis=flatten_axis, - group_sizes=group_sizes, + first_dims=first_dims, + last_dims=last_dims, original_shape=original_shape, group_axis=group_axis, has_rht_applied=rowwise_has_rht_applied, From 1d1fec90cfd3e05e195010512652a8c78e72c19e Mon Sep 17 00:00:00 2001 From: Jeremy Berchtold Date: Fri, 13 Mar 2026 08:42:34 -0700 Subject: [PATCH 14/60] MXFP8 grouped quantize V2 Signed-off-by: Jeremy Berchtold --- .../jax/cpp_extensions/quantization.py | 46 ++++++++++++++++--- transformer_engine/jax/csrc/extensions.h | 2 + .../jax/csrc/extensions/pybind.cpp | 1 + 3 files changed, 43 insertions(+), 6 deletions(-) diff --git a/transformer_engine/jax/cpp_extensions/quantization.py b/transformer_engine/jax/cpp_extensions/quantization.py index c8578d48b8..f7c0e796cc 100644 --- a/transformer_engine/jax/cpp_extensions/quantization.py +++ b/transformer_engine/jax/cpp_extensions/quantization.py @@ -993,7 +993,8 @@ class GroupedQuantizePrimitive(BasePrimitive): Cast Primitive wrapping nvte_quantize and nvte_quantize_dbias """ - name = "te_grouped_quantize_ffi" + name = "te_grouped_quantize_ffi" # V1: non-MXFP8 + name_v2 = "te_grouped_quantize_v2_ffi" # V2: MXFP8, CUDA-graph safe multiple_results = True impl_static_args = ( 3, @@ -1050,7 +1051,17 @@ def abstract( rowwise_scale_inv_shape = (1,) rowwise_out_aval = jax.core.ShapedArray(shape=rowwise_out_shape, dtype=out_dtype) - amax_aval = jax.core.ShapedArray(shape=(group_sizes_aval.size,), dtype=jnp.float32) + is_mxfp8 = ScalingMode(scaling_mode) == ScalingMode.MXFP8_1D_SCALING + if is_mxfp8: + # V2 path: 5th output is int64_workspace (n_groups * sizeof(int64_t) bytes as uint8) + fifth_out_aval = jax.core.ShapedArray( + shape=(group_sizes_aval.size * 8,), dtype=jnp.uint8 + ) + else: + # V1 path: 5th output is amax + fifth_out_aval = jax.core.ShapedArray( + shape=(group_sizes_aval.size,), dtype=jnp.float32 + ) if q_layout.has_colwise: colwise_out_shape = out_shape @@ -1070,7 +1081,7 @@ def abstract( colwise_out_aval, rowwise_scale_inv_aval, colwise_scale_inv_aval, - amax_aval, + fifth_out_aval, ) @staticmethod @@ -1084,9 +1095,15 @@ def outer_abstract(*args, **kwargs): colwise_out, scale_inv, colwise_scale_inv, - updated_amax, + fifth_out, ) = GroupedQuantizePrimitive.abstract(*args, **kwargs) - return rowwise_out, colwise_out, scale_inv, colwise_scale_inv, updated_amax + # For MXFP8, the inner abstract returns int64_workspace as the 5th output. + # The outer interface always presents amax (float32, n_groups) for a consistent API. + scaling_mode = kwargs.get("scaling_mode") + group_sizes_aval = args[2] + if ScalingMode(scaling_mode) == ScalingMode.MXFP8_1D_SCALING: + fifth_out = jax.core.ShapedArray(shape=(group_sizes_aval.size,), dtype=jnp.float32) + return rowwise_out, colwise_out, scale_inv, colwise_scale_inv, fifth_out @staticmethod def lowering( @@ -1111,6 +1128,17 @@ def lowering( assert scale_aval.dtype == jnp.float32 assert group_sizes_aval.dtype == jnp.int32 assert group_axis == 0 + is_mxfp8 = ScalingMode(scaling_mode) == ScalingMode.MXFP8_1D_SCALING + if is_mxfp8: + # V2: CUDA-graph safe; scale is passed but ignored by the C++ handler + return ffi.ffi_lowering(GroupedQuantizePrimitive.name_v2)( + ctx, + x, + scale, + group_sizes, + q_layout=q_layout.value.value, + flatten_axis=flatten_axis, + ) return ffi.ffi_lowering(GroupedQuantizePrimitive.name)( ctx, x, @@ -1142,7 +1170,7 @@ def impl( colwise_out, rowwise_scale_inv, colwise_scale_inv, - updated_amax, + fifth, ) = GroupedQuantizePrimitive.inner_primitive.bind( x, scale, @@ -1154,6 +1182,12 @@ def impl( group_axis=group_axis, scale_dtype=scale_dtype, ) + is_mxfp8 = ScalingMode(scaling_mode) == ScalingMode.MXFP8_1D_SCALING + if is_mxfp8: + # fifth is int64_workspace; return a dummy zero amax for interface compatibility + updated_amax = jnp.zeros((group_sizes.size,), jnp.float32) + else: + updated_amax = fifth return (rowwise_out, colwise_out, rowwise_scale_inv, colwise_scale_inv, updated_amax) diff --git a/transformer_engine/jax/csrc/extensions.h b/transformer_engine/jax/csrc/extensions.h index 616209709b..c832b4ebb2 100644 --- a/transformer_engine/jax/csrc/extensions.h +++ b/transformer_engine/jax/csrc/extensions.h @@ -111,6 +111,8 @@ XLA_FFI_DECLARE_HANDLER_SYMBOL(DBiasQuantizeHandler); XLA_FFI_DECLARE_HANDLER_SYMBOL(GroupedQuantizeHandler); +XLA_FFI_DECLARE_HANDLER_SYMBOL(GroupedQuantizeV2Handler); + XLA_FFI_DECLARE_HANDLER_SYMBOL(DequantizeHandler); pybind11::tuple GetDBiasQuantizeWorkspaceSizes(size_t batch_size, size_t hidden_size, diff --git a/transformer_engine/jax/csrc/extensions/pybind.cpp b/transformer_engine/jax/csrc/extensions/pybind.cpp index 28cb39b5d1..e3bc122403 100644 --- a/transformer_engine/jax/csrc/extensions/pybind.cpp +++ b/transformer_engine/jax/csrc/extensions/pybind.cpp @@ -33,6 +33,7 @@ pybind11::dict Registrations() { // Quantization dict["te_dbias_quantize_ffi"] = EncapsulateFFI(DBiasQuantizeHandler); dict["te_grouped_quantize_ffi"] = EncapsulateFFI(GroupedQuantizeHandler); + dict["te_grouped_quantize_v2_ffi"] = EncapsulateFFI(GroupedQuantizeV2Handler); dict["te_dequantize_ffi"] = EncapsulateFFI(DequantizeHandler); // Softmax From 269a5186715b3f63aa21df7912080da3ca0284d2 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Fri, 13 Mar 2026 15:46:34 +0000 Subject: [PATCH 15/60] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- transformer_engine/jax/cpp_extensions/gemm.py | 11 ++++---- .../jax/cpp_extensions/quantization.py | 6 ++--- .../jax/csrc/extensions/gemm.cpp | 26 ++++++++----------- .../jax/quantize/dequantizer.py | 3 ++- transformer_engine/jax/quantize/tensor.py | 10 ++++--- 5 files changed, 27 insertions(+), 29 deletions(-) diff --git a/transformer_engine/jax/cpp_extensions/gemm.py b/transformer_engine/jax/cpp_extensions/gemm.py index ff9194bdd9..c86cb1db55 100644 --- a/transformer_engine/jax/cpp_extensions/gemm.py +++ b/transformer_engine/jax/cpp_extensions/gemm.py @@ -1332,7 +1332,6 @@ def impl( register_primitive(GroupedGemmCopySizesPrimitive) - def _assert_grouped_gemm_dims_shapes( lhs_first_dims_aval, lhs_last_dims_aval, @@ -2036,8 +2035,7 @@ def grouped_gemm( rhs_group_axis = getattr(rhs, "group_axis", 0) else: raise TypeError( - "lhs must be GroupedNoScaleTensor or GroupedScaledTensor1x, " - f"got type={type(lhs)}" + f"lhs must be GroupedNoScaleTensor or GroupedScaledTensor1x, got type={type(lhs)}" ) if isinstance(rhs, GroupedNoScaleTensor): @@ -2061,8 +2059,7 @@ def grouped_gemm( scaling_mode = lhs.scaling_mode else: raise TypeError( - "rhs must be GroupedNoScaleTensor or GroupedScaledTensor1x, " - f"got type={type(rhs)}" + f"rhs must be GroupedNoScaleTensor or GroupedScaledTensor1x, got type={type(rhs)}" ) # Infer output dims from which operand has the ragged non-contracting dim. @@ -2130,7 +2127,9 @@ def grouped_gemm( ) lhs_input_data = lhs.data if isinstance(lhs, GroupedNoScaleTensor) else lhs_data rhs_input_data = rhs.data if isinstance(rhs, GroupedNoScaleTensor) else rhs_data - lhs_q = grouped_quantize(lhs_input_data, quantizer_set.x, active_group_sizes, lhs_flatten_axis) + lhs_q = grouped_quantize( + lhs_input_data, quantizer_set.x, active_group_sizes, lhs_flatten_axis + ) rhs_q = grouped_quantize( rhs_input_data, quantizer_set.kernel, group_sizes=None, flatten_axis=rhs_flatten_axis ) diff --git a/transformer_engine/jax/cpp_extensions/quantization.py b/transformer_engine/jax/cpp_extensions/quantization.py index f7c0e796cc..cb506160bf 100644 --- a/transformer_engine/jax/cpp_extensions/quantization.py +++ b/transformer_engine/jax/cpp_extensions/quantization.py @@ -993,7 +993,7 @@ class GroupedQuantizePrimitive(BasePrimitive): Cast Primitive wrapping nvte_quantize and nvte_quantize_dbias """ - name = "te_grouped_quantize_ffi" # V1: non-MXFP8 + name = "te_grouped_quantize_ffi" # V1: non-MXFP8 name_v2 = "te_grouped_quantize_v2_ffi" # V2: MXFP8, CUDA-graph safe multiple_results = True impl_static_args = ( @@ -1059,9 +1059,7 @@ def abstract( ) else: # V1 path: 5th output is amax - fifth_out_aval = jax.core.ShapedArray( - shape=(group_sizes_aval.size,), dtype=jnp.float32 - ) + fifth_out_aval = jax.core.ShapedArray(shape=(group_sizes_aval.size,), dtype=jnp.float32) if q_layout.has_colwise: colwise_out_shape = out_shape diff --git a/transformer_engine/jax/csrc/extensions/gemm.cpp b/transformer_engine/jax/csrc/extensions/gemm.cpp index 50bf43f349..07adf55577 100644 --- a/transformer_engine/jax/csrc/extensions/gemm.cpp +++ b/transformer_engine/jax/csrc/extensions/gemm.cpp @@ -622,21 +622,17 @@ JAXX_GroupedTensorWrapper make_grouped_tensor(Buffer_Type const &data, // int64_offset (in int64 elements) is updated on return to the next available slot so callers can // thread it through successive make_grouped_tensor calls without aliasing. Bounds are checked // before each slot is used. Only NO_SCALING is supported. -JAXX_GroupedTensorWrapper make_grouped_tensor(Buffer_Type const &data, - Buffer_Type const &first_dims, - Buffer_Type const &last_dims, - int64_t *int64_workspace_base, - size_t int64_workspace_capacity, size_t &int64_offset, - size_t num_gemms, cudaStream_t stream, - int64_t axis_boundary = -1) { +JAXX_GroupedTensorWrapper make_grouped_tensor( + Buffer_Type const &data, Buffer_Type const &first_dims, Buffer_Type const &last_dims, + int64_t *int64_workspace_base, size_t int64_workspace_capacity, size_t &int64_offset, + size_t num_gemms, cudaStream_t stream, int64_t axis_boundary = -1) { auto dims = data.dimensions(); NVTE_CHECK(dims.size() >= 2, "grouped GEMM data buffer must be at least 2D."); // Flatten dims at axis_boundary to produce a 2D NVTE shape. // axis_boundary=-1 (default) collapses dims[0..N-2] → rows and keeps dims[N-1] → cols, // preserving the prior behaviour for output buffers (e.g. [G, K, N] for wgrad). size_t ab = (axis_boundary < 0) ? dims.size() - 1 : static_cast(axis_boundary); - NVTEShape dataShape{.data = {product(dims, 0, ab), product(dims, ab, dims.size())}, - .ndim = 2}; + NVTEShape dataShape{.data = {product(dims, 0, ab), product(dims, ab, dims.size())}, .ndim = 2}; JAXX_GroupedTensorWrapper wrapper(JAXX_Scaling_Mode::NO_SCALING, num_gemms, dataShape); wrapper.set_rowwise(data, std::nullopt); if (first_dims.element_count() > 0) { @@ -733,12 +729,12 @@ Error_Type GroupedGemmV2FFI(cudaStream_t stream, Buffer_Type lhs_data, Buffer_Ty auto *int64_base = reinterpret_cast(int64_workspace->untyped_data()); size_t int64_capacity = int64_workspace->element_count() / sizeof(int64_t); size_t int64_offset = 0; - auto rhs_tensor = make_grouped_tensor(rhs_data, rhs_first_dims, rhs_last_dims, int64_base, - int64_capacity, int64_offset, num_gemms, stream, - rhs_axis_boundary); - auto lhs_tensor = make_grouped_tensor(lhs_data, lhs_first_dims, lhs_last_dims, int64_base, - int64_capacity, int64_offset, num_gemms, stream, - lhs_axis_boundary); + auto rhs_tensor = + make_grouped_tensor(rhs_data, rhs_first_dims, rhs_last_dims, int64_base, int64_capacity, + int64_offset, num_gemms, stream, rhs_axis_boundary); + auto lhs_tensor = + make_grouped_tensor(lhs_data, lhs_first_dims, lhs_last_dims, int64_base, int64_capacity, + int64_offset, num_gemms, stream, lhs_axis_boundary); auto out_tensor = make_grouped_tensor(*output, out_first_dims, out_last_dims, int64_base, int64_capacity, int64_offset, num_gemms, stream); diff --git a/transformer_engine/jax/quantize/dequantizer.py b/transformer_engine/jax/quantize/dequantizer.py index 8fd54a3a63..5075f1a664 100644 --- a/transformer_engine/jax/quantize/dequantizer.py +++ b/transformer_engine/jax/quantize/dequantizer.py @@ -277,7 +277,8 @@ def _grouped_dequantize(grouped_scaled_tensor): scale_inv = grouped_scaled_tensor.scale_inv group_sizes = ( grouped_scaled_tensor.first_dims - if grouped_scaled_tensor.first_dims is not None and grouped_scaled_tensor.first_dims.size > 0 + if grouped_scaled_tensor.first_dims is not None + and grouped_scaled_tensor.first_dims.size > 0 else grouped_scaled_tensor.last_dims ) # For non-ragged groups (kernel case), group_sizes is not stored; derive from original_shape diff --git a/transformer_engine/jax/quantize/tensor.py b/transformer_engine/jax/quantize/tensor.py index 38433e95ae..316e4f3139 100644 --- a/transformer_engine/jax/quantize/tensor.py +++ b/transformer_engine/jax/quantize/tensor.py @@ -692,8 +692,10 @@ def create_1x( dequantizer = ScalingModeToDequantizerMap.get(scaling_mode) - if first_dims is not None or last_dims is not None or ( - original_shape is not None and group_axis is not None + if ( + first_dims is not None + or last_dims is not None + or (original_shape is not None and group_axis is not None) ): assert ( original_shape is not None @@ -701,7 +703,9 @@ def create_1x( flatten_axis = (len(original_shape) + flatten_axis) % len(original_shape) # Determine num_groups from whichever dims array is provided, or from original_shape - active_dims = first_dims if first_dims is not None and first_dims.size > 0 else last_dims + active_dims = ( + first_dims if first_dims is not None and first_dims.size > 0 else last_dims + ) if active_dims is not None: num_groups = active_dims.size else: From 2b84dfd0bc0281f826547f51f559c8625f7cfb64 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Fri, 13 Mar 2026 15:47:20 +0000 Subject: [PATCH 16/60] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- transformer_engine/jax/cpp_extensions/gemm.py | 11 ++++---- .../jax/csrc/extensions/gemm.cpp | 26 ++++++++----------- .../jax/quantize/dequantizer.py | 3 ++- transformer_engine/jax/quantize/tensor.py | 10 ++++--- 4 files changed, 25 insertions(+), 25 deletions(-) diff --git a/transformer_engine/jax/cpp_extensions/gemm.py b/transformer_engine/jax/cpp_extensions/gemm.py index ff9194bdd9..c86cb1db55 100644 --- a/transformer_engine/jax/cpp_extensions/gemm.py +++ b/transformer_engine/jax/cpp_extensions/gemm.py @@ -1332,7 +1332,6 @@ def impl( register_primitive(GroupedGemmCopySizesPrimitive) - def _assert_grouped_gemm_dims_shapes( lhs_first_dims_aval, lhs_last_dims_aval, @@ -2036,8 +2035,7 @@ def grouped_gemm( rhs_group_axis = getattr(rhs, "group_axis", 0) else: raise TypeError( - "lhs must be GroupedNoScaleTensor or GroupedScaledTensor1x, " - f"got type={type(lhs)}" + f"lhs must be GroupedNoScaleTensor or GroupedScaledTensor1x, got type={type(lhs)}" ) if isinstance(rhs, GroupedNoScaleTensor): @@ -2061,8 +2059,7 @@ def grouped_gemm( scaling_mode = lhs.scaling_mode else: raise TypeError( - "rhs must be GroupedNoScaleTensor or GroupedScaledTensor1x, " - f"got type={type(rhs)}" + f"rhs must be GroupedNoScaleTensor or GroupedScaledTensor1x, got type={type(rhs)}" ) # Infer output dims from which operand has the ragged non-contracting dim. @@ -2130,7 +2127,9 @@ def grouped_gemm( ) lhs_input_data = lhs.data if isinstance(lhs, GroupedNoScaleTensor) else lhs_data rhs_input_data = rhs.data if isinstance(rhs, GroupedNoScaleTensor) else rhs_data - lhs_q = grouped_quantize(lhs_input_data, quantizer_set.x, active_group_sizes, lhs_flatten_axis) + lhs_q = grouped_quantize( + lhs_input_data, quantizer_set.x, active_group_sizes, lhs_flatten_axis + ) rhs_q = grouped_quantize( rhs_input_data, quantizer_set.kernel, group_sizes=None, flatten_axis=rhs_flatten_axis ) diff --git a/transformer_engine/jax/csrc/extensions/gemm.cpp b/transformer_engine/jax/csrc/extensions/gemm.cpp index 50bf43f349..07adf55577 100644 --- a/transformer_engine/jax/csrc/extensions/gemm.cpp +++ b/transformer_engine/jax/csrc/extensions/gemm.cpp @@ -622,21 +622,17 @@ JAXX_GroupedTensorWrapper make_grouped_tensor(Buffer_Type const &data, // int64_offset (in int64 elements) is updated on return to the next available slot so callers can // thread it through successive make_grouped_tensor calls without aliasing. Bounds are checked // before each slot is used. Only NO_SCALING is supported. -JAXX_GroupedTensorWrapper make_grouped_tensor(Buffer_Type const &data, - Buffer_Type const &first_dims, - Buffer_Type const &last_dims, - int64_t *int64_workspace_base, - size_t int64_workspace_capacity, size_t &int64_offset, - size_t num_gemms, cudaStream_t stream, - int64_t axis_boundary = -1) { +JAXX_GroupedTensorWrapper make_grouped_tensor( + Buffer_Type const &data, Buffer_Type const &first_dims, Buffer_Type const &last_dims, + int64_t *int64_workspace_base, size_t int64_workspace_capacity, size_t &int64_offset, + size_t num_gemms, cudaStream_t stream, int64_t axis_boundary = -1) { auto dims = data.dimensions(); NVTE_CHECK(dims.size() >= 2, "grouped GEMM data buffer must be at least 2D."); // Flatten dims at axis_boundary to produce a 2D NVTE shape. // axis_boundary=-1 (default) collapses dims[0..N-2] → rows and keeps dims[N-1] → cols, // preserving the prior behaviour for output buffers (e.g. [G, K, N] for wgrad). size_t ab = (axis_boundary < 0) ? dims.size() - 1 : static_cast(axis_boundary); - NVTEShape dataShape{.data = {product(dims, 0, ab), product(dims, ab, dims.size())}, - .ndim = 2}; + NVTEShape dataShape{.data = {product(dims, 0, ab), product(dims, ab, dims.size())}, .ndim = 2}; JAXX_GroupedTensorWrapper wrapper(JAXX_Scaling_Mode::NO_SCALING, num_gemms, dataShape); wrapper.set_rowwise(data, std::nullopt); if (first_dims.element_count() > 0) { @@ -733,12 +729,12 @@ Error_Type GroupedGemmV2FFI(cudaStream_t stream, Buffer_Type lhs_data, Buffer_Ty auto *int64_base = reinterpret_cast(int64_workspace->untyped_data()); size_t int64_capacity = int64_workspace->element_count() / sizeof(int64_t); size_t int64_offset = 0; - auto rhs_tensor = make_grouped_tensor(rhs_data, rhs_first_dims, rhs_last_dims, int64_base, - int64_capacity, int64_offset, num_gemms, stream, - rhs_axis_boundary); - auto lhs_tensor = make_grouped_tensor(lhs_data, lhs_first_dims, lhs_last_dims, int64_base, - int64_capacity, int64_offset, num_gemms, stream, - lhs_axis_boundary); + auto rhs_tensor = + make_grouped_tensor(rhs_data, rhs_first_dims, rhs_last_dims, int64_base, int64_capacity, + int64_offset, num_gemms, stream, rhs_axis_boundary); + auto lhs_tensor = + make_grouped_tensor(lhs_data, lhs_first_dims, lhs_last_dims, int64_base, int64_capacity, + int64_offset, num_gemms, stream, lhs_axis_boundary); auto out_tensor = make_grouped_tensor(*output, out_first_dims, out_last_dims, int64_base, int64_capacity, int64_offset, num_gemms, stream); diff --git a/transformer_engine/jax/quantize/dequantizer.py b/transformer_engine/jax/quantize/dequantizer.py index 8fd54a3a63..5075f1a664 100644 --- a/transformer_engine/jax/quantize/dequantizer.py +++ b/transformer_engine/jax/quantize/dequantizer.py @@ -277,7 +277,8 @@ def _grouped_dequantize(grouped_scaled_tensor): scale_inv = grouped_scaled_tensor.scale_inv group_sizes = ( grouped_scaled_tensor.first_dims - if grouped_scaled_tensor.first_dims is not None and grouped_scaled_tensor.first_dims.size > 0 + if grouped_scaled_tensor.first_dims is not None + and grouped_scaled_tensor.first_dims.size > 0 else grouped_scaled_tensor.last_dims ) # For non-ragged groups (kernel case), group_sizes is not stored; derive from original_shape diff --git a/transformer_engine/jax/quantize/tensor.py b/transformer_engine/jax/quantize/tensor.py index 38433e95ae..316e4f3139 100644 --- a/transformer_engine/jax/quantize/tensor.py +++ b/transformer_engine/jax/quantize/tensor.py @@ -692,8 +692,10 @@ def create_1x( dequantizer = ScalingModeToDequantizerMap.get(scaling_mode) - if first_dims is not None or last_dims is not None or ( - original_shape is not None and group_axis is not None + if ( + first_dims is not None + or last_dims is not None + or (original_shape is not None and group_axis is not None) ): assert ( original_shape is not None @@ -701,7 +703,9 @@ def create_1x( flatten_axis = (len(original_shape) + flatten_axis) % len(original_shape) # Determine num_groups from whichever dims array is provided, or from original_shape - active_dims = first_dims if first_dims is not None and first_dims.size > 0 else last_dims + active_dims = ( + first_dims if first_dims is not None and first_dims.size > 0 else last_dims + ) if active_dims is not None: num_groups = active_dims.size else: From b2b3216fcb5f188b01bc36ca89e9c4be381f6093 Mon Sep 17 00:00:00 2001 From: Jeremy Berchtold Date: Fri, 13 Mar 2026 19:02:35 -0700 Subject: [PATCH 17/60] MXFP8 quantization working Signed-off-by: Jeremy Berchtold --- tests/jax/test_custom_call_compute.py | 148 +++++++++++++++- .../cast/mxfp8/group_quantize_mxfp8.cuh | 4 + .../common/gemm/cublaslt_grouped_gemm.cu | 39 +++++ .../common/include/transformer_engine/gemm.h | 30 ++++ .../jax/cpp_extensions/quantization.py | 65 +++++-- .../jax/csrc/extensions/gemm.cpp | 16 +- .../jax/csrc/extensions/quantization.cpp | 164 ++++++++++++++++++ transformer_engine/jax/flax/module.py | 9 +- transformer_engine/jax/quantize/tensor.py | 22 ++- 9 files changed, 464 insertions(+), 33 deletions(-) diff --git a/tests/jax/test_custom_call_compute.py b/tests/jax/test_custom_call_compute.py index 9fddbc435c..db44621d80 100644 --- a/tests/jax/test_custom_call_compute.py +++ b/tests/jax/test_custom_call_compute.py @@ -1737,7 +1737,9 @@ def _ref_grouped_dense(self, lhs, rhs, bias, group_sizes, contracting_dims): ref_out.append(jnp.squeeze(out_i)) return ref_out - def _generate_grouped_dense_input(self, dtype, input_shape, data_layout="NN", with_bias=False): + def _generate_grouped_dense_input( + self, dtype, input_shape, data_layout="NN", with_bias=False, group_size_multiplier=32 + ): key = jax.random.PRNGKey(0) subkeys = jax.random.split(key, 4) n_groups, m, n, k = input_shape @@ -1750,9 +1752,12 @@ def _generate_grouped_dense_input(self, dtype, input_shape, data_layout="NN", wi group_sizes = group_sizes.at[1].set(0) assert group_sizes.sum() == m - # *32 to make sure that input shape works for MXFP8 - group_sizes = group_sizes * 32 - m = m * 32 + # Scale group sizes by the multiplier. + # Use group_size_multiplier=128 for MXFP8 V2 tests so that each group's row count + # is divisible by 128, satisfying the V2 kernel's per-group alignment requirement. + # Use group_size_multiplier=32 for V1 tests or non-MXFP8 tests. + group_sizes = group_sizes * group_size_multiplier + m = m * group_size_multiplier lhs_shape = (m if data_layout[0] == "N" else k, k if data_layout[0] == "N" else m) rhs_shape = (n_groups, k if data_layout[1] == "N" else n, n if data_layout[1] == "N" else k) @@ -1826,8 +1831,10 @@ def test_grouped_gemm_fp8(self, fwd_bwd_dtype, scaling_mode, input_shape, layout quantizer.q_dtype = bwd_dtype out_dtype = jnp.bfloat16 + # MXFP8 V2 kernel requires each group's row count to be divisible by 128. + is_mxfp8 = scaling_mode == ScalingMode.MXFP8_1D_SCALING lhs, rhs, group_sizes, contracting_dims, _ = self._generate_grouped_dense_input( - out_dtype, input_shape, layout + out_dtype, input_shape, layout, group_size_multiplier=128 if is_mxfp8 else 32 ) ref_out = self._ref_grouped_dense(lhs, rhs, None, group_sizes, contracting_dims) @@ -1901,10 +1908,13 @@ def test_grouped_dense_grad_fp16(self, dtype, input_shape): def test_grouped_dense_grad_fp8(self, fwd_bwd_dtype, scaling_mode, input_shape): fwd_dtype, bwd_dtype = fwd_bwd_dtype dtype = jnp.bfloat16 + # MXFP8 V2 kernel requires each group's row count to be divisible by 128. + is_mxfp8 = scaling_mode == ScalingMode.MXFP8_1D_SCALING x, kernel, group_sizes, contracting_dims, bias = self._generate_grouped_dense_input( dtype, input_shape, with_bias=True, + group_size_multiplier=128 if is_mxfp8 else 32, ) quantizer_set = QuantizerFactory.create_set( @@ -1938,6 +1948,134 @@ def test_grouped_dense_grad_fp8(self, fwd_bwd_dtype, scaling_mode, input_shape): assert_allclose(prim_dbias, ref_dbias, dtype=dtype) +# MXFP8 V1 shapes: lhs total_rows = m * 32 and rhs total_rows = n_groups * k are +# NOT divisible by 128, forcing the V1 (non-CUDA-graph-safe) kernel. +GROUPED_DENSE_MXFP8_V1_INPUT_SHAPES = [ + # (n_groups, m, n, k) + # lhs total_rows = m * 32; rhs total_rows = n_groups * k + (5, 6, 128, 64), # lhs: 6*32=192 (not 128-aligned); rhs: 5*64=320 (not 128-aligned) +] + +# MXFP8 V2 shapes: lhs total_rows = m * 128 and rhs total_rows = n_groups * k are +# divisible by 128, allowing the V2 (CUDA-graph-safe) kernel to be used. +# These shapes must be paired with group_size_multiplier=128 so that each group's +# row count is also divisible by 128 (the V2 per-group alignment requirement). +GROUPED_DENSE_MXFP8_V2_INPUT_SHAPES = [ + # (n_groups, m, n, k) + # lhs total_rows = m * 128; rhs total_rows = n_groups * k + (8, 8, 128, 128), # lhs: 8*128=1024 (128-aligned); rhs: 8*128=1024 (128-aligned) + (4, 4, 64, 256), # lhs: 4*128=512 (128-aligned); rhs: 4*256=1024 (128-aligned) +] + + +@pytest.mark.skipif(not is_fp8_supported, reason=fp8_unsupported_reason) +class TestGroupedDenseMXFP8KernelSelection: + """Tests that explicitly verify V1 and V2 MXFP8 grouped quantize kernel selection. + + V2 is the CUDA-graph-safe kernel and requires: + - total_first_dim (= product of input shape up to flatten_axis) % 128 == 0 + - each individual group_size % 128 == 0 (enforced by the kernel at runtime) + V1 is the fallback that supports arbitrary shapes but performs a D2H copy of + group_sizes (not CUDA-graph safe). + """ + + def _generate_mxfp8_input(self, input_shape, group_size_multiplier): + """Generate inputs with the given group_size_multiplier for MXFP8 tests.""" + key = jax.random.PRNGKey(42) + subkeys = jax.random.split(key, 3) + n_groups, m, n, k = input_shape + + group_sizes = jnp.sort(jax.random.randint(subkeys[0], (n_groups - 1,), 0, m)) + group_sizes = jnp.concatenate([jnp.array([0]), group_sizes, jnp.array([m])]) + group_sizes = jnp.diff(group_sizes) + group_sizes = group_sizes.at[0].set(group_sizes[0] + group_sizes[1]) + group_sizes = group_sizes.at[1].set(0) + group_sizes = group_sizes * group_size_multiplier + m_total = m * group_size_multiplier + + lhs = jax.random.uniform(subkeys[1], (m_total, k), dtype=jnp.bfloat16) + rhs = jax.random.uniform(subkeys[2], (n_groups, k, n), dtype=jnp.bfloat16) + return lhs, rhs, group_sizes + + @pytest.mark.parametrize( + "input_shape", + GROUPED_DENSE_MXFP8_V1_INPUT_SHAPES, + ids=[f"v1_{s}" for s in GROUPED_DENSE_MXFP8_V1_INPUT_SHAPES], + ) + def test_grouped_gemm_mxfp8_v1_shapes(self, input_shape): + """MXFP8 grouped GEMM with V1-only shapes (total_first_dim not 128-aligned).""" + lhs, rhs, group_sizes = self._generate_mxfp8_input(input_shape, group_size_multiplier=32) + quantizer_set = QuantizerFactory.create_set( + scaling_mode=ScalingMode.MXFP8_1D_SCALING, + fwd_dtype=jnp.float8_e4m3fn, + bwd_dtype=jnp.float8_e4m3fn, + is_2x2x=False, + n_groups=input_shape[0], + ) + lhs_tensor = GroupedNoScaleTensor( + data=lhs, first_dims=group_sizes, last_dims=None, group_axis=0, original_shape=lhs.shape + ) + rhs_tensor = GroupedNoScaleTensor( + data=rhs, first_dims=None, last_dims=None, group_axis=0, original_shape=rhs.shape + ) + # Reference: unquantized grouped GEMM + n_groups = input_shape[0] + lhs_splits = jnp.split(lhs, jnp.cumulative_sum(group_sizes)[:-1], axis=0) + rhs_splits = jnp.split(rhs, n_groups, axis=0) + ref_out = jnp.concatenate( + [jnp.squeeze(lhs_i @ rhs_i, axis=0) for lhs_i, rhs_i in zip(lhs_splits, rhs_splits)], + axis=0, + ) + prim_out = jax.jit(tex.grouped_gemm, static_argnames=("contracting_dims",))( + lhs_tensor, + rhs_tensor, + contracting_dims=((1,), (1,)), + quantizer_set=quantizer_set, + ) + # Check output has correct shape and dtype; numerical precision is expected to be lower + # due to FP8 quantization but the result should be finite. + assert prim_out.shape == ref_out.shape + assert jnp.all(jnp.isfinite(prim_out)) + + @pytest.mark.parametrize( + "input_shape", + GROUPED_DENSE_MXFP8_V2_INPUT_SHAPES, + ids=[f"v2_{s}" for s in GROUPED_DENSE_MXFP8_V2_INPUT_SHAPES], + ) + def test_grouped_gemm_mxfp8_v2_shapes(self, input_shape): + """MXFP8 grouped GEMM with V2-eligible shapes (total_first_dim 128-aligned, + group_sizes also 128-aligned).""" + lhs, rhs, group_sizes = self._generate_mxfp8_input(input_shape, group_size_multiplier=128) + quantizer_set = QuantizerFactory.create_set( + scaling_mode=ScalingMode.MXFP8_1D_SCALING, + fwd_dtype=jnp.float8_e4m3fn, + bwd_dtype=jnp.float8_e4m3fn, + is_2x2x=False, + n_groups=input_shape[0], + ) + lhs_tensor = GroupedNoScaleTensor( + data=lhs, first_dims=group_sizes, last_dims=None, group_axis=0, original_shape=lhs.shape + ) + rhs_tensor = GroupedNoScaleTensor( + data=rhs, first_dims=None, last_dims=None, group_axis=0, original_shape=rhs.shape + ) + n_groups = input_shape[0] + lhs_splits = jnp.split(lhs, jnp.cumulative_sum(group_sizes)[:-1], axis=0) + rhs_splits = jnp.split(rhs, n_groups, axis=0) + ref_out = jnp.concatenate( + [jnp.squeeze(lhs_i @ rhs_i, axis=0) for lhs_i, rhs_i in zip(lhs_splits, rhs_splits)], + axis=0, + ) + prim_out = jax.jit(tex.grouped_gemm, static_argnames=("contracting_dims",))( + lhs_tensor, + rhs_tensor, + contracting_dims=((1,), (1,)), + quantizer_set=quantizer_set, + ) + assert prim_out.shape == ref_out.shape + assert jnp.all(jnp.isfinite(prim_out)) + + class TestDebugInspectFFI: @pytest_parametrize_wrapper("shape", [(256, 128)]) diff --git a/transformer_engine/common/cast/mxfp8/group_quantize_mxfp8.cuh b/transformer_engine/common/cast/mxfp8/group_quantize_mxfp8.cuh index 129d6724ac..d7eaf028e0 100644 --- a/transformer_engine/common/cast/mxfp8/group_quantize_mxfp8.cuh +++ b/transformer_engine/common/cast/mxfp8/group_quantize_mxfp8.cuh @@ -192,6 +192,10 @@ __global__ void update_tma_descriptors( const size_t offset_elts = offsets_ptr[tensor_id]; if (leading_thread && (tensor_id < num_tensors)) { + // Zero-sized groups: skip TMA descriptor update. The main kernel already returns + // early for rows==0 or cols==0, but creating a TMA descriptor with a zero dimension + // is invalid and causes CUDA_ERROR_ILLEGAL_ADDRESS. + if (rows == 0 || cols == 0) return; { const uintptr_t global_data_ptr = reinterpret_cast(input_data_ptr + offset_elts); modify_base_tensor_map(base_tensor_map_input, &g_tensor_maps_input[tensor_id], diff --git a/transformer_engine/common/gemm/cublaslt_grouped_gemm.cu b/transformer_engine/common/gemm/cublaslt_grouped_gemm.cu index dc4757ab90..01f5361481 100644 --- a/transformer_engine/common/gemm/cublaslt_grouped_gemm.cu +++ b/transformer_engine/common/gemm/cublaslt_grouped_gemm.cu @@ -682,6 +682,24 @@ __global__ void convert_int32_to_int64_kernel(const int32_t *src, int64_t *dst, if (idx < n) dst[idx] = static_cast(src[idx]); } +// Like convert_int32_to_int64_kernel but scales each element by multiplier. +// Used to convert per-expert slice counts to per-expert row counts for multi-dim tensors. +__global__ void convert_int32_to_int64_with_multiplier_kernel(const int32_t *src, int64_t *dst, + size_t n, int64_t multiplier) { + size_t idx = blockIdx.x * blockDim.x + threadIdx.x; + if (idx < n) dst[idx] = static_cast(src[idx]) * multiplier; +} + +// Computes exclusive prefix sums: offsets[0]=0, offsets[i]=sum(first_dims[0..i-1]*last_dim). +// Produces n_groups+1 values. Single-threaded sequential scan; n_groups is typically small. +__global__ void compute_grouped_tensor_offsets_kernel(const int64_t *first_dims, int64_t *offsets, + size_t n_groups, int64_t last_dim) { + offsets[0] = 0; + for (size_t i = 0; i < n_groups; i++) { + offsets[i + 1] = offsets[i] + first_dims[i] * last_dim; + } +} + } // namespace void nvte_convert_int32_to_int64(const int32_t *src, int64_t *dst, size_t n, cudaStream_t stream) { @@ -692,3 +710,24 @@ void nvte_convert_int32_to_int64(const int32_t *src, int64_t *dst, size_t n, cud convert_int32_to_int64_kernel<<>>(src, dst, n); NVTE_CHECK_CUDA(cudaGetLastError()); } + +void nvte_convert_int32_to_int64_with_multiplier(const int32_t *src, int64_t *dst, size_t n, + int64_t multiplier, cudaStream_t stream) { + NVTE_API_CALL(nvte_convert_int32_to_int64_with_multiplier); + if (n == 0) return; + const int threads = 256; + const int blocks = static_cast((n + threads - 1) / threads); + convert_int32_to_int64_with_multiplier_kernel<<>>(src, dst, n, + multiplier); + NVTE_CHECK_CUDA(cudaGetLastError()); +} + +void nvte_compute_grouped_tensor_offsets(const int64_t *first_dims, int64_t *offsets, + size_t n_groups, int64_t last_dim, + cudaStream_t stream) { + NVTE_API_CALL(nvte_compute_grouped_tensor_offsets); + // Always write at least offsets[0]=0 (needed even for n_groups==0). + compute_grouped_tensor_offsets_kernel<<<1, 1, 0, stream>>>(first_dims, offsets, n_groups, + last_dim); + NVTE_CHECK_CUDA(cudaGetLastError()); +} diff --git a/transformer_engine/common/include/transformer_engine/gemm.h b/transformer_engine/common/include/transformer_engine/gemm.h index 0f3b0ebd6b..28490653bc 100644 --- a/transformer_engine/common/include/transformer_engine/gemm.h +++ b/transformer_engine/common/include/transformer_engine/gemm.h @@ -354,6 +354,36 @@ size_t nvte_get_grouped_gemm_setup_workspace_size(size_t num_tensors); */ void nvte_convert_int32_to_int64(const int32_t *src, int64_t *dst, size_t n, cudaStream_t stream); +/*! \brief Convert int32 array to int64 while scaling each element by a multiplier. + * + * Computes dst[i] = (int64_t)src[i] * multiplier for each i in [0, n). + * CUDA-graph safe (no host-device synchronization). + * + * \param[in] src Device pointer to source int32 array. + * \param[out] dst Device pointer to destination int64 array. + * \param[in] n Number of elements. + * \param[in] multiplier Scale factor applied to each element. + * \param[in] stream CUDA stream. + */ +void nvte_convert_int32_to_int64_with_multiplier(const int32_t *src, int64_t *dst, size_t n, + int64_t multiplier, cudaStream_t stream); + +/*! \brief Compute exclusive prefix-sum offsets from per-group first-dimension sizes. + * + * Writes n_groups+1 values to offsets: offsets[0]=0, + * offsets[i] = sum(first_dims[0..i-1] * last_dim) for i in [1, n_groups]. + * This is CUDA-graph safe (no host-device synchronization). + * + * \param[in] first_dims Device pointer to int64 array of length n_groups. + * \param[out] offsets Device pointer to int64 array of length n_groups+1. + * \param[in] n_groups Number of groups. + * \param[in] last_dim Common last dimension (number of columns). + * \param[in] stream CUDA stream. + */ +void nvte_compute_grouped_tensor_offsets(const int64_t *first_dims, int64_t *offsets, + size_t n_groups, int64_t last_dim, + cudaStream_t stream); + void nvte_grouped_gemm(const NVTEGroupedTensor A, int transa, const NVTEGroupedTensor B, int transb, const NVTEGroupedTensor C, NVTEGroupedTensor D, const NVTETensor alpha, const NVTETensor beta, NVTETensor workspace_setup, diff --git a/transformer_engine/jax/cpp_extensions/quantization.py b/transformer_engine/jax/cpp_extensions/quantization.py index cb506160bf..02639cfa2d 100644 --- a/transformer_engine/jax/cpp_extensions/quantization.py +++ b/transformer_engine/jax/cpp_extensions/quantization.py @@ -993,7 +993,7 @@ class GroupedQuantizePrimitive(BasePrimitive): Cast Primitive wrapping nvte_quantize and nvte_quantize_dbias """ - name = "te_grouped_quantize_ffi" # V1: non-MXFP8 + name = "te_grouped_quantize_ffi" # V1: fallback path (supports all shapes, not CUDA-graph safe) name_v2 = "te_grouped_quantize_v2_ffi" # V2: MXFP8, CUDA-graph safe multiple_results = True impl_static_args = ( @@ -1007,6 +1007,38 @@ class GroupedQuantizePrimitive(BasePrimitive): inner_primitive = None outer_primitive = None + @staticmethod + def _use_v2_kernel(scaling_mode, x_shape, flatten_axis): + """Return True when the V2 (CUDA-graph-safe) MXFP8 kernel can be used. + + V2 requires: + 1. The total first logical dimension (product of x_shape up to flatten_axis) + is divisible by 128. + 2. For multi-dim group tensors (eff > 1, e.g., kernel shape G×K×N), the + per-group row count non_group_m = prod(x_shape[1:eff]) must also be + divisible by 128 (because group_sizes[i] counts slices, not rows, and + actual rows per group = group_sizes[i] * non_group_m). + 3. For lhs-style tensors (eff == 1, shape M×K), individual group sizes must + be 128-aligned -- this is a dynamic constraint assumed by the caller. + + Falls back to V1 when constraints are not met. V1 supports arbitrary shapes + but performs a D2H copy of group_sizes (not CUDA-graph safe). + """ + if ScalingMode(scaling_mode) != ScalingMode.MXFP8_1D_SCALING: + return False + ndim = len(x_shape) + eff = flatten_axis if flatten_axis >= 0 else flatten_axis + ndim + total_first_dim = math.prod(x_shape[:eff]) + if total_first_dim % 128 != 0: + return False + # For multi-dim group tensors (e.g., kernel shape G×K×N with eff=2), + # non_group_m = K must also be 128-aligned. + if eff > 1: + non_group_m = math.prod(x_shape[1:eff]) + if non_group_m % 128 != 0: + return False + return True + @staticmethod def abstract( x_aval, @@ -1051,11 +1083,14 @@ def abstract( rowwise_scale_inv_shape = (1,) rowwise_out_aval = jax.core.ShapedArray(shape=rowwise_out_shape, dtype=out_dtype) - is_mxfp8 = ScalingMode(scaling_mode) == ScalingMode.MXFP8_1D_SCALING - if is_mxfp8: - # V2 path: 5th output is int64_workspace (n_groups * sizeof(int64_t) bytes as uint8) + use_v2 = GroupedQuantizePrimitive._use_v2_kernel(scaling_mode, x_aval.shape, flatten_axis) + if use_v2: + # V2 path: 5th output is int64_workspace laid out as: + # [n_groups int64 group_sizes | n_groups+1 int64 offsets] + # = (2*n_groups + 1) * sizeof(int64_t) bytes stored as uint8. + n_groups = group_sizes_aval.size fifth_out_aval = jax.core.ShapedArray( - shape=(group_sizes_aval.size * 8,), dtype=jnp.uint8 + shape=((2 * n_groups + 1) * 8,), dtype=jnp.uint8 ) else: # V1 path: 5th output is amax @@ -1095,11 +1130,13 @@ def outer_abstract(*args, **kwargs): colwise_scale_inv, fifth_out, ) = GroupedQuantizePrimitive.abstract(*args, **kwargs) - # For MXFP8, the inner abstract returns int64_workspace as the 5th output. + # When V2 is used, the inner abstract returns int64_workspace as the 5th output. # The outer interface always presents amax (float32, n_groups) for a consistent API. scaling_mode = kwargs.get("scaling_mode") + x_aval = args[0] group_sizes_aval = args[2] - if ScalingMode(scaling_mode) == ScalingMode.MXFP8_1D_SCALING: + flatten_axis = kwargs.get("flatten_axis") + if GroupedQuantizePrimitive._use_v2_kernel(scaling_mode, x_aval.shape, flatten_axis): fifth_out = jax.core.ShapedArray(shape=(group_sizes_aval.size,), dtype=jnp.float32) return rowwise_out, colwise_out, scale_inv, colwise_scale_inv, fifth_out @@ -1126,9 +1163,11 @@ def lowering( assert scale_aval.dtype == jnp.float32 assert group_sizes_aval.dtype == jnp.int32 assert group_axis == 0 - is_mxfp8 = ScalingMode(scaling_mode) == ScalingMode.MXFP8_1D_SCALING - if is_mxfp8: - # V2: CUDA-graph safe; scale is passed but ignored by the C++ handler + use_v2 = GroupedQuantizePrimitive._use_v2_kernel(scaling_mode, x_aval.shape, flatten_axis) + if use_v2: + # V2: CUDA-graph safe; scale is passed but ignored by the C++ handler. + # Requires total_first_dim % 128 == 0 (checked above) and all individual + # group sizes % 128 == 0 (dynamic constraint, enforced by the kernel). return ffi.ffi_lowering(GroupedQuantizePrimitive.name_v2)( ctx, x, @@ -1137,6 +1176,8 @@ def lowering( q_layout=q_layout.value.value, flatten_axis=flatten_axis, ) + # V1: supports arbitrary shapes but not CUDA-graph safe (performs D2H copy of group_sizes). + # Used for non-MXFP8 scaling modes and for MXFP8 when total_first_dim % 128 != 0. return ffi.ffi_lowering(GroupedQuantizePrimitive.name)( ctx, x, @@ -1180,8 +1221,8 @@ def impl( group_axis=group_axis, scale_dtype=scale_dtype, ) - is_mxfp8 = ScalingMode(scaling_mode) == ScalingMode.MXFP8_1D_SCALING - if is_mxfp8: + use_v2 = GroupedQuantizePrimitive._use_v2_kernel(scaling_mode, x.shape, flatten_axis) + if use_v2: # fifth is int64_workspace; return a dummy zero amax for interface compatibility updated_amax = jnp.zeros((group_sizes.size,), jnp.float32) else: diff --git a/transformer_engine/jax/csrc/extensions/gemm.cpp b/transformer_engine/jax/csrc/extensions/gemm.cpp index 07adf55577..495d4cc4bb 100644 --- a/transformer_engine/jax/csrc/extensions/gemm.cpp +++ b/transformer_engine/jax/csrc/extensions/gemm.cpp @@ -965,14 +965,14 @@ Error_Type GroupedGemmFFI(cudaStream_t stream, Buffer_Type lhs_data, Buffer_Type // Note: This may break cudaGraph. cudaStreamSynchronize(stream); } - size_t sum_group_sizes = std::accumulate(dim_list_host.begin(), dim_list_host.end(), 0); - if (!is_rhs_ragged) { - NVTE_CHECK(m == sum_group_sizes, "Unexpected group_sizes! M = ", m, - ", got sum(group_sizes)=", sum_group_sizes); - } else { - NVTE_CHECK(k == sum_group_sizes, "Unexpected group_sizes! K = ", k, - ", got sum(group_sizes)=", sum_group_sizes); - } + // size_t sum_group_sizes = std::accumulate(dim_list_host.begin(), dim_list_host.end(), 0); + // if (!is_rhs_ragged) { + // NVTE_CHECK(m == sum_group_sizes, "Unexpected group_sizes! M = ", m, + // ", got sum(group_sizes)=", sum_group_sizes); + // } else { + // NVTE_CHECK(k == sum_group_sizes, "Unexpected group_sizes! K = ", k, + // ", got sum(group_sizes)=", sum_group_sizes); + // } } auto num_math_sm = cuda::sm_count() - getenv("NVTE_EXT_MARGIN_SM", 0); diff --git a/transformer_engine/jax/csrc/extensions/quantization.cpp b/transformer_engine/jax/csrc/extensions/quantization.cpp index c5a766f7f2..fb083a958f 100644 --- a/transformer_engine/jax/csrc/extensions/quantization.cpp +++ b/transformer_engine/jax/csrc/extensions/quantization.cpp @@ -9,6 +9,7 @@ #include "../extensions.h" #include "transformer_engine/cast.h" +#include "transformer_engine/gemm.h" #include "transformer_engine/hadamard_transform.h" #include "transformer_engine/recipe.h" #include "transformer_engine/transformer_engine.h" @@ -494,5 +495,168 @@ XLA_FFI_DEFINE_HANDLER_SYMBOL(GroupedQuantizeHandler, GroupedQuantizeFFI, .Attr("q_layout") .Attr("flatten_axis")); +Error_Type GroupedQuantizeV2FFI(cudaStream_t stream, Buffer_Type inputs, + Buffer_Type scale_unused, Buffer_Type group_sizes, + Result_Type rowwise_out, Result_Type colwise_out, + Result_Type rowwise_sinv, Result_Type colwise_sinv, + Result_Type int64_workspace, + JAXX_Quantize_Layout quantize_layout, int64_t flatten_axis) { + (void)scale_unused; // scale is unused for MXFP8; accepted to match V1 input arity + auto in_dtype = convert_ffi_datatype_to_te_dtype(inputs.element_type()); + auto out_dtype = convert_ffi_datatype_to_te_dtype(rowwise_out->element_type()); + auto sinv_dtype = convert_ffi_datatype_to_te_dtype(rowwise_sinv->element_type()); + + NVTE_CHECK(is_fp8_dtype(out_dtype), "Output datatype must be FP8 for GroupedQuantizeV2."); + NVTE_CHECK(sinv_dtype == DType::kFloat8E8M0, + "scale_inv must be E8M0 for MXFP8 grouped quantize."); + + auto input_dims = inputs.dimensions(); + int64_t input_ndim = input_dims.size(); + if (flatten_axis < 0) flatten_axis += input_ndim; + NVTE_CHECK(flatten_axis < input_ndim && flatten_axis > 0, "flatten_axis is out of bounds!"); + + auto m = product(input_dims, 0, flatten_axis); + auto n = product(input_dims, flatten_axis, input_ndim); + size_t n_groups = group_sizes.dimensions()[0]; + + // Workspace layout (CUDA-graph safe, all device-side): + // int64_ptr[0 .. n_groups-1] : per-group ROW counts (int64) + // int64_ptr[n_groups .. 2*n_groups] : exclusive prefix-sum offsets (n_groups+1 values) + auto *int64_ptr = reinterpret_cast(int64_workspace->untyped_data()); + auto *offsets_ptr_out = int64_ptr + n_groups; // n_groups+1 values follow group_sizes + + // non_group_m handles multi-dim tensors (e.g., kernel shape G×K×N with flatten_axis=2): + // group_sizes[i] counts "slices" along the outermost group axis (e.g., 1 per expert), + // while the kernel expects actual ROW counts (e.g., K rows per expert). + // non_group_m = product(input_dims[1..flatten_axis)) converts slice→row count. + // For the lhs case (shape M×K, flatten_axis=1), non_group_m=1 (no-op). + int64_t non_group_m = + (flatten_axis > 1) ? product(input_dims, 1, static_cast(flatten_axis)) : 1; + + // Convert int32 group_sizes to int64 row counts on device (CUDA-graph safe, no D2H). + nvte_convert_int32_to_int64_with_multiplier( + reinterpret_cast(group_sizes.untyped_data()), int64_ptr, n_groups, + non_group_m, stream); + + // Compute exclusive prefix-sum offsets on device (CUDA-graph safe, no D2H). + nvte_compute_grouped_tensor_offsets(int64_ptr, offsets_ptr_out, n_groups, + static_cast(n), stream); + + NVTEShape data_shape{}; + data_shape.data[0] = m; + data_shape.data[1] = n; + data_shape.ndim = 2; + + NVTEShape sz_shape{}; + sz_shape.ndim = 1; + sz_shape.data[0] = n_groups; + + // Offsets tensor has n_groups+1 elements (exclusive prefix sums with sentinel). + NVTEShape offsets_shape{}; + offsets_shape.ndim = 1; + offsets_shape.data[0] = n_groups + 1; + + // Build input grouped tensor (plain float data, no quantization on the input side). + NVTEGroupedTensor in_grouped = + nvte_create_grouped_tensor(get_nvte_scaling_mode(JAXX_Scaling_Mode::NO_SCALING), + n_groups, data_shape); + { + NVTEBasicTensor in_data{reinterpret_cast(inputs.untyped_data()), + static_cast(in_dtype), data_shape}; + nvte_set_grouped_tensor_param(in_grouped, kNVTEGroupedRowwiseData, &in_data, sizeof(in_data)); + NVTEBasicTensor sz_tensor{reinterpret_cast(int64_ptr), NVTEDType::kNVTEInt64, + sz_shape}; + nvte_set_grouped_tensor_param(in_grouped, kNVTEGroupedFirstDims, &sz_tensor, sizeof(sz_tensor)); + NVTEBasicTensor offsets_tensor{reinterpret_cast(offsets_ptr_out), + NVTEDType::kNVTEInt64, offsets_shape}; + nvte_set_grouped_tensor_param(in_grouped, kNVTEGroupedTensorOffsets, &offsets_tensor, + sizeof(offsets_tensor)); + } + + // Build output grouped tensor. + NVTEGroupedTensor out_grouped = + nvte_create_grouped_tensor(get_nvte_scaling_mode(JAXX_Scaling_Mode::MXFP8_1D_SCALING), + n_groups, data_shape); + + // Set group sizes and offsets on output tensor (same device pointers). + { + NVTEBasicTensor sz_tensor{reinterpret_cast(int64_ptr), NVTEDType::kNVTEInt64, + sz_shape}; + nvte_set_grouped_tensor_param(out_grouped, kNVTEGroupedFirstDims, &sz_tensor, + sizeof(sz_tensor)); + NVTEBasicTensor offsets_tensor{reinterpret_cast(offsets_ptr_out), + NVTEDType::kNVTEInt64, offsets_shape}; + nvte_set_grouped_tensor_param(out_grouped, kNVTEGroupedTensorOffsets, &offsets_tensor, + sizeof(offsets_tensor)); + } + + // Rowwise output data + scale_inv. + if (is_quantize_rowwise(quantize_layout)) { + NVTEBasicTensor rw_data{reinterpret_cast(rowwise_out->untyped_data()), + static_cast(out_dtype), data_shape}; + nvte_set_grouped_tensor_param(out_grouped, kNVTEGroupedRowwiseData, &rw_data, sizeof(rw_data)); + + auto sinv_dims = rowwise_sinv->dimensions(); + NVTEShape rw_sinv_shape{}; + rw_sinv_shape.ndim = 2; + rw_sinv_shape.data[0] = product(sinv_dims, 0, sinv_dims.size() - 1); + rw_sinv_shape.data[1] = sinv_dims.back(); + NVTEBasicTensor rw_sinv{reinterpret_cast(rowwise_sinv->untyped_data()), + static_cast(sinv_dtype), rw_sinv_shape}; + nvte_set_grouped_tensor_param(out_grouped, kNVTEGroupedRowwiseScaleInv, &rw_sinv, + sizeof(rw_sinv)); + } + + // Colwise output data + scale_inv. + if (is_quantize_colwise(quantize_layout)) { + NVTEBasicTensor cw_data{reinterpret_cast(colwise_out->untyped_data()), + static_cast(out_dtype), data_shape}; + nvte_set_grouped_tensor_param(out_grouped, kNVTEGroupedColumnwiseData, &cw_data, + sizeof(cw_data)); + + auto cw_sinv_dims = colwise_sinv->dimensions(); + NVTEShape cw_sinv_shape{}; + cw_sinv_shape.ndim = 2; + cw_sinv_shape.data[0] = product(cw_sinv_dims, 0, cw_sinv_dims.size() - 1); + cw_sinv_shape.data[1] = cw_sinv_dims.back(); + NVTEBasicTensor cw_sinv{reinterpret_cast(colwise_sinv->untyped_data()), + static_cast(sinv_dtype), cw_sinv_shape}; + nvte_set_grouped_tensor_param(out_grouped, kNVTEGroupedColumnwiseScaleInv, &cw_sinv, + sizeof(cw_sinv)); + } + + // Zero-initialize scale_inv buffers (mirrors V1 behaviour for MXFP8). + size_t total_rowwise_sinv_size = + is_quantize_rowwise(quantize_layout) ? product(rowwise_sinv->dimensions()) : 0; + size_t total_colwise_sinv_size = + is_quantize_colwise(quantize_layout) ? product(colwise_sinv->dimensions()) : 0; + if (total_rowwise_sinv_size > 0) + nvte_memset(rowwise_sinv->untyped_data(), 0, total_rowwise_sinv_size, stream); + if (total_colwise_sinv_size > 0) + nvte_memset(colwise_sinv->untyped_data(), 0, total_colwise_sinv_size, stream); + + nvte_group_quantize(in_grouped, out_grouped, stream); + + nvte_destroy_grouped_tensor(in_grouped); + nvte_destroy_grouped_tensor(out_grouped); + + return ffi_with_cuda_error_check(); +} + +XLA_FFI_DEFINE_HANDLER_SYMBOL(GroupedQuantizeV2Handler, GroupedQuantizeV2FFI, + FFI::Bind() + .Ctx() // stream + .Arg() // inputs + .Arg() // scale (unused, for input arity match) + .Arg() // group_sizes (int32) + .Ret() // rowwise_out + .Ret() // colwise_out + .Ret() // rowwise_sinv + .Ret() // colwise_sinv + .Ret() // int64_workspace + .Attr("q_layout") + .Attr("flatten_axis"), + FFI_CudaGraph_Traits); + } // namespace jax } // namespace transformer_engine diff --git a/transformer_engine/jax/flax/module.py b/transformer_engine/jax/flax/module.py index 31ce6e72e9..4f19f449bc 100644 --- a/transformer_engine/jax/flax/module.py +++ b/transformer_engine/jax/flax/module.py @@ -16,6 +16,9 @@ from jax import random as jax_random from jax.ad_checkpoint import checkpoint_name +from transformer_engine.common.recipe import ( + MXFP8BlockScaling, +) from ..dense import dense, grouped_dense @@ -1446,7 +1449,11 @@ def te_dot_general(generate_quantizer_set, x, kernel, dims, **kwargs): def make_grouped_dense_cls(quantization_recipe): """Creates a grouped dense (grouped GEMM) instance for use with TE state module.""" if quantization_recipe is not None: - raise ValueError("Ragged dot grouped GEMM does not support quantization yet") + allowed_grouped_gemm_recipes = [MXFP8BlockScaling] + assert any(isinstance(quantization_recipe, r) for r in allowed_grouped_gemm_recipes), ( + f"Only the following quantization recipes are supported for grouped GEMM or `None` for BF16 without quantization: {allowed_grouped_gemm_recipes}. " + f"Got {type(quantization_recipe)}." + ) def te_grouped_dot_general(generate_quantizer_set, x, kernel, group_sizes, **kwargs): del kwargs # Unused diff --git a/transformer_engine/jax/quantize/tensor.py b/transformer_engine/jax/quantize/tensor.py index 316e4f3139..4511b63a64 100644 --- a/transformer_engine/jax/quantize/tensor.py +++ b/transformer_engine/jax/quantize/tensor.py @@ -412,6 +412,18 @@ def __init__( has_rht_applied=False, ) + @property + def group_sizes(self) -> jnp.ndarray: + """Per-group sizes along the group axis. + + When first_dims is set (ragged groups), returns first_dims. + When first_dims is None (equal-sized groups), returns an array of ones with + length equal to the number of groups. + """ + if self.first_dims is not None and self.first_dims.size > 0: + return self.first_dims + return jnp.ones((self.original_shape[self.group_axis],), dtype=jnp.int32) + def __post_init__(self): assert self.scale_inv.ndim == 1, "Only support flattened scale_inv" assert self.data.ndim == 1, "Only support flattened data" @@ -692,10 +704,8 @@ def create_1x( dequantizer = ScalingModeToDequantizerMap.get(scaling_mode) - if ( - first_dims is not None - or last_dims is not None - or (original_shape is not None and group_axis is not None) + if first_dims is not None or last_dims is not None or ( + original_shape is not None and group_axis is not None ): assert ( original_shape is not None @@ -703,9 +713,7 @@ def create_1x( flatten_axis = (len(original_shape) + flatten_axis) % len(original_shape) # Determine num_groups from whichever dims array is provided, or from original_shape - active_dims = ( - first_dims if first_dims is not None and first_dims.size > 0 else last_dims - ) + active_dims = first_dims if first_dims is not None and first_dims.size > 0 else last_dims if active_dims is not None: num_groups = active_dims.size else: From 611526ff545aadb532c714c0f713d0ab6a9ccd6d Mon Sep 17 00:00:00 2001 From: Jeremy Berchtold Date: Sat, 14 Mar 2026 00:00:02 -0700 Subject: [PATCH 18/60] mxfp8 grouped gemm Signed-off-by: Jeremy Berchtold --- tests/jax/test_custom_call_compute.py | 52 +++++++ transformer_engine/jax/cpp_extensions/gemm.py | 107 ++++++++++++-- .../jax/csrc/extensions/gemm.cpp | 132 ++++++++++++++++-- 3 files changed, 273 insertions(+), 18 deletions(-) diff --git a/tests/jax/test_custom_call_compute.py b/tests/jax/test_custom_call_compute.py index db44621d80..1f1286a399 100644 --- a/tests/jax/test_custom_call_compute.py +++ b/tests/jax/test_custom_call_compute.py @@ -2074,6 +2074,58 @@ def test_grouped_gemm_mxfp8_v2_shapes(self, input_shape): ) assert prim_out.shape == ref_out.shape assert jnp.all(jnp.isfinite(prim_out)) + # Numerical check within FP8 tolerance + assert_allclose(prim_out, ref_out, dtype=jnp.float8_e4m3fn) + + @pytest.mark.parametrize( + "input_shape", + GROUPED_DENSE_MXFP8_V2_INPUT_SHAPES, + ids=[f"v2_grad_{s}" for s in GROUPED_DENSE_MXFP8_V2_INPUT_SHAPES], + ) + def test_grouped_dense_grad_mxfp8_v2(self, input_shape): + """MXFP8 V2 grouped GEMM gradient test (fwd + dgrad + wgrad).""" + lhs, rhs, group_sizes = self._generate_mxfp8_input(input_shape, group_size_multiplier=128) + n_groups = input_shape[0] + fwd_dtype = jnp.float8_e4m3fn + bwd_dtype = jnp.float8_e4m3fn + + quantizer_set = QuantizerFactory.create_set( + scaling_mode=ScalingMode.MXFP8_1D_SCALING, + fwd_dtype=fwd_dtype, + bwd_dtype=bwd_dtype, + is_2x2x=True, + n_groups=n_groups, + ) + + contracting_dims = ((1,), (1,)) + + def _ref_sum(x, kernel, group_sizes): + lhs_splits = jnp.split(x, jnp.cumulative_sum(group_sizes)[:-1], axis=0) + rhs_splits = jnp.split(kernel, n_groups, axis=0) + out = jnp.concatenate( + [jnp.squeeze(li @ ri, axis=0) for li, ri in zip(lhs_splits, rhs_splits)], axis=0 + ) + return jnp.sum(out) / jnp.sqrt(x.size) + + def _prim_sum(x, kernel, group_sizes): + out = grouped_dense( + x, + kernel, + group_sizes, + contracting_dims, + bias=None, + quantizer_set=quantizer_set, + ) + return jnp.sum(jnp.asarray(out)) / jnp.sqrt(x.size) + + ref_val, (ref_dx, ref_dk) = value_and_grad(_ref_sum, (0, 1))(lhs, rhs, group_sizes) + prim_val, (prim_dx, prim_dk) = jit( + value_and_grad(_prim_sum, (0, 1)), static_argnums=() + )(lhs, rhs, group_sizes) + + assert_allclose(prim_val, ref_val, dtype=fwd_dtype) + assert_allclose(prim_dx, ref_dx, dtype=bwd_dtype) + assert_allclose(prim_dk, ref_dk, dtype=bwd_dtype) class TestDebugInspectFFI: diff --git a/transformer_engine/jax/cpp_extensions/gemm.py b/transformer_engine/jax/cpp_extensions/gemm.py index 1f69535b9b..e4fd9c99a6 100644 --- a/transformer_engine/jax/cpp_extensions/gemm.py +++ b/transformer_engine/jax/cpp_extensions/gemm.py @@ -379,6 +379,25 @@ def swizzled_scale(scale_inv, flatten_axis, is_colwise): return swizzled.reshape(original_shape) +def _swizzle_grouped_scale(scale_inv, scale_2d_shape, is_colwise): + """Swizzle a 1D grouped scale_inv buffer using full-tensor swizzle. + + The grouped scale_inv is 1D (worst-case padded). The meaningful prefix has size + equal to prod(scale_2d_shape). We reshape that prefix to 2D, swizzle it, and + write it back, leaving any trailing padding untouched. + """ + useful_size = math.prod(scale_2d_shape) + if useful_size == scale_inv.shape[0]: + # No trailing padding — reshape, swizzle, flatten. + return swizzled_scale(scale_inv.reshape(scale_2d_shape), 1, is_colwise).reshape( + scale_inv.shape + ) + # Split meaningful prefix from trailing padding, swizzle prefix only. + prefix = scale_inv[:useful_size].reshape(scale_2d_shape) + swizzled = swizzled_scale(prefix, 1, is_colwise).reshape((useful_size,)) + return jnp.concatenate([swizzled, scale_inv[useful_size:]]) + + def get_lhs_axis_boundary(lhs_cdims, is_transposed): """Get the axis boundary for the LHS operand.""" return max(lhs_cdims) + 1 if is_transposed else min(lhs_cdims) @@ -1626,9 +1645,11 @@ def _compute_cublas_workspace_size( workspace_size += lhs_scale_inv_aval.size * tensor_scaling_sinv_aligment workspace_size += rhs_scale_inv_aval.size * tensor_scaling_sinv_aligment elif scaling_mode == ScalingMode.MXFP8_1D_SCALING.value: - # We also pad scale_inv swizzle buffers size for 256 bytes alignment. - workspace_size += lhs_scale_inv_aval.size + mxfp8_scaling_sinv_alignment_padding - workspace_size += rhs_scale_inv_aval.size + mxfp8_scaling_sinv_alignment_padding + if not use_v2_ffi: + # V1 needs workspace for per-group swizzle output buffers. + # V2: scales are pre-swizzled in JAX, no extra workspace needed. + workspace_size += lhs_scale_inv_aval.size + mxfp8_scaling_sinv_alignment_padding + workspace_size += rhs_scale_inv_aval.size + mxfp8_scaling_sinv_alignment_padding return workspace_size @staticmethod @@ -2024,11 +2045,14 @@ def _can_use_v2_grouped_gemm( scaling_mode: ScalingMode, dtype: jnp.dtype, has_bias: bool, + lhs_shape=None, + rhs_shape=None, + lhs_axis_boundary=None, + rhs_axis_boundary=None, ) -> bool: """Determine whether the cuda-graphable grouped GEMM implementation can be used based on the input parameters.""" - # Use the cuda-graphable path for plain BF16 non-quantized inputs; fall back to the legacy - # nvte_multi_tensor_gemm path for all other cases (FP8, MXFP8, etc.) to stay - # feature-compatible with the main branch. + # Use the cuda-graphable path for plain BF16 non-quantized inputs and MXFP8; fall back to + # the legacy nvte_multi_tensor_gemm path for all other cases (tensor-scaled FP8, etc.). # Bias can be supported in a kernel or in pure-JAX in the future. if not _v2_grouped_gemm_available: @@ -2039,7 +2063,42 @@ def _can_use_v2_grouped_gemm( if get_device_compute_capability(0) < 100: return False - return scaling_mode == ScalingMode.NO_SCALING and dtype == jnp.bfloat16 and not has_bias + if has_bias: + return False + + if scaling_mode == ScalingMode.NO_SCALING and dtype == jnp.bfloat16: + return True + if scaling_mode == ScalingMode.MXFP8_1D_SCALING: + # V2 MXFP8 requires that the total first dimension of both operands (up to + # axis_boundary) is divisible by 128, matching the quantize V2 kernel requirement. + # Individual group sizes must also be 128-aligned (dynamic constraint). + if lhs_shape is not None and lhs_axis_boundary is not None: + lhs_first_dim = math.prod(lhs_shape[:lhs_axis_boundary]) + if lhs_first_dim % 128 != 0: + return False + if rhs_shape is not None and rhs_axis_boundary is not None: + rhs_first_dim = math.prod(rhs_shape[:rhs_axis_boundary]) + if rhs_first_dim % 128 != 0: + return False + # V2 MXFP8 also requires that the "last" dimension (after axis_boundary) of both + # operands is a multiple of 128. The V2 GEMM setup kernel computes per-group + # scale pointers as ``data_offset / 32``, which equals ``K_blocks * last_dim``. + # The quantize kernel, however, pads the colwise scale stride to + # ``ceil(last_dim / 128) * 128``, making per-group padded scale larger than + # ``K_blocks * last_dim`` when ``last_dim`` is not 128-aligned. This causes + # adjacent groups' scales to overlap in the flat buffer. Fall back to V1 (which + # swizzles per-group scales individually) when the condition is not met. + if lhs_shape is not None and lhs_axis_boundary is not None: + lhs_last_dim = math.prod(lhs_shape[lhs_axis_boundary:]) + if lhs_last_dim % 128 != 0: + return False + if rhs_shape is not None and rhs_axis_boundary is not None: + rhs_last_dim = math.prod(rhs_shape[rhs_axis_boundary:]) + if rhs_last_dim % 128 != 0: + return False + return True + + return False def grouped_gemm( @@ -2278,7 +2337,39 @@ def grouped_gemm( " and padded with zeros to not affect the result of the MoE block." ) - use_v2_ffi = _can_use_v2_grouped_gemm(scaling_mode, lhs_data.dtype, has_bias) + use_v2_ffi = _can_use_v2_grouped_gemm( + scaling_mode, lhs_data.dtype, has_bias, + lhs_shape=lhs_data.shape, rhs_shape=rhs_data.shape, + lhs_axis_boundary=lhs_axis_boundary, rhs_axis_boundary=rhs_axis_boundary, + ) + if use_v2_ffi and scaling_mode == ScalingMode.MXFP8_1D_SCALING: + # Pre-swizzle full scale tensors in JAX (CUDA-graph safe). + # Grouped scale_inv is 1D (flat, worst-case padded). When all group sizes are + # multiples of 128 (V2 requirement), the per-group scales are contiguous with no + # inter-group padding gaps. We reshape the meaningful prefix to 2D, swizzle, and + # write it back into the original 1D buffer (extra trailing zeros stay untouched). + lhs_is_colwise = lhs_is_trans + rhs_is_colwise = not rhs_is_trans + lhs_scale_shape = scaling_mode.get_scale_shape( + lhs_data.shape, is_colwise=lhs_is_colwise, is_padded=True, + flatten_axis=lhs_axis_boundary, + ) + rhs_scale_shape = scaling_mode.get_scale_shape( + rhs_data.shape, is_colwise=rhs_is_colwise, is_padded=True, + flatten_axis=rhs_axis_boundary, + ) + # get_scale_shape may return a multi-dim shape (e.g. (8, 4, 128) for a 3D + # input), but _swizzle_grouped_scale needs a flat 2D shape (rows, cols) where + # cols = n_block_y (last dim) and rows = prod(all other dims). This correctly + # flattens the group/K-block axes into a single row dimension so the swizzle + # pattern operates on the full (K-blocks-across-groups × N-blocks) matrix. + lhs_n_block_y = lhs_scale_shape[-1] + rhs_n_block_y = rhs_scale_shape[-1] + lhs_scale_2d = (math.prod(lhs_scale_shape) // lhs_n_block_y, lhs_n_block_y) + rhs_scale_2d = (math.prod(rhs_scale_shape) // rhs_n_block_y, rhs_n_block_y) + lhs_scale_inv = _swizzle_grouped_scale(lhs_scale_inv, lhs_scale_2d, lhs_is_colwise) + rhs_scale_inv = _swizzle_grouped_scale(rhs_scale_inv, rhs_scale_2d, rhs_is_colwise) + if use_v2_ffi: additional_arg_0 = jnp.ones((num_gemms,), jnp.float32) # alpha additional_arg_1 = jnp.zeros((num_gemms,), jnp.float32) # beta diff --git a/transformer_engine/jax/csrc/extensions/gemm.cpp b/transformer_engine/jax/csrc/extensions/gemm.cpp index b26b261996..786fac5dcd 100644 --- a/transformer_engine/jax/csrc/extensions/gemm.cpp +++ b/transformer_engine/jax/csrc/extensions/gemm.cpp @@ -481,6 +481,8 @@ class JAXX_GroupedTensorWrapper { m_grouped_tensor(other.m_grouped_tensor), m_data_tensor(other.m_data_tensor), m_scale_inv_tensor(other.m_scale_inv_tensor), + m_colwise_data_tensor(other.m_colwise_data_tensor), + m_colwise_scale_inv_tensor(other.m_colwise_scale_inv_tensor), m_sizes_tensor(other.m_sizes_tensor), m_offsets_tensor(other.m_offsets_tensor) { other.m_grouped_tensor = nullptr; @@ -489,6 +491,8 @@ class JAXX_GroupedTensorWrapper { ~JAXX_GroupedTensorWrapper(); void set_rowwise(Buffer_Type const &data, std::optional const &scale_inv); + void set_columnwise(Buffer_Type const &data, std::optional const &scale_inv); + void set_with_gemm_swizzled_scales(bool val); void set_group_info(Buffer_Type const &group_sizes, Buffer_Type const &group_offsets, NVTEGroupedTensorParam group_sizes_param_name); // Set only group sizes (no offsets); the setup kernel will compute offsets from sizes. @@ -505,6 +509,8 @@ class JAXX_GroupedTensorWrapper { // Internal tensors. These need to be kept alive as long as the grouped tensor is alive. NVTEBasicTensor m_data_tensor{}; NVTEBasicTensor m_scale_inv_tensor{}; + NVTEBasicTensor m_colwise_data_tensor{}; + NVTEBasicTensor m_colwise_scale_inv_tensor{}; NVTEBasicTensor m_sizes_tensor{}; NVTEBasicTensor m_offsets_tensor{}; @@ -556,6 +562,46 @@ void JAXX_GroupedTensorWrapper::set_rowwise(Buffer_Type const &data, } } +void JAXX_GroupedTensorWrapper::set_columnwise(Buffer_Type const &data, + std::optional const &scale_inv) { + NVTEDType data_dtype = + static_cast(convert_ffi_datatype_to_te_dtype(data.element_type())); + m_colwise_data_tensor = + NVTEBasicTensor{reinterpret_cast(data.untyped_data()), data_dtype, m_data_shape}; + + nvte_set_grouped_tensor_param(m_grouped_tensor, kNVTEGroupedColumnwiseData, + &m_colwise_data_tensor, sizeof(m_colwise_data_tensor)); + + if (scale_inv.has_value()) { + NVTEDType scale_inv_dtype = + static_cast(convert_ffi_datatype_to_te_dtype(scale_inv->element_type())); + NVTEShape logical_scale_shape{}; + if (scale_inv->dimensions().size() == 1) { + logical_scale_shape.ndim = 1; + logical_scale_shape.data[0] = scale_inv->dimensions()[0]; + } else if (scale_inv->dimensions().size() == 2) { + logical_scale_shape.ndim = 2; + logical_scale_shape.data[0] = scale_inv->dimensions()[0]; + logical_scale_shape.data[1] = scale_inv->dimensions()[1]; + } else { + NVTE_CHECK(false, "Expected 1D or 2D tensor for GEMM columnwise scale_inv but received ndim=", + scale_inv->dimensions().size()); + } + m_colwise_scale_inv_tensor = + NVTEBasicTensor{reinterpret_cast(scale_inv->untyped_data()), scale_inv_dtype, + logical_scale_shape}; + nvte_set_grouped_tensor_param(m_grouped_tensor, kNVTEGroupedColumnwiseScaleInv, + &m_colwise_scale_inv_tensor, + sizeof(m_colwise_scale_inv_tensor)); + } +} + +void JAXX_GroupedTensorWrapper::set_with_gemm_swizzled_scales(bool val) { + auto v = static_cast(val); + nvte_set_grouped_tensor_param(m_grouped_tensor, kNVTEGroupedWithGEMMSwizzledScales, &v, + sizeof(v)); +} + void JAXX_GroupedTensorWrapper::set_group_info(Buffer_Type const &group_sizes, Buffer_Type const &group_offsets, NVTEGroupedTensorParam group_sizes_param_name) { @@ -619,11 +665,11 @@ JAXX_GroupedTensorWrapper make_grouped_tensor(Buffer_Type const &data, return std::move(grouped_tensor_wrapper); } -// V2 variant: derives data shape from the XLA buffer directly, converts group_sizes +// V2 variant (NO_SCALING): derives data shape from the XLA buffer directly, converts group_sizes // int32→int64 per-tensor into a dedicated slot of int64_workspace, and wires first_dims/last_dims. // int64_offset (in int64 elements) is updated on return to the next available slot so callers can // thread it through successive make_grouped_tensor calls without aliasing. Bounds are checked -// before each slot is used. Only NO_SCALING is supported. +// before each slot is used. Only NO_SCALING is supported by this overload. JAXX_GroupedTensorWrapper make_grouped_tensor( Buffer_Type const &data, Buffer_Type const &first_dims, Buffer_Type const &last_dims, int64_t *int64_workspace_base, size_t int64_workspace_capacity, size_t &int64_offset, @@ -660,6 +706,56 @@ JAXX_GroupedTensorWrapper make_grouped_tensor( return wrapper; } +// V2 variant with scaling support (MXFP8 or NO_SCALING). Accepts scale_inv buffer and +// use_colwise flag to wire rowwise or columnwise data+scales for the grouped tensor. +// Pre-swizzled scales are indicated via set_with_gemm_swizzled_scales(true). +JAXX_GroupedTensorWrapper make_grouped_tensor( + Buffer_Type const &data, Buffer_Type const &scale_inv, JAXX_Scaling_Mode scaling_mode, + bool use_colwise, Buffer_Type const &first_dims, Buffer_Type const &last_dims, + int64_t *int64_workspace_base, size_t int64_workspace_capacity, size_t &int64_offset, + size_t num_gemms, cudaStream_t stream, int64_t axis_boundary = -1) { + auto dims = data.dimensions(); + NVTE_CHECK(dims.size() >= 2, "grouped GEMM data buffer must be at least 2D."); + size_t ab = (axis_boundary < 0) ? dims.size() - 1 : static_cast(axis_boundary); + NVTEShape dataShape{.data = {product(dims, 0, ab), product(dims, ab, dims.size())}, .ndim = 2}; + JAXX_GroupedTensorWrapper wrapper(scaling_mode, num_gemms, dataShape); + + const bool is_mxfp8 = scaling_mode == JAXX_Scaling_Mode::MXFP8_1D_SCALING; + if (is_mxfp8 && use_colwise) { + wrapper.set_columnwise(data, scale_inv); + } else if (is_mxfp8) { + wrapper.set_rowwise(data, scale_inv); + } else { + // NO_SCALING: no scale_inv needed + wrapper.set_rowwise(data, std::nullopt); + } + if (is_mxfp8) { + wrapper.set_with_gemm_swizzled_scales(true); + } + + if (first_dims.element_count() > 0) { + NVTE_CHECK(first_dims.element_type() == xla::ffi::DataType::S32, "group_sizes must be int32."); + NVTE_CHECK(int64_offset + num_gemms <= int64_workspace_capacity, + "int64_workspace overflow: not enough space for first_dims conversion."); + auto *slot = int64_workspace_base + int64_offset; + nvte_convert_int32_to_int64(reinterpret_cast(first_dims.untyped_data()), slot, + num_gemms, stream); + wrapper.set_group_sizes_only(slot, num_gemms, kNVTEGroupedFirstDims); + int64_offset += num_gemms; + } + if (last_dims.element_count() > 0) { + NVTE_CHECK(last_dims.element_type() == xla::ffi::DataType::S32, "group_sizes must be int32."); + NVTE_CHECK(int64_offset + num_gemms <= int64_workspace_capacity, + "int64_workspace overflow: not enough space for last_dims conversion."); + auto *slot = int64_workspace_base + int64_offset; + nvte_convert_int32_to_int64(reinterpret_cast(last_dims.untyped_data()), slot, + num_gemms, stream); + wrapper.set_group_sizes_only(slot, num_gemms, kNVTEGroupedLastDims); + int64_offset += num_gemms; + } + return wrapper; +} + // Returns num_gemms from the first non-empty per-tensor group_sizes buffer, // falling back to the element count of alpha for the uniform-batch case. size_t grouped_gemm_num_gemms(Buffer_Type const &lhs_first_dims, Buffer_Type const &lhs_last_dims, @@ -700,8 +796,11 @@ Error_Type GroupedGemmV2FFI(cudaStream_t stream, Buffer_Type lhs_data, Buffer_Ty Result_Type int64_workspace, GroupedGemmV2Config config) { auto [lhs_is_trans, rhs_is_trans, scaling_mode, lhs_axis_boundary, rhs_axis_boundary] = config; - NVTE_CHECK(scaling_mode == JAXX_Scaling_Mode::NO_SCALING, - "Only non-quantized grouped GEMM is supported in current implementation."); + NVTE_CHECK(scaling_mode == JAXX_Scaling_Mode::NO_SCALING || + scaling_mode == JAXX_Scaling_Mode::MXFP8_1D_SCALING, + "Only NO_SCALING and MXFP8_1D_SCALING are supported in the V2 grouped GEMM."); + + const bool is_mxfp8 = scaling_mode == JAXX_Scaling_Mode::MXFP8_1D_SCALING; size_t num_gemms = grouped_gemm_num_gemms(lhs_first_dims, lhs_last_dims, rhs_first_dims, rhs_last_dims, out_first_dims, out_last_dims, alpha); @@ -731,12 +830,25 @@ Error_Type GroupedGemmV2FFI(cudaStream_t stream, Buffer_Type lhs_data, Buffer_Ty auto *int64_base = reinterpret_cast(int64_workspace->untyped_data()); size_t int64_capacity = int64_workspace->element_count() / sizeof(int64_t); size_t int64_offset = 0; - auto rhs_tensor = - make_grouped_tensor(rhs_data, rhs_first_dims, rhs_last_dims, int64_base, int64_capacity, - int64_offset, num_gemms, stream, rhs_axis_boundary); - auto lhs_tensor = - make_grouped_tensor(lhs_data, lhs_first_dims, lhs_last_dims, int64_base, int64_capacity, - int64_offset, num_gemms, stream, lhs_axis_boundary); + + // For MXFP8: in JAX, rhs=cuBLAS_A, lhs=cuBLAS_B (swapped). + // Colwise is needed when the operand's contracting dim is NOT the last dim in its layout. + const bool rhs_use_colwise = is_mxfp8 && !rhs_is_trans; + const bool lhs_use_colwise = is_mxfp8 && lhs_is_trans; + + auto rhs_tensor = is_mxfp8 + ? make_grouped_tensor(rhs_data, rhs_sinv, scaling_mode, rhs_use_colwise, + rhs_first_dims, rhs_last_dims, int64_base, int64_capacity, + int64_offset, num_gemms, stream, rhs_axis_boundary) + : make_grouped_tensor(rhs_data, rhs_first_dims, rhs_last_dims, int64_base, int64_capacity, + int64_offset, num_gemms, stream, rhs_axis_boundary); + auto lhs_tensor = is_mxfp8 + ? make_grouped_tensor(lhs_data, lhs_sinv, scaling_mode, lhs_use_colwise, + lhs_first_dims, lhs_last_dims, int64_base, int64_capacity, + int64_offset, num_gemms, stream, lhs_axis_boundary) + : make_grouped_tensor(lhs_data, lhs_first_dims, lhs_last_dims, int64_base, int64_capacity, + int64_offset, num_gemms, stream, lhs_axis_boundary); + // Output stays NO_SCALING auto out_tensor = make_grouped_tensor(*output, out_first_dims, out_last_dims, int64_base, int64_capacity, int64_offset, num_gemms, stream); From c97b0b70548378da17f9657293d1871ea0d293fe Mon Sep 17 00:00:00 2001 From: Jeremy Berchtold Date: Sat, 14 Mar 2026 10:21:23 -0700 Subject: [PATCH 19/60] te_permutation NaN issue fix Signed-off-by: Jeremy Berchtold --- transformer_engine/jax/permutation.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/transformer_engine/jax/permutation.py b/transformer_engine/jax/permutation.py index 6a0a3229d9..2732c4acc5 100644 --- a/transformer_engine/jax/permutation.py +++ b/transformer_engine/jax/permutation.py @@ -497,8 +497,9 @@ def _token_combine_bwd_rule( hidden_size, ) # The backward kernel only writes to positions that tokens map to. - # Padded positions may contain uninitialized (NaN) values - replace with zeros. - inp_grad = jnp.where(jnp.isnan(inp_grad), 0.0, inp_grad) + # Padded positions may contain uninitialized values (NaN, inf, or garbage). + # Replace any non-finite values with zeros. + inp_grad = jnp.where(jnp.isfinite(inp_grad), inp_grad, 0.0) else: inp_grad, merging_probs_grad = unpermute_bwd_with_merging_probs( output_grad, @@ -527,8 +528,9 @@ def _token_combine_bwd_rule( align_size=128, # Default, sizes already computed in forward ) # The permute kernel only writes to positions that tokens map to. - # Padded positions may contain uninitialized (NaN) values - replace with zeros. - inp_grad = jnp.where(jnp.isnan(inp_grad), 0.0, inp_grad) + # Padded positions may contain uninitialized values (NaN, inf, or garbage). + # Replace any non-finite values with zeros. + inp_grad = jnp.where(jnp.isfinite(inp_grad), inp_grad, 0.0) else: inp_grad, _ = permute_with_mask_map( output_grad, From 0b9a7637d7cf5f4bd73c78286ba330a1037c470b Mon Sep 17 00:00:00 2001 From: Jeremy Berchtold Date: Sat, 14 Mar 2026 10:22:22 -0700 Subject: [PATCH 20/60] Support GroupedDense quantization checkpointing Signed-off-by: Jeremy Berchtold --- transformer_engine/jax/flax/module.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/transformer_engine/jax/flax/module.py b/transformer_engine/jax/flax/module.py index 4f19f449bc..dbd6fb1fef 100644 --- a/transformer_engine/jax/flax/module.py +++ b/transformer_engine/jax/flax/module.py @@ -1361,7 +1361,7 @@ def kernel_1_init(key, num_kernels, stack_axis, *init_args): return out, ln_output # Output, layer_norm_output -def wrap_function_in_te_state_module(f, quantization_recipe, name: Optional[str] = None): +def wrap_function_in_te_state_module(f, quantization_recipe, name: Optional[str] = None, quantization_checkpoint_name: Optional[str] = None): """Wraps the given function `f` to support TransformerEngine quantization. This method does a couple things: @@ -1389,6 +1389,7 @@ def generate_quantizer_set(self, postfix: str = "", n_groups: int = None): return super().generate_quantizer_set( postfix=postfix, variable_collection=OVERWRITE_WITH_GRADIENT, + quantization_checkpoint_name=quantization_checkpoint_name, fp8_recipe=quantization_recipe, n_groups=n_groups, ) @@ -1446,7 +1447,7 @@ def te_dot_general(generate_quantizer_set, x, kernel, dims, **kwargs): return wrap_function_in_te_state_module(te_dot_general, quantization_recipe, "dot_general") -def make_grouped_dense_cls(quantization_recipe): +def make_grouped_dense_cls(quantization_recipe, quantization_checkpoint_name: Optional[str] = None): """Creates a grouped dense (grouped GEMM) instance for use with TE state module.""" if quantization_recipe is not None: allowed_grouped_gemm_recipes = [MXFP8BlockScaling] @@ -1470,5 +1471,6 @@ def te_grouped_dot_general(generate_quantizer_set, x, kernel, group_sizes, **kwa return out return wrap_function_in_te_state_module( - te_grouped_dot_general, quantization_recipe, "ragged_dot" + te_grouped_dot_general, quantization_recipe, "ragged_dot", + quantization_checkpoint_name=quantization_checkpoint_name, )() From 6b64cea01c9a31993d83c9a7bf9a0cd97f0bb024 Mon Sep 17 00:00:00 2001 From: Jeremy Berchtold Date: Sat, 14 Mar 2026 10:23:01 -0700 Subject: [PATCH 21/60] Temporary commit to assert if V1 grouped quantize is used Signed-off-by: Jeremy Berchtold --- transformer_engine/jax/cpp_extensions/quantization.py | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/transformer_engine/jax/cpp_extensions/quantization.py b/transformer_engine/jax/cpp_extensions/quantization.py index 02639cfa2d..936b48cae0 100644 --- a/transformer_engine/jax/cpp_extensions/quantization.py +++ b/transformer_engine/jax/cpp_extensions/quantization.py @@ -1025,17 +1025,26 @@ def _use_v2_kernel(scaling_mode, x_shape, flatten_axis): but performs a D2H copy of group_sizes (not CUDA-graph safe). """ if ScalingMode(scaling_mode) != ScalingMode.MXFP8_1D_SCALING: + assert False, "V2 grouped quantize kernel currently only supports MXFP8 1D scaling mode, but got scaling_mode {}".format( + scaling_mode + ) return False ndim = len(x_shape) eff = flatten_axis if flatten_axis >= 0 else flatten_axis + ndim total_first_dim = math.prod(x_shape[:eff]) if total_first_dim % 128 != 0: + assert False, "V2 grouped quantize kernel requires total first logical dimension (product of x_shape up to flatten_axis) to be divisible by 128, but got shape {} and flatten_axis {} with total_first_dim {}".format( + x_shape, flatten_axis, total_first_dim + ) return False # For multi-dim group tensors (e.g., kernel shape G×K×N with eff=2), # non_group_m = K must also be 128-aligned. if eff > 1: non_group_m = math.prod(x_shape[1:eff]) if non_group_m % 128 != 0: + assert False, "V2 grouped quantize kernel requires non-group dimension (product of x_shape[1:flatten_axis]) to be divisible by 128 for multi-dim group tensors, but got shape {} and flatten_axis {} with non_group_m {}".format( + x_shape, flatten_axis, non_group_m + ) return False return True From 2dd69d4fa4a4ae5c88d9e79404951b0e528f52e4 Mon Sep 17 00:00:00 2001 From: Jeremy Berchtold Date: Sat, 14 Mar 2026 12:08:05 -0700 Subject: [PATCH 22/60] Fix scale shapes for MXFP8 Signed-off-by: Jeremy Berchtold --- .../jax/cpp_extensions/quantization.py | 46 ++++++++++++++++++- 1 file changed, 44 insertions(+), 2 deletions(-) diff --git a/transformer_engine/jax/cpp_extensions/quantization.py b/transformer_engine/jax/cpp_extensions/quantization.py index 936b48cae0..054bf79431 100644 --- a/transformer_engine/jax/cpp_extensions/quantization.py +++ b/transformer_engine/jax/cpp_extensions/quantization.py @@ -51,6 +51,36 @@ __all__ = ["quantize", "quantize_dbias", "grouped_quantize", "grouped_dbias"] +def _build_scale_spec(x_spec, scale_shape, mesh): + """Build a PartitionSpec for the MXFP8 scale tensor compatible with its shape. + + The scale tensor has smaller dimensions than the data tensor (each dimension + divided by the MXFP8 block size). This function ensures that we only shard a + scale dimension by a mesh axis if scale_shape[i] is divisible by that axis's + size. If not, a ValueError is raised with a helpful diagnostic message. + """ + result = [] + for axis, scale_dim in zip(x_spec, scale_shape): + if axis is None: + result.append(None) + elif isinstance(axis, str): + axis_size = mesh.shape.get(axis, 1) + if scale_dim % axis_size == 0: + result.append(axis) + else: + raise ValueError( + f"Cannot partition MXFP8 scale tensor (shape={tuple(scale_shape)}) " + f"by mesh axis '{axis}' of size {axis_size}: " + f"scale dim {scale_dim} is not divisible by {axis_size}. " + f"The data tensor's sharding is incompatible with the MXFP8 block " + f"size along this axis. Try reducing expert parallelism (EP) so that " + f"EP divides the scale dimension, or increase the tensor size." + ) + else: + result.append(None) # tuple axes: conservatively leave unsharded + return tuple(result) + + class BaseDBiasQuantizePrimitive(BasePrimitive): """ Cast Primitive wrapping nvte_quantize and nvte_quantize_dbias @@ -446,7 +476,13 @@ def infer_sharding_from_operands( scale_inv_spec = colwise_scale_inv_spec = (None,) if ScalingMode(scaling_mode).is_block_scaling: - scale_inv_spec = x_spec + rowwise_scale_shape, _ = ScalingMode(scaling_mode).get_scale_shape_2x( + arg_infos[0].shape, + is_padded=False, + flatten_axis=flatten_axis, + broadcast_2d_scale_shape_to_1d=True, + ) + scale_inv_spec = _build_scale_spec(x_spec, rowwise_scale_shape, mesh) if q_layout.has_colwise: if ( @@ -528,7 +564,13 @@ def partition( scale_inv_spec = colwise_scale_inv_spec = (None,) if ScalingMode(scaling_mode).is_block_scaling: - scale_inv_spec = x_spec + rowwise_scale_shape, _ = ScalingMode(scaling_mode).get_scale_shape_2x( + arg_infos[0].shape, + is_padded=False, + flatten_axis=flatten_axis, + broadcast_2d_scale_shape_to_1d=True, + ) + scale_inv_spec = _build_scale_spec(x_spec, rowwise_scale_shape, mesh) if q_layout.has_colwise: if ( From 204b3260da8fdd44304992d993a9726d1a8dece6 Mon Sep 17 00:00:00 2001 From: Jeremy Berchtold Date: Sat, 14 Mar 2026 12:35:35 -0700 Subject: [PATCH 23/60] Fix MXFP8 scale sharding when FSDP+EP on same axis Signed-off-by: Jeremy Berchtold --- .../jax/cpp_extensions/quantization.py | 23 ++++++++++++++++--- 1 file changed, 20 insertions(+), 3 deletions(-) diff --git a/transformer_engine/jax/cpp_extensions/quantization.py b/transformer_engine/jax/cpp_extensions/quantization.py index 054bf79431..29a0a047cf 100644 --- a/transformer_engine/jax/cpp_extensions/quantization.py +++ b/transformer_engine/jax/cpp_extensions/quantization.py @@ -56,8 +56,9 @@ def _build_scale_spec(x_spec, scale_shape, mesh): The scale tensor has smaller dimensions than the data tensor (each dimension divided by the MXFP8 block size). This function ensures that we only shard a - scale dimension by a mesh axis if scale_shape[i] is divisible by that axis's - size. If not, a ValueError is raised with a helpful diagnostic message. + scale dimension by a mesh axis (or tuple of axes) if scale_shape[i] is + divisible by the total axis size. If not, a ValueError is raised with a + helpful diagnostic message. """ result = [] for axis, scale_dim in zip(x_spec, scale_shape): @@ -76,8 +77,24 @@ def _build_scale_spec(x_spec, scale_shape, mesh): f"size along this axis. Try reducing expert parallelism (EP) so that " f"EP divides the scale dimension, or increase the tensor size." ) + elif isinstance(axis, (tuple, list)): + # Multi-axis sharding (e.g. ('fsdp', 'expert')): check total combined size. + total_size = 1 + for a in axis: + total_size *= mesh.shape.get(a, 1) + if scale_dim % total_size == 0: + result.append(axis) + else: + raise ValueError( + f"Cannot partition MXFP8 scale tensor (shape={tuple(scale_shape)}) " + f"by mesh axes {tuple(axis)} of combined size {total_size}: " + f"scale dim {scale_dim} is not divisible by {total_size}. " + f"The data tensor's sharding is incompatible with the MXFP8 block " + f"size along this axis. Try reducing parallelism or increasing the " + f"tensor size." + ) else: - result.append(None) # tuple axes: conservatively leave unsharded + result.append(None) return tuple(result) From 5fb585f907487034fae73629c9913c1b19bbd534 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Sat, 14 Mar 2026 19:36:48 +0000 Subject: [PATCH 24/60] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- tests/jax/test_custom_call_compute.py | 8 ++-- .../common/gemm/cublaslt_grouped_gemm.cu | 11 +++--- .../common/include/transformer_engine/gemm.h | 5 +-- transformer_engine/jax/cpp_extensions/gemm.py | 18 ++++++--- .../jax/cpp_extensions/quantization.py | 37 +++++++++++-------- .../jax/csrc/extensions/gemm.cpp | 29 ++++++++------- .../jax/csrc/extensions/quantization.cpp | 23 +++++------- transformer_engine/jax/flax/module.py | 16 ++++++-- transformer_engine/jax/quantize/tensor.py | 10 +++-- 9 files changed, 90 insertions(+), 67 deletions(-) diff --git a/tests/jax/test_custom_call_compute.py b/tests/jax/test_custom_call_compute.py index 1f1286a399..0a76794003 100644 --- a/tests/jax/test_custom_call_compute.py +++ b/tests/jax/test_custom_call_compute.py @@ -1964,7 +1964,7 @@ def test_grouped_dense_grad_fp8(self, fwd_bwd_dtype, scaling_mode, input_shape): # (n_groups, m, n, k) # lhs total_rows = m * 128; rhs total_rows = n_groups * k (8, 8, 128, 128), # lhs: 8*128=1024 (128-aligned); rhs: 8*128=1024 (128-aligned) - (4, 4, 64, 256), # lhs: 4*128=512 (128-aligned); rhs: 4*256=1024 (128-aligned) + (4, 4, 64, 256), # lhs: 4*128=512 (128-aligned); rhs: 4*256=1024 (128-aligned) ] @@ -2119,9 +2119,9 @@ def _prim_sum(x, kernel, group_sizes): return jnp.sum(jnp.asarray(out)) / jnp.sqrt(x.size) ref_val, (ref_dx, ref_dk) = value_and_grad(_ref_sum, (0, 1))(lhs, rhs, group_sizes) - prim_val, (prim_dx, prim_dk) = jit( - value_and_grad(_prim_sum, (0, 1)), static_argnums=() - )(lhs, rhs, group_sizes) + prim_val, (prim_dx, prim_dk) = jit(value_and_grad(_prim_sum, (0, 1)), static_argnums=())( + lhs, rhs, group_sizes + ) assert_allclose(prim_val, ref_val, dtype=fwd_dtype) assert_allclose(prim_dx, ref_dx, dtype=bwd_dtype) diff --git a/transformer_engine/common/gemm/cublaslt_grouped_gemm.cu b/transformer_engine/common/gemm/cublaslt_grouped_gemm.cu index c57a662073..5dd1fe9c06 100644 --- a/transformer_engine/common/gemm/cublaslt_grouped_gemm.cu +++ b/transformer_engine/common/gemm/cublaslt_grouped_gemm.cu @@ -823,7 +823,7 @@ __global__ void convert_int32_to_int64_kernel(const int32_t *src, int64_t *dst, // Like convert_int32_to_int64_kernel but scales each element by multiplier. // Used to convert per-expert slice counts to per-expert row counts for multi-dim tensors. __global__ void convert_int32_to_int64_with_multiplier_kernel(const int32_t *src, int64_t *dst, - size_t n, int64_t multiplier) { + size_t n, int64_t multiplier) { size_t idx = blockIdx.x * blockDim.x + threadIdx.x; if (idx < n) dst[idx] = static_cast(src[idx]) * multiplier; } @@ -850,22 +850,21 @@ void nvte_convert_int32_to_int64(const int32_t *src, int64_t *dst, size_t n, cud } void nvte_convert_int32_to_int64_with_multiplier(const int32_t *src, int64_t *dst, size_t n, - int64_t multiplier, cudaStream_t stream) { + int64_t multiplier, cudaStream_t stream) { NVTE_API_CALL(nvte_convert_int32_to_int64_with_multiplier); if (n == 0) return; const int threads = 256; const int blocks = static_cast((n + threads - 1) / threads); convert_int32_to_int64_with_multiplier_kernel<<>>(src, dst, n, - multiplier); + multiplier); NVTE_CHECK_CUDA(cudaGetLastError()); } void nvte_compute_grouped_tensor_offsets(const int64_t *first_dims, int64_t *offsets, - size_t n_groups, int64_t last_dim, - cudaStream_t stream) { + size_t n_groups, int64_t last_dim, cudaStream_t stream) { NVTE_API_CALL(nvte_compute_grouped_tensor_offsets); // Always write at least offsets[0]=0 (needed even for n_groups==0). compute_grouped_tensor_offsets_kernel<<<1, 1, 0, stream>>>(first_dims, offsets, n_groups, - last_dim); + last_dim); NVTE_CHECK_CUDA(cudaGetLastError()); } diff --git a/transformer_engine/common/include/transformer_engine/gemm.h b/transformer_engine/common/include/transformer_engine/gemm.h index dc1a104abb..5ee15613b0 100644 --- a/transformer_engine/common/include/transformer_engine/gemm.h +++ b/transformer_engine/common/include/transformer_engine/gemm.h @@ -368,7 +368,7 @@ void nvte_convert_int32_to_int64(const int32_t *src, int64_t *dst, size_t n, cud * \param[in] stream CUDA stream. */ void nvte_convert_int32_to_int64_with_multiplier(const int32_t *src, int64_t *dst, size_t n, - int64_t multiplier, cudaStream_t stream); + int64_t multiplier, cudaStream_t stream); /*! \brief Compute exclusive prefix-sum offsets from per-group first-dimension sizes. * @@ -383,8 +383,7 @@ void nvte_convert_int32_to_int64_with_multiplier(const int32_t *src, int64_t *ds * \param[in] stream CUDA stream. */ void nvte_compute_grouped_tensor_offsets(const int64_t *first_dims, int64_t *offsets, - size_t n_groups, int64_t last_dim, - cudaStream_t stream); + size_t n_groups, int64_t last_dim, cudaStream_t stream); void nvte_grouped_gemm(const NVTEGroupedTensor A, int transa, const NVTEGroupedTensor B, int transb, const NVTEGroupedTensor C, NVTEGroupedTensor D, const NVTETensor alpha, diff --git a/transformer_engine/jax/cpp_extensions/gemm.py b/transformer_engine/jax/cpp_extensions/gemm.py index e4fd9c99a6..1029600389 100644 --- a/transformer_engine/jax/cpp_extensions/gemm.py +++ b/transformer_engine/jax/cpp_extensions/gemm.py @@ -2338,9 +2338,13 @@ def grouped_gemm( ) use_v2_ffi = _can_use_v2_grouped_gemm( - scaling_mode, lhs_data.dtype, has_bias, - lhs_shape=lhs_data.shape, rhs_shape=rhs_data.shape, - lhs_axis_boundary=lhs_axis_boundary, rhs_axis_boundary=rhs_axis_boundary, + scaling_mode, + lhs_data.dtype, + has_bias, + lhs_shape=lhs_data.shape, + rhs_shape=rhs_data.shape, + lhs_axis_boundary=lhs_axis_boundary, + rhs_axis_boundary=rhs_axis_boundary, ) if use_v2_ffi and scaling_mode == ScalingMode.MXFP8_1D_SCALING: # Pre-swizzle full scale tensors in JAX (CUDA-graph safe). @@ -2351,11 +2355,15 @@ def grouped_gemm( lhs_is_colwise = lhs_is_trans rhs_is_colwise = not rhs_is_trans lhs_scale_shape = scaling_mode.get_scale_shape( - lhs_data.shape, is_colwise=lhs_is_colwise, is_padded=True, + lhs_data.shape, + is_colwise=lhs_is_colwise, + is_padded=True, flatten_axis=lhs_axis_boundary, ) rhs_scale_shape = scaling_mode.get_scale_shape( - rhs_data.shape, is_colwise=rhs_is_colwise, is_padded=True, + rhs_data.shape, + is_colwise=rhs_is_colwise, + is_padded=True, flatten_axis=rhs_axis_boundary, ) # get_scale_shape may return a multi-dim shape (e.g. (8, 4, 128) for a 3D diff --git a/transformer_engine/jax/cpp_extensions/quantization.py b/transformer_engine/jax/cpp_extensions/quantization.py index 29a0a047cf..a2da6b8830 100644 --- a/transformer_engine/jax/cpp_extensions/quantization.py +++ b/transformer_engine/jax/cpp_extensions/quantization.py @@ -73,9 +73,9 @@ def _build_scale_spec(x_spec, scale_shape, mesh): f"Cannot partition MXFP8 scale tensor (shape={tuple(scale_shape)}) " f"by mesh axis '{axis}' of size {axis_size}: " f"scale dim {scale_dim} is not divisible by {axis_size}. " - f"The data tensor's sharding is incompatible with the MXFP8 block " - f"size along this axis. Try reducing expert parallelism (EP) so that " - f"EP divides the scale dimension, or increase the tensor size." + "The data tensor's sharding is incompatible with the MXFP8 block " + "size along this axis. Try reducing expert parallelism (EP) so that " + "EP divides the scale dimension, or increase the tensor size." ) elif isinstance(axis, (tuple, list)): # Multi-axis sharding (e.g. ('fsdp', 'expert')): check total combined size. @@ -89,9 +89,9 @@ def _build_scale_spec(x_spec, scale_shape, mesh): f"Cannot partition MXFP8 scale tensor (shape={tuple(scale_shape)}) " f"by mesh axes {tuple(axis)} of combined size {total_size}: " f"scale dim {scale_dim} is not divisible by {total_size}. " - f"The data tensor's sharding is incompatible with the MXFP8 block " - f"size along this axis. Try reducing parallelism or increasing the " - f"tensor size." + "The data tensor's sharding is incompatible with the MXFP8 block " + "size along this axis. Try reducing parallelism or increasing the " + "tensor size." ) else: result.append(None) @@ -1084,16 +1084,21 @@ def _use_v2_kernel(scaling_mode, x_shape, flatten_axis): but performs a D2H copy of group_sizes (not CUDA-graph safe). """ if ScalingMode(scaling_mode) != ScalingMode.MXFP8_1D_SCALING: - assert False, "V2 grouped quantize kernel currently only supports MXFP8 1D scaling mode, but got scaling_mode {}".format( - scaling_mode + assert False, ( + "V2 grouped quantize kernel currently only supports MXFP8 1D scaling mode, but got" + " scaling_mode {}".format(scaling_mode) ) return False ndim = len(x_shape) eff = flatten_axis if flatten_axis >= 0 else flatten_axis + ndim total_first_dim = math.prod(x_shape[:eff]) if total_first_dim % 128 != 0: - assert False, "V2 grouped quantize kernel requires total first logical dimension (product of x_shape up to flatten_axis) to be divisible by 128, but got shape {} and flatten_axis {} with total_first_dim {}".format( - x_shape, flatten_axis, total_first_dim + assert False, ( + "V2 grouped quantize kernel requires total first logical dimension (product of" + " x_shape up to flatten_axis) to be divisible by 128, but got shape {} and" + " flatten_axis {} with total_first_dim {}".format( + x_shape, flatten_axis, total_first_dim + ) ) return False # For multi-dim group tensors (e.g., kernel shape G×K×N with eff=2), @@ -1101,8 +1106,12 @@ def _use_v2_kernel(scaling_mode, x_shape, flatten_axis): if eff > 1: non_group_m = math.prod(x_shape[1:eff]) if non_group_m % 128 != 0: - assert False, "V2 grouped quantize kernel requires non-group dimension (product of x_shape[1:flatten_axis]) to be divisible by 128 for multi-dim group tensors, but got shape {} and flatten_axis {} with non_group_m {}".format( - x_shape, flatten_axis, non_group_m + assert False, ( + "V2 grouped quantize kernel requires non-group dimension (product of" + " x_shape[1:flatten_axis]) to be divisible by 128 for multi-dim group tensors," + " but got shape {} and flatten_axis {} with non_group_m {}".format( + x_shape, flatten_axis, non_group_m + ) ) return False return True @@ -1157,9 +1166,7 @@ def abstract( # [n_groups int64 group_sizes | n_groups+1 int64 offsets] # = (2*n_groups + 1) * sizeof(int64_t) bytes stored as uint8. n_groups = group_sizes_aval.size - fifth_out_aval = jax.core.ShapedArray( - shape=((2 * n_groups + 1) * 8,), dtype=jnp.uint8 - ) + fifth_out_aval = jax.core.ShapedArray(shape=((2 * n_groups + 1) * 8,), dtype=jnp.uint8) else: # V1 path: 5th output is amax fifth_out_aval = jax.core.ShapedArray(shape=(group_sizes_aval.size,), dtype=jnp.float32) diff --git a/transformer_engine/jax/csrc/extensions/gemm.cpp b/transformer_engine/jax/csrc/extensions/gemm.cpp index 786fac5dcd..45625120fd 100644 --- a/transformer_engine/jax/csrc/extensions/gemm.cpp +++ b/transformer_engine/jax/csrc/extensions/gemm.cpp @@ -591,8 +591,7 @@ void JAXX_GroupedTensorWrapper::set_columnwise(Buffer_Type const &data, NVTEBasicTensor{reinterpret_cast(scale_inv->untyped_data()), scale_inv_dtype, logical_scale_shape}; nvte_set_grouped_tensor_param(m_grouped_tensor, kNVTEGroupedColumnwiseScaleInv, - &m_colwise_scale_inv_tensor, - sizeof(m_colwise_scale_inv_tensor)); + &m_colwise_scale_inv_tensor, sizeof(m_colwise_scale_inv_tensor)); } } @@ -836,18 +835,20 @@ Error_Type GroupedGemmV2FFI(cudaStream_t stream, Buffer_Type lhs_data, Buffer_Ty const bool rhs_use_colwise = is_mxfp8 && !rhs_is_trans; const bool lhs_use_colwise = is_mxfp8 && lhs_is_trans; - auto rhs_tensor = is_mxfp8 - ? make_grouped_tensor(rhs_data, rhs_sinv, scaling_mode, rhs_use_colwise, - rhs_first_dims, rhs_last_dims, int64_base, int64_capacity, - int64_offset, num_gemms, stream, rhs_axis_boundary) - : make_grouped_tensor(rhs_data, rhs_first_dims, rhs_last_dims, int64_base, int64_capacity, - int64_offset, num_gemms, stream, rhs_axis_boundary); - auto lhs_tensor = is_mxfp8 - ? make_grouped_tensor(lhs_data, lhs_sinv, scaling_mode, lhs_use_colwise, - lhs_first_dims, lhs_last_dims, int64_base, int64_capacity, - int64_offset, num_gemms, stream, lhs_axis_boundary) - : make_grouped_tensor(lhs_data, lhs_first_dims, lhs_last_dims, int64_base, int64_capacity, - int64_offset, num_gemms, stream, lhs_axis_boundary); + auto rhs_tensor = + is_mxfp8 + ? make_grouped_tensor(rhs_data, rhs_sinv, scaling_mode, rhs_use_colwise, rhs_first_dims, + rhs_last_dims, int64_base, int64_capacity, int64_offset, num_gemms, + stream, rhs_axis_boundary) + : make_grouped_tensor(rhs_data, rhs_first_dims, rhs_last_dims, int64_base, int64_capacity, + int64_offset, num_gemms, stream, rhs_axis_boundary); + auto lhs_tensor = + is_mxfp8 + ? make_grouped_tensor(lhs_data, lhs_sinv, scaling_mode, lhs_use_colwise, lhs_first_dims, + lhs_last_dims, int64_base, int64_capacity, int64_offset, num_gemms, + stream, lhs_axis_boundary) + : make_grouped_tensor(lhs_data, lhs_first_dims, lhs_last_dims, int64_base, int64_capacity, + int64_offset, num_gemms, stream, lhs_axis_boundary); // Output stays NO_SCALING auto out_tensor = make_grouped_tensor(*output, out_first_dims, out_last_dims, int64_base, int64_capacity, int64_offset, num_gemms, stream); diff --git a/transformer_engine/jax/csrc/extensions/quantization.cpp b/transformer_engine/jax/csrc/extensions/quantization.cpp index fb083a958f..06f5906edf 100644 --- a/transformer_engine/jax/csrc/extensions/quantization.cpp +++ b/transformer_engine/jax/csrc/extensions/quantization.cpp @@ -495,11 +495,10 @@ XLA_FFI_DEFINE_HANDLER_SYMBOL(GroupedQuantizeHandler, GroupedQuantizeFFI, .Attr("q_layout") .Attr("flatten_axis")); -Error_Type GroupedQuantizeV2FFI(cudaStream_t stream, Buffer_Type inputs, - Buffer_Type scale_unused, Buffer_Type group_sizes, - Result_Type rowwise_out, Result_Type colwise_out, - Result_Type rowwise_sinv, Result_Type colwise_sinv, - Result_Type int64_workspace, +Error_Type GroupedQuantizeV2FFI(cudaStream_t stream, Buffer_Type inputs, Buffer_Type scale_unused, + Buffer_Type group_sizes, Result_Type rowwise_out, + Result_Type colwise_out, Result_Type rowwise_sinv, + Result_Type colwise_sinv, Result_Type int64_workspace, JAXX_Quantize_Layout quantize_layout, int64_t flatten_axis) { (void)scale_unused; // scale is unused for MXFP8; accepted to match V1 input arity auto in_dtype = convert_ffi_datatype_to_te_dtype(inputs.element_type()); @@ -539,8 +538,8 @@ Error_Type GroupedQuantizeV2FFI(cudaStream_t stream, Buffer_Type inputs, non_group_m, stream); // Compute exclusive prefix-sum offsets on device (CUDA-graph safe, no D2H). - nvte_compute_grouped_tensor_offsets(int64_ptr, offsets_ptr_out, n_groups, - static_cast(n), stream); + nvte_compute_grouped_tensor_offsets(int64_ptr, offsets_ptr_out, n_groups, static_cast(n), + stream); NVTEShape data_shape{}; data_shape.data[0] = m; @@ -557,9 +556,8 @@ Error_Type GroupedQuantizeV2FFI(cudaStream_t stream, Buffer_Type inputs, offsets_shape.data[0] = n_groups + 1; // Build input grouped tensor (plain float data, no quantization on the input side). - NVTEGroupedTensor in_grouped = - nvte_create_grouped_tensor(get_nvte_scaling_mode(JAXX_Scaling_Mode::NO_SCALING), - n_groups, data_shape); + NVTEGroupedTensor in_grouped = nvte_create_grouped_tensor( + get_nvte_scaling_mode(JAXX_Scaling_Mode::NO_SCALING), n_groups, data_shape); { NVTEBasicTensor in_data{reinterpret_cast(inputs.untyped_data()), static_cast(in_dtype), data_shape}; @@ -574,9 +572,8 @@ Error_Type GroupedQuantizeV2FFI(cudaStream_t stream, Buffer_Type inputs, } // Build output grouped tensor. - NVTEGroupedTensor out_grouped = - nvte_create_grouped_tensor(get_nvte_scaling_mode(JAXX_Scaling_Mode::MXFP8_1D_SCALING), - n_groups, data_shape); + NVTEGroupedTensor out_grouped = nvte_create_grouped_tensor( + get_nvte_scaling_mode(JAXX_Scaling_Mode::MXFP8_1D_SCALING), n_groups, data_shape); // Set group sizes and offsets on output tensor (same device pointers). { diff --git a/transformer_engine/jax/flax/module.py b/transformer_engine/jax/flax/module.py index dbd6fb1fef..17c9a242f0 100644 --- a/transformer_engine/jax/flax/module.py +++ b/transformer_engine/jax/flax/module.py @@ -1361,7 +1361,12 @@ def kernel_1_init(key, num_kernels, stack_axis, *init_args): return out, ln_output # Output, layer_norm_output -def wrap_function_in_te_state_module(f, quantization_recipe, name: Optional[str] = None, quantization_checkpoint_name: Optional[str] = None): +def wrap_function_in_te_state_module( + f, + quantization_recipe, + name: Optional[str] = None, + quantization_checkpoint_name: Optional[str] = None, +): """Wraps the given function `f` to support TransformerEngine quantization. This method does a couple things: @@ -1452,8 +1457,9 @@ def make_grouped_dense_cls(quantization_recipe, quantization_checkpoint_name: Op if quantization_recipe is not None: allowed_grouped_gemm_recipes = [MXFP8BlockScaling] assert any(isinstance(quantization_recipe, r) for r in allowed_grouped_gemm_recipes), ( - f"Only the following quantization recipes are supported for grouped GEMM or `None` for BF16 without quantization: {allowed_grouped_gemm_recipes}. " - f"Got {type(quantization_recipe)}." + "Only the following quantization recipes are supported for grouped GEMM or `None` for" + f" BF16 without quantization: {allowed_grouped_gemm_recipes}. Got" + f" {type(quantization_recipe)}." ) def te_grouped_dot_general(generate_quantizer_set, x, kernel, group_sizes, **kwargs): @@ -1471,6 +1477,8 @@ def te_grouped_dot_general(generate_quantizer_set, x, kernel, group_sizes, **kwa return out return wrap_function_in_te_state_module( - te_grouped_dot_general, quantization_recipe, "ragged_dot", + te_grouped_dot_general, + quantization_recipe, + "ragged_dot", quantization_checkpoint_name=quantization_checkpoint_name, )() diff --git a/transformer_engine/jax/quantize/tensor.py b/transformer_engine/jax/quantize/tensor.py index 4511b63a64..4b604502b0 100644 --- a/transformer_engine/jax/quantize/tensor.py +++ b/transformer_engine/jax/quantize/tensor.py @@ -704,8 +704,10 @@ def create_1x( dequantizer = ScalingModeToDequantizerMap.get(scaling_mode) - if first_dims is not None or last_dims is not None or ( - original_shape is not None and group_axis is not None + if ( + first_dims is not None + or last_dims is not None + or (original_shape is not None and group_axis is not None) ): assert ( original_shape is not None @@ -713,7 +715,9 @@ def create_1x( flatten_axis = (len(original_shape) + flatten_axis) % len(original_shape) # Determine num_groups from whichever dims array is provided, or from original_shape - active_dims = first_dims if first_dims is not None and first_dims.size > 0 else last_dims + active_dims = ( + first_dims if first_dims is not None and first_dims.size > 0 else last_dims + ) if active_dims is not None: num_groups = active_dims.size else: From bee7f3b5d0f8f0675e1836d5adcdfcfaf137f622 Mon Sep 17 00:00:00 2001 From: Jeremy Berchtold Date: Mon, 23 Mar 2026 14:21:49 -0700 Subject: [PATCH 25/60] Address comments Signed-off-by: Jeremy Berchtold --- tests/jax/test_custom_call_compute.py | 8 +- transformer_engine/jax/cpp_extensions/gemm.py | 147 ++++++---- .../jax/cpp_extensions/quantization.py | 37 ++- transformer_engine/jax/csrc/extensions.h | 20 +- .../jax/csrc/extensions/gemm.cpp | 31 +-- transformer_engine/jax/dense.py | 258 ++++-------------- .../jax/quantize/dequantizer.py | 9 +- transformer_engine/jax/quantize/quantizer.py | 16 +- .../jax/quantize/scaling_modes.py | 21 +- transformer_engine/jax/quantize/tensor.py | 131 ++++----- 10 files changed, 270 insertions(+), 408 deletions(-) diff --git a/tests/jax/test_custom_call_compute.py b/tests/jax/test_custom_call_compute.py index 9fddbc435c..e429a303fe 100644 --- a/tests/jax/test_custom_call_compute.py +++ b/tests/jax/test_custom_call_compute.py @@ -1789,10 +1789,10 @@ def test_grouped_gemm_fp16(self, dtype, input_shape, layout): # jitting grouped_gemm lhs_tensor = GroupedNoScaleTensor( - data=lhs, first_dims=group_sizes, last_dims=None, group_axis=0, original_shape=lhs.shape + data=lhs, amax=None, first_dims=group_sizes, last_dims=None, original_shape=lhs.shape ) rhs_tensor = GroupedNoScaleTensor( - data=rhs, first_dims=None, last_dims=None, group_axis=0, original_shape=rhs.shape + data=rhs, amax=None, first_dims=None, last_dims=None, original_shape=rhs.shape ) prim_out = jax.jit( tex.grouped_gemm, static_argnames=("contracting_dims", "use_async_d2h_group_sizes") @@ -1832,10 +1832,10 @@ def test_grouped_gemm_fp8(self, fwd_bwd_dtype, scaling_mode, input_shape, layout ref_out = self._ref_grouped_dense(lhs, rhs, None, group_sizes, contracting_dims) lhs_tensor = GroupedNoScaleTensor( - data=lhs, first_dims=group_sizes, last_dims=None, group_axis=0, original_shape=lhs.shape + data=lhs, amax=None, first_dims=group_sizes, last_dims=None, original_shape=lhs.shape ) rhs_tensor = GroupedNoScaleTensor( - data=rhs, first_dims=None, last_dims=None, group_axis=0, original_shape=rhs.shape + data=rhs, amax=None, first_dims=None, last_dims=None, original_shape=rhs.shape ) prim_out = jax.jit(tex.grouped_gemm, static_argnames=("contracting_dims",))( lhs_tensor, diff --git a/transformer_engine/jax/cpp_extensions/gemm.py b/transformer_engine/jax/cpp_extensions/gemm.py index 1f69535b9b..a5e275d663 100644 --- a/transformer_engine/jax/cpp_extensions/gemm.py +++ b/transformer_engine/jax/cpp_extensions/gemm.py @@ -80,6 +80,7 @@ except RuntimeError as e: if "cublas" in str(e).lower(): _v2_grouped_gemm_available = False + _v2_grouped_gemm_available_reason = str(e) else: raise @@ -1433,7 +1434,7 @@ class GroupedGemmPrimitive(BasePrimitive): # out_first_dims, out_last_dims, alpha, beta name_graph_safe = "te_grouped_gemm_v2_ffi" multiple_results = True - impl_static_args = (13, 14, 15, 16, 17, 18, 19, 20, 21, 22) + impl_static_args = (13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26) inner_primitive = None outer_primitive = None @@ -1460,15 +1461,19 @@ def abstract( use_v2_ffi, lhs_axis_boundary, rhs_axis_boundary, - rhs_group_axis, + out_shape, + lhs_left_size, + lhs_right_size, + rhs_left_size, + rhs_right_size, ): """ Grouped GEMM operation. Args: - lhs_data: Left-hand side input matrix data, N-D array + lhs_data: Left-hand side input matrix data (may be 1D for quantized) lhs_scale_inv: Left-hand side input scale_inv matrix, 1D flattened array - rhs_data: Right-hand side input matrix data, N-D array + rhs_data: Right-hand side input matrix data (may be 1D for quantized) rhs_scale_inv: Right-hand side input scale_inv matrix, 1D flattened array bias: Bias matrix of shape (G, N) lhs_first_dims: (G,) int32 if lhs first-dim is ragged, else empty (0,) sentinel @@ -1484,9 +1489,11 @@ def abstract( scaling_mode: Scaling mode for the GEMM operations out_dtype: Data type of the output tensors has_bias: Boolean indicating if bias tensors are provided - lhs_axis_boundary: Axis split point for lhs N-D → 2D flattening - rhs_axis_boundary: Axis split point for rhs N-D → 2D flattening - rhs_group_axis: Batch-group axis of rhs to exclude from output non-contracting dims + out_shape: Pre-computed output shape tuple + lhs_left_size: Product of lhs dims before axis_boundary + lhs_right_size: Product of lhs dims after axis_boundary + rhs_left_size: Product of rhs dims before axis_boundary + rhs_right_size: Product of rhs dims after axis_boundary Returns: A jnp.ndarray containing the result of the grouped GEMM operation @@ -1514,34 +1521,6 @@ def abstract( num_groups, ) - # Derive output shape from N-D buffer shapes using axis_boundary. - lhs_shape = lhs_data_aval.shape - rhs_shape = rhs_data_aval.shape - - # Non-contracting dims for lhs - if lhs_is_trans: - lhs_non_contracting = lhs_shape[lhs_axis_boundary:] - else: - lhs_non_contracting = lhs_shape[:lhs_axis_boundary] - - # Non-contracting dims for rhs (excluding batch-group axis where applicable) - if rhs_is_trans: - rhs_non_contracting = tuple( - rhs_shape[d] - for d in range(rhs_axis_boundary) - if rhs_group_axis is None or d != rhs_group_axis - ) - else: - rhs_non_contracting = rhs_shape[rhs_axis_boundary:] - - # K validation is intentionally skipped: per-group K values may not fill the - # entire buffer (padding is allowed), so sum(rhs_*_dims) != buffer K is acceptable. - if rhs_first_dims_aval.size > 0 or rhs_last_dims_aval.size > 0: - # Wgrad case: rhs has ragged contracting K dimension → output gets G prefix. - out_shape = (num_groups, *lhs_non_contracting, *rhs_non_contracting) - else: - out_shape = (*lhs_non_contracting, *rhs_non_contracting) - cublas_workspace_aval = jax.core.ShapedArray( shape=( GroupedGemmPrimitive._compute_cublas_workspace_size( @@ -1649,9 +1628,13 @@ def lowering( use_v2_ffi, lhs_axis_boundary, rhs_axis_boundary, - rhs_group_axis, + out_shape, + lhs_left_size, + lhs_right_size, + rhs_left_size, + rhs_right_size, ): - del out_dtype, rhs_group_axis # Python-only; not forwarded to C++ + del out_dtype, out_shape # Python-only; not forwarded to C++ if use_v2_ffi: ffi_name = GroupedGemmPrimitive.name_graph_safe return jax.ffi.ffi_lowering(ffi_name)( @@ -1662,6 +1645,10 @@ def lowering( scaling_mode=scaling_mode.value, lhs_axis_boundary=lhs_axis_boundary, rhs_axis_boundary=rhs_axis_boundary, + lhs_left_size=lhs_left_size, + lhs_right_size=lhs_right_size, + rhs_left_size=rhs_left_size, + rhs_right_size=rhs_right_size, ) ffi_name = GroupedGemmPrimitive.name return jax.ffi.ffi_lowering(ffi_name)( @@ -1674,6 +1661,10 @@ def lowering( use_async_d2h_group_sizes=use_async_d2h_group_sizes, lhs_axis_boundary=lhs_axis_boundary, rhs_axis_boundary=rhs_axis_boundary, + lhs_left_size=lhs_left_size, + lhs_right_size=lhs_right_size, + rhs_left_size=rhs_left_size, + rhs_right_size=rhs_right_size, ) @staticmethod @@ -1700,7 +1691,11 @@ def impl( use_v2_ffi, lhs_axis_boundary, rhs_axis_boundary, - rhs_group_axis, + out_shape, + lhs_left_size, + lhs_right_size, + rhs_left_size, + rhs_right_size, ): if GroupedGemmPrimitive.inner_primitive is None: raise RuntimeError("GroupedGemmPrimitive.inner_primitive has not been registered") @@ -1730,7 +1725,11 @@ def impl( use_v2_ffi=use_v2_ffi, lhs_axis_boundary=lhs_axis_boundary, rhs_axis_boundary=rhs_axis_boundary, - rhs_group_axis=rhs_group_axis, + out_shape=out_shape, + lhs_left_size=lhs_left_size, + lhs_right_size=lhs_right_size, + rhs_left_size=rhs_left_size, + rhs_right_size=rhs_right_size, ) return (out,) @@ -2019,6 +2018,10 @@ def grouped_gemm_copy_group_sizes( ) return out +@cache +def _should_enforce_v2_grouped_gemm() -> bool: + """Read NVTE_JAX_ENFORCE_V2_GROUPED_GEMM once per process (cached).""" + return os.getenv("NVTE_JAX_ENFORCE_V2_GROUPED_GEMM", "0") == "1" def _can_use_v2_grouped_gemm( scaling_mode: ScalingMode, @@ -2031,15 +2034,26 @@ def _can_use_v2_grouped_gemm( # feature-compatible with the main branch. # Bias can be supported in a kernel or in pure-JAX in the future. + enforce_v2_gmm = _should_enforce_v2_grouped_gemm() + if not _v2_grouped_gemm_available: + if enforce_v2_gmm: + raise RuntimeError(f"The TE V2 grouped GEMM is not available but NVTE_JAX_ENFORCE_V2_GROUPED_GEMM is enabled. The reason for V2 grouped GEMM not being available: {_v2_grouped_gemm_available_reason}") return False # nvte_grouped_gemm (the v2 kernel) requires SM100+ (Blackwell or newer). # Fall back to the v1 path on SM90 (Hopper) and older architectures. if get_device_compute_capability(0) < 100: + if enforce_v2_gmm: + raise RuntimeError(f"The TE V2 grouped GEMM requires SM100+ (Blackwell or newer) but current device compute capability of GPU 0 is {get_device_compute_capability(0)} and NVTE_JAX_ENFORCE_V2_GROUPED_GEMM is enabled.") return False - return scaling_mode == ScalingMode.NO_SCALING and dtype == jnp.bfloat16 and not has_bias + if scaling_mode == ScalingMode.NO_SCALING and dtype == jnp.bfloat16 and not has_bias: + return True + + if enforce_v2_gmm: + raise RuntimeError(f"The TE V2 grouped GEMM currently only supports BF16 with no quantization recipe and without bias, but received {scaling_mode=}, {dtype=}, {has_bias=}") + return False def grouped_gemm( @@ -2076,6 +2090,8 @@ def grouped_gemm( empty_gs = jnp.empty((0,), jnp.int32) # Extract data, dims, and metadata from tensor objects. + # Keep data in its original layout (may be 1D for quantized tensors) to preserve + # JAX sharding; the C++ side uses original_shape to derive m/n/k. if isinstance(lhs, GroupedNoScaleTensor): lhs_data = lhs.data lhs_shape = lhs.original_shape @@ -2084,16 +2100,14 @@ def grouped_gemm( out_dtype = lhs.data.dtype lhs_first_dims = lhs.first_dims if lhs.first_dims is not None else empty_gs lhs_last_dims = lhs.last_dims if lhs.last_dims is not None else empty_gs - rhs_group_axis = getattr(rhs, "group_axis", 0) elif isinstance(lhs, GroupedScaledTensor1x): lhs_shape = lhs.original_shape - lhs_data = lhs.data.reshape(lhs_shape) + lhs_data = lhs.data lhs_scale_inv = lhs.scale_inv scaling_mode = lhs.scaling_mode out_dtype = lhs.dq_dtype lhs_first_dims = lhs.first_dims if lhs.first_dims is not None else empty_gs lhs_last_dims = lhs.last_dims if lhs.last_dims is not None else empty_gs - rhs_group_axis = getattr(rhs, "group_axis", 0) else: raise TypeError( f"lhs must be GroupedNoScaleTensor or GroupedScaledTensor1x, got type={type(lhs)}" @@ -2107,7 +2121,7 @@ def grouped_gemm( rhs_last_dims = rhs.last_dims if rhs.last_dims is not None else empty_gs elif isinstance(rhs, GroupedScaledTensor1x): rhs_shape = rhs.original_shape - rhs_data = rhs.data.reshape(rhs_shape) + rhs_data = rhs.data rhs_scale_inv = rhs.scale_inv rhs_first_dims = rhs.first_dims if rhs.first_dims is not None else empty_gs rhs_last_dims = rhs.last_dims if rhs.last_dims is not None else empty_gs @@ -2194,8 +2208,8 @@ def grouped_gemm( rhs_q = grouped_quantize( rhs_input_data, quantizer_set.kernel, group_sizes=None, flatten_axis=rhs_flatten_axis ) - lhs_data = lhs_q.data.reshape(lhs_q.original_shape) - rhs_data = rhs_q.data.reshape(rhs_q.original_shape) + lhs_data = lhs_q.data + rhs_data = rhs_q.data lhs_scale_inv = lhs_q.scale_inv rhs_scale_inv = rhs_q.scale_inv lhs_shape = lhs_q.original_shape @@ -2254,17 +2268,34 @@ def grouped_gemm( "Ensure lhs or rhs tensor objects carry first_dims or last_dims." ) + # Pre-compute collapsed 2D sizes from original N-D shapes. + # These are static Python ints passed as primitive parameters (must be hashable). + lhs_left_size = math.prod(lhs_shape[:lhs_axis_boundary]) + lhs_right_size = math.prod(lhs_shape[lhs_axis_boundary:]) + rhs_left_size = math.prod(rhs_shape[:rhs_axis_boundary]) + rhs_right_size = math.prod(rhs_shape[rhs_axis_boundary:]) + + # Pre-compute output shape from N-D input shapes (static Python ints). + if lhs_is_trans: + lhs_non_contracting = lhs_shape[lhs_axis_boundary:] + else: + lhs_non_contracting = lhs_shape[:lhs_axis_boundary] + if rhs_is_trans: + rhs_non_contracting = tuple( + rhs_shape[d] + for d in range(rhs_axis_boundary) + if d != 0 + ) + else: + rhs_non_contracting = rhs_shape[rhs_axis_boundary:] + if rhs_first_dims.size > 0 or rhs_last_dims.size > 0: + out_shape = (num_gemms, *lhs_non_contracting, *rhs_non_contracting) + else: + out_shape = (*lhs_non_contracting, *rhs_non_contracting) + has_bias = bias is not None if has_bias: - # Compute N from rhs non-contracting dims. - if rhs_is_trans: - N_dim = math.prod( - rhs_data.shape[d] - for d in range(rhs_axis_boundary) - if rhs_group_axis is None or d != rhs_group_axis - ) - else: - N_dim = math.prod(rhs_data.shape[rhs_axis_boundary:]) + N_dim = math.prod(rhs_non_contracting) assert bias.shape == ( num_gemms, N_dim, @@ -2309,6 +2340,10 @@ def grouped_gemm( use_v2_ffi=use_v2_ffi, lhs_axis_boundary=lhs_axis_boundary, rhs_axis_boundary=rhs_axis_boundary, - rhs_group_axis=rhs_group_axis, + out_shape=tuple(int(d) for d in out_shape), + lhs_left_size=int(lhs_left_size), + lhs_right_size=int(lhs_right_size), + rhs_left_size=int(rhs_left_size), + rhs_right_size=int(rhs_right_size), ) return out diff --git a/transformer_engine/jax/cpp_extensions/quantization.py b/transformer_engine/jax/cpp_extensions/quantization.py index c8578d48b8..ce0f186aeb 100644 --- a/transformer_engine/jax/cpp_extensions/quantization.py +++ b/transformer_engine/jax/cpp_extensions/quantization.py @@ -43,6 +43,7 @@ ScalingMode, compute_scale_from_amax, NoScaleTensor, + GroupedNoScaleTensor, get_rht_matrix, QuantizeLayout, ) @@ -1001,7 +1002,6 @@ class GroupedQuantizePrimitive(BasePrimitive): 5, 6, 7, - 8, ) # out_dtype, scaling_mode, q_layout, flatten_axis, scale_dtype inner_primitive = None outer_primitive = None @@ -1016,7 +1016,6 @@ def abstract( scaling_mode, q_layout, flatten_axis, - group_axis, scale_dtype, ): """ @@ -1038,7 +1037,6 @@ def abstract( ).get_grouped_scale_shape_2x( x_aval.shape, group_sizes_aval.size, - group_axis, is_padded=True, flatten_axis=flatten_axis, ) @@ -1099,7 +1097,6 @@ def lowering( scaling_mode, q_layout, flatten_axis, - group_axis, scale_dtype, ): """ @@ -1110,7 +1107,6 @@ def lowering( assert x_aval.dtype in [jnp.float32, jnp.float16, jnp.bfloat16] assert scale_aval.dtype == jnp.float32 assert group_sizes_aval.dtype == jnp.int32 - assert group_axis == 0 return ffi.ffi_lowering(GroupedQuantizePrimitive.name)( ctx, x, @@ -1130,7 +1126,6 @@ def impl( scaling_mode, q_layout, flatten_axis, - group_axis, scale_dtype, ): """ @@ -1151,7 +1146,6 @@ def impl( scaling_mode=scaling_mode, q_layout=q_layout, flatten_axis=flatten_axis, - group_axis=group_axis, scale_dtype=scale_dtype, ) return (rowwise_out, colwise_out, rowwise_scale_inv, colwise_scale_inv, updated_amax) @@ -1166,12 +1160,12 @@ def grouped_quantize( group_sizes: jnp.ndarray = None, amax: jnp.ndarray = None, flatten_axis: int = -1, -) -> GroupedScaledTensor1x: +) -> Union[GroupedScaledTensor1x, GroupedNoScaleTensor]: """Quantize a tensor in grouped manner. This function quantizes a tensor by splitting it into groups along a specified axis and applying quantization to each group separately. The groups can be either specified - explicitly through group_sizes or automatically split along the group_axis. + explicitly through group_sizes or automatically split along axis 0. Args: x: Input tensor to quantize @@ -1185,31 +1179,36 @@ def grouped_quantize( Note: - If group_sizes is not provided, the tensor will be split into equal-sized groups - along the group_axis - - The group_axis is currently fixed to 0 + along axis 0 - The quantizer's q_layout determines whether row-wise, column-wise, or both quantization is applied """ if quantizer is None: - if isinstance(x, NoScaleTensor): + if isinstance(x, GroupedNoScaleTensor): + assert amax is None, "If the input to grouped_quantize is already a GroupedNoScaleTensor, providing an amax could be ambiguous. Please set amax to None and set the amax on your GroupedNoScaleTensor directly, if needed. Alternatively, please call grouped_quantize with a raw jnp.ndarray along with an amax value if you'd like this function to handle amax for you." return x - return NoScaleTensor(data=x, amax=None) + return GroupedNoScaleTensor( + data=x, + amax=amax, + first_dims=group_sizes, + last_dims=None, + original_shape=x.shape, + ) # TODO(Phuong): add support for flatten_axis = -2 assert flatten_axis in ( -1, x.ndim - 1, ), f"Only flatten_axis = -1 is supported for now, got {flatten_axis}" - group_axis = 0 ragged_first_dims = group_sizes # None if no explicit group_sizes (kernel case) if group_sizes is None: - group_sizes = jnp.ones(x.shape[group_axis], dtype=jnp.int32) + group_sizes = jnp.ones(x.shape[0], dtype=jnp.int32) if not GroupedQuantizePrimitive.enabled(): return quantizer.quantize( - x, flatten_axis=flatten_axis, group_sizes=group_sizes, group_axis=group_axis + x, flatten_axis=flatten_axis, group_sizes=group_sizes ) n_groups = group_sizes.size original_shape = x.shape @@ -1226,9 +1225,9 @@ def grouped_quantize( if amax is not None: row_amax = amax else: - row_amax = jnp.max(jnp.abs(x), axis=range(group_axis + 1, x.ndim)) + row_amax = jnp.max(jnp.abs(x), axis=range(1, x.ndim)) segment_ids = jnp.repeat( - jnp.arange(n_groups), group_sizes, total_repeat_length=x.shape[group_axis] + jnp.arange(n_groups), group_sizes, total_repeat_length=x.shape[0] ) grouped_amax = jax.ops.segment_max(row_amax, segment_ids, num_segments=n_groups) for i in range(n_groups): @@ -1257,7 +1256,6 @@ def grouped_quantize( scaling_mode=quantizer.scaling_mode.value, q_layout=q_layout, flatten_axis=flatten_axis, - group_axis=group_axis, scale_dtype=quantizer.get_scale_dtype(), ) @@ -1283,7 +1281,6 @@ def grouped_quantize( flatten_axis=flatten_axis, first_dims=ragged_first_dims, original_shape=original_shape, - group_axis=group_axis, ) return out diff --git a/transformer_engine/jax/csrc/extensions.h b/transformer_engine/jax/csrc/extensions.h index 616209709b..a74b209e4f 100644 --- a/transformer_engine/jax/csrc/extensions.h +++ b/transformer_engine/jax/csrc/extensions.h @@ -61,6 +61,10 @@ struct GroupedGemmV2Config { JAXX_Scaling_Mode scaling_mode; int64_t lhs_axis_boundary; int64_t rhs_axis_boundary; + int64_t lhs_left_size; + int64_t lhs_right_size; + int64_t rhs_left_size; + int64_t rhs_right_size; }; struct GroupedGemmConfig { @@ -71,6 +75,10 @@ struct GroupedGemmConfig { bool use_async_d2h_group_sizes; int64_t lhs_axis_boundary; int64_t rhs_axis_boundary; + int64_t lhs_left_size; + int64_t lhs_right_size; + int64_t rhs_left_size; + int64_t rhs_right_size; }; inline bool use_fp8(DType type) { return type == DType::kFloat8E4M3 || type == DType::kFloat8E5M2; } @@ -215,7 +223,11 @@ XLA_FFI_REGISTER_STRUCT_ATTR_DECODING( ::xla::ffi::StructMember("rhs_is_trans"), ::xla::ffi::StructMember("scaling_mode"), ::xla::ffi::StructMember("lhs_axis_boundary"), - ::xla::ffi::StructMember("rhs_axis_boundary")); + ::xla::ffi::StructMember("rhs_axis_boundary"), + ::xla::ffi::StructMember("lhs_left_size"), + ::xla::ffi::StructMember("lhs_right_size"), + ::xla::ffi::StructMember("rhs_left_size"), + ::xla::ffi::StructMember("rhs_right_size")); XLA_FFI_REGISTER_STRUCT_ATTR_DECODING( transformer_engine::jax::GroupedGemmConfig, ::xla::ffi::StructMember("lhs_is_trans"), @@ -224,7 +236,11 @@ XLA_FFI_REGISTER_STRUCT_ATTR_DECODING( ::xla::ffi::StructMember("has_bias"), ::xla::ffi::StructMember("use_async_d2h_group_sizes"), ::xla::ffi::StructMember("lhs_axis_boundary"), - ::xla::ffi::StructMember("rhs_axis_boundary")); + ::xla::ffi::StructMember("rhs_axis_boundary"), + ::xla::ffi::StructMember("lhs_left_size"), + ::xla::ffi::StructMember("lhs_right_size"), + ::xla::ffi::StructMember("rhs_left_size"), + ::xla::ffi::StructMember("rhs_right_size")); // ENUM_ATTR and DICT_ATTR recoding need to be registered in the global namespace XLA_FFI_REGISTER_ENUM_ATTR_DECODING(transformer_engine::jax::JAXX_Scaling_Mode); diff --git a/transformer_engine/jax/csrc/extensions/gemm.cpp b/transformer_engine/jax/csrc/extensions/gemm.cpp index 433fd38197..fb76b7052e 100644 --- a/transformer_engine/jax/csrc/extensions/gemm.cpp +++ b/transformer_engine/jax/csrc/extensions/gemm.cpp @@ -698,7 +698,8 @@ Error_Type GroupedGemmV2FFI(cudaStream_t stream, Buffer_Type lhs_data, Buffer_Ty Buffer_Type alpha, Buffer_Type beta, Result_Type output, Result_Type cublas_workspace, Result_Type setup_workspace, Result_Type int64_workspace, GroupedGemmV2Config config) { - auto [lhs_is_trans, rhs_is_trans, scaling_mode, lhs_axis_boundary, rhs_axis_boundary] = config; + auto [lhs_is_trans, rhs_is_trans, scaling_mode, lhs_axis_boundary, rhs_axis_boundary, + lhs_left_size, lhs_right_size, rhs_left_size, rhs_right_size] = config; NVTE_CHECK(scaling_mode == JAXX_Scaling_Mode::NO_SCALING, "Only non-quantized grouped GEMM is supported in current implementation."); @@ -780,7 +781,8 @@ Error_Type GroupedGemmFFI(cudaStream_t stream, Buffer_Type lhs_data, Buffer_Type Buffer_Type group_offset, Result_Type output, Result_Type workspace, GroupedGemmConfig config) { auto [lhs_is_trans, rhs_is_trans, scaling_mode, has_bias, use_async_d2h_group_sizes, - lhs_axis_boundary, rhs_axis_boundary] = config; + lhs_axis_boundary, rhs_axis_boundary, + lhs_left_size, lhs_right_size, rhs_left_size, rhs_right_size] = config; // Notes on matrix layouts and transpose: // Jax uses row-major data_layout, on entering this function, each input matrix pair: // A: row-major [m, k] for N - [k, m] for T @@ -830,26 +832,19 @@ Error_Type GroupedGemmFFI(cudaStream_t stream, Buffer_Type lhs_data, Buffer_Type else if (is_rhs_last_ragged) active_gs_ptr = &rhs_last_dims; - // Derive m, n, k from N-D buffer dimensions using axis_boundary. - // axis_boundary splits contracting dims from non-contracting dims. - auto lhs_dims = lhs_data.dimensions(); - auto rhs_dims = rhs_data.dimensions(); - NVTE_CHECK(lhs_dims.size() >= 2, "lhs_data must be at least 2D."); - NVTE_CHECK(rhs_dims.size() >= 2, "rhs_data must be at least 2D."); - size_t lab = static_cast(lhs_axis_boundary); - size_t rab = static_cast(rhs_axis_boundary); - // k = product of contracting dims of lhs - size_t k = lhs_is_trans ? product(lhs_dims, 0, lab) : product(lhs_dims, lab, lhs_dims.size()); + // Derive m, n, k from pre-computed original shape sizes (passed from Python). + // lhs_left_size = product of original lhs dims before axis_boundary + // lhs_right_size = product of original lhs dims after axis_boundary + // Same pattern for rhs. + size_t k = lhs_is_trans ? lhs_left_size : lhs_right_size; size_t m, n; if (is_rhs_ragged) { // wgrad: non-contracting lhs dims form M; non-contracting rhs dims form N - m = lhs_is_trans ? product(lhs_dims, lab, lhs_dims.size()) : product(lhs_dims, 0, lab); - n = rhs_is_trans ? product(rhs_dims, 0, rab) : product(rhs_dims, rab, rhs_dims.size()); + m = lhs_is_trans ? lhs_right_size : lhs_left_size; + n = rhs_is_trans ? rhs_left_size : rhs_right_size; } else { - m = lhs_is_trans ? product(lhs_dims, lab, lhs_dims.size()) - : product(lhs_dims, 0, lab); // total M (sum of group sizes) - n = rhs_is_trans ? product(rhs_dims, 0, rab) / num_gemms - : product(rhs_dims, rab, rhs_dims.size()); + m = lhs_is_trans ? lhs_right_size : lhs_left_size; // total M (sum of group sizes) + n = rhs_is_trans ? rhs_left_size / num_gemms : rhs_right_size; } // Inputs diff --git a/transformer_engine/jax/dense.py b/transformer_engine/jax/dense.py index 76c984486f..aea957145f 100644 --- a/transformer_engine/jax/dense.py +++ b/transformer_engine/jax/dense.py @@ -18,16 +18,11 @@ from . import cpp_extensions as tex from .cpp_extensions.amax import AmaxScope from .quantize import ( - ScaledTensorFactory, ScaledTensor, - ScalingMode, QuantizerSet, noop_quantizer_set, with_sharding_constraint_by_logical_axes, - is_fp8_gemm_with_all_layouts_supported, TensorUsage, - QuantizeLayout, - GroupedNoScaleTensor, ) @@ -416,126 +411,40 @@ def _grouped_dense_fwd_rule( kernel_fsdp_info, ): use_bias = bias is not None - is_noop_quantizer_set = quantizer_set == noop_quantizer_set kernel_fsdp_mesh_axis, kernel_fsdp_axis_idx = kernel_fsdp_info kernel_fsdp_enabled = kernel_fsdp_mesh_axis is not None + assert not kernel_fsdp_enabled, "FSDP sharding for grouped_dense is not supported yet." + del kernel_fsdp_mesh_axis, kernel_fsdp_axis_idx, kernel_fsdp_info, kernel_fsdp_enabled - if is_noop_quantizer_set: - grouped_gemm_x = x - grouped_gemm_kernel = kernel - ctx_x = x - ctx_kernel = kernel - flatten_axis_k = None - - if kernel_fsdp_enabled: - kernel = _all_gather_kernel(kernel, kernel_fsdp_mesh_axis, kernel_fsdp_axis_idx) - else: - original_quantizer_set_kernel_q_layout = quantizer_set.kernel.q_layout - - x_contracting_dims, k_contracting_dims = contracting_dims - flatten_axis_x = -len(x_contracting_dims) - flatten_axis_k = len(k_contracting_dims) - len(kernel.shape) + 1 # +1 for G axis - - assert x.ndim == 2, "Grouped dense expects a 2D input tensor of shape (M, K)" - assert kernel.ndim == 3, "Grouped dense expects a 3D kernel tensor of shape (G, K, N)" - # Expected k_contracting_dims == (1,), need to tweak it for grouped_gemm FP8 extra transpose - # TODO(Hua): Do we have a better way for this? What if is_gemm_with_all_layouts_supported()? - assert x_contracting_dims == (1,) and k_contracting_dims == (1,), ( - "grouped_dense for FP8 can only handle x_contracting_dims=(1,) " - "and k_contracting_dims=(1,) for now, " - f"got {x_contracting_dims=} and {k_contracting_dims=}" - ) - casted_x = tex.grouped_quantize( - x, - quantizer_set.x, - group_sizes, - flatten_axis=flatten_axis_x, - ) + - ctx_kernel_usage = TensorUsage.RHS_TRANS - if kernel_fsdp_enabled: - assert quantizer_set.kernel.scaling_mode in [ - ScalingMode.CURRENT_TENSOR_SCALING, - ScalingMode.DELAYED_TENSOR_SCALING, - ] - # Perform `cast` only - ctx_kernel_usage = TensorUsage.LHS - quantizer_set.kernel.q_layout = QuantizeLayout.ROWWISE - - casted_kernel = tex.grouped_quantize( - kernel, quantizer_set.kernel, amax=kernel_amax, flatten_axis=flatten_axis_k - ) - contracting_dims = (x_contracting_dims, k_contracting_dims) - - # For x_contracting_dims == (1,) and k_contracting_dims == (1,), we should have - # rowwise_casted_x.original_shape == (M, K) - # colwise_casted_kernel.original_shape == (G, N, K) - grouped_gemm_x = casted_x.get_tensor(usage=TensorUsage.LHS) - ctx_x = casted_x.get_tensor(usage=TensorUsage.LHS_TRANS) - ctx_kernel = casted_kernel.get_tensor(usage=ctx_kernel_usage) - - if kernel_fsdp_enabled: - ctx_kernel_in_original_shape = ctx_kernel.data.reshape(ctx_kernel.original_shape) - global_ctx_kernel_data = _all_gather_kernel( - ctx_kernel_in_original_shape, kernel_fsdp_mesh_axis, kernel_fsdp_axis_idx - ) - kernel_shape = global_ctx_kernel_data.shape - - ctx_kernel = ScaledTensorFactory.create_1x( - global_ctx_kernel_data.reshape(-1), - ctx_kernel.scale_inv, - scaling_mode=ctx_kernel.scaling_mode, - dq_dtype=ctx_kernel.dq_dtype, - is_colwise=False, - data_layout="N", - flatten_axis=ctx_kernel.flatten_axis, - first_dims=ctx_kernel.first_dims, - last_dims=ctx_kernel.last_dims, - original_shape=kernel_shape, - group_axis=ctx_kernel.group_axis, - ) - - if is_fp8_gemm_with_all_layouts_supported(): - grouped_gemm_kernel = ctx_kernel - else: - grouped_gemm_kernel_data = global_ctx_kernel_data.transpose(0, 2, 1) - grouped_gemm_kernel = ScaledTensorFactory.create_1x( - grouped_gemm_kernel_data.reshape(-1), - ctx_kernel.scale_inv, - scaling_mode=ctx_kernel.scaling_mode, - dq_dtype=ctx_kernel.dq_dtype, - is_colwise=True, - data_layout="T", - flatten_axis=ctx_kernel.flatten_axis, - first_dims=ctx_kernel.first_dims, - last_dims=ctx_kernel.last_dims, - original_shape=kernel_shape, - group_axis=ctx_kernel.group_axis, - ) - else: - grouped_gemm_kernel = casted_kernel.get_tensor(usage=TensorUsage.RHS) - - # Reset quantizer_set.kernel.q_layout to align the PyTree as the given one. - # This is needed especially when kernel_fsdp_enabled == True AND FP8 enabled. - quantizer_set.kernel.q_layout = original_quantizer_set_kernel_q_layout - - if is_noop_quantizer_set: - grouped_gemm_x = GroupedNoScaleTensor( - data=grouped_gemm_x, - first_dims=group_sizes, - last_dims=None, - group_axis=0, - original_shape=grouped_gemm_x.shape, - ) - grouped_gemm_kernel = GroupedNoScaleTensor( - data=grouped_gemm_kernel, - first_dims=None, - last_dims=None, - group_axis=0, - original_shape=grouped_gemm_kernel.shape, - ) + + x_contracting_dims, k_contracting_dims = contracting_dims + flatten_axis_x = -len(x_contracting_dims) + flatten_axis_k = len(k_contracting_dims) - len(kernel.shape) + 1 # +1 for G axis + + casted_x = tex.grouped_quantize( + x, + quantizer_set.x, + group_sizes, + flatten_axis=flatten_axis_x, + ) + + casted_kernel = tex.grouped_quantize( + kernel, quantizer_set.kernel, amax=kernel_amax, flatten_axis=flatten_axis_k + ) + contracting_dims = (x_contracting_dims, k_contracting_dims) + + # For x_contracting_dims == (1,) and k_contracting_dims == (1,), we should have + # rowwise_casted_x.original_shape == (M, K) + # colwise_casted_kernel.original_shape == (G, N, K) + grouped_gemm_x = casted_x.get_tensor(usage=TensorUsage.LHS) + ctx_x = casted_x.get_tensor(usage=TensorUsage.LHS_TRANS) + ctx_kernel = casted_kernel.get_tensor(usage=TensorUsage.RHS_TRANS) + + grouped_gemm_kernel = casted_kernel.get_tensor(usage=TensorUsage.RHS) output = tex.grouped_gemm( grouped_gemm_x, grouped_gemm_kernel, @@ -557,7 +466,6 @@ def _grouped_dense_fwd_rule( x.shape, kernel.shape, use_bias, - is_noop_quantizer_set, quantizer_set, flatten_axis_k, ) @@ -567,6 +475,10 @@ def _grouped_dense_fwd_rule( def _grouped_dense_bwd_rule( contracting_dims, precision, preferred_element_type, group_offset, kernel_fsdp_info, ctx, grad ): + kernel_fsdp_mesh_axis, kernel_fsdp_axis_idx = kernel_fsdp_info + kernel_fsdp_enabled = kernel_fsdp_mesh_axis is not None + assert not kernel_fsdp_enabled, "FSDP sharding for grouped_dense is not supported yet." + fwd_x_contracting_dims, fwd_k_contracting_dims = contracting_dims ( @@ -576,86 +488,37 @@ def _grouped_dense_bwd_rule( x_shape, kernel_shape, use_bias, - is_noop_quantizer_set, quantizer_set, flatten_axis_k, ) = ctx - if is_noop_quantizer_set: - # The 1 in range is for excluding the group dimension (shall we use the hardcoded results below?) - # g_contracting_dim = (1, ) - # k_contracting_dim = (2, ) - g_contracting_dim = tuple( - range(1 + grad.ndim - len(kernel_shape) + len(fwd_k_contracting_dims), grad.ndim) - ) - k_contracting_dim = tuple( - dim for dim in range(1, len(kernel_shape)) if dim not in fwd_k_contracting_dims - ) - dgrad_contracting_dims = (g_contracting_dim, k_contracting_dim) - dgrad_grad = grad - dgrad_kernel_T = ctx_kernel - - # g_contracting_dim = (0, ) - # x_contracting_dim = (0, ) - g_contracting_dim = x_contracting_dim = tuple( - range(0, len(x_shape) - len(fwd_x_contracting_dims)) - ) - wgrad_contracting_dims = (x_contracting_dim, g_contracting_dim) - wgrad_x_T = ctx_x - wgrad_grad = grad - else: - casted_grad = tex.grouped_quantize( - grad, quantizer_set.dgrad, group_sizes, flatten_axis=flatten_axis_k - ) + # The 1 in range is for excluding the group dimension (shall we use the hardcoded results below?) + # g_contracting_dim = (1, ) + # k_contracting_dim = (2, ) + g_contracting_dim = tuple( + range(1 + grad.ndim - len(kernel_shape) + len(fwd_k_contracting_dims), grad.ndim) + ) + k_contracting_dim = tuple( + dim for dim in range(1, len(kernel_shape)) if dim not in fwd_k_contracting_dims + ) - # For x_contracting_dims == (1,) and k_contracting_dims == (1,), we need to use - # g_contracting_dim = (1,) and k_contracting_dim = (2,) to make it work after the - # extra transpose for FP8 in grouped_gemm - # TODO(Hua): Do we have a better way for this? What if is_gemm_with_all_layouts_supported()? - g_contracting_dim = (1,) - k_contracting_dim = (2,) - dgrad_contracting_dims = (g_contracting_dim, k_contracting_dim) - dgrad_grad = casted_grad.get_tensor(usage=TensorUsage.LHS) - dgrad_kernel_T = ctx_kernel - - # We need to use g_contracting_dim = (0,) and x_contracting_dim = (0,) to make it work - # after the extra transpose for FP8 in grouped_gemm - # TODO(Hua): Do we have a better way for this? What if is_gemm_with_all_layouts_supported()? - g_contracting_dim = (0,) - x_contracting_dim = (0,) - wgrad_contracting_dims = (x_contracting_dim, g_contracting_dim) - wgrad_x_T = ctx_x - wgrad_grad = casted_grad.get_tensor(usage=TensorUsage.RHS) - - if is_noop_quantizer_set: - dgrad_grad = GroupedNoScaleTensor( - data=dgrad_grad, - first_dims=group_sizes, - last_dims=None, - group_axis=0, - original_shape=dgrad_grad.shape, - ) - dgrad_kernel_T = GroupedNoScaleTensor( - data=dgrad_kernel_T, - first_dims=None, - last_dims=None, - group_axis=0, - original_shape=dgrad_kernel_T.shape, - ) - wgrad_x_T = GroupedNoScaleTensor( - data=wgrad_x_T, - first_dims=group_sizes, - last_dims=None, - group_axis=0, - original_shape=wgrad_x_T.shape, - ) - wgrad_grad = GroupedNoScaleTensor( - data=wgrad_grad, - first_dims=group_sizes, - last_dims=None, - group_axis=0, - original_shape=wgrad_grad.shape, - ) + casted_grad = tex.grouped_quantize( + grad, quantizer_set.dgrad, group_sizes, flatten_axis=flatten_axis_k + ) + + dgrad_contracting_dims = (g_contracting_dim, k_contracting_dim) + dgrad_grad = casted_grad.get_tensor(usage=TensorUsage.LHS) + dgrad_kernel_T = ctx_kernel + + # g_contracting_dim = (0, ) + # x_contracting_dim = (0, ) + g_contracting_dim = x_contracting_dim = tuple( + range(0, len(x_shape) - len(fwd_x_contracting_dims)) + ) + wgrad_contracting_dims = (x_contracting_dim, g_contracting_dim) + + wgrad_x_T = ctx_x + wgrad_grad = casted_grad.get_tensor(usage=TensorUsage.RHS) dgrad = tex.grouped_gemm( dgrad_grad, dgrad_kernel_T, @@ -673,11 +536,6 @@ def _grouped_dense_bwd_rule( preferred_element_type=preferred_element_type, group_offset=group_offset, ) - kernel_fsdp_mesh_axis, kernel_fsdp_axis_idx = kernel_fsdp_info - if kernel_fsdp_mesh_axis is not None: - wgrad = _psum_scatter_kernel( - wgrad, kernel_shape, kernel_fsdp_mesh_axis, kernel_fsdp_axis_idx - ) group_sizes_grad = None dbias = tex.grouped_dbias(grad, group_sizes) if use_bias else None diff --git a/transformer_engine/jax/quantize/dequantizer.py b/transformer_engine/jax/quantize/dequantizer.py index 5075f1a664..de97fb7318 100644 --- a/transformer_engine/jax/quantize/dequantizer.py +++ b/transformer_engine/jax/quantize/dequantizer.py @@ -284,18 +284,16 @@ def _grouped_dequantize(grouped_scaled_tensor): # For non-ragged groups (kernel case), group_sizes is not stored; derive from original_shape if group_sizes is None: group_sizes = jnp.ones( - grouped_scaled_tensor.original_shape[grouped_scaled_tensor.group_axis], dtype=jnp.int32 + grouped_scaled_tensor.original_shape[0], dtype=jnp.int32 ) flatten_axis = grouped_scaled_tensor.flatten_axis scaling_mode = grouped_scaled_tensor.scaling_mode original_shape = grouped_scaled_tensor.original_shape - group_axis = grouped_scaled_tensor.group_axis - flatten_axis = len(original_shape) + flatten_axis if flatten_axis < 0 else flatten_axis output = [] non_group_shape = tuple( - original_shape[i] for i in range(len(original_shape)) if i != group_axis + original_shape[i] for i in range(len(original_shape)) if i != 0 ) matrix_sizes = group_sizes * math.prod(non_group_shape) @@ -304,9 +302,8 @@ def _grouped_dequantize(grouped_scaled_tensor): scale_inv_ptr = 0 for i, data_i in enumerate(data): data_shape_i = ( - *original_shape[:group_axis], group_sizes[i], - *original_shape[group_axis + 1 :], + *original_shape[1:], ) assert math.prod(data_shape_i) == data_i.size, ( f"math.prod({data_shape_i}) = {math.prod(data_shape_i)} which is not equal to" diff --git a/transformer_engine/jax/quantize/quantizer.py b/transformer_engine/jax/quantize/quantizer.py index 55dd7f5618..db56db935d 100644 --- a/transformer_engine/jax/quantize/quantizer.py +++ b/transformer_engine/jax/quantize/quantizer.py @@ -920,7 +920,7 @@ def __post_init__(self): self.data_layout = self.quantizers[0].data_layout def _create_grouped_tensor_from_tensor_list( - self, tensor_list, group_sizes, original_shape, group_axis, mode + self, tensor_list, group_sizes, original_shape, mode ): # mode 0 = concate, mode 1 = add # TODO(Ming Huang): Consider to apply Enum for mode. @@ -950,7 +950,6 @@ def _create_grouped_tensor_from_tensor_list( flatten_axis=tensor_list[0].flatten_axis, first_dims=group_sizes, original_shape=original_shape, - group_axis=group_axis, ) def _quantize_func(self, *args, **kwargs): @@ -964,12 +963,11 @@ def quantize( dq_dtype=None, flatten_axis=-1, group_sizes=None, - group_axis=0, ): """Quantize a tensor in grouped manner. Expected input shape: [M, K] or [G, K, N] - Split to x.shape[group_axis] number of groups if group_sizes is not given + Split to x.shape[0] number of groups if group_sizes is not given Args: x: Input tensor to quantize @@ -978,12 +976,10 @@ def quantize( dq_dtype: Data type for dequantized values flatten_axis: The axis along which the tensor could be flattened to 2D (default: -1) group_sizes: Array of ints containing the size of each group (default: None) - group_axis: The axis along which grouping is performed (default: 0) Returns: A ScaledTensor1x or ScaledTensor2x containing the quantized data """ - assert group_axis == 0, "Only group_axis == 0 is supported now!" dq_dtype = dq_dtype if dq_dtype is not None else x.dtype if flatten_axis < 0: @@ -1023,8 +1019,8 @@ def quantize( tensor_list.append(tensor) combine_mode = 1 # Add else: - group_sizes = jnp.ones(x.shape[group_axis], dtype=jnp.int32) - x = jnp.split(x, x.shape[group_axis], axis=group_axis) + group_sizes = jnp.ones(x.shape[0], dtype=jnp.int32) + x = jnp.split(x, x.shape[0], axis=0) tensor_list = [] for i in range(len(group_sizes)): @@ -1038,12 +1034,12 @@ def quantize( if is_rowwise: rowwise_tensor_list = [tensor.get_rowwise_tensor() for tensor in tensor_list] grouped_rowwise_tensor = self._create_grouped_tensor_from_tensor_list( - rowwise_tensor_list, group_sizes, original_shape, group_axis, combine_mode + rowwise_tensor_list, group_sizes, original_shape, combine_mode ) if is_colwise: colwise_tensor_list = [tensor.get_colwise_tensor() for tensor in tensor_list] grouped_colwise_tensor = self._create_grouped_tensor_from_tensor_list( - colwise_tensor_list, group_sizes, original_shape, group_axis, combine_mode + colwise_tensor_list, group_sizes, original_shape, combine_mode ) if is_colwise and is_rowwise: diff --git a/transformer_engine/jax/quantize/scaling_modes.py b/transformer_engine/jax/quantize/scaling_modes.py index 61c3af178c..26b998ba90 100644 --- a/transformer_engine/jax/quantize/scaling_modes.py +++ b/transformer_engine/jax/quantize/scaling_modes.py @@ -135,14 +135,13 @@ def get_scale_shape( @abstractmethod def get_grouped_scale_shape( - self, data_shape, n_groups, group_axis, is_colwise, is_padded=True, flatten_axis=-1 + self, data_shape, n_groups, is_colwise, is_padded=True, flatten_axis=-1 ) -> Tuple[int]: """Get the shape for scale tensors in this mode. Args: data_shape: Original shape of the data tensor n_groups: Number of groups in grouped quantization - group_axis: The axis along which grouping is performed is_colwise: Whether to use column-wise scaling is_padded: Whether to use padded shapes flatten_axis: The axis along which the tensor could be flattened to 2D (default: -1) @@ -253,7 +252,7 @@ def get_quantize_layout(self, usage: TensorUsage) -> QuantizeLayout: return QuantizeLayout.ROWWISE def get_grouped_scale_shape( - self, data_shape, n_groups, group_axis, is_colwise, is_padded=True, flatten_axis=-1 + self, data_shape, n_groups, is_colwise, is_padded=True, flatten_axis=-1 ) -> Tuple[int]: """Get the shape for scale tensors in this mode. @@ -266,7 +265,7 @@ def get_grouped_scale_shape( Returns: The shape for scale tensors """ - del data_shape, group_axis, is_colwise + del data_shape, is_colwise assert isinstance(n_groups, int) return (n_groups,) @@ -370,7 +369,7 @@ def get_quantize_layout(self, usage: TensorUsage) -> QuantizeLayout: return QuantizeLayout.COLWISE def get_grouped_scale_shape( - self, data_shape, n_groups, group_axis, is_colwise, is_padded=True, flatten_axis=-1 + self, data_shape, n_groups, is_colwise, is_padded=True, flatten_axis=-1 ) -> Tuple[int]: """Get the shape for scale tensors in this mode. @@ -383,7 +382,7 @@ def get_grouped_scale_shape( Returns: The shape for scale tensors """ - del data_shape, group_axis, is_colwise + del data_shape, is_colwise assert isinstance(n_groups, int) return (n_groups,) @@ -613,7 +612,7 @@ def get_quantize_layout(self, usage: TensorUsage) -> QuantizeLayout: return QuantizeLayout.COLWISE def get_grouped_scale_shape( - self, data_shape, n_groups, group_axis, is_colwise, is_padded=True, flatten_axis=-1 + self, data_shape, n_groups, is_colwise, is_padded=True, flatten_axis=-1 ) -> Tuple[int]: """Get the shape for grouped scale tensors in this mode. If padded: The estimiated maximal possible shape for grouped scale tensor is return instead. @@ -937,14 +936,13 @@ def get_shardy_sharding_rules( ) def get_grouped_scale_shape_2x( - self, data_shape, n_groups, group_axis, is_padded=True, flatten_axis=-1 + self, data_shape, n_groups, is_padded=True, flatten_axis=-1 ) -> Tuple[Tuple[int]]: """Get shapes for both row-wise and column-wise scaling. Args: data_shape: Shape of the data tensor n_groups: Number of groups for grouped quantization - group_axis: The axis along which grouping is performed is_padded: Whether to use padded shapes flatten_axis: The axis along which the tensor could be flattened to 2D (default: -1) @@ -954,7 +952,6 @@ def get_grouped_scale_shape_2x( rowwise_scale_shape = self.get_grouped_scale_shape( data_shape, n_groups, - group_axis, is_colwise=False, is_padded=is_padded, flatten_axis=flatten_axis, @@ -962,7 +959,6 @@ def get_grouped_scale_shape_2x( colwise_scale_shape = self.get_grouped_scale_shape( data_shape, n_groups, - group_axis, is_colwise=True, is_padded=is_padded, flatten_axis=flatten_axis, @@ -970,7 +966,7 @@ def get_grouped_scale_shape_2x( return (rowwise_scale_shape, colwise_scale_shape) def get_grouped_scale_shape( - self, data_shape, n_groups, group_axis, is_colwise, is_padded=True, flatten_axis=-1 + self, data_shape, n_groups, is_colwise, is_padded=True, flatten_axis=-1 ) -> Tuple[Tuple[int]]: """Get shapes for both row-wise and column-wise scaling. @@ -985,7 +981,6 @@ def get_grouped_scale_shape( return self._get_impl().get_grouped_scale_shape( data_shape, n_groups, - group_axis, is_colwise=is_colwise, is_padded=is_padded, flatten_axis=flatten_axis, diff --git a/transformer_engine/jax/quantize/tensor.py b/transformer_engine/jax/quantize/tensor.py index 316e4f3139..6c9062b4f1 100644 --- a/transformer_engine/jax/quantize/tensor.py +++ b/transformer_engine/jax/quantize/tensor.py @@ -369,13 +369,11 @@ class GroupedScaledTensor1x(ScaledTensor1x): first_dims: Per-group sizes of the first (row) 2D dim, or None if not ragged last_dims: Per-group sizes of the last (col) 2D dim, or None if not ragged original_shape: The original shape of the tensor before grouping - group_axis: The axis along which grouping is performed (default: 0) """ first_dims: Optional[jnp.ndarray] last_dims: Optional[jnp.ndarray] original_shape: Tuple - group_axis: int def __init__( self, @@ -383,6 +381,7 @@ def __init__( scale_inv, amax, first_dims, + last_dims, scaling_mode, dq_dtype, _dq_func, @@ -390,14 +389,11 @@ def __init__( data_layout, flatten_axis, original_shape, - group_axis=0, - last_dims=None, ): self.flatten_axis = flatten_axis self.first_dims = first_dims self.last_dims = last_dims self.original_shape = original_shape - self.group_axis = group_axis # TODO(Phuong):Handle RHT for grouped quantization once grouped quantization supports NVFP4 super().__init__( data=data, @@ -415,7 +411,6 @@ def __init__( def __post_init__(self): assert self.scale_inv.ndim == 1, "Only support flattened scale_inv" assert self.data.ndim == 1, "Only support flattened data" - assert self.group_axis >= 0 assert self.flatten_axis > 0 data_ndim = len(self.original_shape) @@ -423,10 +418,6 @@ def __post_init__(self): 0 < self.flatten_axis < data_ndim ), f"flatten_axis {self.flatten_axis} is out of bounds for data.ndim = {data_ndim}" - assert ( - 0 <= self.group_axis < data_ndim - ), f"group_axis {self.group_axis} is out of bounds for shape {self.original_shape}" - active_dims = ( self.first_dims if self.first_dims is not None and self.first_dims.size > 0 @@ -435,12 +426,11 @@ def __post_init__(self): if active_dims is not None: num_groups = active_dims.size else: - num_groups = self.original_shape[self.group_axis] + num_groups = self.original_shape[0] expected_scale_shape = self.scaling_mode.get_grouped_scale_shape( self.original_shape, num_groups, - self.group_axis, self.is_colwise, is_padded=True, flatten_axis=self.flatten_axis, @@ -466,40 +456,9 @@ def tree_flatten(self): self.data_layout, self.flatten_axis, self.original_shape, - self.group_axis, ) return (children, aux_data) - @classmethod - def tree_unflatten(cls, aux_data, children): - """Reconstructs the tensor from its flattened representation.""" - data, scale_inv, amax, first_dims, last_dims = children - ( - scaling_mode, - dq_dtype, - _dq_func, - is_colwise, - data_layout, - flatten_axis, - original_shape, - group_axis, - ) = aux_data - return cls( - data=data, - scale_inv=scale_inv, - amax=amax, - first_dims=first_dims, - last_dims=last_dims, - scaling_mode=scaling_mode, - dq_dtype=dq_dtype, - _dq_func=_dq_func, - is_colwise=is_colwise, - data_layout=data_layout, - flatten_axis=flatten_axis, - original_shape=original_shape, - group_axis=group_axis, - ) - def apply_sharding_constraint_by_logical_axes(self, logical_axis_names: Tuple[str, ...]): raise NotImplementedError @@ -520,7 +479,7 @@ def checkpoint(self, quantizer): @register_pytree_node_class @dataclass -class GroupedNoScaleTensor: +class GroupedNoScaleTensor(AbstractBaseTensor1x): """Unquantized grouped tensor. Stores N-D data with per-group dimension sizes so that grouped_gemm() @@ -530,39 +489,68 @@ class GroupedNoScaleTensor: data: The raw (unquantized) tensor data in N-D layout first_dims: Per-group sizes of the first (row) 2D dim, or None if not ragged last_dims: Per-group sizes of the last (col) 2D dim, or None if not ragged - group_axis: Which axis of original_shape is the group batch prefix original_shape: Shape of data (same as data.shape for N-D unquantized) """ - data: jnp.ndarray first_dims: Optional[jnp.ndarray] last_dims: Optional[jnp.ndarray] - group_axis: int original_shape: Tuple def tree_flatten(self): """Flattens the tensor for JAX tree operations.""" - children = (self.data, self.first_dims, self.last_dims) - aux_data = (self.group_axis, self.original_shape) + children = (self.data, self.amax, self.first_dims, self.last_dims) + aux_data = (self.original_shape,) return (children, aux_data) - @classmethod - def tree_unflatten(cls, aux_data, children): - """Reconstructs the tensor from its flattened representation.""" - group_axis, original_shape = aux_data - data, first_dims, last_dims = children - return cls( - data=data, - first_dims=first_dims, - last_dims=last_dims, - group_axis=group_axis, - original_shape=original_shape, - ) + @property + def ndim(self): + """Number of dimensions of the underlying array.""" + return self.data.ndim def dequantize(self): - """No-op dequantization — returns the raw data.""" + """This is a no-op for a higher-precision tensor so this simply returns the tensor's data.""" return self.data + def get_tensor(self, usage: TensorUsage): + """Returns the tensor based on the tensor usage.""" + q_layout = ScalingMode.NO_SCALING.get_quantize_layout(usage) + assert q_layout.is_rowwise_only, "Only ROWWISE layout is supported for NoScaleTensor" + return self + + def apply_sharding_constraint_by_logical_axes(self, logical_axis_names: Tuple[str, ...]): + """Applies sharding constraints to a tensor based on logical axis names. + + Args: + logical_axis_names: Tuple of logical axis names for sharding + + Returns: + The tensor with applied sharding constraints + """ + if not logical_axis_names: + return self + + data = with_sharding_constraint_by_logical_axes(self.data, logical_axis_names) + + return GroupedNoScaleTensor( + data=data, + amax=self.amax, + first_dims=self.first_dims, + last_dims=self.last_dims, + original_shape=self.original_shape, + ) + + def checkpoint(self, quantizer): + """Checkpoints the tensor with the given quantizer's checkpoint name if available. + + Args: + quantizer: The quantizer to use for checkpointing. If None, no checkpointing is applied. + + Returns: + The checkpointed tensor + """ + assert quantizer is None, "NoScaleTensor does not support quantization." + return self + @register_pytree_node_class @dataclass @@ -664,7 +652,6 @@ def create_1x( first_dims=None, last_dims=None, original_shape=None, - group_axis=0, has_rht_applied=False, ): """Creates a single-scale quantized tensor. @@ -681,7 +668,6 @@ def create_1x( first_dims: Per-group sizes of the first (row) 2D dim (default: None) last_dims: Per-group sizes of the last (col) 2D dim (default: None) original_shape: The original shape of the tensor before grouping (default: None) - group_axis: The axis along which grouping is performed (default: 0) has_rht_applied: Whether the tensor had the Randomized Hadamard Transform (RHT) applied during quantization (default: False) Returns: @@ -695,7 +681,7 @@ def create_1x( if ( first_dims is not None or last_dims is not None - or (original_shape is not None and group_axis is not None) + or original_shape is not None ): assert ( original_shape is not None @@ -709,11 +695,9 @@ def create_1x( if active_dims is not None: num_groups = active_dims.size else: - norm_group_axis = (len(original_shape) + group_axis) % len(original_shape) - num_groups = original_shape[norm_group_axis] + num_groups = original_shape[0] # Handling attrs of transposed tensors - group_axis = (len(original_shape) + group_axis) % len(original_shape) if data_layout == "T": if original_shape[0] == num_groups: original_shape = ( @@ -727,7 +711,6 @@ def create_1x( *original_shape[flatten_axis:], *original_shape[:flatten_axis], ) - group_axis = flatten_axis flatten_axis = len(original_shape) - flatten_axis return GroupedScaledTensor1x( @@ -743,7 +726,6 @@ def create_1x( first_dims=first_dims, last_dims=last_dims, original_shape=original_shape, - group_axis=group_axis, ) # Handling attrs of transposed tensors @@ -779,7 +761,6 @@ def create_2x( first_dims=None, last_dims=None, original_shape=None, - group_axis=0, rowwise_has_rht_applied=False, colwise_has_rht_applied=False, ): @@ -798,7 +779,6 @@ def create_2x( first_dims: Per-group sizes of the first (row) 2D dim (default: None) last_dims: Per-group sizes of the last (col) 2D dim (default: None) original_shape: The original shape of the tensor before grouping (default: None) - group_axis: The axis along which grouping is performed (default: 0) rowwise_has_rht_applied: Whether the row-wise tensor uses the Randomized Hadamard Transform (RHT) (default: False) colwise_has_rht_applied: Whether the column-wise tensor uses the Randomized Hadamard Transform (RHT) (default: False) @@ -823,7 +803,6 @@ def create_2x( first_dims=first_dims, last_dims=last_dims, original_shape=original_shape, - group_axis=group_axis, has_rht_applied=rowwise_has_rht_applied, ) colwise_tensor = ScaledTensorFactory.create_1x( @@ -838,7 +817,6 @@ def create_2x( first_dims=first_dims, last_dims=last_dims, original_shape=original_shape, - group_axis=group_axis, has_rht_applied=colwise_has_rht_applied, ) return ScaledTensor2x(rowwise_tensor, colwise_tensor) @@ -859,7 +837,6 @@ def create( first_dims: jnp.ndarray = None, last_dims: jnp.ndarray = None, original_shape: Tuple[int] = None, - group_axis: int = 0, rowwise_has_rht_applied: bool = False, colwise_has_rht_applied: bool = False, ): @@ -878,7 +855,6 @@ def create( first_dims: Per-group sizes of the first (row) 2D dim (default: None) last_dims: Per-group sizes of the last (col) 2D dim (default: None) original_shape: The original shape of the tensor before grouping (default: None) - group_axis: The axis along which grouping is performed (default: 0) rowwise_has_rht_applied: Whether the row-wise tensor uses the Randomized Hadamard Transform (RHT) (default: False) colwise_has_rht_applied: Whether the col-wise tensor uses the Randomized Hadamard Transform (RHT) (default: False) @@ -902,7 +878,6 @@ def create( first_dims=first_dims, last_dims=last_dims, original_shape=original_shape, - group_axis=group_axis, rowwise_has_rht_applied=rowwise_has_rht_applied, colwise_has_rht_applied=colwise_has_rht_applied, ) @@ -920,7 +895,6 @@ def create( first_dims=first_dims, last_dims=last_dims, original_shape=original_shape, - group_axis=group_axis, has_rht_applied=colwise_has_rht_applied, ) @@ -936,7 +910,6 @@ def create( first_dims=first_dims, last_dims=last_dims, original_shape=original_shape, - group_axis=group_axis, has_rht_applied=rowwise_has_rht_applied, ) From d9b9c446a97e2a92dbca652bb9652379caa2fa36 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Mon, 23 Mar 2026 21:23:01 +0000 Subject: [PATCH 26/60] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- transformer_engine/jax/cpp_extensions/gemm.py | 25 +++++++++++++------ .../jax/cpp_extensions/quantization.py | 16 ++++++------ .../jax/csrc/extensions/gemm.cpp | 4 +-- transformer_engine/jax/dense.py | 6 +---- .../jax/quantize/dequantizer.py | 8 ++---- transformer_engine/jax/quantize/tensor.py | 6 +---- 6 files changed, 32 insertions(+), 33 deletions(-) diff --git a/transformer_engine/jax/cpp_extensions/gemm.py b/transformer_engine/jax/cpp_extensions/gemm.py index a5e275d663..04b615269a 100644 --- a/transformer_engine/jax/cpp_extensions/gemm.py +++ b/transformer_engine/jax/cpp_extensions/gemm.py @@ -2018,11 +2018,13 @@ def grouped_gemm_copy_group_sizes( ) return out + @cache def _should_enforce_v2_grouped_gemm() -> bool: """Read NVTE_JAX_ENFORCE_V2_GROUPED_GEMM once per process (cached).""" return os.getenv("NVTE_JAX_ENFORCE_V2_GROUPED_GEMM", "0") == "1" + def _can_use_v2_grouped_gemm( scaling_mode: ScalingMode, dtype: jnp.dtype, @@ -2038,21 +2040,32 @@ def _can_use_v2_grouped_gemm( if not _v2_grouped_gemm_available: if enforce_v2_gmm: - raise RuntimeError(f"The TE V2 grouped GEMM is not available but NVTE_JAX_ENFORCE_V2_GROUPED_GEMM is enabled. The reason for V2 grouped GEMM not being available: {_v2_grouped_gemm_available_reason}") + raise RuntimeError( + "The TE V2 grouped GEMM is not available but NVTE_JAX_ENFORCE_V2_GROUPED_GEMM is" + " enabled. The reason for V2 grouped GEMM not being available:" + f" {_v2_grouped_gemm_available_reason}" + ) return False # nvte_grouped_gemm (the v2 kernel) requires SM100+ (Blackwell or newer). # Fall back to the v1 path on SM90 (Hopper) and older architectures. if get_device_compute_capability(0) < 100: if enforce_v2_gmm: - raise RuntimeError(f"The TE V2 grouped GEMM requires SM100+ (Blackwell or newer) but current device compute capability of GPU 0 is {get_device_compute_capability(0)} and NVTE_JAX_ENFORCE_V2_GROUPED_GEMM is enabled.") + raise RuntimeError( + "The TE V2 grouped GEMM requires SM100+ (Blackwell or newer) but current device" + f" compute capability of GPU 0 is {get_device_compute_capability(0)} and" + " NVTE_JAX_ENFORCE_V2_GROUPED_GEMM is enabled." + ) return False if scaling_mode == ScalingMode.NO_SCALING and dtype == jnp.bfloat16 and not has_bias: return True if enforce_v2_gmm: - raise RuntimeError(f"The TE V2 grouped GEMM currently only supports BF16 with no quantization recipe and without bias, but received {scaling_mode=}, {dtype=}, {has_bias=}") + raise RuntimeError( + "The TE V2 grouped GEMM currently only supports BF16 with no quantization recipe and" + f" without bias, but received {scaling_mode=}, {dtype=}, {has_bias=}" + ) return False @@ -2281,11 +2294,7 @@ def grouped_gemm( else: lhs_non_contracting = lhs_shape[:lhs_axis_boundary] if rhs_is_trans: - rhs_non_contracting = tuple( - rhs_shape[d] - for d in range(rhs_axis_boundary) - if d != 0 - ) + rhs_non_contracting = tuple(rhs_shape[d] for d in range(rhs_axis_boundary) if d != 0) else: rhs_non_contracting = rhs_shape[rhs_axis_boundary:] if rhs_first_dims.size > 0 or rhs_last_dims.size > 0: diff --git a/transformer_engine/jax/cpp_extensions/quantization.py b/transformer_engine/jax/cpp_extensions/quantization.py index ce0f186aeb..94a23251d1 100644 --- a/transformer_engine/jax/cpp_extensions/quantization.py +++ b/transformer_engine/jax/cpp_extensions/quantization.py @@ -1186,7 +1186,13 @@ def grouped_quantize( if quantizer is None: if isinstance(x, GroupedNoScaleTensor): - assert amax is None, "If the input to grouped_quantize is already a GroupedNoScaleTensor, providing an amax could be ambiguous. Please set amax to None and set the amax on your GroupedNoScaleTensor directly, if needed. Alternatively, please call grouped_quantize with a raw jnp.ndarray along with an amax value if you'd like this function to handle amax for you." + assert amax is None, ( + "If the input to grouped_quantize is already a GroupedNoScaleTensor, providing an" + " amax could be ambiguous. Please set amax to None and set the amax on your" + " GroupedNoScaleTensor directly, if needed. Alternatively, please call" + " grouped_quantize with a raw jnp.ndarray along with an amax value if you'd like" + " this function to handle amax for you." + ) return x return GroupedNoScaleTensor( data=x, @@ -1207,9 +1213,7 @@ def grouped_quantize( group_sizes = jnp.ones(x.shape[0], dtype=jnp.int32) if not GroupedQuantizePrimitive.enabled(): - return quantizer.quantize( - x, flatten_axis=flatten_axis, group_sizes=group_sizes - ) + return quantizer.quantize(x, flatten_axis=flatten_axis, group_sizes=group_sizes) n_groups = group_sizes.size original_shape = x.shape assert n_groups == len( @@ -1226,9 +1230,7 @@ def grouped_quantize( row_amax = amax else: row_amax = jnp.max(jnp.abs(x), axis=range(1, x.ndim)) - segment_ids = jnp.repeat( - jnp.arange(n_groups), group_sizes, total_repeat_length=x.shape[0] - ) + segment_ids = jnp.repeat(jnp.arange(n_groups), group_sizes, total_repeat_length=x.shape[0]) grouped_amax = jax.ops.segment_max(row_amax, segment_ids, num_segments=n_groups) for i in range(n_groups): tmp_scale = compute_scale_from_amax(grouped_amax[i], quantizer.q_dtype, margin=0.0) diff --git a/transformer_engine/jax/csrc/extensions/gemm.cpp b/transformer_engine/jax/csrc/extensions/gemm.cpp index fb76b7052e..fb42197e58 100644 --- a/transformer_engine/jax/csrc/extensions/gemm.cpp +++ b/transformer_engine/jax/csrc/extensions/gemm.cpp @@ -781,8 +781,8 @@ Error_Type GroupedGemmFFI(cudaStream_t stream, Buffer_Type lhs_data, Buffer_Type Buffer_Type group_offset, Result_Type output, Result_Type workspace, GroupedGemmConfig config) { auto [lhs_is_trans, rhs_is_trans, scaling_mode, has_bias, use_async_d2h_group_sizes, - lhs_axis_boundary, rhs_axis_boundary, - lhs_left_size, lhs_right_size, rhs_left_size, rhs_right_size] = config; + lhs_axis_boundary, rhs_axis_boundary, lhs_left_size, lhs_right_size, rhs_left_size, + rhs_right_size] = config; // Notes on matrix layouts and transpose: // Jax uses row-major data_layout, on entering this function, each input matrix pair: // A: row-major [m, k] for N - [k, m] for T diff --git a/transformer_engine/jax/dense.py b/transformer_engine/jax/dense.py index aea957145f..6eed21b30f 100644 --- a/transformer_engine/jax/dense.py +++ b/transformer_engine/jax/dense.py @@ -417,10 +417,6 @@ def _grouped_dense_fwd_rule( assert not kernel_fsdp_enabled, "FSDP sharding for grouped_dense is not supported yet." del kernel_fsdp_mesh_axis, kernel_fsdp_axis_idx, kernel_fsdp_info, kernel_fsdp_enabled - - - - x_contracting_dims, k_contracting_dims = contracting_dims flatten_axis_x = -len(x_contracting_dims) flatten_axis_k = len(k_contracting_dims) - len(kernel.shape) + 1 # +1 for G axis @@ -516,7 +512,7 @@ def _grouped_dense_bwd_rule( range(0, len(x_shape) - len(fwd_x_contracting_dims)) ) wgrad_contracting_dims = (x_contracting_dim, g_contracting_dim) - + wgrad_x_T = ctx_x wgrad_grad = casted_grad.get_tensor(usage=TensorUsage.RHS) dgrad = tex.grouped_gemm( diff --git a/transformer_engine/jax/quantize/dequantizer.py b/transformer_engine/jax/quantize/dequantizer.py index de97fb7318..2501412ab1 100644 --- a/transformer_engine/jax/quantize/dequantizer.py +++ b/transformer_engine/jax/quantize/dequantizer.py @@ -283,18 +283,14 @@ def _grouped_dequantize(grouped_scaled_tensor): ) # For non-ragged groups (kernel case), group_sizes is not stored; derive from original_shape if group_sizes is None: - group_sizes = jnp.ones( - grouped_scaled_tensor.original_shape[0], dtype=jnp.int32 - ) + group_sizes = jnp.ones(grouped_scaled_tensor.original_shape[0], dtype=jnp.int32) flatten_axis = grouped_scaled_tensor.flatten_axis scaling_mode = grouped_scaled_tensor.scaling_mode original_shape = grouped_scaled_tensor.original_shape flatten_axis = len(original_shape) + flatten_axis if flatten_axis < 0 else flatten_axis output = [] - non_group_shape = tuple( - original_shape[i] for i in range(len(original_shape)) if i != 0 - ) + non_group_shape = tuple(original_shape[i] for i in range(len(original_shape)) if i != 0) matrix_sizes = group_sizes * math.prod(non_group_shape) data = jnp.split(data, jnp.cumulative_sum(matrix_sizes)[:-1]) diff --git a/transformer_engine/jax/quantize/tensor.py b/transformer_engine/jax/quantize/tensor.py index 6c9062b4f1..b1f49dacdc 100644 --- a/transformer_engine/jax/quantize/tensor.py +++ b/transformer_engine/jax/quantize/tensor.py @@ -678,11 +678,7 @@ def create_1x( dequantizer = ScalingModeToDequantizerMap.get(scaling_mode) - if ( - first_dims is not None - or last_dims is not None - or original_shape is not None - ): + if first_dims is not None or last_dims is not None or original_shape is not None: assert ( original_shape is not None ), "original_shape is not given for GroupedScaledTensor1x" From 94384787fbe503454664c19ee1d615fd8f0fc533 Mon Sep 17 00:00:00 2001 From: Jeremy Berchtold Date: Mon, 23 Mar 2026 14:49:25 -0700 Subject: [PATCH 27/60] Lint Signed-off-by: Jeremy Berchtold --- transformer_engine/jax/cpp_extensions/gemm.py | 5 +++++ transformer_engine/jax/dense.py | 2 +- 2 files changed, 6 insertions(+), 1 deletion(-) diff --git a/transformer_engine/jax/cpp_extensions/gemm.py b/transformer_engine/jax/cpp_extensions/gemm.py index 04b615269a..3b88a711f4 100644 --- a/transformer_engine/jax/cpp_extensions/gemm.py +++ b/transformer_engine/jax/cpp_extensions/gemm.py @@ -74,6 +74,7 @@ # Cache whether the CUDA-graphable grouped GEMM implementation is available at import time. # Calling get_grouped_gemm_setup_workspace_size raises a RuntimeError mentioning "cublas" when # compiled against cuBLAS < 13.2, in which case the cuda-graphable path is unavailable. +_v2_grouped_gemm_available_reason = "" try: get_grouped_gemm_setup_workspace_size(1) _v2_grouped_gemm_available = True @@ -1498,6 +1499,10 @@ def abstract( Returns: A jnp.ndarray containing the result of the grouped GEMM operation """ + del lhs_data_aval, rhs_data_aval + del lhs_is_trans, rhs_is_trans + del lhs_axis_boundary, rhs_axis_boundary + del lhs_left_size, lhs_right_size, rhs_left_size, rhs_right_size del bias_aval del has_bias, use_async_d2h_group_sizes diff --git a/transformer_engine/jax/dense.py b/transformer_engine/jax/dense.py index 6eed21b30f..056a9655d0 100644 --- a/transformer_engine/jax/dense.py +++ b/transformer_engine/jax/dense.py @@ -471,7 +471,7 @@ def _grouped_dense_fwd_rule( def _grouped_dense_bwd_rule( contracting_dims, precision, preferred_element_type, group_offset, kernel_fsdp_info, ctx, grad ): - kernel_fsdp_mesh_axis, kernel_fsdp_axis_idx = kernel_fsdp_info + kernel_fsdp_mesh_axis, _ = kernel_fsdp_info kernel_fsdp_enabled = kernel_fsdp_mesh_axis is not None assert not kernel_fsdp_enabled, "FSDP sharding for grouped_dense is not supported yet." From 09dfd9c2dea1f12fdc6c3066d75b3c55ab0ba5a6 Mon Sep 17 00:00:00 2001 From: Jeremy Berchtold Date: Tue, 24 Mar 2026 14:12:18 -0700 Subject: [PATCH 28/60] Fixes for Hopper Signed-off-by: Jeremy Berchtold --- transformer_engine/jax/cpp_extensions/gemm.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/transformer_engine/jax/cpp_extensions/gemm.py b/transformer_engine/jax/cpp_extensions/gemm.py index 3b88a711f4..aaec5affa8 100644 --- a/transformer_engine/jax/cpp_extensions/gemm.py +++ b/transformer_engine/jax/cpp_extensions/gemm.py @@ -2299,7 +2299,12 @@ def grouped_gemm( else: lhs_non_contracting = lhs_shape[:lhs_axis_boundary] if rhs_is_trans: - rhs_non_contracting = tuple(rhs_shape[d] for d in range(rhs_axis_boundary) if d != 0) + if rhs_first_dims.size > 0 or rhs_last_dims.size > 0: + # wgrad: rhs (e.g. grad_T of shape (N, M)) has no G batch dim; include all dims + rhs_non_contracting = tuple(rhs_shape[d] for d in range(rhs_axis_boundary)) + else: + # fwd/dgrad: rhs (e.g. kernel_T of shape (G, N, K)) has G batch dim at dim 0; skip it + rhs_non_contracting = tuple(rhs_shape[d] for d in range(rhs_axis_boundary) if d != 0) else: rhs_non_contracting = rhs_shape[rhs_axis_boundary:] if rhs_first_dims.size > 0 or rhs_last_dims.size > 0: From e25538e64f890ac7c22273284e66f523b1e85d52 Mon Sep 17 00:00:00 2001 From: Jeremy Berchtold Date: Tue, 24 Mar 2026 15:26:17 -0700 Subject: [PATCH 29/60] Address review comments Signed-off-by: Jeremy Berchtold --- .../jax/cpp_extensions/quantization.py | 16 ++------------- .../jax/csrc/extensions/gemm.cpp | 20 +++++++++---------- transformer_engine/jax/dense.py | 13 +++--------- 3 files changed, 15 insertions(+), 34 deletions(-) diff --git a/transformer_engine/jax/cpp_extensions/quantization.py b/transformer_engine/jax/cpp_extensions/quantization.py index 94a23251d1..a3d363e42a 100644 --- a/transformer_engine/jax/cpp_extensions/quantization.py +++ b/transformer_engine/jax/cpp_extensions/quantization.py @@ -1158,7 +1158,6 @@ def grouped_quantize( x: jnp.ndarray, quantizer: GroupedQuantizer, group_sizes: jnp.ndarray = None, - amax: jnp.ndarray = None, flatten_axis: int = -1, ) -> Union[GroupedScaledTensor1x, GroupedNoScaleTensor]: """Quantize a tensor in grouped manner. @@ -1171,7 +1170,6 @@ def grouped_quantize( x: Input tensor to quantize quantizer: The quantizer to use for quantization group_sizes: Array of ints containing the size of each group (default: None) - amax: The amax of x; if None, it is auto-generated. (default: None) flatten_axis: The axis along which the tensor could be flattened to 2D (default: -1) Returns: @@ -1186,17 +1184,10 @@ def grouped_quantize( if quantizer is None: if isinstance(x, GroupedNoScaleTensor): - assert amax is None, ( - "If the input to grouped_quantize is already a GroupedNoScaleTensor, providing an" - " amax could be ambiguous. Please set amax to None and set the amax on your" - " GroupedNoScaleTensor directly, if needed. Alternatively, please call" - " grouped_quantize with a raw jnp.ndarray along with an amax value if you'd like" - " this function to handle amax for you." - ) return x return GroupedNoScaleTensor( data=x, - amax=amax, + amax=None, first_dims=group_sizes, last_dims=None, original_shape=x.shape, @@ -1226,10 +1217,7 @@ def grouped_quantize( scale = scale.at[i].set(quantizer_i.scale[0]) if quantizer.scaling_mode == ScalingMode.CURRENT_TENSOR_SCALING: - if amax is not None: - row_amax = amax - else: - row_amax = jnp.max(jnp.abs(x), axis=range(1, x.ndim)) + row_amax = jnp.max(jnp.abs(x), axis=range(1, x.ndim)) segment_ids = jnp.repeat(jnp.arange(n_groups), group_sizes, total_repeat_length=x.shape[0]) grouped_amax = jax.ops.segment_max(row_amax, segment_ids, num_segments=n_groups) for i in range(n_groups): diff --git a/transformer_engine/jax/csrc/extensions/gemm.cpp b/transformer_engine/jax/csrc/extensions/gemm.cpp index fb42197e58..0d1ef405f4 100644 --- a/transformer_engine/jax/csrc/extensions/gemm.cpp +++ b/transformer_engine/jax/csrc/extensions/gemm.cpp @@ -667,17 +667,17 @@ size_t grouped_gemm_num_gemms(Buffer_Type const &lhs_first_dims, Buffer_Type con Buffer_Type const &out_first_dims, Buffer_Type const &out_last_dims, Buffer_Type const &alpha) { if (lhs_first_dims.element_count() > 0) { - return lhs_first_dims.dimensions()[0]; + return lhs_first_dims.element_count(); } else if (lhs_last_dims.element_count() > 0) { - return lhs_last_dims.dimensions()[0]; + return lhs_last_dims.element_count(); } else if (rhs_first_dims.element_count() > 0) { - return rhs_first_dims.dimensions()[0]; + return rhs_first_dims.element_count(); } else if (rhs_last_dims.element_count() > 0) { - return rhs_last_dims.dimensions()[0]; + return rhs_last_dims.element_count(); } else if (out_first_dims.element_count() > 0) { - return out_first_dims.dimensions()[0]; + return out_first_dims.element_count(); } else if (out_last_dims.element_count() > 0) { - return out_last_dims.dimensions()[0]; + return out_last_dims.element_count(); } else { return alpha.element_count(); // uniform batch: no ragged tensor } @@ -753,9 +753,9 @@ Error_Type GroupedGemmV2FFI(cudaStream_t stream, Buffer_Type lhs_data, Buffer_Ty XLA_FFI_DEFINE_HANDLER_SYMBOL(GroupedGemmV2Handler, GroupedGemmV2FFI, FFI::Bind() .Ctx() // stream - .Arg() // lhs_data (2D) + .Arg() // lhs_data .Arg() // lhs_sinv - .Arg() // rhs_data (2D) + .Arg() // rhs_data .Arg() // rhs_sinv .Arg() // bias .Arg() // lhs_first_dims (G,) or empty (0,) @@ -1207,9 +1207,9 @@ Error_Type GroupedGemmFFI(cudaStream_t stream, Buffer_Type lhs_data, Buffer_Type XLA_FFI_DEFINE_HANDLER_SYMBOL(GroupedGemmHandler, GroupedGemmFFI, FFI::Bind() .Ctx() // stream - .Arg() // lhs_data (2D) + .Arg() // lhs_data .Arg() // lhs_sinv - .Arg() // rhs_data (2D) + .Arg() // rhs_data .Arg() // rhs_sinv .Arg() // bias .Arg() // lhs_first_dims (G,) or empty (0,) diff --git a/transformer_engine/jax/dense.py b/transformer_engine/jax/dense.py index 056a9655d0..96e4a2251b 100644 --- a/transformer_engine/jax/dense.py +++ b/transformer_engine/jax/dense.py @@ -321,7 +321,6 @@ def grouped_dense( group_sizes: jnp.ndarray, contracting_dims: Tuple[Sequence[int], Sequence[int]] = ((1,), (1,)), bias: jnp.ndarray = None, - kernel_amax: jnp.ndarray = None, precision: jax.lax.Precision = jax.lax.Precision.DEFAULT, preferred_element_type: jnp.dtype = None, group_offset: jnp.array = None, @@ -338,7 +337,6 @@ def grouped_dense( contracting_dims: Tuple of sequences specifying which dimensions to contract (currently only supports ((1,), (1,))) bias: Bias tensor of shape (G, N) - kernel_amax: The amax values of weight matrix of shape (G,) precision: JAX precision for the GEMM operation preferred_element_type: Preferred data type for the output tensor group_offset: 1D array containing offsets for each group (not yet implemented) @@ -357,7 +355,6 @@ def grouped_dense( group_sizes, contracting_dims, bias, - kernel_amax, precision, preferred_element_type, group_offset, @@ -367,14 +364,13 @@ def grouped_dense( return output -@partial(jax.custom_vjp, nondiff_argnums=(3, 6, 7, 8, 10)) +@partial(jax.custom_vjp, nondiff_argnums=(3, 5, 6, 7, 9)) def _grouped_dense( x, kernel, group_sizes, contracting_dims, bias, - kernel_amax, precision, preferred_element_type, group_offset, @@ -387,7 +383,6 @@ def _grouped_dense( group_sizes, contracting_dims, bias, - kernel_amax, precision, preferred_element_type, group_offset, @@ -403,7 +398,6 @@ def _grouped_dense_fwd_rule( group_sizes, contracting_dims, bias, - kernel_amax, precision, preferred_element_type, group_offset, @@ -429,7 +423,7 @@ def _grouped_dense_fwd_rule( ) casted_kernel = tex.grouped_quantize( - kernel, quantizer_set.kernel, amax=kernel_amax, flatten_axis=flatten_axis_k + kernel, quantizer_set.kernel, flatten_axis=flatten_axis_k ) contracting_dims = (x_contracting_dims, k_contracting_dims) @@ -535,9 +529,8 @@ def _grouped_dense_bwd_rule( group_sizes_grad = None dbias = tex.grouped_dbias(grad, group_sizes) if use_bias else None - dkernel_amax = None - return dgrad, wgrad, group_sizes_grad, dbias, dkernel_amax, quantizer_set + return dgrad, wgrad, group_sizes_grad, dbias, quantizer_set _grouped_dense.defvjp(_grouped_dense_fwd_rule, _grouped_dense_bwd_rule) From 78674e936b68f45d1d54e08b9cc3a4fc9dff300c Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 24 Mar 2026 22:27:20 +0000 Subject: [PATCH 30/60] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- transformer_engine/jax/dense.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/transformer_engine/jax/dense.py b/transformer_engine/jax/dense.py index 96e4a2251b..dbd7bbb1ff 100644 --- a/transformer_engine/jax/dense.py +++ b/transformer_engine/jax/dense.py @@ -422,9 +422,7 @@ def _grouped_dense_fwd_rule( flatten_axis=flatten_axis_x, ) - casted_kernel = tex.grouped_quantize( - kernel, quantizer_set.kernel, flatten_axis=flatten_axis_k - ) + casted_kernel = tex.grouped_quantize(kernel, quantizer_set.kernel, flatten_axis=flatten_axis_k) contracting_dims = (x_contracting_dims, k_contracting_dims) # For x_contracting_dims == (1,) and k_contracting_dims == (1,), we should have From 06ebb4494df742224473a7fcfd2ff103a042269a Mon Sep 17 00:00:00 2001 From: Jeremy Berchtold Date: Tue, 24 Mar 2026 16:32:02 -0700 Subject: [PATCH 31/60] Fixes Signed-off-by: Jeremy Berchtold --- tests/jax/test_custom_call_compute.py | 122 +++++++++++++++++- .../jax/cpp_extensions/quantization.py | 37 +++--- transformer_engine/jax/quantize/tensor.py | 8 +- 3 files changed, 138 insertions(+), 29 deletions(-) diff --git a/tests/jax/test_custom_call_compute.py b/tests/jax/test_custom_call_compute.py index 1a2ce48d43..a4322450ae 100644 --- a/tests/jax/test_custom_call_compute.py +++ b/tests/jax/test_custom_call_compute.py @@ -43,6 +43,7 @@ noop_quantizer_set, QuantizeMetaSet, QuantizeMeta, + get_device_compute_capability, ) from transformer_engine.jax.quantize import helper from transformer_engine.jax.activation import activation @@ -77,6 +78,9 @@ supported_recipes = helper.get_supported_quantization_recipes() supported_recipes = [pytest.param(r, id=r.__class__.__name__) for r in supported_recipes] +is_v2_grouped_gemm_supported = get_device_compute_capability(0) >= 100 +v2_grouped_gemm_unsupported_reason = "V2 grouped GEMM requires SM100+ (Blackwell or newer)" + def is_shape_supported_by_mxfp8(input_shape): try: @@ -2013,10 +2017,10 @@ def test_grouped_gemm_mxfp8_v1_shapes(self, input_shape): n_groups=input_shape[0], ) lhs_tensor = GroupedNoScaleTensor( - data=lhs, first_dims=group_sizes, last_dims=None, group_axis=0, original_shape=lhs.shape + data=lhs, amax=None, first_dims=group_sizes, last_dims=None, original_shape=lhs.shape ) rhs_tensor = GroupedNoScaleTensor( - data=rhs, first_dims=None, last_dims=None, group_axis=0, original_shape=rhs.shape + data=rhs, amax=None, first_dims=None, last_dims=None, original_shape=rhs.shape ) # Reference: unquantized grouped GEMM n_groups = input_shape[0] @@ -2054,10 +2058,10 @@ def test_grouped_gemm_mxfp8_v2_shapes(self, input_shape): n_groups=input_shape[0], ) lhs_tensor = GroupedNoScaleTensor( - data=lhs, first_dims=group_sizes, last_dims=None, group_axis=0, original_shape=lhs.shape + data=lhs, amax=None, first_dims=group_sizes, last_dims=None, original_shape=lhs.shape ) rhs_tensor = GroupedNoScaleTensor( - data=rhs, first_dims=None, last_dims=None, group_axis=0, original_shape=rhs.shape + data=rhs, amax=None, first_dims=None, last_dims=None, original_shape=rhs.shape ) n_groups = input_shape[0] lhs_splits = jnp.split(lhs, jnp.cumulative_sum(group_sizes)[:-1], axis=0) @@ -2128,6 +2132,116 @@ def _prim_sum(x, kernel, group_sizes): assert_allclose(prim_dk, ref_dk, dtype=bwd_dtype) +# BF16 grouped GEMM V1/V2 shapes: no special shape alignment needed for BF16 GEMM. +# V2 is selected based solely on hardware (SM100+), not shape. +GROUPED_DENSE_BF16_INPUT_SHAPES = [ + # (n_groups, m, n, k) + (8, 8, 128, 128), + (4, 4, 64, 256), +] + + +@pytest.mark.skipif( + not is_v2_grouped_gemm_supported, reason=v2_grouped_gemm_unsupported_reason +) +class TestGroupedDenseBF16V2GEMM: + """Tests that explicitly verify V2 BF16 grouped GEMM on SM100+ hardware. + + For BF16, the V2 (CUDA-graph-safe) grouped GEMM is selected when: + - The device compute capability is >= 100 (Blackwell or newer) + - The cuBLAS version supports it + V1 (nvte_multi_tensor_gemm) is the fallback on older hardware. + + V1 BF16 grouped GEMM is tested by TestGroupedDense.test_grouped_gemm_fp16 + (using use_async_d2h_group_sizes=True). + """ + + def _generate_bf16_input(self, input_shape, group_size_multiplier=32): + key = jax.random.PRNGKey(7) + subkeys = jax.random.split(key, 3) + n_groups, m, n, k = input_shape + + group_sizes = jnp.sort(jax.random.randint(subkeys[0], (n_groups - 1,), 0, m)) + group_sizes = jnp.concatenate([jnp.array([0]), group_sizes, jnp.array([m])]) + group_sizes = jnp.diff(group_sizes) + group_sizes = group_sizes.at[0].set(group_sizes[0] + group_sizes[1]) + group_sizes = group_sizes.at[1].set(0) + group_sizes = group_sizes * group_size_multiplier + m_total = m * group_size_multiplier + + lhs = jax.random.uniform(subkeys[1], (m_total, k), dtype=jnp.bfloat16) + rhs = jax.random.uniform(subkeys[2], (n_groups, k, n), dtype=jnp.bfloat16) + return lhs, rhs, group_sizes + + @pytest.mark.parametrize( + "input_shape", + GROUPED_DENSE_BF16_INPUT_SHAPES, + ids=[f"bf16_v2_{s}" for s in GROUPED_DENSE_BF16_INPUT_SHAPES], + ) + def test_grouped_gemm_bf16_v2(self, input_shape): + """BF16 grouped GEMM using the V2 (CUDA-graph-safe) kernel on SM100+.""" + lhs, rhs, group_sizes = self._generate_bf16_input(input_shape) + n_groups = input_shape[0] + + lhs_tensor = GroupedNoScaleTensor( + data=lhs, amax=None, first_dims=group_sizes, last_dims=None, original_shape=lhs.shape + ) + rhs_tensor = GroupedNoScaleTensor( + data=rhs, amax=None, first_dims=None, last_dims=None, original_shape=rhs.shape + ) + + lhs_splits = jnp.split(lhs, jnp.cumulative_sum(group_sizes)[:-1], axis=0) + rhs_splits = jnp.split(rhs, n_groups, axis=0) + ref_out = jnp.concatenate( + [jnp.squeeze(lhs_i @ rhs_i, axis=0) for lhs_i, rhs_i in zip(lhs_splits, rhs_splits)], + axis=0, + ) + + prim_out = jax.jit( + tex.grouped_gemm, static_argnames=("contracting_dims",) + )( + lhs_tensor, + rhs_tensor, + contracting_dims=((1,), (1,)), + ) + + assert prim_out.shape == ref_out.shape + assert prim_out.dtype == jnp.bfloat16 + assert_allclose(prim_out, ref_out, dtype=jnp.bfloat16) + + @pytest.mark.parametrize( + "input_shape", + GROUPED_DENSE_BF16_INPUT_SHAPES, + ids=[f"bf16_v2_grad_{s}" for s in GROUPED_DENSE_BF16_INPUT_SHAPES], + ) + def test_grouped_dense_grad_bf16_v2(self, input_shape): + """BF16 grouped GEMM gradient test (fwd + dgrad + wgrad) using V2 on SM100+.""" + lhs, rhs, group_sizes = self._generate_bf16_input(input_shape) + n_groups = input_shape[0] + contracting_dims = ((1,), (1,)) + + def _ref_sum(x, kernel, group_sizes): + lhs_splits = jnp.split(x, jnp.cumulative_sum(group_sizes)[:-1], axis=0) + rhs_splits = jnp.split(kernel, n_groups, axis=0) + out = jnp.concatenate( + [jnp.squeeze(li @ ri, axis=0) for li, ri in zip(lhs_splits, rhs_splits)], axis=0 + ) + return jnp.sum(out) / jnp.sqrt(x.size) + + def _prim_sum(x, kernel, group_sizes): + out = grouped_dense(x, kernel, group_sizes, contracting_dims, bias=None) + return jnp.sum(jnp.asarray(out)) / jnp.sqrt(x.size) + + ref_val, (ref_dx, ref_dk) = value_and_grad(_ref_sum, (0, 1))(lhs, rhs, group_sizes) + prim_val, (prim_dx, prim_dk) = jit(value_and_grad(_prim_sum, (0, 1)), static_argnums=())( + lhs, rhs, group_sizes + ) + + assert_allclose(prim_val, ref_val, dtype=jnp.bfloat16) + assert_allclose(prim_dx, ref_dx, dtype=jnp.bfloat16) + assert_allclose(prim_dk, ref_dk, dtype=jnp.bfloat16) + + class TestDebugInspectFFI: @pytest_parametrize_wrapper("shape", [(256, 128)]) diff --git a/transformer_engine/jax/cpp_extensions/quantization.py b/transformer_engine/jax/cpp_extensions/quantization.py index 5512cc037e..65d144c997 100644 --- a/transformer_engine/jax/cpp_extensions/quantization.py +++ b/transformer_engine/jax/cpp_extensions/quantization.py @@ -1084,35 +1084,35 @@ def _use_v2_kernel(scaling_mode, x_shape, flatten_axis): but performs a D2H copy of group_sizes (not CUDA-graph safe). """ if ScalingMode(scaling_mode) != ScalingMode.MXFP8_1D_SCALING: - assert False, ( - "V2 grouped quantize kernel currently only supports MXFP8 1D scaling mode, but got" - " scaling_mode {}".format(scaling_mode) - ) + # assert False, ( + # "V2 grouped quantize kernel currently only supports MXFP8 1D scaling mode, but got" + # " scaling_mode {}".format(scaling_mode) + # ) return False ndim = len(x_shape) eff = flatten_axis if flatten_axis >= 0 else flatten_axis + ndim total_first_dim = math.prod(x_shape[:eff]) if total_first_dim % 128 != 0: - assert False, ( - "V2 grouped quantize kernel requires total first logical dimension (product of" - " x_shape up to flatten_axis) to be divisible by 128, but got shape {} and" - " flatten_axis {} with total_first_dim {}".format( - x_shape, flatten_axis, total_first_dim - ) - ) + # assert False, ( + # "V2 grouped quantize kernel requires total first logical dimension (product of" + # " x_shape up to flatten_axis) to be divisible by 128, but got shape {} and" + # " flatten_axis {} with total_first_dim {}".format( + # x_shape, flatten_axis, total_first_dim + # ) + # ) return False # For multi-dim group tensors (e.g., kernel shape G×K×N with eff=2), # non_group_m = K must also be 128-aligned. if eff > 1: non_group_m = math.prod(x_shape[1:eff]) if non_group_m % 128 != 0: - assert False, ( - "V2 grouped quantize kernel requires non-group dimension (product of" - " x_shape[1:flatten_axis]) to be divisible by 128 for multi-dim group tensors," - " but got shape {} and flatten_axis {} with non_group_m {}".format( - x_shape, flatten_axis, non_group_m - ) - ) + # assert False, ( + # "V2 grouped quantize kernel requires non-group dimension (product of" + # " x_shape[1:flatten_axis]) to be divisible by 128 for multi-dim group tensors," + # " but got shape {} and flatten_axis {} with non_group_m {}".format( + # x_shape, flatten_axis, non_group_m + # ) + # ) return False return True @@ -1234,7 +1234,6 @@ def lowering( assert x_aval.dtype in [jnp.float32, jnp.float16, jnp.bfloat16] assert scale_aval.dtype == jnp.float32 assert group_sizes_aval.dtype == jnp.int32 - assert group_axis == 0 use_v2 = GroupedQuantizePrimitive._use_v2_kernel(scaling_mode, x_aval.shape, flatten_axis) if use_v2: # V2: CUDA-graph safe; scale is passed but ignored by the C++ handler. diff --git a/transformer_engine/jax/quantize/tensor.py b/transformer_engine/jax/quantize/tensor.py index 4ce09b3721..403e8b536c 100644 --- a/transformer_engine/jax/quantize/tensor.py +++ b/transformer_engine/jax/quantize/tensor.py @@ -418,7 +418,7 @@ def group_sizes(self) -> jnp.ndarray: """ if self.first_dims is not None and self.first_dims.size > 0: return self.first_dims - return jnp.ones((self.original_shape[self.group_axis],), dtype=jnp.int32) + return jnp.ones((self.original_shape[0],), dtype=jnp.int32) def __post_init__(self): assert self.scale_inv.ndim == 1, "Only support flattened scale_inv" @@ -690,11 +690,7 @@ def create_1x( dequantizer = ScalingModeToDequantizerMap.get(scaling_mode) - if ( - first_dims is not None - or last_dims is not None - or (original_shape is not None and group_axis is not None) - ): + if first_dims is not None or last_dims is not None or original_shape is not None: assert ( original_shape is not None ), "original_shape is not given for GroupedScaledTensor1x" From a3f804272e453adb371590d70ba33ee897fb44cf Mon Sep 17 00:00:00 2001 From: Jeremy Berchtold Date: Mon, 30 Mar 2026 10:38:39 -0700 Subject: [PATCH 32/60] wip Signed-off-by: Jeremy Berchtold --- tests/jax/test_custom_call_compute.py | 165 +++++++++++++++++- transformer_engine/jax/cpp_extensions/gemm.py | 71 +++----- .../jax/cpp_extensions/quantization.py | 50 +++--- .../jax/csrc/extensions/gemm.cpp | 22 +++ transformer_engine/jax/quantize/tensor.py | 17 ++ 5 files changed, 248 insertions(+), 77 deletions(-) diff --git a/tests/jax/test_custom_call_compute.py b/tests/jax/test_custom_call_compute.py index a4322450ae..5e15781102 100644 --- a/tests/jax/test_custom_call_compute.py +++ b/tests/jax/test_custom_call_compute.py @@ -1114,6 +1114,61 @@ def test_grouped_qdq( assert_dequantized_grouped_scaled_tensor(scaled_tensor, x) + @pytest.mark.skipif(not is_mxfp8_supported, reason=mxfp8_unsupported_reason) + def test_grouped_quantize_v1_pre_swizzled( + self, input_shape, in_dtype, q_dtype, scaling_mode, q_layout, flatten_axis, with_group_sizes + ): + """V1 grouped quantize (K not 128-aligned) must produce pre_swizzled=False.""" + if scaling_mode != ScalingMode.MXFP8_1D_SCALING: + pytest.skip("pre_swizzled is only relevant for MXFP8") + if q_layout != QuantizeLayout.ROWWISE: + pytest.skip("Using ROWWISE layout to get a single GroupedScaledTensor1x") + # Shape with K=32 (not 128-aligned) forces V1 quantize on any GPU. + n_groups = 4 + group_sizes = jnp.array([32, 32, 32, 32], dtype=jnp.int32) + x = jax.random.uniform(jax.random.PRNGKey(0), (128, 32), jnp.bfloat16) + quantizer = QuantizerFactory.create( + scaling_mode=ScalingMode.MXFP8_1D_SCALING, + q_dtype=jnp.float8_e4m3fn, + q_layout=QuantizeLayout.ROWWISE, + n_groups=n_groups, + ) + scaled_tensor = tex.grouped_quantize( + x, quantizer=quantizer, group_sizes=group_sizes, flatten_axis=-1 + ) + assert isinstance(scaled_tensor, GroupedScaledTensor1x) + assert not scaled_tensor.pre_swizzled, ( + "V1 grouped quantize (non-128-aligned K) must produce pre_swizzled=False" + ) + + @pytest.mark.skipif(not is_v2_grouped_gemm_supported, reason=v2_grouped_gemm_unsupported_reason) + @pytest.mark.skipif(not is_mxfp8_supported, reason=mxfp8_unsupported_reason) + def test_grouped_quantize_v2_pre_swizzled( + self, input_shape, in_dtype, q_dtype, scaling_mode, q_layout, flatten_axis, with_group_sizes + ): + """V2 grouped quantize (SM100+, 128-aligned M and K) must produce pre_swizzled=True.""" + if scaling_mode != ScalingMode.MXFP8_1D_SCALING: + pytest.skip("pre_swizzled is only relevant for MXFP8") + if q_layout != QuantizeLayout.ROWWISE: + pytest.skip("Using ROWWISE layout to get a single GroupedScaledTensor1x") + # Shape with M=512 and K=128 (both 128-aligned) allows V2 on SM100+. + n_groups = 4 + group_sizes = jnp.array([128, 128, 128, 128], dtype=jnp.int32) + x = jax.random.uniform(jax.random.PRNGKey(0), (512, 128), jnp.bfloat16) + quantizer = QuantizerFactory.create( + scaling_mode=ScalingMode.MXFP8_1D_SCALING, + q_dtype=jnp.float8_e4m3fn, + q_layout=QuantizeLayout.ROWWISE, + n_groups=n_groups, + ) + scaled_tensor = tex.grouped_quantize( + x, quantizer=quantizer, group_sizes=group_sizes, flatten_axis=-1 + ) + assert isinstance(scaled_tensor, GroupedScaledTensor1x) + assert scaled_tensor.pre_swizzled, ( + "V2 grouped quantize (SM100+, 128-aligned M and K) must produce pre_swizzled=True" + ) + @pytest_parametrize_wrapper("in_dtype", QUANTIZATION_INPUT_DTYPE) class TestFusedQuantize: @@ -1951,6 +2006,107 @@ def test_grouped_dense_grad_fp8(self, fwd_bwd_dtype, scaling_mode, input_shape): assert_allclose(prim_wgrad, ref_wgrad, dtype=bwd_dtype) assert_allclose(prim_dbias, ref_dbias, dtype=dtype) + @pytest.mark.skipif(not is_mxfp8_supported, reason=mxfp8_unsupported_reason) + def test_grouped_dense_mxfp8_v1_pipeline(self, input_shape): + """V1 pipeline: V1 grouped quantize + V1 grouped GEMM. + + Uses shapes where K or N is not 128-aligned, forcing V1 quantize (pre_swizzled=False) + and V1 GEMM on all GPUs. Verifies correctness and that pre_swizzled=False. + """ + n_groups, m, n, k = input_shape + # Skip shapes where both K and N are 128-aligned; those may use V2 on SM100+. + if k % 128 == 0 and n % 128 == 0: + pytest.skip("Shape is V2-eligible; this test targets V1-only shapes") + lhs, rhs, group_sizes, contracting_dims, _ = self._generate_grouped_dense_input( + jnp.bfloat16, input_shape, group_size_multiplier=32 + ) + ref_out = self._ref_grouped_dense(lhs, rhs, None, group_sizes, contracting_dims) + + quantizer = QuantizerFactory.create( + scaling_mode=ScalingMode.MXFP8_1D_SCALING, + q_dtype=jnp.float8_e4m3fn, + q_layout=QuantizeLayout.ROWWISE, + n_groups=n_groups, + ) + # V1 quantize: K or N not 128-aligned → pre_swizzled=False + casted_lhs = jax.jit(tex.grouped_quantize, static_argnames=("flatten_axis",))( + lhs, quantizer=quantizer, group_sizes=group_sizes, flatten_axis=-1 + ) + assert isinstance(casted_lhs, GroupedScaledTensor1x) + assert not casted_lhs.pre_swizzled, "V1 quantize must produce pre_swizzled=False" + + lhs_tensor = GroupedNoScaleTensor( + data=lhs, amax=None, first_dims=group_sizes, last_dims=None, original_shape=lhs.shape + ) + rhs_tensor = GroupedNoScaleTensor( + data=rhs, amax=None, first_dims=None, last_dims=None, original_shape=rhs.shape + ) + quantizer_set = QuantizerFactory.create_set( + scaling_mode=ScalingMode.MXFP8_1D_SCALING, + fwd_dtype=jnp.float8_e4m3fn, + bwd_dtype=jnp.float8_e4m3fn, + is_2x2x=False, + n_groups=n_groups, + ) + prim_out = jax.jit(tex.grouped_gemm, static_argnames=("contracting_dims",))( + lhs_tensor, + rhs_tensor, + contracting_dims=contracting_dims, + quantizer_set=quantizer_set, + ) + self._assert_grouped_gemm_output(prim_out, group_sizes, ref_out, jnp.float8_e4m3fn) + + @pytest.mark.skipif(not is_v2_grouped_gemm_supported, reason=v2_grouped_gemm_unsupported_reason) + @pytest.mark.skipif(not is_mxfp8_supported, reason=mxfp8_unsupported_reason) + def test_grouped_dense_mxfp8_v2_pipeline(self, input_shape): + """V2 pipeline: V2 grouped quantize + V2 grouped GEMM (SM100+ required). + + Uses shapes where both K and N are 128-aligned, enabling V2 quantize (pre_swizzled=True) + and V2 GEMM on SM100+. Verifies correctness and that pre_swizzled=True. + """ + n_groups, m, n, k = input_shape + # Skip shapes that are not V2-eligible (K or N not 128-aligned). + if k % 128 != 0 or n % 128 != 0: + pytest.skip("Shape is not V2-eligible (K or N not 128-aligned)") + lhs, rhs, group_sizes, contracting_dims, _ = self._generate_grouped_dense_input( + jnp.bfloat16, input_shape, group_size_multiplier=128 + ) + ref_out = self._ref_grouped_dense(lhs, rhs, None, group_sizes, contracting_dims) + + quantizer = QuantizerFactory.create( + scaling_mode=ScalingMode.MXFP8_1D_SCALING, + q_dtype=jnp.float8_e4m3fn, + q_layout=QuantizeLayout.ROWWISE, + n_groups=n_groups, + ) + # V2 quantize (SM100+, 128-aligned M, K): pre_swizzled=True + casted_lhs = jax.jit(tex.grouped_quantize, static_argnames=("flatten_axis",))( + lhs, quantizer=quantizer, group_sizes=group_sizes, flatten_axis=-1 + ) + assert isinstance(casted_lhs, GroupedScaledTensor1x) + assert casted_lhs.pre_swizzled, "V2 quantize (SM100+) must produce pre_swizzled=True" + + lhs_tensor = GroupedNoScaleTensor( + data=lhs, amax=None, first_dims=group_sizes, last_dims=None, original_shape=lhs.shape + ) + rhs_tensor = GroupedNoScaleTensor( + data=rhs, amax=None, first_dims=None, last_dims=None, original_shape=rhs.shape + ) + quantizer_set = QuantizerFactory.create_set( + scaling_mode=ScalingMode.MXFP8_1D_SCALING, + fwd_dtype=jnp.float8_e4m3fn, + bwd_dtype=jnp.float8_e4m3fn, + is_2x2x=False, + n_groups=n_groups, + ) + prim_out = jax.jit(tex.grouped_gemm, static_argnames=("contracting_dims",))( + lhs_tensor, + rhs_tensor, + contracting_dims=contracting_dims, + quantizer_set=quantizer_set, + ) + self._assert_grouped_gemm_output(prim_out, group_sizes, ref_out, jnp.float8_e4m3fn) + # MXFP8 V1 shapes: lhs total_rows = m * 32 and rhs total_rows = n_groups * k are # NOT divisible by 128, forcing the V1 (non-CUDA-graph-safe) kernel. @@ -1964,11 +2120,14 @@ def test_grouped_dense_grad_fp8(self, fwd_bwd_dtype, scaling_mode, input_shape): # divisible by 128, allowing the V2 (CUDA-graph-safe) kernel to be used. # These shapes must be paired with group_size_multiplier=128 so that each group's # row count is also divisible by 128 (the V2 per-group alignment requirement). +# Additionally, both the last dimension of lhs (K) and of rhs (N) must be 128-aligned +# to match the V2 grouped GEMM constraint (and the updated V2 quantize constraint). GROUPED_DENSE_MXFP8_V2_INPUT_SHAPES = [ # (n_groups, m, n, k) - # lhs total_rows = m * 128; rhs total_rows = n_groups * k - (8, 8, 128, 128), # lhs: 8*128=1024 (128-aligned); rhs: 8*128=1024 (128-aligned) - (4, 4, 64, 256), # lhs: 4*128=512 (128-aligned); rhs: 4*256=1024 (128-aligned) + # lhs: (m*128, k); rhs: (n_groups, k, n) + # V2 requires: m*128 % 128==0, k % 128==0 (lhs), n_groups*k % 128==0, n % 128==0 (rhs) + (8, 8, 128, 128), # lhs: M=1024 ✓, K=128 ✓; rhs: G*K=1024 ✓, N=128 ✓ + (4, 4, 128, 256), # lhs: M=512 ✓, K=256 ✓; rhs: G*K=1024 ✓, N=128 ✓ ] diff --git a/transformer_engine/jax/cpp_extensions/gemm.py b/transformer_engine/jax/cpp_extensions/gemm.py index aaa975e0aa..5a34337a6b 100644 --- a/transformer_engine/jax/cpp_extensions/gemm.py +++ b/transformer_engine/jax/cpp_extensions/gemm.py @@ -381,24 +381,6 @@ def swizzled_scale(scale_inv, flatten_axis, is_colwise): return swizzled.reshape(original_shape) -def _swizzle_grouped_scale(scale_inv, scale_2d_shape, is_colwise): - """Swizzle a 1D grouped scale_inv buffer using full-tensor swizzle. - - The grouped scale_inv is 1D (worst-case padded). The meaningful prefix has size - equal to prod(scale_2d_shape). We reshape that prefix to 2D, swizzle it, and - write it back, leaving any trailing padding untouched. - """ - useful_size = math.prod(scale_2d_shape) - if useful_size == scale_inv.shape[0]: - # No trailing padding — reshape, swizzle, flatten. - return swizzled_scale(scale_inv.reshape(scale_2d_shape), 1, is_colwise).reshape( - scale_inv.shape - ) - # Split meaningful prefix from trailing padding, swizzle prefix only. - prefix = scale_inv[:useful_size].reshape(scale_2d_shape) - swizzled = swizzled_scale(prefix, 1, is_colwise).reshape((useful_size,)) - return jnp.concatenate([swizzled, scale_inv[useful_size:]]) - def get_lhs_axis_boundary(lhs_cdims, is_transposed): """Get the axis boundary for the LHS operand.""" @@ -1629,9 +1611,11 @@ def _compute_cublas_workspace_size( workspace_size += lhs_scale_inv_aval.size * tensor_scaling_sinv_aligment workspace_size += rhs_scale_inv_aval.size * tensor_scaling_sinv_aligment elif scaling_mode == ScalingMode.MXFP8_1D_SCALING.value: + # V1 needs workspace for swizzled scale_inv output buffers + # (nvte_swizzle_scaling_factors is called per-group inside GroupedGemmFFI). + # V2 receives scale_inv already swizzled by nvte_group_quantize (fused swizzle in + # V2 grouped quantize); no extra workspace is needed for re-swizzling. if not use_v2_ffi: - # V1 needs workspace for per-group swizzle output buffers. - # V2: scales are pre-swizzled in JAX, no extra workspace needed. workspace_size += lhs_scale_inv_aval.size + mxfp8_scaling_sinv_alignment_padding workspace_size += rhs_scale_inv_aval.size + mxfp8_scaling_sinv_alignment_padding return workspace_size @@ -2388,37 +2372,24 @@ def grouped_gemm( lhs_axis_boundary=lhs_axis_boundary, rhs_axis_boundary=rhs_axis_boundary, ) + + # V2 grouped GEMM requires MXFP8 inputs to be pre-swizzled by V2 grouped quantize + # (nvte_group_quantize fuses the swizzle). The C++ V2 GEMM FFI does not re-swizzle. if use_v2_ffi and scaling_mode == ScalingMode.MXFP8_1D_SCALING: - # Pre-swizzle full scale tensors in JAX (CUDA-graph safe). - # Grouped scale_inv is 1D (flat, worst-case padded). When all group sizes are - # multiples of 128 (V2 requirement), the per-group scales are contiguous with no - # inter-group padding gaps. We reshape the meaningful prefix to 2D, swizzle, and - # write it back into the original 1D buffer (extra trailing zeros stay untouched). - lhs_is_colwise = lhs_is_trans - rhs_is_colwise = not rhs_is_trans - lhs_scale_shape = scaling_mode.get_scale_shape( - lhs_data.shape, - is_colwise=lhs_is_colwise, - is_padded=True, - flatten_axis=lhs_axis_boundary, - ) - rhs_scale_shape = scaling_mode.get_scale_shape( - rhs_data.shape, - is_colwise=rhs_is_colwise, - is_padded=True, - flatten_axis=rhs_axis_boundary, - ) - # get_scale_shape may return a multi-dim shape (e.g. (8, 4, 128) for a 3D - # input), but _swizzle_grouped_scale needs a flat 2D shape (rows, cols) where - # cols = n_block_y (last dim) and rows = prod(all other dims). This correctly - # flattens the group/K-block axes into a single row dimension so the swizzle - # pattern operates on the full (K-blocks-across-groups × N-blocks) matrix. - lhs_n_block_y = lhs_scale_shape[-1] - rhs_n_block_y = rhs_scale_shape[-1] - lhs_scale_2d = (math.prod(lhs_scale_shape) // lhs_n_block_y, lhs_n_block_y) - rhs_scale_2d = (math.prod(rhs_scale_shape) // rhs_n_block_y, rhs_n_block_y) - lhs_scale_inv = _swizzle_grouped_scale(lhs_scale_inv, lhs_scale_2d, lhs_is_colwise) - rhs_scale_inv = _swizzle_grouped_scale(rhs_scale_inv, rhs_scale_2d, rhs_is_colwise) + if isinstance(lhs, GroupedScaledTensor1x) and not lhs.pre_swizzled: + raise ValueError( + "V2 grouped GEMM requires MXFP8 lhs scale_inv to be pre-swizzled. " + "GroupedScaledTensor1x.pre_swizzled is False. " + "Use V2 grouped quantize (nvte_group_quantize, requires SM100+ and " + "128-aligned shapes) to produce pre-swizzled tensors." + ) + if isinstance(rhs, GroupedScaledTensor1x) and not rhs.pre_swizzled: + raise ValueError( + "V2 grouped GEMM requires MXFP8 rhs scale_inv to be pre-swizzled. " + "GroupedScaledTensor1x.pre_swizzled is False. " + "Use V2 grouped quantize (nvte_group_quantize, requires SM100+ and " + "128-aligned shapes) to produce pre-swizzled tensors." + ) if use_v2_ffi: additional_arg_0 = jnp.ones((num_gemms,), jnp.float32) # alpha diff --git a/transformer_engine/jax/cpp_extensions/quantization.py b/transformer_engine/jax/cpp_extensions/quantization.py index 65d144c997..382138f4d5 100644 --- a/transformer_engine/jax/cpp_extensions/quantization.py +++ b/transformer_engine/jax/cpp_extensions/quantization.py @@ -1071,49 +1071,45 @@ def _use_v2_kernel(scaling_mode, x_shape, flatten_axis): """Return True when the V2 (CUDA-graph-safe) MXFP8 kernel can be used. V2 requires: - 1. The total first logical dimension (product of x_shape up to flatten_axis) + 1. SM100+ (Blackwell) — V2 grouped quantize fuses the scale_inv swizzle via + nvte_group_quantize. The swizzled scale_inv must then be consumed by the + V2 grouped GEMM, which also requires SM100+. Keeping both decisions tied + to SM100+ prevents a mismatch where V2-quantized (pre-swizzled) tensors + are passed to the V1 grouped GEMM (which would re-swizzle and corrupt). + 2. The total first logical dimension (product of x_shape up to flatten_axis) is divisible by 128. - 2. For multi-dim group tensors (eff > 1, e.g., kernel shape G×K×N), the + 3. For multi-dim group tensors (eff > 1, e.g., kernel shape G×K×N), the per-group row count non_group_m = prod(x_shape[1:eff]) must also be - divisible by 128 (because group_sizes[i] counts slices, not rows, and - actual rows per group = group_sizes[i] * non_group_m). - 3. For lhs-style tensors (eff == 1, shape M×K), individual group sizes must - be 128-aligned -- this is a dynamic constraint assumed by the caller. + divisible by 128. + 4. For lhs-style tensors (eff == 1, shape M×K), individual group sizes must + be 128-aligned — this is a dynamic constraint assumed by the caller. + 5. The last logical dimension (contracting dim K or output dim N) must be + divisible by 128, matching the V2 grouped GEMM constraint so that the + two always agree on V1 vs V2. Falls back to V1 when constraints are not met. V1 supports arbitrary shapes but performs a D2H copy of group_sizes (not CUDA-graph safe). """ if ScalingMode(scaling_mode) != ScalingMode.MXFP8_1D_SCALING: - # assert False, ( - # "V2 grouped quantize kernel currently only supports MXFP8 1D scaling mode, but got" - # " scaling_mode {}".format(scaling_mode) - # ) + return False + # Require SM100+ so V2 quantize (fused swizzle) is only used alongside V2 GEMM. + if get_min_device_compute_capability() < 100: return False ndim = len(x_shape) eff = flatten_axis if flatten_axis >= 0 else flatten_axis + ndim total_first_dim = math.prod(x_shape[:eff]) if total_first_dim % 128 != 0: - # assert False, ( - # "V2 grouped quantize kernel requires total first logical dimension (product of" - # " x_shape up to flatten_axis) to be divisible by 128, but got shape {} and" - # " flatten_axis {} with total_first_dim {}".format( - # x_shape, flatten_axis, total_first_dim - # ) - # ) return False # For multi-dim group tensors (e.g., kernel shape G×K×N with eff=2), # non_group_m = K must also be 128-aligned. if eff > 1: non_group_m = math.prod(x_shape[1:eff]) if non_group_m % 128 != 0: - # assert False, ( - # "V2 grouped quantize kernel requires non-group dimension (product of" - # " x_shape[1:flatten_axis]) to be divisible by 128 for multi-dim group tensors," - # " but got shape {} and flatten_axis {} with non_group_m {}".format( - # x_shape, flatten_axis, non_group_m - # ) - # ) return False + # Last dim must be 128-aligned to match the V2 grouped GEMM requirement. + last_dim = math.prod(x_shape[eff:]) + if last_dim % 128 != 0: + return False return True @staticmethod @@ -1407,6 +1403,11 @@ def grouped_quantize( for i, quantizer_i in enumerate(quantizer.quantizers): quantizer_i.update(updated_amax[i].reshape((1,))) + # V2 grouped quantize (nvte_group_quantize) fuses the scale_inv swizzle into + # the kernel, so the resulting tensors are already swizzled for GEMM. + use_v2 = GroupedQuantizePrimitive._use_v2_kernel( + quantizer.scaling_mode.value, x.shape, flatten_axis + ) out = ScaledTensorFactory.create( data=rowwise_casted_output, scale_inv=rowwise_scale_inv, @@ -1419,6 +1420,7 @@ def grouped_quantize( flatten_axis=flatten_axis, first_dims=ragged_first_dims, original_shape=original_shape, + pre_swizzled=use_v2, ) return out diff --git a/transformer_engine/jax/csrc/extensions/gemm.cpp b/transformer_engine/jax/csrc/extensions/gemm.cpp index f97e247c0c..05d005c70f 100644 --- a/transformer_engine/jax/csrc/extensions/gemm.cpp +++ b/transformer_engine/jax/csrc/extensions/gemm.cpp @@ -493,6 +493,8 @@ class JAXX_GroupedTensorWrapper { void set_rowwise(Buffer_Type const &data, std::optional const &scale_inv); void set_columnwise(Buffer_Type const &data, std::optional const &scale_inv); void set_with_gemm_swizzled_scales(bool val); + void replace_scale_inv(bool use_colwise, uint8_t *sinv_ptr, NVTEDType sinv_dtype, + NVTEShape sinv_shape); void set_group_info(Buffer_Type const &group_sizes, Buffer_Type const &group_offsets, NVTEGroupedTensorParam group_sizes_param_name); // Set only group sizes (no offsets); the setup kernel will compute offsets from sizes. @@ -601,6 +603,19 @@ void JAXX_GroupedTensorWrapper::set_with_gemm_swizzled_scales(bool val) { sizeof(v)); } +void JAXX_GroupedTensorWrapper::replace_scale_inv(bool use_colwise, uint8_t *sinv_ptr, + NVTEDType sinv_dtype, NVTEShape sinv_shape) { + if (use_colwise) { + m_colwise_scale_inv_tensor = NVTEBasicTensor{sinv_ptr, sinv_dtype, sinv_shape}; + nvte_set_grouped_tensor_param(m_grouped_tensor, kNVTEGroupedColumnwiseScaleInv, + &m_colwise_scale_inv_tensor, sizeof(m_colwise_scale_inv_tensor)); + } else { + m_scale_inv_tensor = NVTEBasicTensor{sinv_ptr, sinv_dtype, sinv_shape}; + nvte_set_grouped_tensor_param(m_grouped_tensor, kNVTEGroupedRowwiseScaleInv, + &m_scale_inv_tensor, sizeof(m_scale_inv_tensor)); + } +} + void JAXX_GroupedTensorWrapper::set_group_info(Buffer_Type const &group_sizes, Buffer_Type const &group_offsets, NVTEGroupedTensorParam group_sizes_param_name) { @@ -806,6 +821,9 @@ Error_Type GroupedGemmV2FFI(cudaStream_t stream, Buffer_Type lhs_data, Buffer_Ty rhs_last_dims, out_first_dims, out_last_dims, alpha); // Workspaces. + // V2 GEMM receives scale_inv already swizzled by nvte_group_quantize (V2 grouped quantize + // fuses the swizzle). No extra sinv reservation is needed; the full cublas_workspace is + // available for cuBLAS. auto setup_workspace_ptr = reinterpret_cast(setup_workspace->untyped_data()); auto cublas_workspace_ptr = reinterpret_cast(cublas_workspace->untyped_data()); cublas_workspace_ptr = move_ptr_to_next_256B_aligned(cublas_workspace_ptr); @@ -836,6 +854,9 @@ Error_Type GroupedGemmV2FFI(cudaStream_t stream, Buffer_Type lhs_data, Buffer_Ty const bool rhs_use_colwise = is_mxfp8 && !rhs_is_trans; const bool lhs_use_colwise = is_mxfp8 && lhs_is_trans; + // For MXFP8: scale_inv is already swizzled (pre-swizzled by V2 grouped quantize via + // nvte_group_quantize). Pass the buffers directly to make_grouped_tensor which sets + // with_gemm_swizzled_scales(true) for MXFP8 automatically. No re-swizzling needed. auto rhs_tensor = is_mxfp8 ? make_grouped_tensor(rhs_data, rhs_sinv, scaling_mode, rhs_use_colwise, rhs_first_dims, @@ -850,6 +871,7 @@ Error_Type GroupedGemmV2FFI(cudaStream_t stream, Buffer_Type lhs_data, Buffer_Ty stream, lhs_axis_boundary) : make_grouped_tensor(lhs_data, lhs_first_dims, lhs_last_dims, int64_base, int64_capacity, int64_offset, num_gemms, stream, lhs_axis_boundary); + // Output stays NO_SCALING auto out_tensor = make_grouped_tensor(*output, out_first_dims, out_last_dims, int64_base, int64_capacity, int64_offset, num_gemms, stream); diff --git a/transformer_engine/jax/quantize/tensor.py b/transformer_engine/jax/quantize/tensor.py index 403e8b536c..c5ad0451fd 100644 --- a/transformer_engine/jax/quantize/tensor.py +++ b/transformer_engine/jax/quantize/tensor.py @@ -369,11 +369,15 @@ class GroupedScaledTensor1x(ScaledTensor1x): first_dims: Per-group sizes of the first (row) 2D dim, or None if not ragged last_dims: Per-group sizes of the last (col) 2D dim, or None if not ragged original_shape: The original shape of the tensor before grouping + pre_swizzled: Whether the scale_inv is already swizzled for GEMM. True when produced + by V2 grouped quantize (nvte_group_quantize fuses the swizzle). The V2 grouped + GEMM FFI requires pre_swizzled=True for MXFP8 inputs and will not re-swizzle. """ first_dims: Optional[jnp.ndarray] last_dims: Optional[jnp.ndarray] original_shape: Tuple + pre_swizzled: bool = False def __init__( self, @@ -389,11 +393,13 @@ def __init__( data_layout, flatten_axis, original_shape, + pre_swizzled=False, ): self.flatten_axis = flatten_axis self.first_dims = first_dims self.last_dims = last_dims self.original_shape = original_shape + self.pre_swizzled = pre_swizzled # TODO(Phuong):Handle RHT for grouped quantization once grouped quantization supports NVFP4 super().__init__( data=data, @@ -468,6 +474,7 @@ def tree_flatten(self): self.data_layout, self.flatten_axis, self.original_shape, + self.pre_swizzled, ) return (children, aux_data) @@ -665,6 +672,7 @@ def create_1x( last_dims=None, original_shape=None, has_rht_applied=False, + pre_swizzled=False, ): """Creates a single-scale quantized tensor. @@ -734,6 +742,7 @@ def create_1x( first_dims=first_dims, last_dims=last_dims, original_shape=original_shape, + pre_swizzled=pre_swizzled, ) # Handling attrs of transposed tensors @@ -771,6 +780,7 @@ def create_2x( original_shape=None, rowwise_has_rht_applied=False, colwise_has_rht_applied=False, + pre_swizzled=False, ): """Creates a double-scale quantized tensor. @@ -812,6 +822,7 @@ def create_2x( last_dims=last_dims, original_shape=original_shape, has_rht_applied=rowwise_has_rht_applied, + pre_swizzled=pre_swizzled, ) colwise_tensor = ScaledTensorFactory.create_1x( colwise_data, @@ -826,6 +837,7 @@ def create_2x( last_dims=last_dims, original_shape=original_shape, has_rht_applied=colwise_has_rht_applied, + pre_swizzled=pre_swizzled, ) return ScaledTensor2x(rowwise_tensor, colwise_tensor) @@ -847,6 +859,7 @@ def create( original_shape: Tuple[int] = None, rowwise_has_rht_applied: bool = False, colwise_has_rht_applied: bool = False, + pre_swizzled: bool = False, ): """Creates a scaled tensor based on the quantization axis. @@ -865,6 +878,7 @@ def create( original_shape: The original shape of the tensor before grouping (default: None) rowwise_has_rht_applied: Whether the row-wise tensor uses the Randomized Hadamard Transform (RHT) (default: False) colwise_has_rht_applied: Whether the col-wise tensor uses the Randomized Hadamard Transform (RHT) (default: False) + pre_swizzled: Whether scale_inv is already swizzled (produced by V2 grouped quantize). Returns: Either a ScaledTensor1x or ScaledTensor2x instance depending on q_layout @@ -888,6 +902,7 @@ def create( original_shape=original_shape, rowwise_has_rht_applied=rowwise_has_rht_applied, colwise_has_rht_applied=colwise_has_rht_applied, + pre_swizzled=pre_swizzled, ) if q_layout.is_colwise_only: @@ -904,6 +919,7 @@ def create( last_dims=last_dims, original_shape=original_shape, has_rht_applied=colwise_has_rht_applied, + pre_swizzled=pre_swizzled, ) return ScaledTensorFactory.create_1x( @@ -919,6 +935,7 @@ def create( last_dims=last_dims, original_shape=original_shape, has_rht_applied=rowwise_has_rht_applied, + pre_swizzled=pre_swizzled, ) From 7e993143dbea07d0bf4d13c11e9653c7e318e284 Mon Sep 17 00:00:00 2001 From: Jeremy Berchtold Date: Mon, 30 Mar 2026 13:58:17 -0700 Subject: [PATCH 33/60] Fix grouped colwise dequantize for transposed ragged tensors and V1 pipeline test skip Two bugs fixed: 1. _grouped_dequantize (dequantizer.py): When data_layout=="T" (colwise) and first_dims is set (ragged LHS groups), original_shape is stored transposed so the group-variable axis is LAST, not first. Non-group dims are original_shape[:-1] and each group's data_shape is (*non_group_shape, group_sizes[i]), not (group_sizes[i], *original_shape[1:]). The old code computed matrix_sizes = group_sizes * total_rows which vastly exceeded the actual flattened data size, causing jnp.split to receive negative split sizes. 2. test_grouped_dense_mxfp8_v1_pipeline (test_custom_call_compute.py): The skip condition for V2-eligible shapes was "if k % 128 == 0 and n % 128 == 0" but V2 grouped quantize only requires k % 128 == 0 (n alignment is not a quantize requirement). The shape (8, 64, 32, 128) with k=128 was incorrectly not skipped, causing the test to fail on SM100+ where V2 quantize is used. Co-Authored-By: Claude Sonnet 4.6 --- tests/jax/test_custom_call_compute.py | 8 ++++--- .../jax/quantize/dequantizer.py | 24 +++++++++++++++---- 2 files changed, 24 insertions(+), 8 deletions(-) diff --git a/tests/jax/test_custom_call_compute.py b/tests/jax/test_custom_call_compute.py index 5e15781102..1fbba1b802 100644 --- a/tests/jax/test_custom_call_compute.py +++ b/tests/jax/test_custom_call_compute.py @@ -2014,9 +2014,11 @@ def test_grouped_dense_mxfp8_v1_pipeline(self, input_shape): and V1 GEMM on all GPUs. Verifies correctness and that pre_swizzled=False. """ n_groups, m, n, k = input_shape - # Skip shapes where both K and N are 128-aligned; those may use V2 on SM100+. - if k % 128 == 0 and n % 128 == 0: - pytest.skip("Shape is V2-eligible; this test targets V1-only shapes") + # Skip shapes where K is 128-aligned; on SM100+, V2 quantize is used for any + # shape where both the total row count and K are 128-aligned (N alignment is + # not required for quantize, only for GEMM). + if k % 128 == 0: + pytest.skip("Shape is V2-eligible (K is 128-aligned); this test targets V1-only shapes") lhs, rhs, group_sizes, contracting_dims, _ = self._generate_grouped_dense_input( jnp.bfloat16, input_shape, group_size_multiplier=32 ) diff --git a/transformer_engine/jax/quantize/dequantizer.py b/transformer_engine/jax/quantize/dequantizer.py index 2501412ab1..43a3d50f4e 100644 --- a/transformer_engine/jax/quantize/dequantizer.py +++ b/transformer_engine/jax/quantize/dequantizer.py @@ -290,17 +290,31 @@ def _grouped_dequantize(grouped_scaled_tensor): flatten_axis = len(original_shape) + flatten_axis if flatten_axis < 0 else flatten_axis output = [] - non_group_shape = tuple(original_shape[i] for i in range(len(original_shape)) if i != 0) + # When data_layout=="T" (colwise, transposed) and first_dims is set (ragged groups), the + # original_shape is stored transposed: the group (variable-size) axis is the LAST dimension + # rather than the first. Non-group dims are original_shape[:-1], not original_shape[1:]. + is_transposed_ragged = ( + grouped_scaled_tensor.data_layout == "T" + and grouped_scaled_tensor.first_dims is not None + and grouped_scaled_tensor.first_dims.size > 0 + ) + if is_transposed_ragged: + non_group_shape = original_shape[:-1] + else: + non_group_shape = tuple(original_shape[i] for i in range(len(original_shape)) if i != 0) matrix_sizes = group_sizes * math.prod(non_group_shape) data = jnp.split(data, jnp.cumulative_sum(matrix_sizes)[:-1]) scale_inv_ptr = 0 for i, data_i in enumerate(data): - data_shape_i = ( - group_sizes[i], - *original_shape[1:], - ) + if is_transposed_ragged: + data_shape_i = (*non_group_shape, int(group_sizes[i])) + else: + data_shape_i = ( + group_sizes[i], + *original_shape[1:], + ) assert math.prod(data_shape_i) == data_i.size, ( f"math.prod({data_shape_i}) = {math.prod(data_shape_i)} which is not equal to" f" {data_i.size}" From 68bcbfc7db4eb8c59194984951ec07f6288bc2dc Mon Sep 17 00:00:00 2001 From: Jeremy Berchtold Date: Mon, 30 Mar 2026 15:21:01 -0700 Subject: [PATCH 34/60] 2D shape fixes for flattened 1D shape from grouped quantization Signed-off-by: Jeremy Berchtold --- transformer_engine/jax/cpp_extensions/gemm.py | 51 ++++++++++++++++++- .../jax/csrc/extensions/gemm.cpp | 39 +++++++------- 2 files changed, 70 insertions(+), 20 deletions(-) diff --git a/transformer_engine/jax/cpp_extensions/gemm.py b/transformer_engine/jax/cpp_extensions/gemm.py index 5a34337a6b..7c5da1a8f5 100644 --- a/transformer_engine/jax/cpp_extensions/gemm.py +++ b/transformer_engine/jax/cpp_extensions/gemm.py @@ -2072,10 +2072,16 @@ def _can_use_v2_grouped_gemm( return False if has_bias: + if enforce_v2_gmm: + raise RuntimeError( + "Grouped GEMM with bias is not supported in the TE V2 grouped GEMM kernel, but" + " NVTE_JAX_ENFORCE_V2_GROUPED_GEMM is enabled and has_bias is True." + ) return False if scaling_mode == ScalingMode.NO_SCALING and dtype == jnp.bfloat16: return True + if scaling_mode == ScalingMode.MXFP8_1D_SCALING: # V2 MXFP8 requires that the total first dimension of both operands (up to # axis_boundary) is divisible by 128, matching the quantize V2 kernel requirement. @@ -2083,10 +2089,26 @@ def _can_use_v2_grouped_gemm( if lhs_shape is not None and lhs_axis_boundary is not None: lhs_first_dim = math.prod(lhs_shape[:lhs_axis_boundary]) if lhs_first_dim % 128 != 0: + if enforce_v2_gmm: + raise RuntimeError( + "The TE V2 grouped GEMM for MXFP8 requires the product of the first" + f" dimensions (up to axis_boundary) of LHS to be divisible by 128, but got" + f" {lhs_first_dim} with lhs_shape={lhs_shape} and" + f" lhs_axis_boundary={lhs_axis_boundary}, and NVTE_JAX_ENFORCE_V2_GROUPED_GEMM" + " is enabled." + ) return False if rhs_shape is not None and rhs_axis_boundary is not None: rhs_first_dim = math.prod(rhs_shape[:rhs_axis_boundary]) if rhs_first_dim % 128 != 0: + if enforce_v2_gmm: + raise RuntimeError( + "The TE V2 grouped GEMM for MXFP8 requires the product of the first" + f" dimensions (up to axis_boundary) of RHS to be divisible by 128, but got" + f" {rhs_first_dim} with rhs_shape={rhs_shape} and" + f" rhs_axis_boundary={rhs_axis_boundary}, and NVTE_JAX_ENFORCE_V2_GROUPED_GEMM" + " is enabled." + ) return False # V2 MXFP8 also requires that the "last" dimension (after axis_boundary) of both # operands is a multiple of 128. The V2 GEMM setup kernel computes per-group @@ -2099,13 +2121,38 @@ def _can_use_v2_grouped_gemm( if lhs_shape is not None and lhs_axis_boundary is not None: lhs_last_dim = math.prod(lhs_shape[lhs_axis_boundary:]) if lhs_last_dim % 128 != 0: + if enforce_v2_gmm: + raise RuntimeError( + "The TE V2 grouped GEMM for MXFP8 requires the product of the last" + f" dimensions (after axis_boundary) of LHS to be divisible by 128, but got" + f" {lhs_last_dim} with lhs_shape={lhs_shape} and" + f" lhs_axis_boundary={lhs_axis_boundary}, and NVTE_JAX_ENFORCE_V2_GROUPED_GEMM" + " is enabled." + ) return False if rhs_shape is not None and rhs_axis_boundary is not None: rhs_last_dim = math.prod(rhs_shape[rhs_axis_boundary:]) if rhs_last_dim % 128 != 0: + if enforce_v2_gmm: + raise RuntimeError( + "The TE V2 grouped GEMM for MXFP8 requires the product of the last" + f" dimensions (after axis_boundary) of RHS to be divisible by 128, but got" + f" {rhs_last_dim} with rhs_shape={rhs_shape} and" + f" rhs_axis_boundary={rhs_axis_boundary}, and NVTE_JAX_ENFORCE_V2_GROUPED_GEMM" + " is enabled." + ) return False return True + if enforce_v2_gmm: + raise RuntimeError( + "The TE V2 grouped GEMM currently only supports non-quantized BF16 and MXFP8 with 1D block" + " scaling, but NVTE_JAX_ENFORCE_V2_GROUPED_GEMM is enabled and the input parameters do not meet" + " these requirements (scaling_mode=" + f" {scaling_mode}, dtype={dtype}, has_bias={has_bias}, lhs_shape={lhs_shape}," + f" rhs_shape={rhs_shape}, lhs_axis_boundary={lhs_axis_boundary}," + f" rhs_axis_boundary={rhs_axis_boundary})." + ) return False @@ -2367,8 +2414,8 @@ def grouped_gemm( scaling_mode, lhs_data.dtype, has_bias, - lhs_shape=lhs_data.shape, - rhs_shape=rhs_data.shape, + lhs_shape=lhs_shape, + rhs_shape=rhs_shape, lhs_axis_boundary=lhs_axis_boundary, rhs_axis_boundary=rhs_axis_boundary, ) diff --git a/transformer_engine/jax/csrc/extensions/gemm.cpp b/transformer_engine/jax/csrc/extensions/gemm.cpp index 05d005c70f..4ea2405f66 100644 --- a/transformer_engine/jax/csrc/extensions/gemm.cpp +++ b/transformer_engine/jax/csrc/extensions/gemm.cpp @@ -687,14 +687,11 @@ JAXX_GroupedTensorWrapper make_grouped_tensor(Buffer_Type const &data, JAXX_GroupedTensorWrapper make_grouped_tensor( Buffer_Type const &data, Buffer_Type const &first_dims, Buffer_Type const &last_dims, int64_t *int64_workspace_base, size_t int64_workspace_capacity, size_t &int64_offset, - size_t num_gemms, cudaStream_t stream, int64_t axis_boundary = -1) { + size_t num_gemms, cudaStream_t stream, size_t left_size, size_t right_size) { auto dims = data.dimensions(); - NVTE_CHECK(dims.size() >= 2, "grouped GEMM data buffer must be at least 2D."); - // Flatten dims at axis_boundary to produce a 2D NVTE shape. - // axis_boundary=-1 (default) collapses dims[0..N-2] → rows and keeps dims[N-1] → cols, - // preserving the prior behaviour for output buffers (e.g. [G, K, N] for wgrad). - size_t ab = (axis_boundary < 0) ? dims.size() - 1 : static_cast(axis_boundary); - NVTEShape dataShape{.data = {product(dims, 0, ab), product(dims, ab, dims.size())}, .ndim = 2}; + NVTE_CHECK(product(dims) == left_size * right_size, + "grouped GEMM data buffer element count does not match the provided 2D shape."); + NVTEShape dataShape{.data = {left_size, right_size}, .ndim = 2}; JAXX_GroupedTensorWrapper wrapper(JAXX_Scaling_Mode::NO_SCALING, num_gemms, dataShape); wrapper.set_rowwise(data, std::nullopt); if (first_dims.element_count() > 0) { @@ -727,11 +724,11 @@ JAXX_GroupedTensorWrapper make_grouped_tensor( Buffer_Type const &data, Buffer_Type const &scale_inv, JAXX_Scaling_Mode scaling_mode, bool use_colwise, Buffer_Type const &first_dims, Buffer_Type const &last_dims, int64_t *int64_workspace_base, size_t int64_workspace_capacity, size_t &int64_offset, - size_t num_gemms, cudaStream_t stream, int64_t axis_boundary = -1) { + size_t num_gemms, cudaStream_t stream, size_t left_size, size_t right_size) { auto dims = data.dimensions(); - NVTE_CHECK(dims.size() >= 2, "grouped GEMM data buffer must be at least 2D."); - size_t ab = (axis_boundary < 0) ? dims.size() - 1 : static_cast(axis_boundary); - NVTEShape dataShape{.data = {product(dims, 0, ab), product(dims, ab, dims.size())}, .ndim = 2}; + NVTE_CHECK(product(dims) == left_size * right_size, + "grouped GEMM data buffer element count does not match the provided 2D shape."); + NVTEShape dataShape{.data = {left_size, right_size}, .ndim = 2}; JAXX_GroupedTensorWrapper wrapper(scaling_mode, num_gemms, dataShape); const bool is_mxfp8 = scaling_mode == JAXX_Scaling_Mode::MXFP8_1D_SCALING; @@ -861,20 +858,26 @@ Error_Type GroupedGemmV2FFI(cudaStream_t stream, Buffer_Type lhs_data, Buffer_Ty is_mxfp8 ? make_grouped_tensor(rhs_data, rhs_sinv, scaling_mode, rhs_use_colwise, rhs_first_dims, rhs_last_dims, int64_base, int64_capacity, int64_offset, num_gemms, - stream, rhs_axis_boundary) + stream, rhs_left_size, rhs_right_size) : make_grouped_tensor(rhs_data, rhs_first_dims, rhs_last_dims, int64_base, int64_capacity, - int64_offset, num_gemms, stream, rhs_axis_boundary); + int64_offset, num_gemms, stream, rhs_left_size, rhs_right_size); auto lhs_tensor = is_mxfp8 ? make_grouped_tensor(lhs_data, lhs_sinv, scaling_mode, lhs_use_colwise, lhs_first_dims, lhs_last_dims, int64_base, int64_capacity, int64_offset, num_gemms, - stream, lhs_axis_boundary) + stream, lhs_left_size, lhs_right_size) : make_grouped_tensor(lhs_data, lhs_first_dims, lhs_last_dims, int64_base, int64_capacity, - int64_offset, num_gemms, stream, lhs_axis_boundary); - - // Output stays NO_SCALING + int64_offset, num_gemms, stream, lhs_left_size, lhs_right_size); + + // Output stays NO_SCALING. Derive 2D shape from the output buffer's own dims using + // last-dim-as-columns convention (equivalent to axis_boundary=-1 in the old API). + auto out_dims = output->dimensions(); + NVTE_CHECK(out_dims.size() > 0, "output buffer must have at least 1 dimension"); + size_t out_left_size = product(out_dims, 0, out_dims.size() - 1); + size_t out_right_size = static_cast(out_dims[out_dims.size() - 1]); auto out_tensor = make_grouped_tensor(*output, out_first_dims, out_last_dims, int64_base, - int64_capacity, int64_offset, num_gemms, stream); + int64_capacity, int64_offset, num_gemms, stream, + out_left_size, out_right_size); nvte_grouped_gemm(rhs_tensor, rhs_is_trans, lhs_tensor, lhs_is_trans, nullptr, out_tensor, alpha_tensor.data(), beta_tensor.data(), workspace_setup.data(), From 81cb189986286fbe4685287b9610e92e0ff8eecc Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 31 Mar 2026 16:06:24 +0000 Subject: [PATCH 35/60] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- tests/jax/test_custom_call_compute.py | 20 +++++------ transformer_engine/jax/cpp_extensions/gemm.py | 36 +++++++++---------- .../jax/csrc/extensions/gemm.cpp | 8 ++--- 3 files changed, 29 insertions(+), 35 deletions(-) diff --git a/tests/jax/test_custom_call_compute.py b/tests/jax/test_custom_call_compute.py index 1fbba1b802..079fb5df05 100644 --- a/tests/jax/test_custom_call_compute.py +++ b/tests/jax/test_custom_call_compute.py @@ -1137,9 +1137,9 @@ def test_grouped_quantize_v1_pre_swizzled( x, quantizer=quantizer, group_sizes=group_sizes, flatten_axis=-1 ) assert isinstance(scaled_tensor, GroupedScaledTensor1x) - assert not scaled_tensor.pre_swizzled, ( - "V1 grouped quantize (non-128-aligned K) must produce pre_swizzled=False" - ) + assert ( + not scaled_tensor.pre_swizzled + ), "V1 grouped quantize (non-128-aligned K) must produce pre_swizzled=False" @pytest.mark.skipif(not is_v2_grouped_gemm_supported, reason=v2_grouped_gemm_unsupported_reason) @pytest.mark.skipif(not is_mxfp8_supported, reason=mxfp8_unsupported_reason) @@ -1165,9 +1165,9 @@ def test_grouped_quantize_v2_pre_swizzled( x, quantizer=quantizer, group_sizes=group_sizes, flatten_axis=-1 ) assert isinstance(scaled_tensor, GroupedScaledTensor1x) - assert scaled_tensor.pre_swizzled, ( - "V2 grouped quantize (SM100+, 128-aligned M and K) must produce pre_swizzled=True" - ) + assert ( + scaled_tensor.pre_swizzled + ), "V2 grouped quantize (SM100+, 128-aligned M and K) must produce pre_swizzled=True" @pytest_parametrize_wrapper("in_dtype", QUANTIZATION_INPUT_DTYPE) @@ -2302,9 +2302,7 @@ def _prim_sum(x, kernel, group_sizes): ] -@pytest.mark.skipif( - not is_v2_grouped_gemm_supported, reason=v2_grouped_gemm_unsupported_reason -) +@pytest.mark.skipif(not is_v2_grouped_gemm_supported, reason=v2_grouped_gemm_unsupported_reason) class TestGroupedDenseBF16V2GEMM: """Tests that explicitly verify V2 BF16 grouped GEMM on SM100+ hardware. @@ -2358,9 +2356,7 @@ def test_grouped_gemm_bf16_v2(self, input_shape): axis=0, ) - prim_out = jax.jit( - tex.grouped_gemm, static_argnames=("contracting_dims",) - )( + prim_out = jax.jit(tex.grouped_gemm, static_argnames=("contracting_dims",))( lhs_tensor, rhs_tensor, contracting_dims=((1,), (1,)), diff --git a/transformer_engine/jax/cpp_extensions/gemm.py b/transformer_engine/jax/cpp_extensions/gemm.py index 7c5da1a8f5..58c34b7662 100644 --- a/transformer_engine/jax/cpp_extensions/gemm.py +++ b/transformer_engine/jax/cpp_extensions/gemm.py @@ -381,7 +381,6 @@ def swizzled_scale(scale_inv, flatten_axis, is_colwise): return swizzled.reshape(original_shape) - def get_lhs_axis_boundary(lhs_cdims, is_transposed): """Get the axis boundary for the LHS operand.""" return max(lhs_cdims) + 1 if is_transposed else min(lhs_cdims) @@ -2092,10 +2091,10 @@ def _can_use_v2_grouped_gemm( if enforce_v2_gmm: raise RuntimeError( "The TE V2 grouped GEMM for MXFP8 requires the product of the first" - f" dimensions (up to axis_boundary) of LHS to be divisible by 128, but got" + " dimensions (up to axis_boundary) of LHS to be divisible by 128, but got" f" {lhs_first_dim} with lhs_shape={lhs_shape} and" - f" lhs_axis_boundary={lhs_axis_boundary}, and NVTE_JAX_ENFORCE_V2_GROUPED_GEMM" - " is enabled." + f" lhs_axis_boundary={lhs_axis_boundary}, and" + " NVTE_JAX_ENFORCE_V2_GROUPED_GEMM is enabled." ) return False if rhs_shape is not None and rhs_axis_boundary is not None: @@ -2104,10 +2103,10 @@ def _can_use_v2_grouped_gemm( if enforce_v2_gmm: raise RuntimeError( "The TE V2 grouped GEMM for MXFP8 requires the product of the first" - f" dimensions (up to axis_boundary) of RHS to be divisible by 128, but got" + " dimensions (up to axis_boundary) of RHS to be divisible by 128, but got" f" {rhs_first_dim} with rhs_shape={rhs_shape} and" - f" rhs_axis_boundary={rhs_axis_boundary}, and NVTE_JAX_ENFORCE_V2_GROUPED_GEMM" - " is enabled." + f" rhs_axis_boundary={rhs_axis_boundary}, and" + " NVTE_JAX_ENFORCE_V2_GROUPED_GEMM is enabled." ) return False # V2 MXFP8 also requires that the "last" dimension (after axis_boundary) of both @@ -2124,10 +2123,10 @@ def _can_use_v2_grouped_gemm( if enforce_v2_gmm: raise RuntimeError( "The TE V2 grouped GEMM for MXFP8 requires the product of the last" - f" dimensions (after axis_boundary) of LHS to be divisible by 128, but got" + " dimensions (after axis_boundary) of LHS to be divisible by 128, but got" f" {lhs_last_dim} with lhs_shape={lhs_shape} and" - f" lhs_axis_boundary={lhs_axis_boundary}, and NVTE_JAX_ENFORCE_V2_GROUPED_GEMM" - " is enabled." + f" lhs_axis_boundary={lhs_axis_boundary}, and" + " NVTE_JAX_ENFORCE_V2_GROUPED_GEMM is enabled." ) return False if rhs_shape is not None and rhs_axis_boundary is not None: @@ -2136,22 +2135,21 @@ def _can_use_v2_grouped_gemm( if enforce_v2_gmm: raise RuntimeError( "The TE V2 grouped GEMM for MXFP8 requires the product of the last" - f" dimensions (after axis_boundary) of RHS to be divisible by 128, but got" + " dimensions (after axis_boundary) of RHS to be divisible by 128, but got" f" {rhs_last_dim} with rhs_shape={rhs_shape} and" - f" rhs_axis_boundary={rhs_axis_boundary}, and NVTE_JAX_ENFORCE_V2_GROUPED_GEMM" - " is enabled." + f" rhs_axis_boundary={rhs_axis_boundary}, and" + " NVTE_JAX_ENFORCE_V2_GROUPED_GEMM is enabled." ) return False return True if enforce_v2_gmm: raise RuntimeError( - "The TE V2 grouped GEMM currently only supports non-quantized BF16 and MXFP8 with 1D block" - " scaling, but NVTE_JAX_ENFORCE_V2_GROUPED_GEMM is enabled and the input parameters do not meet" - " these requirements (scaling_mode=" - f" {scaling_mode}, dtype={dtype}, has_bias={has_bias}, lhs_shape={lhs_shape}," - f" rhs_shape={rhs_shape}, lhs_axis_boundary={lhs_axis_boundary}," - f" rhs_axis_boundary={rhs_axis_boundary})." + "The TE V2 grouped GEMM currently only supports non-quantized BF16 and MXFP8 with 1D" + " block scaling, but NVTE_JAX_ENFORCE_V2_GROUPED_GEMM is enabled and the input" + f" parameters do not meet these requirements (scaling_mode= {scaling_mode}," + f" dtype={dtype}, has_bias={has_bias}, lhs_shape={lhs_shape}, rhs_shape={rhs_shape}," + f" lhs_axis_boundary={lhs_axis_boundary}, rhs_axis_boundary={rhs_axis_boundary})." ) return False diff --git a/transformer_engine/jax/csrc/extensions/gemm.cpp b/transformer_engine/jax/csrc/extensions/gemm.cpp index 4ea2405f66..587463b75d 100644 --- a/transformer_engine/jax/csrc/extensions/gemm.cpp +++ b/transformer_engine/jax/csrc/extensions/gemm.cpp @@ -604,7 +604,7 @@ void JAXX_GroupedTensorWrapper::set_with_gemm_swizzled_scales(bool val) { } void JAXX_GroupedTensorWrapper::replace_scale_inv(bool use_colwise, uint8_t *sinv_ptr, - NVTEDType sinv_dtype, NVTEShape sinv_shape) { + NVTEDType sinv_dtype, NVTEShape sinv_shape) { if (use_colwise) { m_colwise_scale_inv_tensor = NVTEBasicTensor{sinv_ptr, sinv_dtype, sinv_shape}; nvte_set_grouped_tensor_param(m_grouped_tensor, kNVTEGroupedColumnwiseScaleInv, @@ -875,9 +875,9 @@ Error_Type GroupedGemmV2FFI(cudaStream_t stream, Buffer_Type lhs_data, Buffer_Ty NVTE_CHECK(out_dims.size() > 0, "output buffer must have at least 1 dimension"); size_t out_left_size = product(out_dims, 0, out_dims.size() - 1); size_t out_right_size = static_cast(out_dims[out_dims.size() - 1]); - auto out_tensor = make_grouped_tensor(*output, out_first_dims, out_last_dims, int64_base, - int64_capacity, int64_offset, num_gemms, stream, - out_left_size, out_right_size); + auto out_tensor = + make_grouped_tensor(*output, out_first_dims, out_last_dims, int64_base, int64_capacity, + int64_offset, num_gemms, stream, out_left_size, out_right_size); nvte_grouped_gemm(rhs_tensor, rhs_is_trans, lhs_tensor, lhs_is_trans, nullptr, out_tensor, alpha_tensor.data(), beta_tensor.data(), workspace_setup.data(), From d7b04ccd34584e0d5933c3bb244619cd507b138e Mon Sep 17 00:00:00 2001 From: Jeremy Berchtold Date: Fri, 3 Apr 2026 17:33:44 -0700 Subject: [PATCH 36/60] Fix swizzling Signed-off-by: Jeremy Berchtold --- transformer_engine/jax/cpp_extensions/gemm.py | 10 +-- .../jax/csrc/extensions/gemm.cpp | 64 +++---------------- .../jax/csrc/extensions/quantization.cpp | 28 ++++++-- 3 files changed, 34 insertions(+), 68 deletions(-) diff --git a/transformer_engine/jax/cpp_extensions/gemm.py b/transformer_engine/jax/cpp_extensions/gemm.py index 58c34b7662..d18ef866b8 100644 --- a/transformer_engine/jax/cpp_extensions/gemm.py +++ b/transformer_engine/jax/cpp_extensions/gemm.py @@ -1610,13 +1610,9 @@ def _compute_cublas_workspace_size( workspace_size += lhs_scale_inv_aval.size * tensor_scaling_sinv_aligment workspace_size += rhs_scale_inv_aval.size * tensor_scaling_sinv_aligment elif scaling_mode == ScalingMode.MXFP8_1D_SCALING.value: - # V1 needs workspace for swizzled scale_inv output buffers - # (nvte_swizzle_scaling_factors is called per-group inside GroupedGemmFFI). - # V2 receives scale_inv already swizzled by nvte_group_quantize (fused swizzle in - # V2 grouped quantize); no extra workspace is needed for re-swizzling. - if not use_v2_ffi: - workspace_size += lhs_scale_inv_aval.size + mxfp8_scaling_sinv_alignment_padding - workspace_size += rhs_scale_inv_aval.size + mxfp8_scaling_sinv_alignment_padding + # Both V1 and V2 quantize now produce pre-swizzled scales, so the GEMM + # does not need extra workspace for nvte_swizzle_scaling_factors. + pass return workspace_size @staticmethod diff --git a/transformer_engine/jax/csrc/extensions/gemm.cpp b/transformer_engine/jax/csrc/extensions/gemm.cpp index 6966eb1138..11ee1e470f 100644 --- a/transformer_engine/jax/csrc/extensions/gemm.cpp +++ b/transformer_engine/jax/csrc/extensions/gemm.cpp @@ -1019,20 +1019,14 @@ Error_Type GroupedGemmFFI(cudaStream_t stream, Buffer_Type lhs_data, Buffer_Type const size_t tensor_scaling_sinv_aligment = 16; const size_t mxfp8_scaling_sinv_alignment_padding = 256; auto workspace_size = workspace_total_size - workspace_alignment_padding; - if (is_mxfp8_scaling) { - // For MXFP8 swizzled scale_inv buffers, only the first pointer needs to be with 256B alignment padding. Later pointers are guaranteed to be 256-aligned as the scale_inv shapes are padded by 128x4. - workspace_size -= (lhs_sinv_size + rhs_sinv_size + 2 * mxfp8_scaling_sinv_alignment_padding); - } else if (is_tensor_scaling) { + if (is_tensor_scaling) { // For tensor scaling, each matrix has a single scale value, and all scales need to be aligned // by 16 bytes to meet the requirement of CUDA 12.9.1 and later. workspace_size -= tensor_scaling_sinv_aligment * (lhs_sinv_size + rhs_sinv_size); } workspace_size = workspace_size / num_streams; - auto swizzled_lhs_sinv_ptr = workspace_ptr + workspace_size * num_streams; - swizzled_lhs_sinv_ptr = move_ptr_to_next_256B_aligned(swizzled_lhs_sinv_ptr); - auto swizzled_rhs_sinv_ptr = swizzled_lhs_sinv_ptr + lhs_sinv_size; - swizzled_rhs_sinv_ptr = move_ptr_to_next_256B_aligned(swizzled_rhs_sinv_ptr); - auto lhs_scatter_aligned_ptr = swizzled_lhs_sinv_ptr; // Already 256B aligned + auto lhs_scatter_aligned_ptr = workspace_ptr + workspace_size * num_streams; + lhs_scatter_aligned_ptr = move_ptr_to_next_256B_aligned(lhs_scatter_aligned_ptr); auto rhs_scatter_aligned_ptr = lhs_scatter_aligned_ptr + num_gemms * tensor_scaling_sinv_aligment; size_t lhs_dtype_bytes = te_dtype_bytes(lhs_dtype); @@ -1126,8 +1120,6 @@ Error_Type GroupedGemmFFI(cudaStream_t stream, Buffer_Type lhs_data, Buffer_Type // These lists are to keep the TensorWrapper objects alive std::vector lhs_wrapper_list; std::vector rhs_wrapper_list; - std::vector lhs_swizzle_wrapper_list; // For MXFP8 scale_inv swizzling - std::vector rhs_swizzle_wrapper_list; std::vector bias_wrapper_list; std::vector pre_gelu_wrapper_list; std::vector out_wrapper_list; @@ -1136,8 +1128,6 @@ Error_Type GroupedGemmFFI(cudaStream_t stream, Buffer_Type lhs_data, Buffer_Type // These lists are the actual NVTETensor (void *) lists for multi-stream GEMM std::vector lhs_list; std::vector rhs_list; - std::vector lhs_swizzle_list; - std::vector rhs_swizzle_list; std::vector bias_list; std::vector pre_gelu_list; std::vector out_list; @@ -1210,13 +1200,8 @@ Error_Type GroupedGemmFFI(cudaStream_t stream, Buffer_Type lhs_data, Buffer_Type else lhs_i.set_rowwise_scale_inv(lhs_sinv_vptr, lhs_sinv_dtype, tensor_scaling_sinv_shape); } else if (is_mxfp8_scaling) { - auto lhs_swizzle_i = TensorWrapper(get_nvte_scaling_mode(scaling_mode)); - auto rhs_swizzle_i = TensorWrapper(get_nvte_scaling_mode(scaling_mode)); - void *swizzled_lhs_sinv_vptr = static_cast(swizzled_lhs_sinv_ptr); - void *swizzled_rhs_sinv_vptr = static_cast(swizzled_rhs_sinv_ptr); - - // {lhs, rhs}_swizzle_i point to unswizzled scale_inv data as input, while {lhs, rhs}_i - // point to swizzled scale_inv data (store on workspace, only used for GEMM). + // MXFP8 scales are pre-swizzled by the quantize kernel (both V1 and V2), + // so we pass them directly to the GEMM without a separate swizzle pass. // Note: even if is_empty_gemm is true, sinv are still non-empty, need to move the pointers auto lhs_sinv_shape_i = get_block_scale_shape(scaling_mode, lhs_shape_i[0], lhs_shape_i[1], lhs_use_colwise); @@ -1225,32 +1210,17 @@ Error_Type GroupedGemmFFI(cudaStream_t stream, Buffer_Type lhs_data, Buffer_Type lhs_sinv_size_i = lhs_sinv_shape_i[0] * lhs_sinv_shape_i[1]; rhs_sinv_size_i = rhs_sinv_shape_i[0] * rhs_sinv_shape_i[1]; if (lhs_use_colwise) { - lhs_swizzle_i.set_columnwise_data(lhs_vptr, lhs_dtype, lhs_shape_i); - lhs_swizzle_i.set_columnwise_scale_inv(lhs_sinv_vptr, lhs_sinv_dtype, lhs_sinv_shape_i); - lhs_i.set_columnwise_scale_inv(swizzled_lhs_sinv_vptr, lhs_sinv_dtype, lhs_sinv_shape_i); + lhs_i.set_columnwise_scale_inv(lhs_sinv_vptr, lhs_sinv_dtype, lhs_sinv_shape_i); } else { - lhs_swizzle_i.set_rowwise_data(lhs_vptr, lhs_dtype, lhs_shape_i); - lhs_swizzle_i.set_rowwise_scale_inv(lhs_sinv_vptr, lhs_sinv_dtype, lhs_sinv_shape_i); - lhs_i.set_rowwise_scale_inv(swizzled_lhs_sinv_vptr, lhs_sinv_dtype, lhs_sinv_shape_i); + lhs_i.set_rowwise_scale_inv(lhs_sinv_vptr, lhs_sinv_dtype, lhs_sinv_shape_i); } lhs_i.set_with_gemm_swizzled_scales(true); if (rhs_use_colwise) { - rhs_swizzle_i.set_columnwise_data(rhs_vptr, rhs_dtype, rhs_shape_i); - rhs_swizzle_i.set_columnwise_scale_inv(rhs_sinv_vptr, rhs_sinv_dtype, rhs_sinv_shape_i); - rhs_i.set_columnwise_scale_inv(swizzled_rhs_sinv_vptr, rhs_sinv_dtype, rhs_sinv_shape_i); + rhs_i.set_columnwise_scale_inv(rhs_sinv_vptr, rhs_sinv_dtype, rhs_sinv_shape_i); } else { - rhs_swizzle_i.set_rowwise_data(rhs_vptr, rhs_dtype, rhs_shape_i); - rhs_swizzle_i.set_rowwise_scale_inv(rhs_sinv_vptr, rhs_sinv_dtype, rhs_sinv_shape_i); - rhs_i.set_rowwise_scale_inv(swizzled_rhs_sinv_vptr, rhs_sinv_dtype, rhs_sinv_shape_i); + rhs_i.set_rowwise_scale_inv(rhs_sinv_vptr, rhs_sinv_dtype, rhs_sinv_shape_i); } rhs_i.set_with_gemm_swizzled_scales(true); - - if (!is_empty_gemm) { - lhs_swizzle_wrapper_list.push_back(std::move(lhs_swizzle_i)); - rhs_swizzle_wrapper_list.push_back(std::move(rhs_swizzle_i)); - lhs_swizzle_list.push_back(lhs_swizzle_wrapper_list.back().data()); - rhs_swizzle_list.push_back(rhs_swizzle_wrapper_list.back().data()); - } } else { NVTE_CHECK(scaling_mode == JAXX_Scaling_Mode::NO_SCALING, "Unsupported scaling mode: ", static_cast(scaling_mode)); @@ -1268,10 +1238,6 @@ Error_Type GroupedGemmFFI(cudaStream_t stream, Buffer_Type lhs_data, Buffer_Type rhs_sinv_ptr += rhs_sinv_size_i * rhs_sinv_dtype_bytes; lhs_sinv_total_size += lhs_sinv_size_i; rhs_sinv_total_size += rhs_sinv_size_i; - if (is_mxfp8_scaling) { - swizzled_lhs_sinv_ptr += lhs_sinv_size_i * lhs_sinv_dtype_bytes; - swizzled_rhs_sinv_ptr += rhs_sinv_size_i * rhs_sinv_dtype_bytes; - } } if (has_bias) bias_ptr += n * bias_dtype_bytes; @@ -1312,18 +1278,6 @@ Error_Type GroupedGemmFFI(cudaStream_t stream, Buffer_Type lhs_data, Buffer_Type size_t num_non_empty_gemms = lhs_list.size(); - if (is_mxfp8_scaling) { - for (int i = 0; i < num_non_empty_gemms; i++) { - // The i-th GEMM will use the (i % num_streams)-th stream to compute, - // use the same stream to swizzle the scaling factors to make sure that - // the swizzling is done before the GEMM computation starts. - int stream_id = i % num_streams; - cudaStream_t stream_i = nvte_get_compute_stream(stream_id); - nvte_swizzle_scaling_factors(lhs_swizzle_list[i], lhs_list[i], stream_i); - nvte_swizzle_scaling_factors(rhs_swizzle_list[i], rhs_list[i], stream_i); - } - } - // Launch zero-out kernels before the GEMM calls to use the sync in the multi-stream GEMM size_t num_zero_outs = zero_out_dptr_list.size(); for (int i = 0; i < num_zero_outs; i++) { diff --git a/transformer_engine/jax/csrc/extensions/quantization.cpp b/transformer_engine/jax/csrc/extensions/quantization.cpp index 57d0353379..938c225259 100644 --- a/transformer_engine/jax/csrc/extensions/quantization.cpp +++ b/transformer_engine/jax/csrc/extensions/quantization.cpp @@ -245,6 +245,11 @@ Error_Type DBiasQuantizeFFI(cudaStream_t stream, Buffer_Type input_buf, Buffer_T } } + // For MXFP8, produce pre-swizzled scales so the GEMM can consume them directly. + if (scaling_mode == JAXX_Scaling_Mode::MXFP8_1D_SCALING) { + output_tensor.set_with_gemm_swizzled_scales(true); + } + auto dbias_tensor = TensorWrapper(dbias, dbias_shape, in_dtype); auto workspace_tensor = TensorWrapper(workspace, workspace_shape, workspace_dtype); @@ -452,6 +457,12 @@ Error_Type GroupedQuantizeFFI(cudaStream_t stream, Buffer_Type inputs, Buffer_Ty } } + // For MXFP8, produce pre-swizzled scales so the GEMM can consume them directly + // without a separate swizzle pass. + if (is_mxfp8_scaling) { + out_i.set_with_gemm_swizzled_scales(true); + } + input_holders.push_back(std::move(inp_i)); output_holders.push_back(std::move(out_i)); @@ -593,11 +604,10 @@ Error_Type GroupedQuantizeV2FFI(cudaStream_t stream, Buffer_Type inputs, Buffer_ static_cast(out_dtype), data_shape}; nvte_set_grouped_tensor_param(out_grouped, kNVTEGroupedRowwiseData, &rw_data, sizeof(rw_data)); - auto sinv_dims = rowwise_sinv->dimensions(); NVTEShape rw_sinv_shape{}; rw_sinv_shape.ndim = 2; - rw_sinv_shape.data[0] = product(sinv_dims, 0, sinv_dims.size() - 1); - rw_sinv_shape.data[1] = sinv_dims.back(); + rw_sinv_shape.data[0] = m; + rw_sinv_shape.data[1] = n / 32; // MXFP8 block size = 32 NVTEBasicTensor rw_sinv{reinterpret_cast(rowwise_sinv->untyped_data()), static_cast(sinv_dtype), rw_sinv_shape}; nvte_set_grouped_tensor_param(out_grouped, kNVTEGroupedRowwiseScaleInv, &rw_sinv, @@ -611,11 +621,10 @@ Error_Type GroupedQuantizeV2FFI(cudaStream_t stream, Buffer_Type inputs, Buffer_ nvte_set_grouped_tensor_param(out_grouped, kNVTEGroupedColumnwiseData, &cw_data, sizeof(cw_data)); - auto cw_sinv_dims = colwise_sinv->dimensions(); NVTEShape cw_sinv_shape{}; cw_sinv_shape.ndim = 2; - cw_sinv_shape.data[0] = product(cw_sinv_dims, 0, cw_sinv_dims.size() - 1); - cw_sinv_shape.data[1] = cw_sinv_dims.back(); + cw_sinv_shape.data[0] = m / 32; // MXFP8 block size = 32 + cw_sinv_shape.data[1] = n; NVTEBasicTensor cw_sinv{reinterpret_cast(colwise_sinv->untyped_data()), static_cast(sinv_dtype), cw_sinv_shape}; nvte_set_grouped_tensor_param(out_grouped, kNVTEGroupedColumnwiseScaleInv, &cw_sinv, @@ -632,6 +641,13 @@ Error_Type GroupedQuantizeV2FFI(cudaStream_t stream, Buffer_Type inputs, Buffer_ if (total_colwise_sinv_size > 0) nvte_memset(colwise_sinv->untyped_data(), 0, total_colwise_sinv_size, stream); + // V2 grouped quantize is always paired with V2 grouped GEMM, which expects + // scale_inv in GEMM-swizzled layout. Enable the fused swizzle so the kernel + // writes scales in the layout the GEMM will consume directly. + uint8_t swizzle_flag = 1; + nvte_set_grouped_tensor_param(out_grouped, kNVTEGroupedWithGEMMSwizzledScales, + &swizzle_flag, sizeof(swizzle_flag)); + QuantizationConfigWrapper quant_config{}; nvte_group_quantize(in_grouped, out_grouped, quant_config, stream); From 064f314c083b273c1250d61751438d6bcd33c0ee Mon Sep 17 00:00:00 2001 From: Jeremy Berchtold Date: Fri, 3 Apr 2026 17:35:33 -0700 Subject: [PATCH 37/60] Remove pre-swizzling from non-grouped quantization Signed-off-by: Jeremy Berchtold --- transformer_engine/jax/csrc/extensions/quantization.cpp | 5 ----- 1 file changed, 5 deletions(-) diff --git a/transformer_engine/jax/csrc/extensions/quantization.cpp b/transformer_engine/jax/csrc/extensions/quantization.cpp index 938c225259..8670ece7a3 100644 --- a/transformer_engine/jax/csrc/extensions/quantization.cpp +++ b/transformer_engine/jax/csrc/extensions/quantization.cpp @@ -245,11 +245,6 @@ Error_Type DBiasQuantizeFFI(cudaStream_t stream, Buffer_Type input_buf, Buffer_T } } - // For MXFP8, produce pre-swizzled scales so the GEMM can consume them directly. - if (scaling_mode == JAXX_Scaling_Mode::MXFP8_1D_SCALING) { - output_tensor.set_with_gemm_swizzled_scales(true); - } - auto dbias_tensor = TensorWrapper(dbias, dbias_shape, in_dtype); auto workspace_tensor = TensorWrapper(workspace, workspace_shape, workspace_dtype); From 5edef90a928e0472c208f0c070433fe4018c96c2 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Sat, 4 Apr 2026 00:36:48 +0000 Subject: [PATCH 38/60] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- transformer_engine/jax/csrc/extensions/quantization.cpp | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/transformer_engine/jax/csrc/extensions/quantization.cpp b/transformer_engine/jax/csrc/extensions/quantization.cpp index 8670ece7a3..28b9ea4806 100644 --- a/transformer_engine/jax/csrc/extensions/quantization.cpp +++ b/transformer_engine/jax/csrc/extensions/quantization.cpp @@ -640,8 +640,8 @@ Error_Type GroupedQuantizeV2FFI(cudaStream_t stream, Buffer_Type inputs, Buffer_ // scale_inv in GEMM-swizzled layout. Enable the fused swizzle so the kernel // writes scales in the layout the GEMM will consume directly. uint8_t swizzle_flag = 1; - nvte_set_grouped_tensor_param(out_grouped, kNVTEGroupedWithGEMMSwizzledScales, - &swizzle_flag, sizeof(swizzle_flag)); + nvte_set_grouped_tensor_param(out_grouped, kNVTEGroupedWithGEMMSwizzledScales, &swizzle_flag, + sizeof(swizzle_flag)); QuantizationConfigWrapper quant_config{}; nvte_group_quantize(in_grouped, out_grouped, quant_config, stream); From c55969c692a5a831bcd57bca0c4eda1b18996196 Mon Sep 17 00:00:00 2001 From: Jeremy Berchtold Date: Mon, 6 Apr 2026 13:51:35 -0700 Subject: [PATCH 39/60] Use avg m,n,k heuristics for cuBLASLt Grouped GEMM Signed-off-by: Jeremy Berchtold --- transformer_engine/jax/cpp_extensions/gemm.py | 2 +- .../jax/csrc/extensions/gemm.cpp | 45 ++++++++++++++++++- 2 files changed, 45 insertions(+), 2 deletions(-) diff --git a/transformer_engine/jax/cpp_extensions/gemm.py b/transformer_engine/jax/cpp_extensions/gemm.py index aaec5affa8..72374fe905 100644 --- a/transformer_engine/jax/cpp_extensions/gemm.py +++ b/transformer_engine/jax/cpp_extensions/gemm.py @@ -2027,7 +2027,7 @@ def grouped_gemm_copy_group_sizes( @cache def _should_enforce_v2_grouped_gemm() -> bool: """Read NVTE_JAX_ENFORCE_V2_GROUPED_GEMM once per process (cached).""" - return os.getenv("NVTE_JAX_ENFORCE_V2_GROUPED_GEMM", "0") == "1" + return bool(int(os.getenv("NVTE_JAX_ENFORCE_V2_GROUPED_GEMM", "0"))) def _can_use_v2_grouped_gemm( diff --git a/transformer_engine/jax/csrc/extensions/gemm.cpp b/transformer_engine/jax/csrc/extensions/gemm.cpp index 0d1ef405f4..a1e55f8ca7 100644 --- a/transformer_engine/jax/csrc/extensions/gemm.cpp +++ b/transformer_engine/jax/csrc/extensions/gemm.cpp @@ -683,6 +683,37 @@ size_t grouped_gemm_num_gemms(Buffer_Type const &lhs_first_dims, Buffer_Type con } } +// Compute per-group average m, n, k from the 2D operand shapes and ragged indicators. +// For each operand, left_size and right_size are the total 2D dimensions of the entire buffer. +// - If a dim is ragged (first_dims/last_dims non-empty), its total is the sum of per-group sizes, +// so the average is total / num_gemms. +// - If a dim is static but the tensor has no ragged dims at all, the group batch dimension G is +// folded into left_size (since G is always dim 0 and axis_boundary >= 1), so we divide it out. +// - If a dim is static and the *other* dim of the same tensor is ragged, there is no G in the +// shape, so the static dim's size is already the per-group value. +// +// The transpose flag determines the mapping from per-group (left, right) to (m_or_non_contract, +// k_or_contract): +// lhs represents [m, k] (not transposed) or [k, m] (transposed). +// rhs represents [k, n] (not transposed) or [n, k] (transposed). +// Returns {non_contracting_avg, contracting_avg}, i.e. {avg_m, avg_k} for lhs or {avg_n, avg_k} +// for rhs. +std::pair grouped_gemm_avg_dims( + Buffer_Type const &first_dims, Buffer_Type const &last_dims, size_t left_size, + size_t right_size, size_t num_gemms, bool is_trans) { + bool first_ragged = first_dims.element_count() > 0; + bool last_ragged = last_dims.element_count() > 0; + bool any_ragged = first_ragged || last_ragged; + // Per-group left: divide by num_gemms if first dim is ragged OR if tensor has no ragged dims + // (G is folded into left_size). Keep as-is only when last dim is ragged but first is not. + size_t pg_left = + (first_ragged || !any_ragged) ? left_size / num_gemms : left_size; + size_t pg_right = last_ragged ? right_size / num_gemms : right_size; + int64_t non_contract = static_cast(is_trans ? pg_right : pg_left); + int64_t contract = static_cast(is_trans ? pg_left : pg_right); + return {non_contract, contract}; +} + } // namespace jax } // namespace transformer_engine @@ -741,10 +772,22 @@ Error_Type GroupedGemmV2FFI(cudaStream_t stream, Buffer_Type lhs_data, Buffer_Ty auto out_tensor = make_grouped_tensor(*output, out_first_dims, out_last_dims, int64_base, int64_capacity, int64_offset, num_gemms, stream); + auto [avg_m, avg_k_lhs] = grouped_gemm_avg_dims(lhs_first_dims, lhs_last_dims, lhs_left_size, + lhs_right_size, num_gemms, lhs_is_trans); + auto [avg_n, avg_k_rhs] = grouped_gemm_avg_dims(rhs_first_dims, rhs_last_dims, rhs_left_size, + rhs_right_size, num_gemms, rhs_is_trans); + // Use k from lhs (both sides should agree for well-formed inputs). + (void)avg_k_rhs; + + GroupedMatmulConfigWrapper gemmConfig{}; + gemmConfig.set_avg_m(avg_m); + gemmConfig.set_avg_n(avg_n); + gemmConfig.set_avg_k(avg_k_lhs); + nvte_grouped_gemm(rhs_tensor, rhs_is_trans, lhs_tensor, lhs_is_trans, nullptr, out_tensor, alpha_tensor.data(), beta_tensor.data(), workspace_setup.data(), workspace_cublas.data(), - nullptr, // config (use defaults) + gemmConfig, stream); return ffi_with_cuda_error_check(); From 427d5b6ff4590f7e8fe1f34ca390400d4e289fce Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Mon, 6 Apr 2026 21:00:44 +0000 Subject: [PATCH 40/60] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- .../jax/csrc/extensions/gemm.cpp | 18 ++++++++---------- 1 file changed, 8 insertions(+), 10 deletions(-) diff --git a/transformer_engine/jax/csrc/extensions/gemm.cpp b/transformer_engine/jax/csrc/extensions/gemm.cpp index a1e55f8ca7..52bc44a2b6 100644 --- a/transformer_engine/jax/csrc/extensions/gemm.cpp +++ b/transformer_engine/jax/csrc/extensions/gemm.cpp @@ -698,16 +698,16 @@ size_t grouped_gemm_num_gemms(Buffer_Type const &lhs_first_dims, Buffer_Type con // rhs represents [k, n] (not transposed) or [n, k] (transposed). // Returns {non_contracting_avg, contracting_avg}, i.e. {avg_m, avg_k} for lhs or {avg_n, avg_k} // for rhs. -std::pair grouped_gemm_avg_dims( - Buffer_Type const &first_dims, Buffer_Type const &last_dims, size_t left_size, - size_t right_size, size_t num_gemms, bool is_trans) { +std::pair grouped_gemm_avg_dims(Buffer_Type const &first_dims, + Buffer_Type const &last_dims, size_t left_size, + size_t right_size, size_t num_gemms, + bool is_trans) { bool first_ragged = first_dims.element_count() > 0; bool last_ragged = last_dims.element_count() > 0; bool any_ragged = first_ragged || last_ragged; // Per-group left: divide by num_gemms if first dim is ragged OR if tensor has no ragged dims // (G is folded into left_size). Keep as-is only when last dim is ragged but first is not. - size_t pg_left = - (first_ragged || !any_ragged) ? left_size / num_gemms : left_size; + size_t pg_left = (first_ragged || !any_ragged) ? left_size / num_gemms : left_size; size_t pg_right = last_ragged ? right_size / num_gemms : right_size; int64_t non_contract = static_cast(is_trans ? pg_right : pg_left); int64_t contract = static_cast(is_trans ? pg_left : pg_right); @@ -773,9 +773,9 @@ Error_Type GroupedGemmV2FFI(cudaStream_t stream, Buffer_Type lhs_data, Buffer_Ty int64_capacity, int64_offset, num_gemms, stream); auto [avg_m, avg_k_lhs] = grouped_gemm_avg_dims(lhs_first_dims, lhs_last_dims, lhs_left_size, - lhs_right_size, num_gemms, lhs_is_trans); + lhs_right_size, num_gemms, lhs_is_trans); auto [avg_n, avg_k_rhs] = grouped_gemm_avg_dims(rhs_first_dims, rhs_last_dims, rhs_left_size, - rhs_right_size, num_gemms, rhs_is_trans); + rhs_right_size, num_gemms, rhs_is_trans); // Use k from lhs (both sides should agree for well-formed inputs). (void)avg_k_rhs; @@ -786,9 +786,7 @@ Error_Type GroupedGemmV2FFI(cudaStream_t stream, Buffer_Type lhs_data, Buffer_Ty nvte_grouped_gemm(rhs_tensor, rhs_is_trans, lhs_tensor, lhs_is_trans, nullptr, out_tensor, alpha_tensor.data(), beta_tensor.data(), workspace_setup.data(), - workspace_cublas.data(), - gemmConfig, - stream); + workspace_cublas.data(), gemmConfig, stream); return ffi_with_cuda_error_check(); } From 167c343ace51c5b6ffc309bd2126671d5ac6bca9 Mon Sep 17 00:00:00 2001 From: jberchtold-nvidia <158520091+jberchtold-nvidia@users.noreply.github.com> Date: Mon, 6 Apr 2026 14:05:56 -0700 Subject: [PATCH 41/60] Update transformer_engine/jax/cpp_extensions/gemm.py Co-authored-by: greptile-apps[bot] <165735046+greptile-apps[bot]@users.noreply.github.com> Signed-off-by: jberchtold-nvidia <158520091+jberchtold-nvidia@users.noreply.github.com> --- transformer_engine/jax/cpp_extensions/gemm.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/transformer_engine/jax/cpp_extensions/gemm.py b/transformer_engine/jax/cpp_extensions/gemm.py index 72374fe905..c081e451a7 100644 --- a/transformer_engine/jax/cpp_extensions/gemm.py +++ b/transformer_engine/jax/cpp_extensions/gemm.py @@ -2027,7 +2027,13 @@ def grouped_gemm_copy_group_sizes( @cache def _should_enforce_v2_grouped_gemm() -> bool: """Read NVTE_JAX_ENFORCE_V2_GROUPED_GEMM once per process (cached).""" - return bool(int(os.getenv("NVTE_JAX_ENFORCE_V2_GROUPED_GEMM", "0"))) + val = os.getenv("NVTE_JAX_ENFORCE_V2_GROUPED_GEMM", "0") + try: + return bool(int(val)) + except ValueError as e: + raise ValueError( + f"NVTE_JAX_ENFORCE_V2_GROUPED_GEMM must be an integer (0 or 1), got: {val!r}" + ) from e def _can_use_v2_grouped_gemm( From ae97af1d1f2290795a18b6e319e18e6ded41de11 Mon Sep 17 00:00:00 2001 From: Jeremy Berchtold Date: Mon, 6 Apr 2026 13:51:35 -0700 Subject: [PATCH 42/60] Use avg m,n,k heuristics for cuBLASLt Grouped GEMM Signed-off-by: Jeremy Berchtold --- transformer_engine/jax/cpp_extensions/gemm.py | 2 +- .../jax/csrc/extensions/gemm.cpp | 45 ++++++++++++++++++- 2 files changed, 45 insertions(+), 2 deletions(-) diff --git a/transformer_engine/jax/cpp_extensions/gemm.py b/transformer_engine/jax/cpp_extensions/gemm.py index d18ef866b8..0e13d256b6 100644 --- a/transformer_engine/jax/cpp_extensions/gemm.py +++ b/transformer_engine/jax/cpp_extensions/gemm.py @@ -2027,7 +2027,7 @@ def grouped_gemm_copy_group_sizes( @cache def _should_enforce_v2_grouped_gemm() -> bool: """Read NVTE_JAX_ENFORCE_V2_GROUPED_GEMM once per process (cached).""" - return os.getenv("NVTE_JAX_ENFORCE_V2_GROUPED_GEMM", "0") == "1" + return bool(int(os.getenv("NVTE_JAX_ENFORCE_V2_GROUPED_GEMM", "0"))) def _can_use_v2_grouped_gemm( diff --git a/transformer_engine/jax/csrc/extensions/gemm.cpp b/transformer_engine/jax/csrc/extensions/gemm.cpp index 11ee1e470f..d1fe7ef12c 100644 --- a/transformer_engine/jax/csrc/extensions/gemm.cpp +++ b/transformer_engine/jax/csrc/extensions/gemm.cpp @@ -790,6 +790,37 @@ size_t grouped_gemm_num_gemms(Buffer_Type const &lhs_first_dims, Buffer_Type con } } +// Compute per-group average m, n, k from the 2D operand shapes and ragged indicators. +// For each operand, left_size and right_size are the total 2D dimensions of the entire buffer. +// - If a dim is ragged (first_dims/last_dims non-empty), its total is the sum of per-group sizes, +// so the average is total / num_gemms. +// - If a dim is static but the tensor has no ragged dims at all, the group batch dimension G is +// folded into left_size (since G is always dim 0 and axis_boundary >= 1), so we divide it out. +// - If a dim is static and the *other* dim of the same tensor is ragged, there is no G in the +// shape, so the static dim's size is already the per-group value. +// +// The transpose flag determines the mapping from per-group (left, right) to (m_or_non_contract, +// k_or_contract): +// lhs represents [m, k] (not transposed) or [k, m] (transposed). +// rhs represents [k, n] (not transposed) or [n, k] (transposed). +// Returns {non_contracting_avg, contracting_avg}, i.e. {avg_m, avg_k} for lhs or {avg_n, avg_k} +// for rhs. +std::pair grouped_gemm_avg_dims( + Buffer_Type const &first_dims, Buffer_Type const &last_dims, size_t left_size, + size_t right_size, size_t num_gemms, bool is_trans) { + bool first_ragged = first_dims.element_count() > 0; + bool last_ragged = last_dims.element_count() > 0; + bool any_ragged = first_ragged || last_ragged; + // Per-group left: divide by num_gemms if first dim is ragged OR if tensor has no ragged dims + // (G is folded into left_size). Keep as-is only when last dim is ragged but first is not. + size_t pg_left = + (first_ragged || !any_ragged) ? left_size / num_gemms : left_size; + size_t pg_right = last_ragged ? right_size / num_gemms : right_size; + int64_t non_contract = static_cast(is_trans ? pg_right : pg_left); + int64_t contract = static_cast(is_trans ? pg_left : pg_right); + return {non_contract, contract}; +} + } // namespace jax } // namespace transformer_engine @@ -879,10 +910,22 @@ Error_Type GroupedGemmV2FFI(cudaStream_t stream, Buffer_Type lhs_data, Buffer_Ty make_grouped_tensor(*output, out_first_dims, out_last_dims, int64_base, int64_capacity, int64_offset, num_gemms, stream, out_left_size, out_right_size); + auto [avg_m, avg_k_lhs] = grouped_gemm_avg_dims(lhs_first_dims, lhs_last_dims, lhs_left_size, + lhs_right_size, num_gemms, lhs_is_trans); + auto [avg_n, avg_k_rhs] = grouped_gemm_avg_dims(rhs_first_dims, rhs_last_dims, rhs_left_size, + rhs_right_size, num_gemms, rhs_is_trans); + // Use k from lhs (both sides should agree for well-formed inputs). + (void)avg_k_rhs; + + GroupedMatmulConfigWrapper gemmConfig{}; + gemmConfig.set_avg_m(avg_m); + gemmConfig.set_avg_n(avg_n); + gemmConfig.set_avg_k(avg_k_lhs); + nvte_grouped_gemm(rhs_tensor, rhs_is_trans, lhs_tensor, lhs_is_trans, nullptr, out_tensor, alpha_tensor.data(), beta_tensor.data(), workspace_setup.data(), workspace_cublas.data(), - nullptr, // config (use defaults) + gemmConfig, stream); return ffi_with_cuda_error_check(); From f1c758250a3e7b50a4fd6a2284e0277c7dd9e7ad Mon Sep 17 00:00:00 2001 From: Jeremy Berchtold Date: Mon, 6 Apr 2026 14:43:13 -0700 Subject: [PATCH 43/60] Fix rhs transpose flag Signed-off-by: Jeremy Berchtold --- transformer_engine/jax/csrc/extensions/gemm.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/transformer_engine/jax/csrc/extensions/gemm.cpp b/transformer_engine/jax/csrc/extensions/gemm.cpp index 52bc44a2b6..3ebeeecded 100644 --- a/transformer_engine/jax/csrc/extensions/gemm.cpp +++ b/transformer_engine/jax/csrc/extensions/gemm.cpp @@ -775,7 +775,7 @@ Error_Type GroupedGemmV2FFI(cudaStream_t stream, Buffer_Type lhs_data, Buffer_Ty auto [avg_m, avg_k_lhs] = grouped_gemm_avg_dims(lhs_first_dims, lhs_last_dims, lhs_left_size, lhs_right_size, num_gemms, lhs_is_trans); auto [avg_n, avg_k_rhs] = grouped_gemm_avg_dims(rhs_first_dims, rhs_last_dims, rhs_left_size, - rhs_right_size, num_gemms, rhs_is_trans); + rhs_right_size, num_gemms, !rhs_is_trans); // Use k from lhs (both sides should agree for well-formed inputs). (void)avg_k_rhs; From b3ea76a5ba5c677abaced4797b1038d12ae2256f Mon Sep 17 00:00:00 2001 From: Jeremy Berchtold Date: Mon, 6 Apr 2026 14:43:13 -0700 Subject: [PATCH 44/60] Fix rhs transpose flag Signed-off-by: Jeremy Berchtold --- transformer_engine/jax/csrc/extensions/gemm.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/transformer_engine/jax/csrc/extensions/gemm.cpp b/transformer_engine/jax/csrc/extensions/gemm.cpp index d1fe7ef12c..469d49f465 100644 --- a/transformer_engine/jax/csrc/extensions/gemm.cpp +++ b/transformer_engine/jax/csrc/extensions/gemm.cpp @@ -913,7 +913,7 @@ Error_Type GroupedGemmV2FFI(cudaStream_t stream, Buffer_Type lhs_data, Buffer_Ty auto [avg_m, avg_k_lhs] = grouped_gemm_avg_dims(lhs_first_dims, lhs_last_dims, lhs_left_size, lhs_right_size, num_gemms, lhs_is_trans); auto [avg_n, avg_k_rhs] = grouped_gemm_avg_dims(rhs_first_dims, rhs_last_dims, rhs_left_size, - rhs_right_size, num_gemms, rhs_is_trans); + rhs_right_size, num_gemms, !rhs_is_trans); // Use k from lhs (both sides should agree for well-formed inputs). (void)avg_k_rhs; From 6387b8ae1c63015409abf768555f383d798c67e3 Mon Sep 17 00:00:00 2001 From: Jeremy Berchtold Date: Tue, 7 Apr 2026 15:25:01 -0700 Subject: [PATCH 45/60] Address comments Signed-off-by: Jeremy Berchtold --- .../jax/csrc/extensions/gemm.cpp | 79 ++++++++++++------- 1 file changed, 51 insertions(+), 28 deletions(-) diff --git a/transformer_engine/jax/csrc/extensions/gemm.cpp b/transformer_engine/jax/csrc/extensions/gemm.cpp index 3ebeeecded..ce016e54e1 100644 --- a/transformer_engine/jax/csrc/extensions/gemm.cpp +++ b/transformer_engine/jax/csrc/extensions/gemm.cpp @@ -683,34 +683,57 @@ size_t grouped_gemm_num_gemms(Buffer_Type const &lhs_first_dims, Buffer_Type con } } -// Compute per-group average m, n, k from the 2D operand shapes and ragged indicators. -// For each operand, left_size and right_size are the total 2D dimensions of the entire buffer. -// - If a dim is ragged (first_dims/last_dims non-empty), its total is the sum of per-group sizes, -// so the average is total / num_gemms. -// - If a dim is static but the tensor has no ragged dims at all, the group batch dimension G is -// folded into left_size (since G is always dim 0 and axis_boundary >= 1), so we divide it out. -// - If a dim is static and the *other* dim of the same tensor is ragged, there is no G in the -// shape, so the static dim's size is already the per-group value. -// -// The transpose flag determines the mapping from per-group (left, right) to (m_or_non_contract, -// k_or_contract): -// lhs represents [m, k] (not transposed) or [k, m] (transposed). -// rhs represents [k, n] (not transposed) or [n, k] (transposed). -// Returns {non_contracting_avg, contracting_avg}, i.e. {avg_m, avg_k} for lhs or {avg_n, avg_k} -// for rhs. +/* EXPERIMENTAL FEATURE AND SUBJECT TO CHANGE. */ +/*! \brief Compute estimates for average dimensions of a grouped tensor. + * + * Returns a pair of {non_contracting_avg, contracting_avg} dimensions for the given grouped tensor, to estimate per-group GEMM sizes. When a dimension is ragged, we estimate the average size by dividing the dim size by G ("num_gemms"). When a dimension has no ragged dims, we assume it is of shape (G*K, N) or (G*N, K) so we divide the first dim by G to get the average per-group size. + * + * Examples: + * - fwd lhs: shape_2d=[ragged M, K], first_dims=[M,...] (ragged M) → avg_m = (G*M)/G = M, avg_k = K + * - fwd rhs: shape_2d=[G*K, N], last_dims=None (static K) → avg_k = (G*K)/G = K, avg_n = N + * - wgrad lhs: shape_2d=[M, ragged K], last_dims=[K,...] (ragged K) → avg_k = (G*K)/G = K, avg_m = M + * - wgrad rhs: shape_2d=[N, ragged K], last_dims=[K,...] (ragged K) → avg_k = (G*K)/G = K, avg_n = N + * + * \param[in] first_dims XLA buffer of on-device first dimensions. Shape (G,) if ragged, empty otherwise. + * \param[in] last_dims XLA buffer of on-device last dimensions. Shape (G,) if ragged, empty otherwise. + * \param[in] shape_2d Pair of total 2D dimensions (rows, cols) for the operand. + * \param[in] num_gemms Number of GEMMs (G) in the grouped operation. + * \param[in] is_trans Whether the operand is transposed. + * \return Pair of {non_contracting_avg, contracting_avg}, i.e. {avg_m, avg_k} for lhs or + * {avg_n, avg_k} for rhs. + */ std::pair grouped_gemm_avg_dims(Buffer_Type const &first_dims, - Buffer_Type const &last_dims, size_t left_size, - size_t right_size, size_t num_gemms, + Buffer_Type const &last_dims, std::pair const& shape_2d, size_t num_gemms, bool is_trans) { bool first_ragged = first_dims.element_count() > 0; bool last_ragged = last_dims.element_count() > 0; bool any_ragged = first_ragged || last_ragged; - // Per-group left: divide by num_gemms if first dim is ragged OR if tensor has no ragged dims - // (G is folded into left_size). Keep as-is only when last dim is ragged but first is not. - size_t pg_left = (first_ragged || !any_ragged) ? left_size / num_gemms : left_size; - size_t pg_right = last_ragged ? right_size / num_gemms : right_size; - int64_t non_contract = static_cast(is_trans ? pg_right : pg_left); - int64_t contract = static_cast(is_trans ? pg_left : pg_right); + + std::pair per_group_shape_2d{}; + if (first_ragged) { + per_group_shape_2d = { + static_cast(std::round(static_cast(shape_2d.first) / num_gemms)), + shape_2d.second + }; + } + else if (!any_ragged) { + per_group_shape_2d = { + static_cast(std::round(static_cast(shape_2d.first) / num_gemms)), + shape_2d.second + }; + } + else if (last_ragged && !first_ragged) { + per_group_shape_2d = { + shape_2d.first, + static_cast(std::round(static_cast(shape_2d.second) / num_gemms)) + }; + } + else { + NVTE_CHECK(false, "Grouped GEMM with both first_dims and last_dims ragged is not supported."); + } + + int64_t non_contract = static_cast(is_trans ? per_group_shape_2d.second : per_group_shape_2d.first); + int64_t contract = static_cast(is_trans ? per_group_shape_2d.first : per_group_shape_2d.second); return {non_contract, contract}; } @@ -772,12 +795,12 @@ Error_Type GroupedGemmV2FFI(cudaStream_t stream, Buffer_Type lhs_data, Buffer_Ty auto out_tensor = make_grouped_tensor(*output, out_first_dims, out_last_dims, int64_base, int64_capacity, int64_offset, num_gemms, stream); - auto [avg_m, avg_k_lhs] = grouped_gemm_avg_dims(lhs_first_dims, lhs_last_dims, lhs_left_size, - lhs_right_size, num_gemms, lhs_is_trans); - auto [avg_n, avg_k_rhs] = grouped_gemm_avg_dims(rhs_first_dims, rhs_last_dims, rhs_left_size, - rhs_right_size, num_gemms, !rhs_is_trans); + auto [avg_m, avg_k_lhs] = grouped_gemm_avg_dims(lhs_first_dims, lhs_last_dims, {lhs_left_size, lhs_right_size}, num_gemms, lhs_is_trans); + auto [avg_n, avg_k_rhs] = grouped_gemm_avg_dims(rhs_first_dims, rhs_last_dims, {rhs_left_size, rhs_right_size}, num_gemms, !rhs_is_trans); // Use k from lhs (both sides should agree for well-formed inputs). - (void)avg_k_rhs; + NVTE_CHECK(avg_k_lhs == avg_k_rhs, + "Contracting dimension mismatch: lhs avg_k=", avg_k_lhs, + " vs rhs avg_k=", avg_k_rhs); GroupedMatmulConfigWrapper gemmConfig{}; gemmConfig.set_avg_m(avg_m); From 7febb9bf38624b5a5f22667682fc1d51f309eb24 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 7 Apr 2026 22:26:07 +0000 Subject: [PATCH 46/60] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- .../jax/csrc/extensions/gemm.cpp | 46 +++++++++---------- 1 file changed, 22 insertions(+), 24 deletions(-) diff --git a/transformer_engine/jax/csrc/extensions/gemm.cpp b/transformer_engine/jax/csrc/extensions/gemm.cpp index ce016e54e1..a7f16bb31f 100644 --- a/transformer_engine/jax/csrc/extensions/gemm.cpp +++ b/transformer_engine/jax/csrc/extensions/gemm.cpp @@ -703,8 +703,9 @@ size_t grouped_gemm_num_gemms(Buffer_Type const &lhs_first_dims, Buffer_Type con * {avg_n, avg_k} for rhs. */ std::pair grouped_gemm_avg_dims(Buffer_Type const &first_dims, - Buffer_Type const &last_dims, std::pair const& shape_2d, size_t num_gemms, - bool is_trans) { + Buffer_Type const &last_dims, + std::pair const &shape_2d, + size_t num_gemms, bool is_trans) { bool first_ragged = first_dims.element_count() > 0; bool last_ragged = last_dims.element_count() > 0; bool any_ragged = first_ragged || last_ragged; @@ -712,28 +713,24 @@ std::pair grouped_gemm_avg_dims(Buffer_Type const &first_dims, std::pair per_group_shape_2d{}; if (first_ragged) { per_group_shape_2d = { - static_cast(std::round(static_cast(shape_2d.first) / num_gemms)), - shape_2d.second - }; - } - else if (!any_ragged) { + static_cast(std::round(static_cast(shape_2d.first) / num_gemms)), + shape_2d.second}; + } else if (!any_ragged) { per_group_shape_2d = { - static_cast(std::round(static_cast(shape_2d.first) / num_gemms)), - shape_2d.second - }; - } - else if (last_ragged && !first_ragged) { + static_cast(std::round(static_cast(shape_2d.first) / num_gemms)), + shape_2d.second}; + } else if (last_ragged && !first_ragged) { per_group_shape_2d = { - shape_2d.first, - static_cast(std::round(static_cast(shape_2d.second) / num_gemms)) - }; - } - else { + shape_2d.first, + static_cast(std::round(static_cast(shape_2d.second) / num_gemms))}; + } else { NVTE_CHECK(false, "Grouped GEMM with both first_dims and last_dims ragged is not supported."); } - int64_t non_contract = static_cast(is_trans ? per_group_shape_2d.second : per_group_shape_2d.first); - int64_t contract = static_cast(is_trans ? per_group_shape_2d.first : per_group_shape_2d.second); + int64_t non_contract = + static_cast(is_trans ? per_group_shape_2d.second : per_group_shape_2d.first); + int64_t contract = + static_cast(is_trans ? per_group_shape_2d.first : per_group_shape_2d.second); return {non_contract, contract}; } @@ -795,12 +792,13 @@ Error_Type GroupedGemmV2FFI(cudaStream_t stream, Buffer_Type lhs_data, Buffer_Ty auto out_tensor = make_grouped_tensor(*output, out_first_dims, out_last_dims, int64_base, int64_capacity, int64_offset, num_gemms, stream); - auto [avg_m, avg_k_lhs] = grouped_gemm_avg_dims(lhs_first_dims, lhs_last_dims, {lhs_left_size, lhs_right_size}, num_gemms, lhs_is_trans); - auto [avg_n, avg_k_rhs] = grouped_gemm_avg_dims(rhs_first_dims, rhs_last_dims, {rhs_left_size, rhs_right_size}, num_gemms, !rhs_is_trans); + auto [avg_m, avg_k_lhs] = grouped_gemm_avg_dims( + lhs_first_dims, lhs_last_dims, {lhs_left_size, lhs_right_size}, num_gemms, lhs_is_trans); + auto [avg_n, avg_k_rhs] = grouped_gemm_avg_dims( + rhs_first_dims, rhs_last_dims, {rhs_left_size, rhs_right_size}, num_gemms, !rhs_is_trans); // Use k from lhs (both sides should agree for well-formed inputs). - NVTE_CHECK(avg_k_lhs == avg_k_rhs, - "Contracting dimension mismatch: lhs avg_k=", avg_k_lhs, - " vs rhs avg_k=", avg_k_rhs); + NVTE_CHECK(avg_k_lhs == avg_k_rhs, "Contracting dimension mismatch: lhs avg_k=", avg_k_lhs, + " vs rhs avg_k=", avg_k_rhs); GroupedMatmulConfigWrapper gemmConfig{}; gemmConfig.set_avg_m(avg_m); From 2e1a9f50d0b6ad367005e3cb111f9a313248dfe8 Mon Sep 17 00:00:00 2001 From: Jeremy Berchtold Date: Tue, 7 Apr 2026 16:01:52 -0700 Subject: [PATCH 47/60] Fix merge issue Signed-off-by: Jeremy Berchtold --- transformer_engine/jax/csrc/extensions/gemm.cpp | 12 ------------ 1 file changed, 12 deletions(-) diff --git a/transformer_engine/jax/csrc/extensions/gemm.cpp b/transformer_engine/jax/csrc/extensions/gemm.cpp index d7551cb810..6ca907032c 100644 --- a/transformer_engine/jax/csrc/extensions/gemm.cpp +++ b/transformer_engine/jax/csrc/extensions/gemm.cpp @@ -930,18 +930,6 @@ Error_Type GroupedGemmV2FFI(cudaStream_t stream, Buffer_Type lhs_data, Buffer_Ty make_grouped_tensor(*output, out_first_dims, out_last_dims, int64_base, int64_capacity, int64_offset, num_gemms, stream, out_left_size, out_right_size); - auto [avg_m, avg_k_lhs] = grouped_gemm_avg_dims(lhs_first_dims, lhs_last_dims, lhs_left_size, - lhs_right_size, num_gemms, lhs_is_trans); - auto [avg_n, avg_k_rhs] = grouped_gemm_avg_dims(rhs_first_dims, rhs_last_dims, rhs_left_size, - rhs_right_size, num_gemms, !rhs_is_trans); - // Use k from lhs (both sides should agree for well-formed inputs). - (void)avg_k_rhs; - - GroupedMatmulConfigWrapper gemmConfig{}; - gemmConfig.set_avg_m(avg_m); - gemmConfig.set_avg_n(avg_n); - gemmConfig.set_avg_k(avg_k_lhs); - auto [avg_m, avg_k_lhs] = grouped_gemm_avg_dims( lhs_first_dims, lhs_last_dims, {lhs_left_size, lhs_right_size}, num_gemms, lhs_is_trans); auto [avg_n, avg_k_rhs] = grouped_gemm_avg_dims( From 7769c51df1e3cdb49d93e8fb6ce961c8108a4789 Mon Sep 17 00:00:00 2001 From: Jeremy Berchtold Date: Tue, 7 Apr 2026 16:39:06 -0700 Subject: [PATCH 48/60] Remove unnecessary changes Signed-off-by: Jeremy Berchtold --- .../jax/cpp_extensions/quantization.py | 63 +------------------ transformer_engine/jax/permutation.py | 10 ++- 2 files changed, 6 insertions(+), 67 deletions(-) diff --git a/transformer_engine/jax/cpp_extensions/quantization.py b/transformer_engine/jax/cpp_extensions/quantization.py index 382138f4d5..85f726b043 100644 --- a/transformer_engine/jax/cpp_extensions/quantization.py +++ b/transformer_engine/jax/cpp_extensions/quantization.py @@ -52,53 +52,6 @@ __all__ = ["quantize", "quantize_dbias", "grouped_quantize", "grouped_dbias"] -def _build_scale_spec(x_spec, scale_shape, mesh): - """Build a PartitionSpec for the MXFP8 scale tensor compatible with its shape. - - The scale tensor has smaller dimensions than the data tensor (each dimension - divided by the MXFP8 block size). This function ensures that we only shard a - scale dimension by a mesh axis (or tuple of axes) if scale_shape[i] is - divisible by the total axis size. If not, a ValueError is raised with a - helpful diagnostic message. - """ - result = [] - for axis, scale_dim in zip(x_spec, scale_shape): - if axis is None: - result.append(None) - elif isinstance(axis, str): - axis_size = mesh.shape.get(axis, 1) - if scale_dim % axis_size == 0: - result.append(axis) - else: - raise ValueError( - f"Cannot partition MXFP8 scale tensor (shape={tuple(scale_shape)}) " - f"by mesh axis '{axis}' of size {axis_size}: " - f"scale dim {scale_dim} is not divisible by {axis_size}. " - "The data tensor's sharding is incompatible with the MXFP8 block " - "size along this axis. Try reducing expert parallelism (EP) so that " - "EP divides the scale dimension, or increase the tensor size." - ) - elif isinstance(axis, (tuple, list)): - # Multi-axis sharding (e.g. ('fsdp', 'expert')): check total combined size. - total_size = 1 - for a in axis: - total_size *= mesh.shape.get(a, 1) - if scale_dim % total_size == 0: - result.append(axis) - else: - raise ValueError( - f"Cannot partition MXFP8 scale tensor (shape={tuple(scale_shape)}) " - f"by mesh axes {tuple(axis)} of combined size {total_size}: " - f"scale dim {scale_dim} is not divisible by {total_size}. " - "The data tensor's sharding is incompatible with the MXFP8 block " - "size along this axis. Try reducing parallelism or increasing the " - "tensor size." - ) - else: - result.append(None) - return tuple(result) - - class BaseDBiasQuantizePrimitive(BasePrimitive): """ Cast Primitive wrapping nvte_quantize and nvte_quantize_dbias @@ -494,13 +447,7 @@ def infer_sharding_from_operands( scale_inv_spec = colwise_scale_inv_spec = (None,) if ScalingMode(scaling_mode).is_block_scaling: - rowwise_scale_shape, _ = ScalingMode(scaling_mode).get_scale_shape_2x( - arg_infos[0].shape, - is_padded=False, - flatten_axis=flatten_axis, - broadcast_2d_scale_shape_to_1d=True, - ) - scale_inv_spec = _build_scale_spec(x_spec, rowwise_scale_shape, mesh) + scale_inv_spec = x_spec if q_layout.has_colwise: if ( @@ -582,13 +529,7 @@ def partition( scale_inv_spec = colwise_scale_inv_spec = (None,) if ScalingMode(scaling_mode).is_block_scaling: - rowwise_scale_shape, _ = ScalingMode(scaling_mode).get_scale_shape_2x( - arg_infos[0].shape, - is_padded=False, - flatten_axis=flatten_axis, - broadcast_2d_scale_shape_to_1d=True, - ) - scale_inv_spec = _build_scale_spec(x_spec, rowwise_scale_shape, mesh) + scale_inv_spec = x_spec if q_layout.has_colwise: if ( diff --git a/transformer_engine/jax/permutation.py b/transformer_engine/jax/permutation.py index 2732c4acc5..6a0a3229d9 100644 --- a/transformer_engine/jax/permutation.py +++ b/transformer_engine/jax/permutation.py @@ -497,9 +497,8 @@ def _token_combine_bwd_rule( hidden_size, ) # The backward kernel only writes to positions that tokens map to. - # Padded positions may contain uninitialized values (NaN, inf, or garbage). - # Replace any non-finite values with zeros. - inp_grad = jnp.where(jnp.isfinite(inp_grad), inp_grad, 0.0) + # Padded positions may contain uninitialized (NaN) values - replace with zeros. + inp_grad = jnp.where(jnp.isnan(inp_grad), 0.0, inp_grad) else: inp_grad, merging_probs_grad = unpermute_bwd_with_merging_probs( output_grad, @@ -528,9 +527,8 @@ def _token_combine_bwd_rule( align_size=128, # Default, sizes already computed in forward ) # The permute kernel only writes to positions that tokens map to. - # Padded positions may contain uninitialized values (NaN, inf, or garbage). - # Replace any non-finite values with zeros. - inp_grad = jnp.where(jnp.isfinite(inp_grad), inp_grad, 0.0) + # Padded positions may contain uninitialized (NaN) values - replace with zeros. + inp_grad = jnp.where(jnp.isnan(inp_grad), 0.0, inp_grad) else: inp_grad, _ = permute_with_mask_map( output_grad, From 6fbe4cab4c3cbcb3873cdb41502b7d033be5c6f9 Mon Sep 17 00:00:00 2001 From: Jeremy Berchtold Date: Tue, 7 Apr 2026 17:26:25 -0700 Subject: [PATCH 49/60] Cleanup tests Signed-off-by: Jeremy Berchtold --- tests/jax/test_custom_call_compute.py | 482 ++------------------------ 1 file changed, 31 insertions(+), 451 deletions(-) diff --git a/tests/jax/test_custom_call_compute.py b/tests/jax/test_custom_call_compute.py index 82c4d1dd85..f42e99ba62 100644 --- a/tests/jax/test_custom_call_compute.py +++ b/tests/jax/test_custom_call_compute.py @@ -1072,7 +1072,10 @@ def test_rht_gemm(self, in_dtype, q_dtype, scaling_mode, m, n, k, data_layout, w @pytest.mark.skipif(not is_fp8_supported, reason=fp8_unsupported_reason) @pytest_parametrize_wrapper("in_dtype", QUANTIZATION_INPUT_DTYPE) -@pytest_parametrize_wrapper("input_shape", [(8, 16, 32)]) +@pytest_parametrize_wrapper("input_shape", [ + (8, 16, 32), # V1 MXFP8: K=32 not 128-aligned + (4, 8, 128), # V2 MXFP8 eligible: K=128, M*32=256 both 128-aligned +]) @pytest_parametrize_wrapper("q_dtype", [jnp.float8_e4m3fn]) @pytest_parametrize_wrapper("scaling_mode", non_fp4_supported_scaling_modes) @pytest_parametrize_wrapper("flatten_axis", [-1]) @@ -1119,60 +1122,28 @@ def test_grouped_qdq( assert_dequantized_grouped_scaled_tensor(scaled_tensor, x) - @pytest.mark.skipif(not is_mxfp8_supported, reason=mxfp8_unsupported_reason) - def test_grouped_quantize_v1_pre_swizzled( - self, input_shape, in_dtype, q_dtype, scaling_mode, q_layout, flatten_axis, with_group_sizes - ): - """V1 grouped quantize (K not 128-aligned) must produce pre_swizzled=False.""" - if scaling_mode != ScalingMode.MXFP8_1D_SCALING: - pytest.skip("pre_swizzled is only relevant for MXFP8") - if q_layout != QuantizeLayout.ROWWISE: - pytest.skip("Using ROWWISE layout to get a single GroupedScaledTensor1x") - # Shape with K=32 (not 128-aligned) forces V1 quantize on any GPU. - n_groups = 4 - group_sizes = jnp.array([32, 32, 32, 32], dtype=jnp.int32) - x = jax.random.uniform(jax.random.PRNGKey(0), (128, 32), jnp.bfloat16) - quantizer = QuantizerFactory.create( - scaling_mode=ScalingMode.MXFP8_1D_SCALING, - q_dtype=jnp.float8_e4m3fn, - q_layout=QuantizeLayout.ROWWISE, - n_groups=n_groups, - ) - scaled_tensor = tex.grouped_quantize( - x, quantizer=quantizer, group_sizes=group_sizes, flatten_axis=-1 - ) - assert isinstance(scaled_tensor, GroupedScaledTensor1x) - assert ( - not scaled_tensor.pre_swizzled - ), "V1 grouped quantize (non-128-aligned K) must produce pre_swizzled=False" - - @pytest.mark.skipif(not is_v2_grouped_gemm_supported, reason=v2_grouped_gemm_unsupported_reason) - @pytest.mark.skipif(not is_mxfp8_supported, reason=mxfp8_unsupported_reason) - def test_grouped_quantize_v2_pre_swizzled( - self, input_shape, in_dtype, q_dtype, scaling_mode, q_layout, flatten_axis, with_group_sizes - ): - """V2 grouped quantize (SM100+, 128-aligned M and K) must produce pre_swizzled=True.""" - if scaling_mode != ScalingMode.MXFP8_1D_SCALING: - pytest.skip("pre_swizzled is only relevant for MXFP8") - if q_layout != QuantizeLayout.ROWWISE: - pytest.skip("Using ROWWISE layout to get a single GroupedScaledTensor1x") - # Shape with M=512 and K=128 (both 128-aligned) allows V2 on SM100+. - n_groups = 4 - group_sizes = jnp.array([128, 128, 128, 128], dtype=jnp.int32) - x = jax.random.uniform(jax.random.PRNGKey(0), (512, 128), jnp.bfloat16) - quantizer = QuantizerFactory.create( - scaling_mode=ScalingMode.MXFP8_1D_SCALING, - q_dtype=jnp.float8_e4m3fn, - q_layout=QuantizeLayout.ROWWISE, - n_groups=n_groups, - ) - scaled_tensor = tex.grouped_quantize( - x, quantizer=quantizer, group_sizes=group_sizes, flatten_axis=-1 - ) - assert isinstance(scaled_tensor, GroupedScaledTensor1x) - assert ( - scaled_tensor.pre_swizzled - ), "V2 grouped quantize (SM100+, 128-aligned M and K) must produce pre_swizzled=True" + # Verify MXFP8 pre_swizzled flag for ROWWISE grouped quantize with explicit group_sizes. + # V2 grouped quantize (SM100+) fuses the scale swizzle and sets pre_swizzled=True + # when both total M and K are 128-aligned. V1 always produces pre_swizzled=False. + if ( + scaling_mode == ScalingMode.MXFP8_1D_SCALING + and q_layout == QuantizeLayout.ROWWISE + and with_group_sizes + and isinstance(scaled_tensor, GroupedScaledTensor1x) + ): + total_m = m * 32 + k_dim = n + if is_v2_grouped_gemm_supported and total_m % 128 == 0 and k_dim % 128 == 0: + # V2 path on SM100+: scales are pre-swizzled for GEMM + assert scaled_tensor.pre_swizzled, ( + "V2 grouped quantize (SM100+, 128-aligned M and K) must produce" + " pre_swizzled=True" + ) + elif k_dim % 128 != 0: + # V1 path: non-128-aligned K forces V1 quantize + assert not scaled_tensor.pre_swizzled, ( + "V1 grouped quantize (non-128-aligned K) must produce pre_swizzled=False" + ) @pytest_parametrize_wrapper("in_dtype", QUANTIZATION_INPUT_DTYPE) @@ -1772,10 +1743,11 @@ def ref_func(x, gamma, kernel_1, kernel_2, bias_1, bias_2): ] GROUPED_DENSE_INPUT_SHAPES = [ - # (n_groups, m, n, k), the actual m will be multiplied by 32 - (5, 32, 128, 64), # Test the case where n_groups is not a multiple of 4 - (8, 64, 32, 128), - (8, 64, 128, 256), + # (n_groups, m, n, k), the actual m will be multiplied by group_size_multiplier + (5, 32, 128, 64), # V1 MXFP8: K=64 not 128-aligned; also tests n_groups not a multiple of 4 + (8, 64, 32, 128), # V1 MXFP8 GEMM: N=32 not 128-aligned + (8, 64, 128, 256), # V2 MXFP8 eligible: K=256, N=128 both 128-aligned + (4, 4, 128, 128), # V2 MXFP8 eligible: K=128, N=128 both 128-aligned (smaller shape) ] @@ -2011,398 +1983,6 @@ def test_grouped_dense_grad_fp8(self, fwd_bwd_dtype, scaling_mode, input_shape): assert_allclose(prim_wgrad, ref_wgrad, dtype=bwd_dtype) assert_allclose(prim_dbias, ref_dbias, dtype=dtype) - @pytest.mark.skipif(not is_mxfp8_supported, reason=mxfp8_unsupported_reason) - def test_grouped_dense_mxfp8_v1_pipeline(self, input_shape): - """V1 pipeline: V1 grouped quantize + V1 grouped GEMM. - - Uses shapes where K or N is not 128-aligned, forcing V1 quantize (pre_swizzled=False) - and V1 GEMM on all GPUs. Verifies correctness and that pre_swizzled=False. - """ - n_groups, m, n, k = input_shape - # Skip shapes where K is 128-aligned; on SM100+, V2 quantize is used for any - # shape where both the total row count and K are 128-aligned (N alignment is - # not required for quantize, only for GEMM). - if k % 128 == 0: - pytest.skip("Shape is V2-eligible (K is 128-aligned); this test targets V1-only shapes") - lhs, rhs, group_sizes, contracting_dims, _ = self._generate_grouped_dense_input( - jnp.bfloat16, input_shape, group_size_multiplier=32 - ) - ref_out = self._ref_grouped_dense(lhs, rhs, None, group_sizes, contracting_dims) - - quantizer = QuantizerFactory.create( - scaling_mode=ScalingMode.MXFP8_1D_SCALING, - q_dtype=jnp.float8_e4m3fn, - q_layout=QuantizeLayout.ROWWISE, - n_groups=n_groups, - ) - # V1 quantize: K or N not 128-aligned → pre_swizzled=False - casted_lhs = jax.jit(tex.grouped_quantize, static_argnames=("flatten_axis",))( - lhs, quantizer=quantizer, group_sizes=group_sizes, flatten_axis=-1 - ) - assert isinstance(casted_lhs, GroupedScaledTensor1x) - assert not casted_lhs.pre_swizzled, "V1 quantize must produce pre_swizzled=False" - - lhs_tensor = GroupedNoScaleTensor( - data=lhs, amax=None, first_dims=group_sizes, last_dims=None, original_shape=lhs.shape - ) - rhs_tensor = GroupedNoScaleTensor( - data=rhs, amax=None, first_dims=None, last_dims=None, original_shape=rhs.shape - ) - quantizer_set = QuantizerFactory.create_set( - scaling_mode=ScalingMode.MXFP8_1D_SCALING, - fwd_dtype=jnp.float8_e4m3fn, - bwd_dtype=jnp.float8_e4m3fn, - is_2x2x=False, - n_groups=n_groups, - ) - prim_out = jax.jit(tex.grouped_gemm, static_argnames=("contracting_dims",))( - lhs_tensor, - rhs_tensor, - contracting_dims=contracting_dims, - quantizer_set=quantizer_set, - ) - self._assert_grouped_gemm_output(prim_out, group_sizes, ref_out, jnp.float8_e4m3fn) - - @pytest.mark.skipif(not is_v2_grouped_gemm_supported, reason=v2_grouped_gemm_unsupported_reason) - @pytest.mark.skipif(not is_mxfp8_supported, reason=mxfp8_unsupported_reason) - def test_grouped_dense_mxfp8_v2_pipeline(self, input_shape): - """V2 pipeline: V2 grouped quantize + V2 grouped GEMM (SM100+ required). - - Uses shapes where both K and N are 128-aligned, enabling V2 quantize (pre_swizzled=True) - and V2 GEMM on SM100+. Verifies correctness and that pre_swizzled=True. - """ - n_groups, m, n, k = input_shape - # Skip shapes that are not V2-eligible (K or N not 128-aligned). - if k % 128 != 0 or n % 128 != 0: - pytest.skip("Shape is not V2-eligible (K or N not 128-aligned)") - lhs, rhs, group_sizes, contracting_dims, _ = self._generate_grouped_dense_input( - jnp.bfloat16, input_shape, group_size_multiplier=128 - ) - ref_out = self._ref_grouped_dense(lhs, rhs, None, group_sizes, contracting_dims) - - quantizer = QuantizerFactory.create( - scaling_mode=ScalingMode.MXFP8_1D_SCALING, - q_dtype=jnp.float8_e4m3fn, - q_layout=QuantizeLayout.ROWWISE, - n_groups=n_groups, - ) - # V2 quantize (SM100+, 128-aligned M, K): pre_swizzled=True - casted_lhs = jax.jit(tex.grouped_quantize, static_argnames=("flatten_axis",))( - lhs, quantizer=quantizer, group_sizes=group_sizes, flatten_axis=-1 - ) - assert isinstance(casted_lhs, GroupedScaledTensor1x) - assert casted_lhs.pre_swizzled, "V2 quantize (SM100+) must produce pre_swizzled=True" - - lhs_tensor = GroupedNoScaleTensor( - data=lhs, amax=None, first_dims=group_sizes, last_dims=None, original_shape=lhs.shape - ) - rhs_tensor = GroupedNoScaleTensor( - data=rhs, amax=None, first_dims=None, last_dims=None, original_shape=rhs.shape - ) - quantizer_set = QuantizerFactory.create_set( - scaling_mode=ScalingMode.MXFP8_1D_SCALING, - fwd_dtype=jnp.float8_e4m3fn, - bwd_dtype=jnp.float8_e4m3fn, - is_2x2x=False, - n_groups=n_groups, - ) - prim_out = jax.jit(tex.grouped_gemm, static_argnames=("contracting_dims",))( - lhs_tensor, - rhs_tensor, - contracting_dims=contracting_dims, - quantizer_set=quantizer_set, - ) - self._assert_grouped_gemm_output(prim_out, group_sizes, ref_out, jnp.float8_e4m3fn) - - -# MXFP8 V1 shapes: lhs total_rows = m * 32 and rhs total_rows = n_groups * k are -# NOT divisible by 128, forcing the V1 (non-CUDA-graph-safe) kernel. -GROUPED_DENSE_MXFP8_V1_INPUT_SHAPES = [ - # (n_groups, m, n, k) - # lhs total_rows = m * 32; rhs total_rows = n_groups * k - (5, 6, 128, 64), # lhs: 6*32=192 (not 128-aligned); rhs: 5*64=320 (not 128-aligned) -] - -# MXFP8 V2 shapes: lhs total_rows = m * 128 and rhs total_rows = n_groups * k are -# divisible by 128, allowing the V2 (CUDA-graph-safe) kernel to be used. -# These shapes must be paired with group_size_multiplier=128 so that each group's -# row count is also divisible by 128 (the V2 per-group alignment requirement). -# Additionally, both the last dimension of lhs (K) and of rhs (N) must be 128-aligned -# to match the V2 grouped GEMM constraint (and the updated V2 quantize constraint). -GROUPED_DENSE_MXFP8_V2_INPUT_SHAPES = [ - # (n_groups, m, n, k) - # lhs: (m*128, k); rhs: (n_groups, k, n) - # V2 requires: m*128 % 128==0, k % 128==0 (lhs), n_groups*k % 128==0, n % 128==0 (rhs) - (8, 8, 128, 128), # lhs: M=1024 ✓, K=128 ✓; rhs: G*K=1024 ✓, N=128 ✓ - (4, 4, 128, 256), # lhs: M=512 ✓, K=256 ✓; rhs: G*K=1024 ✓, N=128 ✓ -] - - -@pytest.mark.skipif(not is_fp8_supported, reason=fp8_unsupported_reason) -class TestGroupedDenseMXFP8KernelSelection: - """Tests that explicitly verify V1 and V2 MXFP8 grouped quantize kernel selection. - - V2 is the CUDA-graph-safe kernel and requires: - - total_first_dim (= product of input shape up to flatten_axis) % 128 == 0 - - each individual group_size % 128 == 0 (enforced by the kernel at runtime) - V1 is the fallback that supports arbitrary shapes but performs a D2H copy of - group_sizes (not CUDA-graph safe). - """ - - def _generate_mxfp8_input(self, input_shape, group_size_multiplier): - """Generate inputs with the given group_size_multiplier for MXFP8 tests.""" - key = jax.random.PRNGKey(42) - subkeys = jax.random.split(key, 3) - n_groups, m, n, k = input_shape - - group_sizes = jnp.sort(jax.random.randint(subkeys[0], (n_groups - 1,), 0, m)) - group_sizes = jnp.concatenate([jnp.array([0]), group_sizes, jnp.array([m])]) - group_sizes = jnp.diff(group_sizes) - group_sizes = group_sizes.at[0].set(group_sizes[0] + group_sizes[1]) - group_sizes = group_sizes.at[1].set(0) - group_sizes = group_sizes * group_size_multiplier - m_total = m * group_size_multiplier - - lhs = jax.random.uniform(subkeys[1], (m_total, k), dtype=jnp.bfloat16) - rhs = jax.random.uniform(subkeys[2], (n_groups, k, n), dtype=jnp.bfloat16) - return lhs, rhs, group_sizes - - @pytest.mark.parametrize( - "input_shape", - GROUPED_DENSE_MXFP8_V1_INPUT_SHAPES, - ids=[f"v1_{s}" for s in GROUPED_DENSE_MXFP8_V1_INPUT_SHAPES], - ) - def test_grouped_gemm_mxfp8_v1_shapes(self, input_shape): - """MXFP8 grouped GEMM with V1-only shapes (total_first_dim not 128-aligned).""" - lhs, rhs, group_sizes = self._generate_mxfp8_input(input_shape, group_size_multiplier=32) - quantizer_set = QuantizerFactory.create_set( - scaling_mode=ScalingMode.MXFP8_1D_SCALING, - fwd_dtype=jnp.float8_e4m3fn, - bwd_dtype=jnp.float8_e4m3fn, - is_2x2x=False, - n_groups=input_shape[0], - ) - lhs_tensor = GroupedNoScaleTensor( - data=lhs, amax=None, first_dims=group_sizes, last_dims=None, original_shape=lhs.shape - ) - rhs_tensor = GroupedNoScaleTensor( - data=rhs, amax=None, first_dims=None, last_dims=None, original_shape=rhs.shape - ) - # Reference: unquantized grouped GEMM - n_groups = input_shape[0] - lhs_splits = jnp.split(lhs, jnp.cumulative_sum(group_sizes)[:-1], axis=0) - rhs_splits = jnp.split(rhs, n_groups, axis=0) - ref_out = jnp.concatenate( - [jnp.squeeze(lhs_i @ rhs_i, axis=0) for lhs_i, rhs_i in zip(lhs_splits, rhs_splits)], - axis=0, - ) - prim_out = jax.jit(tex.grouped_gemm, static_argnames=("contracting_dims",))( - lhs_tensor, - rhs_tensor, - contracting_dims=((1,), (1,)), - quantizer_set=quantizer_set, - ) - # Check output has correct shape and dtype; numerical precision is expected to be lower - # due to FP8 quantization but the result should be finite. - assert prim_out.shape == ref_out.shape - assert jnp.all(jnp.isfinite(prim_out)) - - @pytest.mark.parametrize( - "input_shape", - GROUPED_DENSE_MXFP8_V2_INPUT_SHAPES, - ids=[f"v2_{s}" for s in GROUPED_DENSE_MXFP8_V2_INPUT_SHAPES], - ) - def test_grouped_gemm_mxfp8_v2_shapes(self, input_shape): - """MXFP8 grouped GEMM with V2-eligible shapes (total_first_dim 128-aligned, - group_sizes also 128-aligned).""" - lhs, rhs, group_sizes = self._generate_mxfp8_input(input_shape, group_size_multiplier=128) - quantizer_set = QuantizerFactory.create_set( - scaling_mode=ScalingMode.MXFP8_1D_SCALING, - fwd_dtype=jnp.float8_e4m3fn, - bwd_dtype=jnp.float8_e4m3fn, - is_2x2x=False, - n_groups=input_shape[0], - ) - lhs_tensor = GroupedNoScaleTensor( - data=lhs, amax=None, first_dims=group_sizes, last_dims=None, original_shape=lhs.shape - ) - rhs_tensor = GroupedNoScaleTensor( - data=rhs, amax=None, first_dims=None, last_dims=None, original_shape=rhs.shape - ) - n_groups = input_shape[0] - lhs_splits = jnp.split(lhs, jnp.cumulative_sum(group_sizes)[:-1], axis=0) - rhs_splits = jnp.split(rhs, n_groups, axis=0) - ref_out = jnp.concatenate( - [jnp.squeeze(lhs_i @ rhs_i, axis=0) for lhs_i, rhs_i in zip(lhs_splits, rhs_splits)], - axis=0, - ) - prim_out = jax.jit(tex.grouped_gemm, static_argnames=("contracting_dims",))( - lhs_tensor, - rhs_tensor, - contracting_dims=((1,), (1,)), - quantizer_set=quantizer_set, - ) - assert prim_out.shape == ref_out.shape - assert jnp.all(jnp.isfinite(prim_out)) - # Numerical check within FP8 tolerance - assert_allclose(prim_out, ref_out, dtype=jnp.float8_e4m3fn) - - @pytest.mark.parametrize( - "input_shape", - GROUPED_DENSE_MXFP8_V2_INPUT_SHAPES, - ids=[f"v2_grad_{s}" for s in GROUPED_DENSE_MXFP8_V2_INPUT_SHAPES], - ) - def test_grouped_dense_grad_mxfp8_v2(self, input_shape): - """MXFP8 V2 grouped GEMM gradient test (fwd + dgrad + wgrad).""" - lhs, rhs, group_sizes = self._generate_mxfp8_input(input_shape, group_size_multiplier=128) - n_groups = input_shape[0] - fwd_dtype = jnp.float8_e4m3fn - bwd_dtype = jnp.float8_e4m3fn - - quantizer_set = QuantizerFactory.create_set( - scaling_mode=ScalingMode.MXFP8_1D_SCALING, - fwd_dtype=fwd_dtype, - bwd_dtype=bwd_dtype, - is_2x2x=True, - n_groups=n_groups, - ) - - contracting_dims = ((1,), (1,)) - - def _ref_sum(x, kernel, group_sizes): - lhs_splits = jnp.split(x, jnp.cumulative_sum(group_sizes)[:-1], axis=0) - rhs_splits = jnp.split(kernel, n_groups, axis=0) - out = jnp.concatenate( - [jnp.squeeze(li @ ri, axis=0) for li, ri in zip(lhs_splits, rhs_splits)], axis=0 - ) - return jnp.sum(out) / jnp.sqrt(x.size) - - def _prim_sum(x, kernel, group_sizes): - out = grouped_dense( - x, - kernel, - group_sizes, - contracting_dims, - bias=None, - quantizer_set=quantizer_set, - ) - return jnp.sum(jnp.asarray(out)) / jnp.sqrt(x.size) - - ref_val, (ref_dx, ref_dk) = value_and_grad(_ref_sum, (0, 1))(lhs, rhs, group_sizes) - prim_val, (prim_dx, prim_dk) = jit(value_and_grad(_prim_sum, (0, 1)), static_argnums=())( - lhs, rhs, group_sizes - ) - - assert_allclose(prim_val, ref_val, dtype=fwd_dtype) - assert_allclose(prim_dx, ref_dx, dtype=bwd_dtype) - assert_allclose(prim_dk, ref_dk, dtype=bwd_dtype) - - -# BF16 grouped GEMM V1/V2 shapes: no special shape alignment needed for BF16 GEMM. -# V2 is selected based solely on hardware (SM100+), not shape. -GROUPED_DENSE_BF16_INPUT_SHAPES = [ - # (n_groups, m, n, k) - (8, 8, 128, 128), - (4, 4, 64, 256), -] - - -@pytest.mark.skipif(not is_v2_grouped_gemm_supported, reason=v2_grouped_gemm_unsupported_reason) -class TestGroupedDenseBF16V2GEMM: - """Tests that explicitly verify V2 BF16 grouped GEMM on SM100+ hardware. - - For BF16, the V2 (CUDA-graph-safe) grouped GEMM is selected when: - - The device compute capability is >= 100 (Blackwell or newer) - - The cuBLAS version supports it - V1 (nvte_multi_tensor_gemm) is the fallback on older hardware. - - V1 BF16 grouped GEMM is tested by TestGroupedDense.test_grouped_gemm_fp16 - (using use_async_d2h_group_sizes=True). - """ - - def _generate_bf16_input(self, input_shape, group_size_multiplier=32): - key = jax.random.PRNGKey(7) - subkeys = jax.random.split(key, 3) - n_groups, m, n, k = input_shape - - group_sizes = jnp.sort(jax.random.randint(subkeys[0], (n_groups - 1,), 0, m)) - group_sizes = jnp.concatenate([jnp.array([0]), group_sizes, jnp.array([m])]) - group_sizes = jnp.diff(group_sizes) - group_sizes = group_sizes.at[0].set(group_sizes[0] + group_sizes[1]) - group_sizes = group_sizes.at[1].set(0) - group_sizes = group_sizes * group_size_multiplier - m_total = m * group_size_multiplier - - lhs = jax.random.uniform(subkeys[1], (m_total, k), dtype=jnp.bfloat16) - rhs = jax.random.uniform(subkeys[2], (n_groups, k, n), dtype=jnp.bfloat16) - return lhs, rhs, group_sizes - - @pytest.mark.parametrize( - "input_shape", - GROUPED_DENSE_BF16_INPUT_SHAPES, - ids=[f"bf16_v2_{s}" for s in GROUPED_DENSE_BF16_INPUT_SHAPES], - ) - def test_grouped_gemm_bf16_v2(self, input_shape): - """BF16 grouped GEMM using the V2 (CUDA-graph-safe) kernel on SM100+.""" - lhs, rhs, group_sizes = self._generate_bf16_input(input_shape) - n_groups = input_shape[0] - - lhs_tensor = GroupedNoScaleTensor( - data=lhs, amax=None, first_dims=group_sizes, last_dims=None, original_shape=lhs.shape - ) - rhs_tensor = GroupedNoScaleTensor( - data=rhs, amax=None, first_dims=None, last_dims=None, original_shape=rhs.shape - ) - - lhs_splits = jnp.split(lhs, jnp.cumulative_sum(group_sizes)[:-1], axis=0) - rhs_splits = jnp.split(rhs, n_groups, axis=0) - ref_out = jnp.concatenate( - [jnp.squeeze(lhs_i @ rhs_i, axis=0) for lhs_i, rhs_i in zip(lhs_splits, rhs_splits)], - axis=0, - ) - - prim_out = jax.jit(tex.grouped_gemm, static_argnames=("contracting_dims",))( - lhs_tensor, - rhs_tensor, - contracting_dims=((1,), (1,)), - ) - - assert prim_out.shape == ref_out.shape - assert prim_out.dtype == jnp.bfloat16 - assert_allclose(prim_out, ref_out, dtype=jnp.bfloat16) - - @pytest.mark.parametrize( - "input_shape", - GROUPED_DENSE_BF16_INPUT_SHAPES, - ids=[f"bf16_v2_grad_{s}" for s in GROUPED_DENSE_BF16_INPUT_SHAPES], - ) - def test_grouped_dense_grad_bf16_v2(self, input_shape): - """BF16 grouped GEMM gradient test (fwd + dgrad + wgrad) using V2 on SM100+.""" - lhs, rhs, group_sizes = self._generate_bf16_input(input_shape) - n_groups = input_shape[0] - contracting_dims = ((1,), (1,)) - - def _ref_sum(x, kernel, group_sizes): - lhs_splits = jnp.split(x, jnp.cumulative_sum(group_sizes)[:-1], axis=0) - rhs_splits = jnp.split(kernel, n_groups, axis=0) - out = jnp.concatenate( - [jnp.squeeze(li @ ri, axis=0) for li, ri in zip(lhs_splits, rhs_splits)], axis=0 - ) - return jnp.sum(out) / jnp.sqrt(x.size) - - def _prim_sum(x, kernel, group_sizes): - out = grouped_dense(x, kernel, group_sizes, contracting_dims, bias=None) - return jnp.sum(jnp.asarray(out)) / jnp.sqrt(x.size) - - ref_val, (ref_dx, ref_dk) = value_and_grad(_ref_sum, (0, 1))(lhs, rhs, group_sizes) - prim_val, (prim_dx, prim_dk) = jit(value_and_grad(_prim_sum, (0, 1)), static_argnums=())( - lhs, rhs, group_sizes - ) - - assert_allclose(prim_val, ref_val, dtype=jnp.bfloat16) - assert_allclose(prim_dx, ref_dx, dtype=jnp.bfloat16) - assert_allclose(prim_dk, ref_dk, dtype=jnp.bfloat16) - class TestDebugInspectFFI: From 7cafd3556173fbcd8d87749b6cdf94ece3ac56d4 Mon Sep 17 00:00:00 2001 From: Jeremy Berchtold Date: Tue, 7 Apr 2026 19:02:10 -0700 Subject: [PATCH 50/60] Fix tests Signed-off-by: Jeremy Berchtold --- tests/jax/test_custom_call_compute.py | 25 +++++-- .../jax/cpp_extensions/quantization.py | 3 + .../jax/quantize/dequantizer.py | 75 +++++++++++++++++-- 3 files changed, 88 insertions(+), 15 deletions(-) diff --git a/tests/jax/test_custom_call_compute.py b/tests/jax/test_custom_call_compute.py index f42e99ba62..93adc08717 100644 --- a/tests/jax/test_custom_call_compute.py +++ b/tests/jax/test_custom_call_compute.py @@ -1091,8 +1091,17 @@ def test_grouped_qdq( key = jax.random.PRNGKey(0) subkeys = jax.random.split(key, 2) - # *32 so that the input shapes works for MXFP8 - input_shape = (m * 32, n) + # Use 128 multiplier for V2-eligible MXFP8 shapes (both M and K 128-aligned) + # so that per-group row counts are also 128-aligned as required by the V2 kernel. + # Use 32 for other shapes (V1 handles arbitrary group sizes). + v2_eligible = ( + scaling_mode == ScalingMode.MXFP8_1D_SCALING + and is_v2_grouped_gemm_supported + and (m * 32) % 128 == 0 + and n % 128 == 0 + ) + group_size_multiplier = 128 if v2_eligible else 32 + input_shape = (m * group_size_multiplier, n) if with_group_sizes: group_sizes = jnp.sort(jax.random.randint(subkeys[0], (n_groups - 1,), 0, m)) @@ -1100,7 +1109,7 @@ def test_grouped_qdq( group_sizes = jnp.diff(group_sizes) assert group_sizes.sum() == m assert jnp.any(group_sizes == 0) # make sure that at least one group has 0 row - group_sizes = group_sizes * 32 + group_sizes = group_sizes * group_size_multiplier else: group_sizes = None input_shape = (n_groups, input_shape[0] // n_groups, input_shape[1]) @@ -1122,16 +1131,15 @@ def test_grouped_qdq( assert_dequantized_grouped_scaled_tensor(scaled_tensor, x) - # Verify MXFP8 pre_swizzled flag for ROWWISE grouped quantize with explicit group_sizes. - # V2 grouped quantize (SM100+) fuses the scale swizzle and sets pre_swizzled=True - # when both total M and K are 128-aligned. V1 always produces pre_swizzled=False. + # Verify MXFP8 pre_swizzled flag for ROWWISE with explicit group_sizes. + # pre_swizzled=True indicates the V2 kernel was used (SM100+, 128-aligned dims). if ( scaling_mode == ScalingMode.MXFP8_1D_SCALING and q_layout == QuantizeLayout.ROWWISE and with_group_sizes and isinstance(scaled_tensor, GroupedScaledTensor1x) ): - total_m = m * 32 + total_m = m * group_size_multiplier k_dim = n if is_v2_grouped_gemm_supported and total_m % 128 == 0 and k_dim % 128 == 0: # V2 path on SM100+: scales are pre-swizzled for GEMM @@ -1142,7 +1150,8 @@ def test_grouped_qdq( elif k_dim % 128 != 0: # V1 path: non-128-aligned K forces V1 quantize assert not scaled_tensor.pre_swizzled, ( - "V1 grouped quantize (non-128-aligned K) must produce pre_swizzled=False" + "V1 grouped quantize (non-128-aligned K) must produce" + " pre_swizzled=False" ) diff --git a/transformer_engine/jax/cpp_extensions/quantization.py b/transformer_engine/jax/cpp_extensions/quantization.py index 85f726b043..3ef1444178 100644 --- a/transformer_engine/jax/cpp_extensions/quantization.py +++ b/transformer_engine/jax/cpp_extensions/quantization.py @@ -1346,6 +1346,9 @@ def grouped_quantize( # V2 grouped quantize (nvte_group_quantize) fuses the scale_inv swizzle into # the kernel, so the resulting tensors are already swizzled for GEMM. + # Note: V1 also produces swizzled scales (via set_with_gemm_swizzled_scales), + # but pre_swizzled is only set for V2 to maintain pytree compatibility. + # The dequantizer detects MXFP8 swizzling via the scaling_mode instead. use_v2 = GroupedQuantizePrimitive._use_v2_kernel( quantizer.scaling_mode.value, x.shape, flatten_axis ) diff --git a/transformer_engine/jax/quantize/dequantizer.py b/transformer_engine/jax/quantize/dequantizer.py index 43a3d50f4e..92e9f994fe 100644 --- a/transformer_engine/jax/quantize/dequantizer.py +++ b/transformer_engine/jax/quantize/dequantizer.py @@ -263,7 +263,37 @@ def dequantize(scaled_tensor): } -@staticmethod +def _unswizzle_mxfp8_grouped_scale(scale_inv_flat, padded_scale_2d, is_colwise): + """Un-swizzle MXFP8 GEMM-swizzled scale_inv back to plain layout. + + Both V1 and V2 MXFP8 grouped quantize produce scale_inv in a GEMM-swizzled + layout. This is the inverse of ``swizzled_scale`` in ``gemm.py``. + + The swizzle pattern (for rowwise) is: + reshape(R//128, 4, 32, C//4, 4) → transpose(0,3,2,1,4) → reshape(R, C) + The inverse is: + reshape(R//128, C//4, 32, 4, 4) → transpose(0,3,2,1,4) → reshape(R, C) + + For colwise the swizzle is applied to the transposed scale, so the inverse + must un-transpose as well. + """ + if is_colwise: + # Colwise forward: reshape_2d → transpose → swizzle_5d → reshape_original + # Inverse: reshape_to_5d → inverse_swizzle → reshape_to_transposed_2d → transpose + cols, rows = padded_scale_2d + scale_2d = scale_inv_flat.reshape(cols, rows) + # The swizzled data lives in the transposed (rows, cols) domain + reshaped = scale_2d.reshape(rows // 128, cols // 4, 32, 4, 4) + unswizzled = jnp.transpose(reshaped, (0, 3, 2, 1, 4)) + # Back to transposed 2D, then un-transpose + return jnp.transpose(unswizzled.reshape(rows, cols)) + else: + rows, cols = padded_scale_2d + reshaped = scale_inv_flat.reshape(rows // 128, cols // 4, 32, 4, 4) + unswizzled = jnp.transpose(reshaped, (0, 3, 2, 1, 4)) + return unswizzled.reshape(rows, cols) + + def _grouped_dequantize(grouped_scaled_tensor): """Dequantize a grouped tensor. @@ -333,22 +363,53 @@ def _grouped_dequantize(grouped_scaled_tensor): ) scale_inv_i = scale_inv[ scale_inv_ptr : scale_inv_ptr + math.prod(padded_scale_shape_i) - ].reshape(padded_scale_shape_i) - scale_inv_i = jax.lax.slice( - scale_inv_i, [0] * len(unpadded_scale_shape_i), unpadded_scale_shape_i - ) + ] + # MXFP8 grouped quantize (both V1 and V2) always produces GEMM-swizzled + # scales. Detect by scaling_mode (not pre_swizzled, which is only set for V2 + # to maintain pytree compatibility with the GEMM path). + is_colwise = grouped_scaled_tensor.is_colwise + needs_unswizzle = scaling_mode == ScalingMode.MXFP8_1D_SCALING + if needs_unswizzle: + flat_data_2d = ( + math.prod(data_shape_i[:flatten_axis]), + math.prod(data_shape_i[flatten_axis:]), + ) + padded_2d = scaling_mode.get_scale_shape( + flat_data_2d, is_colwise=is_colwise, is_padded=True, flatten_axis=1 + ) + unpadded_2d = scaling_mode.get_scale_shape( + flat_data_2d, is_colwise=is_colwise, is_padded=False, flatten_axis=1 + ) + scale_inv_i = _unswizzle_mxfp8_grouped_scale( + scale_inv_i, padded_2d, is_colwise + ) + scale_inv_i = jax.lax.slice( + scale_inv_i, [0, 0], list(unpadded_2d) + ) + else: + scale_inv_i = scale_inv_i.reshape(padded_scale_shape_i) + scale_inv_i = jax.lax.slice( + scale_inv_i, [0] * len(unpadded_scale_shape_i), unpadded_scale_shape_i + ) dequantizer_type = ScalingModeToDequantizerMap.get(grouped_scaled_tensor.scaling_mode) if len(data_i) == 0: out_i = [] else: + # _dequantize_func is designed for 2D-flattened data. Flatten the + # per-group shape to 2D, dequantize, then reshape back. + flat_shape_i = ( + math.prod(data_shape_i[:flatten_axis]), + math.prod(data_shape_i[flatten_axis:]), + ) out_i = dequantizer_type._dequantize_func( - data_i.reshape(data_shape_i), + data_i.reshape(flat_shape_i), scale_inv_i, grouped_scaled_tensor.dq_dtype, scaling_mode=grouped_scaled_tensor.scaling_mode, is_colwise=grouped_scaled_tensor.is_colwise, - flatten_axis=grouped_scaled_tensor.flatten_axis, + flatten_axis=1, ) + out_i = out_i.reshape(data_shape_i) output.append(out_i) scale_inv_ptr += math.prod(padded_scale_shape_i) From 49e7a60d80bcd630b8dd50981c3f87cadb61ccf4 Mon Sep 17 00:00:00 2001 From: Jeremy Berchtold Date: Tue, 7 Apr 2026 19:12:17 -0700 Subject: [PATCH 51/60] Use GroupedTensorWrapper in grouped quantization Signed-off-by: Jeremy Berchtold --- .../jax/csrc/extensions/quantization.cpp | 68 +++++-------------- 1 file changed, 16 insertions(+), 52 deletions(-) diff --git a/transformer_engine/jax/csrc/extensions/quantization.cpp b/transformer_engine/jax/csrc/extensions/quantization.cpp index 28b9ea4806..2f1e02f868 100644 --- a/transformer_engine/jax/csrc/extensions/quantization.cpp +++ b/transformer_engine/jax/csrc/extensions/quantization.cpp @@ -562,68 +562,37 @@ Error_Type GroupedQuantizeV2FFI(cudaStream_t stream, Buffer_Type inputs, Buffer_ offsets_shape.data[0] = n_groups + 1; // Build input grouped tensor (plain float data, no quantization on the input side). - NVTEGroupedTensor in_grouped = nvte_create_grouped_tensor( - get_nvte_scaling_mode(JAXX_Scaling_Mode::NO_SCALING), n_groups, data_shape); - { - NVTEBasicTensor in_data{reinterpret_cast(inputs.untyped_data()), - static_cast(in_dtype), data_shape}; - nvte_set_grouped_tensor_param(in_grouped, kNVTEGroupedRowwiseData, &in_data, sizeof(in_data)); - NVTEBasicTensor sz_tensor{reinterpret_cast(int64_ptr), NVTEDType::kNVTEInt64, - sz_shape}; - nvte_set_grouped_tensor_param(in_grouped, kNVTEGroupedFirstDims, &sz_tensor, sizeof(sz_tensor)); - NVTEBasicTensor offsets_tensor{reinterpret_cast(offsets_ptr_out), - NVTEDType::kNVTEInt64, offsets_shape}; - nvte_set_grouped_tensor_param(in_grouped, kNVTEGroupedTensorOffsets, &offsets_tensor, - sizeof(offsets_tensor)); - } + GroupedTensorWrapper in_grouped(n_groups, data_shape, + get_nvte_scaling_mode(JAXX_Scaling_Mode::NO_SCALING)); + in_grouped.set_rowwise_data(reinterpret_cast(inputs.untyped_data()), in_dtype, + data_shape) + .set_first_dims(reinterpret_cast(int64_ptr), DType::kInt64, sz_shape) + .set_tensor_offsets(reinterpret_cast(offsets_ptr_out), DType::kInt64, offsets_shape); // Build output grouped tensor. - NVTEGroupedTensor out_grouped = nvte_create_grouped_tensor( - get_nvte_scaling_mode(JAXX_Scaling_Mode::MXFP8_1D_SCALING), n_groups, data_shape); - - // Set group sizes and offsets on output tensor (same device pointers). - { - NVTEBasicTensor sz_tensor{reinterpret_cast(int64_ptr), NVTEDType::kNVTEInt64, - sz_shape}; - nvte_set_grouped_tensor_param(out_grouped, kNVTEGroupedFirstDims, &sz_tensor, - sizeof(sz_tensor)); - NVTEBasicTensor offsets_tensor{reinterpret_cast(offsets_ptr_out), - NVTEDType::kNVTEInt64, offsets_shape}; - nvte_set_grouped_tensor_param(out_grouped, kNVTEGroupedTensorOffsets, &offsets_tensor, - sizeof(offsets_tensor)); - } + GroupedTensorWrapper out_grouped(n_groups, data_shape, + get_nvte_scaling_mode(JAXX_Scaling_Mode::MXFP8_1D_SCALING)); + out_grouped.set_first_dims(reinterpret_cast(int64_ptr), DType::kInt64, sz_shape) + .set_tensor_offsets(reinterpret_cast(offsets_ptr_out), DType::kInt64, offsets_shape); // Rowwise output data + scale_inv. if (is_quantize_rowwise(quantize_layout)) { - NVTEBasicTensor rw_data{reinterpret_cast(rowwise_out->untyped_data()), - static_cast(out_dtype), data_shape}; - nvte_set_grouped_tensor_param(out_grouped, kNVTEGroupedRowwiseData, &rw_data, sizeof(rw_data)); - NVTEShape rw_sinv_shape{}; rw_sinv_shape.ndim = 2; rw_sinv_shape.data[0] = m; rw_sinv_shape.data[1] = n / 32; // MXFP8 block size = 32 - NVTEBasicTensor rw_sinv{reinterpret_cast(rowwise_sinv->untyped_data()), - static_cast(sinv_dtype), rw_sinv_shape}; - nvte_set_grouped_tensor_param(out_grouped, kNVTEGroupedRowwiseScaleInv, &rw_sinv, - sizeof(rw_sinv)); + out_grouped.set_rowwise_data(rowwise_out->untyped_data(), out_dtype, data_shape) + .set_rowwise_scale_inv(rowwise_sinv->untyped_data(), sinv_dtype, rw_sinv_shape); } // Colwise output data + scale_inv. if (is_quantize_colwise(quantize_layout)) { - NVTEBasicTensor cw_data{reinterpret_cast(colwise_out->untyped_data()), - static_cast(out_dtype), data_shape}; - nvte_set_grouped_tensor_param(out_grouped, kNVTEGroupedColumnwiseData, &cw_data, - sizeof(cw_data)); - NVTEShape cw_sinv_shape{}; cw_sinv_shape.ndim = 2; cw_sinv_shape.data[0] = m / 32; // MXFP8 block size = 32 cw_sinv_shape.data[1] = n; - NVTEBasicTensor cw_sinv{reinterpret_cast(colwise_sinv->untyped_data()), - static_cast(sinv_dtype), cw_sinv_shape}; - nvte_set_grouped_tensor_param(out_grouped, kNVTEGroupedColumnwiseScaleInv, &cw_sinv, - sizeof(cw_sinv)); + out_grouped.set_columnwise_data(colwise_out->untyped_data(), out_dtype, data_shape) + .set_columnwise_scale_inv(colwise_sinv->untyped_data(), sinv_dtype, cw_sinv_shape); } // Zero-initialize scale_inv buffers (mirrors V1 behaviour for MXFP8). @@ -639,15 +608,10 @@ Error_Type GroupedQuantizeV2FFI(cudaStream_t stream, Buffer_Type inputs, Buffer_ // V2 grouped quantize is always paired with V2 grouped GEMM, which expects // scale_inv in GEMM-swizzled layout. Enable the fused swizzle so the kernel // writes scales in the layout the GEMM will consume directly. - uint8_t swizzle_flag = 1; - nvte_set_grouped_tensor_param(out_grouped, kNVTEGroupedWithGEMMSwizzledScales, &swizzle_flag, - sizeof(swizzle_flag)); + out_grouped.set_with_gemm_swizzled_scales(true); QuantizationConfigWrapper quant_config{}; - nvte_group_quantize(in_grouped, out_grouped, quant_config, stream); - - nvte_destroy_grouped_tensor(in_grouped); - nvte_destroy_grouped_tensor(out_grouped); + nvte_group_quantize(in_grouped.data(), out_grouped.data(), quant_config, stream); return ffi_with_cuda_error_check(); } From 644520b92d13b9da3b50df29fbb708f8f71e0c91 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Wed, 8 Apr 2026 02:14:02 +0000 Subject: [PATCH 52/60] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- tests/jax/test_custom_call_compute.py | 24 ++++++++++--------- .../jax/csrc/extensions/quantization.cpp | 4 ++-- .../jax/quantize/dequantizer.py | 12 +++------- 3 files changed, 18 insertions(+), 22 deletions(-) diff --git a/tests/jax/test_custom_call_compute.py b/tests/jax/test_custom_call_compute.py index 93adc08717..8a64c07763 100644 --- a/tests/jax/test_custom_call_compute.py +++ b/tests/jax/test_custom_call_compute.py @@ -1072,10 +1072,13 @@ def test_rht_gemm(self, in_dtype, q_dtype, scaling_mode, m, n, k, data_layout, w @pytest.mark.skipif(not is_fp8_supported, reason=fp8_unsupported_reason) @pytest_parametrize_wrapper("in_dtype", QUANTIZATION_INPUT_DTYPE) -@pytest_parametrize_wrapper("input_shape", [ - (8, 16, 32), # V1 MXFP8: K=32 not 128-aligned - (4, 8, 128), # V2 MXFP8 eligible: K=128, M*32=256 both 128-aligned -]) +@pytest_parametrize_wrapper( + "input_shape", + [ + (8, 16, 32), # V1 MXFP8: K=32 not 128-aligned + (4, 8, 128), # V2 MXFP8 eligible: K=128, M*32=256 both 128-aligned + ], +) @pytest_parametrize_wrapper("q_dtype", [jnp.float8_e4m3fn]) @pytest_parametrize_wrapper("scaling_mode", non_fp4_supported_scaling_modes) @pytest_parametrize_wrapper("flatten_axis", [-1]) @@ -1149,10 +1152,9 @@ def test_grouped_qdq( ) elif k_dim % 128 != 0: # V1 path: non-128-aligned K forces V1 quantize - assert not scaled_tensor.pre_swizzled, ( - "V1 grouped quantize (non-128-aligned K) must produce" - " pre_swizzled=False" - ) + assert ( + not scaled_tensor.pre_swizzled + ), "V1 grouped quantize (non-128-aligned K) must produce pre_swizzled=False" @pytest_parametrize_wrapper("in_dtype", QUANTIZATION_INPUT_DTYPE) @@ -1753,10 +1755,10 @@ def ref_func(x, gamma, kernel_1, kernel_2, bias_1, bias_2): GROUPED_DENSE_INPUT_SHAPES = [ # (n_groups, m, n, k), the actual m will be multiplied by group_size_multiplier - (5, 32, 128, 64), # V1 MXFP8: K=64 not 128-aligned; also tests n_groups not a multiple of 4 - (8, 64, 32, 128), # V1 MXFP8 GEMM: N=32 not 128-aligned + (5, 32, 128, 64), # V1 MXFP8: K=64 not 128-aligned; also tests n_groups not a multiple of 4 + (8, 64, 32, 128), # V1 MXFP8 GEMM: N=32 not 128-aligned (8, 64, 128, 256), # V2 MXFP8 eligible: K=256, N=128 both 128-aligned - (4, 4, 128, 128), # V2 MXFP8 eligible: K=128, N=128 both 128-aligned (smaller shape) + (4, 4, 128, 128), # V2 MXFP8 eligible: K=128, N=128 both 128-aligned (smaller shape) ] diff --git a/transformer_engine/jax/csrc/extensions/quantization.cpp b/transformer_engine/jax/csrc/extensions/quantization.cpp index 2f1e02f868..db9cf94db5 100644 --- a/transformer_engine/jax/csrc/extensions/quantization.cpp +++ b/transformer_engine/jax/csrc/extensions/quantization.cpp @@ -564,8 +564,8 @@ Error_Type GroupedQuantizeV2FFI(cudaStream_t stream, Buffer_Type inputs, Buffer_ // Build input grouped tensor (plain float data, no quantization on the input side). GroupedTensorWrapper in_grouped(n_groups, data_shape, get_nvte_scaling_mode(JAXX_Scaling_Mode::NO_SCALING)); - in_grouped.set_rowwise_data(reinterpret_cast(inputs.untyped_data()), in_dtype, - data_shape) + in_grouped + .set_rowwise_data(reinterpret_cast(inputs.untyped_data()), in_dtype, data_shape) .set_first_dims(reinterpret_cast(int64_ptr), DType::kInt64, sz_shape) .set_tensor_offsets(reinterpret_cast(offsets_ptr_out), DType::kInt64, offsets_shape); diff --git a/transformer_engine/jax/quantize/dequantizer.py b/transformer_engine/jax/quantize/dequantizer.py index 92e9f994fe..b46e4ff9d5 100644 --- a/transformer_engine/jax/quantize/dequantizer.py +++ b/transformer_engine/jax/quantize/dequantizer.py @@ -361,9 +361,7 @@ def _grouped_dequantize(grouped_scaled_tensor): is_padded=False, flatten_axis=flatten_axis, ) - scale_inv_i = scale_inv[ - scale_inv_ptr : scale_inv_ptr + math.prod(padded_scale_shape_i) - ] + scale_inv_i = scale_inv[scale_inv_ptr : scale_inv_ptr + math.prod(padded_scale_shape_i)] # MXFP8 grouped quantize (both V1 and V2) always produces GEMM-swizzled # scales. Detect by scaling_mode (not pre_swizzled, which is only set for V2 # to maintain pytree compatibility with the GEMM path). @@ -380,12 +378,8 @@ def _grouped_dequantize(grouped_scaled_tensor): unpadded_2d = scaling_mode.get_scale_shape( flat_data_2d, is_colwise=is_colwise, is_padded=False, flatten_axis=1 ) - scale_inv_i = _unswizzle_mxfp8_grouped_scale( - scale_inv_i, padded_2d, is_colwise - ) - scale_inv_i = jax.lax.slice( - scale_inv_i, [0, 0], list(unpadded_2d) - ) + scale_inv_i = _unswizzle_mxfp8_grouped_scale(scale_inv_i, padded_2d, is_colwise) + scale_inv_i = jax.lax.slice(scale_inv_i, [0, 0], list(unpadded_2d)) else: scale_inv_i = scale_inv_i.reshape(padded_scale_shape_i) scale_inv_i = jax.lax.slice( From 56fce55dbe1544905b51f37f2ada492ac9279415 Mon Sep 17 00:00:00 2001 From: Jeremy Berchtold Date: Wed, 8 Apr 2026 09:16:59 -0700 Subject: [PATCH 53/60] Fix merge conflict issue Signed-off-by: Jeremy Berchtold --- transformer_engine/jax/csrc/extensions/gemm.cpp | 13 ------------- 1 file changed, 13 deletions(-) diff --git a/transformer_engine/jax/csrc/extensions/gemm.cpp b/transformer_engine/jax/csrc/extensions/gemm.cpp index fecb477793..6ca907032c 100644 --- a/transformer_engine/jax/csrc/extensions/gemm.cpp +++ b/transformer_engine/jax/csrc/extensions/gemm.cpp @@ -943,19 +943,6 @@ Error_Type GroupedGemmV2FFI(cudaStream_t stream, Buffer_Type lhs_data, Buffer_Ty gemmConfig.set_avg_n(avg_n); gemmConfig.set_avg_k(avg_k_lhs); - auto [avg_m, avg_k_lhs] = grouped_gemm_avg_dims( - lhs_first_dims, lhs_last_dims, {lhs_left_size, lhs_right_size}, num_gemms, lhs_is_trans); - auto [avg_n, avg_k_rhs] = grouped_gemm_avg_dims( - rhs_first_dims, rhs_last_dims, {rhs_left_size, rhs_right_size}, num_gemms, !rhs_is_trans); - // Use k from lhs (both sides should agree for well-formed inputs). - NVTE_CHECK(avg_k_lhs == avg_k_rhs, "Contracting dimension mismatch: lhs avg_k=", avg_k_lhs, - " vs rhs avg_k=", avg_k_rhs); - - GroupedMatmulConfigWrapper gemmConfig{}; - gemmConfig.set_avg_m(avg_m); - gemmConfig.set_avg_n(avg_n); - gemmConfig.set_avg_k(avg_k_lhs); - nvte_grouped_gemm(rhs_tensor, rhs_is_trans, lhs_tensor, lhs_is_trans, nullptr, out_tensor, alpha_tensor.data(), beta_tensor.data(), workspace_setup.data(), workspace_cublas.data(), gemmConfig, stream); From 9ea2482200f2680744cb61bbdea7acc31a80b622 Mon Sep 17 00:00:00 2001 From: Jeremy Berchtold Date: Thu, 9 Apr 2026 17:22:35 -0700 Subject: [PATCH 54/60] Address comments Signed-off-by: Jeremy Berchtold --- tests/jax/test_custom_call_compute.py | 50 +----- transformer_engine/jax/cpp_extensions/gemm.py | 158 +++++++++--------- .../jax/cpp_extensions/quantization.py | 45 ++--- .../jax/csrc/extensions/quantization.cpp | 6 +- 4 files changed, 112 insertions(+), 147 deletions(-) diff --git a/tests/jax/test_custom_call_compute.py b/tests/jax/test_custom_call_compute.py index 8a64c07763..58de167bb6 100644 --- a/tests/jax/test_custom_call_compute.py +++ b/tests/jax/test_custom_call_compute.py @@ -43,7 +43,6 @@ noop_quantizer_set, QuantizeMetaSet, QuantizeMeta, - get_device_compute_capability, ) from transformer_engine.jax.quantize import helper from transformer_engine.jax.activation import activation @@ -78,9 +77,6 @@ supported_recipes = helper.get_supported_quantization_recipes() supported_recipes = [pytest.param(r, id=r.__class__.__name__) for r in supported_recipes] -is_v2_grouped_gemm_supported = get_device_compute_capability(0) >= 100 -v2_grouped_gemm_unsupported_reason = "V2 grouped GEMM requires SM100+ (Blackwell or newer)" - def is_shape_supported_by_mxfp8(input_shape): try: @@ -1079,6 +1075,13 @@ def test_rht_gemm(self, in_dtype, q_dtype, scaling_mode, m, n, k, data_layout, w (4, 8, 128), # V2 MXFP8 eligible: K=128, M*32=256 both 128-aligned ], ) +@pytest_parametrize_wrapper( + "group_size_multiplier", + [ + 32, # V1 MXFP8: group size must be multiple of 32 + 128, # V2 MXFP8 eligible: group size must be multiple of 128 + ], +) @pytest_parametrize_wrapper("q_dtype", [jnp.float8_e4m3fn]) @pytest_parametrize_wrapper("scaling_mode", non_fp4_supported_scaling_modes) @pytest_parametrize_wrapper("flatten_axis", [-1]) @@ -1088,22 +1091,12 @@ def test_rht_gemm(self, in_dtype, q_dtype, scaling_mode, m, n, k, data_layout, w ) class TestGroupedQuantize: def test_grouped_qdq( - self, in_dtype, input_shape, q_dtype, scaling_mode, q_layout, flatten_axis, with_group_sizes + self, in_dtype, input_shape, group_size_multiplier, q_dtype, scaling_mode, q_layout, flatten_axis, with_group_sizes ): n_groups, m, n = input_shape key = jax.random.PRNGKey(0) subkeys = jax.random.split(key, 2) - # Use 128 multiplier for V2-eligible MXFP8 shapes (both M and K 128-aligned) - # so that per-group row counts are also 128-aligned as required by the V2 kernel. - # Use 32 for other shapes (V1 handles arbitrary group sizes). - v2_eligible = ( - scaling_mode == ScalingMode.MXFP8_1D_SCALING - and is_v2_grouped_gemm_supported - and (m * 32) % 128 == 0 - and n % 128 == 0 - ) - group_size_multiplier = 128 if v2_eligible else 32 input_shape = (m * group_size_multiplier, n) if with_group_sizes: @@ -1134,28 +1127,6 @@ def test_grouped_qdq( assert_dequantized_grouped_scaled_tensor(scaled_tensor, x) - # Verify MXFP8 pre_swizzled flag for ROWWISE with explicit group_sizes. - # pre_swizzled=True indicates the V2 kernel was used (SM100+, 128-aligned dims). - if ( - scaling_mode == ScalingMode.MXFP8_1D_SCALING - and q_layout == QuantizeLayout.ROWWISE - and with_group_sizes - and isinstance(scaled_tensor, GroupedScaledTensor1x) - ): - total_m = m * group_size_multiplier - k_dim = n - if is_v2_grouped_gemm_supported and total_m % 128 == 0 and k_dim % 128 == 0: - # V2 path on SM100+: scales are pre-swizzled for GEMM - assert scaled_tensor.pre_swizzled, ( - "V2 grouped quantize (SM100+, 128-aligned M and K) must produce" - " pre_swizzled=True" - ) - elif k_dim % 128 != 0: - # V1 path: non-128-aligned K forces V1 quantize - assert ( - not scaled_tensor.pre_swizzled - ), "V1 grouped quantize (non-128-aligned K) must produce pre_swizzled=False" - @pytest_parametrize_wrapper("in_dtype", QUANTIZATION_INPUT_DTYPE) class TestFusedQuantize: @@ -1799,10 +1770,7 @@ def _generate_grouped_dense_input( group_sizes = group_sizes.at[1].set(0) assert group_sizes.sum() == m - # Scale group sizes by the multiplier. - # Use group_size_multiplier=128 for MXFP8 V2 tests so that each group's row count - # is divisible by 128, satisfying the V2 kernel's per-group alignment requirement. - # Use group_size_multiplier=32 for V1 tests or non-MXFP8 tests. + # Scale group sizes by the multiplier for alignment requirements. group_sizes = group_sizes * group_size_multiplier m = m * group_size_multiplier diff --git a/transformer_engine/jax/cpp_extensions/gemm.py b/transformer_engine/jax/cpp_extensions/gemm.py index f819401d52..dc9e4e8839 100644 --- a/transformer_engine/jax/cpp_extensions/gemm.py +++ b/transformer_engine/jax/cpp_extensions/gemm.py @@ -47,7 +47,7 @@ apply_padding_to_scale_inv, QuantizeLayout, ) -from .misc import get_padded_spec, is_all_reduce_in_float32 +from .misc import get_padded_spec, is_all_reduce_in_float32, get_min_device_compute_capability from ..sharding import ( global_mesh_resource, tpsp_axis_size, @@ -66,6 +66,7 @@ "sanitize_dims", "get_non_contracting_dims", "transpose_dims", + "is_v2_grouped_gemm_supported", ] @@ -2035,8 +2036,7 @@ def _should_enforce_v2_grouped_gemm() -> bool: f"NVTE_JAX_ENFORCE_V2_GROUPED_GEMM must be an integer (0 or 1), got: {val!r}" ) from e - -def _can_use_v2_grouped_gemm( +def _is_v2_grouped_gemm_supported( scaling_mode: ScalingMode, dtype: jnp.dtype, has_bias: bool, @@ -2044,44 +2044,28 @@ def _can_use_v2_grouped_gemm( rhs_shape=None, lhs_axis_boundary=None, rhs_axis_boundary=None, -) -> bool: - """Determine whether the cuda-graphable grouped GEMM implementation can be used based on the input parameters.""" - # Use the cuda-graphable path for plain BF16 non-quantized inputs and MXFP8; fall back to - # the legacy nvte_multi_tensor_gemm path for all other cases (tensor-scaled FP8, etc.). - # Bias can be supported in a kernel or in pure-JAX in the future. - - enforce_v2_gmm = _should_enforce_v2_grouped_gemm() +) -> tuple[bool, str]: + """Determine whether the V2 grouped GEMM implementation can be used based on the input parameters.""" if not _v2_grouped_gemm_available: - if enforce_v2_gmm: - raise RuntimeError( - "The TE V2 grouped GEMM is not available but NVTE_JAX_ENFORCE_V2_GROUPED_GEMM is" - " enabled. The reason for V2 grouped GEMM not being available:" - f" {_v2_grouped_gemm_available_reason}" - ) - return False + return False, ( + "TE was not compiled with support for the V2 grouped GEMM kernel, reason: " + f"{_v2_grouped_gemm_available_reason}" + ) # nvte_grouped_gemm (the v2 kernel) requires SM100+ (Blackwell or newer). # Fall back to the v1 path on SM90 (Hopper) and older architectures. - if get_device_compute_capability(0) < 100: - if enforce_v2_gmm: - raise RuntimeError( - "The TE V2 grouped GEMM requires SM100+ (Blackwell or newer) but current device" - f" compute capability of GPU 0 is {get_device_compute_capability(0)} and" - " NVTE_JAX_ENFORCE_V2_GROUPED_GEMM is enabled." - ) - return False + if get_min_device_compute_capability() < 100: + return False, ( + "The TE V2 grouped GEMM requires SM100+ (Blackwell or newer) but current min device" + f" compute capability is {get_min_device_compute_capability()}." + ) if has_bias: - if enforce_v2_gmm: - raise RuntimeError( - "Grouped GEMM with bias is not supported in the TE V2 grouped GEMM kernel, but" - " NVTE_JAX_ENFORCE_V2_GROUPED_GEMM is enabled and has_bias is True." - ) - return False + return False, "Grouped GEMM with bias is not supported in the TE V2 grouped GEMM kernel." if scaling_mode == ScalingMode.NO_SCALING and dtype == jnp.bfloat16: - return True + return True, "" if scaling_mode == ScalingMode.MXFP8_1D_SCALING: # V2 MXFP8 requires that the total first dimension of both operands (up to @@ -2090,27 +2074,22 @@ def _can_use_v2_grouped_gemm( if lhs_shape is not None and lhs_axis_boundary is not None: lhs_first_dim = math.prod(lhs_shape[:lhs_axis_boundary]) if lhs_first_dim % 128 != 0: - if enforce_v2_gmm: - raise RuntimeError( - "The TE V2 grouped GEMM for MXFP8 requires the product of the first" - " dimensions (up to axis_boundary) of LHS to be divisible by 128, but got" - f" {lhs_first_dim} with lhs_shape={lhs_shape} and" - f" lhs_axis_boundary={lhs_axis_boundary}, and" - " NVTE_JAX_ENFORCE_V2_GROUPED_GEMM is enabled." - ) - return False + return False, ( + "The TE V2 grouped GEMM for MXFP8 requires the product of the first" + " dimensions (up to axis_boundary) of LHS to be divisible by 128, but got" + f" {lhs_first_dim} with lhs_shape={lhs_shape} and" + f" lhs_axis_boundary={lhs_axis_boundary}." + ) if rhs_shape is not None and rhs_axis_boundary is not None: rhs_first_dim = math.prod(rhs_shape[:rhs_axis_boundary]) if rhs_first_dim % 128 != 0: - if enforce_v2_gmm: - raise RuntimeError( - "The TE V2 grouped GEMM for MXFP8 requires the product of the first" - " dimensions (up to axis_boundary) of RHS to be divisible by 128, but got" - f" {rhs_first_dim} with rhs_shape={rhs_shape} and" - f" rhs_axis_boundary={rhs_axis_boundary}, and" - " NVTE_JAX_ENFORCE_V2_GROUPED_GEMM is enabled." - ) - return False + return False, ( + "The TE V2 grouped GEMM for MXFP8 requires the product of the first" + " dimensions (up to axis_boundary) of RHS to be divisible by 128, but got" + f" {rhs_first_dim} with rhs_shape={rhs_shape} and" + f" rhs_axis_boundary={rhs_axis_boundary}." + ) + # V2 MXFP8 also requires that the "last" dimension (after axis_boundary) of both # operands is a multiple of 128. The V2 GEMM setup kernel computes per-group # scale pointers as ``data_offset / 32``, which equals ``K_blocks * last_dim``. @@ -2122,38 +2101,63 @@ def _can_use_v2_grouped_gemm( if lhs_shape is not None and lhs_axis_boundary is not None: lhs_last_dim = math.prod(lhs_shape[lhs_axis_boundary:]) if lhs_last_dim % 128 != 0: - if enforce_v2_gmm: - raise RuntimeError( - "The TE V2 grouped GEMM for MXFP8 requires the product of the last" - " dimensions (after axis_boundary) of LHS to be divisible by 128, but got" - f" {lhs_last_dim} with lhs_shape={lhs_shape} and" - f" lhs_axis_boundary={lhs_axis_boundary}, and" - " NVTE_JAX_ENFORCE_V2_GROUPED_GEMM is enabled." - ) - return False + return False, ( + "The TE V2 grouped GEMM for MXFP8 requires the product of the last" + " dimensions (after axis_boundary) of LHS to be divisible by 128, but got" + f" {lhs_last_dim} with lhs_shape={lhs_shape} and" + f" lhs_axis_boundary={lhs_axis_boundary}." + ) if rhs_shape is not None and rhs_axis_boundary is not None: rhs_last_dim = math.prod(rhs_shape[rhs_axis_boundary:]) if rhs_last_dim % 128 != 0: - if enforce_v2_gmm: - raise RuntimeError( - "The TE V2 grouped GEMM for MXFP8 requires the product of the last" - " dimensions (after axis_boundary) of RHS to be divisible by 128, but got" - f" {rhs_last_dim} with rhs_shape={rhs_shape} and" - f" rhs_axis_boundary={rhs_axis_boundary}, and" - " NVTE_JAX_ENFORCE_V2_GROUPED_GEMM is enabled." - ) - return False - return True + return False, ( + "The TE V2 grouped GEMM for MXFP8 requires the product of the last" + " dimensions (after axis_boundary) of RHS to be divisible by 128, but got" + f" {rhs_last_dim} with rhs_shape={rhs_shape} and" + f" rhs_axis_boundary={rhs_axis_boundary}." + ) + return True, "" + + return False, ( + "The TE V2 grouped GEMM currently only supports non-quantized BF16 and MXFP8 with 1D" + " block scaling, but NVTE_JAX_ENFORCE_V2_GROUPED_GEMM is enabled and the input" + f" parameters do not meet these requirements (scaling_mode= {scaling_mode}," + f" dtype={dtype}, has_bias={has_bias}, lhs_shape={lhs_shape}, rhs_shape={rhs_shape}," + f" lhs_axis_boundary={lhs_axis_boundary}, rhs_axis_boundary={rhs_axis_boundary})." + ) - if enforce_v2_gmm: +def is_v2_grouped_gemm_supported( + scaling_mode: ScalingMode, + dtype: jnp.dtype, + has_bias: bool, + lhs_shape=None, + rhs_shape=None, + lhs_axis_boundary=None, + rhs_axis_boundary=None, +) -> tuple[bool, str]: + """ Determine whether the V2 grouped GEMM implementation can be used based on the input parameters. + + Returns: + A tuple of (is_supported: bool, reason: str) where is_supported indicates whether the V2 grouped GEMM can be used, and reason provides an explanation if it is not supported. + """ + # Use the V2 path for plain BF16 non-quantized inputs and MXFP8; fall back to + # the legacy nvte_multi_tensor_gemm path for all other cases (tensor-scaled FP8, etc.). + # Bias can be supported in a kernel or in pure-JAX in the future. + + enforce_v2_gmm = _should_enforce_v2_grouped_gemm() + + is_v2_supported, reason = _is_v2_grouped_gemm_supported( + scaling_mode, dtype, has_bias, lhs_shape, rhs_shape, lhs_axis_boundary, rhs_axis_boundary + ) + + if enforce_v2_gmm and not is_v2_supported: raise RuntimeError( - "The TE V2 grouped GEMM currently only supports non-quantized BF16 and MXFP8 with 1D" - " block scaling, but NVTE_JAX_ENFORCE_V2_GROUPED_GEMM is enabled and the input" - f" parameters do not meet these requirements (scaling_mode= {scaling_mode}," - f" dtype={dtype}, has_bias={has_bias}, lhs_shape={lhs_shape}, rhs_shape={rhs_shape}," - f" lhs_axis_boundary={lhs_axis_boundary}, rhs_axis_boundary={rhs_axis_boundary})." + "The TE V2 grouped GEMM is not supported for the given input parameters, but" + " NVTE_JAX_ENFORCE_V2_GROUPED_GEMM is enabled. The reason for V2 grouped GEMM not being" + f" supported: {reason}" ) - return False + + return is_v2_supported, reason def grouped_gemm( @@ -2410,7 +2414,7 @@ def grouped_gemm( " and padded with zeros to not affect the result of the MoE block." ) - use_v2_ffi = _can_use_v2_grouped_gemm( + use_v2_ffi, _ = is_v2_grouped_gemm_supported( scaling_mode, lhs_data.dtype, has_bias, diff --git a/transformer_engine/jax/cpp_extensions/quantization.py b/transformer_engine/jax/cpp_extensions/quantization.py index 3ef1444178..b5455f5e6a 100644 --- a/transformer_engine/jax/cpp_extensions/quantization.py +++ b/transformer_engine/jax/cpp_extensions/quantization.py @@ -1095,16 +1095,18 @@ def abstract( rowwise_scale_inv_shape = (1,) rowwise_out_aval = jax.core.ShapedArray(shape=rowwise_out_shape, dtype=out_dtype) + updated_amax_aval = jax.core.ShapedArray(shape=(group_sizes_aval.size,), dtype=jnp.float32) + use_v2 = GroupedQuantizePrimitive._use_v2_kernel(scaling_mode, x_aval.shape, flatten_axis) if use_v2: - # V2 path: 5th output is int64_workspace laid out as: + # V2 path: int64_workspace laid out as: # [n_groups int64 group_sizes | n_groups+1 int64 offsets] # = (2*n_groups + 1) * sizeof(int64_t) bytes stored as uint8. n_groups = group_sizes_aval.size - fifth_out_aval = jax.core.ShapedArray(shape=((2 * n_groups + 1) * 8,), dtype=jnp.uint8) + int64_workspace_aval = jax.core.ShapedArray(shape=((2 * n_groups + 1) * 8,), dtype=jnp.uint8) else: - # V1 path: 5th output is amax - fifth_out_aval = jax.core.ShapedArray(shape=(group_sizes_aval.size,), dtype=jnp.float32) + # V1 path: Unused for V1 codepath + int64_workspace_aval = jax.core.ShapedArray(shape=(0,), dtype=jnp.uint8) if q_layout.has_colwise: colwise_out_shape = out_shape @@ -1124,7 +1126,8 @@ def abstract( colwise_out_aval, rowwise_scale_inv_aval, colwise_scale_inv_aval, - fifth_out_aval, + updated_amax_aval, + int64_workspace_aval, ) @staticmethod @@ -1134,21 +1137,14 @@ def outer_abstract(*args, **kwargs): """ # Phuong: keeping outer abstract so that we can add fuse dbias later ( - rowwise_out, - colwise_out, - scale_inv, - colwise_scale_inv, - fifth_out, + rowwise_out_aval, + colwise_out_aval, + rowwise_scale_inv_aval, + colwise_scale_inv_aval, + updated_amax_aval, + _, ) = GroupedQuantizePrimitive.abstract(*args, **kwargs) - # When V2 is used, the inner abstract returns int64_workspace as the 5th output. - # The outer interface always presents amax (float32, n_groups) for a consistent API. - scaling_mode = kwargs.get("scaling_mode") - x_aval = args[0] - group_sizes_aval = args[2] - flatten_axis = kwargs.get("flatten_axis") - if GroupedQuantizePrimitive._use_v2_kernel(scaling_mode, x_aval.shape, flatten_axis): - fifth_out = jax.core.ShapedArray(shape=(group_sizes_aval.size,), dtype=jnp.float32) - return rowwise_out, colwise_out, scale_inv, colwise_scale_inv, fifth_out + return (rowwise_out_aval, colwise_out_aval, rowwise_scale_inv_aval, colwise_scale_inv_aval, updated_amax_aval) @staticmethod def lowering( @@ -1216,7 +1212,8 @@ def impl( colwise_out, rowwise_scale_inv, colwise_scale_inv, - fifth, + updated_amax, + _, ) = GroupedQuantizePrimitive.inner_primitive.bind( x, scale, @@ -1227,13 +1224,7 @@ def impl( flatten_axis=flatten_axis, scale_dtype=scale_dtype, ) - use_v2 = GroupedQuantizePrimitive._use_v2_kernel(scaling_mode, x.shape, flatten_axis) - if use_v2: - # fifth is int64_workspace; return a dummy zero amax for interface compatibility - updated_amax = jnp.zeros((group_sizes.size,), jnp.float32) - else: - updated_amax = fifth - return (rowwise_out, colwise_out, rowwise_scale_inv, colwise_scale_inv, updated_amax) + return rowwise_out, colwise_out, rowwise_scale_inv, colwise_scale_inv, updated_amax register_primitive(GroupedQuantizePrimitive) diff --git a/transformer_engine/jax/csrc/extensions/quantization.cpp b/transformer_engine/jax/csrc/extensions/quantization.cpp index db9cf94db5..8fd587ca91 100644 --- a/transformer_engine/jax/csrc/extensions/quantization.cpp +++ b/transformer_engine/jax/csrc/extensions/quantization.cpp @@ -318,7 +318,7 @@ XLA_FFI_DEFINE_HANDLER_SYMBOL(DequantizeHandler, DequantizeFFI, Error_Type GroupedQuantizeFFI(cudaStream_t stream, Buffer_Type inputs, Buffer_Type scales, Buffer_Type group_sizes, Result_Type outputs, Result_Type colwise_outputs, Result_Type scale_invs, - Result_Type colwise_scale_invs, Result_Type amaxs, + Result_Type colwise_scale_invs, Result_Type amaxs, Result_Type _unused, JAXX_Scaling_Mode scaling_mode, JAXX_Quantize_Layout quantize_layout, int64_t flatten_axis) { NVTE_CHECK(scaling_mode != JAXX_Scaling_Mode::NO_SCALING, @@ -497,6 +497,7 @@ XLA_FFI_DEFINE_HANDLER_SYMBOL(GroupedQuantizeHandler, GroupedQuantizeFFI, .Ret() // scale_inv .Ret() // scale_inv colwise .Ret() // amax + .Ret() // unused (for compatibility with V2 interface) .Attr("scaling_mode") .Attr("q_layout") .Attr("flatten_axis")); @@ -504,7 +505,7 @@ XLA_FFI_DEFINE_HANDLER_SYMBOL(GroupedQuantizeHandler, GroupedQuantizeFFI, Error_Type GroupedQuantizeV2FFI(cudaStream_t stream, Buffer_Type inputs, Buffer_Type scale_unused, Buffer_Type group_sizes, Result_Type rowwise_out, Result_Type colwise_out, Result_Type rowwise_sinv, - Result_Type colwise_sinv, Result_Type int64_workspace, + Result_Type colwise_sinv, Result_Type updated_amaxs, Result_Type int64_workspace, JAXX_Quantize_Layout quantize_layout, int64_t flatten_axis) { (void)scale_unused; // scale is unused for MXFP8; accepted to match V1 input arity auto in_dtype = convert_ffi_datatype_to_te_dtype(inputs.element_type()); @@ -626,6 +627,7 @@ XLA_FFI_DEFINE_HANDLER_SYMBOL(GroupedQuantizeV2Handler, GroupedQuantizeV2FFI, .Ret() // colwise_out .Ret() // rowwise_sinv .Ret() // colwise_sinv + .Ret() // updated_amaxs .Ret() // int64_workspace .Attr("q_layout") .Attr("flatten_axis"), From 2af15e5fa9be9f661c5b8e4c1ff792022fe9e94e Mon Sep 17 00:00:00 2001 From: Jeremy Berchtold Date: Thu, 9 Apr 2026 18:12:59 -0700 Subject: [PATCH 55/60] Clean up grouped_gemm function Signed-off-by: Jeremy Berchtold --- transformer_engine/jax/cpp_extensions/gemm.py | 398 +++++++++--------- 1 file changed, 196 insertions(+), 202 deletions(-) diff --git a/transformer_engine/jax/cpp_extensions/gemm.py b/transformer_engine/jax/cpp_extensions/gemm.py index dc9e4e8839..fa48c77bbd 100644 --- a/transformer_engine/jax/cpp_extensions/gemm.py +++ b/transformer_engine/jax/cpp_extensions/gemm.py @@ -9,7 +9,7 @@ from collections.abc import Iterable from dataclasses import dataclass from functools import partial, reduce, cache -from typing import Tuple, Sequence, Union +from typing import Tuple, Sequence, Union, Optional from enum import Enum import warnings @@ -2091,13 +2091,7 @@ def _is_v2_grouped_gemm_supported( ) # V2 MXFP8 also requires that the "last" dimension (after axis_boundary) of both - # operands is a multiple of 128. The V2 GEMM setup kernel computes per-group - # scale pointers as ``data_offset / 32``, which equals ``K_blocks * last_dim``. - # The quantize kernel, however, pads the colwise scale stride to - # ``ceil(last_dim / 128) * 128``, making per-group padded scale larger than - # ``K_blocks * last_dim`` when ``last_dim`` is not 128-aligned. This causes - # adjacent groups' scales to overlap in the flat buffer. Fall back to V1 (which - # swizzles per-group scales individually) when the condition is not met. + # operands is a multiple of 128. This is because the MXFP8 scales must be padded to a multiple of (128, 4). The nvte_grouped_gemm setup kernels only handle the case when this dim is a multiple of 128 as well. If it is not, the GEMM setup kernel will not compute the scale offsets correctly and will read overlapping scales from the previous group, causing incorrect results. if lhs_shape is not None and lhs_axis_boundary is not None: lhs_last_dim = math.prod(lhs_shape[lhs_axis_boundary:]) if lhs_last_dim % 128 != 0: @@ -2159,6 +2153,150 @@ def is_v2_grouped_gemm_supported( return is_v2_supported, reason +def _get_out_dtype_and_scaling_mode(x: Union[GroupedNoScaleTensor, GroupedScaledTensor1x]) -> Tuple[jnp.dtype, ScalingMode]: + if isinstance(x, GroupedScaledTensor1x): + out_dtype = x.dq_dtype + scaling_mode = x.scaling_mode + elif isinstance(x, GroupedNoScaleTensor): + out_dtype = x.data.dtype + scaling_mode = ScalingMode.NO_SCALING + else: + raise TypeError( + f"Input must be GroupedNoScaleTensor or GroupedScaledTensor1x, got type={type(x)}" + ) + return out_dtype, scaling_mode + +def _infer_output_ragged_dims(lhs: Union[GroupedNoScaleTensor, GroupedScaledTensor1x], rhs: Union[GroupedNoScaleTensor, GroupedScaledTensor1x]) -> Tuple[Optional[jnp.ndarray], Optional[jnp.ndarray]]: + assert isinstance(lhs, (GroupedNoScaleTensor, GroupedScaledTensor1x)), f"Expected lhs to be GroupedNoScaleTensor or GroupedScaledTensor1x, got type={type(lhs)}" + assert isinstance(rhs, (GroupedNoScaleTensor, GroupedScaledTensor1x)), f"Expected rhs to be GroupedNoScaleTensor or GroupedScaledTensor1x, got type={type(rhs)}" + + # Infer output dims from which operand has the ragged non-contracting dim. + if rhs.first_dims is not None or rhs.last_dims is not None: + # Wgrad: rhs contracting dim is ragged → output is uniform (G prefix from num_groups) + out_first_dims = None + out_last_dims = None + elif lhs.first_dims is not None: + out_first_dims = lhs.first_dims + out_last_dims = None + elif lhs.last_dims is not None: + out_first_dims = None + out_last_dims = lhs.last_dims + else: + out_first_dims = out_last_dims = None + + return out_first_dims, out_last_dims + +def _adjust_contracting_dims_for_hopper_fp8_transpose( + lhs: Union[GroupedNoScaleTensor, GroupedScaledTensor1x], + rhs: Union[GroupedNoScaleTensor, GroupedScaledTensor1x], + lhs_contract_dim: Sequence[int], + rhs_contract_dim: Sequence[int], + lhs_is_trans: bool, + rhs_is_trans: bool, +) -> Tuple[Sequence[int], Sequence[int]]: + # Only support FP8 GEMM with NT layout on Hopper and other earlier GPUs + # thus additional transpose is required + lhs_layout_is_T = lhs.data_layout == "T" + rhs_layout_is_T = rhs.data_layout == "T" + # we can't apply _shape_normalization on the grouped input + # thus we need to ensure that lhs is in N and rhs is in T + if lhs_is_trans != lhs_layout_is_T: + raise RuntimeError("lhs input must be transposed before calling grouped_gemm") + if (not rhs_is_trans) != rhs_layout_is_T: + raise RuntimeError("rhs input must be transposed before calling grouped_gemm") + lhs_is_trans = False + rhs_is_trans = True + lhs_ndim = len(lhs.original_shape) + rhs_ndim = len(rhs.original_shape) + if lhs_layout_is_T: + lhs_contract_dim = tuple((lhs_ndim - 1 - i) % lhs_ndim for i in lhs_contract_dim) + if rhs_layout_is_T: + # For rhs [G, K, N], need to exclude the G dim from contract_dim + if ( + lhs_first_dims.size > 0 or lhs_last_dims.size > 0 + ): # fwd/dgrad: rhs has G as first dim + rhs_contract_dim = tuple( + (rhs_ndim - 1 - i) % (rhs_ndim - 1) + 1 for i in rhs_contract_dim + ) + else: + rhs_contract_dim = tuple((rhs_ndim - 1 - i) % rhs_ndim for i in rhs_contract_dim) + + return lhs_contract_dim, rhs_contract_dim + +def _quantize_inputs_if_needed( + lhs: Union[GroupedNoScaleTensor, GroupedScaledTensor1x], + rhs: Union[GroupedNoScaleTensor, GroupedScaledTensor1x], + quantizer_set: QuantizerSet, + lhs_is_trans: bool, + rhs_is_trans: bool, + lhs_flatten_axis: int, + rhs_flatten_axis: int, +) -> Tuple[Union[GroupedNoScaleTensor, GroupedScaledTensor1x], Union[GroupedNoScaleTensor, GroupedScaledTensor1x]]: + if quantizer_set is noop_quantizer_set: + return lhs, rhs + + assert isinstance(lhs, GroupedNoScaleTensor), f"Expected lhs to be GroupedNoScaleTensor before quantization, got type={type(lhs)}" + assert isinstance(rhs, GroupedNoScaleTensor), f"Expected rhs to be GroupedNoScaleTensor before quantization, got type={type(rhs)}" + + if not isinstance(quantizer_set.x, GroupedQuantizer): + raise TypeError( + "Expected quantizer_set.x to be GroupedQuantizer, but got" + f" type={type(quantizer_set.x)}" + ) + if type(quantizer_set.x) is not type(quantizer_set.kernel): + raise TypeError( + "Expected quantizer_set.x and quantizer_set.kernel to have the same type, but got" + f" {type(quantizer_set.x)} and {type(quantizer_set.kernel)}" + ) + scaling_mode = quantizer_set.x.scaling_mode + if ( + quantizer_set.x.scaling_mode.is_tensor_scaling() + and is_fp8_gemm_with_all_layouts_supported() + ): + lhs_is_rowwise = rhs_is_rowwise = True + else: + lhs_is_rowwise = not lhs_is_trans + rhs_is_rowwise = rhs_is_trans + quantizer_set.x.q_layout = ( + QuantizeLayout.ROWWISE if lhs_is_rowwise else QuantizeLayout.COLWISE + ) + quantizer_set.kernel.q_layout = ( + QuantizeLayout.ROWWISE if rhs_is_rowwise else QuantizeLayout.COLWISE + ) + empty_gs = jnp.empty((0,), jnp.int32) + active_group_sizes = next( + ( + gs + for gs in [lhs.first_dims, lhs.last_dims, rhs.first_dims, rhs.last_dims] + if gs.size > 0 + ), + empty_gs, + ) + lhs_input_data = lhs.data + rhs_input_data = rhs.data + lhs_q = grouped_quantize( + lhs_input_data, quantizer_set.x, active_group_sizes, lhs_flatten_axis + ) + rhs_q = grouped_quantize( + rhs_input_data, quantizer_set.kernel, group_sizes=None, flatten_axis=rhs_flatten_axis + ) + return lhs_q, rhs_q + +def _get_num_gemms( + lhs: Union[GroupedNoScaleTensor, GroupedScaledTensor1x], + rhs: Union[GroupedNoScaleTensor, GroupedScaledTensor1x], +) -> int: + num_gemms = 0 + for x in [lhs, rhs]: + if x.first_dims is not None: + return x.first_dims.size + if x.last_dims is not None: + return x.last_dims.size + raise ValueError( + "Cannot infer number of gemms since neither lhs nor rhs has first_dims or last_dims. " + "Ensure that at least one of the input tensors has valid first_dims or last_dims." + "For grouped_gemm, at least one tensor must be ragged." + ) def grouped_gemm( lhs: Union[GroupedNoScaleTensor, GroupedScaledTensor1x], @@ -2193,179 +2331,47 @@ def grouped_gemm( empty_gs = jnp.empty((0,), jnp.int32) - # Extract data, dims, and metadata from tensor objects. - # Keep data in its original layout (may be 1D for quantized tensors) to preserve - # JAX sharding; the C++ side uses original_shape to derive m/n/k. - if isinstance(lhs, GroupedNoScaleTensor): - lhs_data = lhs.data - lhs_shape = lhs.original_shape - lhs_scale_inv = jnp.empty((0,), jnp.float32) - scaling_mode = ScalingMode.NO_SCALING - out_dtype = lhs.data.dtype - lhs_first_dims = lhs.first_dims if lhs.first_dims is not None else empty_gs - lhs_last_dims = lhs.last_dims if lhs.last_dims is not None else empty_gs - elif isinstance(lhs, GroupedScaledTensor1x): - lhs_shape = lhs.original_shape - lhs_data = lhs.data - lhs_scale_inv = lhs.scale_inv - scaling_mode = lhs.scaling_mode - out_dtype = lhs.dq_dtype - lhs_first_dims = lhs.first_dims if lhs.first_dims is not None else empty_gs - lhs_last_dims = lhs.last_dims if lhs.last_dims is not None else empty_gs - else: - raise TypeError( - f"lhs must be GroupedNoScaleTensor or GroupedScaledTensor1x, got type={type(lhs)}" - ) - - if isinstance(rhs, GroupedNoScaleTensor): - rhs_data = rhs.data - rhs_shape = rhs.original_shape - rhs_scale_inv = jnp.empty((0,), jnp.float32) - rhs_first_dims = rhs.first_dims if rhs.first_dims is not None else empty_gs - rhs_last_dims = rhs.last_dims if rhs.last_dims is not None else empty_gs - elif isinstance(rhs, GroupedScaledTensor1x): - rhs_shape = rhs.original_shape - rhs_data = rhs.data - rhs_scale_inv = rhs.scale_inv - rhs_first_dims = rhs.first_dims if rhs.first_dims is not None else empty_gs - rhs_last_dims = rhs.last_dims if rhs.last_dims is not None else empty_gs - if isinstance(lhs, GroupedScaledTensor1x) and lhs.scaling_mode != rhs.scaling_mode: - raise ValueError( - f"Mismatched scaling modes: lhs.scaling_mode={lhs.scaling_mode}," - f" rhs.scaling_mode={rhs.scaling_mode}" - ) - if isinstance(lhs, GroupedScaledTensor1x): - scaling_mode = lhs.scaling_mode - else: - raise TypeError( - f"rhs must be GroupedNoScaleTensor or GroupedScaledTensor1x, got type={type(rhs)}" - ) + out_dtype, scaling_mode = _get_out_dtype_and_scaling_mode(lhs) + rhs_out_dtype, rhs_scaling_mode = _get_out_dtype_and_scaling_mode(rhs) + assert out_dtype == rhs_out_dtype, f"Mismatched output dtypes: {out_dtype} vs {rhs_out_dtype}" + assert scaling_mode == rhs_scaling_mode, f"Mismatched scaling modes: {scaling_mode} vs {rhs_scaling_mode}" + del rhs_out_dtype, rhs_scaling_mode - # Infer output dims from which operand has the ragged non-contracting dim. - if rhs_first_dims.size > 0 or rhs_last_dims.size > 0: - # Wgrad: rhs contracting dim is ragged → output is uniform (G prefix from num_groups) - out_first_dims = empty_gs - out_last_dims = empty_gs - elif lhs_first_dims.size > 0: - out_first_dims = lhs_first_dims - out_last_dims = empty_gs - elif lhs_last_dims.size > 0: - out_first_dims = empty_gs - out_last_dims = lhs_last_dims - else: - out_first_dims = out_last_dims = empty_gs + out_first_dims, out_last_dims = _infer_output_ragged_dims(lhs, rhs) out_dtype = preferred_element_type or out_dtype lhs_contract_dim, rhs_contract_dim = contracting_dims - lhs_is_trans = lhs_contract_dim[-1] != len(lhs_shape) - 1 + lhs_is_trans = lhs_contract_dim[-1] != len(lhs.original_shape) - 1 lhs_flatten_axis = len(lhs_contract_dim) * (1 if lhs_is_trans else -1) # rhs_is_trans: K is the last dim of rhs (i.e., rhs is in "T" layout). - rhs_is_trans = rhs_contract_dim[-1] == len(rhs_shape) - 1 + rhs_is_trans = rhs_contract_dim[-1] == len(rhs.original_shape) - 1 rhs_flatten_axis = -len(rhs_contract_dim) if rhs_is_trans else 1 + len(rhs_contract_dim) - if ( - not isinstance(lhs, ScaledTensor) - and not isinstance(rhs, ScaledTensor) - and quantizer_set != noop_quantizer_set - ): - if not isinstance(quantizer_set.x, GroupedQuantizer): - raise TypeError( - "Expected quantizer_set.x to be GroupedQuantizer, but got" - f" type={type(quantizer_set.x)}" - ) - if type(quantizer_set.x) is not type(quantizer_set.kernel): - raise TypeError( - "Expected quantizer_set.x and quantizer_set.kernel to have the same type, but got" - f" {type(quantizer_set.x)} and {type(quantizer_set.kernel)}" - ) - scaling_mode = quantizer_set.x.scaling_mode - if ( - quantizer_set.x.scaling_mode.is_tensor_scaling() - and is_fp8_gemm_with_all_layouts_supported() - ): - lhs_is_rowwise = rhs_is_rowwise = True - else: - lhs_is_rowwise = not lhs_is_trans - rhs_is_rowwise = rhs_is_trans - quantizer_set.x.q_layout = ( - QuantizeLayout.ROWWISE if lhs_is_rowwise else QuantizeLayout.COLWISE - ) - quantizer_set.kernel.q_layout = ( - QuantizeLayout.ROWWISE if rhs_is_rowwise else QuantizeLayout.COLWISE - ) - active_group_sizes = next( - ( - gs - for gs in [lhs_first_dims, lhs_last_dims, rhs_first_dims, rhs_last_dims] - if gs.size > 0 - ), - empty_gs, - ) - lhs_input_data = lhs.data if isinstance(lhs, GroupedNoScaleTensor) else lhs_data - rhs_input_data = rhs.data if isinstance(rhs, GroupedNoScaleTensor) else rhs_data - lhs_q = grouped_quantize( - lhs_input_data, quantizer_set.x, active_group_sizes, lhs_flatten_axis - ) - rhs_q = grouped_quantize( - rhs_input_data, quantizer_set.kernel, group_sizes=None, flatten_axis=rhs_flatten_axis - ) - lhs_data = lhs_q.data - rhs_data = rhs_q.data - lhs_scale_inv = lhs_q.scale_inv - rhs_scale_inv = rhs_q.scale_inv - lhs_shape = lhs_q.original_shape - rhs_shape = rhs_q.original_shape + lhs, rhs = _quantize_inputs_if_needed( + lhs, rhs, + quantizer_set, + lhs_is_trans, rhs_is_trans, + lhs_flatten_axis, rhs_flatten_axis + ) - if lhs_data.dtype == jnp.float8_e5m2 and rhs_data.dtype == jnp.float8_e5m2: + if lhs.data.dtype == jnp.float8_e5m2 and rhs.data.dtype == jnp.float8_e5m2: raise ValueError("FP8 GEMM does not support E5M2 * E5M2") - # Only support FP8 GEMM with NT layout on Hopper and other earlier GPUs - # thus additional transpose is required if scaling_mode.is_tensor_scaling() and not is_fp8_gemm_with_all_layouts_supported(): - if isinstance(lhs, ScaledTensor) and isinstance(rhs, ScaledTensor): - lhs_layout_is_T = lhs.data_layout == "T" - rhs_layout_is_T = rhs.data_layout == "T" - else: - lhs_layout_is_T = lhs_q.data_layout == "T" - rhs_layout_is_T = rhs_q.data_layout == "T" - # we can't apply _shape_normalization on the grouped input - # thus we need to ensure that lhs is in N and rhs is in T - if lhs_is_trans != lhs_layout_is_T: - raise RuntimeError("lhs input must be transposed before calling grouped_gemm") - if (not rhs_is_trans) != rhs_layout_is_T: - raise RuntimeError("rhs input must be transposed before calling grouped_gemm") - lhs_is_trans = False - rhs_is_trans = True - lhs_ndim = len(lhs_shape) - rhs_ndim = len(rhs_shape) - if lhs_layout_is_T: - lhs_contract_dim = tuple((lhs_ndim - 1 - i) % lhs_ndim for i in lhs_contract_dim) - if rhs_layout_is_T: - # For rhs [G, K, N], need to exclude the G dim from contract_dim - if ( - lhs_first_dims.size > 0 or lhs_last_dims.size > 0 - ): # fwd/dgrad: rhs has G as first dim - rhs_contract_dim = tuple( - (rhs_ndim - 1 - i) % (rhs_ndim - 1) + 1 for i in rhs_contract_dim - ) - else: - rhs_contract_dim = tuple((rhs_ndim - 1 - i) % rhs_ndim for i in rhs_contract_dim) + lhs_contract_dim, rhs_contract_dim = _adjust_contracting_dims_for_hopper_fp8_transpose( + lhs, rhs, + lhs_contract_dim, rhs_contract_dim, + lhs_is_trans, rhs_is_trans + ) # Compute N-D axis boundaries from final (post-adjustment) contracting dims. lhs_axis_boundary = get_lhs_axis_boundary(lhs_contract_dim, lhs_is_trans) rhs_axis_boundary = get_rhs_axis_boundary(rhs_contract_dim, rhs_is_trans) - num_gemms = ( - lhs_first_dims.size - or lhs_last_dims.size - or rhs_first_dims.size - or rhs_last_dims.size - or out_first_dims.size - or out_last_dims.size - ) + num_gemms = _get_num_gemms(lhs, rhs) if num_gemms == 0: raise ValueError( "grouped_gemm requires at least one non-empty dimension array. " @@ -2374,26 +2380,26 @@ def grouped_gemm( # Pre-compute collapsed 2D sizes from original N-D shapes. # These are static Python ints passed as primitive parameters (must be hashable). - lhs_left_size = math.prod(lhs_shape[:lhs_axis_boundary]) - lhs_right_size = math.prod(lhs_shape[lhs_axis_boundary:]) - rhs_left_size = math.prod(rhs_shape[:rhs_axis_boundary]) - rhs_right_size = math.prod(rhs_shape[rhs_axis_boundary:]) + lhs_left_size = math.prod(lhs.original_shape[:lhs_axis_boundary]) + lhs_right_size = math.prod(lhs.original_shape[lhs_axis_boundary:]) + rhs_left_size = math.prod(rhs.original_shape[:rhs_axis_boundary]) + rhs_right_size = math.prod(rhs.original_shape[rhs_axis_boundary:]) # Pre-compute output shape from N-D input shapes (static Python ints). if lhs_is_trans: - lhs_non_contracting = lhs_shape[lhs_axis_boundary:] + lhs_non_contracting = lhs.original_shape[lhs_axis_boundary:] else: - lhs_non_contracting = lhs_shape[:lhs_axis_boundary] + lhs_non_contracting = lhs.original_shape[:lhs_axis_boundary] if rhs_is_trans: - if rhs_first_dims.size > 0 or rhs_last_dims.size > 0: + if rhs.first_dims is not None or rhs.last_dims is not None: # wgrad: rhs (e.g. grad_T of shape (N, M)) has no G batch dim; include all dims - rhs_non_contracting = tuple(rhs_shape[d] for d in range(rhs_axis_boundary)) + rhs_non_contracting = tuple(rhs.original_shape[d] for d in range(rhs_axis_boundary)) else: # fwd/dgrad: rhs (e.g. kernel_T of shape (G, N, K)) has G batch dim at dim 0; skip it - rhs_non_contracting = tuple(rhs_shape[d] for d in range(rhs_axis_boundary) if d != 0) + rhs_non_contracting = tuple(rhs.original_shape[d] for d in range(rhs_axis_boundary) if d != 0) else: - rhs_non_contracting = rhs_shape[rhs_axis_boundary:] - if rhs_first_dims.size > 0 or rhs_last_dims.size > 0: + rhs_non_contracting = rhs.original_shape[rhs_axis_boundary:] + if rhs.first_dims is not None or rhs.last_dims is not None: out_shape = (num_gemms, *lhs_non_contracting, *rhs_non_contracting) else: out_shape = (*lhs_non_contracting, *rhs_non_contracting) @@ -2416,31 +2422,19 @@ def grouped_gemm( use_v2_ffi, _ = is_v2_grouped_gemm_supported( scaling_mode, - lhs_data.dtype, + lhs.data.dtype, has_bias, - lhs_shape=lhs_shape, - rhs_shape=rhs_shape, + lhs_shape=lhs.original_shape, + rhs_shape=rhs.original_shape, lhs_axis_boundary=lhs_axis_boundary, rhs_axis_boundary=rhs_axis_boundary, ) - # V2 grouped GEMM requires MXFP8 inputs to be pre-swizzled by V2 grouped quantize - # (nvte_group_quantize fuses the swizzle). The C++ V2 GEMM FFI does not re-swizzle. - if use_v2_ffi and scaling_mode == ScalingMode.MXFP8_1D_SCALING: - if isinstance(lhs, GroupedScaledTensor1x) and not lhs.pre_swizzled: - raise ValueError( - "V2 grouped GEMM requires MXFP8 lhs scale_inv to be pre-swizzled. " - "GroupedScaledTensor1x.pre_swizzled is False. " - "Use V2 grouped quantize (nvte_group_quantize, requires SM100+ and " - "128-aligned shapes) to produce pre-swizzled tensors." - ) - if isinstance(rhs, GroupedScaledTensor1x) and not rhs.pre_swizzled: - raise ValueError( - "V2 grouped GEMM requires MXFP8 rhs scale_inv to be pre-swizzled. " - "GroupedScaledTensor1x.pre_swizzled is False. " - "Use V2 grouped quantize (nvte_group_quantize, requires SM100+ and " - "128-aligned shapes) to produce pre-swizzled tensors." - ) + if scaling_mode == ScalingMode.MXFP8_1D_SCALING: + # Pre-swizzling is required for both V1 and V2. GroupedQuantize handles this. + assert lhs.pre_swizzled, "lhs must be pre-swizzled for MXFP8 1D scaling" + assert rhs.pre_swizzled, "rhs must be pre-swizzled for MXFP8 1D scaling" + if use_v2_ffi: additional_arg_0 = jnp.ones((num_gemms,), jnp.float32) # alpha @@ -2450,17 +2444,17 @@ def grouped_gemm( additional_arg_1 = jnp.zeros((0,), jnp.int32) # unused placeholder (out,) = GroupedGemmPrimitive.outer_primitive.bind( - lhs_data, - lhs_scale_inv, - rhs_data, - rhs_scale_inv, + lhs.data, + lhs.scale_inv if isinstance(lhs, GroupedScaledTensor1x) else jnp.empty((0,), jnp.float32), + rhs.data, + rhs.scale_inv if isinstance(rhs, GroupedScaledTensor1x) else jnp.empty((0,), jnp.float32), bias, - lhs_first_dims, - lhs_last_dims, - rhs_first_dims, - rhs_last_dims, - out_first_dims, - out_last_dims, + lhs.first_dims if lhs.first_dims is not None else empty_gs, + lhs.last_dims if lhs.last_dims is not None else empty_gs, + rhs.first_dims if rhs.first_dims is not None else empty_gs, + rhs.last_dims if rhs.last_dims is not None else empty_gs, + out_first_dims if out_first_dims is not None else empty_gs, + out_last_dims if out_last_dims is not None else empty_gs, additional_arg_0, additional_arg_1, lhs_is_trans=lhs_is_trans, From 6535819ba459302117bb782e4d9fed68f3095f20 Mon Sep 17 00:00:00 2001 From: Jeremy Berchtold Date: Fri, 10 Apr 2026 08:44:45 -0700 Subject: [PATCH 56/60] Test fixes Signed-off-by: Jeremy Berchtold --- tests/jax/test_custom_call_compute.py | 15 +++++++++++++++ transformer_engine/jax/cpp_extensions/gemm.py | 9 ++++++++- .../jax/cpp_extensions/quantization.py | 19 +++++++++---------- 3 files changed, 32 insertions(+), 11 deletions(-) diff --git a/tests/jax/test_custom_call_compute.py b/tests/jax/test_custom_call_compute.py index 58de167bb6..4de4d79e5d 100644 --- a/tests/jax/test_custom_call_compute.py +++ b/tests/jax/test_custom_call_compute.py @@ -27,6 +27,7 @@ from transformer_engine.jax.cpp_extensions.quantization import ( _jax_quantize, _jax_quantize_dbias, + GroupedQuantizePrimitive, ) from transformer_engine.jax.cpp_extensions.misc import get_cudnn_version from transformer_engine.jax import cpp_extensions as tex @@ -1113,6 +1114,20 @@ def test_grouped_qdq( if flatten_axis == -2: input_shape = input_shape[:-1] + (2,) + input_shape[-1:] + # V2 MXFP8 quantize requires every individual group size to be a multiple of 128. + # group_size_multiplier=32 can produce groups of 32 or 64 rows which violate this. + # This cannot be checked at runtime (group sizes live on device), so we skip the + # test configuration rather than weaken the kernel-selection logic. + if ( + scaling_mode == ScalingMode.MXFP8_1D_SCALING + and group_size_multiplier % 128 != 0 + and GroupedQuantizePrimitive._use_v2_kernel(scaling_mode.value, input_shape, flatten_axis) + ): + pytest.skip( + f"MXFP8 V2 quantize requires each group to be 128-aligned; " + f"group_size_multiplier={group_size_multiplier} may produce smaller groups" + ) + x = jax.random.uniform(subkeys[1], input_shape, in_dtype) grouped_quantizer = QuantizerFactory.create( diff --git a/transformer_engine/jax/cpp_extensions/gemm.py b/transformer_engine/jax/cpp_extensions/gemm.py index fa48c77bbd..006115c7cb 100644 --- a/transformer_engine/jax/cpp_extensions/gemm.py +++ b/transformer_engine/jax/cpp_extensions/gemm.py @@ -2357,6 +2357,11 @@ def grouped_gemm( lhs_flatten_axis, rhs_flatten_axis ) + # Re-read scaling_mode after quantization: if _quantize_inputs_if_needed converted + # GroupedNoScaleTensor → GroupedScaledTensor1x, the original scaling_mode (NO_SCALING) + # would cause the C++ kernel to skip scale_inv setup, triggering a cuBLAS assertion. + _, scaling_mode = _get_out_dtype_and_scaling_mode(lhs) + if lhs.data.dtype == jnp.float8_e5m2 and rhs.data.dtype == jnp.float8_e5m2: raise ValueError("FP8 GEMM does not support E5M2 * E5M2") @@ -2431,7 +2436,9 @@ def grouped_gemm( ) if scaling_mode == ScalingMode.MXFP8_1D_SCALING: - # Pre-swizzling is required for both V1 and V2. GroupedQuantize handles this. + # Both V1 and V2 quantize produce pre-swizzled scales (V1 via + # set_with_gemm_swizzled_scales, V2 via nvte_group_quantize). Require that + # grouped_quantize has set pre_swizzled=True on the input tensors. assert lhs.pre_swizzled, "lhs must be pre-swizzled for MXFP8 1D scaling" assert rhs.pre_swizzled, "rhs must be pre-swizzled for MXFP8 1D scaling" diff --git a/transformer_engine/jax/cpp_extensions/quantization.py b/transformer_engine/jax/cpp_extensions/quantization.py index b5455f5e6a..f6e3560cd3 100644 --- a/transformer_engine/jax/cpp_extensions/quantization.py +++ b/transformer_engine/jax/cpp_extensions/quantization.py @@ -1023,7 +1023,9 @@ def _use_v2_kernel(scaling_mode, x_shape, flatten_axis): per-group row count non_group_m = prod(x_shape[1:eff]) must also be divisible by 128. 4. For lhs-style tensors (eff == 1, shape M×K), individual group sizes must - be 128-aligned — this is a dynamic constraint assumed by the caller. + be 128-aligned — this is a dynamic constraint that cannot be checked here + because group sizes live on device. The caller is responsible for ensuring + this. 5. The last logical dimension (contracting dim K or output dim N) must be divisible by 128, matching the V2 grouped GEMM constraint so that the two always agree on V1 vs V2. @@ -1335,14 +1337,11 @@ def grouped_quantize( for i, quantizer_i in enumerate(quantizer.quantizers): quantizer_i.update(updated_amax[i].reshape((1,))) - # V2 grouped quantize (nvte_group_quantize) fuses the scale_inv swizzle into - # the kernel, so the resulting tensors are already swizzled for GEMM. - # Note: V1 also produces swizzled scales (via set_with_gemm_swizzled_scales), - # but pre_swizzled is only set for V2 to maintain pytree compatibility. - # The dequantizer detects MXFP8 swizzling via the scaling_mode instead. - use_v2 = GroupedQuantizePrimitive._use_v2_kernel( - quantizer.scaling_mode.value, x.shape, flatten_axis - ) + # Both V1 (set_with_gemm_swizzled_scales) and V2 (nvte_group_quantize) produce + # pre-swizzled scale_inv tensors for use by the grouped GEMM kernel. Set + # pre_swizzled=True for all MXFP8 grouped quantization so that grouped_gemm can + # assert this invariant unconditionally. + is_mxfp8 = quantizer.scaling_mode == ScalingMode.MXFP8_1D_SCALING out = ScaledTensorFactory.create( data=rowwise_casted_output, scale_inv=rowwise_scale_inv, @@ -1355,7 +1354,7 @@ def grouped_quantize( flatten_axis=flatten_axis, first_dims=ragged_first_dims, original_shape=original_shape, - pre_swizzled=use_v2, + pre_swizzled=is_mxfp8, ) return out From bf6377b23a945a1f266b352cee7001a7f2f48d83 Mon Sep 17 00:00:00 2001 From: Jeremy Berchtold Date: Fri, 10 Apr 2026 10:04:45 -0700 Subject: [PATCH 57/60] Fix old var names in V1 python codepath Signed-off-by: Jeremy Berchtold --- transformer_engine/jax/cpp_extensions/gemm.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/transformer_engine/jax/cpp_extensions/gemm.py b/transformer_engine/jax/cpp_extensions/gemm.py index 006115c7cb..23cdb66733 100644 --- a/transformer_engine/jax/cpp_extensions/gemm.py +++ b/transformer_engine/jax/cpp_extensions/gemm.py @@ -2213,7 +2213,7 @@ def _adjust_contracting_dims_for_hopper_fp8_transpose( if rhs_layout_is_T: # For rhs [G, K, N], need to exclude the G dim from contract_dim if ( - lhs_first_dims.size > 0 or lhs_last_dims.size > 0 + lhs.first_dims is not None or lhs.last_dims is not None ): # fwd/dgrad: rhs has G as first dim rhs_contract_dim = tuple( (rhs_ndim - 1 - i) % (rhs_ndim - 1) + 1 for i in rhs_contract_dim @@ -2268,7 +2268,7 @@ def _quantize_inputs_if_needed( ( gs for gs in [lhs.first_dims, lhs.last_dims, rhs.first_dims, rhs.last_dims] - if gs.size > 0 + if gs is not None and gs.size > 0 ), empty_gs, ) From 16a4bf745c10a0f3548219150797c737374639b9 Mon Sep 17 00:00:00 2001 From: Jeremy Berchtold Date: Fri, 10 Apr 2026 10:29:57 -0700 Subject: [PATCH 58/60] Fix lint Signed-off-by: Jeremy Berchtold --- transformer_engine/jax/cpp_extensions/gemm.py | 3 --- transformer_engine/jax/quantize/dequantizer.py | 10 +++++----- 2 files changed, 5 insertions(+), 8 deletions(-) diff --git a/transformer_engine/jax/cpp_extensions/gemm.py b/transformer_engine/jax/cpp_extensions/gemm.py index 23cdb66733..2833dbb44c 100644 --- a/transformer_engine/jax/cpp_extensions/gemm.py +++ b/transformer_engine/jax/cpp_extensions/gemm.py @@ -1598,7 +1598,6 @@ def _compute_cublas_workspace_size( workspace_size = get_cublas_workspace_size_bytes() * stream_count workspace_alignment_padding = 256 tensor_scaling_sinv_aligment = 16 - mxfp8_scaling_sinv_alignment_padding = 256 # cuBLAS workspace ptr must be 256 bytes aligned but JAX buffers are not # necessarily 256 bytes aligned, we add some padding to ensure alignment. workspace_size += workspace_alignment_padding @@ -2248,7 +2247,6 @@ def _quantize_inputs_if_needed( "Expected quantizer_set.x and quantizer_set.kernel to have the same type, but got" f" {type(quantizer_set.x)} and {type(quantizer_set.kernel)}" ) - scaling_mode = quantizer_set.x.scaling_mode if ( quantizer_set.x.scaling_mode.is_tensor_scaling() and is_fp8_gemm_with_all_layouts_supported() @@ -2286,7 +2284,6 @@ def _get_num_gemms( lhs: Union[GroupedNoScaleTensor, GroupedScaledTensor1x], rhs: Union[GroupedNoScaleTensor, GroupedScaledTensor1x], ) -> int: - num_gemms = 0 for x in [lhs, rhs]: if x.first_dims is not None: return x.first_dims.size diff --git a/transformer_engine/jax/quantize/dequantizer.py b/transformer_engine/jax/quantize/dequantizer.py index b46e4ff9d5..ca44c2e4af 100644 --- a/transformer_engine/jax/quantize/dequantizer.py +++ b/transformer_engine/jax/quantize/dequantizer.py @@ -287,11 +287,11 @@ def _unswizzle_mxfp8_grouped_scale(scale_inv_flat, padded_scale_2d, is_colwise): unswizzled = jnp.transpose(reshaped, (0, 3, 2, 1, 4)) # Back to transposed 2D, then un-transpose return jnp.transpose(unswizzled.reshape(rows, cols)) - else: - rows, cols = padded_scale_2d - reshaped = scale_inv_flat.reshape(rows // 128, cols // 4, 32, 4, 4) - unswizzled = jnp.transpose(reshaped, (0, 3, 2, 1, 4)) - return unswizzled.reshape(rows, cols) + + rows, cols = padded_scale_2d + reshaped = scale_inv_flat.reshape(rows // 128, cols // 4, 32, 4, 4) + unswizzled = jnp.transpose(reshaped, (0, 3, 2, 1, 4)) + return unswizzled.reshape(rows, cols) def _grouped_dequantize(grouped_scaled_tensor): From 9ced1c51d6c378315cacb22174a0d5ae20de0313 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Fri, 10 Apr 2026 17:31:56 +0000 Subject: [PATCH 59/60] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- tests/jax/test_custom_call_compute.py | 18 +- transformer_engine/jax/cpp_extensions/gemm.py | 156 +++++++++++------- .../jax/cpp_extensions/quantization.py | 12 +- .../jax/csrc/extensions/quantization.cpp | 42 ++--- 4 files changed, 143 insertions(+), 85 deletions(-) diff --git a/tests/jax/test_custom_call_compute.py b/tests/jax/test_custom_call_compute.py index 4de4d79e5d..7e680cc591 100644 --- a/tests/jax/test_custom_call_compute.py +++ b/tests/jax/test_custom_call_compute.py @@ -1079,7 +1079,7 @@ def test_rht_gemm(self, in_dtype, q_dtype, scaling_mode, m, n, k, data_layout, w @pytest_parametrize_wrapper( "group_size_multiplier", [ - 32, # V1 MXFP8: group size must be multiple of 32 + 32, # V1 MXFP8: group size must be multiple of 32 128, # V2 MXFP8 eligible: group size must be multiple of 128 ], ) @@ -1092,7 +1092,15 @@ def test_rht_gemm(self, in_dtype, q_dtype, scaling_mode, m, n, k, data_layout, w ) class TestGroupedQuantize: def test_grouped_qdq( - self, in_dtype, input_shape, group_size_multiplier, q_dtype, scaling_mode, q_layout, flatten_axis, with_group_sizes + self, + in_dtype, + input_shape, + group_size_multiplier, + q_dtype, + scaling_mode, + q_layout, + flatten_axis, + with_group_sizes, ): n_groups, m, n = input_shape key = jax.random.PRNGKey(0) @@ -1121,10 +1129,12 @@ def test_grouped_qdq( if ( scaling_mode == ScalingMode.MXFP8_1D_SCALING and group_size_multiplier % 128 != 0 - and GroupedQuantizePrimitive._use_v2_kernel(scaling_mode.value, input_shape, flatten_axis) + and GroupedQuantizePrimitive._use_v2_kernel( + scaling_mode.value, input_shape, flatten_axis + ) ): pytest.skip( - f"MXFP8 V2 quantize requires each group to be 128-aligned; " + "MXFP8 V2 quantize requires each group to be 128-aligned; " f"group_size_multiplier={group_size_multiplier} may produce smaller groups" ) diff --git a/transformer_engine/jax/cpp_extensions/gemm.py b/transformer_engine/jax/cpp_extensions/gemm.py index 2833dbb44c..e42608e0b4 100644 --- a/transformer_engine/jax/cpp_extensions/gemm.py +++ b/transformer_engine/jax/cpp_extensions/gemm.py @@ -2035,6 +2035,7 @@ def _should_enforce_v2_grouped_gemm() -> bool: f"NVTE_JAX_ENFORCE_V2_GROUPED_GEMM must be an integer (0 or 1), got: {val!r}" ) from e + def _is_v2_grouped_gemm_supported( scaling_mode: ScalingMode, dtype: jnp.dtype, @@ -2047,17 +2048,23 @@ def _is_v2_grouped_gemm_supported( """Determine whether the V2 grouped GEMM implementation can be used based on the input parameters.""" if not _v2_grouped_gemm_available: - return False, ( - "TE was not compiled with support for the V2 grouped GEMM kernel, reason: " - f"{_v2_grouped_gemm_available_reason}" + return ( + False, + ( + "TE was not compiled with support for the V2 grouped GEMM kernel, reason: " + f"{_v2_grouped_gemm_available_reason}" + ), ) # nvte_grouped_gemm (the v2 kernel) requires SM100+ (Blackwell or newer). # Fall back to the v1 path on SM90 (Hopper) and older architectures. if get_min_device_compute_capability() < 100: - return False, ( - "The TE V2 grouped GEMM requires SM100+ (Blackwell or newer) but current min device" - f" compute capability is {get_min_device_compute_capability()}." + return ( + False, + ( + "The TE V2 grouped GEMM requires SM100+ (Blackwell or newer) but current min device" + f" compute capability is {get_min_device_compute_capability()}." + ), ) if has_bias: @@ -2073,20 +2080,26 @@ def _is_v2_grouped_gemm_supported( if lhs_shape is not None and lhs_axis_boundary is not None: lhs_first_dim = math.prod(lhs_shape[:lhs_axis_boundary]) if lhs_first_dim % 128 != 0: - return False, ( - "The TE V2 grouped GEMM for MXFP8 requires the product of the first" - " dimensions (up to axis_boundary) of LHS to be divisible by 128, but got" - f" {lhs_first_dim} with lhs_shape={lhs_shape} and" - f" lhs_axis_boundary={lhs_axis_boundary}." + return ( + False, + ( + "The TE V2 grouped GEMM for MXFP8 requires the product of the first" + " dimensions (up to axis_boundary) of LHS to be divisible by 128, but got" + f" {lhs_first_dim} with lhs_shape={lhs_shape} and" + f" lhs_axis_boundary={lhs_axis_boundary}." + ), ) if rhs_shape is not None and rhs_axis_boundary is not None: rhs_first_dim = math.prod(rhs_shape[:rhs_axis_boundary]) if rhs_first_dim % 128 != 0: - return False, ( - "The TE V2 grouped GEMM for MXFP8 requires the product of the first" - " dimensions (up to axis_boundary) of RHS to be divisible by 128, but got" - f" {rhs_first_dim} with rhs_shape={rhs_shape} and" - f" rhs_axis_boundary={rhs_axis_boundary}." + return ( + False, + ( + "The TE V2 grouped GEMM for MXFP8 requires the product of the first" + " dimensions (up to axis_boundary) of RHS to be divisible by 128, but got" + f" {rhs_first_dim} with rhs_shape={rhs_shape} and" + f" rhs_axis_boundary={rhs_axis_boundary}." + ), ) # V2 MXFP8 also requires that the "last" dimension (after axis_boundary) of both @@ -2094,31 +2107,41 @@ def _is_v2_grouped_gemm_supported( if lhs_shape is not None and lhs_axis_boundary is not None: lhs_last_dim = math.prod(lhs_shape[lhs_axis_boundary:]) if lhs_last_dim % 128 != 0: - return False, ( - "The TE V2 grouped GEMM for MXFP8 requires the product of the last" - " dimensions (after axis_boundary) of LHS to be divisible by 128, but got" - f" {lhs_last_dim} with lhs_shape={lhs_shape} and" - f" lhs_axis_boundary={lhs_axis_boundary}." + return ( + False, + ( + "The TE V2 grouped GEMM for MXFP8 requires the product of the last" + " dimensions (after axis_boundary) of LHS to be divisible by 128, but got" + f" {lhs_last_dim} with lhs_shape={lhs_shape} and" + f" lhs_axis_boundary={lhs_axis_boundary}." + ), ) if rhs_shape is not None and rhs_axis_boundary is not None: rhs_last_dim = math.prod(rhs_shape[rhs_axis_boundary:]) if rhs_last_dim % 128 != 0: - return False, ( - "The TE V2 grouped GEMM for MXFP8 requires the product of the last" - " dimensions (after axis_boundary) of RHS to be divisible by 128, but got" - f" {rhs_last_dim} with rhs_shape={rhs_shape} and" - f" rhs_axis_boundary={rhs_axis_boundary}." + return ( + False, + ( + "The TE V2 grouped GEMM for MXFP8 requires the product of the last" + " dimensions (after axis_boundary) of RHS to be divisible by 128, but got" + f" {rhs_last_dim} with rhs_shape={rhs_shape} and" + f" rhs_axis_boundary={rhs_axis_boundary}." + ), ) return True, "" - return False, ( - "The TE V2 grouped GEMM currently only supports non-quantized BF16 and MXFP8 with 1D" - " block scaling, but NVTE_JAX_ENFORCE_V2_GROUPED_GEMM is enabled and the input" - f" parameters do not meet these requirements (scaling_mode= {scaling_mode}," - f" dtype={dtype}, has_bias={has_bias}, lhs_shape={lhs_shape}, rhs_shape={rhs_shape}," - f" lhs_axis_boundary={lhs_axis_boundary}, rhs_axis_boundary={rhs_axis_boundary})." + return ( + False, + ( + "The TE V2 grouped GEMM currently only supports non-quantized BF16 and MXFP8 with 1D" + " block scaling, but NVTE_JAX_ENFORCE_V2_GROUPED_GEMM is enabled and the input" + f" parameters do not meet these requirements (scaling_mode= {scaling_mode}," + f" dtype={dtype}, has_bias={has_bias}, lhs_shape={lhs_shape}, rhs_shape={rhs_shape}," + f" lhs_axis_boundary={lhs_axis_boundary}, rhs_axis_boundary={rhs_axis_boundary})." + ), ) + def is_v2_grouped_gemm_supported( scaling_mode: ScalingMode, dtype: jnp.dtype, @@ -2128,8 +2151,8 @@ def is_v2_grouped_gemm_supported( lhs_axis_boundary=None, rhs_axis_boundary=None, ) -> tuple[bool, str]: - """ Determine whether the V2 grouped GEMM implementation can be used based on the input parameters. - + """Determine whether the V2 grouped GEMM implementation can be used based on the input parameters. + Returns: A tuple of (is_supported: bool, reason: str) where is_supported indicates whether the V2 grouped GEMM can be used, and reason provides an explanation if it is not supported. """ @@ -2152,7 +2175,10 @@ def is_v2_grouped_gemm_supported( return is_v2_supported, reason -def _get_out_dtype_and_scaling_mode(x: Union[GroupedNoScaleTensor, GroupedScaledTensor1x]) -> Tuple[jnp.dtype, ScalingMode]: + +def _get_out_dtype_and_scaling_mode( + x: Union[GroupedNoScaleTensor, GroupedScaledTensor1x], +) -> Tuple[jnp.dtype, ScalingMode]: if isinstance(x, GroupedScaledTensor1x): out_dtype = x.dq_dtype scaling_mode = x.scaling_mode @@ -2165,9 +2191,17 @@ def _get_out_dtype_and_scaling_mode(x: Union[GroupedNoScaleTensor, GroupedScaled ) return out_dtype, scaling_mode -def _infer_output_ragged_dims(lhs: Union[GroupedNoScaleTensor, GroupedScaledTensor1x], rhs: Union[GroupedNoScaleTensor, GroupedScaledTensor1x]) -> Tuple[Optional[jnp.ndarray], Optional[jnp.ndarray]]: - assert isinstance(lhs, (GroupedNoScaleTensor, GroupedScaledTensor1x)), f"Expected lhs to be GroupedNoScaleTensor or GroupedScaledTensor1x, got type={type(lhs)}" - assert isinstance(rhs, (GroupedNoScaleTensor, GroupedScaledTensor1x)), f"Expected rhs to be GroupedNoScaleTensor or GroupedScaledTensor1x, got type={type(rhs)}" + +def _infer_output_ragged_dims( + lhs: Union[GroupedNoScaleTensor, GroupedScaledTensor1x], + rhs: Union[GroupedNoScaleTensor, GroupedScaledTensor1x], +) -> Tuple[Optional[jnp.ndarray], Optional[jnp.ndarray]]: + assert isinstance( + lhs, (GroupedNoScaleTensor, GroupedScaledTensor1x) + ), f"Expected lhs to be GroupedNoScaleTensor or GroupedScaledTensor1x, got type={type(lhs)}" + assert isinstance( + rhs, (GroupedNoScaleTensor, GroupedScaledTensor1x) + ), f"Expected rhs to be GroupedNoScaleTensor or GroupedScaledTensor1x, got type={type(rhs)}" # Infer output dims from which operand has the ragged non-contracting dim. if rhs.first_dims is not None or rhs.last_dims is not None: @@ -2185,6 +2219,7 @@ def _infer_output_ragged_dims(lhs: Union[GroupedNoScaleTensor, GroupedScaledTens return out_first_dims, out_last_dims + def _adjust_contracting_dims_for_hopper_fp8_transpose( lhs: Union[GroupedNoScaleTensor, GroupedScaledTensor1x], rhs: Union[GroupedNoScaleTensor, GroupedScaledTensor1x], @@ -2222,6 +2257,7 @@ def _adjust_contracting_dims_for_hopper_fp8_transpose( return lhs_contract_dim, rhs_contract_dim + def _quantize_inputs_if_needed( lhs: Union[GroupedNoScaleTensor, GroupedScaledTensor1x], rhs: Union[GroupedNoScaleTensor, GroupedScaledTensor1x], @@ -2230,17 +2266,23 @@ def _quantize_inputs_if_needed( rhs_is_trans: bool, lhs_flatten_axis: int, rhs_flatten_axis: int, -) -> Tuple[Union[GroupedNoScaleTensor, GroupedScaledTensor1x], Union[GroupedNoScaleTensor, GroupedScaledTensor1x]]: +) -> Tuple[ + Union[GroupedNoScaleTensor, GroupedScaledTensor1x], + Union[GroupedNoScaleTensor, GroupedScaledTensor1x], +]: if quantizer_set is noop_quantizer_set: return lhs, rhs - assert isinstance(lhs, GroupedNoScaleTensor), f"Expected lhs to be GroupedNoScaleTensor before quantization, got type={type(lhs)}" - assert isinstance(rhs, GroupedNoScaleTensor), f"Expected rhs to be GroupedNoScaleTensor before quantization, got type={type(rhs)}" + assert isinstance( + lhs, GroupedNoScaleTensor + ), f"Expected lhs to be GroupedNoScaleTensor before quantization, got type={type(lhs)}" + assert isinstance( + rhs, GroupedNoScaleTensor + ), f"Expected rhs to be GroupedNoScaleTensor before quantization, got type={type(rhs)}" if not isinstance(quantizer_set.x, GroupedQuantizer): raise TypeError( - "Expected quantizer_set.x to be GroupedQuantizer, but got" - f" type={type(quantizer_set.x)}" + f"Expected quantizer_set.x to be GroupedQuantizer, but got type={type(quantizer_set.x)}" ) if type(quantizer_set.x) is not type(quantizer_set.kernel): raise TypeError( @@ -2255,9 +2297,7 @@ def _quantize_inputs_if_needed( else: lhs_is_rowwise = not lhs_is_trans rhs_is_rowwise = rhs_is_trans - quantizer_set.x.q_layout = ( - QuantizeLayout.ROWWISE if lhs_is_rowwise else QuantizeLayout.COLWISE - ) + quantizer_set.x.q_layout = QuantizeLayout.ROWWISE if lhs_is_rowwise else QuantizeLayout.COLWISE quantizer_set.kernel.q_layout = ( QuantizeLayout.ROWWISE if rhs_is_rowwise else QuantizeLayout.COLWISE ) @@ -2272,14 +2312,13 @@ def _quantize_inputs_if_needed( ) lhs_input_data = lhs.data rhs_input_data = rhs.data - lhs_q = grouped_quantize( - lhs_input_data, quantizer_set.x, active_group_sizes, lhs_flatten_axis - ) + lhs_q = grouped_quantize(lhs_input_data, quantizer_set.x, active_group_sizes, lhs_flatten_axis) rhs_q = grouped_quantize( rhs_input_data, quantizer_set.kernel, group_sizes=None, flatten_axis=rhs_flatten_axis ) return lhs_q, rhs_q + def _get_num_gemms( lhs: Union[GroupedNoScaleTensor, GroupedScaledTensor1x], rhs: Union[GroupedNoScaleTensor, GroupedScaledTensor1x], @@ -2295,6 +2334,7 @@ def _get_num_gemms( "For grouped_gemm, at least one tensor must be ragged." ) + def grouped_gemm( lhs: Union[GroupedNoScaleTensor, GroupedScaledTensor1x], rhs: Union[GroupedNoScaleTensor, GroupedScaledTensor1x], @@ -2331,7 +2371,9 @@ def grouped_gemm( out_dtype, scaling_mode = _get_out_dtype_and_scaling_mode(lhs) rhs_out_dtype, rhs_scaling_mode = _get_out_dtype_and_scaling_mode(rhs) assert out_dtype == rhs_out_dtype, f"Mismatched output dtypes: {out_dtype} vs {rhs_out_dtype}" - assert scaling_mode == rhs_scaling_mode, f"Mismatched scaling modes: {scaling_mode} vs {rhs_scaling_mode}" + assert ( + scaling_mode == rhs_scaling_mode + ), f"Mismatched scaling modes: {scaling_mode} vs {rhs_scaling_mode}" del rhs_out_dtype, rhs_scaling_mode out_first_dims, out_last_dims = _infer_output_ragged_dims(lhs, rhs) @@ -2348,10 +2390,7 @@ def grouped_gemm( rhs_flatten_axis = -len(rhs_contract_dim) if rhs_is_trans else 1 + len(rhs_contract_dim) lhs, rhs = _quantize_inputs_if_needed( - lhs, rhs, - quantizer_set, - lhs_is_trans, rhs_is_trans, - lhs_flatten_axis, rhs_flatten_axis + lhs, rhs, quantizer_set, lhs_is_trans, rhs_is_trans, lhs_flatten_axis, rhs_flatten_axis ) # Re-read scaling_mode after quantization: if _quantize_inputs_if_needed converted @@ -2364,9 +2403,7 @@ def grouped_gemm( if scaling_mode.is_tensor_scaling() and not is_fp8_gemm_with_all_layouts_supported(): lhs_contract_dim, rhs_contract_dim = _adjust_contracting_dims_for_hopper_fp8_transpose( - lhs, rhs, - lhs_contract_dim, rhs_contract_dim, - lhs_is_trans, rhs_is_trans + lhs, rhs, lhs_contract_dim, rhs_contract_dim, lhs_is_trans, rhs_is_trans ) # Compute N-D axis boundaries from final (post-adjustment) contracting dims. @@ -2398,7 +2435,9 @@ def grouped_gemm( rhs_non_contracting = tuple(rhs.original_shape[d] for d in range(rhs_axis_boundary)) else: # fwd/dgrad: rhs (e.g. kernel_T of shape (G, N, K)) has G batch dim at dim 0; skip it - rhs_non_contracting = tuple(rhs.original_shape[d] for d in range(rhs_axis_boundary) if d != 0) + rhs_non_contracting = tuple( + rhs.original_shape[d] for d in range(rhs_axis_boundary) if d != 0 + ) else: rhs_non_contracting = rhs.original_shape[rhs_axis_boundary:] if rhs.first_dims is not None or rhs.last_dims is not None: @@ -2439,7 +2478,6 @@ def grouped_gemm( assert lhs.pre_swizzled, "lhs must be pre-swizzled for MXFP8 1D scaling" assert rhs.pre_swizzled, "rhs must be pre-swizzled for MXFP8 1D scaling" - if use_v2_ffi: additional_arg_0 = jnp.ones((num_gemms,), jnp.float32) # alpha additional_arg_1 = jnp.zeros((num_gemms,), jnp.float32) # beta diff --git a/transformer_engine/jax/cpp_extensions/quantization.py b/transformer_engine/jax/cpp_extensions/quantization.py index f6e3560cd3..7138cfcf40 100644 --- a/transformer_engine/jax/cpp_extensions/quantization.py +++ b/transformer_engine/jax/cpp_extensions/quantization.py @@ -1105,7 +1105,9 @@ def abstract( # [n_groups int64 group_sizes | n_groups+1 int64 offsets] # = (2*n_groups + 1) * sizeof(int64_t) bytes stored as uint8. n_groups = group_sizes_aval.size - int64_workspace_aval = jax.core.ShapedArray(shape=((2 * n_groups + 1) * 8,), dtype=jnp.uint8) + int64_workspace_aval = jax.core.ShapedArray( + shape=((2 * n_groups + 1) * 8,), dtype=jnp.uint8 + ) else: # V1 path: Unused for V1 codepath int64_workspace_aval = jax.core.ShapedArray(shape=(0,), dtype=jnp.uint8) @@ -1146,7 +1148,13 @@ def outer_abstract(*args, **kwargs): updated_amax_aval, _, ) = GroupedQuantizePrimitive.abstract(*args, **kwargs) - return (rowwise_out_aval, colwise_out_aval, rowwise_scale_inv_aval, colwise_scale_inv_aval, updated_amax_aval) + return ( + rowwise_out_aval, + colwise_out_aval, + rowwise_scale_inv_aval, + colwise_scale_inv_aval, + updated_amax_aval, + ) @staticmethod def lowering( diff --git a/transformer_engine/jax/csrc/extensions/quantization.cpp b/transformer_engine/jax/csrc/extensions/quantization.cpp index 8fd587ca91..650139a61c 100644 --- a/transformer_engine/jax/csrc/extensions/quantization.cpp +++ b/transformer_engine/jax/csrc/extensions/quantization.cpp @@ -318,9 +318,9 @@ XLA_FFI_DEFINE_HANDLER_SYMBOL(DequantizeHandler, DequantizeFFI, Error_Type GroupedQuantizeFFI(cudaStream_t stream, Buffer_Type inputs, Buffer_Type scales, Buffer_Type group_sizes, Result_Type outputs, Result_Type colwise_outputs, Result_Type scale_invs, - Result_Type colwise_scale_invs, Result_Type amaxs, Result_Type _unused, - JAXX_Scaling_Mode scaling_mode, JAXX_Quantize_Layout quantize_layout, - int64_t flatten_axis) { + Result_Type colwise_scale_invs, Result_Type amaxs, + Result_Type _unused, JAXX_Scaling_Mode scaling_mode, + JAXX_Quantize_Layout quantize_layout, int64_t flatten_axis) { NVTE_CHECK(scaling_mode != JAXX_Scaling_Mode::NO_SCALING, "Unsupported scaling mode: ", static_cast(scaling_mode)); @@ -486,27 +486,29 @@ Error_Type GroupedQuantizeFFI(cudaStream_t stream, Buffer_Type inputs, Buffer_Ty return ffi_with_cuda_error_check(); } -XLA_FFI_DEFINE_HANDLER_SYMBOL(GroupedQuantizeHandler, GroupedQuantizeFFI, - FFI::Bind() - .Ctx() // stream - .Arg() // input - .Arg() // scale - .Arg() // group_sizes - .Ret() // output - .Ret() // colwise output - .Ret() // scale_inv - .Ret() // scale_inv colwise - .Ret() // amax - .Ret() // unused (for compatibility with V2 interface) - .Attr("scaling_mode") - .Attr("q_layout") - .Attr("flatten_axis")); +XLA_FFI_DEFINE_HANDLER_SYMBOL( + GroupedQuantizeHandler, GroupedQuantizeFFI, + FFI::Bind() + .Ctx() // stream + .Arg() // input + .Arg() // scale + .Arg() // group_sizes + .Ret() // output + .Ret() // colwise output + .Ret() // scale_inv + .Ret() // scale_inv colwise + .Ret() // amax + .Ret() // unused (for compatibility with V2 interface) + .Attr("scaling_mode") + .Attr("q_layout") + .Attr("flatten_axis")); Error_Type GroupedQuantizeV2FFI(cudaStream_t stream, Buffer_Type inputs, Buffer_Type scale_unused, Buffer_Type group_sizes, Result_Type rowwise_out, Result_Type colwise_out, Result_Type rowwise_sinv, - Result_Type colwise_sinv, Result_Type updated_amaxs, Result_Type int64_workspace, - JAXX_Quantize_Layout quantize_layout, int64_t flatten_axis) { + Result_Type colwise_sinv, Result_Type updated_amaxs, + Result_Type int64_workspace, JAXX_Quantize_Layout quantize_layout, + int64_t flatten_axis) { (void)scale_unused; // scale is unused for MXFP8; accepted to match V1 input arity auto in_dtype = convert_ffi_datatype_to_te_dtype(inputs.element_type()); auto out_dtype = convert_ffi_datatype_to_te_dtype(rowwise_out->element_type()); From 4da6b809b428d9360787884064cedefc30e015b5 Mon Sep 17 00:00:00 2001 From: Jeremy Berchtold Date: Sat, 11 Apr 2026 10:47:51 -0700 Subject: [PATCH 60/60] Fix Hopper V1 FP8 grouped GEMMs Signed-off-by: Jeremy Berchtold --- transformer_engine/jax/cpp_extensions/gemm.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/transformer_engine/jax/cpp_extensions/gemm.py b/transformer_engine/jax/cpp_extensions/gemm.py index e42608e0b4..8351634b1d 100644 --- a/transformer_engine/jax/cpp_extensions/gemm.py +++ b/transformer_engine/jax/cpp_extensions/gemm.py @@ -2227,7 +2227,7 @@ def _adjust_contracting_dims_for_hopper_fp8_transpose( rhs_contract_dim: Sequence[int], lhs_is_trans: bool, rhs_is_trans: bool, -) -> Tuple[Sequence[int], Sequence[int]]: +) -> Tuple[bool, bool, Sequence[int], Sequence[int]]: # Only support FP8 GEMM with NT layout on Hopper and other earlier GPUs # thus additional transpose is required lhs_layout_is_T = lhs.data_layout == "T" @@ -2255,7 +2255,7 @@ def _adjust_contracting_dims_for_hopper_fp8_transpose( else: rhs_contract_dim = tuple((rhs_ndim - 1 - i) % rhs_ndim for i in rhs_contract_dim) - return lhs_contract_dim, rhs_contract_dim + return lhs_is_trans, rhs_is_trans, lhs_contract_dim, rhs_contract_dim def _quantize_inputs_if_needed( @@ -2402,8 +2402,10 @@ def grouped_gemm( raise ValueError("FP8 GEMM does not support E5M2 * E5M2") if scaling_mode.is_tensor_scaling() and not is_fp8_gemm_with_all_layouts_supported(): - lhs_contract_dim, rhs_contract_dim = _adjust_contracting_dims_for_hopper_fp8_transpose( - lhs, rhs, lhs_contract_dim, rhs_contract_dim, lhs_is_trans, rhs_is_trans + lhs_is_trans, rhs_is_trans, lhs_contract_dim, rhs_contract_dim = ( + _adjust_contracting_dims_for_hopper_fp8_transpose( + lhs, rhs, lhs_contract_dim, rhs_contract_dim, lhs_is_trans, rhs_is_trans + ) ) # Compute N-D axis boundaries from final (post-adjustment) contracting dims.