Skip to content

Commit 377711d

Browse files
authored
chore: de-spaghetti mocker scheduler (#4789)
Signed-off-by: PeaBrane <[email protected]>
1 parent 8f0ac73 commit 377711d

File tree

1 file changed

+171
-121
lines changed

1 file changed

+171
-121
lines changed

lib/llm/src/mocker/scheduler.rs

Lines changed: 171 additions & 121 deletions
Original file line numberDiff line numberDiff line change
@@ -280,27 +280,11 @@ impl Scheduler {
280280

281281
loop {
282282
// 1. Receive requests
283-
if state.is_empty() {
284-
// Fully idle - block until new request arrives
285-
tokio::select! {
286-
biased;
287-
Some(request) = request_rx.recv() => {
288-
state.receive(request);
289-
}
290-
_ = cancel_token_clone.cancelled() => {
291-
break;
292-
}
293-
}
294-
} else {
295-
// Has active/waiting work - collect any pending requests without blocking
296-
while let Ok(request) = request_rx.try_recv() {
297-
state.receive(request);
298-
}
299-
300-
// Check for cancellation
301-
if cancel_token_clone.is_cancelled() {
302-
break;
303-
}
283+
if receive_requests(&mut state, &mut request_rx, &cancel_token_clone)
284+
.await
285+
.is_none()
286+
{
287+
break;
304288
}
305289

306290
// Start timing for this forward pass (schedule + simulate)
@@ -310,106 +294,30 @@ impl Scheduler {
310294
try_schedule(&mut state, &kv_manager, &mut hit_rates, &args);
311295

312296
// 3. Simulate prefill + decode
313-
let mut total_time = Duration::ZERO;
314-
315-
// Process prefilling
316-
while let Some((prefill_compute, maybe_creation_signal, is_full_prefill)) =
317-
state.try_prefill(&args.perf_model)
318-
{
319-
// NOTE: Prefill cost/time is always incremented for new blocks, even if they
320-
// could be cached by other requests in the same batch. This matches vLLM behavior.
321-
// For decode workers, skip adding prefill compute time
322-
if args.worker_type != WorkerType::Decode {
323-
total_time += Duration::from_secs_f64(prefill_compute / 1000.0);
324-
}
325-
326-
if let Some(creation_signal) = maybe_creation_signal
327-
&& !process_signals(&mut kv_manager, std::slice::from_ref(&creation_signal))
328-
{
329-
panic!("Block allocation for prefilling cannot fail.");
330-
}
331-
332-
// Impossible to schedule more prefills if we encounter one incomplete (chunked) prefill
333-
if !is_full_prefill {
334-
break;
335-
}
336-
}
337-
338-
// Compute decode timing
339-
let active_kv_tokens = kv_manager.num_active_blocks() * args.block_size;
340-
// Compute average context length across all active decode requests
341-
let (total_length, count) = state
342-
.decode
343-
.keys()
344-
.filter_map(|uuid| state.requests.get(uuid))
345-
.fold((0usize, 0usize), |(sum, cnt), req| {
346-
if let Request::Active(seq) = req {
347-
(sum + seq.len(), cnt + 1)
348-
} else {
349-
(sum, cnt)
350-
}
351-
});
352-
let context_length = if count > 0 { total_length / count } else { 0 };
353-
let decoding_time = args
354-
.perf_model
355-
.predict_decode_time(active_kv_tokens, context_length);
356-
total_time += Duration::from_secs_f64(decoding_time / 1000.0);
357-
358-
state.reset_active_tokens();
359-
360-
// Process decoding
361-
let uuids: Vec<Uuid> = state.decode.keys().cloned().collect();
362-
for uuid in uuids {
363-
let Some(sequence) = state.run(uuid) else {
364-
continue;
365-
};
366-
let signals = sequence.generate();
367-
368-
// Process all signals with the KvManager
369-
// Handling of preemption on failure
370-
if !process_signals(&mut kv_manager, &signals) {
371-
sequence.pop(); // revert the failed generation op
372-
for signal in state.preempt() {
373-
kv_manager.process(&signal);
374-
}
375-
continue;
376-
}
377-
378-
// Check completion and send notification
379-
let is_complete = sequence.generated_tokens() >= sequence.max_output_tokens();
380-
let should_output =
381-
sequence.generated_tokens() > sequence.already_generated_tokens();
382-
383-
let mut send_failed = false;
384-
if should_output {
385-
send_failed = output_tx.as_ref().is_some_and(|tx| {
386-
tx.send(OutputSignal {
387-
uuid,
388-
completed: is_complete,
389-
})
390-
.is_err()
391-
});
392-
}
393-
394-
if send_failed {
395-
for signal in &sequence.free_signal() {
396-
kv_manager.process(signal);
397-
}
398-
}
399-
400-
if send_failed || is_complete {
401-
state.complete(&uuid);
402-
continue;
403-
}
404-
}
405-
406-
// Send metrics once per forward pass (after all prefill and decode processing)
407-
{
408-
let metrics = get_fwd_pass_metrics(&state, &kv_manager, &hit_rates, dp_rank);
409-
let _ = metrics_tx.send(metrics);
410-
}
411-
412-
// 4. Sleep to maintain target iteration timing
297+
let prefill_time = simulate_prefill(
298+
&mut state,
299+
&mut kv_manager,
300+
&args.perf_model,
301+
args.worker_type,
302+
);
303+
let decode_time = simulate_decode(
304+
&mut state,
305+
&mut kv_manager,
306+
&output_tx,
307+
&args.perf_model,
308+
args.block_size,
309+
);
310+
let total_time = prefill_time + decode_time;
311+
312+
// 4. Send metrics once per forward pass (after all prefill and decode processing)
313+
let _ = metrics_tx.send(get_fwd_pass_metrics(
314+
&state,
315+
&kv_manager,
316+
&hit_rates,
317+
dp_rank,
318+
));
319+
320+
// 5. Sleep to maintain target iteration timing
413321
let target_duration =
414322
Duration::from_secs_f64(total_time.as_secs_f64() / args.speedup_ratio);
415323
let elapsed = iteration_start.elapsed();
@@ -441,6 +349,148 @@ impl Scheduler {
441349
}
442350
}
443351

352+
/// Receive requests from the channel.
353+
/// Returns `Some(())` to continue the loop, `None` to break (on cancellation).
354+
async fn receive_requests(
355+
state: &mut SchedulerState,
356+
request_rx: &mut mpsc::UnboundedReceiver<DirectRequest>,
357+
cancel_token: &CancellationToken,
358+
) -> Option<()> {
359+
if cancel_token.is_cancelled() {
360+
return None;
361+
}
362+
363+
if state.is_empty() {
364+
// Fully idle - block until new request arrives
365+
tokio::select! {
366+
biased;
367+
_ = cancel_token.cancelled() => {
368+
return None;
369+
}
370+
Some(request) = request_rx.recv() => {
371+
state.receive(request);
372+
return Some(());
373+
}
374+
}
375+
}
376+
377+
// Has active/waiting work - collect any pending requests without blocking
378+
while let Ok(request) = request_rx.try_recv() {
379+
state.receive(request);
380+
}
381+
382+
Some(())
383+
}
384+
385+
/// Simulate prefill phase for all pending prefill requests.
386+
/// Returns the total prefill compute time.
387+
fn simulate_prefill(
388+
state: &mut SchedulerState,
389+
kv_manager: &mut KvManager,
390+
perf_model: &PerfModel,
391+
worker_type: WorkerType,
392+
) -> Duration {
393+
let mut total_time = Duration::ZERO;
394+
395+
while let Some((prefill_compute, maybe_creation_signal, is_full_prefill)) =
396+
state.try_prefill(perf_model)
397+
{
398+
// NOTE: Prefill cost/time is always incremented for new blocks, even if they
399+
// could be cached by other requests in the same batch. This matches vLLM behavior.
400+
// For decode workers, skip adding prefill compute time
401+
if worker_type != WorkerType::Decode {
402+
total_time += Duration::from_secs_f64(prefill_compute / 1000.0);
403+
}
404+
405+
if let Some(creation_signal) = maybe_creation_signal
406+
&& !process_signals(kv_manager, std::slice::from_ref(&creation_signal))
407+
{
408+
panic!("Block allocation for prefilling cannot fail.");
409+
}
410+
411+
// Impossible to schedule more prefills if we encounter one incomplete (chunked) prefill
412+
if !is_full_prefill {
413+
break;
414+
}
415+
}
416+
417+
total_time
418+
}
419+
420+
/// Simulate decode phase for all active decode requests.
421+
/// Returns the total decode compute time.
422+
fn simulate_decode(
423+
state: &mut SchedulerState,
424+
kv_manager: &mut KvManager,
425+
output_tx: &Option<mpsc::UnboundedSender<OutputSignal>>,
426+
perf_model: &PerfModel,
427+
block_size: usize,
428+
) -> Duration {
429+
// Compute decode timing
430+
let active_kv_tokens = kv_manager.num_active_blocks() * block_size;
431+
// Compute average context length across all active decode requests
432+
let (total_length, count) = state
433+
.decode
434+
.keys()
435+
.filter_map(|uuid| state.requests.get(uuid))
436+
.fold((0usize, 0usize), |(sum, cnt), req| {
437+
if let Request::Active(seq) = req {
438+
(sum + seq.len(), cnt + 1)
439+
} else {
440+
(sum, cnt)
441+
}
442+
});
443+
let context_length = if count > 0 { total_length / count } else { 0 };
444+
let decoding_time = perf_model.predict_decode_time(active_kv_tokens, context_length);
445+
let total_time = Duration::from_secs_f64(decoding_time / 1000.0);
446+
447+
state.reset_active_tokens();
448+
449+
// Process decoding
450+
let uuids: Vec<Uuid> = state.decode.keys().cloned().collect();
451+
for uuid in uuids {
452+
let Some(sequence) = state.run(uuid) else {
453+
continue;
454+
};
455+
let signals = sequence.generate();
456+
457+
// Process all signals with the KvManager
458+
// Handling of preemption on failure
459+
if !process_signals(kv_manager, &signals) {
460+
sequence.pop(); // revert the failed generation op
461+
for signal in state.preempt() {
462+
kv_manager.process(&signal);
463+
}
464+
continue;
465+
}
466+
467+
// Check completion and send notification
468+
let is_complete = sequence.generated_tokens() >= sequence.max_output_tokens();
469+
let should_output = sequence.generated_tokens() > sequence.already_generated_tokens();
470+
471+
let send_failed = should_output
472+
&& output_tx.as_ref().is_some_and(|tx| {
473+
tx.send(OutputSignal {
474+
uuid,
475+
completed: is_complete,
476+
})
477+
.is_err()
478+
});
479+
480+
if send_failed {
481+
for signal in &sequence.free_signal() {
482+
kv_manager.process(signal);
483+
}
484+
}
485+
486+
if send_failed || is_complete {
487+
state.complete(&uuid);
488+
}
489+
}
490+
491+
total_time
492+
}
493+
444494
/// Calculate forward pass metrics from current state
445495
fn get_fwd_pass_metrics(
446496
state: &SchedulerState,

0 commit comments

Comments
 (0)