Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 15 additions & 3 deletions tests/cpp/operator/test_grouped_gemm.cu
Original file line number Diff line number Diff line change
Expand Up @@ -178,9 +178,13 @@ void run_grouped_gemm_case(const TestParams& params) {
GTEST_SKIP() << "Grouped GEMM requires cuBLAS 13.3+, but compile-time cuBLAS version is "
<< CUBLAS_VERSION << ".";
#else
if (getDeviceComputeCapability() < blackwellComputeCapability) {
const int compute_capability = getDeviceComputeCapability();
if (compute_capability < blackwellComputeCapability) {
GTEST_SKIP() << "Grouped GEMM requires Blackwell (SM100) or newer.";
}
if (compute_capability == 120) {
GTEST_SKIP() << "Grouped GEMM is currently unsupported on SM120.";
}

const std::vector<std::tuple<size_t, size_t, size_t>> shapes = make_shapes(params.shape_case);

Expand Down Expand Up @@ -356,9 +360,13 @@ void run_grouped_gemm_discrete_out_case(const TestParams& params) {
GTEST_SKIP() << "Grouped GEMM requires cuBLAS 13.3+, but compile-time cuBLAS version is "
<< CUBLAS_VERSION << ".";
#else
if (getDeviceComputeCapability() < blackwellComputeCapability) {
const int compute_capability = getDeviceComputeCapability();
if (compute_capability < blackwellComputeCapability) {
GTEST_SKIP() << "Grouped GEMM requires Blackwell (SM100) or newer.";
}
if (compute_capability == 120) {
GTEST_SKIP() << "Grouped GEMM is currently unsupported on SM120.";
}

const std::vector<std::tuple<size_t, size_t, size_t>> shapes = make_shapes(params.shape_case);

Expand Down Expand Up @@ -527,9 +535,13 @@ void run_grouped_gemm_discrete_in_case(const TestParams& params) {
GTEST_SKIP() << "Grouped GEMM requires cuBLAS 13.3+, but compile-time cuBLAS version is "
<< CUBLAS_VERSION << ".";
#else
if (getDeviceComputeCapability() < blackwellComputeCapability) {
const int compute_capability = getDeviceComputeCapability();
if (compute_capability < blackwellComputeCapability) {
GTEST_SKIP() << "Grouped GEMM requires Blackwell (SM100) or newer.";
}
if (compute_capability == 120) {
GTEST_SKIP() << "Grouped GEMM is currently unsupported on SM120.";
}

const std::vector<std::tuple<size_t, size_t, size_t>> shapes = make_shapes(params.shape_case);

Expand Down
40 changes: 36 additions & 4 deletions tests/pytorch/debug/run_distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,19 @@
fp8_available = is_fp8_available()


def _cmp_dist(ground_truth, output, parallel_mode):
if parallel_mode == "column" and torch.cuda.get_device_capability() == (12, 0):
# SM120: distributed column-parallel path may show a single-element
# activation outlier slightly above default fp32 atol, while grads match.
torch.testing.assert_close(
ground_truth["activation"], output["activation"], atol=1.2e-5, rtol=1.3e-6
)
torch.testing.assert_close(ground_truth["wgrad"], output["wgrad"])
torch.testing.assert_close(ground_truth["dgrad"], output["dgrad"])
else:
_cmp(ground_truth, output)


def _get_tensors(parallel_mode, weight_seed=SEED, data_seed=SEED, tp_size=None, tp_rank=None):
if tp_size is None:
tp_size = WORLD_SIZE
Expand Down Expand Up @@ -445,7 +458,16 @@ def test_disable_fp8_gemms(fprop_fp8, dgrad_fp8, wgrad_fp8, parallel_mode, **kwa

x.grad.zero_()
ground_truth = _emulate_linear_distributed(x, weight, parallel_mode=parallel_mode, **fp8_kwargs)
_cmp(ground_truth, output)
if parallel_mode == "column" and torch.cuda.get_device_capability() == (12, 0):
# SM120: distributed column-parallel path may show a single-element
# activation outlier slightly above default fp32 atol, while grads match.
torch.testing.assert_close(
ground_truth["activation"], output["activation"], atol=1.2e-5, rtol=1.3e-6
)
torch.testing.assert_close(ground_truth["wgrad"], output["wgrad"])
torch.testing.assert_close(ground_truth["dgrad"], output["dgrad"])
else:
_cmp_dist(ground_truth, output, parallel_mode)


@run_debug_test
Expand All @@ -466,7 +488,17 @@ def test_disable_fp8_layer(parallel_mode, **kwargs):
y = _run_forward_backward(x, model, parallel_mode)

output = {"activation": y.clone(), "wgrad": model.weight.grad.clone(), "dgrad": x.grad.clone()}
_cmp(ground_truth, output)
if parallel_mode == "column" and torch.cuda.get_device_capability() == (12, 0):
# SM120: distributed column-parallel path may show a single-element
# activation outlier slightly above default fp32 atol, while grads match.
# Allow for new atol/rtol values (on SM120) = 1.2e-5, 1.3e-6 instead of 1e-5, 1e-6
torch.testing.assert_close(
ground_truth["activation"], output["activation"], atol=1.2e-5, rtol=1.3e-6
)
torch.testing.assert_close(ground_truth["wgrad"], output["wgrad"])
torch.testing.assert_close(ground_truth["dgrad"], output["dgrad"])
else:
_cmp_dist(ground_truth, output, parallel_mode)


@run_debug_test
Expand Down Expand Up @@ -554,7 +586,7 @@ def test_per_tensor_scaling(
x, weight, parallel_mode=parallel_mode, loss_multiplier=LOSS_MULTIPLIER, **fp8_kwargs
)

_cmp(ground_truth, output)
_cmp_dist(ground_truth, output, parallel_mode)


@run_debug_test
Expand Down Expand Up @@ -617,7 +649,7 @@ def test_fake_quant_fp8(
_get_current_scale(x, wgrad_input) if not fp8_kwargs["wgrad_fp8"] else None
)
ground_truth = _emulate_linear_distributed(x, weight, parallel_mode=parallel_mode, **fp8_kwargs)
_cmp(ground_truth, output)
_cmp_dist(ground_truth, output, parallel_mode)


def _init_distributed():
Expand Down
22 changes: 17 additions & 5 deletions tests/pytorch/test_custom_recipe.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,10 +119,17 @@ def test_custom_recipe_grouped_linear_sanity():
num_gemms = 3
in_features = 64
out_features = 64
batch = 32
base = batch // num_gemms
rem = batch % num_gemms
m_splits = [base + (1 if i < rem else 0) for i in range(num_gemms)]
# Use 16-aligned splits on SM120 to satisfy FP8 GEMM leading-dimension requirements in backward.
is_sm120 = torch.cuda.get_device_capability() == (12, 0)
if is_sm120:
split_m = 16
batch = num_gemms * split_m
m_splits = [split_m] * num_gemms
else:
batch = 32
base = batch // num_gemms
rem = batch % num_gemms
m_splits = [base + (1 if i < rem else 0) for i in range(num_gemms)]

model = GroupedLinear(num_gemms, in_features, out_features, params_dtype=torch.bfloat16).cuda()
inp = torch.randn(batch, in_features, device="cuda", dtype=torch.bfloat16, requires_grad=True)
Expand Down Expand Up @@ -272,7 +279,12 @@ def test_custom_recipe_factory_invocation_counts_and_cycling():

in_features = 64
out_features = 64
batch = 8
# Use single-aligned batch on SM120 only.
is_sm120 = torch.cuda.get_device_capability() == (12, 0)
if is_sm120:
batch = 16
else:
batch = 8

op = Linear(in_features, out_features, params_dtype=torch.bfloat16)
inp = torch.randn(batch, in_features, device="cuda", dtype=torch.bfloat16, requires_grad=True)
Expand Down
7 changes: 6 additions & 1 deletion tests/pytorch/test_fusible_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -3238,7 +3238,12 @@ def to_cpu(tensor: Optional[torch.Tensor]) -> Optional[torch.Tensor]:
torch.testing.assert_close(to_cpu(ffn2.weight.grad), w2_ref.grad, **tols)
torch.testing.assert_close(to_cpu(ffn1.weight.grad), w1_ref.grad, **tols)
if bias:
torch.testing.assert_close(to_cpu(ffn1.bias.grad), b1_ref.grad, **tols)
# SM120 NVFP4 can show a tiny outlier for single element in this bias grad entry while all
# other checks stay within the existing loose sanity tolerances.
b1_tols = tols
if quantization == "nvfp4" and torch.cuda.get_device_capability() == (12, 0):
b1_tols = {"rtol": tols["rtol"], "atol": 0.55}
torch.testing.assert_close(to_cpu(ffn1.bias.grad), b1_ref.grad, **b1_tols)
torch.testing.assert_close(to_cpu(ffn2.bias.grad), b2_ref.grad, **tols)

@pytest.mark.parametrize("bias", (False, True))
Expand Down
13 changes: 13 additions & 0 deletions tests/pytorch/test_numerics.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@
nvfp4_available = is_nvfp4_available()

sm_80plus = get_device_compute_capability() >= (8, 0)
sm_120 = get_device_compute_capability() == (12, 0)

seed = 1234
# Reset RNG states.
Expand Down Expand Up @@ -2703,9 +2704,15 @@ def test_transformer_layer_hidden_states_format(dtype, bs, model):
max_seqlen_kv=config.max_seqlen_kv,
)

tols = dtype_tols(dtype)
if sm_120:
# sm120 FusedAttention does not support T3HD/TH3D layouts, so for T3HD/TH3D, the test falls back to using Flash Attn backend
# whereas for BSHD/SBHD, the test uses FusedAttention backend by default. Hence, relaxing the atol tolerance for T3HD/TH3D.
tols["atol"] = max(tols["atol"], 4e-3)
torch.testing.assert_close(
y_bshd,
y_thd.reshape(bs, config.max_seqlen_q, config.hidden_size).contiguous(),
**tols,
)


Expand Down Expand Up @@ -2865,6 +2872,8 @@ def test_grouped_gemm_grouped_tensor(z, m, n, k, case, layout, accumulate) -> No
pytest.skip("Grouped GEMM requires cuBLAS 13.3+.")
if torch.cuda.get_device_capability() < (10, 0):
pytest.skip("Grouped GEMM requires Blackwell (SM100) or newer.")
if torch.cuda.get_device_capability() == (12, 0):
pytest.skip("Grouped GEMM is currently unsupported on SM120.")
if not is_bf16_available():
pytest.skip("bfloat16 is required for grouped GEMM test.")

Expand Down Expand Up @@ -3019,6 +3028,8 @@ def test_grouped_gemm_grouped_tensor_zero_work(layout, accumulate, quant_type) -
"""
if torch.cuda.get_device_capability() < (10, 0):
pytest.skip("Grouped GEMM requires Blackwell (SM100) or newer.")
if torch.cuda.get_device_capability() == (12, 0):
pytest.skip("Grouped GEMM is currently unsupported on SM120.")
if not is_bf16_available():
pytest.skip("bfloat16 is required for grouped GEMM test.")
if quant_type == "mxfp8" and not mxfp8_available:
Expand Down Expand Up @@ -3161,6 +3172,8 @@ def test_grouped_gemm_grouped_tensor_mxfp8(
pytest.skip("Grouped GEMM requires cuBLAS 13.3+.")
if torch.cuda.get_device_capability() < (10, 0):
pytest.skip("Grouped GEMM requires Blackwell (SM100) or newer.")
if torch.cuda.get_device_capability() == (12, 0):
pytest.skip("Grouped GEMM is currently unsupported on SM120.")
if dtype == torch.bfloat16 and not is_bf16_available():
pytest.skip("bfloat16 is required for grouped GEMM test.")

Expand Down
12 changes: 10 additions & 2 deletions transformer_engine/common/cast/dispatch/gated.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,12 @@ void quantize_gated_fwd_helper(const NVTETensor nvte_input, NVTETensor nvte_outp

switch (output->scaling_mode) {
case NVTE_DELAYED_TENSOR_SCALING: {
const bool use_tma_kernels = (cols % 32 == 0) && is_supported_by_CC_100();
//const bool use_tma_kernels = (cols % 32 == 0) && is_supported_by_CC_100();
// sm120 shared memory capapbilities are much smaller than sm100, so we disable TMA kernels on sm120
// KL: It is possible that for fwd, the limits are not exceeded for sm120. To be investigated -
// are there any forward only tests we'd like to keep enabled on sm120?
const bool use_tma_kernels =
(cols % 32 == 0) && is_supported_by_CC_100() && !is_supported_by_CC_120();
if (use_tma_kernels) {
Tensor dummy_grad_tensor;
fp8::cast_gated_tma</*IS_BWD=*/false, ParamOP, ActOP, nullptr>(input, dummy_grad_tensor,
Expand Down Expand Up @@ -137,7 +142,10 @@ void quantize_gated_bwd_helper(const NVTETensor nvte_grad, const NVTETensor nvte

switch (output->scaling_mode) {
case NVTE_DELAYED_TENSOR_SCALING: {
const bool use_tma_kernels = (cols % 32 == 0) && is_supported_by_CC_100();
//const bool use_tma_kernels = (cols % 32 == 0) && is_supported_by_CC_100();
// sm120 shared memory capapbilities are much smaller than sm100, so we disable TMA kernels on sm120
const bool use_tma_kernels =
(cols % 32 == 0) && is_supported_by_CC_100() && !is_supported_by_CC_120();
if (use_tma_kernels) {
fp8::cast_gated_tma</*IS_BWD=*/true, ParamOP, ActOP, DActOP>(gated_input, grad, output, p,
stream);
Expand Down
5 changes: 5 additions & 0 deletions transformer_engine/common/cast/mxfp8/quantize_mxfp8.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -491,6 +491,11 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK)
}
}

// Ensure async shared->global copy is done reading shared source before reuse.
ptx::cp_async_bulk_wait_group_read<0>();
// Ensure all warps reach the reuse boundary before DBIAS scratch writes.
__syncthreads();

parity ^= 1;

if constexpr (IS_DBIAS) {
Expand Down
7 changes: 7 additions & 0 deletions transformer_engine/common/common.cu
Original file line number Diff line number Diff line change
Expand Up @@ -287,6 +287,13 @@ bool is_supported_by_CC_100() {
return deviceComputeCapability >= 100;
}

// KL: test function for CC 120
bool is_supported_by_CC_120() {
int deviceComputeCapability = cuda::sm_arch(cuda::current_device());

return deviceComputeCapability == 120;
}

std::vector<std::vector<Tensor *>> convert_tensor_array(NVTETensor **nvte_tensors,
size_t outer_size, size_t inner_size) {
std::vector<std::vector<Tensor *>> ret;
Expand Down
2 changes: 2 additions & 0 deletions transformer_engine/common/common.h
Original file line number Diff line number Diff line change
Expand Up @@ -1029,6 +1029,8 @@ void create_2D_tensor_map(

bool is_supported_by_CC_100();

bool is_supported_by_CC_120();

std::vector<std::vector<Tensor *>> convert_tensor_array(NVTETensor **nvte_tensors,
size_t outer_size, size_t inner_size);

Expand Down
7 changes: 5 additions & 2 deletions transformer_engine/common/gemm/cublaslt_grouped_gemm.cu
Original file line number Diff line number Diff line change
Expand Up @@ -301,8 +301,11 @@ inline size_t grouped_gemm_setup_workspace_size(size_t num_tensors) {

inline void check_grouped_gemm_requirements(const char *api_name) {
const int current_device = transformer_engine::cuda::current_device();
NVTE_CHECK(transformer_engine::cuda::sm_arch(current_device) >= 100, api_name,
" requires Blackwell (SM100) or newer architecture.");
const int sm_arch = transformer_engine::cuda::sm_arch(current_device);
NVTE_CHECK(sm_arch >= 100, api_name, " requires Blackwell (SM100) or newer architecture.");
NVTE_CHECK(sm_arch != 120, api_name,
" is currently unsupported on SM120. Grouped cuBLASLt GEMM heuristic selection "
"returns CUBLAS_STATUS_NOT_SUPPORTED on this architecture (even with relaxed hints)");
NVTE_CHECK(transformer_engine::cuda::cublas_version() >= CUBLAS_GROUPED_GEMM_VERSION, api_name,
" requires cuBLAS 13.3+, but run-time cuBLAS version is ",
transformer_engine::cuda::cublas_version());
Expand Down
45 changes: 39 additions & 6 deletions transformer_engine/pytorch/csrc/extensions/cast.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,12 @@ py::object quantize(const at::Tensor &tensor, py::handle quantizer, const py::ob

namespace {

inline bool is_sm120_device() {
cudaDeviceProp device_prop{};
NVTE_CHECK_CUDA(cudaGetDeviceProperties(&device_prop, c10::cuda::current_device()));
return device_prop.major == 12 && device_prop.minor == 0;
}

// helper functions for NVFP4 grouped quantization (cuda graph safe with shapes stored in device without D2H copy)
void group_quantize_nvfp4_impl(const GroupedTensorWrapper &grouped_input_tensor,
GroupedTensorWrapper &grouped_output_tensor,
Expand Down Expand Up @@ -1019,6 +1025,7 @@ void split_quantize_nvfp4_impl_with_rht_helper(const TensorWrapper &input,
cudaStream_t stream) {
const size_t num_tensors = split_sections.size();
const auto &quantizer = *quantizers.front();
const bool sm120_device = is_sm120_device();

std::vector<NVTETensor> nvte_tensor_input_list;
std::vector<NVTETensor> nvte_tensor_output_list;
Expand All @@ -1031,6 +1038,8 @@ void split_quantize_nvfp4_impl_with_rht_helper(const TensorWrapper &input,
bool all_aligned_token_dim =
std::all_of(split_sections.begin(), split_sections.end(),
[](size_t split_section) { return split_section % 128 == 0; });
// SM120 fallback: avoid the fully fused grouped row+col RHT kernel path.
all_aligned_token_dim = all_aligned_token_dim && !sm120_device;

// in the case when rowwise and colwise cannot be fused, we have to generate the RNG states twice
// so that rowwise and colwise will have different random numbers
Expand All @@ -1049,7 +1058,7 @@ void split_quantize_nvfp4_impl_with_rht_helper(const TensorWrapper &input,
bool with_bulk_generate_rng_states = true;

// Stochastic rounding
bool need_stochastic_rounding = quantizer.stochastic_rounding;
bool need_stochastic_rounding = quantizer.stochastic_rounding && !sm120_device;
auto stochastic_rng_state_resources = setup_stochastic_rounding_rng_states_helper(
num_tensors, need_stochastic_rounding, with_bulk_generate_rng_states,
need_separate_rng_states, quant_config_list, quant_config_list_colwise);
Expand Down Expand Up @@ -1138,6 +1147,8 @@ void split_quantize_nvfp4_impl_with_rht_helper(const TensorWrapper &input,
if (quantizer.columnwise_usage) {
std::vector<TensorWrapper> out_transpose_list;
std::vector<NVTETensor> nvte_tensor_out_transpose_list;
std::vector<at::Tensor> rht_output_t_tensors;
rht_output_t_tensors.reserve(num_tensors);
for (size_t i = 0; i < num_tensors; i++) {
bool is_empty_split = input_list[i].numel() == 0;
auto out_columnwise_data = output_list[i].get_columnwise_data();
Expand Down Expand Up @@ -1169,10 +1180,31 @@ void split_quantize_nvfp4_impl_with_rht_helper(const TensorWrapper &input,
out_transpose_list.emplace_back(std::move(out_transpose));
nvte_tensor_out_transpose_list.push_back(out_transpose_list.back().data());
}
nvte_group_hadamard_transform_cast_fusion_columnwise(
input.data(), reinterpret_cast<NVTETensor *>(nvte_tensor_out_transpose_list.data()),
rht_matrix_nvte.data(), split_sections.data(), num_tensors,
quant_config_list_colwise_to_use[0], stream);
if (sm120_device) {
// SM120 fallback: avoid grouped columnwise RHT fusion path and run unfused per split.
for (size_t i = 0; i < num_tensors; i++) {
if (input_list[i].numel() == 0) {
continue;
}
const int rows = static_cast<int>(split_sections[i]);
const int cols = static_cast<int>(input_list[i].size(input_list[i].ndim() - 1));
auto rht_output_t = allocateTorchTensor(cols, rows, input_list[i].dtype());
rht_output_t_tensors.push_back(rht_output_t);
TensorWrapper rht_output_t_cpp;
rht_output_t_cpp.set_rowwise_data(
rht_output_t.data_ptr(), input_list[i].dtype(),
std::vector<size_t>{static_cast<size_t>(cols), static_cast<size_t>(rows)});
nvte_hadamard_transform(input_list[i].data(), rht_output_t_cpp.data(), 0,
quantizer.rht_matrix_random_sign_mask_t, stream);
nvte_quantize_v2(rht_output_t_cpp.data(), out_transpose_list[i].data(),
quant_config_list_colwise_to_use[i], stream);
}
} else {
nvte_group_hadamard_transform_cast_fusion_columnwise(
input.data(), reinterpret_cast<NVTETensor *>(nvte_tensor_out_transpose_list.data()),
rht_matrix_nvte.data(), split_sections.data(), num_tensors,
quant_config_list_colwise_to_use[0], stream);
}
}
}
}
Expand All @@ -1185,6 +1217,7 @@ void split_quantize_nvfp4_impl_helper(const TensorWrapper &input,
cudaStream_t stream) {
const size_t num_tensors = input_list.size();
const auto &quantizer = *quantizers.front();
const bool sm120_device = is_sm120_device();

std::vector<NVTETensor> nvte_tensor_input_list;
std::vector<NVTETensor> nvte_tensor_output_list;
Expand All @@ -1207,7 +1240,7 @@ void split_quantize_nvfp4_impl_helper(const TensorWrapper &input,
// so that we can generate all rng states at once
bool with_bulk_generate_rng_states = false;

bool need_stochastic_rounding = quantizer.stochastic_rounding;
bool need_stochastic_rounding = quantizer.stochastic_rounding && !sm120_device;

// place holder for colwise rng states, which are not needed in this case
std::vector<QuantizationConfigWrapper> dummy_quant_config_list_colwise;
Expand Down
Loading