Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion volo-grpc/src/server/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ use motore::{
pub use service::ServiceBuilder;
use tower::util::BoxCloneService;
#[cfg(feature = "__tls")]
use volo::net::tls::{Acceptor, ServerTlsConfig};
use volo::net::tls::ServerTlsConfig;
use volo::{
net::{conn::Conn, incoming::Incoming},
spawn,
Expand Down
11 changes: 8 additions & 3 deletions volo-http/src/client/transport/tls.rs
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
use http::uri::Scheme;
use motore::service::UnaryService;
use volo::net::{
conn::{Conn, ConnStream},
tls::{Connector, TlsConnector},
conn::{Conn, ConnInfo, ConnStream},
tls::TlsConnector,
};

use super::{connector::PeerInfo, plain::PlainMakeConnection};
Expand Down Expand Up @@ -46,7 +46,12 @@ where
_ => unreachable!(),
};
match self.tls_connector.connect(&target_name, tcp_stream).await {
Ok(conn) => Ok(conn),
Ok(stream) => Ok(Conn {
stream,
info: ConnInfo {
peer_addr: Some(req.address),
},
}),
Err(err) => {
tracing::warn!("[Volo-HTTP] failed to make tls connection, error: {err}");
Err(request_error(err))
Expand Down
2 changes: 1 addition & 1 deletion volo-http/src/server/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ use scopeguard::defer;
use tokio::sync::Notify;
use tracing::Instrument;
#[cfg(feature = "__tls")]
use volo::net::{conn::ConnStream, tls::Acceptor, tls::ServerTlsConfig};
use volo::net::{conn::ConnStream, tls::ServerTlsConfig};
use volo::{
context::Context,
net::{Address, MakeIncoming, conn::Conn, incoming::Incoming},
Expand Down
156 changes: 46 additions & 110 deletions volo/src/net/conn.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,33 +23,22 @@ pub trait DynStream: AsyncRead + AsyncWrite + Send + 'static {}

impl<T> DynStream for T where T: AsyncRead + AsyncWrite + Send + 'static {}

#[allow(clippy::large_enum_variant)]
#[pin_project(project = IoStreamProj)]
pub enum ConnStream {
Tcp(#[pin] TcpStream),
#[cfg(target_family = "unix")]
Unix(#[pin] UnixStream),
#[cfg(feature = "rustls")]
Rustls(#[pin] tokio_rustls::TlsStream<TcpStream>),
#[cfg(feature = "native-tls")]
NativeTls(#[pin] tokio_native_tls::TlsStream<TcpStream>),
#[cfg(feature = "__tls")]
Tls(#[pin] super::tls::TlsStream),
}

#[cfg(feature = "rustls")]
type RustlsWriteHalf = tokio::io::WriteHalf<tokio_rustls::TlsStream<TcpStream>>;

#[cfg(feature = "native-tls")]
type NativeTlsWriteHalf = tokio::io::WriteHalf<tokio_native_tls::TlsStream<TcpStream>>;

#[pin_project(project = OwnedWriteHalfProj)]
pub enum OwnedWriteHalf {
Tcp(#[pin] tcp::OwnedWriteHalf),
#[cfg(target_family = "unix")]
Unix(#[pin] unix::OwnedWriteHalf),
#[cfg(feature = "rustls")]
Rustls(#[pin] RustlsWriteHalf),
#[cfg(feature = "native-tls")]
NativeTls(#[pin] NativeTlsWriteHalf),
#[cfg(feature = "__tls")]
Tls(#[pin] super::tls::OwnedWriteHalf),
}

impl AsyncWrite for OwnedWriteHalf {
Expand All @@ -63,10 +52,8 @@ impl AsyncWrite for OwnedWriteHalf {
OwnedWriteHalfProj::Tcp(half) => half.poll_write(cx, buf),
#[cfg(target_family = "unix")]
OwnedWriteHalfProj::Unix(half) => half.poll_write(cx, buf),
#[cfg(feature = "rustls")]
OwnedWriteHalfProj::Rustls(half) => half.poll_write(cx, buf),
#[cfg(feature = "native-tls")]
OwnedWriteHalfProj::NativeTls(half) => half.poll_write(cx, buf),
#[cfg(feature = "__tls")]
OwnedWriteHalfProj::Tls(half) => half.poll_write(cx, buf),
}
}

Expand All @@ -76,10 +63,8 @@ impl AsyncWrite for OwnedWriteHalf {
OwnedWriteHalfProj::Tcp(half) => half.poll_flush(cx),
#[cfg(target_family = "unix")]
OwnedWriteHalfProj::Unix(half) => half.poll_flush(cx),
#[cfg(feature = "rustls")]
OwnedWriteHalfProj::Rustls(half) => half.poll_flush(cx),
#[cfg(feature = "native-tls")]
OwnedWriteHalfProj::NativeTls(half) => half.poll_flush(cx),
#[cfg(feature = "__tls")]
OwnedWriteHalfProj::Tls(half) => half.poll_flush(cx),
}
}

Expand All @@ -89,10 +74,8 @@ impl AsyncWrite for OwnedWriteHalf {
OwnedWriteHalfProj::Tcp(half) => half.poll_shutdown(cx),
#[cfg(target_family = "unix")]
OwnedWriteHalfProj::Unix(half) => half.poll_shutdown(cx),
#[cfg(feature = "rustls")]
OwnedWriteHalfProj::Rustls(half) => half.poll_shutdown(cx),
#[cfg(feature = "native-tls")]
OwnedWriteHalfProj::NativeTls(half) => half.poll_shutdown(cx),
#[cfg(feature = "__tls")]
OwnedWriteHalfProj::Tls(half) => half.poll_shutdown(cx),
}
}

Expand All @@ -106,10 +89,8 @@ impl AsyncWrite for OwnedWriteHalf {
OwnedWriteHalfProj::Tcp(half) => half.poll_write_vectored(cx, bufs),
#[cfg(target_family = "unix")]
OwnedWriteHalfProj::Unix(half) => half.poll_write_vectored(cx, bufs),
#[cfg(feature = "rustls")]
OwnedWriteHalfProj::Rustls(half) => half.poll_write_vectored(cx, bufs),
#[cfg(feature = "native-tls")]
OwnedWriteHalfProj::NativeTls(half) => half.poll_write_vectored(cx, bufs),
#[cfg(feature = "__tls")]
OwnedWriteHalfProj::Tls(half) => half.poll_write_vectored(cx, bufs),
}
}

Expand All @@ -119,29 +100,19 @@ impl AsyncWrite for OwnedWriteHalf {
Self::Tcp(half) => half.is_write_vectored(),
#[cfg(target_family = "unix")]
Self::Unix(half) => half.is_write_vectored(),
#[cfg(feature = "rustls")]
Self::Rustls(half) => half.is_write_vectored(),
#[cfg(feature = "native-tls")]
Self::NativeTls(half) => half.is_write_vectored(),
#[cfg(feature = "__tls")]
Self::Tls(half) => half.is_write_vectored(),
}
}
}

#[cfg(feature = "rustls")]
type RustlsReadHalf = tokio::io::ReadHalf<tokio_rustls::TlsStream<TcpStream>>;

#[cfg(feature = "native-tls")]
type NativeTlsReadHalf = tokio::io::ReadHalf<tokio_native_tls::TlsStream<TcpStream>>;

#[pin_project(project = OwnedReadHalfProj)]
pub enum OwnedReadHalf {
Tcp(#[pin] tcp::OwnedReadHalf),
#[cfg(target_family = "unix")]
Unix(#[pin] unix::OwnedReadHalf),
#[cfg(feature = "rustls")]
Rustls(#[pin] RustlsReadHalf),
#[cfg(feature = "native-tls")]
NativeTls(#[pin] NativeTlsReadHalf),
#[cfg(feature = "__tls")]
Tls(#[pin] super::tls::OwnedReadHalf),
}

impl AsyncRead for OwnedReadHalf {
Expand All @@ -155,10 +126,8 @@ impl AsyncRead for OwnedReadHalf {
OwnedReadHalfProj::Tcp(half) => half.poll_read(cx, buf),
#[cfg(target_family = "unix")]
OwnedReadHalfProj::Unix(half) => half.poll_read(cx, buf),
#[cfg(feature = "rustls")]
OwnedReadHalfProj::Rustls(half) => half.poll_read(cx, buf),
#[cfg(feature = "native-tls")]
OwnedReadHalfProj::NativeTls(half) => half.poll_read(cx, buf),
#[cfg(feature = "__tls")]
OwnedReadHalfProj::Tls(half) => half.poll_read(cx, buf),
}
}
}
Expand All @@ -175,27 +144,18 @@ impl ConnStream {
let (rh, wh) = stream.into_split();
(OwnedReadHalf::Unix(rh), OwnedWriteHalf::Unix(wh))
}
#[cfg(feature = "rustls")]
Self::Rustls(stream) => {
let (rh, wh) = tokio::io::split(stream);
(OwnedReadHalf::Rustls(rh), OwnedWriteHalf::Rustls(wh))
}
#[cfg(feature = "native-tls")]
Self::NativeTls(stream) => {
let (rh, wh) = tokio::io::split(stream);
(OwnedReadHalf::NativeTls(rh), OwnedWriteHalf::NativeTls(wh))
#[cfg(feature = "__tls")]
Self::Tls(stream) => {
let (rh, wh) = stream.into_split();
(OwnedReadHalf::Tls(rh), OwnedWriteHalf::Tls(wh))
}
}
}

pub fn negotiated_alpn(&self) -> Option<Vec<u8>> {
match self {
#[cfg(feature = "rustls")]
Self::Rustls(tokio_rustls::TlsStream::Client(stream)) => {
stream.get_ref().1.alpn_protocol().map(ToOwned::to_owned)
}
#[cfg(feature = "native-tls")]
Self::NativeTls(stream) => stream.get_ref().negotiated_alpn().unwrap_or_default(),
#[cfg(feature = "__tls")]
Self::Tls(stream) => stream.negotiated_alpn(),
_ => None,
}
}
Expand All @@ -217,19 +177,14 @@ impl From<UnixStream> for ConnStream {
}
}

#[cfg(feature = "rustls")]
impl From<tokio_rustls::TlsStream<TcpStream>> for ConnStream {
#[inline]
fn from(s: tokio_rustls::TlsStream<TcpStream>) -> Self {
Self::Rustls(s)
}
}

#[cfg(feature = "native-tls")]
impl From<tokio_native_tls::TlsStream<TcpStream>> for ConnStream {
#[cfg(feature = "__tls")]
impl<T> From<T> for ConnStream
where
T: Into<super::tls::TlsStream>,
{
#[inline]
fn from(s: tokio_native_tls::TlsStream<TcpStream>) -> Self {
Self::NativeTls(s)
fn from(s: T) -> Self {
Self::Tls(s.into())
}
}

Expand All @@ -244,10 +199,8 @@ impl AsyncRead for ConnStream {
IoStreamProj::Tcp(s) => s.poll_read(cx, buf),
#[cfg(target_family = "unix")]
IoStreamProj::Unix(s) => s.poll_read(cx, buf),
#[cfg(feature = "rustls")]
IoStreamProj::Rustls(s) => s.poll_read(cx, buf),
#[cfg(feature = "native-tls")]
IoStreamProj::NativeTls(s) => s.poll_read(cx, buf),
#[cfg(feature = "__tls")]
IoStreamProj::Tls(s) => s.poll_read(cx, buf),
}
}
}
Expand All @@ -263,10 +216,8 @@ impl AsyncWrite for ConnStream {
IoStreamProj::Tcp(s) => s.poll_write(cx, buf),
#[cfg(target_family = "unix")]
IoStreamProj::Unix(s) => s.poll_write(cx, buf),
#[cfg(feature = "rustls")]
IoStreamProj::Rustls(s) => s.poll_write(cx, buf),
#[cfg(feature = "native-tls")]
IoStreamProj::NativeTls(s) => s.poll_write(cx, buf),
#[cfg(feature = "__tls")]
IoStreamProj::Tls(s) => s.poll_write(cx, buf),
}
}

Expand All @@ -276,10 +227,8 @@ impl AsyncWrite for ConnStream {
IoStreamProj::Tcp(s) => s.poll_flush(cx),
#[cfg(target_family = "unix")]
IoStreamProj::Unix(s) => s.poll_flush(cx),
#[cfg(feature = "rustls")]
IoStreamProj::Rustls(s) => s.poll_flush(cx),
#[cfg(feature = "native-tls")]
IoStreamProj::NativeTls(s) => s.poll_flush(cx),
#[cfg(feature = "__tls")]
IoStreamProj::Tls(s) => s.poll_flush(cx),
}
}

Expand All @@ -289,10 +238,8 @@ impl AsyncWrite for ConnStream {
IoStreamProj::Tcp(s) => s.poll_shutdown(cx),
#[cfg(target_family = "unix")]
IoStreamProj::Unix(s) => s.poll_shutdown(cx),
#[cfg(feature = "rustls")]
IoStreamProj::Rustls(s) => s.poll_shutdown(cx),
#[cfg(feature = "native-tls")]
IoStreamProj::NativeTls(s) => s.poll_shutdown(cx),
#[cfg(feature = "__tls")]
IoStreamProj::Tls(s) => s.poll_shutdown(cx),
}
}

Expand All @@ -306,10 +253,8 @@ impl AsyncWrite for ConnStream {
IoStreamProj::Tcp(s) => s.poll_write_vectored(cx, bufs),
#[cfg(target_family = "unix")]
IoStreamProj::Unix(s) => s.poll_write_vectored(cx, bufs),
#[cfg(feature = "rustls")]
IoStreamProj::Rustls(s) => s.poll_write_vectored(cx, bufs),
#[cfg(feature = "native-tls")]
IoStreamProj::NativeTls(s) => s.poll_write_vectored(cx, bufs),
#[cfg(feature = "__tls")]
IoStreamProj::Tls(s) => s.poll_write_vectored(cx, bufs),
}
}

Expand All @@ -319,10 +264,8 @@ impl AsyncWrite for ConnStream {
Self::Tcp(s) => s.is_write_vectored(),
#[cfg(target_family = "unix")]
Self::Unix(s) => s.is_write_vectored(),
#[cfg(feature = "rustls")]
Self::Rustls(s) => s.is_write_vectored(),
#[cfg(feature = "native-tls")]
Self::NativeTls(s) => s.is_write_vectored(),
#[cfg(feature = "__tls")]
Self::Tls(s) => s.is_write_vectored(),
}
}
}
Expand All @@ -334,19 +277,12 @@ impl ConnStream {
Self::Tcp(s) => s.peer_addr().map(Address::from).ok(),
#[cfg(target_family = "unix")]
Self::Unix(s) => s.peer_addr().map(Address::from).ok(),
#[cfg(feature = "rustls")]
Self::Rustls(s) => s.get_ref().0.peer_addr().map(Address::from).ok(),
#[cfg(feature = "native-tls")]
Self::NativeTls(s) => s
.get_ref()
.get_ref()
.get_ref()
.peer_addr()
.map(Address::from)
.ok(),
#[cfg(feature = "__tls")]
Self::Tls(s) => s.peer_addr().map(Address::from).ok(),
}
}
}

pub struct Conn {
pub stream: ConnStream,
pub info: ConnInfo,
Expand Down
1 change: 0 additions & 1 deletion volo/src/net/dial.rs
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,6 @@ impl DefaultMakeTransport {

impl MakeTransport for DefaultMakeTransport {
type ReadHalf = OwnedReadHalf;

type WriteHalf = OwnedWriteHalf;

async fn make_transport(&self, addr: Address) -> io::Result<(Self::ReadHalf, Self::WriteHalf)> {
Expand Down
16 changes: 8 additions & 8 deletions volo/src/net/ready.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,10 +20,10 @@ impl AsyncReady for OwnedReadHalf {
OwnedReadHalf::Tcp(half) => half.ready(interest).await,
#[cfg(target_family = "unix")]
OwnedReadHalf::Unix(half) => half.ready(interest).await,
#[cfg(feature = "rustls")]
OwnedReadHalf::Rustls(_) => todo!(),
#[cfg(feature = "native-tls")]
OwnedReadHalf::NativeTls(_) => todo!(),
#[cfg(feature = "__tls")]
OwnedReadHalf::Tls(_) => {
unimplemented!("AsyncReady is not supported for TLS connection")
}
}
}
}
Expand All @@ -34,10 +34,10 @@ impl AsyncReady for OwnedWriteHalf {
OwnedWriteHalf::Tcp(half) => half.ready(interest).await,
#[cfg(target_family = "unix")]
OwnedWriteHalf::Unix(half) => half.ready(interest).await,
#[cfg(feature = "rustls")]
OwnedWriteHalf::Rustls(_) => todo!(),
#[cfg(feature = "native-tls")]
OwnedWriteHalf::NativeTls(_) => todo!(),
#[cfg(feature = "__tls")]
OwnedWriteHalf::Tls(_) => {
unimplemented!("AsyncReady is not supported for TLS connection")
}
}
}
}
Loading
Loading