Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 10 additions & 1 deletion lib/llm/src/mocker/engine.rs
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,11 @@ impl MockVllmEngine {
}

pub async fn start(&self, component: Component) -> Result<()> {
let cancel_token = component.drt().runtime().child_token();
// Use primary_token() instead of child_token() so the mocker continues running
// during graceful shutdown (Phase 1/2) and only stops in Phase 3.
// child_token() is a child of endpoint_shutdown_token which is cancelled in Phase 1.
// primary_token() is only cancelled in Phase 3, after waiting for inflight requests.
let cancel_token = component.drt().primary_token();

// Simulate engine startup time if configured
if let Some(startup_time_secs) = self.engine_args.startup_time {
Expand Down Expand Up @@ -143,6 +147,11 @@ impl MockVllmEngine {
}
}
_ = cancel_token_cloned.cancelled() => {
tracing::info!("Scheduler output task cancelled, clearing active requests");
// Clear all active requests to unblock waiting request handlers
// This will cause their request_rx.recv() to return None
let mut active = active_requests_clone.lock().await;
active.clear();
break;
}
}
Expand Down
22 changes: 21 additions & 1 deletion lib/runtime/src/pipeline/network/ingress/http_endpoint.rs
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,27 @@ impl SharedHttpServer {
.system_health
.lock()
.set_endpoint_health_status(endpoint_name, HealthStatus::NotReady);
tracing::debug!("Unregistered endpoint handler for subject: {}", subject);
tracing::debug!(
endpoint_name = %endpoint_name,
subject = %subject,
"Unregistered HTTP endpoint handler"
);

let inflight_count = handler.inflight.load(Ordering::SeqCst);
if inflight_count > 0 {
tracing::info!(
endpoint_name = %endpoint_name,
inflight_count = inflight_count,
"Waiting for inflight HTTP requests to complete"
);
while handler.inflight.load(Ordering::SeqCst) > 0 {
handler.notify.notified().await;
}
tracing::info!(
endpoint_name = %endpoint_name,
"All inflight HTTP requests completed"
);
}
}
}

Expand Down
22 changes: 21 additions & 1 deletion lib/runtime/src/pipeline/network/ingress/nats_server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ pub struct NatsMultiplexedServer {

struct EndpointTask {
cancel_token: CancellationToken,
join_handle: tokio::task::JoinHandle<()>,
_endpoint_name: String,
}

Expand Down Expand Up @@ -145,7 +146,7 @@ impl super::unified_server::RequestPlaneServer for NatsMultiplexedServer {
// Spawn task to handle this endpoint using PushEndpoint
// Note: PushEndpoint::start() is a blocking loop that runs until cancelled
let endpoint_name_clone = endpoint_name.clone();
tokio::spawn(async move {
let join_handle = tokio::spawn(async move {
if let Err(e) = push_endpoint
.start(
service_endpoint,
Expand Down Expand Up @@ -180,6 +181,7 @@ impl super::unified_server::RequestPlaneServer for NatsMultiplexedServer {
endpoint_name.clone(),
EndpointTask {
cancel_token: endpoint_cancel,
join_handle,
_endpoint_name: endpoint_name,
},
);
Expand All @@ -193,7 +195,25 @@ impl super::unified_server::RequestPlaneServer for NatsMultiplexedServer {
endpoint_name = %endpoint_name,
"Unregistering NATS endpoint"
);
// Cancel the token to trigger graceful shutdown
task.cancel_token.cancel();

// Wait for the endpoint task to complete (which includes waiting for inflight requests)
tracing::debug!(
endpoint_name = %endpoint_name,
"Waiting for NATS endpoint task to complete"
);
if let Err(e) = task.join_handle.await {
tracing::warn!(
endpoint_name = %endpoint_name,
error = %e,
"NATS endpoint task panicked during shutdown"
);
}
tracing::info!(
endpoint_name = %endpoint_name,
"NATS endpoint unregistration complete"
);
}
Ok(())
}
Expand Down
26 changes: 18 additions & 8 deletions lib/runtime/src/pipeline/network/ingress/push_endpoint.rs
Original file line number Diff line number Diff line change
Expand Up @@ -135,16 +135,26 @@ impl PushEndpoint {

// await for all inflight requests to complete if graceful shutdown
if self.graceful_shutdown {
tracing::info!(
"Waiting for {} inflight requests to complete",
inflight.load(Ordering::SeqCst)
);
while inflight.load(Ordering::SeqCst) > 0 {
notify.notified().await;
let inflight_count = inflight.load(Ordering::SeqCst);
if inflight_count > 0 {
tracing::info!(
endpoint_name = endpoint_name_local.as_str(),
inflight_count = inflight_count,
"Waiting for inflight NATS requests to complete"
);
while inflight.load(Ordering::SeqCst) > 0 {
notify.notified().await;
}
tracing::info!(
endpoint_name = endpoint_name_local.as_str(),
"All inflight NATS requests completed"
);
}
tracing::info!("All inflight requests completed");
} else {
tracing::info!("Skipping graceful shutdown, not waiting for inflight requests");
tracing::info!(
endpoint_name = endpoint_name_local.as_str(),
"Skipping graceful shutdown, not waiting for inflight requests"
);
}

Ok(())
Expand Down
32 changes: 27 additions & 5 deletions lib/runtime/src/pipeline/network/ingress/shared_tcp_endpoint.rs
Original file line number Diff line number Diff line change
Expand Up @@ -100,11 +100,33 @@ impl SharedTcpServer {
}

pub async fn unregister_endpoint(&self, endpoint_path: &str, endpoint_name: &str) {
self.handlers.remove(endpoint_path);
tracing::info!(
"Unregistered endpoint '{}' from shared TCP server",
endpoint_name
);
if let Some((_, handler)) = self.handlers.remove(endpoint_path) {
handler
.system_health
.lock()
.set_endpoint_health_status(endpoint_name, crate::HealthStatus::NotReady);
tracing::info!(
endpoint_name = %endpoint_name,
endpoint_path = %endpoint_path,
"Unregistered TCP endpoint handler"
);

let inflight_count = handler.inflight.load(Ordering::SeqCst);
if inflight_count > 0 {
tracing::info!(
endpoint_name = %endpoint_name,
inflight_count = inflight_count,
"Waiting for inflight TCP requests to complete"
);
while handler.inflight.load(Ordering::SeqCst) > 0 {
handler.notify.notified().await;
}
tracing::info!(
endpoint_name = %endpoint_name,
"All inflight TCP requests completed"
);
}
}
}

pub async fn start(self: Arc<Self>) -> Result<()> {
Expand Down
Loading