Skip to content

Commit ef703e5

Browse files
jberchtold-nvidiapre-commit-ci[bot]vthumbe1503greptile-apps[bot]pggPL
authored
[Core] MXFP8 grouped GEMM + tensor-scaled FP8 fixes (NVIDIA#2748)
* MXFP8 grouped GEMM + tensor-scaled FP8 fixes Signed-off-by: Jeremy Berchtold <jberchtold@nvidia.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Change version to 13.3 Co-authored-by: greptile-apps[bot] <165735046+greptile-apps[bot]@users.noreply.github.com> Signed-off-by: vthumbe1503 <vthumbe@nvidia.com> * Random padding condition shouldnt be done for mxfp8 Signed-off-by: vthumbe1503 <vthumbe@nvidia.com> * Remove incorrect comment Signed-off-by: vthumbe1503 <vthumbe@nvidia.com> * CUBLAS > 13.2 is enough Signed-off-by: vthumbe1503 <vthumbe@nvidia.com> * CUBLAS version needed for MXFP8 indeed seems to be 13.3 Signed-off-by: vthumbe1503 <vthumbe@nvidia.com> * Accidental line removal added back. Plus need changes ci t trigger Add documentation for scaling factors in common.h Signed-off-by: vthumbe1503 <vthumbe@nvidia.com> * Update cuBLAS version requirement for MXFP8 support Signed-off-by: vthumbe1503 <vthumbe@nvidia.com> * grouped gemm: address code review comments - Replace nvte_set/get_grouped_tensor_swizzled_scales with nvte_set_grouped_tensor_param - Add host-side validation: A and B must use same scaling mode (both MXFP8 or both tensor scaling) - Add host-side validation: A and B must both be FP8 or both non-FP8; restrict inputs to FP8/BF16 - Restrict output (C/D) to BF16/FP32; remove FP16 from supported types - Refactor workspace allocation: replace manual offset arithmetic with moving pointer pattern - Use void* + NVTEScalingMode in setup kernel instead of separate float*/char* scale params - Extract use_columnwise(swap_dims) helper to eliminate duplicated MXFP8 columnwise blocks - Split set_fp8_scale_pointers into set_fp8_scale_pointers / set_mxfp8_scale_pointers - Remove scale_inv_ptrs from GroupedOperandSelection; pass workspace pointers directly - Move swizzled-scales validation into validate_grouped_gemm_inputs for fail-fast behavior - Add use_split_accumulator to GroupedMatmulConfig (Hopper only, default false) - Add FP8 test case with per-tensor scales; add BF16/MXFP8 shape-varying test cases Signed-off-by: Pawel Gadzinski <pgadzinski@nvidia.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Signed-off-by: Jeremy Berchtold <jberchtold@nvidia.com> Signed-off-by: vthumbe1503 <vthumbe@nvidia.com> Signed-off-by: Pawel Gadzinski <pgadzinski@nvidia.com> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: vthumbe1503 <vthumbe@nvidia.com> Co-authored-by: greptile-apps[bot] <165735046+greptile-apps[bot]@users.noreply.github.com> Co-authored-by: Pawel Gadzinski <pgadzinski@nvidia.com>
1 parent 06a23e3 commit ef703e5

7 files changed

Lines changed: 413 additions & 88 deletions

File tree

tests/cpp/operator/test_grouped_gemm.cu

Lines changed: 102 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
#include <transformer_engine/cast.h>
2121
#include <transformer_engine/gemm.h>
2222
#include <transformer_engine/recipe.h>
23+
#include <transformer_engine/swizzle.h>
2324
#include <transformer_engine/transformer_engine.h>
2425

2526
#include "../test_common.h"
@@ -32,6 +33,7 @@ namespace {
3233
enum class InputCase {
3334
kFP8Current,
3435
kBF16,
36+
kMXFP8,
3537
};
3638

3739
enum class ShapeCase {
@@ -44,16 +46,29 @@ enum class ShapeCase {
4446
size_t grouped_setup_workspace_size(const size_t num_tensors) {
4547
const size_t ptr_bytes = num_tensors * sizeof(void*);
4648
const size_t int_bytes = num_tensors * sizeof(int);
47-
// Layout: 6 pointer arrays (A, B, C, D, alpha, beta) + 6 int arrays (a_rows, a_cols, b_rows, b_cols, d_rows, d_cols)
48-
size_t size = 6 * ptr_bytes + 6 * int_bytes;
49+
// Layout: 8 pointer arrays (A, B, C, D, alpha, beta, a_scale, b_scale) + 6 int arrays
50+
size_t size = 8 * ptr_bytes + 6 * int_bytes;
4951
const size_t alignment = 256;
5052
size = ((size + alignment - 1) / alignment) * alignment;
5153
return size;
5254
}
5355

5456
Tensor make_fp8_operand(const std::string& name, const std::vector<size_t>& shape) {
5557
Tensor input_fp32(name + "_fp32", shape, DType::kFloat32);
56-
fillUniform(&input_fp32);
58+
59+
const size_t numel = shape[0] * shape[1];
60+
std::vector<float> data(numel);
61+
std::mt19937 gen(std::hash<std::string>{}(name));
62+
// Random mean and stddev -> different amax per tensor -> different scales
63+
std::uniform_real_distribution<float> param_dis(0.1f, 10.0f);
64+
float mean = param_dis(gen);
65+
float stddev = param_dis(gen);
66+
std::normal_distribution<float> dis(mean, stddev);
67+
for (size_t i = 0; i < numel; ++i) {
68+
data[i] = dis(gen);
69+
}
70+
NVTE_CHECK_CUDA(cudaMemcpy(input_fp32.rowwise_dptr(), data.data(),
71+
numel * sizeof(float), cudaMemcpyHostToDevice));
5772

5873
Tensor fp8(name, shape, TypeInfo<fp8e4m3>::dtype, true, true, NVTE_DELAYED_TENSOR_SCALING);
5974

@@ -73,6 +88,64 @@ Tensor make_bf16_operand(const std::string& name, const std::vector<size_t>& sha
7388
return t;
7489
}
7590

91+
92+
// Creates an MXFP8 operand with the correct data layout for GEMM.
93+
// MXFP8 GEMM requirements (scales are along K dimension):
94+
// A transposed -> needs rowwise data/scales
95+
// A non-transposed -> needs columnwise data/scales
96+
// B transposed -> needs columnwise data/scales
97+
// B non-transposed -> needs rowwise data/scales
98+
Tensor make_mxfp8_operand(const std::string& name, const std::vector<size_t>& shape,
99+
bool is_A, bool transposed) {
100+
// Determine which data layout we need
101+
bool use_rowwise, use_colwise;
102+
if (is_A) {
103+
// A: transposed -> rowwise, non-transposed -> columnwise
104+
use_rowwise = transposed;
105+
use_colwise = !transposed;
106+
} else {
107+
// B: transposed -> columnwise, non-transposed -> rowwise (opposite of A!)
108+
use_rowwise = !transposed;
109+
use_colwise = transposed;
110+
}
111+
112+
// Create BF16 input with random data
113+
Tensor input_bf16(name + "_bf16", shape, DType::kBFloat16);
114+
fillUniform(&input_bf16);
115+
116+
// Create MXFP8 tensor with only the required data layout
117+
Tensor mxfp8(name, shape, TypeInfo<fp8e4m3>::dtype, use_rowwise, use_colwise,
118+
NVTE_MXFP8_1D_SCALING);
119+
120+
// Quantize BF16 -> MXFP8
121+
nvte_quantize(input_bf16.data(), mxfp8.data(), 0);
122+
123+
// Create output tensor for swizzled scales (same data shape, same layout)
124+
Tensor mxfp8_swizzled(name + "_swizzled", shape, TypeInfo<fp8e4m3>::dtype,
125+
use_rowwise, use_colwise, NVTE_MXFP8_1D_SCALING);
126+
mxfp8_swizzled.set_with_gemm_swizzled_scales(true); // Must be set BEFORE swizzle call
127+
128+
// Copy quantized data from mxfp8 to mxfp8_swizzled
129+
if (use_rowwise) {
130+
size_t data_bytes = test::bytes(mxfp8.rowwise_shape(), mxfp8.dtype());
131+
NVTE_CHECK_CUDA(cudaMemcpy(mxfp8_swizzled.rowwise_dptr(), mxfp8.rowwise_dptr(),
132+
data_bytes, cudaMemcpyDeviceToDevice));
133+
}
134+
if (use_colwise) {
135+
size_t data_bytes = test::bytes(mxfp8.columnwise_shape(), mxfp8.dtype());
136+
NVTE_CHECK_CUDA(cudaMemcpy(mxfp8_swizzled.columnwise_dptr(), mxfp8.columnwise_dptr(),
137+
data_bytes, cudaMemcpyDeviceToDevice));
138+
}
139+
140+
// Swizzle scales for GEMM
141+
nvte_swizzle_scaling_factors(mxfp8.data(), mxfp8_swizzled.data(), 0);
142+
143+
// Sync to ensure operations are complete
144+
NVTE_CHECK_CUDA(cudaDeviceSynchronize());
145+
146+
return mxfp8_swizzled;
147+
}
148+
76149
struct TestParams {
77150
InputCase input_case;
78151
bool transa;
@@ -88,16 +161,16 @@ struct TestParams {
88161
std::vector<std::tuple<size_t, size_t, size_t>> make_shapes(ShapeCase scase) {
89162
switch (scase) {
90163
case ShapeCase::kAllSame:
91-
return {{64, 64, 32}, {64, 64, 32}, {64, 64, 32}};
164+
return {{128, 256, 384}, {128, 256, 384}, {128, 256, 384}};
92165
case ShapeCase::kSameFirst:
93166
// Same M (first dim), varying N and K
94-
return {{64, 80, 32}, {64, 96, 48}, {64, 112, 64}};
167+
return {{128, 256, 384}, {128, 384, 512}, {128, 512, 640}};
95168
case ShapeCase::kSameLast:
96169
// Same N (last dim), varying M and K
97-
return {{64, 80, 32}, {80, 80, 48}, {96, 80, 64}};
170+
return {{128, 256, 384}, {256, 256, 512}, {384, 256, 640}};
98171
case ShapeCase::kAllDifferent:
99172
default:
100-
return {{64, 96, 32}, {80, 112, 48}, {96, 128, 64}};
173+
return {{128, 256, 384}, {256, 384, 512}, {384, 512, 640}};
101174
}
102175
}
103176

@@ -138,6 +211,13 @@ void run_grouped_gemm_case(const TestParams& params) {
138211
B_tensors.emplace_back(make_bf16_operand("B" + std::to_string(i), b_shape));
139212
break;
140213
}
214+
case InputCase::kMXFP8: {
215+
A_tensors.emplace_back(make_mxfp8_operand("A" + std::to_string(i), a_shape,
216+
/*is_A=*/true, params.transa));
217+
B_tensors.emplace_back(make_mxfp8_operand("B" + std::to_string(i), b_shape,
218+
/*is_A=*/false, params.transb));
219+
break;
220+
}
141221
}
142222
D_multi.emplace_back(Tensor("D_multi" + std::to_string(i),
143223
std::vector<size_t>{M, N},
@@ -246,7 +326,9 @@ void run_grouped_gemm_case(const TestParams& params) {
246326
cublas_ws.data(),
247327
nullptr, // config (use defaults)
248328
0);
329+
NVTE_CHECK_CUDA(cudaDeviceSynchronize());
249330

331+
// Compare results
250332
for (size_t i = 0; i < num_gemms; ++i) {
251333
Tensor grouped_split("grouped_D" + std::to_string(i),
252334
std::vector<size_t>{static_cast<size_t>(std::get<0>(shapes[i])),
@@ -277,7 +359,7 @@ TEST_P(GroupedGemmTest, CompareWithMultiTensorGemm) {
277359
}
278360

279361
std::string MakeGroupedGemmTestName(const testing::TestParamInfo<GroupedGemmTest::ParamType>& info) {
280-
constexpr const char* kInputNames[] = {"FP8Current", "BF16"};
362+
constexpr const char* kInputNames[] = {"FP8Current", "BF16", "MXFP8"};
281363
constexpr const char* kShapeNames[] = {"AllSame", "SameM", "SameN", "AllDiff"};
282364
const std::string layout = std::string("ta") + (info.param.transa ? "T" : "N") +
283365
"tb" + (info.param.transb ? "T" : "N");
@@ -288,16 +370,27 @@ std::string MakeGroupedGemmTestName(const testing::TestParamInfo<GroupedGemmTest
288370

289371
// TestParams: {input_case, transa, transb, shape_case, use_null_c}
290372
const std::vector<TestParams> kTestParams = {
291-
// Basic tests
373+
// FP8 tests (each tensor has random mean/stddev -> different scales)
292374
{InputCase::kFP8Current, true, false, ShapeCase::kAllDifferent, false},
293375
{InputCase::kFP8Current, false, true, ShapeCase::kAllDifferent, false},
294376
{InputCase::kFP8Current, false, false, ShapeCase::kAllSame, false},
377+
// BF16 tests
295378
{InputCase::kBF16, true, false, ShapeCase::kSameFirst, false},
296379
{InputCase::kBF16, false, true, ShapeCase::kSameLast, false},
297380
{InputCase::kBF16, false, false, ShapeCase::kAllSame, false},
298381
{InputCase::kBF16, true, true, ShapeCase::kAllDifferent, false},
299382
// Test NULL C (valid when beta=0)
300383
{InputCase::kBF16, false, false, ShapeCase::kAllSame, true},
384+
// MXFP8 tests
385+
{InputCase::kMXFP8, true, false, ShapeCase::kAllSame, false},
386+
{InputCase::kMXFP8, true, false, ShapeCase::kAllDifferent, false},
387+
{InputCase::kMXFP8, false, true, ShapeCase::kAllSame, false},
388+
{InputCase::kMXFP8, false, true, ShapeCase::kAllDifferent, false},
389+
{InputCase::kMXFP8, false, false, ShapeCase::kAllSame, false},
390+
{InputCase::kMXFP8, false, false, ShapeCase::kAllDifferent, false},
391+
{InputCase::kMXFP8, false, false, ShapeCase::kSameFirst, false},
392+
// MXFP8 with NULL C
393+
{InputCase::kMXFP8, true, false, ShapeCase::kAllSame, true},
301394
};
302395

303396
INSTANTIATE_TEST_SUITE_P(OperatorTest,

tests/cpp/test_common.cu

Lines changed: 98 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1061,7 +1061,14 @@ std::array<size_t, 4> get_scale_tensor_dims(const size_t rows,
10611061
GroupedBuffers build_grouped_tensor(const std::vector<Tensor*>& tensors,
10621062
const NVTEScalingMode scaling_mode) {
10631063
NVTE_CHECK(!tensors.empty(), "No tensors provided for grouped tensor build.");
1064-
const NVTEShape shape = tensors[0]->rowwise_shape();
1064+
1065+
// Check which data layouts are available (all tensors must have the same)
1066+
const bool has_rowwise = tensors[0]->rowwise();
1067+
const bool has_columnwise = tensors[0]->columnwise();
1068+
NVTE_CHECK(has_rowwise || has_columnwise, "Tensors must have at least one data layout.");
1069+
1070+
const NVTEShape shape = has_rowwise ? tensors[0]->rowwise_shape()
1071+
: tensors[0]->columnwise_shape();
10651072
const DType dtype = tensors[0]->dtype();
10661073
const size_t num_tensors = tensors.size();
10671074
const size_t elem_size = typeToNumBits(dtype) / 8;
@@ -1076,7 +1083,8 @@ GroupedBuffers build_grouped_tensor(const std::vector<Tensor*>& tensors,
10761083
std::vector<int64_t> first_dims(num_tensors);
10771084
std::vector<int64_t> last_dims(num_tensors);
10781085
for (size_t i = 0; i < num_tensors; ++i) {
1079-
const auto s = tensors[i]->rowwise_shape();
1086+
const auto s = has_rowwise ? tensors[i]->rowwise_shape()
1087+
: tensors[i]->columnwise_shape();
10801088
NVTE_CHECK(s.ndim == 2, "Grouped tensor build expects 2D tensors.");
10811089
first_dims[i] = static_cast<int64_t>(s.data[0]);
10821090
last_dims[i] = static_cast<int64_t>(s.data[1]);
@@ -1105,10 +1113,11 @@ GroupedBuffers build_grouped_tensor(const std::vector<Tensor*>& tensors,
11051113
};
11061114

11071115
const bool need_offsets = !same_first || !same_last;
1116+
const bool use_random_padding = need_offsets && scaling_mode != NVTE_MXFP8_1D_SCALING;
11081117
if (need_offsets) {
11091118
offsets[0] = 0;
11101119
for (size_t i = 1; i < num_tensors; ++i) {
1111-
offsets[i] = offsets[i - 1] + numel(i - 1) + random_padding();
1120+
offsets[i] = offsets[i - 1] + numel(i - 1) + (use_random_padding ? random_padding() : 0);
11121121
}
11131122
} else {
11141123
for (size_t i = 0; i < num_tensors; ++i) {
@@ -1146,21 +1155,24 @@ GroupedBuffers build_grouped_tensor(const std::vector<Tensor*>& tensors,
11461155
: (logical_first * logical_last);
11471156
const size_t total_bytes = static_cast<size_t>(total_elems) * elem_size;
11481157

1149-
grouped.data = cuda_alloc(total_bytes);
1150-
for (size_t i = 0; i < num_tensors; ++i) {
1151-
const size_t offset_bytes = static_cast<size_t>(offsets[i]) * elem_size;
1152-
NVTE_CHECK_CUDA(cudaMemcpy(static_cast<char*>(grouped.data.get()) + offset_bytes,
1153-
tensors[i]->rowwise_dptr(),
1154-
grouped.tensor_bytes[i],
1155-
cudaMemcpyDeviceToDevice));
1156-
}
1157-
1158-
NVTEBasicTensor data_tensor{grouped.data.get(), static_cast<NVTEDType>(dtype), grouped.logical_shape};
11591158
NVTEGroupedTensor h = grouped.handle.get();
1160-
nvte_set_grouped_tensor_param(h, kNVTEGroupedRowwiseData, &data_tensor, sizeof(data_tensor));
11611159

1162-
const bool include_columnwise = isFp8Type(dtype) || isFp4Type(dtype);
1163-
if (include_columnwise) {
1160+
// Copy rowwise data if available
1161+
if (has_rowwise) {
1162+
grouped.data = cuda_alloc(total_bytes);
1163+
for (size_t i = 0; i < num_tensors; ++i) {
1164+
const size_t offset_bytes = static_cast<size_t>(offsets[i]) * elem_size;
1165+
NVTE_CHECK_CUDA(cudaMemcpy(static_cast<char*>(grouped.data.get()) + offset_bytes,
1166+
tensors[i]->rowwise_dptr(),
1167+
grouped.tensor_bytes[i],
1168+
cudaMemcpyDeviceToDevice));
1169+
}
1170+
NVTEBasicTensor data_tensor{grouped.data.get(), static_cast<NVTEDType>(dtype), grouped.logical_shape};
1171+
nvte_set_grouped_tensor_param(h, kNVTEGroupedRowwiseData, &data_tensor, sizeof(data_tensor));
1172+
}
1173+
1174+
// Copy columnwise data if available
1175+
if (has_columnwise) {
11641176
grouped.columnwise_data = cuda_alloc(total_bytes);
11651177
for (size_t i = 0; i < num_tensors; ++i) {
11661178
const size_t offset_bytes = static_cast<size_t>(offsets[i]) * elem_size;
@@ -1202,11 +1214,17 @@ GroupedBuffers build_grouped_tensor(const std::vector<Tensor*>& tensors,
12021214
nvte_set_grouped_tensor_param(h, kNVTEGroupedTensorOffsets, &off_tensor, sizeof(off_tensor));
12031215
}
12041216

1205-
if (isFp8Type(dtype)) {
1217+
if (isFp8Type(dtype) && scaling_mode == NVTE_DELAYED_TENSOR_SCALING) {
1218+
// FP8 tensor scaling: one float scale_inv per tensor
1219+
// For delayed scaling, rowwise and columnwise share the same scale
12061220
std::vector<float> scale_inv_cpu(num_tensors, 1.f);
12071221
for (size_t i = 0; i < num_tensors; ++i) {
12081222
tensors[i]->to_cpu();
1209-
scale_inv_cpu[i] = tensors[i]->rowwise_cpu_scale_inv_ptr<float>()[0];
1223+
if (has_rowwise) {
1224+
scale_inv_cpu[i] = tensors[i]->rowwise_cpu_scale_inv_ptr<float>()[0];
1225+
} else {
1226+
scale_inv_cpu[i] = tensors[i]->columnwise_cpu_scale_inv_ptr<float>()[0];
1227+
}
12101228
}
12111229
grouped.scale_inv = cuda_alloc(sizeof(float) * num_tensors);
12121230
NVTE_CHECK_CUDA(cudaMemcpy(grouped.scale_inv.get(), scale_inv_cpu.data(),
@@ -1217,6 +1235,68 @@ GroupedBuffers build_grouped_tensor(const std::vector<Tensor*>& tensors,
12171235
sizeof(scale_tensor));
12181236
nvte_set_grouped_tensor_param(h, kNVTEGroupedColumnwiseScaleInv, &scale_tensor,
12191237
sizeof(scale_tensor));
1238+
} else if (scaling_mode == NVTE_MXFP8_1D_SCALING) {
1239+
// MXFP8: E8M0 scale_inv per block of 32 elements
1240+
// Helper to gather scale_inv from individual tensors into a contiguous buffer
1241+
auto gather_scales = [&](
1242+
auto get_shape_fn,
1243+
auto get_cpu_ptr_fn) -> std::pair<CudaPtr<>, size_t> {
1244+
// Compute total size and offsets
1245+
size_t total_bytes = 0;
1246+
std::vector<size_t> scale_offsets(num_tensors);
1247+
std::vector<size_t> numels(num_tensors);
1248+
1249+
for (size_t i = 0; i < num_tensors; ++i) {
1250+
scale_offsets[i] = total_bytes;
1251+
const NVTEShape shape = get_shape_fn(tensors[i]);
1252+
size_t numel = 1;
1253+
for (size_t d = 0; d < shape.ndim; ++d) {
1254+
numel *= shape.data[d];
1255+
}
1256+
numels[i] = numel;
1257+
total_bytes += numel; // E8M0 is 1 byte per element
1258+
}
1259+
1260+
// Allocate and copy
1261+
CudaPtr<> buffer = cuda_alloc(total_bytes);
1262+
for (size_t i = 0; i < num_tensors; ++i) {
1263+
tensors[i]->to_cpu();
1264+
NVTE_CHECK_CUDA(cudaGetLastError());
1265+
void* dst = static_cast<char*>(buffer.get()) + scale_offsets[i];
1266+
const void* src = get_cpu_ptr_fn(tensors[i]);
1267+
NVTE_CHECK_CUDA(cudaMemcpy(dst, src, numels[i], cudaMemcpyHostToDevice));
1268+
}
1269+
return {std::move(buffer), total_bytes};
1270+
};
1271+
1272+
// Gather rowwise scale_inv if available
1273+
if (has_rowwise) {
1274+
auto [row_buffer, row_total] = gather_scales(
1275+
[](Tensor* t) { return t->rowwise_scale_inv_shape(); },
1276+
[](Tensor* t) { return t->rowwise_cpu_scale_inv_ptr<uint8_t>(); });
1277+
grouped.scale_inv = std::move(row_buffer);
1278+
1279+
NVTEShape row_shape = nvte_make_shape(&row_total, 1);
1280+
NVTEBasicTensor row_tensor{grouped.scale_inv.get(), kNVTEFloat8E8M0, row_shape};
1281+
nvte_set_grouped_tensor_param(h, kNVTEGroupedRowwiseScaleInv, &row_tensor, sizeof(row_tensor));
1282+
}
1283+
1284+
// Gather columnwise scale_inv if available
1285+
if (has_columnwise) {
1286+
auto [col_buffer, col_total] = gather_scales(
1287+
[](Tensor* t) { return t->columnwise_scale_inv_shape(); },
1288+
[](Tensor* t) { return t->columnwise_cpu_scale_inv_ptr<uint8_t>(); });
1289+
grouped.columnwise_scale_inv = std::move(col_buffer);
1290+
1291+
NVTEShape col_shape = nvte_make_shape(&col_total, 1);
1292+
NVTEBasicTensor col_tensor{grouped.columnwise_scale_inv.get(), kNVTEFloat8E8M0, col_shape};
1293+
nvte_set_grouped_tensor_param(h, kNVTEGroupedColumnwiseScaleInv, &col_tensor, sizeof(col_tensor));
1294+
}
1295+
1296+
// Mark as having swizzled scales (required for GEMM)
1297+
const uint8_t swizzled = 1;
1298+
nvte_set_grouped_tensor_param(h, kNVTEGroupedWithGEMMSwizzledScales, &swizzled,
1299+
sizeof(swizzled));
12201300
}
12211301

12221302
return grouped;

tests/cpp/test_common.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -535,6 +535,7 @@ struct GroupedBuffers {
535535
GroupedTensorHandle handle;
536536
CudaPtr<> data;
537537
CudaPtr<> scale_inv;
538+
CudaPtr<> columnwise_scale_inv;
538539
CudaPtr<int64_t> first_dims_dev;
539540
CudaPtr<int64_t> last_dims_dev;
540541
CudaPtr<int64_t> offsets_dev;

transformer_engine/common/common.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -378,7 +378,8 @@ struct GroupedTensor {
378378
last_dims(nullptr, std::vector<size_t>{0}, DType::kInt64),
379379
tensor_offsets(nullptr, std::vector<size_t>{0}, DType::kInt64),
380380
logical_shape(nvte_make_shape(nullptr, 1)),
381-
nvte_tensor(0) {}
381+
nvte_tensor(0),
382+
with_gemm_swizzled_scales(false) {}
382383

383384
explicit operator NVTEGroupedTensor() const noexcept { return nvte_tensor; }
384385

transformer_engine/common/gemm/config.h

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -44,10 +44,13 @@ struct GroupedMatmulConfig {
4444
// Number of streaming multiprocessors to use in GEMM kernel
4545
int sm_count = 0;
4646

47+
// Split accumulator mode. Only taken into account on Hopper.
48+
bool use_split_accumulator = false;
49+
4750
// Note: API transfers the value type, not std::optional
48-
static constexpr size_t attr_sizes[] = {sizeof(decltype(avg_m)::value_type),
49-
sizeof(decltype(avg_n)::value_type),
50-
sizeof(decltype(avg_k)::value_type), sizeof(sm_count)};
51+
static constexpr size_t attr_sizes[] = {
52+
sizeof(decltype(avg_m)::value_type), sizeof(decltype(avg_n)::value_type),
53+
sizeof(decltype(avg_k)::value_type), sizeof(sm_count), sizeof(uint8_t)};
5154
};
5255

5356
} // namespace transformer_engine

0 commit comments

Comments
 (0)