diff --git a/tests/jax/test_custom_call_compute.py b/tests/jax/test_custom_call_compute.py index ddb74fd636..7e680cc591 100644 --- a/tests/jax/test_custom_call_compute.py +++ b/tests/jax/test_custom_call_compute.py @@ -27,6 +27,7 @@ from transformer_engine.jax.cpp_extensions.quantization import ( _jax_quantize, _jax_quantize_dbias, + GroupedQuantizePrimitive, ) from transformer_engine.jax.cpp_extensions.misc import get_cudnn_version from transformer_engine.jax import cpp_extensions as tex @@ -1068,7 +1069,20 @@ def test_rht_gemm(self, in_dtype, q_dtype, scaling_mode, m, n, k, data_layout, w @pytest.mark.skipif(not is_fp8_supported, reason=fp8_unsupported_reason) @pytest_parametrize_wrapper("in_dtype", QUANTIZATION_INPUT_DTYPE) -@pytest_parametrize_wrapper("input_shape", [(8, 16, 32)]) +@pytest_parametrize_wrapper( + "input_shape", + [ + (8, 16, 32), # V1 MXFP8: K=32 not 128-aligned + (4, 8, 128), # V2 MXFP8 eligible: K=128, M*32=256 both 128-aligned + ], +) +@pytest_parametrize_wrapper( + "group_size_multiplier", + [ + 32, # V1 MXFP8: group size must be multiple of 32 + 128, # V2 MXFP8 eligible: group size must be multiple of 128 + ], +) @pytest_parametrize_wrapper("q_dtype", [jnp.float8_e4m3fn]) @pytest_parametrize_wrapper("scaling_mode", non_fp4_supported_scaling_modes) @pytest_parametrize_wrapper("flatten_axis", [-1]) @@ -1078,14 +1092,21 @@ def test_rht_gemm(self, in_dtype, q_dtype, scaling_mode, m, n, k, data_layout, w ) class TestGroupedQuantize: def test_grouped_qdq( - self, in_dtype, input_shape, q_dtype, scaling_mode, q_layout, flatten_axis, with_group_sizes + self, + in_dtype, + input_shape, + group_size_multiplier, + q_dtype, + scaling_mode, + q_layout, + flatten_axis, + with_group_sizes, ): n_groups, m, n = input_shape key = jax.random.PRNGKey(0) subkeys = jax.random.split(key, 2) - # *32 so that the input shapes works for MXFP8 - input_shape = (m * 32, n) + input_shape = (m * group_size_multiplier, n) if with_group_sizes: group_sizes = jnp.sort(jax.random.randint(subkeys[0], (n_groups - 1,), 0, m)) @@ -1093,7 +1114,7 @@ def test_grouped_qdq( group_sizes = jnp.diff(group_sizes) assert group_sizes.sum() == m assert jnp.any(group_sizes == 0) # make sure that at least one group has 0 row - group_sizes = group_sizes * 32 + group_sizes = group_sizes * group_size_multiplier else: group_sizes = None input_shape = (n_groups, input_shape[0] // n_groups, input_shape[1]) @@ -1101,6 +1122,22 @@ def test_grouped_qdq( if flatten_axis == -2: input_shape = input_shape[:-1] + (2,) + input_shape[-1:] + # V2 MXFP8 quantize requires every individual group size to be a multiple of 128. + # group_size_multiplier=32 can produce groups of 32 or 64 rows which violate this. + # This cannot be checked at runtime (group sizes live on device), so we skip the + # test configuration rather than weaken the kernel-selection logic. + if ( + scaling_mode == ScalingMode.MXFP8_1D_SCALING + and group_size_multiplier % 128 != 0 + and GroupedQuantizePrimitive._use_v2_kernel( + scaling_mode.value, input_shape, flatten_axis + ) + ): + pytest.skip( + "MXFP8 V2 quantize requires each group to be 128-aligned; " + f"group_size_multiplier={group_size_multiplier} may produce smaller groups" + ) + x = jax.random.uniform(subkeys[1], input_shape, in_dtype) grouped_quantizer = QuantizerFactory.create( @@ -1713,10 +1750,11 @@ def ref_func(x, gamma, kernel_1, kernel_2, bias_1, bias_2): ] GROUPED_DENSE_INPUT_SHAPES = [ - # (n_groups, m, n, k), the actual m will be multiplied by 32 - (5, 32, 128, 64), # Test the case where n_groups is not a multiple of 4 - (8, 64, 32, 128), - (8, 64, 128, 256), + # (n_groups, m, n, k), the actual m will be multiplied by group_size_multiplier + (5, 32, 128, 64), # V1 MXFP8: K=64 not 128-aligned; also tests n_groups not a multiple of 4 + (8, 64, 32, 128), # V1 MXFP8 GEMM: N=32 not 128-aligned + (8, 64, 128, 256), # V2 MXFP8 eligible: K=256, N=128 both 128-aligned + (4, 4, 128, 128), # V2 MXFP8 eligible: K=128, N=128 both 128-aligned (smaller shape) ] @@ -1742,7 +1780,9 @@ def _ref_grouped_dense(self, lhs, rhs, bias, group_sizes, contracting_dims): ref_out.append(jnp.squeeze(out_i)) return ref_out - def _generate_grouped_dense_input(self, dtype, input_shape, data_layout="NN", with_bias=False): + def _generate_grouped_dense_input( + self, dtype, input_shape, data_layout="NN", with_bias=False, group_size_multiplier=32 + ): key = jax.random.PRNGKey(0) subkeys = jax.random.split(key, 4) n_groups, m, n, k = input_shape @@ -1755,9 +1795,9 @@ def _generate_grouped_dense_input(self, dtype, input_shape, data_layout="NN", wi group_sizes = group_sizes.at[1].set(0) assert group_sizes.sum() == m - # *32 to make sure that input shape works for MXFP8 - group_sizes = group_sizes * 32 - m = m * 32 + # Scale group sizes by the multiplier for alignment requirements. + group_sizes = group_sizes * group_size_multiplier + m = m * group_size_multiplier lhs_shape = (m if data_layout[0] == "N" else k, k if data_layout[0] == "N" else m) rhs_shape = (n_groups, k if data_layout[1] == "N" else n, n if data_layout[1] == "N" else k) @@ -1831,8 +1871,10 @@ def test_grouped_gemm_fp8(self, fwd_bwd_dtype, scaling_mode, input_shape, layout quantizer.q_dtype = bwd_dtype out_dtype = jnp.bfloat16 + # MXFP8 V2 kernel requires each group's row count to be divisible by 128. + is_mxfp8 = scaling_mode == ScalingMode.MXFP8_1D_SCALING lhs, rhs, group_sizes, contracting_dims, _ = self._generate_grouped_dense_input( - out_dtype, input_shape, layout + out_dtype, input_shape, layout, group_size_multiplier=128 if is_mxfp8 else 32 ) ref_out = self._ref_grouped_dense(lhs, rhs, None, group_sizes, contracting_dims) @@ -1906,10 +1948,13 @@ def test_grouped_dense_grad_fp16(self, dtype, input_shape): def test_grouped_dense_grad_fp8(self, fwd_bwd_dtype, scaling_mode, input_shape): fwd_dtype, bwd_dtype = fwd_bwd_dtype dtype = jnp.bfloat16 + # MXFP8 V2 kernel requires each group's row count to be divisible by 128. + is_mxfp8 = scaling_mode == ScalingMode.MXFP8_1D_SCALING x, kernel, group_sizes, contracting_dims, bias = self._generate_grouped_dense_input( dtype, input_shape, with_bias=True, + group_size_multiplier=128 if is_mxfp8 else 32, ) quantizer_set = QuantizerFactory.create_set( diff --git a/transformer_engine/common/gemm/cublaslt_grouped_gemm.cu b/transformer_engine/common/gemm/cublaslt_grouped_gemm.cu index a8e0b6df83..985c53f760 100644 --- a/transformer_engine/common/gemm/cublaslt_grouped_gemm.cu +++ b/transformer_engine/common/gemm/cublaslt_grouped_gemm.cu @@ -1414,6 +1414,24 @@ __global__ void convert_int32_to_int64_kernel(const int32_t *src, int64_t *dst, if (idx < n) dst[idx] = static_cast(src[idx]); } +// Like convert_int32_to_int64_kernel but scales each element by multiplier. +// Used to convert per-expert slice counts to per-expert row counts for multi-dim tensors. +__global__ void convert_int32_to_int64_with_multiplier_kernel(const int32_t *src, int64_t *dst, + size_t n, int64_t multiplier) { + size_t idx = blockIdx.x * blockDim.x + threadIdx.x; + if (idx < n) dst[idx] = static_cast(src[idx]) * multiplier; +} + +// Computes exclusive prefix sums: offsets[0]=0, offsets[i]=sum(first_dims[0..i-1]*last_dim). +// Produces n_groups+1 values. Single-threaded sequential scan; n_groups is typically small. +__global__ void compute_grouped_tensor_offsets_kernel(const int64_t *first_dims, int64_t *offsets, + size_t n_groups, int64_t last_dim) { + offsets[0] = 0; + for (size_t i = 0; i < n_groups; i++) { + offsets[i + 1] = offsets[i] + first_dims[i] * last_dim; + } +} + } // namespace void nvte_convert_int32_to_int64(const int32_t *src, int64_t *dst, size_t n, cudaStream_t stream) { @@ -1424,3 +1442,23 @@ void nvte_convert_int32_to_int64(const int32_t *src, int64_t *dst, size_t n, cud convert_int32_to_int64_kernel<<>>(src, dst, n); NVTE_CHECK_CUDA(cudaGetLastError()); } + +void nvte_convert_int32_to_int64_with_multiplier(const int32_t *src, int64_t *dst, size_t n, + int64_t multiplier, cudaStream_t stream) { + NVTE_API_CALL(nvte_convert_int32_to_int64_with_multiplier); + if (n == 0) return; + const int threads = 256; + const int blocks = static_cast((n + threads - 1) / threads); + convert_int32_to_int64_with_multiplier_kernel<<>>(src, dst, n, + multiplier); + NVTE_CHECK_CUDA(cudaGetLastError()); +} + +void nvte_compute_grouped_tensor_offsets(const int64_t *first_dims, int64_t *offsets, + size_t n_groups, int64_t last_dim, cudaStream_t stream) { + NVTE_API_CALL(nvte_compute_grouped_tensor_offsets); + // Always write at least offsets[0]=0 (needed even for n_groups==0). + compute_grouped_tensor_offsets_kernel<<<1, 1, 0, stream>>>(first_dims, offsets, n_groups, + last_dim); + NVTE_CHECK_CUDA(cudaGetLastError()); +} diff --git a/transformer_engine/common/include/transformer_engine/gemm.h b/transformer_engine/common/include/transformer_engine/gemm.h index 6999dd857f..fcd08a40a9 100644 --- a/transformer_engine/common/include/transformer_engine/gemm.h +++ b/transformer_engine/common/include/transformer_engine/gemm.h @@ -356,6 +356,35 @@ size_t nvte_get_grouped_gemm_setup_workspace_size(size_t num_tensors); */ void nvte_convert_int32_to_int64(const int32_t *src, int64_t *dst, size_t n, cudaStream_t stream); +/*! \brief Convert int32 array to int64 while scaling each element by a multiplier. + * + * Computes dst[i] = (int64_t)src[i] * multiplier for each i in [0, n). + * CUDA-graph safe (no host-device synchronization). + * + * \param[in] src Device pointer to source int32 array. + * \param[out] dst Device pointer to destination int64 array. + * \param[in] n Number of elements. + * \param[in] multiplier Scale factor applied to each element. + * \param[in] stream CUDA stream. + */ +void nvte_convert_int32_to_int64_with_multiplier(const int32_t *src, int64_t *dst, size_t n, + int64_t multiplier, cudaStream_t stream); + +/*! \brief Compute exclusive prefix-sum offsets from per-group first-dimension sizes. + * + * Writes n_groups+1 values to offsets: offsets[0]=0, + * offsets[i] = sum(first_dims[0..i-1] * last_dim) for i in [1, n_groups]. + * This is CUDA-graph safe (no host-device synchronization). + * + * \param[in] first_dims Device pointer to int64 array of length n_groups. + * \param[out] offsets Device pointer to int64 array of length n_groups+1. + * \param[in] n_groups Number of groups. + * \param[in] last_dim Common last dimension (number of columns). + * \param[in] stream CUDA stream. + */ +void nvte_compute_grouped_tensor_offsets(const int64_t *first_dims, int64_t *offsets, + size_t n_groups, int64_t last_dim, cudaStream_t stream); + void nvte_grouped_gemm(const NVTEGroupedTensor A, int transa, const NVTEGroupedTensor B, int transb, const NVTEGroupedTensor C, NVTEGroupedTensor D, const NVTETensor alpha, const NVTETensor beta, NVTETensor workspace_setup, diff --git a/transformer_engine/jax/cpp_extensions/gemm.py b/transformer_engine/jax/cpp_extensions/gemm.py index c081e451a7..8351634b1d 100644 --- a/transformer_engine/jax/cpp_extensions/gemm.py +++ b/transformer_engine/jax/cpp_extensions/gemm.py @@ -9,7 +9,7 @@ from collections.abc import Iterable from dataclasses import dataclass from functools import partial, reduce, cache -from typing import Tuple, Sequence, Union +from typing import Tuple, Sequence, Union, Optional from enum import Enum import warnings @@ -47,7 +47,7 @@ apply_padding_to_scale_inv, QuantizeLayout, ) -from .misc import get_padded_spec, is_all_reduce_in_float32 +from .misc import get_padded_spec, is_all_reduce_in_float32, get_min_device_compute_capability from ..sharding import ( global_mesh_resource, tpsp_axis_size, @@ -66,6 +66,7 @@ "sanitize_dims", "get_non_contracting_dims", "transpose_dims", + "is_v2_grouped_gemm_supported", ] @@ -1597,7 +1598,6 @@ def _compute_cublas_workspace_size( workspace_size = get_cublas_workspace_size_bytes() * stream_count workspace_alignment_padding = 256 tensor_scaling_sinv_aligment = 16 - mxfp8_scaling_sinv_alignment_padding = 256 # cuBLAS workspace ptr must be 256 bytes aligned but JAX buffers are not # necessarily 256 bytes aligned, we add some padding to ensure alignment. workspace_size += workspace_alignment_padding @@ -1610,9 +1610,9 @@ def _compute_cublas_workspace_size( workspace_size += lhs_scale_inv_aval.size * tensor_scaling_sinv_aligment workspace_size += rhs_scale_inv_aval.size * tensor_scaling_sinv_aligment elif scaling_mode == ScalingMode.MXFP8_1D_SCALING.value: - # We also pad scale_inv swizzle buffers size for 256 bytes alignment. - workspace_size += lhs_scale_inv_aval.size + mxfp8_scaling_sinv_alignment_padding - workspace_size += rhs_scale_inv_aval.size + mxfp8_scaling_sinv_alignment_padding + # Both V1 and V2 quantize now produce pre-swizzled scales, so the GEMM + # does not need extra workspace for nvte_swizzle_scaling_factors. + pass return workspace_size @staticmethod @@ -2036,48 +2036,303 @@ def _should_enforce_v2_grouped_gemm() -> bool: ) from e -def _can_use_v2_grouped_gemm( +def _is_v2_grouped_gemm_supported( scaling_mode: ScalingMode, dtype: jnp.dtype, has_bias: bool, -) -> bool: - """Determine whether the cuda-graphable grouped GEMM implementation can be used based on the input parameters.""" - # Use the cuda-graphable path for plain BF16 non-quantized inputs; fall back to the legacy - # nvte_multi_tensor_gemm path for all other cases (FP8, MXFP8, etc.) to stay - # feature-compatible with the main branch. - # Bias can be supported in a kernel or in pure-JAX in the future. - - enforce_v2_gmm = _should_enforce_v2_grouped_gemm() + lhs_shape=None, + rhs_shape=None, + lhs_axis_boundary=None, + rhs_axis_boundary=None, +) -> tuple[bool, str]: + """Determine whether the V2 grouped GEMM implementation can be used based on the input parameters.""" if not _v2_grouped_gemm_available: - if enforce_v2_gmm: - raise RuntimeError( - "The TE V2 grouped GEMM is not available but NVTE_JAX_ENFORCE_V2_GROUPED_GEMM is" - " enabled. The reason for V2 grouped GEMM not being available:" - f" {_v2_grouped_gemm_available_reason}" - ) - return False + return ( + False, + ( + "TE was not compiled with support for the V2 grouped GEMM kernel, reason: " + f"{_v2_grouped_gemm_available_reason}" + ), + ) # nvte_grouped_gemm (the v2 kernel) requires SM100+ (Blackwell or newer). # Fall back to the v1 path on SM90 (Hopper) and older architectures. - if get_device_compute_capability(0) < 100: - if enforce_v2_gmm: - raise RuntimeError( - "The TE V2 grouped GEMM requires SM100+ (Blackwell or newer) but current device" - f" compute capability of GPU 0 is {get_device_compute_capability(0)} and" - " NVTE_JAX_ENFORCE_V2_GROUPED_GEMM is enabled." - ) - return False + if get_min_device_compute_capability() < 100: + return ( + False, + ( + "The TE V2 grouped GEMM requires SM100+ (Blackwell or newer) but current min device" + f" compute capability is {get_min_device_compute_capability()}." + ), + ) - if scaling_mode == ScalingMode.NO_SCALING and dtype == jnp.bfloat16 and not has_bias: - return True + if has_bias: + return False, "Grouped GEMM with bias is not supported in the TE V2 grouped GEMM kernel." + + if scaling_mode == ScalingMode.NO_SCALING and dtype == jnp.bfloat16: + return True, "" + + if scaling_mode == ScalingMode.MXFP8_1D_SCALING: + # V2 MXFP8 requires that the total first dimension of both operands (up to + # axis_boundary) is divisible by 128, matching the quantize V2 kernel requirement. + # Individual group sizes must also be 128-aligned (dynamic constraint). + if lhs_shape is not None and lhs_axis_boundary is not None: + lhs_first_dim = math.prod(lhs_shape[:lhs_axis_boundary]) + if lhs_first_dim % 128 != 0: + return ( + False, + ( + "The TE V2 grouped GEMM for MXFP8 requires the product of the first" + " dimensions (up to axis_boundary) of LHS to be divisible by 128, but got" + f" {lhs_first_dim} with lhs_shape={lhs_shape} and" + f" lhs_axis_boundary={lhs_axis_boundary}." + ), + ) + if rhs_shape is not None and rhs_axis_boundary is not None: + rhs_first_dim = math.prod(rhs_shape[:rhs_axis_boundary]) + if rhs_first_dim % 128 != 0: + return ( + False, + ( + "The TE V2 grouped GEMM for MXFP8 requires the product of the first" + " dimensions (up to axis_boundary) of RHS to be divisible by 128, but got" + f" {rhs_first_dim} with rhs_shape={rhs_shape} and" + f" rhs_axis_boundary={rhs_axis_boundary}." + ), + ) - if enforce_v2_gmm: + # V2 MXFP8 also requires that the "last" dimension (after axis_boundary) of both + # operands is a multiple of 128. This is because the MXFP8 scales must be padded to a multiple of (128, 4). The nvte_grouped_gemm setup kernels only handle the case when this dim is a multiple of 128 as well. If it is not, the GEMM setup kernel will not compute the scale offsets correctly and will read overlapping scales from the previous group, causing incorrect results. + if lhs_shape is not None and lhs_axis_boundary is not None: + lhs_last_dim = math.prod(lhs_shape[lhs_axis_boundary:]) + if lhs_last_dim % 128 != 0: + return ( + False, + ( + "The TE V2 grouped GEMM for MXFP8 requires the product of the last" + " dimensions (after axis_boundary) of LHS to be divisible by 128, but got" + f" {lhs_last_dim} with lhs_shape={lhs_shape} and" + f" lhs_axis_boundary={lhs_axis_boundary}." + ), + ) + if rhs_shape is not None and rhs_axis_boundary is not None: + rhs_last_dim = math.prod(rhs_shape[rhs_axis_boundary:]) + if rhs_last_dim % 128 != 0: + return ( + False, + ( + "The TE V2 grouped GEMM for MXFP8 requires the product of the last" + " dimensions (after axis_boundary) of RHS to be divisible by 128, but got" + f" {rhs_last_dim} with rhs_shape={rhs_shape} and" + f" rhs_axis_boundary={rhs_axis_boundary}." + ), + ) + return True, "" + + return ( + False, + ( + "The TE V2 grouped GEMM currently only supports non-quantized BF16 and MXFP8 with 1D" + " block scaling, but NVTE_JAX_ENFORCE_V2_GROUPED_GEMM is enabled and the input" + f" parameters do not meet these requirements (scaling_mode= {scaling_mode}," + f" dtype={dtype}, has_bias={has_bias}, lhs_shape={lhs_shape}, rhs_shape={rhs_shape}," + f" lhs_axis_boundary={lhs_axis_boundary}, rhs_axis_boundary={rhs_axis_boundary})." + ), + ) + + +def is_v2_grouped_gemm_supported( + scaling_mode: ScalingMode, + dtype: jnp.dtype, + has_bias: bool, + lhs_shape=None, + rhs_shape=None, + lhs_axis_boundary=None, + rhs_axis_boundary=None, +) -> tuple[bool, str]: + """Determine whether the V2 grouped GEMM implementation can be used based on the input parameters. + + Returns: + A tuple of (is_supported: bool, reason: str) where is_supported indicates whether the V2 grouped GEMM can be used, and reason provides an explanation if it is not supported. + """ + # Use the V2 path for plain BF16 non-quantized inputs and MXFP8; fall back to + # the legacy nvte_multi_tensor_gemm path for all other cases (tensor-scaled FP8, etc.). + # Bias can be supported in a kernel or in pure-JAX in the future. + + enforce_v2_gmm = _should_enforce_v2_grouped_gemm() + + is_v2_supported, reason = _is_v2_grouped_gemm_supported( + scaling_mode, dtype, has_bias, lhs_shape, rhs_shape, lhs_axis_boundary, rhs_axis_boundary + ) + + if enforce_v2_gmm and not is_v2_supported: raise RuntimeError( - "The TE V2 grouped GEMM currently only supports BF16 with no quantization recipe and" - f" without bias, but received {scaling_mode=}, {dtype=}, {has_bias=}" + "The TE V2 grouped GEMM is not supported for the given input parameters, but" + " NVTE_JAX_ENFORCE_V2_GROUPED_GEMM is enabled. The reason for V2 grouped GEMM not being" + f" supported: {reason}" ) - return False + + return is_v2_supported, reason + + +def _get_out_dtype_and_scaling_mode( + x: Union[GroupedNoScaleTensor, GroupedScaledTensor1x], +) -> Tuple[jnp.dtype, ScalingMode]: + if isinstance(x, GroupedScaledTensor1x): + out_dtype = x.dq_dtype + scaling_mode = x.scaling_mode + elif isinstance(x, GroupedNoScaleTensor): + out_dtype = x.data.dtype + scaling_mode = ScalingMode.NO_SCALING + else: + raise TypeError( + f"Input must be GroupedNoScaleTensor or GroupedScaledTensor1x, got type={type(x)}" + ) + return out_dtype, scaling_mode + + +def _infer_output_ragged_dims( + lhs: Union[GroupedNoScaleTensor, GroupedScaledTensor1x], + rhs: Union[GroupedNoScaleTensor, GroupedScaledTensor1x], +) -> Tuple[Optional[jnp.ndarray], Optional[jnp.ndarray]]: + assert isinstance( + lhs, (GroupedNoScaleTensor, GroupedScaledTensor1x) + ), f"Expected lhs to be GroupedNoScaleTensor or GroupedScaledTensor1x, got type={type(lhs)}" + assert isinstance( + rhs, (GroupedNoScaleTensor, GroupedScaledTensor1x) + ), f"Expected rhs to be GroupedNoScaleTensor or GroupedScaledTensor1x, got type={type(rhs)}" + + # Infer output dims from which operand has the ragged non-contracting dim. + if rhs.first_dims is not None or rhs.last_dims is not None: + # Wgrad: rhs contracting dim is ragged → output is uniform (G prefix from num_groups) + out_first_dims = None + out_last_dims = None + elif lhs.first_dims is not None: + out_first_dims = lhs.first_dims + out_last_dims = None + elif lhs.last_dims is not None: + out_first_dims = None + out_last_dims = lhs.last_dims + else: + out_first_dims = out_last_dims = None + + return out_first_dims, out_last_dims + + +def _adjust_contracting_dims_for_hopper_fp8_transpose( + lhs: Union[GroupedNoScaleTensor, GroupedScaledTensor1x], + rhs: Union[GroupedNoScaleTensor, GroupedScaledTensor1x], + lhs_contract_dim: Sequence[int], + rhs_contract_dim: Sequence[int], + lhs_is_trans: bool, + rhs_is_trans: bool, +) -> Tuple[bool, bool, Sequence[int], Sequence[int]]: + # Only support FP8 GEMM with NT layout on Hopper and other earlier GPUs + # thus additional transpose is required + lhs_layout_is_T = lhs.data_layout == "T" + rhs_layout_is_T = rhs.data_layout == "T" + # we can't apply _shape_normalization on the grouped input + # thus we need to ensure that lhs is in N and rhs is in T + if lhs_is_trans != lhs_layout_is_T: + raise RuntimeError("lhs input must be transposed before calling grouped_gemm") + if (not rhs_is_trans) != rhs_layout_is_T: + raise RuntimeError("rhs input must be transposed before calling grouped_gemm") + lhs_is_trans = False + rhs_is_trans = True + lhs_ndim = len(lhs.original_shape) + rhs_ndim = len(rhs.original_shape) + if lhs_layout_is_T: + lhs_contract_dim = tuple((lhs_ndim - 1 - i) % lhs_ndim for i in lhs_contract_dim) + if rhs_layout_is_T: + # For rhs [G, K, N], need to exclude the G dim from contract_dim + if ( + lhs.first_dims is not None or lhs.last_dims is not None + ): # fwd/dgrad: rhs has G as first dim + rhs_contract_dim = tuple( + (rhs_ndim - 1 - i) % (rhs_ndim - 1) + 1 for i in rhs_contract_dim + ) + else: + rhs_contract_dim = tuple((rhs_ndim - 1 - i) % rhs_ndim for i in rhs_contract_dim) + + return lhs_is_trans, rhs_is_trans, lhs_contract_dim, rhs_contract_dim + + +def _quantize_inputs_if_needed( + lhs: Union[GroupedNoScaleTensor, GroupedScaledTensor1x], + rhs: Union[GroupedNoScaleTensor, GroupedScaledTensor1x], + quantizer_set: QuantizerSet, + lhs_is_trans: bool, + rhs_is_trans: bool, + lhs_flatten_axis: int, + rhs_flatten_axis: int, +) -> Tuple[ + Union[GroupedNoScaleTensor, GroupedScaledTensor1x], + Union[GroupedNoScaleTensor, GroupedScaledTensor1x], +]: + if quantizer_set is noop_quantizer_set: + return lhs, rhs + + assert isinstance( + lhs, GroupedNoScaleTensor + ), f"Expected lhs to be GroupedNoScaleTensor before quantization, got type={type(lhs)}" + assert isinstance( + rhs, GroupedNoScaleTensor + ), f"Expected rhs to be GroupedNoScaleTensor before quantization, got type={type(rhs)}" + + if not isinstance(quantizer_set.x, GroupedQuantizer): + raise TypeError( + f"Expected quantizer_set.x to be GroupedQuantizer, but got type={type(quantizer_set.x)}" + ) + if type(quantizer_set.x) is not type(quantizer_set.kernel): + raise TypeError( + "Expected quantizer_set.x and quantizer_set.kernel to have the same type, but got" + f" {type(quantizer_set.x)} and {type(quantizer_set.kernel)}" + ) + if ( + quantizer_set.x.scaling_mode.is_tensor_scaling() + and is_fp8_gemm_with_all_layouts_supported() + ): + lhs_is_rowwise = rhs_is_rowwise = True + else: + lhs_is_rowwise = not lhs_is_trans + rhs_is_rowwise = rhs_is_trans + quantizer_set.x.q_layout = QuantizeLayout.ROWWISE if lhs_is_rowwise else QuantizeLayout.COLWISE + quantizer_set.kernel.q_layout = ( + QuantizeLayout.ROWWISE if rhs_is_rowwise else QuantizeLayout.COLWISE + ) + empty_gs = jnp.empty((0,), jnp.int32) + active_group_sizes = next( + ( + gs + for gs in [lhs.first_dims, lhs.last_dims, rhs.first_dims, rhs.last_dims] + if gs is not None and gs.size > 0 + ), + empty_gs, + ) + lhs_input_data = lhs.data + rhs_input_data = rhs.data + lhs_q = grouped_quantize(lhs_input_data, quantizer_set.x, active_group_sizes, lhs_flatten_axis) + rhs_q = grouped_quantize( + rhs_input_data, quantizer_set.kernel, group_sizes=None, flatten_axis=rhs_flatten_axis + ) + return lhs_q, rhs_q + + +def _get_num_gemms( + lhs: Union[GroupedNoScaleTensor, GroupedScaledTensor1x], + rhs: Union[GroupedNoScaleTensor, GroupedScaledTensor1x], +) -> int: + for x in [lhs, rhs]: + if x.first_dims is not None: + return x.first_dims.size + if x.last_dims is not None: + return x.last_dims.size + raise ValueError( + "Cannot infer number of gemms since neither lhs nor rhs has first_dims or last_dims. " + "Ensure that at least one of the input tensors has valid first_dims or last_dims." + "For grouped_gemm, at least one tensor must be ragged." + ) def grouped_gemm( @@ -2113,179 +2368,51 @@ def grouped_gemm( empty_gs = jnp.empty((0,), jnp.int32) - # Extract data, dims, and metadata from tensor objects. - # Keep data in its original layout (may be 1D for quantized tensors) to preserve - # JAX sharding; the C++ side uses original_shape to derive m/n/k. - if isinstance(lhs, GroupedNoScaleTensor): - lhs_data = lhs.data - lhs_shape = lhs.original_shape - lhs_scale_inv = jnp.empty((0,), jnp.float32) - scaling_mode = ScalingMode.NO_SCALING - out_dtype = lhs.data.dtype - lhs_first_dims = lhs.first_dims if lhs.first_dims is not None else empty_gs - lhs_last_dims = lhs.last_dims if lhs.last_dims is not None else empty_gs - elif isinstance(lhs, GroupedScaledTensor1x): - lhs_shape = lhs.original_shape - lhs_data = lhs.data - lhs_scale_inv = lhs.scale_inv - scaling_mode = lhs.scaling_mode - out_dtype = lhs.dq_dtype - lhs_first_dims = lhs.first_dims if lhs.first_dims is not None else empty_gs - lhs_last_dims = lhs.last_dims if lhs.last_dims is not None else empty_gs - else: - raise TypeError( - f"lhs must be GroupedNoScaleTensor or GroupedScaledTensor1x, got type={type(lhs)}" - ) - - if isinstance(rhs, GroupedNoScaleTensor): - rhs_data = rhs.data - rhs_shape = rhs.original_shape - rhs_scale_inv = jnp.empty((0,), jnp.float32) - rhs_first_dims = rhs.first_dims if rhs.first_dims is not None else empty_gs - rhs_last_dims = rhs.last_dims if rhs.last_dims is not None else empty_gs - elif isinstance(rhs, GroupedScaledTensor1x): - rhs_shape = rhs.original_shape - rhs_data = rhs.data - rhs_scale_inv = rhs.scale_inv - rhs_first_dims = rhs.first_dims if rhs.first_dims is not None else empty_gs - rhs_last_dims = rhs.last_dims if rhs.last_dims is not None else empty_gs - if isinstance(lhs, GroupedScaledTensor1x) and lhs.scaling_mode != rhs.scaling_mode: - raise ValueError( - f"Mismatched scaling modes: lhs.scaling_mode={lhs.scaling_mode}," - f" rhs.scaling_mode={rhs.scaling_mode}" - ) - if isinstance(lhs, GroupedScaledTensor1x): - scaling_mode = lhs.scaling_mode - else: - raise TypeError( - f"rhs must be GroupedNoScaleTensor or GroupedScaledTensor1x, got type={type(rhs)}" - ) + out_dtype, scaling_mode = _get_out_dtype_and_scaling_mode(lhs) + rhs_out_dtype, rhs_scaling_mode = _get_out_dtype_and_scaling_mode(rhs) + assert out_dtype == rhs_out_dtype, f"Mismatched output dtypes: {out_dtype} vs {rhs_out_dtype}" + assert ( + scaling_mode == rhs_scaling_mode + ), f"Mismatched scaling modes: {scaling_mode} vs {rhs_scaling_mode}" + del rhs_out_dtype, rhs_scaling_mode - # Infer output dims from which operand has the ragged non-contracting dim. - if rhs_first_dims.size > 0 or rhs_last_dims.size > 0: - # Wgrad: rhs contracting dim is ragged → output is uniform (G prefix from num_groups) - out_first_dims = empty_gs - out_last_dims = empty_gs - elif lhs_first_dims.size > 0: - out_first_dims = lhs_first_dims - out_last_dims = empty_gs - elif lhs_last_dims.size > 0: - out_first_dims = empty_gs - out_last_dims = lhs_last_dims - else: - out_first_dims = out_last_dims = empty_gs + out_first_dims, out_last_dims = _infer_output_ragged_dims(lhs, rhs) out_dtype = preferred_element_type or out_dtype lhs_contract_dim, rhs_contract_dim = contracting_dims - lhs_is_trans = lhs_contract_dim[-1] != len(lhs_shape) - 1 + lhs_is_trans = lhs_contract_dim[-1] != len(lhs.original_shape) - 1 lhs_flatten_axis = len(lhs_contract_dim) * (1 if lhs_is_trans else -1) # rhs_is_trans: K is the last dim of rhs (i.e., rhs is in "T" layout). - rhs_is_trans = rhs_contract_dim[-1] == len(rhs_shape) - 1 + rhs_is_trans = rhs_contract_dim[-1] == len(rhs.original_shape) - 1 rhs_flatten_axis = -len(rhs_contract_dim) if rhs_is_trans else 1 + len(rhs_contract_dim) - if ( - not isinstance(lhs, ScaledTensor) - and not isinstance(rhs, ScaledTensor) - and quantizer_set != noop_quantizer_set - ): - if not isinstance(quantizer_set.x, GroupedQuantizer): - raise TypeError( - "Expected quantizer_set.x to be GroupedQuantizer, but got" - f" type={type(quantizer_set.x)}" - ) - if type(quantizer_set.x) is not type(quantizer_set.kernel): - raise TypeError( - "Expected quantizer_set.x and quantizer_set.kernel to have the same type, but got" - f" {type(quantizer_set.x)} and {type(quantizer_set.kernel)}" - ) - scaling_mode = quantizer_set.x.scaling_mode - if ( - quantizer_set.x.scaling_mode.is_tensor_scaling() - and is_fp8_gemm_with_all_layouts_supported() - ): - lhs_is_rowwise = rhs_is_rowwise = True - else: - lhs_is_rowwise = not lhs_is_trans - rhs_is_rowwise = rhs_is_trans - quantizer_set.x.q_layout = ( - QuantizeLayout.ROWWISE if lhs_is_rowwise else QuantizeLayout.COLWISE - ) - quantizer_set.kernel.q_layout = ( - QuantizeLayout.ROWWISE if rhs_is_rowwise else QuantizeLayout.COLWISE - ) - active_group_sizes = next( - ( - gs - for gs in [lhs_first_dims, lhs_last_dims, rhs_first_dims, rhs_last_dims] - if gs.size > 0 - ), - empty_gs, - ) - lhs_input_data = lhs.data if isinstance(lhs, GroupedNoScaleTensor) else lhs_data - rhs_input_data = rhs.data if isinstance(rhs, GroupedNoScaleTensor) else rhs_data - lhs_q = grouped_quantize( - lhs_input_data, quantizer_set.x, active_group_sizes, lhs_flatten_axis - ) - rhs_q = grouped_quantize( - rhs_input_data, quantizer_set.kernel, group_sizes=None, flatten_axis=rhs_flatten_axis - ) - lhs_data = lhs_q.data - rhs_data = rhs_q.data - lhs_scale_inv = lhs_q.scale_inv - rhs_scale_inv = rhs_q.scale_inv - lhs_shape = lhs_q.original_shape - rhs_shape = rhs_q.original_shape + lhs, rhs = _quantize_inputs_if_needed( + lhs, rhs, quantizer_set, lhs_is_trans, rhs_is_trans, lhs_flatten_axis, rhs_flatten_axis + ) + + # Re-read scaling_mode after quantization: if _quantize_inputs_if_needed converted + # GroupedNoScaleTensor → GroupedScaledTensor1x, the original scaling_mode (NO_SCALING) + # would cause the C++ kernel to skip scale_inv setup, triggering a cuBLAS assertion. + _, scaling_mode = _get_out_dtype_and_scaling_mode(lhs) - if lhs_data.dtype == jnp.float8_e5m2 and rhs_data.dtype == jnp.float8_e5m2: + if lhs.data.dtype == jnp.float8_e5m2 and rhs.data.dtype == jnp.float8_e5m2: raise ValueError("FP8 GEMM does not support E5M2 * E5M2") - # Only support FP8 GEMM with NT layout on Hopper and other earlier GPUs - # thus additional transpose is required if scaling_mode.is_tensor_scaling() and not is_fp8_gemm_with_all_layouts_supported(): - if isinstance(lhs, ScaledTensor) and isinstance(rhs, ScaledTensor): - lhs_layout_is_T = lhs.data_layout == "T" - rhs_layout_is_T = rhs.data_layout == "T" - else: - lhs_layout_is_T = lhs_q.data_layout == "T" - rhs_layout_is_T = rhs_q.data_layout == "T" - # we can't apply _shape_normalization on the grouped input - # thus we need to ensure that lhs is in N and rhs is in T - if lhs_is_trans != lhs_layout_is_T: - raise RuntimeError("lhs input must be transposed before calling grouped_gemm") - if (not rhs_is_trans) != rhs_layout_is_T: - raise RuntimeError("rhs input must be transposed before calling grouped_gemm") - lhs_is_trans = False - rhs_is_trans = True - lhs_ndim = len(lhs_shape) - rhs_ndim = len(rhs_shape) - if lhs_layout_is_T: - lhs_contract_dim = tuple((lhs_ndim - 1 - i) % lhs_ndim for i in lhs_contract_dim) - if rhs_layout_is_T: - # For rhs [G, K, N], need to exclude the G dim from contract_dim - if ( - lhs_first_dims.size > 0 or lhs_last_dims.size > 0 - ): # fwd/dgrad: rhs has G as first dim - rhs_contract_dim = tuple( - (rhs_ndim - 1 - i) % (rhs_ndim - 1) + 1 for i in rhs_contract_dim - ) - else: - rhs_contract_dim = tuple((rhs_ndim - 1 - i) % rhs_ndim for i in rhs_contract_dim) + lhs_is_trans, rhs_is_trans, lhs_contract_dim, rhs_contract_dim = ( + _adjust_contracting_dims_for_hopper_fp8_transpose( + lhs, rhs, lhs_contract_dim, rhs_contract_dim, lhs_is_trans, rhs_is_trans + ) + ) # Compute N-D axis boundaries from final (post-adjustment) contracting dims. lhs_axis_boundary = get_lhs_axis_boundary(lhs_contract_dim, lhs_is_trans) rhs_axis_boundary = get_rhs_axis_boundary(rhs_contract_dim, rhs_is_trans) - num_gemms = ( - lhs_first_dims.size - or lhs_last_dims.size - or rhs_first_dims.size - or rhs_last_dims.size - or out_first_dims.size - or out_last_dims.size - ) + num_gemms = _get_num_gemms(lhs, rhs) if num_gemms == 0: raise ValueError( "grouped_gemm requires at least one non-empty dimension array. " @@ -2294,26 +2421,28 @@ def grouped_gemm( # Pre-compute collapsed 2D sizes from original N-D shapes. # These are static Python ints passed as primitive parameters (must be hashable). - lhs_left_size = math.prod(lhs_shape[:lhs_axis_boundary]) - lhs_right_size = math.prod(lhs_shape[lhs_axis_boundary:]) - rhs_left_size = math.prod(rhs_shape[:rhs_axis_boundary]) - rhs_right_size = math.prod(rhs_shape[rhs_axis_boundary:]) + lhs_left_size = math.prod(lhs.original_shape[:lhs_axis_boundary]) + lhs_right_size = math.prod(lhs.original_shape[lhs_axis_boundary:]) + rhs_left_size = math.prod(rhs.original_shape[:rhs_axis_boundary]) + rhs_right_size = math.prod(rhs.original_shape[rhs_axis_boundary:]) # Pre-compute output shape from N-D input shapes (static Python ints). if lhs_is_trans: - lhs_non_contracting = lhs_shape[lhs_axis_boundary:] + lhs_non_contracting = lhs.original_shape[lhs_axis_boundary:] else: - lhs_non_contracting = lhs_shape[:lhs_axis_boundary] + lhs_non_contracting = lhs.original_shape[:lhs_axis_boundary] if rhs_is_trans: - if rhs_first_dims.size > 0 or rhs_last_dims.size > 0: + if rhs.first_dims is not None or rhs.last_dims is not None: # wgrad: rhs (e.g. grad_T of shape (N, M)) has no G batch dim; include all dims - rhs_non_contracting = tuple(rhs_shape[d] for d in range(rhs_axis_boundary)) + rhs_non_contracting = tuple(rhs.original_shape[d] for d in range(rhs_axis_boundary)) else: # fwd/dgrad: rhs (e.g. kernel_T of shape (G, N, K)) has G batch dim at dim 0; skip it - rhs_non_contracting = tuple(rhs_shape[d] for d in range(rhs_axis_boundary) if d != 0) + rhs_non_contracting = tuple( + rhs.original_shape[d] for d in range(rhs_axis_boundary) if d != 0 + ) else: - rhs_non_contracting = rhs_shape[rhs_axis_boundary:] - if rhs_first_dims.size > 0 or rhs_last_dims.size > 0: + rhs_non_contracting = rhs.original_shape[rhs_axis_boundary:] + if rhs.first_dims is not None or rhs.last_dims is not None: out_shape = (num_gemms, *lhs_non_contracting, *rhs_non_contracting) else: out_shape = (*lhs_non_contracting, *rhs_non_contracting) @@ -2334,7 +2463,23 @@ def grouped_gemm( " and padded with zeros to not affect the result of the MoE block." ) - use_v2_ffi = _can_use_v2_grouped_gemm(scaling_mode, lhs_data.dtype, has_bias) + use_v2_ffi, _ = is_v2_grouped_gemm_supported( + scaling_mode, + lhs.data.dtype, + has_bias, + lhs_shape=lhs.original_shape, + rhs_shape=rhs.original_shape, + lhs_axis_boundary=lhs_axis_boundary, + rhs_axis_boundary=rhs_axis_boundary, + ) + + if scaling_mode == ScalingMode.MXFP8_1D_SCALING: + # Both V1 and V2 quantize produce pre-swizzled scales (V1 via + # set_with_gemm_swizzled_scales, V2 via nvte_group_quantize). Require that + # grouped_quantize has set pre_swizzled=True on the input tensors. + assert lhs.pre_swizzled, "lhs must be pre-swizzled for MXFP8 1D scaling" + assert rhs.pre_swizzled, "rhs must be pre-swizzled for MXFP8 1D scaling" + if use_v2_ffi: additional_arg_0 = jnp.ones((num_gemms,), jnp.float32) # alpha additional_arg_1 = jnp.zeros((num_gemms,), jnp.float32) # beta @@ -2343,17 +2488,17 @@ def grouped_gemm( additional_arg_1 = jnp.zeros((0,), jnp.int32) # unused placeholder (out,) = GroupedGemmPrimitive.outer_primitive.bind( - lhs_data, - lhs_scale_inv, - rhs_data, - rhs_scale_inv, + lhs.data, + lhs.scale_inv if isinstance(lhs, GroupedScaledTensor1x) else jnp.empty((0,), jnp.float32), + rhs.data, + rhs.scale_inv if isinstance(rhs, GroupedScaledTensor1x) else jnp.empty((0,), jnp.float32), bias, - lhs_first_dims, - lhs_last_dims, - rhs_first_dims, - rhs_last_dims, - out_first_dims, - out_last_dims, + lhs.first_dims if lhs.first_dims is not None else empty_gs, + lhs.last_dims if lhs.last_dims is not None else empty_gs, + rhs.first_dims if rhs.first_dims is not None else empty_gs, + rhs.last_dims if rhs.last_dims is not None else empty_gs, + out_first_dims if out_first_dims is not None else empty_gs, + out_last_dims if out_last_dims is not None else empty_gs, additional_arg_0, additional_arg_1, lhs_is_trans=lhs_is_trans, diff --git a/transformer_engine/jax/cpp_extensions/quantization.py b/transformer_engine/jax/cpp_extensions/quantization.py index a3d363e42a..7138cfcf40 100644 --- a/transformer_engine/jax/cpp_extensions/quantization.py +++ b/transformer_engine/jax/cpp_extensions/quantization.py @@ -994,7 +994,8 @@ class GroupedQuantizePrimitive(BasePrimitive): Cast Primitive wrapping nvte_quantize and nvte_quantize_dbias """ - name = "te_grouped_quantize_ffi" + name = "te_grouped_quantize_ffi" # V1: fallback path (supports all shapes, not CUDA-graph safe) + name_v2 = "te_grouped_quantize_v2_ffi" # V2: MXFP8, CUDA-graph safe multiple_results = True impl_static_args = ( 3, @@ -1006,6 +1007,54 @@ class GroupedQuantizePrimitive(BasePrimitive): inner_primitive = None outer_primitive = None + @staticmethod + def _use_v2_kernel(scaling_mode, x_shape, flatten_axis): + """Return True when the V2 (CUDA-graph-safe) MXFP8 kernel can be used. + + V2 requires: + 1. SM100+ (Blackwell) — V2 grouped quantize fuses the scale_inv swizzle via + nvte_group_quantize. The swizzled scale_inv must then be consumed by the + V2 grouped GEMM, which also requires SM100+. Keeping both decisions tied + to SM100+ prevents a mismatch where V2-quantized (pre-swizzled) tensors + are passed to the V1 grouped GEMM (which would re-swizzle and corrupt). + 2. The total first logical dimension (product of x_shape up to flatten_axis) + is divisible by 128. + 3. For multi-dim group tensors (eff > 1, e.g., kernel shape G×K×N), the + per-group row count non_group_m = prod(x_shape[1:eff]) must also be + divisible by 128. + 4. For lhs-style tensors (eff == 1, shape M×K), individual group sizes must + be 128-aligned — this is a dynamic constraint that cannot be checked here + because group sizes live on device. The caller is responsible for ensuring + this. + 5. The last logical dimension (contracting dim K or output dim N) must be + divisible by 128, matching the V2 grouped GEMM constraint so that the + two always agree on V1 vs V2. + + Falls back to V1 when constraints are not met. V1 supports arbitrary shapes + but performs a D2H copy of group_sizes (not CUDA-graph safe). + """ + if ScalingMode(scaling_mode) != ScalingMode.MXFP8_1D_SCALING: + return False + # Require SM100+ so V2 quantize (fused swizzle) is only used alongside V2 GEMM. + if get_min_device_compute_capability() < 100: + return False + ndim = len(x_shape) + eff = flatten_axis if flatten_axis >= 0 else flatten_axis + ndim + total_first_dim = math.prod(x_shape[:eff]) + if total_first_dim % 128 != 0: + return False + # For multi-dim group tensors (e.g., kernel shape G×K×N with eff=2), + # non_group_m = K must also be 128-aligned. + if eff > 1: + non_group_m = math.prod(x_shape[1:eff]) + if non_group_m % 128 != 0: + return False + # Last dim must be 128-aligned to match the V2 grouped GEMM requirement. + last_dim = math.prod(x_shape[eff:]) + if last_dim % 128 != 0: + return False + return True + @staticmethod def abstract( x_aval, @@ -1048,7 +1097,20 @@ def abstract( rowwise_scale_inv_shape = (1,) rowwise_out_aval = jax.core.ShapedArray(shape=rowwise_out_shape, dtype=out_dtype) - amax_aval = jax.core.ShapedArray(shape=(group_sizes_aval.size,), dtype=jnp.float32) + updated_amax_aval = jax.core.ShapedArray(shape=(group_sizes_aval.size,), dtype=jnp.float32) + + use_v2 = GroupedQuantizePrimitive._use_v2_kernel(scaling_mode, x_aval.shape, flatten_axis) + if use_v2: + # V2 path: int64_workspace laid out as: + # [n_groups int64 group_sizes | n_groups+1 int64 offsets] + # = (2*n_groups + 1) * sizeof(int64_t) bytes stored as uint8. + n_groups = group_sizes_aval.size + int64_workspace_aval = jax.core.ShapedArray( + shape=((2 * n_groups + 1) * 8,), dtype=jnp.uint8 + ) + else: + # V1 path: Unused for V1 codepath + int64_workspace_aval = jax.core.ShapedArray(shape=(0,), dtype=jnp.uint8) if q_layout.has_colwise: colwise_out_shape = out_shape @@ -1068,7 +1130,8 @@ def abstract( colwise_out_aval, rowwise_scale_inv_aval, colwise_scale_inv_aval, - amax_aval, + updated_amax_aval, + int64_workspace_aval, ) @staticmethod @@ -1078,13 +1141,20 @@ def outer_abstract(*args, **kwargs): """ # Phuong: keeping outer abstract so that we can add fuse dbias later ( - rowwise_out, - colwise_out, - scale_inv, - colwise_scale_inv, - updated_amax, + rowwise_out_aval, + colwise_out_aval, + rowwise_scale_inv_aval, + colwise_scale_inv_aval, + updated_amax_aval, + _, ) = GroupedQuantizePrimitive.abstract(*args, **kwargs) - return rowwise_out, colwise_out, scale_inv, colwise_scale_inv, updated_amax + return ( + rowwise_out_aval, + colwise_out_aval, + rowwise_scale_inv_aval, + colwise_scale_inv_aval, + updated_amax_aval, + ) @staticmethod def lowering( @@ -1107,6 +1177,21 @@ def lowering( assert x_aval.dtype in [jnp.float32, jnp.float16, jnp.bfloat16] assert scale_aval.dtype == jnp.float32 assert group_sizes_aval.dtype == jnp.int32 + use_v2 = GroupedQuantizePrimitive._use_v2_kernel(scaling_mode, x_aval.shape, flatten_axis) + if use_v2: + # V2: CUDA-graph safe; scale is passed but ignored by the C++ handler. + # Requires total_first_dim % 128 == 0 (checked above) and all individual + # group sizes % 128 == 0 (dynamic constraint, enforced by the kernel). + return ffi.ffi_lowering(GroupedQuantizePrimitive.name_v2)( + ctx, + x, + scale, + group_sizes, + q_layout=q_layout.value.value, + flatten_axis=flatten_axis, + ) + # V1: supports arbitrary shapes but not CUDA-graph safe (performs D2H copy of group_sizes). + # Used for non-MXFP8 scaling modes and for MXFP8 when total_first_dim % 128 != 0. return ffi.ffi_lowering(GroupedQuantizePrimitive.name)( ctx, x, @@ -1138,6 +1223,7 @@ def impl( rowwise_scale_inv, colwise_scale_inv, updated_amax, + _, ) = GroupedQuantizePrimitive.inner_primitive.bind( x, scale, @@ -1148,7 +1234,7 @@ def impl( flatten_axis=flatten_axis, scale_dtype=scale_dtype, ) - return (rowwise_out, colwise_out, rowwise_scale_inv, colwise_scale_inv, updated_amax) + return rowwise_out, colwise_out, rowwise_scale_inv, colwise_scale_inv, updated_amax register_primitive(GroupedQuantizePrimitive) @@ -1259,6 +1345,11 @@ def grouped_quantize( for i, quantizer_i in enumerate(quantizer.quantizers): quantizer_i.update(updated_amax[i].reshape((1,))) + # Both V1 (set_with_gemm_swizzled_scales) and V2 (nvte_group_quantize) produce + # pre-swizzled scale_inv tensors for use by the grouped GEMM kernel. Set + # pre_swizzled=True for all MXFP8 grouped quantization so that grouped_gemm can + # assert this invariant unconditionally. + is_mxfp8 = quantizer.scaling_mode == ScalingMode.MXFP8_1D_SCALING out = ScaledTensorFactory.create( data=rowwise_casted_output, scale_inv=rowwise_scale_inv, @@ -1271,6 +1362,7 @@ def grouped_quantize( flatten_axis=flatten_axis, first_dims=ragged_first_dims, original_shape=original_shape, + pre_swizzled=is_mxfp8, ) return out diff --git a/transformer_engine/jax/csrc/extensions.h b/transformer_engine/jax/csrc/extensions.h index a74b209e4f..3ba0e7e9b2 100644 --- a/transformer_engine/jax/csrc/extensions.h +++ b/transformer_engine/jax/csrc/extensions.h @@ -119,6 +119,8 @@ XLA_FFI_DECLARE_HANDLER_SYMBOL(DBiasQuantizeHandler); XLA_FFI_DECLARE_HANDLER_SYMBOL(GroupedQuantizeHandler); +XLA_FFI_DECLARE_HANDLER_SYMBOL(GroupedQuantizeV2Handler); + XLA_FFI_DECLARE_HANDLER_SYMBOL(DequantizeHandler); pybind11::tuple GetDBiasQuantizeWorkspaceSizes(size_t batch_size, size_t hidden_size, diff --git a/transformer_engine/jax/csrc/extensions/gemm.cpp b/transformer_engine/jax/csrc/extensions/gemm.cpp index a7f16bb31f..6ca907032c 100644 --- a/transformer_engine/jax/csrc/extensions/gemm.cpp +++ b/transformer_engine/jax/csrc/extensions/gemm.cpp @@ -481,6 +481,8 @@ class JAXX_GroupedTensorWrapper { m_grouped_tensor(other.m_grouped_tensor), m_data_tensor(other.m_data_tensor), m_scale_inv_tensor(other.m_scale_inv_tensor), + m_colwise_data_tensor(other.m_colwise_data_tensor), + m_colwise_scale_inv_tensor(other.m_colwise_scale_inv_tensor), m_sizes_tensor(other.m_sizes_tensor), m_offsets_tensor(other.m_offsets_tensor) { other.m_grouped_tensor = nullptr; @@ -489,6 +491,10 @@ class JAXX_GroupedTensorWrapper { ~JAXX_GroupedTensorWrapper(); void set_rowwise(Buffer_Type const &data, std::optional const &scale_inv); + void set_columnwise(Buffer_Type const &data, std::optional const &scale_inv); + void set_with_gemm_swizzled_scales(bool val); + void replace_scale_inv(bool use_colwise, uint8_t *sinv_ptr, NVTEDType sinv_dtype, + NVTEShape sinv_shape); void set_group_info(Buffer_Type const &group_sizes, Buffer_Type const &group_offsets, NVTEGroupedTensorParam group_sizes_param_name); // Set only group sizes (no offsets); the setup kernel will compute offsets from sizes. @@ -505,6 +511,8 @@ class JAXX_GroupedTensorWrapper { // Internal tensors. These need to be kept alive as long as the grouped tensor is alive. NVTEBasicTensor m_data_tensor{}; NVTEBasicTensor m_scale_inv_tensor{}; + NVTEBasicTensor m_colwise_data_tensor{}; + NVTEBasicTensor m_colwise_scale_inv_tensor{}; NVTEBasicTensor m_sizes_tensor{}; NVTEBasicTensor m_offsets_tensor{}; @@ -556,6 +564,58 @@ void JAXX_GroupedTensorWrapper::set_rowwise(Buffer_Type const &data, } } +void JAXX_GroupedTensorWrapper::set_columnwise(Buffer_Type const &data, + std::optional const &scale_inv) { + NVTEDType data_dtype = + static_cast(convert_ffi_datatype_to_te_dtype(data.element_type())); + m_colwise_data_tensor = + NVTEBasicTensor{reinterpret_cast(data.untyped_data()), data_dtype, m_data_shape}; + + nvte_set_grouped_tensor_param(m_grouped_tensor, kNVTEGroupedColumnwiseData, + &m_colwise_data_tensor, sizeof(m_colwise_data_tensor)); + + if (scale_inv.has_value()) { + NVTEDType scale_inv_dtype = + static_cast(convert_ffi_datatype_to_te_dtype(scale_inv->element_type())); + NVTEShape logical_scale_shape{}; + if (scale_inv->dimensions().size() == 1) { + logical_scale_shape.ndim = 1; + logical_scale_shape.data[0] = scale_inv->dimensions()[0]; + } else if (scale_inv->dimensions().size() == 2) { + logical_scale_shape.ndim = 2; + logical_scale_shape.data[0] = scale_inv->dimensions()[0]; + logical_scale_shape.data[1] = scale_inv->dimensions()[1]; + } else { + NVTE_CHECK(false, "Expected 1D or 2D tensor for GEMM columnwise scale_inv but received ndim=", + scale_inv->dimensions().size()); + } + m_colwise_scale_inv_tensor = + NVTEBasicTensor{reinterpret_cast(scale_inv->untyped_data()), scale_inv_dtype, + logical_scale_shape}; + nvte_set_grouped_tensor_param(m_grouped_tensor, kNVTEGroupedColumnwiseScaleInv, + &m_colwise_scale_inv_tensor, sizeof(m_colwise_scale_inv_tensor)); + } +} + +void JAXX_GroupedTensorWrapper::set_with_gemm_swizzled_scales(bool val) { + auto v = static_cast(val); + nvte_set_grouped_tensor_param(m_grouped_tensor, kNVTEGroupedWithGEMMSwizzledScales, &v, + sizeof(v)); +} + +void JAXX_GroupedTensorWrapper::replace_scale_inv(bool use_colwise, uint8_t *sinv_ptr, + NVTEDType sinv_dtype, NVTEShape sinv_shape) { + if (use_colwise) { + m_colwise_scale_inv_tensor = NVTEBasicTensor{sinv_ptr, sinv_dtype, sinv_shape}; + nvte_set_grouped_tensor_param(m_grouped_tensor, kNVTEGroupedColumnwiseScaleInv, + &m_colwise_scale_inv_tensor, sizeof(m_colwise_scale_inv_tensor)); + } else { + m_scale_inv_tensor = NVTEBasicTensor{sinv_ptr, sinv_dtype, sinv_shape}; + nvte_set_grouped_tensor_param(m_grouped_tensor, kNVTEGroupedRowwiseScaleInv, + &m_scale_inv_tensor, sizeof(m_scale_inv_tensor)); + } +} + void JAXX_GroupedTensorWrapper::set_group_info(Buffer_Type const &group_sizes, Buffer_Type const &group_offsets, NVTEGroupedTensorParam group_sizes_param_name) { @@ -619,22 +679,19 @@ JAXX_GroupedTensorWrapper make_grouped_tensor(Buffer_Type const &data, return std::move(grouped_tensor_wrapper); } -// V2 variant: derives data shape from the XLA buffer directly, converts group_sizes +// V2 variant (NO_SCALING): derives data shape from the XLA buffer directly, converts group_sizes // int32→int64 per-tensor into a dedicated slot of int64_workspace, and wires first_dims/last_dims. // int64_offset (in int64 elements) is updated on return to the next available slot so callers can // thread it through successive make_grouped_tensor calls without aliasing. Bounds are checked -// before each slot is used. Only NO_SCALING is supported. +// before each slot is used. Only NO_SCALING is supported by this overload. JAXX_GroupedTensorWrapper make_grouped_tensor( Buffer_Type const &data, Buffer_Type const &first_dims, Buffer_Type const &last_dims, int64_t *int64_workspace_base, size_t int64_workspace_capacity, size_t &int64_offset, - size_t num_gemms, cudaStream_t stream, int64_t axis_boundary = -1) { + size_t num_gemms, cudaStream_t stream, size_t left_size, size_t right_size) { auto dims = data.dimensions(); - NVTE_CHECK(dims.size() >= 2, "grouped GEMM data buffer must be at least 2D."); - // Flatten dims at axis_boundary to produce a 2D NVTE shape. - // axis_boundary=-1 (default) collapses dims[0..N-2] → rows and keeps dims[N-1] → cols, - // preserving the prior behaviour for output buffers (e.g. [G, K, N] for wgrad). - size_t ab = (axis_boundary < 0) ? dims.size() - 1 : static_cast(axis_boundary); - NVTEShape dataShape{.data = {product(dims, 0, ab), product(dims, ab, dims.size())}, .ndim = 2}; + NVTE_CHECK(product(dims) == left_size * right_size, + "grouped GEMM data buffer element count does not match the provided 2D shape."); + NVTEShape dataShape{.data = {left_size, right_size}, .ndim = 2}; JAXX_GroupedTensorWrapper wrapper(JAXX_Scaling_Mode::NO_SCALING, num_gemms, dataShape); wrapper.set_rowwise(data, std::nullopt); if (first_dims.element_count() > 0) { @@ -660,6 +717,56 @@ JAXX_GroupedTensorWrapper make_grouped_tensor( return wrapper; } +// V2 variant with scaling support (MXFP8 or NO_SCALING). Accepts scale_inv buffer and +// use_colwise flag to wire rowwise or columnwise data+scales for the grouped tensor. +// Pre-swizzled scales are indicated via set_with_gemm_swizzled_scales(true). +JAXX_GroupedTensorWrapper make_grouped_tensor( + Buffer_Type const &data, Buffer_Type const &scale_inv, JAXX_Scaling_Mode scaling_mode, + bool use_colwise, Buffer_Type const &first_dims, Buffer_Type const &last_dims, + int64_t *int64_workspace_base, size_t int64_workspace_capacity, size_t &int64_offset, + size_t num_gemms, cudaStream_t stream, size_t left_size, size_t right_size) { + auto dims = data.dimensions(); + NVTE_CHECK(product(dims) == left_size * right_size, + "grouped GEMM data buffer element count does not match the provided 2D shape."); + NVTEShape dataShape{.data = {left_size, right_size}, .ndim = 2}; + JAXX_GroupedTensorWrapper wrapper(scaling_mode, num_gemms, dataShape); + + const bool is_mxfp8 = scaling_mode == JAXX_Scaling_Mode::MXFP8_1D_SCALING; + if (is_mxfp8 && use_colwise) { + wrapper.set_columnwise(data, scale_inv); + } else if (is_mxfp8) { + wrapper.set_rowwise(data, scale_inv); + } else { + // NO_SCALING: no scale_inv needed + wrapper.set_rowwise(data, std::nullopt); + } + if (is_mxfp8) { + wrapper.set_with_gemm_swizzled_scales(true); + } + + if (first_dims.element_count() > 0) { + NVTE_CHECK(first_dims.element_type() == xla::ffi::DataType::S32, "group_sizes must be int32."); + NVTE_CHECK(int64_offset + num_gemms <= int64_workspace_capacity, + "int64_workspace overflow: not enough space for first_dims conversion."); + auto *slot = int64_workspace_base + int64_offset; + nvte_convert_int32_to_int64(reinterpret_cast(first_dims.untyped_data()), slot, + num_gemms, stream); + wrapper.set_group_sizes_only(slot, num_gemms, kNVTEGroupedFirstDims); + int64_offset += num_gemms; + } + if (last_dims.element_count() > 0) { + NVTE_CHECK(last_dims.element_type() == xla::ffi::DataType::S32, "group_sizes must be int32."); + NVTE_CHECK(int64_offset + num_gemms <= int64_workspace_capacity, + "int64_workspace overflow: not enough space for last_dims conversion."); + auto *slot = int64_workspace_base + int64_offset; + nvte_convert_int32_to_int64(reinterpret_cast(last_dims.untyped_data()), slot, + num_gemms, stream); + wrapper.set_group_sizes_only(slot, num_gemms, kNVTEGroupedLastDims); + int64_offset += num_gemms; + } + return wrapper; +} + // Returns num_gemms from the first non-empty per-tensor group_sizes buffer, // falling back to the element count of alpha for the uniform-batch case. size_t grouped_gemm_num_gemms(Buffer_Type const &lhs_first_dims, Buffer_Type const &lhs_last_dims, @@ -752,13 +859,19 @@ Error_Type GroupedGemmV2FFI(cudaStream_t stream, Buffer_Type lhs_data, Buffer_Ty auto [lhs_is_trans, rhs_is_trans, scaling_mode, lhs_axis_boundary, rhs_axis_boundary, lhs_left_size, lhs_right_size, rhs_left_size, rhs_right_size] = config; - NVTE_CHECK(scaling_mode == JAXX_Scaling_Mode::NO_SCALING, - "Only non-quantized grouped GEMM is supported in current implementation."); + NVTE_CHECK(scaling_mode == JAXX_Scaling_Mode::NO_SCALING || + scaling_mode == JAXX_Scaling_Mode::MXFP8_1D_SCALING, + "Only NO_SCALING and MXFP8_1D_SCALING are supported in the V2 grouped GEMM."); + + const bool is_mxfp8 = scaling_mode == JAXX_Scaling_Mode::MXFP8_1D_SCALING; size_t num_gemms = grouped_gemm_num_gemms(lhs_first_dims, lhs_last_dims, rhs_first_dims, rhs_last_dims, out_first_dims, out_last_dims, alpha); // Workspaces. + // V2 GEMM receives scale_inv already swizzled by nvte_group_quantize (V2 grouped quantize + // fuses the swizzle). No extra sinv reservation is needed; the full cublas_workspace is + // available for cuBLAS. auto setup_workspace_ptr = reinterpret_cast(setup_workspace->untyped_data()); auto cublas_workspace_ptr = reinterpret_cast(cublas_workspace->untyped_data()); cublas_workspace_ptr = move_ptr_to_next_256B_aligned(cublas_workspace_ptr); @@ -783,14 +896,39 @@ Error_Type GroupedGemmV2FFI(cudaStream_t stream, Buffer_Type lhs_data, Buffer_Ty auto *int64_base = reinterpret_cast(int64_workspace->untyped_data()); size_t int64_capacity = int64_workspace->element_count() / sizeof(int64_t); size_t int64_offset = 0; + + // For MXFP8: in JAX, rhs=cuBLAS_A, lhs=cuBLAS_B (swapped). + // Colwise is needed when the operand's contracting dim is NOT the last dim in its layout. + const bool rhs_use_colwise = is_mxfp8 && !rhs_is_trans; + const bool lhs_use_colwise = is_mxfp8 && lhs_is_trans; + + // For MXFP8: scale_inv is already swizzled (pre-swizzled by V2 grouped quantize via + // nvte_group_quantize). Pass the buffers directly to make_grouped_tensor which sets + // with_gemm_swizzled_scales(true) for MXFP8 automatically. No re-swizzling needed. auto rhs_tensor = - make_grouped_tensor(rhs_data, rhs_first_dims, rhs_last_dims, int64_base, int64_capacity, - int64_offset, num_gemms, stream, rhs_axis_boundary); + is_mxfp8 + ? make_grouped_tensor(rhs_data, rhs_sinv, scaling_mode, rhs_use_colwise, rhs_first_dims, + rhs_last_dims, int64_base, int64_capacity, int64_offset, num_gemms, + stream, rhs_left_size, rhs_right_size) + : make_grouped_tensor(rhs_data, rhs_first_dims, rhs_last_dims, int64_base, int64_capacity, + int64_offset, num_gemms, stream, rhs_left_size, rhs_right_size); auto lhs_tensor = - make_grouped_tensor(lhs_data, lhs_first_dims, lhs_last_dims, int64_base, int64_capacity, - int64_offset, num_gemms, stream, lhs_axis_boundary); - auto out_tensor = make_grouped_tensor(*output, out_first_dims, out_last_dims, int64_base, - int64_capacity, int64_offset, num_gemms, stream); + is_mxfp8 + ? make_grouped_tensor(lhs_data, lhs_sinv, scaling_mode, lhs_use_colwise, lhs_first_dims, + lhs_last_dims, int64_base, int64_capacity, int64_offset, num_gemms, + stream, lhs_left_size, lhs_right_size) + : make_grouped_tensor(lhs_data, lhs_first_dims, lhs_last_dims, int64_base, int64_capacity, + int64_offset, num_gemms, stream, lhs_left_size, lhs_right_size); + + // Output stays NO_SCALING. Derive 2D shape from the output buffer's own dims using + // last-dim-as-columns convention (equivalent to axis_boundary=-1 in the old API). + auto out_dims = output->dimensions(); + NVTE_CHECK(out_dims.size() > 0, "output buffer must have at least 1 dimension"); + size_t out_left_size = product(out_dims, 0, out_dims.size() - 1); + size_t out_right_size = static_cast(out_dims[out_dims.size() - 1]); + auto out_tensor = + make_grouped_tensor(*output, out_first_dims, out_last_dims, int64_base, int64_capacity, + int64_offset, num_gemms, stream, out_left_size, out_right_size); auto [avg_m, avg_k_lhs] = grouped_gemm_avg_dims( lhs_first_dims, lhs_last_dims, {lhs_left_size, lhs_right_size}, num_gemms, lhs_is_trans); @@ -943,20 +1081,14 @@ Error_Type GroupedGemmFFI(cudaStream_t stream, Buffer_Type lhs_data, Buffer_Type const size_t tensor_scaling_sinv_aligment = 16; const size_t mxfp8_scaling_sinv_alignment_padding = 256; auto workspace_size = workspace_total_size - workspace_alignment_padding; - if (is_mxfp8_scaling) { - // For MXFP8 swizzled scale_inv buffers, only the first pointer needs to be with 256B alignment padding. Later pointers are guaranteed to be 256-aligned as the scale_inv shapes are padded by 128x4. - workspace_size -= (lhs_sinv_size + rhs_sinv_size + 2 * mxfp8_scaling_sinv_alignment_padding); - } else if (is_tensor_scaling) { + if (is_tensor_scaling) { // For tensor scaling, each matrix has a single scale value, and all scales need to be aligned // by 16 bytes to meet the requirement of CUDA 12.9.1 and later. workspace_size -= tensor_scaling_sinv_aligment * (lhs_sinv_size + rhs_sinv_size); } workspace_size = workspace_size / num_streams; - auto swizzled_lhs_sinv_ptr = workspace_ptr + workspace_size * num_streams; - swizzled_lhs_sinv_ptr = move_ptr_to_next_256B_aligned(swizzled_lhs_sinv_ptr); - auto swizzled_rhs_sinv_ptr = swizzled_lhs_sinv_ptr + lhs_sinv_size; - swizzled_rhs_sinv_ptr = move_ptr_to_next_256B_aligned(swizzled_rhs_sinv_ptr); - auto lhs_scatter_aligned_ptr = swizzled_lhs_sinv_ptr; // Already 256B aligned + auto lhs_scatter_aligned_ptr = workspace_ptr + workspace_size * num_streams; + lhs_scatter_aligned_ptr = move_ptr_to_next_256B_aligned(lhs_scatter_aligned_ptr); auto rhs_scatter_aligned_ptr = lhs_scatter_aligned_ptr + num_gemms * tensor_scaling_sinv_aligment; size_t lhs_dtype_bytes = te_dtype_bytes(lhs_dtype); @@ -1050,8 +1182,6 @@ Error_Type GroupedGemmFFI(cudaStream_t stream, Buffer_Type lhs_data, Buffer_Type // These lists are to keep the TensorWrapper objects alive std::vector lhs_wrapper_list; std::vector rhs_wrapper_list; - std::vector lhs_swizzle_wrapper_list; // For MXFP8 scale_inv swizzling - std::vector rhs_swizzle_wrapper_list; std::vector bias_wrapper_list; std::vector pre_gelu_wrapper_list; std::vector out_wrapper_list; @@ -1060,8 +1190,6 @@ Error_Type GroupedGemmFFI(cudaStream_t stream, Buffer_Type lhs_data, Buffer_Type // These lists are the actual NVTETensor (void *) lists for multi-stream GEMM std::vector lhs_list; std::vector rhs_list; - std::vector lhs_swizzle_list; - std::vector rhs_swizzle_list; std::vector bias_list; std::vector pre_gelu_list; std::vector out_list; @@ -1134,13 +1262,8 @@ Error_Type GroupedGemmFFI(cudaStream_t stream, Buffer_Type lhs_data, Buffer_Type else lhs_i.set_rowwise_scale_inv(lhs_sinv_vptr, lhs_sinv_dtype, tensor_scaling_sinv_shape); } else if (is_mxfp8_scaling) { - auto lhs_swizzle_i = TensorWrapper(get_nvte_scaling_mode(scaling_mode)); - auto rhs_swizzle_i = TensorWrapper(get_nvte_scaling_mode(scaling_mode)); - void *swizzled_lhs_sinv_vptr = static_cast(swizzled_lhs_sinv_ptr); - void *swizzled_rhs_sinv_vptr = static_cast(swizzled_rhs_sinv_ptr); - - // {lhs, rhs}_swizzle_i point to unswizzled scale_inv data as input, while {lhs, rhs}_i - // point to swizzled scale_inv data (store on workspace, only used for GEMM). + // MXFP8 scales are pre-swizzled by the quantize kernel (both V1 and V2), + // so we pass them directly to the GEMM without a separate swizzle pass. // Note: even if is_empty_gemm is true, sinv are still non-empty, need to move the pointers auto lhs_sinv_shape_i = get_block_scale_shape(scaling_mode, lhs_shape_i[0], lhs_shape_i[1], lhs_use_colwise); @@ -1149,32 +1272,17 @@ Error_Type GroupedGemmFFI(cudaStream_t stream, Buffer_Type lhs_data, Buffer_Type lhs_sinv_size_i = lhs_sinv_shape_i[0] * lhs_sinv_shape_i[1]; rhs_sinv_size_i = rhs_sinv_shape_i[0] * rhs_sinv_shape_i[1]; if (lhs_use_colwise) { - lhs_swizzle_i.set_columnwise_data(lhs_vptr, lhs_dtype, lhs_shape_i); - lhs_swizzle_i.set_columnwise_scale_inv(lhs_sinv_vptr, lhs_sinv_dtype, lhs_sinv_shape_i); - lhs_i.set_columnwise_scale_inv(swizzled_lhs_sinv_vptr, lhs_sinv_dtype, lhs_sinv_shape_i); + lhs_i.set_columnwise_scale_inv(lhs_sinv_vptr, lhs_sinv_dtype, lhs_sinv_shape_i); } else { - lhs_swizzle_i.set_rowwise_data(lhs_vptr, lhs_dtype, lhs_shape_i); - lhs_swizzle_i.set_rowwise_scale_inv(lhs_sinv_vptr, lhs_sinv_dtype, lhs_sinv_shape_i); - lhs_i.set_rowwise_scale_inv(swizzled_lhs_sinv_vptr, lhs_sinv_dtype, lhs_sinv_shape_i); + lhs_i.set_rowwise_scale_inv(lhs_sinv_vptr, lhs_sinv_dtype, lhs_sinv_shape_i); } lhs_i.set_with_gemm_swizzled_scales(true); if (rhs_use_colwise) { - rhs_swizzle_i.set_columnwise_data(rhs_vptr, rhs_dtype, rhs_shape_i); - rhs_swizzle_i.set_columnwise_scale_inv(rhs_sinv_vptr, rhs_sinv_dtype, rhs_sinv_shape_i); - rhs_i.set_columnwise_scale_inv(swizzled_rhs_sinv_vptr, rhs_sinv_dtype, rhs_sinv_shape_i); + rhs_i.set_columnwise_scale_inv(rhs_sinv_vptr, rhs_sinv_dtype, rhs_sinv_shape_i); } else { - rhs_swizzle_i.set_rowwise_data(rhs_vptr, rhs_dtype, rhs_shape_i); - rhs_swizzle_i.set_rowwise_scale_inv(rhs_sinv_vptr, rhs_sinv_dtype, rhs_sinv_shape_i); - rhs_i.set_rowwise_scale_inv(swizzled_rhs_sinv_vptr, rhs_sinv_dtype, rhs_sinv_shape_i); + rhs_i.set_rowwise_scale_inv(rhs_sinv_vptr, rhs_sinv_dtype, rhs_sinv_shape_i); } rhs_i.set_with_gemm_swizzled_scales(true); - - if (!is_empty_gemm) { - lhs_swizzle_wrapper_list.push_back(std::move(lhs_swizzle_i)); - rhs_swizzle_wrapper_list.push_back(std::move(rhs_swizzle_i)); - lhs_swizzle_list.push_back(lhs_swizzle_wrapper_list.back().data()); - rhs_swizzle_list.push_back(rhs_swizzle_wrapper_list.back().data()); - } } else { NVTE_CHECK(scaling_mode == JAXX_Scaling_Mode::NO_SCALING, "Unsupported scaling mode: ", static_cast(scaling_mode)); @@ -1192,10 +1300,6 @@ Error_Type GroupedGemmFFI(cudaStream_t stream, Buffer_Type lhs_data, Buffer_Type rhs_sinv_ptr += rhs_sinv_size_i * rhs_sinv_dtype_bytes; lhs_sinv_total_size += lhs_sinv_size_i; rhs_sinv_total_size += rhs_sinv_size_i; - if (is_mxfp8_scaling) { - swizzled_lhs_sinv_ptr += lhs_sinv_size_i * lhs_sinv_dtype_bytes; - swizzled_rhs_sinv_ptr += rhs_sinv_size_i * rhs_sinv_dtype_bytes; - } } if (has_bias) bias_ptr += n * bias_dtype_bytes; @@ -1236,18 +1340,6 @@ Error_Type GroupedGemmFFI(cudaStream_t stream, Buffer_Type lhs_data, Buffer_Type size_t num_non_empty_gemms = lhs_list.size(); - if (is_mxfp8_scaling) { - for (int i = 0; i < num_non_empty_gemms; i++) { - // The i-th GEMM will use the (i % num_streams)-th stream to compute, - // use the same stream to swizzle the scaling factors to make sure that - // the swizzling is done before the GEMM computation starts. - int stream_id = i % num_streams; - cudaStream_t stream_i = nvte_get_compute_stream(stream_id); - nvte_swizzle_scaling_factors(lhs_swizzle_list[i], lhs_list[i], stream_i); - nvte_swizzle_scaling_factors(rhs_swizzle_list[i], rhs_list[i], stream_i); - } - } - // Launch zero-out kernels before the GEMM calls to use the sync in the multi-stream GEMM size_t num_zero_outs = zero_out_dptr_list.size(); for (int i = 0; i < num_zero_outs; i++) { diff --git a/transformer_engine/jax/csrc/extensions/pybind.cpp b/transformer_engine/jax/csrc/extensions/pybind.cpp index 28cb39b5d1..e3bc122403 100644 --- a/transformer_engine/jax/csrc/extensions/pybind.cpp +++ b/transformer_engine/jax/csrc/extensions/pybind.cpp @@ -33,6 +33,7 @@ pybind11::dict Registrations() { // Quantization dict["te_dbias_quantize_ffi"] = EncapsulateFFI(DBiasQuantizeHandler); dict["te_grouped_quantize_ffi"] = EncapsulateFFI(GroupedQuantizeHandler); + dict["te_grouped_quantize_v2_ffi"] = EncapsulateFFI(GroupedQuantizeV2Handler); dict["te_dequantize_ffi"] = EncapsulateFFI(DequantizeHandler); // Softmax diff --git a/transformer_engine/jax/csrc/extensions/quantization.cpp b/transformer_engine/jax/csrc/extensions/quantization.cpp index c5a766f7f2..650139a61c 100644 --- a/transformer_engine/jax/csrc/extensions/quantization.cpp +++ b/transformer_engine/jax/csrc/extensions/quantization.cpp @@ -9,6 +9,7 @@ #include "../extensions.h" #include "transformer_engine/cast.h" +#include "transformer_engine/gemm.h" #include "transformer_engine/hadamard_transform.h" #include "transformer_engine/recipe.h" #include "transformer_engine/transformer_engine.h" @@ -318,8 +319,8 @@ Error_Type GroupedQuantizeFFI(cudaStream_t stream, Buffer_Type inputs, Buffer_Ty Buffer_Type group_sizes, Result_Type outputs, Result_Type colwise_outputs, Result_Type scale_invs, Result_Type colwise_scale_invs, Result_Type amaxs, - JAXX_Scaling_Mode scaling_mode, JAXX_Quantize_Layout quantize_layout, - int64_t flatten_axis) { + Result_Type _unused, JAXX_Scaling_Mode scaling_mode, + JAXX_Quantize_Layout quantize_layout, int64_t flatten_axis) { NVTE_CHECK(scaling_mode != JAXX_Scaling_Mode::NO_SCALING, "Unsupported scaling mode: ", static_cast(scaling_mode)); @@ -451,6 +452,12 @@ Error_Type GroupedQuantizeFFI(cudaStream_t stream, Buffer_Type inputs, Buffer_Ty } } + // For MXFP8, produce pre-swizzled scales so the GEMM can consume them directly + // without a separate swizzle pass. + if (is_mxfp8_scaling) { + out_i.set_with_gemm_swizzled_scales(true); + } + input_holders.push_back(std::move(inp_i)); output_holders.push_back(std::move(out_i)); @@ -479,20 +486,154 @@ Error_Type GroupedQuantizeFFI(cudaStream_t stream, Buffer_Type inputs, Buffer_Ty return ffi_with_cuda_error_check(); } -XLA_FFI_DEFINE_HANDLER_SYMBOL(GroupedQuantizeHandler, GroupedQuantizeFFI, +XLA_FFI_DEFINE_HANDLER_SYMBOL( + GroupedQuantizeHandler, GroupedQuantizeFFI, + FFI::Bind() + .Ctx() // stream + .Arg() // input + .Arg() // scale + .Arg() // group_sizes + .Ret() // output + .Ret() // colwise output + .Ret() // scale_inv + .Ret() // scale_inv colwise + .Ret() // amax + .Ret() // unused (for compatibility with V2 interface) + .Attr("scaling_mode") + .Attr("q_layout") + .Attr("flatten_axis")); + +Error_Type GroupedQuantizeV2FFI(cudaStream_t stream, Buffer_Type inputs, Buffer_Type scale_unused, + Buffer_Type group_sizes, Result_Type rowwise_out, + Result_Type colwise_out, Result_Type rowwise_sinv, + Result_Type colwise_sinv, Result_Type updated_amaxs, + Result_Type int64_workspace, JAXX_Quantize_Layout quantize_layout, + int64_t flatten_axis) { + (void)scale_unused; // scale is unused for MXFP8; accepted to match V1 input arity + auto in_dtype = convert_ffi_datatype_to_te_dtype(inputs.element_type()); + auto out_dtype = convert_ffi_datatype_to_te_dtype(rowwise_out->element_type()); + auto sinv_dtype = convert_ffi_datatype_to_te_dtype(rowwise_sinv->element_type()); + + NVTE_CHECK(is_fp8_dtype(out_dtype), "Output datatype must be FP8 for GroupedQuantizeV2."); + NVTE_CHECK(sinv_dtype == DType::kFloat8E8M0, + "scale_inv must be E8M0 for MXFP8 grouped quantize."); + + auto input_dims = inputs.dimensions(); + int64_t input_ndim = input_dims.size(); + if (flatten_axis < 0) flatten_axis += input_ndim; + NVTE_CHECK(flatten_axis < input_ndim && flatten_axis > 0, "flatten_axis is out of bounds!"); + + auto m = product(input_dims, 0, flatten_axis); + auto n = product(input_dims, flatten_axis, input_ndim); + size_t n_groups = group_sizes.dimensions()[0]; + + // Workspace layout (CUDA-graph safe, all device-side): + // int64_ptr[0 .. n_groups-1] : per-group ROW counts (int64) + // int64_ptr[n_groups .. 2*n_groups] : exclusive prefix-sum offsets (n_groups+1 values) + auto *int64_ptr = reinterpret_cast(int64_workspace->untyped_data()); + auto *offsets_ptr_out = int64_ptr + n_groups; // n_groups+1 values follow group_sizes + + // non_group_m handles multi-dim tensors (e.g., kernel shape G×K×N with flatten_axis=2): + // group_sizes[i] counts "slices" along the outermost group axis (e.g., 1 per expert), + // while the kernel expects actual ROW counts (e.g., K rows per expert). + // non_group_m = product(input_dims[1..flatten_axis)) converts slice→row count. + // For the lhs case (shape M×K, flatten_axis=1), non_group_m=1 (no-op). + int64_t non_group_m = + (flatten_axis > 1) ? product(input_dims, 1, static_cast(flatten_axis)) : 1; + + // Convert int32 group_sizes to int64 row counts on device (CUDA-graph safe, no D2H). + nvte_convert_int32_to_int64_with_multiplier( + reinterpret_cast(group_sizes.untyped_data()), int64_ptr, n_groups, + non_group_m, stream); + + // Compute exclusive prefix-sum offsets on device (CUDA-graph safe, no D2H). + nvte_compute_grouped_tensor_offsets(int64_ptr, offsets_ptr_out, n_groups, static_cast(n), + stream); + + NVTEShape data_shape{}; + data_shape.data[0] = m; + data_shape.data[1] = n; + data_shape.ndim = 2; + + NVTEShape sz_shape{}; + sz_shape.ndim = 1; + sz_shape.data[0] = n_groups; + + // Offsets tensor has n_groups+1 elements (exclusive prefix sums with sentinel). + NVTEShape offsets_shape{}; + offsets_shape.ndim = 1; + offsets_shape.data[0] = n_groups + 1; + + // Build input grouped tensor (plain float data, no quantization on the input side). + GroupedTensorWrapper in_grouped(n_groups, data_shape, + get_nvte_scaling_mode(JAXX_Scaling_Mode::NO_SCALING)); + in_grouped + .set_rowwise_data(reinterpret_cast(inputs.untyped_data()), in_dtype, data_shape) + .set_first_dims(reinterpret_cast(int64_ptr), DType::kInt64, sz_shape) + .set_tensor_offsets(reinterpret_cast(offsets_ptr_out), DType::kInt64, offsets_shape); + + // Build output grouped tensor. + GroupedTensorWrapper out_grouped(n_groups, data_shape, + get_nvte_scaling_mode(JAXX_Scaling_Mode::MXFP8_1D_SCALING)); + out_grouped.set_first_dims(reinterpret_cast(int64_ptr), DType::kInt64, sz_shape) + .set_tensor_offsets(reinterpret_cast(offsets_ptr_out), DType::kInt64, offsets_shape); + + // Rowwise output data + scale_inv. + if (is_quantize_rowwise(quantize_layout)) { + NVTEShape rw_sinv_shape{}; + rw_sinv_shape.ndim = 2; + rw_sinv_shape.data[0] = m; + rw_sinv_shape.data[1] = n / 32; // MXFP8 block size = 32 + out_grouped.set_rowwise_data(rowwise_out->untyped_data(), out_dtype, data_shape) + .set_rowwise_scale_inv(rowwise_sinv->untyped_data(), sinv_dtype, rw_sinv_shape); + } + + // Colwise output data + scale_inv. + if (is_quantize_colwise(quantize_layout)) { + NVTEShape cw_sinv_shape{}; + cw_sinv_shape.ndim = 2; + cw_sinv_shape.data[0] = m / 32; // MXFP8 block size = 32 + cw_sinv_shape.data[1] = n; + out_grouped.set_columnwise_data(colwise_out->untyped_data(), out_dtype, data_shape) + .set_columnwise_scale_inv(colwise_sinv->untyped_data(), sinv_dtype, cw_sinv_shape); + } + + // Zero-initialize scale_inv buffers (mirrors V1 behaviour for MXFP8). + size_t total_rowwise_sinv_size = + is_quantize_rowwise(quantize_layout) ? product(rowwise_sinv->dimensions()) : 0; + size_t total_colwise_sinv_size = + is_quantize_colwise(quantize_layout) ? product(colwise_sinv->dimensions()) : 0; + if (total_rowwise_sinv_size > 0) + nvte_memset(rowwise_sinv->untyped_data(), 0, total_rowwise_sinv_size, stream); + if (total_colwise_sinv_size > 0) + nvte_memset(colwise_sinv->untyped_data(), 0, total_colwise_sinv_size, stream); + + // V2 grouped quantize is always paired with V2 grouped GEMM, which expects + // scale_inv in GEMM-swizzled layout. Enable the fused swizzle so the kernel + // writes scales in the layout the GEMM will consume directly. + out_grouped.set_with_gemm_swizzled_scales(true); + + QuantizationConfigWrapper quant_config{}; + nvte_group_quantize(in_grouped.data(), out_grouped.data(), quant_config, stream); + + return ffi_with_cuda_error_check(); +} + +XLA_FFI_DEFINE_HANDLER_SYMBOL(GroupedQuantizeV2Handler, GroupedQuantizeV2FFI, FFI::Bind() .Ctx() // stream - .Arg() // input - .Arg() // scale - .Arg() // group_sizes - .Ret() // output - .Ret() // colwise output - .Ret() // scale_inv - .Ret() // scale_inv colwise - .Ret() // amax - .Attr("scaling_mode") + .Arg() // inputs + .Arg() // scale (unused, for input arity match) + .Arg() // group_sizes (int32) + .Ret() // rowwise_out + .Ret() // colwise_out + .Ret() // rowwise_sinv + .Ret() // colwise_sinv + .Ret() // updated_amaxs + .Ret() // int64_workspace .Attr("q_layout") - .Attr("flatten_axis")); + .Attr("flatten_axis"), + FFI_CudaGraph_Traits); } // namespace jax } // namespace transformer_engine diff --git a/transformer_engine/jax/flax/module.py b/transformer_engine/jax/flax/module.py index 31ce6e72e9..17c9a242f0 100644 --- a/transformer_engine/jax/flax/module.py +++ b/transformer_engine/jax/flax/module.py @@ -16,6 +16,9 @@ from jax import random as jax_random from jax.ad_checkpoint import checkpoint_name +from transformer_engine.common.recipe import ( + MXFP8BlockScaling, +) from ..dense import dense, grouped_dense @@ -1358,7 +1361,12 @@ def kernel_1_init(key, num_kernels, stack_axis, *init_args): return out, ln_output # Output, layer_norm_output -def wrap_function_in_te_state_module(f, quantization_recipe, name: Optional[str] = None): +def wrap_function_in_te_state_module( + f, + quantization_recipe, + name: Optional[str] = None, + quantization_checkpoint_name: Optional[str] = None, +): """Wraps the given function `f` to support TransformerEngine quantization. This method does a couple things: @@ -1386,6 +1394,7 @@ def generate_quantizer_set(self, postfix: str = "", n_groups: int = None): return super().generate_quantizer_set( postfix=postfix, variable_collection=OVERWRITE_WITH_GRADIENT, + quantization_checkpoint_name=quantization_checkpoint_name, fp8_recipe=quantization_recipe, n_groups=n_groups, ) @@ -1443,10 +1452,15 @@ def te_dot_general(generate_quantizer_set, x, kernel, dims, **kwargs): return wrap_function_in_te_state_module(te_dot_general, quantization_recipe, "dot_general") -def make_grouped_dense_cls(quantization_recipe): +def make_grouped_dense_cls(quantization_recipe, quantization_checkpoint_name: Optional[str] = None): """Creates a grouped dense (grouped GEMM) instance for use with TE state module.""" if quantization_recipe is not None: - raise ValueError("Ragged dot grouped GEMM does not support quantization yet") + allowed_grouped_gemm_recipes = [MXFP8BlockScaling] + assert any(isinstance(quantization_recipe, r) for r in allowed_grouped_gemm_recipes), ( + "Only the following quantization recipes are supported for grouped GEMM or `None` for" + f" BF16 without quantization: {allowed_grouped_gemm_recipes}. Got" + f" {type(quantization_recipe)}." + ) def te_grouped_dot_general(generate_quantizer_set, x, kernel, group_sizes, **kwargs): del kwargs # Unused @@ -1463,5 +1477,8 @@ def te_grouped_dot_general(generate_quantizer_set, x, kernel, group_sizes, **kwa return out return wrap_function_in_te_state_module( - te_grouped_dot_general, quantization_recipe, "ragged_dot" + te_grouped_dot_general, + quantization_recipe, + "ragged_dot", + quantization_checkpoint_name=quantization_checkpoint_name, )() diff --git a/transformer_engine/jax/quantize/dequantizer.py b/transformer_engine/jax/quantize/dequantizer.py index 5abb2e74df..ca44c2e4af 100644 --- a/transformer_engine/jax/quantize/dequantizer.py +++ b/transformer_engine/jax/quantize/dequantizer.py @@ -263,7 +263,37 @@ def dequantize(scaled_tensor): } -@staticmethod +def _unswizzle_mxfp8_grouped_scale(scale_inv_flat, padded_scale_2d, is_colwise): + """Un-swizzle MXFP8 GEMM-swizzled scale_inv back to plain layout. + + Both V1 and V2 MXFP8 grouped quantize produce scale_inv in a GEMM-swizzled + layout. This is the inverse of ``swizzled_scale`` in ``gemm.py``. + + The swizzle pattern (for rowwise) is: + reshape(R//128, 4, 32, C//4, 4) → transpose(0,3,2,1,4) → reshape(R, C) + The inverse is: + reshape(R//128, C//4, 32, 4, 4) → transpose(0,3,2,1,4) → reshape(R, C) + + For colwise the swizzle is applied to the transposed scale, so the inverse + must un-transpose as well. + """ + if is_colwise: + # Colwise forward: reshape_2d → transpose → swizzle_5d → reshape_original + # Inverse: reshape_to_5d → inverse_swizzle → reshape_to_transposed_2d → transpose + cols, rows = padded_scale_2d + scale_2d = scale_inv_flat.reshape(cols, rows) + # The swizzled data lives in the transposed (rows, cols) domain + reshaped = scale_2d.reshape(rows // 128, cols // 4, 32, 4, 4) + unswizzled = jnp.transpose(reshaped, (0, 3, 2, 1, 4)) + # Back to transposed 2D, then un-transpose + return jnp.transpose(unswizzled.reshape(rows, cols)) + + rows, cols = padded_scale_2d + reshaped = scale_inv_flat.reshape(rows // 128, cols // 4, 32, 4, 4) + unswizzled = jnp.transpose(reshaped, (0, 3, 2, 1, 4)) + return unswizzled.reshape(rows, cols) + + def _grouped_dequantize(grouped_scaled_tensor): """Dequantize a grouped tensor. @@ -290,12 +320,13 @@ def _grouped_dequantize(grouped_scaled_tensor): flatten_axis = len(original_shape) + flatten_axis if flatten_axis < 0 else flatten_axis output = [] - # For transposed (colwise) tensors with ragged groups, the group dimension is the last - # axis of original_shape (e.g. original_shape = (N, M) with groups along M), while the - # non-group dimensions are all axes before it. For the uniform-groups case the group - # dimension stays at axis 0, so the existing axis-0 logic applies. + # When data_layout=="T" (colwise, transposed) and first_dims is set (ragged groups), the + # original_shape is stored transposed: the group (variable-size) axis is the LAST dimension + # rather than the first. Non-group dims are original_shape[:-1], not original_shape[1:]. is_transposed_ragged = ( - grouped_scaled_tensor.data_layout == "T" and group_sizes.size != original_shape[0] + grouped_scaled_tensor.data_layout == "T" + and grouped_scaled_tensor.first_dims is not None + and grouped_scaled_tensor.first_dims.size > 0 ) if is_transposed_ragged: non_group_shape = original_shape[:-1] @@ -308,7 +339,7 @@ def _grouped_dequantize(grouped_scaled_tensor): scale_inv_ptr = 0 for i, data_i in enumerate(data): if is_transposed_ragged: - data_shape_i = (*non_group_shape, group_sizes[i]) + data_shape_i = (*non_group_shape, int(group_sizes[i])) else: data_shape_i = ( group_sizes[i], @@ -330,24 +361,49 @@ def _grouped_dequantize(grouped_scaled_tensor): is_padded=False, flatten_axis=flatten_axis, ) - scale_inv_i = scale_inv[ - scale_inv_ptr : scale_inv_ptr + math.prod(padded_scale_shape_i) - ].reshape(padded_scale_shape_i) - scale_inv_i = jax.lax.slice( - scale_inv_i, [0] * len(unpadded_scale_shape_i), unpadded_scale_shape_i - ) + scale_inv_i = scale_inv[scale_inv_ptr : scale_inv_ptr + math.prod(padded_scale_shape_i)] + # MXFP8 grouped quantize (both V1 and V2) always produces GEMM-swizzled + # scales. Detect by scaling_mode (not pre_swizzled, which is only set for V2 + # to maintain pytree compatibility with the GEMM path). + is_colwise = grouped_scaled_tensor.is_colwise + needs_unswizzle = scaling_mode == ScalingMode.MXFP8_1D_SCALING + if needs_unswizzle: + flat_data_2d = ( + math.prod(data_shape_i[:flatten_axis]), + math.prod(data_shape_i[flatten_axis:]), + ) + padded_2d = scaling_mode.get_scale_shape( + flat_data_2d, is_colwise=is_colwise, is_padded=True, flatten_axis=1 + ) + unpadded_2d = scaling_mode.get_scale_shape( + flat_data_2d, is_colwise=is_colwise, is_padded=False, flatten_axis=1 + ) + scale_inv_i = _unswizzle_mxfp8_grouped_scale(scale_inv_i, padded_2d, is_colwise) + scale_inv_i = jax.lax.slice(scale_inv_i, [0, 0], list(unpadded_2d)) + else: + scale_inv_i = scale_inv_i.reshape(padded_scale_shape_i) + scale_inv_i = jax.lax.slice( + scale_inv_i, [0] * len(unpadded_scale_shape_i), unpadded_scale_shape_i + ) dequantizer_type = ScalingModeToDequantizerMap.get(grouped_scaled_tensor.scaling_mode) if len(data_i) == 0: out_i = [] else: + # _dequantize_func is designed for 2D-flattened data. Flatten the + # per-group shape to 2D, dequantize, then reshape back. + flat_shape_i = ( + math.prod(data_shape_i[:flatten_axis]), + math.prod(data_shape_i[flatten_axis:]), + ) out_i = dequantizer_type._dequantize_func( - data_i.reshape(data_shape_i), + data_i.reshape(flat_shape_i), scale_inv_i, grouped_scaled_tensor.dq_dtype, scaling_mode=grouped_scaled_tensor.scaling_mode, is_colwise=grouped_scaled_tensor.is_colwise, - flatten_axis=grouped_scaled_tensor.flatten_axis, + flatten_axis=1, ) + out_i = out_i.reshape(data_shape_i) output.append(out_i) scale_inv_ptr += math.prod(padded_scale_shape_i) diff --git a/transformer_engine/jax/quantize/tensor.py b/transformer_engine/jax/quantize/tensor.py index b1f49dacdc..c5ad0451fd 100644 --- a/transformer_engine/jax/quantize/tensor.py +++ b/transformer_engine/jax/quantize/tensor.py @@ -369,11 +369,15 @@ class GroupedScaledTensor1x(ScaledTensor1x): first_dims: Per-group sizes of the first (row) 2D dim, or None if not ragged last_dims: Per-group sizes of the last (col) 2D dim, or None if not ragged original_shape: The original shape of the tensor before grouping + pre_swizzled: Whether the scale_inv is already swizzled for GEMM. True when produced + by V2 grouped quantize (nvte_group_quantize fuses the swizzle). The V2 grouped + GEMM FFI requires pre_swizzled=True for MXFP8 inputs and will not re-swizzle. """ first_dims: Optional[jnp.ndarray] last_dims: Optional[jnp.ndarray] original_shape: Tuple + pre_swizzled: bool = False def __init__( self, @@ -389,11 +393,13 @@ def __init__( data_layout, flatten_axis, original_shape, + pre_swizzled=False, ): self.flatten_axis = flatten_axis self.first_dims = first_dims self.last_dims = last_dims self.original_shape = original_shape + self.pre_swizzled = pre_swizzled # TODO(Phuong):Handle RHT for grouped quantization once grouped quantization supports NVFP4 super().__init__( data=data, @@ -408,6 +414,18 @@ def __init__( has_rht_applied=False, ) + @property + def group_sizes(self) -> jnp.ndarray: + """Per-group sizes along the group axis. + + When first_dims is set (ragged groups), returns first_dims. + When first_dims is None (equal-sized groups), returns an array of ones with + length equal to the number of groups. + """ + if self.first_dims is not None and self.first_dims.size > 0: + return self.first_dims + return jnp.ones((self.original_shape[0],), dtype=jnp.int32) + def __post_init__(self): assert self.scale_inv.ndim == 1, "Only support flattened scale_inv" assert self.data.ndim == 1, "Only support flattened data" @@ -456,6 +474,7 @@ def tree_flatten(self): self.data_layout, self.flatten_axis, self.original_shape, + self.pre_swizzled, ) return (children, aux_data) @@ -653,6 +672,7 @@ def create_1x( last_dims=None, original_shape=None, has_rht_applied=False, + pre_swizzled=False, ): """Creates a single-scale quantized tensor. @@ -722,6 +742,7 @@ def create_1x( first_dims=first_dims, last_dims=last_dims, original_shape=original_shape, + pre_swizzled=pre_swizzled, ) # Handling attrs of transposed tensors @@ -759,6 +780,7 @@ def create_2x( original_shape=None, rowwise_has_rht_applied=False, colwise_has_rht_applied=False, + pre_swizzled=False, ): """Creates a double-scale quantized tensor. @@ -800,6 +822,7 @@ def create_2x( last_dims=last_dims, original_shape=original_shape, has_rht_applied=rowwise_has_rht_applied, + pre_swizzled=pre_swizzled, ) colwise_tensor = ScaledTensorFactory.create_1x( colwise_data, @@ -814,6 +837,7 @@ def create_2x( last_dims=last_dims, original_shape=original_shape, has_rht_applied=colwise_has_rht_applied, + pre_swizzled=pre_swizzled, ) return ScaledTensor2x(rowwise_tensor, colwise_tensor) @@ -835,6 +859,7 @@ def create( original_shape: Tuple[int] = None, rowwise_has_rht_applied: bool = False, colwise_has_rht_applied: bool = False, + pre_swizzled: bool = False, ): """Creates a scaled tensor based on the quantization axis. @@ -853,6 +878,7 @@ def create( original_shape: The original shape of the tensor before grouping (default: None) rowwise_has_rht_applied: Whether the row-wise tensor uses the Randomized Hadamard Transform (RHT) (default: False) colwise_has_rht_applied: Whether the col-wise tensor uses the Randomized Hadamard Transform (RHT) (default: False) + pre_swizzled: Whether scale_inv is already swizzled (produced by V2 grouped quantize). Returns: Either a ScaledTensor1x or ScaledTensor2x instance depending on q_layout @@ -876,6 +902,7 @@ def create( original_shape=original_shape, rowwise_has_rht_applied=rowwise_has_rht_applied, colwise_has_rht_applied=colwise_has_rht_applied, + pre_swizzled=pre_swizzled, ) if q_layout.is_colwise_only: @@ -892,6 +919,7 @@ def create( last_dims=last_dims, original_shape=original_shape, has_rht_applied=colwise_has_rht_applied, + pre_swizzled=pre_swizzled, ) return ScaledTensorFactory.create_1x( @@ -907,6 +935,7 @@ def create( last_dims=last_dims, original_shape=original_shape, has_rht_applied=rowwise_has_rht_applied, + pre_swizzled=pre_swizzled, )