Skip to content

Commit 663fd5c

Browse files
authored
chore(volo): combine rustls and native-tls (#627)
In previous implementation, `volo::net::Conn` has 4 variants: `Tcp`, `Unix`, `Rustls` and `NativeTls`, but we think it's not elegant to put `Rustls` and `NativeTls` here. And what's more, in most cases, we only need to be aware of TLS and do not need to be aware of which library is used behind it. This commit wraps `Rustls` and `NativeTls` into `Tls` prefixed types and makes `Conn` cleaner. Signed-off-by: Yu Li <[email protected]>
1 parent 10e73ea commit 663fd5c

File tree

9 files changed

+390
-222
lines changed

9 files changed

+390
-222
lines changed

volo-grpc/src/server/mod.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ use motore::{
2020
pub use service::ServiceBuilder;
2121
use tower::util::BoxCloneService;
2222
#[cfg(feature = "__tls")]
23-
use volo::net::tls::{Acceptor, ServerTlsConfig};
23+
use volo::net::tls::ServerTlsConfig;
2424
use volo::{
2525
net::{conn::Conn, incoming::Incoming},
2626
spawn,

volo-http/src/client/transport/tls.rs

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
11
use http::uri::Scheme;
22
use motore::service::UnaryService;
33
use volo::net::{
4-
conn::{Conn, ConnStream},
5-
tls::{Connector, TlsConnector},
4+
conn::{Conn, ConnInfo, ConnStream},
5+
tls::TlsConnector,
66
};
77

88
use super::{connector::PeerInfo, plain::PlainMakeConnection};
@@ -46,7 +46,12 @@ where
4646
_ => unreachable!(),
4747
};
4848
match self.tls_connector.connect(&target_name, tcp_stream).await {
49-
Ok(conn) => Ok(conn),
49+
Ok(stream) => Ok(Conn {
50+
stream,
51+
info: ConnInfo {
52+
peer_addr: Some(req.address),
53+
},
54+
}),
5055
Err(err) => {
5156
tracing::warn!("[Volo-HTTP] failed to make tls connection, error: {err}");
5257
Err(request_error(err))

volo-http/src/server/mod.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ use scopeguard::defer;
2929
use tokio::sync::Notify;
3030
use tracing::Instrument;
3131
#[cfg(feature = "__tls")]
32-
use volo::net::{conn::ConnStream, tls::Acceptor, tls::ServerTlsConfig};
32+
use volo::net::{conn::ConnStream, tls::ServerTlsConfig};
3333
use volo::{
3434
context::Context,
3535
net::{Address, MakeIncoming, conn::Conn, incoming::Incoming},

volo/src/net/conn.rs

Lines changed: 46 additions & 110 deletions
Original file line numberDiff line numberDiff line change
@@ -23,33 +23,22 @@ pub trait DynStream: AsyncRead + AsyncWrite + Send + 'static {}
2323

2424
impl<T> DynStream for T where T: AsyncRead + AsyncWrite + Send + 'static {}
2525

26-
#[allow(clippy::large_enum_variant)]
2726
#[pin_project(project = IoStreamProj)]
2827
pub enum ConnStream {
2928
Tcp(#[pin] TcpStream),
3029
#[cfg(target_family = "unix")]
3130
Unix(#[pin] UnixStream),
32-
#[cfg(feature = "rustls")]
33-
Rustls(#[pin] tokio_rustls::TlsStream<TcpStream>),
34-
#[cfg(feature = "native-tls")]
35-
NativeTls(#[pin] tokio_native_tls::TlsStream<TcpStream>),
31+
#[cfg(feature = "__tls")]
32+
Tls(#[pin] super::tls::TlsStream),
3633
}
3734

38-
#[cfg(feature = "rustls")]
39-
type RustlsWriteHalf = tokio::io::WriteHalf<tokio_rustls::TlsStream<TcpStream>>;
40-
41-
#[cfg(feature = "native-tls")]
42-
type NativeTlsWriteHalf = tokio::io::WriteHalf<tokio_native_tls::TlsStream<TcpStream>>;
43-
4435
#[pin_project(project = OwnedWriteHalfProj)]
4536
pub enum OwnedWriteHalf {
4637
Tcp(#[pin] tcp::OwnedWriteHalf),
4738
#[cfg(target_family = "unix")]
4839
Unix(#[pin] unix::OwnedWriteHalf),
49-
#[cfg(feature = "rustls")]
50-
Rustls(#[pin] RustlsWriteHalf),
51-
#[cfg(feature = "native-tls")]
52-
NativeTls(#[pin] NativeTlsWriteHalf),
40+
#[cfg(feature = "__tls")]
41+
Tls(#[pin] super::tls::OwnedWriteHalf),
5342
}
5443

5544
impl AsyncWrite for OwnedWriteHalf {
@@ -63,10 +52,8 @@ impl AsyncWrite for OwnedWriteHalf {
6352
OwnedWriteHalfProj::Tcp(half) => half.poll_write(cx, buf),
6453
#[cfg(target_family = "unix")]
6554
OwnedWriteHalfProj::Unix(half) => half.poll_write(cx, buf),
66-
#[cfg(feature = "rustls")]
67-
OwnedWriteHalfProj::Rustls(half) => half.poll_write(cx, buf),
68-
#[cfg(feature = "native-tls")]
69-
OwnedWriteHalfProj::NativeTls(half) => half.poll_write(cx, buf),
55+
#[cfg(feature = "__tls")]
56+
OwnedWriteHalfProj::Tls(half) => half.poll_write(cx, buf),
7057
}
7158
}
7259

@@ -76,10 +63,8 @@ impl AsyncWrite for OwnedWriteHalf {
7663
OwnedWriteHalfProj::Tcp(half) => half.poll_flush(cx),
7764
#[cfg(target_family = "unix")]
7865
OwnedWriteHalfProj::Unix(half) => half.poll_flush(cx),
79-
#[cfg(feature = "rustls")]
80-
OwnedWriteHalfProj::Rustls(half) => half.poll_flush(cx),
81-
#[cfg(feature = "native-tls")]
82-
OwnedWriteHalfProj::NativeTls(half) => half.poll_flush(cx),
66+
#[cfg(feature = "__tls")]
67+
OwnedWriteHalfProj::Tls(half) => half.poll_flush(cx),
8368
}
8469
}
8570

@@ -89,10 +74,8 @@ impl AsyncWrite for OwnedWriteHalf {
8974
OwnedWriteHalfProj::Tcp(half) => half.poll_shutdown(cx),
9075
#[cfg(target_family = "unix")]
9176
OwnedWriteHalfProj::Unix(half) => half.poll_shutdown(cx),
92-
#[cfg(feature = "rustls")]
93-
OwnedWriteHalfProj::Rustls(half) => half.poll_shutdown(cx),
94-
#[cfg(feature = "native-tls")]
95-
OwnedWriteHalfProj::NativeTls(half) => half.poll_shutdown(cx),
77+
#[cfg(feature = "__tls")]
78+
OwnedWriteHalfProj::Tls(half) => half.poll_shutdown(cx),
9679
}
9780
}
9881

@@ -106,10 +89,8 @@ impl AsyncWrite for OwnedWriteHalf {
10689
OwnedWriteHalfProj::Tcp(half) => half.poll_write_vectored(cx, bufs),
10790
#[cfg(target_family = "unix")]
10891
OwnedWriteHalfProj::Unix(half) => half.poll_write_vectored(cx, bufs),
109-
#[cfg(feature = "rustls")]
110-
OwnedWriteHalfProj::Rustls(half) => half.poll_write_vectored(cx, bufs),
111-
#[cfg(feature = "native-tls")]
112-
OwnedWriteHalfProj::NativeTls(half) => half.poll_write_vectored(cx, bufs),
92+
#[cfg(feature = "__tls")]
93+
OwnedWriteHalfProj::Tls(half) => half.poll_write_vectored(cx, bufs),
11394
}
11495
}
11596

@@ -119,29 +100,19 @@ impl AsyncWrite for OwnedWriteHalf {
119100
Self::Tcp(half) => half.is_write_vectored(),
120101
#[cfg(target_family = "unix")]
121102
Self::Unix(half) => half.is_write_vectored(),
122-
#[cfg(feature = "rustls")]
123-
Self::Rustls(half) => half.is_write_vectored(),
124-
#[cfg(feature = "native-tls")]
125-
Self::NativeTls(half) => half.is_write_vectored(),
103+
#[cfg(feature = "__tls")]
104+
Self::Tls(half) => half.is_write_vectored(),
126105
}
127106
}
128107
}
129108

130-
#[cfg(feature = "rustls")]
131-
type RustlsReadHalf = tokio::io::ReadHalf<tokio_rustls::TlsStream<TcpStream>>;
132-
133-
#[cfg(feature = "native-tls")]
134-
type NativeTlsReadHalf = tokio::io::ReadHalf<tokio_native_tls::TlsStream<TcpStream>>;
135-
136109
#[pin_project(project = OwnedReadHalfProj)]
137110
pub enum OwnedReadHalf {
138111
Tcp(#[pin] tcp::OwnedReadHalf),
139112
#[cfg(target_family = "unix")]
140113
Unix(#[pin] unix::OwnedReadHalf),
141-
#[cfg(feature = "rustls")]
142-
Rustls(#[pin] RustlsReadHalf),
143-
#[cfg(feature = "native-tls")]
144-
NativeTls(#[pin] NativeTlsReadHalf),
114+
#[cfg(feature = "__tls")]
115+
Tls(#[pin] super::tls::OwnedReadHalf),
145116
}
146117

147118
impl AsyncRead for OwnedReadHalf {
@@ -155,10 +126,8 @@ impl AsyncRead for OwnedReadHalf {
155126
OwnedReadHalfProj::Tcp(half) => half.poll_read(cx, buf),
156127
#[cfg(target_family = "unix")]
157128
OwnedReadHalfProj::Unix(half) => half.poll_read(cx, buf),
158-
#[cfg(feature = "rustls")]
159-
OwnedReadHalfProj::Rustls(half) => half.poll_read(cx, buf),
160-
#[cfg(feature = "native-tls")]
161-
OwnedReadHalfProj::NativeTls(half) => half.poll_read(cx, buf),
129+
#[cfg(feature = "__tls")]
130+
OwnedReadHalfProj::Tls(half) => half.poll_read(cx, buf),
162131
}
163132
}
164133
}
@@ -175,27 +144,18 @@ impl ConnStream {
175144
let (rh, wh) = stream.into_split();
176145
(OwnedReadHalf::Unix(rh), OwnedWriteHalf::Unix(wh))
177146
}
178-
#[cfg(feature = "rustls")]
179-
Self::Rustls(stream) => {
180-
let (rh, wh) = tokio::io::split(stream);
181-
(OwnedReadHalf::Rustls(rh), OwnedWriteHalf::Rustls(wh))
182-
}
183-
#[cfg(feature = "native-tls")]
184-
Self::NativeTls(stream) => {
185-
let (rh, wh) = tokio::io::split(stream);
186-
(OwnedReadHalf::NativeTls(rh), OwnedWriteHalf::NativeTls(wh))
147+
#[cfg(feature = "__tls")]
148+
Self::Tls(stream) => {
149+
let (rh, wh) = stream.into_split();
150+
(OwnedReadHalf::Tls(rh), OwnedWriteHalf::Tls(wh))
187151
}
188152
}
189153
}
190154

191155
pub fn negotiated_alpn(&self) -> Option<Vec<u8>> {
192156
match self {
193-
#[cfg(feature = "rustls")]
194-
Self::Rustls(tokio_rustls::TlsStream::Client(stream)) => {
195-
stream.get_ref().1.alpn_protocol().map(ToOwned::to_owned)
196-
}
197-
#[cfg(feature = "native-tls")]
198-
Self::NativeTls(stream) => stream.get_ref().negotiated_alpn().unwrap_or_default(),
157+
#[cfg(feature = "__tls")]
158+
Self::Tls(stream) => stream.negotiated_alpn(),
199159
_ => None,
200160
}
201161
}
@@ -217,19 +177,14 @@ impl From<UnixStream> for ConnStream {
217177
}
218178
}
219179

220-
#[cfg(feature = "rustls")]
221-
impl From<tokio_rustls::TlsStream<TcpStream>> for ConnStream {
222-
#[inline]
223-
fn from(s: tokio_rustls::TlsStream<TcpStream>) -> Self {
224-
Self::Rustls(s)
225-
}
226-
}
227-
228-
#[cfg(feature = "native-tls")]
229-
impl From<tokio_native_tls::TlsStream<TcpStream>> for ConnStream {
180+
#[cfg(feature = "__tls")]
181+
impl<T> From<T> for ConnStream
182+
where
183+
T: Into<super::tls::TlsStream>,
184+
{
230185
#[inline]
231-
fn from(s: tokio_native_tls::TlsStream<TcpStream>) -> Self {
232-
Self::NativeTls(s)
186+
fn from(s: T) -> Self {
187+
Self::Tls(s.into())
233188
}
234189
}
235190

@@ -244,10 +199,8 @@ impl AsyncRead for ConnStream {
244199
IoStreamProj::Tcp(s) => s.poll_read(cx, buf),
245200
#[cfg(target_family = "unix")]
246201
IoStreamProj::Unix(s) => s.poll_read(cx, buf),
247-
#[cfg(feature = "rustls")]
248-
IoStreamProj::Rustls(s) => s.poll_read(cx, buf),
249-
#[cfg(feature = "native-tls")]
250-
IoStreamProj::NativeTls(s) => s.poll_read(cx, buf),
202+
#[cfg(feature = "__tls")]
203+
IoStreamProj::Tls(s) => s.poll_read(cx, buf),
251204
}
252205
}
253206
}
@@ -263,10 +216,8 @@ impl AsyncWrite for ConnStream {
263216
IoStreamProj::Tcp(s) => s.poll_write(cx, buf),
264217
#[cfg(target_family = "unix")]
265218
IoStreamProj::Unix(s) => s.poll_write(cx, buf),
266-
#[cfg(feature = "rustls")]
267-
IoStreamProj::Rustls(s) => s.poll_write(cx, buf),
268-
#[cfg(feature = "native-tls")]
269-
IoStreamProj::NativeTls(s) => s.poll_write(cx, buf),
219+
#[cfg(feature = "__tls")]
220+
IoStreamProj::Tls(s) => s.poll_write(cx, buf),
270221
}
271222
}
272223

@@ -276,10 +227,8 @@ impl AsyncWrite for ConnStream {
276227
IoStreamProj::Tcp(s) => s.poll_flush(cx),
277228
#[cfg(target_family = "unix")]
278229
IoStreamProj::Unix(s) => s.poll_flush(cx),
279-
#[cfg(feature = "rustls")]
280-
IoStreamProj::Rustls(s) => s.poll_flush(cx),
281-
#[cfg(feature = "native-tls")]
282-
IoStreamProj::NativeTls(s) => s.poll_flush(cx),
230+
#[cfg(feature = "__tls")]
231+
IoStreamProj::Tls(s) => s.poll_flush(cx),
283232
}
284233
}
285234

@@ -289,10 +238,8 @@ impl AsyncWrite for ConnStream {
289238
IoStreamProj::Tcp(s) => s.poll_shutdown(cx),
290239
#[cfg(target_family = "unix")]
291240
IoStreamProj::Unix(s) => s.poll_shutdown(cx),
292-
#[cfg(feature = "rustls")]
293-
IoStreamProj::Rustls(s) => s.poll_shutdown(cx),
294-
#[cfg(feature = "native-tls")]
295-
IoStreamProj::NativeTls(s) => s.poll_shutdown(cx),
241+
#[cfg(feature = "__tls")]
242+
IoStreamProj::Tls(s) => s.poll_shutdown(cx),
296243
}
297244
}
298245

@@ -306,10 +253,8 @@ impl AsyncWrite for ConnStream {
306253
IoStreamProj::Tcp(s) => s.poll_write_vectored(cx, bufs),
307254
#[cfg(target_family = "unix")]
308255
IoStreamProj::Unix(s) => s.poll_write_vectored(cx, bufs),
309-
#[cfg(feature = "rustls")]
310-
IoStreamProj::Rustls(s) => s.poll_write_vectored(cx, bufs),
311-
#[cfg(feature = "native-tls")]
312-
IoStreamProj::NativeTls(s) => s.poll_write_vectored(cx, bufs),
256+
#[cfg(feature = "__tls")]
257+
IoStreamProj::Tls(s) => s.poll_write_vectored(cx, bufs),
313258
}
314259
}
315260

@@ -319,10 +264,8 @@ impl AsyncWrite for ConnStream {
319264
Self::Tcp(s) => s.is_write_vectored(),
320265
#[cfg(target_family = "unix")]
321266
Self::Unix(s) => s.is_write_vectored(),
322-
#[cfg(feature = "rustls")]
323-
Self::Rustls(s) => s.is_write_vectored(),
324-
#[cfg(feature = "native-tls")]
325-
Self::NativeTls(s) => s.is_write_vectored(),
267+
#[cfg(feature = "__tls")]
268+
Self::Tls(s) => s.is_write_vectored(),
326269
}
327270
}
328271
}
@@ -334,19 +277,12 @@ impl ConnStream {
334277
Self::Tcp(s) => s.peer_addr().map(Address::from).ok(),
335278
#[cfg(target_family = "unix")]
336279
Self::Unix(s) => s.peer_addr().map(Address::from).ok(),
337-
#[cfg(feature = "rustls")]
338-
Self::Rustls(s) => s.get_ref().0.peer_addr().map(Address::from).ok(),
339-
#[cfg(feature = "native-tls")]
340-
Self::NativeTls(s) => s
341-
.get_ref()
342-
.get_ref()
343-
.get_ref()
344-
.peer_addr()
345-
.map(Address::from)
346-
.ok(),
280+
#[cfg(feature = "__tls")]
281+
Self::Tls(s) => s.peer_addr().map(Address::from).ok(),
347282
}
348283
}
349284
}
285+
350286
pub struct Conn {
351287
pub stream: ConnStream,
352288
pub info: ConnInfo,

volo/src/net/dial.rs

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -78,7 +78,6 @@ impl DefaultMakeTransport {
7878

7979
impl MakeTransport for DefaultMakeTransport {
8080
type ReadHalf = OwnedReadHalf;
81-
8281
type WriteHalf = OwnedWriteHalf;
8382

8483
async fn make_transport(&self, addr: Address) -> io::Result<(Self::ReadHalf, Self::WriteHalf)> {

volo/src/net/ready.rs

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -20,10 +20,10 @@ impl AsyncReady for OwnedReadHalf {
2020
OwnedReadHalf::Tcp(half) => half.ready(interest).await,
2121
#[cfg(target_family = "unix")]
2222
OwnedReadHalf::Unix(half) => half.ready(interest).await,
23-
#[cfg(feature = "rustls")]
24-
OwnedReadHalf::Rustls(_) => todo!(),
25-
#[cfg(feature = "native-tls")]
26-
OwnedReadHalf::NativeTls(_) => todo!(),
23+
#[cfg(feature = "__tls")]
24+
OwnedReadHalf::Tls(_) => {
25+
unimplemented!("AsyncReady is not supported for TLS connection")
26+
}
2727
}
2828
}
2929
}
@@ -34,10 +34,10 @@ impl AsyncReady for OwnedWriteHalf {
3434
OwnedWriteHalf::Tcp(half) => half.ready(interest).await,
3535
#[cfg(target_family = "unix")]
3636
OwnedWriteHalf::Unix(half) => half.ready(interest).await,
37-
#[cfg(feature = "rustls")]
38-
OwnedWriteHalf::Rustls(_) => todo!(),
39-
#[cfg(feature = "native-tls")]
40-
OwnedWriteHalf::NativeTls(_) => todo!(),
37+
#[cfg(feature = "__tls")]
38+
OwnedWriteHalf::Tls(_) => {
39+
unimplemented!("AsyncReady is not supported for TLS connection")
40+
}
4141
}
4242
}
4343
}

0 commit comments

Comments
 (0)