@@ -391,3 +391,216 @@ impl super::unified_server::RequestPlaneServer for SharedTcpServer {
391391 true
392392 }
393393}
394+
395+ #[ cfg( test) ]
396+ mod tests {
397+ use super :: * ;
398+ use crate :: pipeline:: error:: PipelineError ;
399+ use async_trait:: async_trait;
400+ use std:: sync:: atomic:: { AtomicBool , Ordering } ;
401+ use std:: time:: Duration ;
402+ use tokio:: time:: Instant ;
403+
404+ /// Mock handler that simulates slow request processing for testing
405+ struct SlowMockHandler {
406+ /// Tracks if a request is currently being processed
407+ request_in_flight : Arc < AtomicBool > ,
408+ /// Notifies when request processing starts
409+ request_started : Arc < Notify > ,
410+ /// Notifies when request processing completes
411+ request_completed : Arc < Notify > ,
412+ /// Duration to simulate request processing
413+ processing_duration : Duration ,
414+ }
415+
416+ impl SlowMockHandler {
417+ fn new ( processing_duration : Duration ) -> Self {
418+ Self {
419+ request_in_flight : Arc :: new ( AtomicBool :: new ( false ) ) ,
420+ request_started : Arc :: new ( Notify :: new ( ) ) ,
421+ request_completed : Arc :: new ( Notify :: new ( ) ) ,
422+ processing_duration,
423+ }
424+ }
425+ }
426+
427+ #[ async_trait]
428+ impl PushWorkHandler for SlowMockHandler {
429+ async fn handle_payload ( & self , _payload : Bytes ) -> Result < ( ) , PipelineError > {
430+ self . request_in_flight . store ( true , Ordering :: SeqCst ) ;
431+ self . request_started . notify_one ( ) ;
432+
433+ tracing:: debug!(
434+ "SlowMockHandler: Request started, sleeping for {:?}" ,
435+ self . processing_duration
436+ ) ;
437+
438+ // Simulate slow request processing
439+ tokio:: time:: sleep ( self . processing_duration ) . await ;
440+
441+ tracing:: debug!( "SlowMockHandler: Request completed" ) ;
442+
443+ self . request_in_flight . store ( false , Ordering :: SeqCst ) ;
444+ self . request_completed . notify_one ( ) ;
445+ Ok ( ( ) )
446+ }
447+
448+ fn add_metrics (
449+ & self ,
450+ _endpoint : & crate :: component:: Endpoint ,
451+ _metrics_labels : Option < & [ ( & str , & str ) ] > ,
452+ ) -> Result < ( ) > {
453+ Ok ( ( ) )
454+ }
455+ }
456+
457+ #[ tokio:: test]
458+ async fn test_graceful_shutdown_waits_for_inflight_tcp_requests ( ) {
459+ // Initialize tracing for test debugging
460+ let _ = tracing_subscriber:: fmt ( )
461+ . with_test_writer ( )
462+ . with_max_level ( tracing:: Level :: DEBUG )
463+ . try_init ( ) ;
464+
465+ let cancellation_token = CancellationToken :: new ( ) ;
466+ let bind_addr: SocketAddr = "127.0.0.1:0" . parse ( ) . unwrap ( ) ;
467+
468+ // Create SharedTcpServer
469+ let server = SharedTcpServer :: new ( bind_addr, cancellation_token. clone ( ) ) ;
470+
471+ // Create a handler that takes 1s to process requests
472+ let handler = Arc :: new ( SlowMockHandler :: new ( Duration :: from_secs ( 1 ) ) ) ;
473+ let request_started = handler. request_started . clone ( ) ;
474+ let request_completed = handler. request_completed . clone ( ) ;
475+ let request_in_flight = handler. request_in_flight . clone ( ) ;
476+
477+ // Register endpoint
478+ let endpoint_path = "test_endpoint" . to_string ( ) ;
479+ let system_health = Arc :: new ( Mutex :: new ( SystemHealth :: new (
480+ crate :: HealthStatus :: Ready ,
481+ vec ! [ ] ,
482+ "/health" . to_string ( ) ,
483+ "/live" . to_string ( ) ,
484+ ) ) ) ;
485+
486+ server
487+ . register_endpoint (
488+ endpoint_path. clone ( ) ,
489+ handler. clone ( ) as Arc < dyn PushWorkHandler > ,
490+ 1 ,
491+ "test_namespace" . to_string ( ) ,
492+ "test_component" . to_string ( ) ,
493+ "test_endpoint" . to_string ( ) ,
494+ system_health,
495+ )
496+ . await
497+ . expect ( "Failed to register endpoint" ) ;
498+
499+ tracing:: debug!( "Endpoint registered" ) ;
500+
501+ // Get the endpoint handler to simulate request processing
502+ let endpoint_handler = server
503+ . handlers
504+ . get ( & endpoint_path)
505+ . expect ( "Handler should be registered" )
506+ . clone ( ) ;
507+
508+ // Spawn a task that simulates an inflight request
509+ let request_task = tokio:: spawn ( {
510+ let handler = handler. clone ( ) ;
511+ async move {
512+ let payload = Bytes :: from ( "test payload" ) ;
513+ handler. handle_payload ( payload) . await
514+ }
515+ } ) ;
516+
517+ // Increment inflight counter manually to simulate the request being tracked
518+ endpoint_handler. inflight . fetch_add ( 1 , Ordering :: SeqCst ) ;
519+
520+ // Wait for request to start processing
521+ tokio:: select! {
522+ _ = request_started. notified( ) => {
523+ tracing:: debug!( "Request processing started" ) ;
524+ }
525+ _ = tokio:: time:: sleep( Duration :: from_secs( 2 ) ) => {
526+ panic!( "Timeout waiting for request to start" ) ;
527+ }
528+ }
529+
530+ // Verify request is in flight
531+ assert ! (
532+ request_in_flight. load( Ordering :: SeqCst ) ,
533+ "Request should be in flight"
534+ ) ;
535+
536+ // Now unregister the endpoint while request is inflight
537+ let unregister_start = Instant :: now ( ) ;
538+ tracing:: debug!( "Starting unregister_endpoint with inflight request" ) ;
539+
540+ // Spawn unregister in a separate task so we can monitor its behavior
541+ let unregister_task = tokio:: spawn ( {
542+ let server = server. clone ( ) ;
543+ let endpoint_path = endpoint_path. clone ( ) ;
544+ async move {
545+ server. unregister_endpoint ( & endpoint_path, "test_endpoint" ) . await ;
546+ Instant :: now ( )
547+ }
548+ } ) ;
549+
550+ // Give unregister a moment to remove handler and start waiting
551+ tokio:: time:: sleep ( Duration :: from_millis ( 50 ) ) . await ;
552+
553+ // Verify that unregister_endpoint hasn't returned yet (it should be waiting)
554+ assert ! (
555+ !unregister_task. is_finished( ) ,
556+ "unregister_endpoint should still be waiting for inflight request"
557+ ) ;
558+
559+ tracing:: debug!( "Verified unregister is waiting, now waiting for request to complete" ) ;
560+
561+ // Wait for the request to complete
562+ tokio:: select! {
563+ _ = request_completed. notified( ) => {
564+ tracing:: debug!( "Request completed" ) ;
565+ }
566+ _ = tokio:: time:: sleep( Duration :: from_secs( 2 ) ) => {
567+ panic!( "Timeout waiting for request to complete" ) ;
568+ }
569+ }
570+
571+ // Decrement inflight counter and notify (simulating what the real code does)
572+ endpoint_handler. inflight . fetch_sub ( 1 , Ordering :: SeqCst ) ;
573+ endpoint_handler. notify . notify_one ( ) ;
574+
575+ // Now wait for unregister to complete
576+ let unregister_end = tokio:: time:: timeout ( Duration :: from_secs ( 2 ) , unregister_task)
577+ . await
578+ . expect ( "unregister_endpoint should complete after inflight request finishes" )
579+ . expect ( "unregister task should not panic" ) ;
580+
581+ let unregister_duration = unregister_end - unregister_start;
582+
583+ tracing:: debug!( "unregister_endpoint completed in {:?}" , unregister_duration) ;
584+
585+ // Verify unregister_endpoint waited for the inflight request
586+ assert ! (
587+ unregister_duration >= Duration :: from_secs( 1 ) ,
588+ "unregister_endpoint should have waited ~1s for inflight request, but only took {:?}" ,
589+ unregister_duration
590+ ) ;
591+
592+ // Verify request completed successfully
593+ assert ! (
594+ !request_in_flight. load( Ordering :: SeqCst ) ,
595+ "Request should have completed"
596+ ) ;
597+
598+ // Wait for request task to finish
599+ request_task
600+ . await
601+ . expect ( "Request task should complete" )
602+ . expect ( "Request should succeed" ) ;
603+
604+ tracing:: info!( "Test passed: unregister_endpoint properly waited for inflight TCP request" ) ;
605+ }
606+ }
0 commit comments