Skip to content

Commit d9d056e

Browse files
committed
cust: use usize in BlockSize and GridSize.
Instead of `u32`, because using `usize` for dimensions and indices is more natural in Rust and avoids lots of casts.
1 parent 55cd2bc commit d9d056e

File tree

7 files changed

+67
-111
lines changed

7 files changed

+67
-111
lines changed

crates/cust/src/function.rs

Lines changed: 35 additions & 83 deletions
Original file line numberDiff line numberDiff line change
@@ -18,43 +18,43 @@ use crate::module::Module;
1818
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
1919
pub struct GridSize {
2020
/// Width of grid in blocks
21-
pub x: u32,
21+
pub x: usize,
2222
/// Height of grid in blocks
23-
pub y: u32,
23+
pub y: usize,
2424
/// Depth of grid in blocks
25-
pub z: u32,
25+
pub z: usize,
2626
}
2727
impl GridSize {
2828
/// Create a one-dimensional grid of `x` blocks
2929
#[inline]
30-
pub fn x(x: u32) -> GridSize {
30+
pub fn x(x: usize) -> GridSize {
3131
GridSize { x, y: 1, z: 1 }
3232
}
3333

3434
/// Create a two-dimensional grid of `x * y` blocks
3535
#[inline]
36-
pub fn xy(x: u32, y: u32) -> GridSize {
36+
pub fn xy(x: usize, y: usize) -> GridSize {
3737
GridSize { x, y, z: 1 }
3838
}
3939

4040
/// Create a three-dimensional grid of `x * y * z` blocks
4141
#[inline]
42-
pub fn xyz(x: u32, y: u32, z: u32) -> GridSize {
42+
pub fn xyz(x: usize, y: usize, z: usize) -> GridSize {
4343
GridSize { x, y, z }
4444
}
4545
}
46-
impl From<u32> for GridSize {
47-
fn from(x: u32) -> GridSize {
46+
impl From<usize> for GridSize {
47+
fn from(x: usize) -> GridSize {
4848
GridSize::x(x)
4949
}
5050
}
51-
impl From<(u32, u32)> for GridSize {
52-
fn from((x, y): (u32, u32)) -> GridSize {
51+
impl From<(usize, usize)> for GridSize {
52+
fn from((x, y): (usize, usize)) -> GridSize {
5353
GridSize::xy(x, y)
5454
}
5555
}
56-
impl From<(u32, u32, u32)> for GridSize {
57-
fn from((x, y, z): (u32, u32, u32)) -> GridSize {
56+
impl From<(usize, usize, usize)> for GridSize {
57+
fn from((x, y, z): (usize, usize, usize)) -> GridSize {
5858
GridSize::xyz(x, y, z)
5959
}
6060
}
@@ -64,52 +64,28 @@ impl From<&GridSize> for GridSize {
6464
}
6565
}
6666
#[cfg(feature = "vek")]
67-
impl From<vek::Vec2<u32>> for GridSize {
68-
fn from(vec: vek::Vec2<u32>) -> Self {
69-
GridSize::xy(vec.x, vec.y)
70-
}
71-
}
72-
#[cfg(feature = "vek")]
73-
impl From<vek::Vec3<u32>> for GridSize {
74-
fn from(vec: vek::Vec3<u32>) -> Self {
75-
GridSize::xyz(vec.x, vec.y, vec.z)
76-
}
77-
}
78-
#[cfg(feature = "vek")]
7967
impl From<vek::Vec2<usize>> for GridSize {
8068
fn from(vec: vek::Vec2<usize>) -> Self {
81-
GridSize::xy(vec.x as u32, vec.y as u32)
69+
GridSize::xy(vec.x, vec.y)
8270
}
8371
}
8472
#[cfg(feature = "vek")]
8573
impl From<vek::Vec3<usize>> for GridSize {
8674
fn from(vec: vek::Vec3<usize>) -> Self {
87-
GridSize::xyz(vec.x as u32, vec.y as u32, vec.z as u32)
88-
}
89-
}
90-
91-
#[cfg(feature = "glam")]
92-
impl From<glam::UVec2> for GridSize {
93-
fn from(vec: glam::UVec2) -> Self {
94-
GridSize::xy(vec.x, vec.y)
95-
}
96-
}
97-
#[cfg(feature = "glam")]
98-
impl From<glam::UVec3> for GridSize {
99-
fn from(vec: glam::UVec3) -> Self {
10075
GridSize::xyz(vec.x, vec.y, vec.z)
10176
}
10277
}
78+
10379
#[cfg(feature = "glam")]
10480
impl From<glam::USizeVec2> for GridSize {
10581
fn from(vec: glam::USizeVec2) -> Self {
106-
GridSize::xy(vec.x as u32, vec.y as u32)
82+
GridSize::xy(vec.x, vec.y)
10783
}
10884
}
10985
#[cfg(feature = "glam")]
11086
impl From<glam::USizeVec3> for GridSize {
11187
fn from(vec: glam::USizeVec3) -> Self {
112-
GridSize::xyz(vec.x as u32, vec.y as u32, vec.z as u32)
88+
GridSize::xyz(vec.x, vec.y, vec.z)
11389
}
11490
}
11591

@@ -123,43 +99,43 @@ impl From<glam::USizeVec3> for GridSize {
12399
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
124100
pub struct BlockSize {
125101
/// X dimension of each thread block
126-
pub x: u32,
102+
pub x: usize,
127103
/// Y dimension of each thread block
128-
pub y: u32,
104+
pub y: usize,
129105
/// Z dimension of each thread block
130-
pub z: u32,
106+
pub z: usize,
131107
}
132108
impl BlockSize {
133109
/// Create a one-dimensional block of `x` threads
134110
#[inline]
135-
pub fn x(x: u32) -> BlockSize {
111+
pub fn x(x: usize) -> BlockSize {
136112
BlockSize { x, y: 1, z: 1 }
137113
}
138114

139115
/// Create a two-dimensional block of `x * y` threads
140116
#[inline]
141-
pub fn xy(x: u32, y: u32) -> BlockSize {
117+
pub fn xy(x: usize, y: usize) -> BlockSize {
142118
BlockSize { x, y, z: 1 }
143119
}
144120

145121
/// Create a three-dimensional block of `x * y * z` threads
146122
#[inline]
147-
pub fn xyz(x: u32, y: u32, z: u32) -> BlockSize {
123+
pub fn xyz(x: usize, y: usize, z: usize) -> BlockSize {
148124
BlockSize { x, y, z }
149125
}
150126
}
151-
impl From<u32> for BlockSize {
152-
fn from(x: u32) -> BlockSize {
127+
impl From<usize> for BlockSize {
128+
fn from(x: usize) -> BlockSize {
153129
BlockSize::x(x)
154130
}
155131
}
156-
impl From<(u32, u32)> for BlockSize {
157-
fn from((x, y): (u32, u32)) -> BlockSize {
132+
impl From<(usize, usize)> for BlockSize {
133+
fn from((x, y): (usize, usize)) -> BlockSize {
158134
BlockSize::xy(x, y)
159135
}
160136
}
161-
impl From<(u32, u32, u32)> for BlockSize {
162-
fn from((x, y, z): (u32, u32, u32)) -> BlockSize {
137+
impl From<(usize, usize, usize)> for BlockSize {
138+
fn from((x, y, z): (usize, usize, usize)) -> BlockSize {
163139
BlockSize::xyz(x, y, z)
164140
}
165141
}
@@ -169,52 +145,28 @@ impl From<&BlockSize> for BlockSize {
169145
}
170146
}
171147
#[cfg(feature = "vek")]
172-
impl From<vek::Vec2<u32>> for BlockSize {
173-
fn from(vec: vek::Vec2<u32>) -> Self {
174-
BlockSize::xy(vec.x, vec.y)
175-
}
176-
}
177-
#[cfg(feature = "vek")]
178-
impl From<vek::Vec3<u32>> for BlockSize {
179-
fn from(vec: vek::Vec3<u32>) -> Self {
180-
BlockSize::xyz(vec.x, vec.y, vec.z)
181-
}
182-
}
183-
#[cfg(feature = "vek")]
184148
impl From<vek::Vec2<usize>> for BlockSize {
185149
fn from(vec: vek::Vec2<usize>) -> Self {
186-
BlockSize::xy(vec.x as u32, vec.y as u32)
150+
BlockSize::xy(vec.x, vec.y)
187151
}
188152
}
189153
#[cfg(feature = "vek")]
190154
impl From<vek::Vec3<usize>> for BlockSize {
191155
fn from(vec: vek::Vec3<usize>) -> Self {
192-
BlockSize::xyz(vec.x as u32, vec.y as u32, vec.z as u32)
193-
}
194-
}
195-
196-
#[cfg(feature = "glam")]
197-
impl From<glam::UVec2> for BlockSize {
198-
fn from(vec: glam::UVec2) -> Self {
199-
BlockSize::xy(vec.x, vec.y)
200-
}
201-
}
202-
#[cfg(feature = "glam")]
203-
impl From<glam::UVec3> for BlockSize {
204-
fn from(vec: glam::UVec3) -> Self {
205156
BlockSize::xyz(vec.x, vec.y, vec.z)
206157
}
207158
}
159+
208160
#[cfg(feature = "glam")]
209161
impl From<glam::USizeVec2> for BlockSize {
210162
fn from(vec: glam::USizeVec2) -> Self {
211-
BlockSize::xy(vec.x as u32, vec.y as u32)
163+
BlockSize::xy(vec.x, vec.y)
212164
}
213165
}
214166
#[cfg(feature = "glam")]
215167
impl From<glam::USizeVec3> for BlockSize {
216168
fn from(vec: glam::USizeVec3) -> Self {
217-
BlockSize::xyz(vec.x as u32, vec.y as u32, vec.z as u32)
169+
BlockSize::xyz(vec.x, vec.y, vec.z)
218170
}
219171
}
220172

@@ -448,7 +400,7 @@ impl Function<'_> {
448400
&self,
449401
dynamic_smem_size: usize,
450402
block_size_limit: BlockSize,
451-
) -> CudaResult<(u32, u32)> {
403+
) -> CudaResult<(usize, usize)> {
452404
let mut min_grid_size = MaybeUninit::uninit();
453405
let mut block_size = MaybeUninit::uninit();
454406

@@ -465,8 +417,8 @@ impl Function<'_> {
465417
)
466418
.to_result()?;
467419
Ok((
468-
min_grid_size.assume_init() as u32,
469-
block_size.assume_init() as u32,
420+
min_grid_size.assume_init() as usize,
421+
block_size.assume_init() as usize,
470422
))
471423
}
472424
}

crates/cust/src/graph.rs

Lines changed: 16 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -85,12 +85,12 @@ impl KernelInvocation {
8585
pub fn to_raw(self) -> driver_sys::CUDA_KERNEL_NODE_PARAMS {
8686
driver_sys::CUDA_KERNEL_NODE_PARAMS {
8787
func: self.func,
88-
gridDimX: self.grid_dim.x,
89-
gridDimY: self.grid_dim.y,
90-
gridDimZ: self.grid_dim.z,
91-
blockDimX: self.block_dim.x,
92-
blockDimY: self.block_dim.y,
93-
blockDimZ: self.block_dim.z,
88+
gridDimX: self.grid_dim.x as u32,
89+
gridDimY: self.grid_dim.y as u32,
90+
gridDimZ: self.grid_dim.z as u32,
91+
blockDimX: self.block_dim.x as u32,
92+
blockDimY: self.block_dim.y as u32,
93+
blockDimZ: self.block_dim.z as u32,
9494
kernelParams: Box::into_raw(self.params),
9595
sharedMemBytes: self.shared_mem_bytes,
9696
extra: ptr::null_mut(),
@@ -109,8 +109,16 @@ impl KernelInvocation {
109109
pub unsafe fn from_raw(raw: driver_sys::CUDA_KERNEL_NODE_PARAMS) -> Self {
110110
Self {
111111
func: raw.func,
112-
grid_dim: GridSize::xyz(raw.gridDimX, raw.gridDimY, raw.gridDimZ),
113-
block_dim: BlockSize::xyz(raw.blockDimX, raw.gridDimY, raw.gridDimZ),
112+
grid_dim: GridSize::xyz(
113+
raw.gridDimX as usize,
114+
raw.gridDimY as usize,
115+
raw.gridDimZ as usize,
116+
),
117+
block_dim: BlockSize::xyz(
118+
raw.blockDimX as usize,
119+
raw.gridDimY as usize,
120+
raw.gridDimZ as usize,
121+
),
114122
params: Box::from_raw(raw.kernelParams),
115123
shared_mem_bytes: raw.sharedMemBytes,
116124
params_len: None,

crates/cust/src/stream.rs

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -274,12 +274,12 @@ impl Stream {
274274
unsafe {
275275
driver_sys::cuLaunchKernel(
276276
func.to_raw(),
277-
grid_size.x,
278-
grid_size.y,
279-
grid_size.z,
280-
block_size.x,
281-
block_size.y,
282-
block_size.z,
277+
grid_size.x as u32,
278+
grid_size.y as u32,
279+
grid_size.z as u32,
280+
block_size.x as u32,
281+
block_size.y as u32,
282+
block_size.z as u32,
283283
shared_mem_bytes,
284284
self.inner,
285285
args.as_ptr() as *mut _,

examples/gemm/src/main.rs

Lines changed: 5 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -365,16 +365,12 @@ pub fn gemm_naive(
365365
// This will try to maximize how much of the GPU is used by finding the best launch configuration for the
366366
// current CUDA device/architecture.
367367
let (_, block_size) = kernel.suggested_launch_configuration(0, 0.into())?;
368-
let block_size = block_size as usize;
369368
let (block_size_x, block_size_y) = if block_size > m * n {
370-
(block_size.div_ceil(m) as u32, m as u32)
369+
(block_size.div_ceil(m), m)
371370
} else {
372-
(1, block_size as u32)
371+
(1, block_size)
373372
};
374-
let (grid_size_x, grid_size_y) = (
375-
(m as u32).div_ceil(block_size_x),
376-
(n as u32).div_ceil(block_size_y),
377-
);
373+
let (grid_size_x, grid_size_y) = (m.div_ceil(block_size_x), n.div_ceil(block_size_y));
378374
unsafe {
379375
launch!(
380376
kernel<<<
@@ -438,12 +434,12 @@ pub fn gemm_tiled(
438434
});
439435
let kernel = &*kernel_cell;
440436

441-
let (grid_size_x, grid_size_y) = (n.div_ceil(TILE_SIZE) as u32, m.div_ceil(TILE_SIZE) as u32);
437+
let (grid_size_x, grid_size_y) = (n.div_ceil(TILE_SIZE), m.div_ceil(TILE_SIZE));
442438
unsafe {
443439
launch!(
444440
kernel<<<
445441
(grid_size_x, grid_size_y),
446-
(TILE_SIZE as u32, TILE_SIZE as u32),
442+
(TILE_SIZE, TILE_SIZE),
447443
0,
448444
stream
449445
>>>(

examples/i128_demo/src/main.rs

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -119,8 +119,8 @@ fn main() -> Result<(), Box<dyn Error>> {
119119
let urem_gpu = DeviceBuffer::from_slice(&vec![0u128; len])?;
120120
let srem_gpu = DeviceBuffer::from_slice(&vec![0u128; len])?;
121121

122-
let block_size = 128u32;
123-
let grid_size = (len as u32).div_ceil(block_size);
122+
let block_size = 128usize;
123+
let grid_size = len.div_ceil(block_size);
124124

125125
unsafe {
126126
launch!(
@@ -241,7 +241,7 @@ fn main() -> Result<(), Box<dyn Error>> {
241241

242242
let trap_launch = unsafe {
243243
launch!(
244-
kernel<<<1u32, 1u32, 0, trap_stream>>>(
244+
kernel<<<1, 1, 0, trap_stream>>>(
245245
trap_a.as_device_ptr(),
246246
trap_a.len(),
247247
trap_b.as_device_ptr(),

examples/vecadd/src/main.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,7 @@ fn main() -> Result<(), Box<dyn Error>> {
4545
// current CUDA device/architecture.
4646
let (_, block_size) = vecadd.suggested_launch_configuration(0, 0.into())?;
4747

48-
let grid_size = (NUMBERS_LEN as u32).div_ceil(block_size);
48+
let grid_size = NUMBERS_LEN.div_ceil(block_size);
4949

5050
println!("using {grid_size} blocks and {block_size} threads per block");
5151

samples/introduction/async_api/src/main.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@ fn main() -> Result<(), cust::error::CudaError> {
3737
let value = 26;
3838

3939
let blocks = BlockSize::xy(512, 1);
40-
let grids = GridSize::xy((N / (blocks.x as usize)).try_into().unwrap(), 1);
40+
let grids = GridSize::xy(N / blocks.x, 1);
4141

4242
let start_event = Event::new(EventFlags::DEFAULT)?;
4343
let stop_event = Event::new(EventFlags::DEFAULT)?;

0 commit comments

Comments
 (0)