Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
23 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions tokio-util/src/cfg.rs
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,15 @@ macro_rules! cfg_rt {
}
}

macro_rules! cfg_not_rt {
($($item:item)*) => {
$(
#[cfg(not(feature = "rt"))]
$item
)*
}
}

macro_rules! cfg_time {
($($item:item)*) => {
$(
Expand Down
1 change: 1 addition & 0 deletions tokio-util/src/io/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ mod copy_to_bytes;
mod inspect;
mod read_buf;
mod reader_stream;
pub mod simplex;
mod sink_writer;
mod stream_reader;

Expand Down
322 changes: 322 additions & 0 deletions tokio-util/src/io/simplex.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,322 @@
//! Unidirectional byte-oriented channel.

use crate::util::poll_proceed_and_make_progress;

use bytes::Buf;
use bytes::BytesMut;
use futures_core::ready;
use std::io::Error as IoError;
use std::io::ErrorKind as IoErrorKind;
use std::io::IoSlice;
use std::pin::Pin;
use std::sync::{Arc, Mutex};
use std::task::{Context, Poll, Waker};
use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};

type IoResult<T> = Result<T, IoError>;

const CLOSED_ERROR_MSG: &str = "simplex has been closed";

#[derive(Debug)]
struct Inner {
/// `poll_write` will return [`Poll::Pending`] if the backpressure boundary is reached
backpressure_boundary: usize,

/// either [`Sender`] or [`Receiver`] is closed
is_closed: bool,

/// Waker used to wake the [`Receiver`]
receiver_waker: Option<Waker>,

/// Waker used to wake the [`Sender`]
sender_waker: Option<Waker>,

/// Buffer used to read and write data
buf: BytesMut,
}

impl Inner {
fn with_capacity(capacity: usize) -> Self {
Self {
backpressure_boundary: capacity,
is_closed: false,
receiver_waker: None,
sender_waker: None,
buf: BytesMut::with_capacity(capacity),
}
}

fn register_receiver_waker(&mut self, waker: &Waker) {
match self.receiver_waker.as_mut() {
Some(old) if old.will_wake(waker) => {}
Some(old) => old.clone_from(waker),
None => self.receiver_waker = Some(waker.clone()),
Comment on lines +52 to +53
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We usually avoid dropping wakers while holding a lock. Perhaps we should return the old waker to the caller so they can drop it after releasing the lock?

}
}

fn register_sender_waker(&mut self, waker: &Waker) {
match self.sender_waker.as_mut() {
Some(old) if old.will_wake(waker) => {}
Some(old) => old.clone_from(waker),
None => self.sender_waker = Some(waker.clone()),
}
}

fn take_receiver_waker(&mut self) -> Option<Waker> {
self.receiver_waker.take()
}

fn take_sender_waker(&mut self) -> Option<Waker> {
self.sender_waker.take()
}

fn is_closed(&self) -> bool {
self.is_closed
}

fn close_receiver(&mut self) -> Option<Waker> {
self.is_closed = true;
self.take_sender_waker()
}

fn close_sender(&mut self) -> Option<Waker> {
self.is_closed = true;
self.take_receiver_waker()
}
}

/// Receiver of the simplex channel.
///
/// You can still read the remaining data from the buffer
/// even if the write half has been dropped.
/// See [`Sender::poll_shutdown`] and [`Sender::drop`] for more details.
#[derive(Debug)]
pub struct Receiver {
inner: Arc<Mutex<Inner>>,
}

impl Drop for Receiver {
/// This also wakes up the [`Sender`].
fn drop(&mut self) {
let maybe_waker = {
let mut inner = self.inner.lock().unwrap();
inner.close_receiver()
};

if let Some(waker) = maybe_waker {
waker.wake();
}
}
}

impl AsyncRead for Receiver {
fn poll_read(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &mut ReadBuf<'_>,
) -> Poll<IoResult<()>> {
let mut inner = self.inner.lock().unwrap();

let to_read = buf.remaining().min(inner.buf.remaining());
if to_read == 0 {
if inner.is_closed() || buf.remaining() == 0 {
return Poll::Ready(Ok(()));
}

inner.register_receiver_waker(cx.waker());
let maybe_waker = inner.take_sender_waker();
drop(inner); // unlock before waking up
if let Some(waker) = maybe_waker {
waker.wake();
}
return Poll::Pending;
}

ready!(poll_proceed_and_make_progress(cx));

buf.put_slice(&inner.buf[..to_read]);
inner.buf.advance(to_read);
let waker = inner.take_sender_waker();
drop(inner); // unlock before waking up
if let Some(waker) = waker {
waker.wake();
}
Poll::Ready(Ok(()))
}
}

/// Sender of the simplex channel.
///
/// ## Shutdown
///
/// See [`Sender::poll_shutdown`].
#[derive(Debug)]
pub struct Sender {
inner: Arc<Mutex<Inner>>,
}

impl Drop for Sender {
/// This also wakes up the [`Receiver`].
fn drop(&mut self) {
let maybe_waker = {
let mut inner = self.inner.lock().unwrap();
inner.close_sender()
};

if let Some(waker) = maybe_waker {
waker.wake();
}
}
}

impl AsyncWrite for Sender {
/// # Errors
///
/// This method will return [`IoErrorKind::BrokenPipe`]
/// if the channel has been closed.
fn poll_write(self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &[u8]) -> Poll<IoResult<usize>> {
let mut inner = self.inner.lock().unwrap();

if inner.is_closed() {
return Poll::Ready(Err(IoError::new(IoErrorKind::BrokenPipe, CLOSED_ERROR_MSG)));
}

let free = inner
.backpressure_boundary
.checked_sub(inner.buf.len())
.expect("backpressure boundary overflow");
let to_write = buf.len().min(free);
if to_write == 0 {
if buf.is_empty() {
return Poll::Ready(Ok(0));
}

inner.register_sender_waker(cx.waker());
let waker = inner.take_receiver_waker();
drop(inner); // unlock before waking up
if let Some(waker) = waker {
waker.wake();
}
return Poll::Pending;
}

// this is to avoid starving other tasks
ready!(poll_proceed_and_make_progress(cx));

inner.buf.extend_from_slice(&buf[..to_write]);
let waker = inner.take_receiver_waker();
drop(inner); // unlock before waking up
if let Some(waker) = waker {
waker.wake();
}
Poll::Ready(Ok(to_write))
}

/// # Errors
///
/// This method will return [`IoErrorKind::BrokenPipe`]
/// if the channel has been closed.
fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<IoResult<()>> {
let inner = self.inner.lock().unwrap();
if inner.is_closed() {
Poll::Ready(Err(IoError::new(IoErrorKind::BrokenPipe, CLOSED_ERROR_MSG)))
} else {
Poll::Ready(Ok(()))
}
}

/// After returns [`Poll::Ready`], all the following call to
/// [`Sender::poll_write`] and [`Sender::poll_flush`]
/// will return error.
///
/// The [`Receiver`] can still be used to read remaining data
/// until all bytes have been consumed.
fn poll_shutdown(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<IoResult<()>> {
let maybe_waker = {
let mut inner = self.inner.lock().unwrap();
inner.close_sender()
};

if let Some(waker) = maybe_waker {
waker.wake();
}

Poll::Ready(Ok(()))
}

fn is_write_vectored(&self) -> bool {
true
}

fn poll_write_vectored(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
bufs: &[IoSlice<'_>],
) -> Poll<Result<usize, IoError>> {
let mut inner = self.inner.lock().unwrap();
if inner.is_closed() {
return Poll::Ready(Err(IoError::new(IoErrorKind::BrokenPipe, CLOSED_ERROR_MSG)));
}

let free = inner
.backpressure_boundary
.checked_sub(inner.buf.len())
.expect("backpressure boundary overflow");
if free == 0 {
inner.register_sender_waker(cx.waker());
let maybe_waker = inner.take_receiver_waker();
drop(inner); // unlock before waking up
if let Some(waker) = maybe_waker {
waker.wake();
}
return Poll::Pending;
}

ready!(poll_proceed_and_make_progress(cx));
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Generally you should poll_proceed() at the very beginning of the function, and then call make_progress() only if you made progress (i.e. did not return pending). Doing both at the same time is almost always wrong.


let mut rem = free;
for buf in bufs {
if rem == 0 {
break;
}

let to_write = buf.len().min(rem);
if to_write == 0 {
assert_ne!(rem, 0);
assert_eq!(buf.len(), 0);
continue;
}

inner.buf.extend_from_slice(&buf[..to_write]);
rem -= to_write;
}

let waker = inner.take_receiver_waker();
drop(inner); // unlock before waking up
if let Some(waker) = waker {
waker.wake();
}

Poll::Ready(Ok(free - rem))
}
}

/// Create a simplex channel.
///
/// The `capacity` parameter specifies the maximum number of bytes that can be
/// stored in the channel without making the [`Sender::poll_write`]
/// return [`Poll::Pending`].
///
/// # Panics
///
/// This function will panic if `capacity` is zero.
pub fn new(capacity: usize) -> (Sender, Receiver) {
assert_ne!(capacity, 0, "capacity must be greater than zero");

let inner = Arc::new(Mutex::new(Inner::with_capacity(capacity)));
let tx = Sender {
inner: Arc::clone(&inner),
};
let rx = Receiver { inner };
(tx, rx)
}
21 changes: 21 additions & 0 deletions tokio-util/src/util/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,3 +6,24 @@ pub(crate) use maybe_dangling::MaybeDangling;
#[cfg(any(feature = "io", feature = "codec"))]
#[cfg_attr(not(feature = "io"), allow(unreachable_pub))]
pub use poll_buf::{poll_read_buf, poll_write_buf};

cfg_rt! {
use std::task::{Context, Poll};
use tokio::task::coop::poll_proceed;
use futures_core::ready;

#[cfg_attr(not(feature = "io"), allow(unused))]
pub(crate) fn poll_proceed_and_make_progress(cx: &mut Context<'_>) -> Poll<()> {
ready!(poll_proceed(cx)).made_progress();
Poll::Ready(())
}
}

cfg_not_rt! {
use std::task::{Context, Poll};

#[cfg_attr(not(feature = "io"), allow(unused))]
pub(crate) fn poll_proceed_and_make_progress(_cx: &mut Context<'_>) -> Poll<()> {
Poll::Ready(())
}
}
Loading
Loading