Skip to content

Commit c8845b4

Browse files
fix: worker to graceful shutdown after finishing in-flight requests (#4838)
Signed-off-by: hongkuanz <[email protected]> Co-authored-by: Biswa Panda <[email protected]>
1 parent 69817c2 commit c8845b4

File tree

5 files changed

+97
-16
lines changed

5 files changed

+97
-16
lines changed

lib/llm/src/mocker/engine.rs

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,11 @@ impl MockVllmEngine {
6060
}
6161

6262
pub async fn start(&self, component: Component) -> Result<()> {
63-
let cancel_token = component.drt().runtime().child_token();
63+
// Use primary_token() instead of child_token() so the mocker continues running
64+
// during graceful shutdown (Phase 1/2) and only stops in Phase 3.
65+
// child_token() is a child of endpoint_shutdown_token which is cancelled in Phase 1.
66+
// primary_token() is only cancelled in Phase 3, after waiting for inflight requests.
67+
let cancel_token = component.drt().primary_token();
6468

6569
// Simulate engine startup time if configured
6670
if let Some(startup_time_secs) = self.engine_args.startup_time {
@@ -143,6 +147,11 @@ impl MockVllmEngine {
143147
}
144148
}
145149
_ = cancel_token_cloned.cancelled() => {
150+
tracing::info!("Scheduler output task cancelled, clearing active requests");
151+
// Clear all active requests to unblock waiting request handlers
152+
// This will cause their request_rx.recv() to return None
153+
let mut active = active_requests_clone.lock().await;
154+
active.clear();
146155
break;
147156
}
148157
}

lib/runtime/src/pipeline/network/ingress/http_endpoint.rs

Lines changed: 21 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -105,7 +105,27 @@ impl SharedHttpServer {
105105
.system_health
106106
.lock()
107107
.set_endpoint_health_status(endpoint_name, HealthStatus::NotReady);
108-
tracing::debug!("Unregistered endpoint handler for subject: {}", subject);
108+
tracing::debug!(
109+
endpoint_name = %endpoint_name,
110+
subject = %subject,
111+
"Unregistered HTTP endpoint handler"
112+
);
113+
114+
let inflight_count = handler.inflight.load(Ordering::SeqCst);
115+
if inflight_count > 0 {
116+
tracing::info!(
117+
endpoint_name = %endpoint_name,
118+
inflight_count = inflight_count,
119+
"Waiting for inflight HTTP requests to complete"
120+
);
121+
while handler.inflight.load(Ordering::SeqCst) > 0 {
122+
handler.notify.notified().await;
123+
}
124+
tracing::info!(
125+
endpoint_name = %endpoint_name,
126+
"All inflight HTTP requests completed"
127+
);
128+
}
109129
}
110130
}
111131

lib/runtime/src/pipeline/network/ingress/nats_server.rs

Lines changed: 21 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@ pub struct NatsMultiplexedServer {
3232

3333
struct EndpointTask {
3434
cancel_token: CancellationToken,
35+
join_handle: tokio::task::JoinHandle<()>,
3536
_endpoint_name: String,
3637
}
3738

@@ -145,7 +146,7 @@ impl super::unified_server::RequestPlaneServer for NatsMultiplexedServer {
145146
// Spawn task to handle this endpoint using PushEndpoint
146147
// Note: PushEndpoint::start() is a blocking loop that runs until cancelled
147148
let endpoint_name_clone = endpoint_name.clone();
148-
tokio::spawn(async move {
149+
let join_handle = tokio::spawn(async move {
149150
if let Err(e) = push_endpoint
150151
.start(
151152
service_endpoint,
@@ -180,6 +181,7 @@ impl super::unified_server::RequestPlaneServer for NatsMultiplexedServer {
180181
endpoint_name.clone(),
181182
EndpointTask {
182183
cancel_token: endpoint_cancel,
184+
join_handle,
183185
_endpoint_name: endpoint_name,
184186
},
185187
);
@@ -193,7 +195,25 @@ impl super::unified_server::RequestPlaneServer for NatsMultiplexedServer {
193195
endpoint_name = %endpoint_name,
194196
"Unregistering NATS endpoint"
195197
);
198+
// Cancel the token to trigger graceful shutdown
196199
task.cancel_token.cancel();
200+
201+
// Wait for the endpoint task to complete (which includes waiting for inflight requests)
202+
tracing::debug!(
203+
endpoint_name = %endpoint_name,
204+
"Waiting for NATS endpoint task to complete"
205+
);
206+
if let Err(e) = task.join_handle.await {
207+
tracing::warn!(
208+
endpoint_name = %endpoint_name,
209+
error = %e,
210+
"NATS endpoint task panicked during shutdown"
211+
);
212+
}
213+
tracing::info!(
214+
endpoint_name = %endpoint_name,
215+
"NATS endpoint unregistration complete"
216+
);
197217
}
198218
Ok(())
199219
}

lib/runtime/src/pipeline/network/ingress/push_endpoint.rs

Lines changed: 18 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -135,16 +135,26 @@ impl PushEndpoint {
135135

136136
// await for all inflight requests to complete if graceful shutdown
137137
if self.graceful_shutdown {
138-
tracing::info!(
139-
"Waiting for {} inflight requests to complete",
140-
inflight.load(Ordering::SeqCst)
141-
);
142-
while inflight.load(Ordering::SeqCst) > 0 {
143-
notify.notified().await;
138+
let inflight_count = inflight.load(Ordering::SeqCst);
139+
if inflight_count > 0 {
140+
tracing::info!(
141+
endpoint_name = endpoint_name_local.as_str(),
142+
inflight_count = inflight_count,
143+
"Waiting for inflight NATS requests to complete"
144+
);
145+
while inflight.load(Ordering::SeqCst) > 0 {
146+
notify.notified().await;
147+
}
148+
tracing::info!(
149+
endpoint_name = endpoint_name_local.as_str(),
150+
"All inflight NATS requests completed"
151+
);
144152
}
145-
tracing::info!("All inflight requests completed");
146153
} else {
147-
tracing::info!("Skipping graceful shutdown, not waiting for inflight requests");
154+
tracing::info!(
155+
endpoint_name = endpoint_name_local.as_str(),
156+
"Skipping graceful shutdown, not waiting for inflight requests"
157+
);
148158
}
149159

150160
Ok(())

lib/runtime/src/pipeline/network/ingress/shared_tcp_endpoint.rs

Lines changed: 27 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -100,11 +100,33 @@ impl SharedTcpServer {
100100
}
101101

102102
pub async fn unregister_endpoint(&self, endpoint_path: &str, endpoint_name: &str) {
103-
self.handlers.remove(endpoint_path);
104-
tracing::info!(
105-
"Unregistered endpoint '{}' from shared TCP server",
106-
endpoint_name
107-
);
103+
if let Some((_, handler)) = self.handlers.remove(endpoint_path) {
104+
handler
105+
.system_health
106+
.lock()
107+
.set_endpoint_health_status(endpoint_name, crate::HealthStatus::NotReady);
108+
tracing::info!(
109+
endpoint_name = %endpoint_name,
110+
endpoint_path = %endpoint_path,
111+
"Unregistered TCP endpoint handler"
112+
);
113+
114+
let inflight_count = handler.inflight.load(Ordering::SeqCst);
115+
if inflight_count > 0 {
116+
tracing::info!(
117+
endpoint_name = %endpoint_name,
118+
inflight_count = inflight_count,
119+
"Waiting for inflight TCP requests to complete"
120+
);
121+
while handler.inflight.load(Ordering::SeqCst) > 0 {
122+
handler.notify.notified().await;
123+
}
124+
tracing::info!(
125+
endpoint_name = %endpoint_name,
126+
"All inflight TCP requests completed"
127+
);
128+
}
129+
}
108130
}
109131

110132
pub async fn start(self: Arc<Self>) -> Result<()> {

0 commit comments

Comments
 (0)