diff --git a/lib/runtime/src/pipeline/network/ingress/shared_tcp_endpoint.rs b/lib/runtime/src/pipeline/network/ingress/shared_tcp_endpoint.rs index 7ed5ea6e3d..2b9d880fc2 100644 --- a/lib/runtime/src/pipeline/network/ingress/shared_tcp_endpoint.rs +++ b/lib/runtime/src/pipeline/network/ingress/shared_tcp_endpoint.rs @@ -391,3 +391,218 @@ impl super::unified_server::RequestPlaneServer for SharedTcpServer { true } } + +#[cfg(test)] +mod tests { + use super::*; + use crate::pipeline::error::PipelineError; + use async_trait::async_trait; + use std::sync::atomic::{AtomicBool, Ordering}; + use std::time::Duration; + use tokio::time::Instant; + + /// Mock handler that simulates slow request processing for testing + struct SlowMockHandler { + /// Tracks if a request is currently being processed + request_in_flight: Arc, + /// Notifies when request processing starts + request_started: Arc, + /// Notifies when request processing completes + request_completed: Arc, + /// Duration to simulate request processing + processing_duration: Duration, + } + + impl SlowMockHandler { + fn new(processing_duration: Duration) -> Self { + Self { + request_in_flight: Arc::new(AtomicBool::new(false)), + request_started: Arc::new(Notify::new()), + request_completed: Arc::new(Notify::new()), + processing_duration, + } + } + } + + #[async_trait] + impl PushWorkHandler for SlowMockHandler { + async fn handle_payload(&self, _payload: Bytes) -> Result<(), PipelineError> { + self.request_in_flight.store(true, Ordering::SeqCst); + self.request_started.notify_one(); + + tracing::debug!( + "SlowMockHandler: Request started, sleeping for {:?}", + self.processing_duration + ); + + // Simulate slow request processing + tokio::time::sleep(self.processing_duration).await; + + tracing::debug!("SlowMockHandler: Request completed"); + + self.request_in_flight.store(false, Ordering::SeqCst); + self.request_completed.notify_one(); + Ok(()) + } + + fn add_metrics( + &self, + _endpoint: &crate::component::Endpoint, + _metrics_labels: Option<&[(&str, &str)]>, + ) -> Result<()> { + Ok(()) + } + } + + #[tokio::test] + async fn test_graceful_shutdown_waits_for_inflight_tcp_requests() { + // Initialize tracing for test debugging + let _ = tracing_subscriber::fmt() + .with_test_writer() + .with_max_level(tracing::Level::DEBUG) + .try_init(); + + let cancellation_token = CancellationToken::new(); + let bind_addr: SocketAddr = "127.0.0.1:0".parse().unwrap(); + + // Create SharedTcpServer + let server = SharedTcpServer::new(bind_addr, cancellation_token.clone()); + + // Create a handler that takes 1s to process requests + let handler = Arc::new(SlowMockHandler::new(Duration::from_secs(1))); + let request_started = handler.request_started.clone(); + let request_completed = handler.request_completed.clone(); + let request_in_flight = handler.request_in_flight.clone(); + + // Register endpoint + let endpoint_path = "test_endpoint".to_string(); + let system_health = Arc::new(Mutex::new(SystemHealth::new( + crate::HealthStatus::Ready, + vec![], + "/health".to_string(), + "/live".to_string(), + ))); + + server + .register_endpoint( + endpoint_path.clone(), + handler.clone() as Arc, + 1, + "test_namespace".to_string(), + "test_component".to_string(), + "test_endpoint".to_string(), + system_health, + ) + .await + .expect("Failed to register endpoint"); + + tracing::debug!("Endpoint registered"); + + // Get the endpoint handler to simulate request processing + let endpoint_handler = server + .handlers + .get(&endpoint_path) + .expect("Handler should be registered") + .clone(); + + // Spawn a task that simulates an inflight request + let request_task = tokio::spawn({ + let handler = handler.clone(); + async move { + let payload = Bytes::from("test payload"); + handler.handle_payload(payload).await + } + }); + + // Increment inflight counter manually to simulate the request being tracked + endpoint_handler.inflight.fetch_add(1, Ordering::SeqCst); + + // Wait for request to start processing + tokio::select! { + _ = request_started.notified() => { + tracing::debug!("Request processing started"); + } + _ = tokio::time::sleep(Duration::from_secs(2)) => { + panic!("Timeout waiting for request to start"); + } + } + + // Verify request is in flight + assert!( + request_in_flight.load(Ordering::SeqCst), + "Request should be in flight" + ); + + // Now unregister the endpoint while request is inflight + let unregister_start = Instant::now(); + tracing::debug!("Starting unregister_endpoint with inflight request"); + + // Spawn unregister in a separate task so we can monitor its behavior + let unregister_task = tokio::spawn({ + let server = server.clone(); + let endpoint_path = endpoint_path.clone(); + async move { + server + .unregister_endpoint(&endpoint_path, "test_endpoint") + .await; + Instant::now() + } + }); + + // Give unregister a moment to remove handler and start waiting + tokio::time::sleep(Duration::from_millis(50)).await; + + // Verify that unregister_endpoint hasn't returned yet (it should be waiting) + assert!( + !unregister_task.is_finished(), + "unregister_endpoint should still be waiting for inflight request" + ); + + tracing::debug!("Verified unregister is waiting, now waiting for request to complete"); + + // Wait for the request to complete + tokio::select! { + _ = request_completed.notified() => { + tracing::debug!("Request completed"); + } + _ = tokio::time::sleep(Duration::from_secs(2)) => { + panic!("Timeout waiting for request to complete"); + } + } + + // Decrement inflight counter and notify (simulating what the real code does) + endpoint_handler.inflight.fetch_sub(1, Ordering::SeqCst); + endpoint_handler.notify.notify_one(); + + // Now wait for unregister to complete + let unregister_end = tokio::time::timeout(Duration::from_secs(2), unregister_task) + .await + .expect("unregister_endpoint should complete after inflight request finishes") + .expect("unregister task should not panic"); + + let unregister_duration = unregister_end - unregister_start; + + tracing::debug!("unregister_endpoint completed in {:?}", unregister_duration); + + // Verify unregister_endpoint waited for the inflight request + assert!( + unregister_duration >= Duration::from_secs(1), + "unregister_endpoint should have waited ~1s for inflight request, but only took {:?}", + unregister_duration + ); + + // Verify request completed successfully + assert!( + !request_in_flight.load(Ordering::SeqCst), + "Request should have completed" + ); + + // Wait for request task to finish + request_task + .await + .expect("Request task should complete") + .expect("Request should succeed"); + + tracing::info!("Test passed: unregister_endpoint properly waited for inflight TCP request"); + } +}