Skip to content

Commit 3c94174

Browse files
Fix: invalid tile size (#1043)
1 parent 7a4d97e commit 3c94174

File tree

6 files changed

+27
-14
lines changed

6 files changed

+27
-14
lines changed

crates/cubecl-convolution/src/components/selection.rs

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,9 @@
11
use cubecl_core::{Runtime, client::ComputeClient};
22
use cubecl_matmul::components::stage::PartitionBuffering;
33

4-
use cubecl_matmul::components::{MatmulElems, MatmulSelection, TilingScheme, adjust_dtypes};
4+
use cubecl_matmul::components::{
5+
MatmulAvailabilityError, MatmulElems, MatmulSelection, TilingScheme, adjust_dtypes,
6+
};
57
use cubecl_matmul::{
68
components::tile::TileMatmulFamily,
79
kernels::layered::{NUM_SM_APPROX, NUM_TENSOR_CORES_APPROX, find_instruction_size},
@@ -79,14 +81,14 @@ pub fn convolution_matmul_selection<TMM: TileMatmulFamily, R: Runtime>(
7981
problem: &ConvolutionProblem,
8082
plane_dim: u32,
8183
dtypes: &mut MatmulElems,
82-
) -> MatmulSelection {
84+
) -> Result<MatmulSelection, MatmulAvailabilityError> {
8385
adjust_dtypes::<R>(client, dtypes, TMM::requires_accelerator());
8486

8587
// rough heuristic based on previous bench results where 512 channels with a 3x3 kernel seemed
8688
// to be the rough cutoff for the k=4 size.
8789
let stage_k = if problem.k >= 4096 { 4 } else { 2 };
8890

89-
let tile_size = find_instruction_size::<R, TMM>(client, dtypes, problem.m, problem.n);
91+
let tile_size = find_instruction_size::<R, TMM>(client, dtypes, problem.m, problem.n)?;
9092

9193
let hardware = &client.properties().hardware;
9294
let num_sm = hardware
@@ -111,7 +113,7 @@ pub fn convolution_matmul_selection<TMM: TileMatmulFamily, R: Runtime>(
111113
.build()
112114
.unwrap();
113115

114-
MatmulSelection::builder(tiling_scheme, plane_dim)
116+
Ok(MatmulSelection::builder(tiling_scheme, plane_dim)
115117
.partition_buffering(PartitionBuffering::Single)
116-
.build()
118+
.build())
117119
}

crates/cubecl-convolution/src/kernels/layered/algorithm/multi_stage_tma.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,6 @@ impl<
7070
problem,
7171
plane_dim,
7272
matmul_elems,
73-
))
73+
)?)
7474
}
7575
}

crates/cubecl-convolution/src/kernels/layered/algorithm/simple.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -70,7 +70,7 @@ impl<
7070
) -> Result<MatmulSelection, MatmulSetupError> {
7171
Ok(convolution_matmul_selection::<TMM, R>(
7272
client, problem, plane_dim, dtypes,
73-
))
73+
)?)
7474
}
7575
}
7676

crates/cubecl-convolution/src/kernels/layered/algorithm/simple_tma.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -80,7 +80,7 @@ impl<
8080
) -> Result<MatmulSelection, MatmulSetupError> {
8181
Ok(convolution_matmul_selection::<TMM, R>(
8282
client, problem, plane_dim, dtypes,
83-
))
83+
)?)
8484
}
8585
}
8686

crates/cubecl-matmul/src/components/error.rs

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,9 @@ pub enum MatmulAvailabilityError {
4141
size: Option<TileSize>,
4242
},
4343

44+
/// Impossible to find a supported tile size for the problem.
45+
TileSizeNotFound,
46+
4447
/// The layout of the matmul is unsupported
4548
LayoutUnsupported {
4649
lhs: MatrixLayout,
@@ -172,6 +175,9 @@ impl Debug for MatmulAvailabilityError {
172175
MatmulAvailabilityError::PlaneOpsUnavailable => {
173176
writeln!(f, "Plane-wide operations like plane_sum are not available.")
174177
}
178+
MatmulAvailabilityError::TileSizeNotFound => {
179+
writeln!(f, "No tile size is available for the problem.")
180+
}
175181
}
176182
}
177183
}

crates/cubecl-matmul/src/kernels/layered/selector/plane.rs

Lines changed: 11 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,7 @@ pub fn plane_matmul_selection<TMM: TileMatmulFamily, R: Runtime>(
5050
));
5151
}
5252

53-
let tile_size = find_instruction_size::<R, TMM>(client, dtypes, problem.m, problem.n);
53+
let tile_size = find_instruction_size::<R, TMM>(client, dtypes, problem.m, problem.n)?;
5454

5555
if options.tiny_selection_enabled && is_tiny(problem, &tile_size) {
5656
return Ok(selection_tiny::<R>(client, problem, tile_size, plane_dim));
@@ -258,7 +258,7 @@ pub fn find_instruction_size<R: Runtime, TMM: TileMatmulFamily>(
258258
elems: &MatmulElems,
259259
m: usize,
260260
n: usize,
261-
) -> TileSize {
261+
) -> Result<TileSize, MatmulAvailabilityError> {
262262
let supported = |m: u32, n: u32, k: u32| {
263263
TMM::is_supported::<R>(
264264
client,
@@ -273,7 +273,7 @@ pub fn find_instruction_size<R: Runtime, TMM: TileMatmulFamily>(
273273
)
274274
};
275275

276-
if m >= 4 * n && supported(32, 8, 16) {
276+
let val = if m >= 4 * n && supported(32, 8, 16) {
277277
(32, 8, 16).into()
278278
} else if n >= 4 * n && supported(8, 32, 16) {
279279
(8, 32, 16).into()
@@ -282,16 +282,21 @@ pub fn find_instruction_size<R: Runtime, TMM: TileMatmulFamily>(
282282
} else if supported(8, 8, 8) {
283283
(8, 8, 8).into()
284284
} else {
285-
TMM::supported_sizes::<R>(
285+
match TMM::supported_sizes::<R>(
286286
client,
287287
elems.lhs_register,
288288
elems.rhs_register,
289289
elems.acc_register,
290290
)
291291
.first()
292292
.copied()
293-
.unwrap_or_else(|| (16, 16, 8).into())
294-
}
293+
{
294+
Some(val) => val,
295+
None => return Err(MatmulAvailabilityError::TileSizeNotFound),
296+
}
297+
};
298+
299+
Ok(val)
295300
}
296301

297302
fn selection_tiny<R: Runtime>(

0 commit comments

Comments
 (0)