Skip to content

Commit 70b2b63

Browse files
committed
Fix: KCP deadlock issue in reliable_transport.rs
1 parent be09410 commit 70b2b63

File tree

1 file changed

+67
-67
lines changed

1 file changed

+67
-67
lines changed

src/reliable_transport.rs

Lines changed: 67 additions & 67 deletions
Original file line numberDiff line numberDiff line change
@@ -36,22 +36,23 @@ use tokio::time::{Duration, Instant};
3636
///
3737
/// Provides AsyncRead/AsyncWrite interface over unreliable transport
3838
pub struct ReliableTransport<T> {
39+
/// Underlying unreliable transport
40+
transport: Arc<Mutex<T>>,
3941
/// KCP protocol state
40-
kcp: Arc<Mutex<Kcp<KcpOutput<T>>>>,
42+
kcp: Arc<Mutex<Kcp<KcpOutput>>>,
4143
/// Channel receiver for decoded data from KCP
4244
read_rx: tokio::sync::mpsc::UnboundedReceiver<Vec<u8>>,
43-
/// Channel sender for decoded data (used by background task)
44-
_read_tx: Arc<tokio::sync::mpsc::UnboundedSender<Vec<u8>>>,
45+
/// Buffer for KCP output packets
46+
write_buf: Arc<Mutex<VecDeque<Vec<u8>>>>,
4547
/// Last update time
4648
last_update: Arc<Mutex<Instant>>,
4749
/// Buffer for partially read data
4850
partial_buf: Vec<u8>,
4951
partial_pos: usize,
5052
}
5153

52-
/// Output callback for KCP - sends packets to underlying transport
53-
struct KcpOutput<T> {
54-
transport: Arc<Mutex<T>>,
54+
/// Output callback for KCP - sends packets to a buffer
55+
struct KcpOutput {
5556
write_buf: Arc<Mutex<VecDeque<Vec<u8>>>>,
5657
}
5758

@@ -71,7 +72,6 @@ where
7172
let write_buf = Arc::new(Mutex::new(VecDeque::new()));
7273

7374
let output = KcpOutput {
74-
transport: Arc::clone(&transport),
7575
write_buf: Arc::clone(&write_buf),
7676
};
7777

@@ -116,37 +116,6 @@ where
116116
}
117117
});
118118

119-
// Start write task - send buffered KCP output to underlying transport
120-
let transport_write = Arc::clone(&transport);
121-
let write_buf_task = Arc::clone(&write_buf);
122-
tokio::spawn(async move {
123-
use tokio::io::AsyncWriteExt;
124-
125-
loop {
126-
tokio::time::sleep(Duration::from_millis(5)).await;
127-
128-
// Check if there's data to write
129-
let packet = {
130-
let mut buf = write_buf_task.lock().await;
131-
buf.pop_front()
132-
};
133-
134-
if let Some(data) = packet {
135-
log::debug!("Write task: dequeued {} bytes, writing to transport", data.len());
136-
// Write to underlying transport
137-
let mut transport = transport_write.lock().await;
138-
if let Err(e) = transport.write_all(&data).await {
139-
log::error!("Failed to write KCP output to transport: {}", e);
140-
} else {
141-
log::debug!("Write task: successfully wrote to transport");
142-
}
143-
if let Err(e) = transport.flush().await {
144-
log::error!("Failed to flush transport: {}", e);
145-
}
146-
}
147-
}
148-
});
149-
150119
// Start read task - feed transport data into KCP
151120
let transport_read = Arc::clone(&transport);
152121
let kcp_input = Arc::clone(&kcp);
@@ -225,9 +194,10 @@ where
225194
});
226195

227196
Ok(Self {
197+
transport,
228198
kcp,
229199
read_rx,
230-
_read_tx: read_tx,
200+
write_buf,
231201
last_update,
232202
partial_buf: Vec::new(),
233203
partial_pos: 0,
@@ -270,10 +240,7 @@ where
270240
}
271241
}
272242

273-
impl<T> std::io::Write for KcpOutput<T>
274-
where
275-
T: AsyncWrite + Unpin + Send,
276-
{
243+
impl std::io::Write for KcpOutput {
277244
fn write(&mut self, data: &[u8]) -> std::io::Result<usize> {
278245
// Queue packet to write buffer - will be written by async task
279246
log::debug!("KCP output callback: writing {} bytes to write_buf", data.len());
@@ -358,7 +325,7 @@ where
358325
{
359326
fn poll_write(
360327
self: Pin<&mut Self>,
361-
_cx: &mut Context<'_>,
328+
cx: &mut Context<'_>,
362329
buf: &[u8],
363330
) -> Poll<io::Result<usize>> {
364331
let this = self.get_mut();
@@ -367,10 +334,8 @@ where
367334
let mut kcp = match this.kcp.try_lock() {
368335
Ok(guard) => guard,
369336
Err(_) => {
370-
return Poll::Ready(Err(io::Error::new(
371-
io::ErrorKind::WouldBlock,
372-
"KCP lock busy",
373-
)));
337+
cx.waker().wake_by_ref();
338+
return Poll::Pending;
374339
}
375340
};
376341

@@ -385,44 +350,79 @@ where
385350
}
386351
};
387352

388-
// Flush immediately to trigger output callback
353+
// Flush immediately to trigger output callback, which populates write_buf
389354
if let Err(e) = kcp.flush() {
390355
return Poll::Ready(Err(io::Error::new(
391356
io::ErrorKind::Other,
392357
format!("KCP flush after send error: {:?}", e),
393358
)));
394359
}
395360

361+
// Try to drain the write buffer and write to the transport
362+
// This is the core of the deadlock fix: we write directly here
363+
// instead of in a separate task.
364+
let mut transport = match this.transport.try_lock() {
365+
Ok(guard) => guard,
366+
Err(_) => {
367+
// If we can't get the transport lock, it's likely held by the read task.
368+
// We must return Pending and let the runtime poll us again later.
369+
cx.waker().wake_by_ref();
370+
return Poll::Pending;
371+
}
372+
};
373+
374+
let mut write_buf = match this.write_buf.try_lock() {
375+
Ok(guard) => guard,
376+
Err(_) => {
377+
// This should be rare, but if the KCP output callback is somehow
378+
// running concurrently, we might not get the lock.
379+
cx.waker().wake_by_ref();
380+
return Poll::Pending;
381+
}
382+
};
383+
384+
while let Some(packet) = write_buf.pop_front() {
385+
if let Poll::Ready(Err(e)) = Pin::new(&mut *transport).poll_write(cx, &packet) {
386+
// If write fails (e.g., WouldBlock), re-queue the packet and return
387+
write_buf.push_front(packet);
388+
return Poll::Ready(Err(e));
389+
}
390+
}
391+
396392
Poll::Ready(Ok(result))
397393
}
398394

399-
fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<io::Result<()>> {
395+
fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
400396
let this = self.get_mut();
401397

402-
// Try to get lock on KCP
403-
let mut kcp = match this.kcp.try_lock() {
398+
let mut transport = match this.transport.try_lock() {
404399
Ok(guard) => guard,
405400
Err(_) => {
406-
return Poll::Ready(Err(io::Error::new(
407-
io::ErrorKind::WouldBlock,
408-
"KCP lock busy",
409-
)));
401+
cx.waker().wake_by_ref();
402+
return Poll::Pending;
410403
}
411404
};
412405

413-
// Flush KCP buffers
414-
match kcp.flush() {
415-
Ok(_) => Poll::Ready(Ok(())),
416-
Err(e) => Poll::Ready(Err(io::Error::new(
417-
io::ErrorKind::Other,
418-
format!("KCP flush error: {:?}", e),
419-
))),
420-
}
406+
Pin::new(&mut *transport).poll_flush(cx)
421407
}
422408

423-
fn poll_shutdown(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<io::Result<()>> {
424-
// KCP doesn't have explicit shutdown
425-
Poll::Ready(Ok(()))
409+
fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
410+
let this = self.get_mut();
411+
412+
// Try to flush any remaining data before shutting down
413+
if let Poll::Pending = Pin::new(&mut *this).poll_flush(cx) {
414+
return Poll::Pending;
415+
}
416+
417+
let mut transport = match this.transport.try_lock() {
418+
Ok(guard) => guard,
419+
Err(_) => {
420+
cx.waker().wake_by_ref();
421+
return Poll::Pending;
422+
}
423+
};
424+
425+
Pin::new(&mut *transport).poll_shutdown(cx)
426426
}
427427
}
428428

0 commit comments

Comments
 (0)