@@ -183,6 +183,162 @@ struct CudaStreams {
183183 }
184184};
185185
186+ struct MultiStreamMultiGpu {
187+ private:
188+ CudaStreams *_sub_streams;
189+ uint32_t _num_streams_per_gpu;
190+ uint32_t _num_gpus;
191+
192+ cudaEvent_t _incoming_event;
193+ cudaEvent_t *_outgoing_events;
194+
195+ MultiStreamMultiGpu (const MultiStreamMultiGpu &) = delete ;
196+ MultiStreamMultiGpu &operator =(const MultiStreamMultiGpu &) = delete ;
197+
198+ public:
199+ MultiStreamMultiGpu (const CudaStreams &base_streams,
200+ uint32_t num_streams_per_gpu) {
201+
202+ _sub_streams = nullptr ;
203+ _outgoing_events = nullptr ;
204+ _incoming_event = nullptr ;
205+
206+ _num_streams_per_gpu = num_streams_per_gpu;
207+ _num_gpus = base_streams.count ();
208+
209+ if (num_streams_per_gpu > 0 ) {
210+ _sub_streams = new CudaStreams[num_streams_per_gpu];
211+ for (uint32_t i = 0 ; i < num_streams_per_gpu; ++i) {
212+ _sub_streams[i].create_on_same_gpus (base_streams);
213+ }
214+ }
215+
216+ if (_num_gpus > 0 ) {
217+ _incoming_event = cuda_create_event (base_streams.gpu_index (0 ));
218+ }
219+
220+ uint32_t total_events = num_streams_per_gpu * _num_gpus;
221+ if (total_events > 0 ) {
222+ _outgoing_events = new cudaEvent_t[total_events];
223+ for (uint32_t s = 0 ; s < num_streams_per_gpu; ++s) {
224+ for (uint32_t g = 0 ; g < _num_gpus; ++g) {
225+ _outgoing_events[s * _num_gpus + g] =
226+ cuda_create_event (base_streams.gpu_index (g));
227+ }
228+ }
229+ }
230+ }
231+
232+ CudaStreams &operator [](uint32_t idx) const {
233+ PANIC_IF_FALSE (idx < _num_streams_per_gpu,
234+ " MultiStreamMultiGpu index out of bounds" );
235+ return _sub_streams[idx];
236+ }
237+
238+ uint32_t num_streams () const { return _num_streams_per_gpu; }
239+
240+ void sync_from (const CudaStreams &main_stream) {
241+ cuda_event_record (_incoming_event, main_stream.stream (0 ),
242+ main_stream.gpu_index (0 ));
243+
244+ for (uint32_t s = 0 ; s < _num_streams_per_gpu; ++s) {
245+ for (uint32_t g = 0 ; g < _num_gpus; ++g) {
246+ cuda_stream_wait_event (_sub_streams[s].stream (g), _incoming_event,
247+ _sub_streams[s].gpu_index (g));
248+ }
249+ }
250+ }
251+
252+ void
253+ sync_specific_streams_from (const CudaStreams &main_stream,
254+ std::initializer_list<uint32_t > stream_indices) {
255+ cuda_event_record (_incoming_event, main_stream.stream (0 ),
256+ main_stream.gpu_index (0 ));
257+
258+ for (uint32_t s_idx : stream_indices) {
259+ PANIC_IF_FALSE (s_idx < _num_streams_per_gpu,
260+ " MultiStreamMultiGpu: stream index out of bounds" );
261+
262+ for (uint32_t g = 0 ; g < _num_gpus; ++g) {
263+ cuda_stream_wait_event (_sub_streams[s_idx].stream (g), _incoming_event,
264+ _sub_streams[s_idx].gpu_index (g));
265+ }
266+ }
267+ }
268+
269+ void sync_to (const CudaStreams &main_stream) {
270+ for (uint32_t s = 0 ; s < _num_streams_per_gpu; ++s) {
271+ for (uint32_t g = 0 ; g < _num_gpus; ++g) {
272+ cuda_event_record (_outgoing_events[s * _num_gpus + g],
273+ _sub_streams[s].stream (g),
274+ _sub_streams[s].gpu_index (g));
275+ }
276+ }
277+
278+ for (uint32_t s = 0 ; s < _num_streams_per_gpu; ++s) {
279+ for (uint32_t g = 0 ; g < _num_gpus; ++g) {
280+ cuda_stream_wait_event (main_stream.stream (0 ),
281+ _outgoing_events[s * _num_gpus + g],
282+ main_stream.gpu_index (0 ));
283+ }
284+ }
285+ }
286+
287+ void
288+ sync_specific_streams_to (const CudaStreams &main_stream,
289+ std::initializer_list<uint32_t > stream_indices) {
290+ for (uint32_t s_idx : stream_indices) {
291+ PANIC_IF_FALSE (s_idx < _num_streams_per_gpu,
292+ " MultiStreamMultiGpu: stream index out of bounds" );
293+
294+ for (uint32_t g = 0 ; g < _num_gpus; ++g) {
295+ cuda_event_record (_outgoing_events[s_idx * _num_gpus + g],
296+ _sub_streams[s_idx].stream (g),
297+ _sub_streams[s_idx].gpu_index (g));
298+ }
299+ }
300+
301+ for (uint32_t s_idx : stream_indices) {
302+ for (uint32_t g = 0 ; g < _num_gpus; ++g) {
303+ cuda_stream_wait_event (main_stream.stream (0 ),
304+ _outgoing_events[s_idx * _num_gpus + g],
305+ main_stream.gpu_index (0 ));
306+ }
307+ }
308+ }
309+
310+ void release () {
311+ if (_outgoing_events && _sub_streams) {
312+ for (uint32_t s = 0 ; s < _num_streams_per_gpu; ++s) {
313+ for (uint32_t g = 0 ; g < _num_gpus; ++g) {
314+ cuda_event_destroy (_outgoing_events[s * _num_gpus + g],
315+ _sub_streams[s].gpu_index (g));
316+ }
317+ }
318+ delete[] _outgoing_events;
319+ _outgoing_events = nullptr ;
320+ }
321+
322+ if (_incoming_event && _sub_streams) {
323+ cuda_event_destroy (_incoming_event, _sub_streams[0 ].gpu_index (0 ));
324+ _incoming_event = nullptr ;
325+ }
326+
327+ if (_sub_streams) {
328+ for (uint32_t i = 0 ; i < _num_streams_per_gpu; ++i) {
329+ _sub_streams[i].release ();
330+ }
331+ delete[] _sub_streams;
332+ _sub_streams = nullptr ;
333+ }
334+ }
335+
336+ ~MultiStreamMultiGpu () {
337+ PANIC_IF_FALSE (_sub_streams == nullptr ,
338+ " MultiStreamMultiGpu: must call release before destruction" );
339+ }
340+ };
341+
186342struct CudaStreamsBarrier {
187343private:
188344 std::vector<cudaEvent_t> _events;
0 commit comments