@@ -280,27 +280,11 @@ impl Scheduler {
280280
281281 loop {
282282 // 1. Receive requests
283- if state. is_empty ( ) {
284- // Fully idle - block until new request arrives
285- tokio:: select! {
286- biased;
287- Some ( request) = request_rx. recv( ) => {
288- state. receive( request) ;
289- }
290- _ = cancel_token_clone. cancelled( ) => {
291- break ;
292- }
293- }
294- } else {
295- // Has active/waiting work - collect any pending requests without blocking
296- while let Ok ( request) = request_rx. try_recv ( ) {
297- state. receive ( request) ;
298- }
299-
300- // Check for cancellation
301- if cancel_token_clone. is_cancelled ( ) {
302- break ;
303- }
283+ if receive_requests ( & mut state, & mut request_rx, & cancel_token_clone)
284+ . await
285+ . is_none ( )
286+ {
287+ break ;
304288 }
305289
306290 // Start timing for this forward pass (schedule + simulate)
@@ -310,106 +294,30 @@ impl Scheduler {
310294 try_schedule ( & mut state, & kv_manager, & mut hit_rates, & args) ;
311295
312296 // 3. Simulate prefill + decode
313- let mut total_time = Duration :: ZERO ;
314-
315- // Process prefilling
316- while let Some ( ( prefill_compute, maybe_creation_signal, is_full_prefill) ) =
317- state. try_prefill ( & args. perf_model )
318- {
319- // NOTE: Prefill cost/time is always incremented for new blocks, even if they
320- // could be cached by other requests in the same batch. This matches vLLM behavior.
321- // For decode workers, skip adding prefill compute time
322- if args. worker_type != WorkerType :: Decode {
323- total_time += Duration :: from_secs_f64 ( prefill_compute / 1000.0 ) ;
324- }
325-
326- if let Some ( creation_signal) = maybe_creation_signal
327- && !process_signals ( & mut kv_manager, std:: slice:: from_ref ( & creation_signal) )
328- {
329- panic ! ( "Block allocation for prefilling cannot fail." ) ;
330- }
331-
332- // Impossible to schedule more prefills if we encounter one incomplete (chunked) prefill
333- if !is_full_prefill {
334- break ;
335- }
336- }
337-
338- // Compute decode timing
339- let active_kv_tokens = kv_manager. num_active_blocks ( ) * args. block_size ;
340- // Compute average context length across all active decode requests
341- let ( total_length, count) = state
342- . decode
343- . keys ( )
344- . filter_map ( |uuid| state. requests . get ( uuid) )
345- . fold ( ( 0usize , 0usize ) , |( sum, cnt) , req| {
346- if let Request :: Active ( seq) = req {
347- ( sum + seq. len ( ) , cnt + 1 )
348- } else {
349- ( sum, cnt)
350- }
351- } ) ;
352- let context_length = if count > 0 { total_length / count } else { 0 } ;
353- let decoding_time = args
354- . perf_model
355- . predict_decode_time ( active_kv_tokens, context_length) ;
356- total_time += Duration :: from_secs_f64 ( decoding_time / 1000.0 ) ;
357-
358- state. reset_active_tokens ( ) ;
359-
360- // Process decoding
361- let uuids: Vec < Uuid > = state. decode . keys ( ) . cloned ( ) . collect ( ) ;
362- for uuid in uuids {
363- let Some ( sequence) = state. run ( uuid) else {
364- continue ;
365- } ;
366- let signals = sequence. generate ( ) ;
367-
368- // Process all signals with the KvManager
369- // Handling of preemption on failure
370- if !process_signals ( & mut kv_manager, & signals) {
371- sequence. pop ( ) ; // revert the failed generation op
372- for signal in state. preempt ( ) {
373- kv_manager. process ( & signal) ;
374- }
375- continue ;
376- }
377-
378- // Check completion and send notification
379- let is_complete = sequence. generated_tokens ( ) >= sequence. max_output_tokens ( ) ;
380- let should_output =
381- sequence. generated_tokens ( ) > sequence. already_generated_tokens ( ) ;
382-
383- let mut send_failed = false ;
384- if should_output {
385- send_failed = output_tx. as_ref ( ) . is_some_and ( |tx| {
386- tx. send ( OutputSignal {
387- uuid,
388- completed : is_complete,
389- } )
390- . is_err ( )
391- } ) ;
392- }
393-
394- if send_failed {
395- for signal in & sequence. free_signal ( ) {
396- kv_manager. process ( signal) ;
397- }
398- }
399-
400- if send_failed || is_complete {
401- state. complete ( & uuid) ;
402- continue ;
403- }
404- }
405-
406- // Send metrics once per forward pass (after all prefill and decode processing)
407- {
408- let metrics = get_fwd_pass_metrics ( & state, & kv_manager, & hit_rates, dp_rank) ;
409- let _ = metrics_tx. send ( metrics) ;
410- }
411-
412- // 4. Sleep to maintain target iteration timing
297+ let prefill_time = simulate_prefill (
298+ & mut state,
299+ & mut kv_manager,
300+ & args. perf_model ,
301+ args. worker_type ,
302+ ) ;
303+ let decode_time = simulate_decode (
304+ & mut state,
305+ & mut kv_manager,
306+ & output_tx,
307+ & args. perf_model ,
308+ args. block_size ,
309+ ) ;
310+ let total_time = prefill_time + decode_time;
311+
312+ // 4. Send metrics once per forward pass (after all prefill and decode processing)
313+ let _ = metrics_tx. send ( get_fwd_pass_metrics (
314+ & state,
315+ & kv_manager,
316+ & hit_rates,
317+ dp_rank,
318+ ) ) ;
319+
320+ // 5. Sleep to maintain target iteration timing
413321 let target_duration =
414322 Duration :: from_secs_f64 ( total_time. as_secs_f64 ( ) / args. speedup_ratio ) ;
415323 let elapsed = iteration_start. elapsed ( ) ;
@@ -441,6 +349,148 @@ impl Scheduler {
441349 }
442350}
443351
352+ /// Receive requests from the channel.
353+ /// Returns `Some(())` to continue the loop, `None` to break (on cancellation).
354+ async fn receive_requests (
355+ state : & mut SchedulerState ,
356+ request_rx : & mut mpsc:: UnboundedReceiver < DirectRequest > ,
357+ cancel_token : & CancellationToken ,
358+ ) -> Option < ( ) > {
359+ if cancel_token. is_cancelled ( ) {
360+ return None ;
361+ }
362+
363+ if state. is_empty ( ) {
364+ // Fully idle - block until new request arrives
365+ tokio:: select! {
366+ biased;
367+ _ = cancel_token. cancelled( ) => {
368+ return None ;
369+ }
370+ Some ( request) = request_rx. recv( ) => {
371+ state. receive( request) ;
372+ return Some ( ( ) ) ;
373+ }
374+ }
375+ }
376+
377+ // Has active/waiting work - collect any pending requests without blocking
378+ while let Ok ( request) = request_rx. try_recv ( ) {
379+ state. receive ( request) ;
380+ }
381+
382+ Some ( ( ) )
383+ }
384+
385+ /// Simulate prefill phase for all pending prefill requests.
386+ /// Returns the total prefill compute time.
387+ fn simulate_prefill (
388+ state : & mut SchedulerState ,
389+ kv_manager : & mut KvManager ,
390+ perf_model : & PerfModel ,
391+ worker_type : WorkerType ,
392+ ) -> Duration {
393+ let mut total_time = Duration :: ZERO ;
394+
395+ while let Some ( ( prefill_compute, maybe_creation_signal, is_full_prefill) ) =
396+ state. try_prefill ( perf_model)
397+ {
398+ // NOTE: Prefill cost/time is always incremented for new blocks, even if they
399+ // could be cached by other requests in the same batch. This matches vLLM behavior.
400+ // For decode workers, skip adding prefill compute time
401+ if worker_type != WorkerType :: Decode {
402+ total_time += Duration :: from_secs_f64 ( prefill_compute / 1000.0 ) ;
403+ }
404+
405+ if let Some ( creation_signal) = maybe_creation_signal
406+ && !process_signals ( kv_manager, std:: slice:: from_ref ( & creation_signal) )
407+ {
408+ panic ! ( "Block allocation for prefilling cannot fail." ) ;
409+ }
410+
411+ // Impossible to schedule more prefills if we encounter one incomplete (chunked) prefill
412+ if !is_full_prefill {
413+ break ;
414+ }
415+ }
416+
417+ total_time
418+ }
419+
420+ /// Simulate decode phase for all active decode requests.
421+ /// Returns the total decode compute time.
422+ fn simulate_decode (
423+ state : & mut SchedulerState ,
424+ kv_manager : & mut KvManager ,
425+ output_tx : & Option < mpsc:: UnboundedSender < OutputSignal > > ,
426+ perf_model : & PerfModel ,
427+ block_size : usize ,
428+ ) -> Duration {
429+ // Compute decode timing
430+ let active_kv_tokens = kv_manager. num_active_blocks ( ) * block_size;
431+ // Compute average context length across all active decode requests
432+ let ( total_length, count) = state
433+ . decode
434+ . keys ( )
435+ . filter_map ( |uuid| state. requests . get ( uuid) )
436+ . fold ( ( 0usize , 0usize ) , |( sum, cnt) , req| {
437+ if let Request :: Active ( seq) = req {
438+ ( sum + seq. len ( ) , cnt + 1 )
439+ } else {
440+ ( sum, cnt)
441+ }
442+ } ) ;
443+ let context_length = if count > 0 { total_length / count } else { 0 } ;
444+ let decoding_time = perf_model. predict_decode_time ( active_kv_tokens, context_length) ;
445+ let total_time = Duration :: from_secs_f64 ( decoding_time / 1000.0 ) ;
446+
447+ state. reset_active_tokens ( ) ;
448+
449+ // Process decoding
450+ let uuids: Vec < Uuid > = state. decode . keys ( ) . cloned ( ) . collect ( ) ;
451+ for uuid in uuids {
452+ let Some ( sequence) = state. run ( uuid) else {
453+ continue ;
454+ } ;
455+ let signals = sequence. generate ( ) ;
456+
457+ // Process all signals with the KvManager
458+ // Handling of preemption on failure
459+ if !process_signals ( kv_manager, & signals) {
460+ sequence. pop ( ) ; // revert the failed generation op
461+ for signal in state. preempt ( ) {
462+ kv_manager. process ( & signal) ;
463+ }
464+ continue ;
465+ }
466+
467+ // Check completion and send notification
468+ let is_complete = sequence. generated_tokens ( ) >= sequence. max_output_tokens ( ) ;
469+ let should_output = sequence. generated_tokens ( ) > sequence. already_generated_tokens ( ) ;
470+
471+ let send_failed = should_output
472+ && output_tx. as_ref ( ) . is_some_and ( |tx| {
473+ tx. send ( OutputSignal {
474+ uuid,
475+ completed : is_complete,
476+ } )
477+ . is_err ( )
478+ } ) ;
479+
480+ if send_failed {
481+ for signal in & sequence. free_signal ( ) {
482+ kv_manager. process ( signal) ;
483+ }
484+ }
485+
486+ if send_failed || is_complete {
487+ state. complete ( & uuid) ;
488+ }
489+ }
490+
491+ total_time
492+ }
493+
444494/// Calculate forward pass metrics from current state
445495fn get_fwd_pass_metrics (
446496 state : & SchedulerState ,
0 commit comments