Skip to content

Commit f002d33

Browse files
committed
refactor(gpu): creating MultiStreamMultiGpu to improve the management of multiple streams per GPU
1 parent efc7ae1 commit f002d33

File tree

3 files changed

+270
-536
lines changed

3 files changed

+270
-536
lines changed

backends/tfhe-cuda-backend/cuda/include/helper_multi_gpu.h

Lines changed: 156 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -183,6 +183,162 @@ struct CudaStreams {
183183
}
184184
};
185185

186+
struct MultiStreamMultiGpu {
187+
private:
188+
CudaStreams *_sub_streams;
189+
uint32_t _num_streams_per_gpu;
190+
uint32_t _num_gpus;
191+
192+
cudaEvent_t _incoming_event;
193+
cudaEvent_t *_outgoing_events;
194+
195+
MultiStreamMultiGpu(const MultiStreamMultiGpu &) = delete;
196+
MultiStreamMultiGpu &operator=(const MultiStreamMultiGpu &) = delete;
197+
198+
public:
199+
MultiStreamMultiGpu(const CudaStreams &base_streams,
200+
uint32_t num_streams_per_gpu) {
201+
202+
_sub_streams = nullptr;
203+
_outgoing_events = nullptr;
204+
_incoming_event = nullptr;
205+
206+
_num_streams_per_gpu = num_streams_per_gpu;
207+
_num_gpus = base_streams.count();
208+
209+
if (num_streams_per_gpu > 0) {
210+
_sub_streams = new CudaStreams[num_streams_per_gpu];
211+
for (uint32_t i = 0; i < num_streams_per_gpu; ++i) {
212+
_sub_streams[i].create_on_same_gpus(base_streams);
213+
}
214+
}
215+
216+
if (_num_gpus > 0) {
217+
_incoming_event = cuda_create_event(base_streams.gpu_index(0));
218+
}
219+
220+
uint32_t total_events = num_streams_per_gpu * _num_gpus;
221+
if (total_events > 0) {
222+
_outgoing_events = new cudaEvent_t[total_events];
223+
for (uint32_t s = 0; s < num_streams_per_gpu; ++s) {
224+
for (uint32_t g = 0; g < _num_gpus; ++g) {
225+
_outgoing_events[s * _num_gpus + g] =
226+
cuda_create_event(base_streams.gpu_index(g));
227+
}
228+
}
229+
}
230+
}
231+
232+
CudaStreams &operator[](uint32_t idx) const {
233+
PANIC_IF_FALSE(idx < _num_streams_per_gpu,
234+
"MultiStreamMultiGpu index out of bounds");
235+
return _sub_streams[idx];
236+
}
237+
238+
uint32_t num_streams() const { return _num_streams_per_gpu; }
239+
240+
void sync_from(const CudaStreams &main_stream) {
241+
cuda_event_record(_incoming_event, main_stream.stream(0),
242+
main_stream.gpu_index(0));
243+
244+
for (uint32_t s = 0; s < _num_streams_per_gpu; ++s) {
245+
for (uint32_t g = 0; g < _num_gpus; ++g) {
246+
cuda_stream_wait_event(_sub_streams[s].stream(g), _incoming_event,
247+
_sub_streams[s].gpu_index(g));
248+
}
249+
}
250+
}
251+
252+
void
253+
sync_specific_streams_from(const CudaStreams &main_stream,
254+
std::initializer_list<uint32_t> stream_indices) {
255+
cuda_event_record(_incoming_event, main_stream.stream(0),
256+
main_stream.gpu_index(0));
257+
258+
for (uint32_t s_idx : stream_indices) {
259+
PANIC_IF_FALSE(s_idx < _num_streams_per_gpu,
260+
"MultiStreamMultiGpu: stream index out of bounds");
261+
262+
for (uint32_t g = 0; g < _num_gpus; ++g) {
263+
cuda_stream_wait_event(_sub_streams[s_idx].stream(g), _incoming_event,
264+
_sub_streams[s_idx].gpu_index(g));
265+
}
266+
}
267+
}
268+
269+
void sync_to(const CudaStreams &main_stream) {
270+
for (uint32_t s = 0; s < _num_streams_per_gpu; ++s) {
271+
for (uint32_t g = 0; g < _num_gpus; ++g) {
272+
cuda_event_record(_outgoing_events[s * _num_gpus + g],
273+
_sub_streams[s].stream(g),
274+
_sub_streams[s].gpu_index(g));
275+
}
276+
}
277+
278+
for (uint32_t s = 0; s < _num_streams_per_gpu; ++s) {
279+
for (uint32_t g = 0; g < _num_gpus; ++g) {
280+
cuda_stream_wait_event(main_stream.stream(0),
281+
_outgoing_events[s * _num_gpus + g],
282+
main_stream.gpu_index(0));
283+
}
284+
}
285+
}
286+
287+
void
288+
sync_specific_streams_to(const CudaStreams &main_stream,
289+
std::initializer_list<uint32_t> stream_indices) {
290+
for (uint32_t s_idx : stream_indices) {
291+
PANIC_IF_FALSE(s_idx < _num_streams_per_gpu,
292+
"MultiStreamMultiGpu: stream index out of bounds");
293+
294+
for (uint32_t g = 0; g < _num_gpus; ++g) {
295+
cuda_event_record(_outgoing_events[s_idx * _num_gpus + g],
296+
_sub_streams[s_idx].stream(g),
297+
_sub_streams[s_idx].gpu_index(g));
298+
}
299+
}
300+
301+
for (uint32_t s_idx : stream_indices) {
302+
for (uint32_t g = 0; g < _num_gpus; ++g) {
303+
cuda_stream_wait_event(main_stream.stream(0),
304+
_outgoing_events[s_idx * _num_gpus + g],
305+
main_stream.gpu_index(0));
306+
}
307+
}
308+
}
309+
310+
void release() {
311+
if (_outgoing_events && _sub_streams) {
312+
for (uint32_t s = 0; s < _num_streams_per_gpu; ++s) {
313+
for (uint32_t g = 0; g < _num_gpus; ++g) {
314+
cuda_event_destroy(_outgoing_events[s * _num_gpus + g],
315+
_sub_streams[s].gpu_index(g));
316+
}
317+
}
318+
delete[] _outgoing_events;
319+
_outgoing_events = nullptr;
320+
}
321+
322+
if (_incoming_event && _sub_streams) {
323+
cuda_event_destroy(_incoming_event, _sub_streams[0].gpu_index(0));
324+
_incoming_event = nullptr;
325+
}
326+
327+
if (_sub_streams) {
328+
for (uint32_t i = 0; i < _num_streams_per_gpu; ++i) {
329+
_sub_streams[i].release();
330+
}
331+
delete[] _sub_streams;
332+
_sub_streams = nullptr;
333+
}
334+
}
335+
336+
~MultiStreamMultiGpu() {
337+
PANIC_IF_FALSE(_sub_streams == nullptr,
338+
"MultiStreamMultiGpu: must call release before destruction");
339+
}
340+
};
341+
186342
struct CudaStreamsBarrier {
187343
private:
188344
std::vector<cudaEvent_t> _events;

0 commit comments

Comments
 (0)