@@ -36,22 +36,23 @@ use tokio::time::{Duration, Instant};
3636///
3737/// Provides AsyncRead/AsyncWrite interface over unreliable transport
3838pub struct ReliableTransport < T > {
39+ /// Underlying unreliable transport
40+ transport : Arc < Mutex < T > > ,
3941 /// KCP protocol state
40- kcp : Arc < Mutex < Kcp < KcpOutput < T > > > > ,
42+ kcp : Arc < Mutex < Kcp < KcpOutput > > > ,
4143 /// Channel receiver for decoded data from KCP
4244 read_rx : tokio:: sync:: mpsc:: UnboundedReceiver < Vec < u8 > > ,
43- /// Channel sender for decoded data (used by background task)
44- _read_tx : Arc < tokio :: sync :: mpsc :: UnboundedSender < Vec < u8 > > > ,
45+ /// Buffer for KCP output packets
46+ write_buf : Arc < Mutex < VecDeque < Vec < u8 > > > > ,
4547 /// Last update time
4648 last_update : Arc < Mutex < Instant > > ,
4749 /// Buffer for partially read data
4850 partial_buf : Vec < u8 > ,
4951 partial_pos : usize ,
5052}
5153
52- /// Output callback for KCP - sends packets to underlying transport
53- struct KcpOutput < T > {
54- transport : Arc < Mutex < T > > ,
54+ /// Output callback for KCP - sends packets to a buffer
55+ struct KcpOutput {
5556 write_buf : Arc < Mutex < VecDeque < Vec < u8 > > > > ,
5657}
5758
7172 let write_buf = Arc :: new ( Mutex :: new ( VecDeque :: new ( ) ) ) ;
7273
7374 let output = KcpOutput {
74- transport : Arc :: clone ( & transport) ,
7575 write_buf : Arc :: clone ( & write_buf) ,
7676 } ;
7777
@@ -116,37 +116,6 @@ where
116116 }
117117 } ) ;
118118
119- // Start write task - send buffered KCP output to underlying transport
120- let transport_write = Arc :: clone ( & transport) ;
121- let write_buf_task = Arc :: clone ( & write_buf) ;
122- tokio:: spawn ( async move {
123- use tokio:: io:: AsyncWriteExt ;
124-
125- loop {
126- tokio:: time:: sleep ( Duration :: from_millis ( 5 ) ) . await ;
127-
128- // Check if there's data to write
129- let packet = {
130- let mut buf = write_buf_task. lock ( ) . await ;
131- buf. pop_front ( )
132- } ;
133-
134- if let Some ( data) = packet {
135- log:: debug!( "Write task: dequeued {} bytes, writing to transport" , data. len( ) ) ;
136- // Write to underlying transport
137- let mut transport = transport_write. lock ( ) . await ;
138- if let Err ( e) = transport. write_all ( & data) . await {
139- log:: error!( "Failed to write KCP output to transport: {}" , e) ;
140- } else {
141- log:: debug!( "Write task: successfully wrote to transport" ) ;
142- }
143- if let Err ( e) = transport. flush ( ) . await {
144- log:: error!( "Failed to flush transport: {}" , e) ;
145- }
146- }
147- }
148- } ) ;
149-
150119 // Start read task - feed transport data into KCP
151120 let transport_read = Arc :: clone ( & transport) ;
152121 let kcp_input = Arc :: clone ( & kcp) ;
@@ -225,9 +194,10 @@ where
225194 } ) ;
226195
227196 Ok ( Self {
197+ transport,
228198 kcp,
229199 read_rx,
230- _read_tx : read_tx ,
200+ write_buf ,
231201 last_update,
232202 partial_buf : Vec :: new ( ) ,
233203 partial_pos : 0 ,
@@ -270,10 +240,7 @@ where
270240 }
271241}
272242
273- impl < T > std:: io:: Write for KcpOutput < T >
274- where
275- T : AsyncWrite + Unpin + Send ,
276- {
243+ impl std:: io:: Write for KcpOutput {
277244 fn write ( & mut self , data : & [ u8 ] ) -> std:: io:: Result < usize > {
278245 // Queue packet to write buffer - will be written by async task
279246 log:: debug!( "KCP output callback: writing {} bytes to write_buf" , data. len( ) ) ;
@@ -358,7 +325,7 @@ where
358325{
359326 fn poll_write (
360327 self : Pin < & mut Self > ,
361- _cx : & mut Context < ' _ > ,
328+ cx : & mut Context < ' _ > ,
362329 buf : & [ u8 ] ,
363330 ) -> Poll < io:: Result < usize > > {
364331 let this = self . get_mut ( ) ;
@@ -367,10 +334,8 @@ where
367334 let mut kcp = match this. kcp . try_lock ( ) {
368335 Ok ( guard) => guard,
369336 Err ( _) => {
370- return Poll :: Ready ( Err ( io:: Error :: new (
371- io:: ErrorKind :: WouldBlock ,
372- "KCP lock busy" ,
373- ) ) ) ;
337+ cx. waker ( ) . wake_by_ref ( ) ;
338+ return Poll :: Pending ;
374339 }
375340 } ;
376341
@@ -385,44 +350,79 @@ where
385350 }
386351 } ;
387352
388- // Flush immediately to trigger output callback
353+ // Flush immediately to trigger output callback, which populates write_buf
389354 if let Err ( e) = kcp. flush ( ) {
390355 return Poll :: Ready ( Err ( io:: Error :: new (
391356 io:: ErrorKind :: Other ,
392357 format ! ( "KCP flush after send error: {:?}" , e) ,
393358 ) ) ) ;
394359 }
395360
361+ // Try to drain the write buffer and write to the transport
362+ // This is the core of the deadlock fix: we write directly here
363+ // instead of in a separate task.
364+ let mut transport = match this. transport . try_lock ( ) {
365+ Ok ( guard) => guard,
366+ Err ( _) => {
367+ // If we can't get the transport lock, it's likely held by the read task.
368+ // We must return Pending and let the runtime poll us again later.
369+ cx. waker ( ) . wake_by_ref ( ) ;
370+ return Poll :: Pending ;
371+ }
372+ } ;
373+
374+ let mut write_buf = match this. write_buf . try_lock ( ) {
375+ Ok ( guard) => guard,
376+ Err ( _) => {
377+ // This should be rare, but if the KCP output callback is somehow
378+ // running concurrently, we might not get the lock.
379+ cx. waker ( ) . wake_by_ref ( ) ;
380+ return Poll :: Pending ;
381+ }
382+ } ;
383+
384+ while let Some ( packet) = write_buf. pop_front ( ) {
385+ if let Poll :: Ready ( Err ( e) ) = Pin :: new ( & mut * transport) . poll_write ( cx, & packet) {
386+ // If write fails (e.g., WouldBlock), re-queue the packet and return
387+ write_buf. push_front ( packet) ;
388+ return Poll :: Ready ( Err ( e) ) ;
389+ }
390+ }
391+
396392 Poll :: Ready ( Ok ( result) )
397393 }
398394
399- fn poll_flush ( self : Pin < & mut Self > , _cx : & mut Context < ' _ > ) -> Poll < io:: Result < ( ) > > {
395+ fn poll_flush ( self : Pin < & mut Self > , cx : & mut Context < ' _ > ) -> Poll < io:: Result < ( ) > > {
400396 let this = self . get_mut ( ) ;
401397
402- // Try to get lock on KCP
403- let mut kcp = match this. kcp . try_lock ( ) {
398+ let mut transport = match this. transport . try_lock ( ) {
404399 Ok ( guard) => guard,
405400 Err ( _) => {
406- return Poll :: Ready ( Err ( io:: Error :: new (
407- io:: ErrorKind :: WouldBlock ,
408- "KCP lock busy" ,
409- ) ) ) ;
401+ cx. waker ( ) . wake_by_ref ( ) ;
402+ return Poll :: Pending ;
410403 }
411404 } ;
412405
413- // Flush KCP buffers
414- match kcp. flush ( ) {
415- Ok ( _) => Poll :: Ready ( Ok ( ( ) ) ) ,
416- Err ( e) => Poll :: Ready ( Err ( io:: Error :: new (
417- io:: ErrorKind :: Other ,
418- format ! ( "KCP flush error: {:?}" , e) ,
419- ) ) ) ,
420- }
406+ Pin :: new ( & mut * transport) . poll_flush ( cx)
421407 }
422408
423- fn poll_shutdown ( self : Pin < & mut Self > , _cx : & mut Context < ' _ > ) -> Poll < io:: Result < ( ) > > {
424- // KCP doesn't have explicit shutdown
425- Poll :: Ready ( Ok ( ( ) ) )
409+ fn poll_shutdown ( self : Pin < & mut Self > , cx : & mut Context < ' _ > ) -> Poll < io:: Result < ( ) > > {
410+ let this = self . get_mut ( ) ;
411+
412+ // Try to flush any remaining data before shutting down
413+ if let Poll :: Pending = Pin :: new ( & mut * this) . poll_flush ( cx) {
414+ return Poll :: Pending ;
415+ }
416+
417+ let mut transport = match this. transport . try_lock ( ) {
418+ Ok ( guard) => guard,
419+ Err ( _) => {
420+ cx. waker ( ) . wake_by_ref ( ) ;
421+ return Poll :: Pending ;
422+ }
423+ } ;
424+
425+ Pin :: new ( & mut * transport) . poll_shutdown ( cx)
426426 }
427427}
428428
0 commit comments