Skip to content

Commit 63844b3

Browse files
committed
test: graceful shutdown for tcp req plane
1 parent 5250303 commit 63844b3

File tree

1 file changed

+213
-0
lines changed

1 file changed

+213
-0
lines changed

lib/runtime/src/pipeline/network/ingress/shared_tcp_endpoint.rs

Lines changed: 213 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -391,3 +391,216 @@ impl super::unified_server::RequestPlaneServer for SharedTcpServer {
391391
true
392392
}
393393
}
394+
395+
#[cfg(test)]
396+
mod tests {
397+
use super::*;
398+
use crate::pipeline::error::PipelineError;
399+
use async_trait::async_trait;
400+
use std::sync::atomic::{AtomicBool, Ordering};
401+
use std::time::Duration;
402+
use tokio::time::Instant;
403+
404+
/// Mock handler that simulates slow request processing for testing
405+
struct SlowMockHandler {
406+
/// Tracks if a request is currently being processed
407+
request_in_flight: Arc<AtomicBool>,
408+
/// Notifies when request processing starts
409+
request_started: Arc<Notify>,
410+
/// Notifies when request processing completes
411+
request_completed: Arc<Notify>,
412+
/// Duration to simulate request processing
413+
processing_duration: Duration,
414+
}
415+
416+
impl SlowMockHandler {
417+
fn new(processing_duration: Duration) -> Self {
418+
Self {
419+
request_in_flight: Arc::new(AtomicBool::new(false)),
420+
request_started: Arc::new(Notify::new()),
421+
request_completed: Arc::new(Notify::new()),
422+
processing_duration,
423+
}
424+
}
425+
}
426+
427+
#[async_trait]
428+
impl PushWorkHandler for SlowMockHandler {
429+
async fn handle_payload(&self, _payload: Bytes) -> Result<(), PipelineError> {
430+
self.request_in_flight.store(true, Ordering::SeqCst);
431+
self.request_started.notify_one();
432+
433+
tracing::debug!(
434+
"SlowMockHandler: Request started, sleeping for {:?}",
435+
self.processing_duration
436+
);
437+
438+
// Simulate slow request processing
439+
tokio::time::sleep(self.processing_duration).await;
440+
441+
tracing::debug!("SlowMockHandler: Request completed");
442+
443+
self.request_in_flight.store(false, Ordering::SeqCst);
444+
self.request_completed.notify_one();
445+
Ok(())
446+
}
447+
448+
fn add_metrics(
449+
&self,
450+
_endpoint: &crate::component::Endpoint,
451+
_metrics_labels: Option<&[(&str, &str)]>,
452+
) -> Result<()> {
453+
Ok(())
454+
}
455+
}
456+
457+
#[tokio::test]
458+
async fn test_graceful_shutdown_waits_for_inflight_tcp_requests() {
459+
// Initialize tracing for test debugging
460+
let _ = tracing_subscriber::fmt()
461+
.with_test_writer()
462+
.with_max_level(tracing::Level::DEBUG)
463+
.try_init();
464+
465+
let cancellation_token = CancellationToken::new();
466+
let bind_addr: SocketAddr = "127.0.0.1:0".parse().unwrap();
467+
468+
// Create SharedTcpServer
469+
let server = SharedTcpServer::new(bind_addr, cancellation_token.clone());
470+
471+
// Create a handler that takes 1s to process requests
472+
let handler = Arc::new(SlowMockHandler::new(Duration::from_secs(1)));
473+
let request_started = handler.request_started.clone();
474+
let request_completed = handler.request_completed.clone();
475+
let request_in_flight = handler.request_in_flight.clone();
476+
477+
// Register endpoint
478+
let endpoint_path = "test_endpoint".to_string();
479+
let system_health = Arc::new(Mutex::new(SystemHealth::new(
480+
crate::HealthStatus::Ready,
481+
vec![],
482+
"/health".to_string(),
483+
"/live".to_string(),
484+
)));
485+
486+
server
487+
.register_endpoint(
488+
endpoint_path.clone(),
489+
handler.clone() as Arc<dyn PushWorkHandler>,
490+
1,
491+
"test_namespace".to_string(),
492+
"test_component".to_string(),
493+
"test_endpoint".to_string(),
494+
system_health,
495+
)
496+
.await
497+
.expect("Failed to register endpoint");
498+
499+
tracing::debug!("Endpoint registered");
500+
501+
// Get the endpoint handler to simulate request processing
502+
let endpoint_handler = server
503+
.handlers
504+
.get(&endpoint_path)
505+
.expect("Handler should be registered")
506+
.clone();
507+
508+
// Spawn a task that simulates an inflight request
509+
let request_task = tokio::spawn({
510+
let handler = handler.clone();
511+
async move {
512+
let payload = Bytes::from("test payload");
513+
handler.handle_payload(payload).await
514+
}
515+
});
516+
517+
// Increment inflight counter manually to simulate the request being tracked
518+
endpoint_handler.inflight.fetch_add(1, Ordering::SeqCst);
519+
520+
// Wait for request to start processing
521+
tokio::select! {
522+
_ = request_started.notified() => {
523+
tracing::debug!("Request processing started");
524+
}
525+
_ = tokio::time::sleep(Duration::from_secs(2)) => {
526+
panic!("Timeout waiting for request to start");
527+
}
528+
}
529+
530+
// Verify request is in flight
531+
assert!(
532+
request_in_flight.load(Ordering::SeqCst),
533+
"Request should be in flight"
534+
);
535+
536+
// Now unregister the endpoint while request is inflight
537+
let unregister_start = Instant::now();
538+
tracing::debug!("Starting unregister_endpoint with inflight request");
539+
540+
// Spawn unregister in a separate task so we can monitor its behavior
541+
let unregister_task = tokio::spawn({
542+
let server = server.clone();
543+
let endpoint_path = endpoint_path.clone();
544+
async move {
545+
server.unregister_endpoint(&endpoint_path, "test_endpoint").await;
546+
Instant::now()
547+
}
548+
});
549+
550+
// Give unregister a moment to remove handler and start waiting
551+
tokio::time::sleep(Duration::from_millis(50)).await;
552+
553+
// Verify that unregister_endpoint hasn't returned yet (it should be waiting)
554+
assert!(
555+
!unregister_task.is_finished(),
556+
"unregister_endpoint should still be waiting for inflight request"
557+
);
558+
559+
tracing::debug!("Verified unregister is waiting, now waiting for request to complete");
560+
561+
// Wait for the request to complete
562+
tokio::select! {
563+
_ = request_completed.notified() => {
564+
tracing::debug!("Request completed");
565+
}
566+
_ = tokio::time::sleep(Duration::from_secs(2)) => {
567+
panic!("Timeout waiting for request to complete");
568+
}
569+
}
570+
571+
// Decrement inflight counter and notify (simulating what the real code does)
572+
endpoint_handler.inflight.fetch_sub(1, Ordering::SeqCst);
573+
endpoint_handler.notify.notify_one();
574+
575+
// Now wait for unregister to complete
576+
let unregister_end = tokio::time::timeout(Duration::from_secs(2), unregister_task)
577+
.await
578+
.expect("unregister_endpoint should complete after inflight request finishes")
579+
.expect("unregister task should not panic");
580+
581+
let unregister_duration = unregister_end - unregister_start;
582+
583+
tracing::debug!("unregister_endpoint completed in {:?}", unregister_duration);
584+
585+
// Verify unregister_endpoint waited for the inflight request
586+
assert!(
587+
unregister_duration >= Duration::from_secs(1),
588+
"unregister_endpoint should have waited ~1s for inflight request, but only took {:?}",
589+
unregister_duration
590+
);
591+
592+
// Verify request completed successfully
593+
assert!(
594+
!request_in_flight.load(Ordering::SeqCst),
595+
"Request should have completed"
596+
);
597+
598+
// Wait for request task to finish
599+
request_task
600+
.await
601+
.expect("Request task should complete")
602+
.expect("Request should succeed");
603+
604+
tracing::info!("Test passed: unregister_endpoint properly waited for inflight TCP request");
605+
}
606+
}

0 commit comments

Comments
 (0)