Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
215 changes: 215 additions & 0 deletions lib/runtime/src/pipeline/network/ingress/shared_tcp_endpoint.rs
Original file line number Diff line number Diff line change
Expand Up @@ -391,3 +391,218 @@ impl super::unified_server::RequestPlaneServer for SharedTcpServer {
true
}
}

#[cfg(test)]
mod tests {
use super::*;
use crate::pipeline::error::PipelineError;
use async_trait::async_trait;
use std::sync::atomic::{AtomicBool, Ordering};
use std::time::Duration;
use tokio::time::Instant;

/// Mock handler that simulates slow request processing for testing
struct SlowMockHandler {
/// Tracks if a request is currently being processed
request_in_flight: Arc<AtomicBool>,
/// Notifies when request processing starts
request_started: Arc<Notify>,
/// Notifies when request processing completes
request_completed: Arc<Notify>,
/// Duration to simulate request processing
processing_duration: Duration,
}

impl SlowMockHandler {
fn new(processing_duration: Duration) -> Self {
Self {
request_in_flight: Arc::new(AtomicBool::new(false)),
request_started: Arc::new(Notify::new()),
request_completed: Arc::new(Notify::new()),
processing_duration,
}
}
}

#[async_trait]
impl PushWorkHandler for SlowMockHandler {
async fn handle_payload(&self, _payload: Bytes) -> Result<(), PipelineError> {
self.request_in_flight.store(true, Ordering::SeqCst);
self.request_started.notify_one();

tracing::debug!(
"SlowMockHandler: Request started, sleeping for {:?}",
self.processing_duration
);

// Simulate slow request processing
tokio::time::sleep(self.processing_duration).await;

tracing::debug!("SlowMockHandler: Request completed");

self.request_in_flight.store(false, Ordering::SeqCst);
self.request_completed.notify_one();
Ok(())
}

fn add_metrics(
&self,
_endpoint: &crate::component::Endpoint,
_metrics_labels: Option<&[(&str, &str)]>,
) -> Result<()> {
Ok(())
}
}

#[tokio::test]
async fn test_graceful_shutdown_waits_for_inflight_tcp_requests() {
// Initialize tracing for test debugging
let _ = tracing_subscriber::fmt()
.with_test_writer()
.with_max_level(tracing::Level::DEBUG)
.try_init();

let cancellation_token = CancellationToken::new();
let bind_addr: SocketAddr = "127.0.0.1:0".parse().unwrap();

// Create SharedTcpServer
let server = SharedTcpServer::new(bind_addr, cancellation_token.clone());

// Create a handler that takes 1s to process requests
let handler = Arc::new(SlowMockHandler::new(Duration::from_secs(1)));
let request_started = handler.request_started.clone();
let request_completed = handler.request_completed.clone();
let request_in_flight = handler.request_in_flight.clone();

// Register endpoint
let endpoint_path = "test_endpoint".to_string();
let system_health = Arc::new(Mutex::new(SystemHealth::new(
crate::HealthStatus::Ready,
vec![],
"/health".to_string(),
"/live".to_string(),
)));

server
.register_endpoint(
endpoint_path.clone(),
handler.clone() as Arc<dyn PushWorkHandler>,
1,
"test_namespace".to_string(),
"test_component".to_string(),
"test_endpoint".to_string(),
system_health,
)
.await
.expect("Failed to register endpoint");

tracing::debug!("Endpoint registered");

// Get the endpoint handler to simulate request processing
let endpoint_handler = server
.handlers
.get(&endpoint_path)
.expect("Handler should be registered")
.clone();

// Spawn a task that simulates an inflight request
let request_task = tokio::spawn({
let handler = handler.clone();
async move {
let payload = Bytes::from("test payload");
handler.handle_payload(payload).await
}
});

// Increment inflight counter manually to simulate the request being tracked
endpoint_handler.inflight.fetch_add(1, Ordering::SeqCst);

// Wait for request to start processing
tokio::select! {
_ = request_started.notified() => {
tracing::debug!("Request processing started");
}
_ = tokio::time::sleep(Duration::from_secs(2)) => {
panic!("Timeout waiting for request to start");
}
}

// Verify request is in flight
assert!(
request_in_flight.load(Ordering::SeqCst),
"Request should be in flight"
);

// Now unregister the endpoint while request is inflight
let unregister_start = Instant::now();
tracing::debug!("Starting unregister_endpoint with inflight request");

// Spawn unregister in a separate task so we can monitor its behavior
let unregister_task = tokio::spawn({
let server = server.clone();
let endpoint_path = endpoint_path.clone();
async move {
server
.unregister_endpoint(&endpoint_path, "test_endpoint")
.await;
Instant::now()
}
});

// Give unregister a moment to remove handler and start waiting
tokio::time::sleep(Duration::from_millis(50)).await;

// Verify that unregister_endpoint hasn't returned yet (it should be waiting)
assert!(
!unregister_task.is_finished(),
"unregister_endpoint should still be waiting for inflight request"
);

tracing::debug!("Verified unregister is waiting, now waiting for request to complete");

// Wait for the request to complete
tokio::select! {
_ = request_completed.notified() => {
tracing::debug!("Request completed");
}
_ = tokio::time::sleep(Duration::from_secs(2)) => {
panic!("Timeout waiting for request to complete");
}
}

// Decrement inflight counter and notify (simulating what the real code does)
endpoint_handler.inflight.fetch_sub(1, Ordering::SeqCst);
endpoint_handler.notify.notify_one();

// Now wait for unregister to complete
let unregister_end = tokio::time::timeout(Duration::from_secs(2), unregister_task)
.await
.expect("unregister_endpoint should complete after inflight request finishes")
.expect("unregister task should not panic");

let unregister_duration = unregister_end - unregister_start;

tracing::debug!("unregister_endpoint completed in {:?}", unregister_duration);

// Verify unregister_endpoint waited for the inflight request
assert!(
unregister_duration >= Duration::from_secs(1),
"unregister_endpoint should have waited ~1s for inflight request, but only took {:?}",
unregister_duration
);

// Verify request completed successfully
assert!(
!request_in_flight.load(Ordering::SeqCst),
"Request should have completed"
);

// Wait for request task to finish
request_task
.await
.expect("Request task should complete")
.expect("Request should succeed");

tracing::info!("Test passed: unregister_endpoint properly waited for inflight TCP request");
}
}
Loading