diff --git a/Cargo.toml b/Cargo.toml index fbc855c..9f54f70 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -10,6 +10,8 @@ repository = "https://github.com/compio-rs/cyper" [workspace.dependencies] compio = { version = "0.19.0-rc.1", default-features = false } +compio-log = "0.1.0" + cyper-core = { path = "./cyper-core", default-features = false, version = "0.9.0-rc.2" } cyper-axum = { path = "./cyper-axum", default-features = false, version = "0.9.0-rc.1" } cyper-hickory = { path = "./cyper-hickory", default-features = false, version = "0.1.0-rc.1" } diff --git a/cyper-axum/Cargo.toml b/cyper-axum/Cargo.toml index c764de3..a5ecfe9 100644 --- a/cyper-axum/Cargo.toml +++ b/cyper-axum/Cargo.toml @@ -11,7 +11,7 @@ repository = { workspace = true } [dependencies] compio = { workspace = true, features = ["net", "time"] } -compio-log = "0.1.0" +compio-log = { workspace = true } cyper-core = { workspace = true } socket2 = { workspace = true } diff --git a/cyper-hickory/Cargo.toml b/cyper-hickory/Cargo.toml index dc095a8..1d6d991 100644 --- a/cyper-hickory/Cargo.toml +++ b/cyper-hickory/Cargo.toml @@ -11,11 +11,13 @@ repository.workspace = true [dependencies] compio = { workspace = true, features = ["net", "time", "io-compat"] } +compio-log = { workspace = true } hickory-net = { workspace = true } hickory-resolver = { workspace = true } async-trait = "0.1" +futures-channel = { workspace = true, optional = true } futures-util = { workspace = true } send_wrapper = { workspace = true, features = ["futures"] } @@ -28,6 +30,7 @@ hyper-util = { workspace = true, optional = true, features = [ "http1", "http2", ] } +quinn = { version = "0.11", default-features = false } tower-service = { workspace = true, optional = true } [dev-dependencies] @@ -46,17 +49,36 @@ dnssec = ["hickory-resolver/dnssec-ring", "hickory-server/dnssec-ring"] __ring = ["compio/ring", "compio/rustls"] __tls = ["compio/tls", "__ring"] tls = ["__tls", "hickory-resolver/tls-ring", "hickory-server/tls-ring"] +__http = ["dep:http"] https = [ - "__tls", + "tls", + "__http", "hickory-resolver/https-ring", "hickory-server/https-ring", "dep:cyper-core", - "dep:http", "dep:http-body-util", "dep:hyper", "dep:hyper-util", "dep:tower-service", ] -all = ["dnssec", "tls", "https"] +__quic = ["__ring", "compio/quic"] +quic = [ + "__quic", + "tls", + "compio/bytes", + "hickory-resolver/quic-ring", + "hickory-server/quic-ring", +] +h3 = [ + "quic", + "__http", + "compio/h3", + "dep:futures-channel", + "hickory-resolver/h3-ring", + "hickory-server/h3-ring", +] +all = ["dnssec", "tls", "https", "quic", "h3"] + +enable_log = ["compio-log/enable_log"] nightly = ["compio/nightly", "cyper-core?/nightly"] diff --git a/cyper-hickory/src/h3.rs b/cyper-hickory/src/h3.rs new file mode 100644 index 0000000..e267ac3 --- /dev/null +++ b/cyper-hickory/src/h3.rs @@ -0,0 +1,201 @@ +use std::{ + net::SocketAddr, + pin::Pin, + sync::Arc, + task::{Context, Poll}, + time::Duration, +}; + +use compio::{ + bytes::{Buf, Bytes}, + rustls::ClientConfig, +}; +use compio_log::{debug, warn}; +use futures_channel::mpsc::Sender; +use futures_util::{FutureExt, SinkExt, Stream}; +use hickory_net::{ + NetError, + proto::op::{DnsRequest, DnsResponse}, + xfer::{DnsExchange, DnsRequestSender, DnsResponseStream}, +}; +use send_wrapper::SendWrapper; + +use crate::CompioRuntimeProvider; + +const H3_ALPN: &[u8] = b"h3"; + +pub async fn connect_h3( + server_name: Arc, + path: Arc, + remote_addr: SocketAddr, + bind_addr: Option, + config: ClientConfig, + enable_grease: bool, + timeout: Duration, +) -> Result, NetError> { + let conn = crate::connect_quic( + server_name.clone(), + remote_addr, + bind_addr, + config, + timeout, + H3_ALPN, + ) + .await?; + + let (mut driver, send_request) = compio::quic::h3::client::builder() + .send_grease(enable_grease) + .build(conn) + .await + .map_err(|e| NetError::from(format!("h3 client error: {e}")))?; + + let (tx, mut rx) = futures_channel::mpsc::channel::<()>(1); + + compio::runtime::spawn(async move { + futures_util::select! { + error = driver.wait_idle().fuse() => { + if !error.is_h3_no_error() { + warn!("h3 connection failed to close: {}", error); + } + } + _ = rx.recv().fuse() => { + debug!("h3 connection is shutting down: {}", remote_addr); + } + } + }) + .detach(); + + let stream = H3RequestSender::new(send_request, server_name, path, tx); + let (exchange, bg) = DnsExchange::from_stream(stream); + compio::runtime::spawn(bg).detach(); + Ok(exchange) +} + +type SendRequest = compio::quic::h3::client::SendRequest; + +struct H3RequestSender { + send_request: SendWrapper, + server_name: Arc, + path: Arc, + tx: Sender<()>, + is_shutdown: bool, +} + +impl H3RequestSender { + fn new( + send_request: SendRequest, + server_name: Arc, + path: Arc, + tx: Sender<()>, + ) -> Self { + Self { + send_request: SendWrapper::new(send_request), + server_name, + path, + tx, + is_shutdown: false, + } + } + + async fn inner_send( + mut send_request: SendWrapper, + server_name: Arc, + path: Arc, + mut request: DnsRequest, + ) -> Result { + request.metadata.id = 0; + let bytes = request.to_vec()?; + + let request = crate::build_request(&server_name, &path, bytes.len())?; + + let mut stream = send_request + .send_request(request) + .await + .map_err(|e| NetError::from(format!("h3 send request error: {e}")))?; + + stream + .send_data(Bytes::from(bytes)) + .await + .map_err(|e| NetError::from(format!("h3 send data error: {e}")))?; + + stream + .finish() + .await + .map_err(|e| NetError::from(format!("h3 finish error: {e}")))?; + + let resp = stream + .recv_response() + .await + .map_err(|e| NetError::from(format!("h3 recv response error: {e}")))?; + let (resp, ()) = resp.into_parts(); + + let content_length = crate::get_content_length(&resp.headers)?; + + let mut response_bytes = + Vec::with_capacity(content_length.unwrap_or(512).clamp(512, 4_096)); + while let Some(chunk) = stream + .recv_data() + .await + .map_err(|e| NetError::from(format!("h3 recv data error: {e}")))? + { + response_bytes.extend_from_slice(chunk.chunk()); + + if let Some(content_length) = content_length + && response_bytes.len() >= content_length + { + break; + } + } + + crate::build_response(resp, content_length, response_bytes) + } +} + +impl DnsRequestSender for H3RequestSender { + fn send_message(&mut self, request: DnsRequest) -> DnsResponseStream { + if self.is_shutdown { + panic!("can not send messages after stream is shutdown") + } + + Box::pin(SendWrapper::new(Self::inner_send( + self.send_request.clone(), + self.server_name.clone(), + self.path.clone(), + request, + ))) + .into() + } + + fn shutdown(&mut self) { + self.is_shutdown = true; + compio::runtime::spawn({ + let mut tx = self.tx.clone(); + async move { + let _ = tx.send(()).await; + } + }) + .detach(); + } + + fn is_shutdown(&self) -> bool { + self.is_shutdown + } +} + +impl Stream for H3RequestSender { + type Item = Result<(), NetError>; + + fn poll_next(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll> { + if self.is_shutdown { + return Poll::Ready(None); + } + + if self.tx.is_closed() { + return Poll::Ready(Some(Err(NetError::from( + "h3 connection is already shutdown", + )))); + } + + Poll::Ready(Some(Ok(()))) + } +} diff --git a/cyper-hickory/src/https.rs b/cyper-hickory/src/https.rs index fa19765..2e5a72b 100644 --- a/cyper-hickory/src/https.rs +++ b/cyper-hickory/src/https.rs @@ -2,7 +2,6 @@ use std::{ io, net::SocketAddr, pin::Pin, - str::FromStr, sync::Arc, task::{Context, Poll}, time::Duration, @@ -20,7 +19,7 @@ use hickory_net::{ proto::op::{DnsRequest, DnsResponse}, xfer::{DnsExchange, DnsRequestSender, DnsResponseStream}, }; -use http::{Request, Uri, uri}; +use http::{Request, Uri}; use http_body_util::{BodyExt, Full}; use hyper::{body::Bytes, rt::ReadBufCursor}; use hyper_util::client::legacy::{ @@ -30,7 +29,7 @@ use hyper_util::client::legacy::{ use send_wrapper::SendWrapper; use tower_service::Service; -use crate::{CompioRuntimeProvider, connect_tcp}; +use crate::CompioRuntimeProvider; pub async fn connect_https( server_name: Arc, @@ -53,8 +52,6 @@ pub async fn connect_https( Ok(exchange) } -const MIME_APPLICATION_DNS: &str = "application/dns-message"; - struct RequestSender { client: Arc>>, server_name: Arc, @@ -85,28 +82,10 @@ impl RequestSender { request.metadata.id = 0; let bytes = request.to_vec()?; - let mut parts = uri::Parts::default(); - parts.path_and_query = Some( - uri::PathAndQuery::from_str(&path) - .map_err(|e| NetError::from(format!("invalid DoH path: {e:?}")))?, - ); - parts.scheme = Some(uri::Scheme::HTTPS); - parts.authority = Some( - uri::Authority::from_str(&server_name) - .map_err(|e| NetError::from(format!("invalid authority: {e:?}")))?, - ); - - let url = Uri::from_parts(parts) - .map_err(|e| NetError::from(format!("uri parse error: {e:?}")))?; - - let request = Request::builder() - .method("POST") - .uri(url) - .header(http::header::CONTENT_TYPE, MIME_APPLICATION_DNS) - .header(http::header::ACCEPT, MIME_APPLICATION_DNS) - .header(http::header::CONTENT_LENGTH, bytes.len()) - .body(Full::new(Bytes::from(bytes))) - .map_err(|e| NetError::from(format!("build request error: {e:?}")))?; + let request = crate::build_request(&server_name, &path, bytes.len())?; + + let (parts, ()) = request.into_parts(); + let request = Request::from_parts(parts, Full::new(Bytes::from(bytes))); let response = client .request(request) @@ -114,59 +93,16 @@ impl RequestSender { .map_err(|e| NetError::from(format!("request error: {e:?}")))?; let (response, body) = response.into_parts(); - let content_length = response - .headers - .get(http::header::CONTENT_LENGTH) - .map(|v| v.to_str()) - .transpose() - .map_err(|e| NetError::from(format!("bad headers received: {e:?}")))? - .map(usize::from_str) - .transpose() - .map_err(|e| NetError::from(format!("bad headers received: {e:?}")))?; + let content_length = crate::get_content_length(&response.headers)?; let response_bytes = body .collect() .await .map_err(|e| NetError::from(format!("read response body error: {e:?}")))? - .to_bytes(); - - if let Some(content_length) = content_length - && response_bytes.len() != content_length - { - return Err(NetError::from(format!( - "expected byte length: {}, got: {}", - content_length, - response_bytes.len() - ))); - } - - if !response.status.is_success() { - let error_string = String::from_utf8_lossy(response_bytes.as_ref()); - - return Err(NetError::from(format!( - "http unsuccessful code: {}, message: {}", - response.status, error_string - ))); - } else { - let content_type = response - .headers - .get(http::header::CONTENT_TYPE) - .map(|h| { - h.to_str().map_err(|err| { - NetError::from(format!("ContentType header not a string: {err}")) - }) - }) - .unwrap_or(Ok(MIME_APPLICATION_DNS))?; - - if content_type != MIME_APPLICATION_DNS { - return Err(NetError::from(format!( - "ContentType unsupported (must be '{}'): '{}'", - MIME_APPLICATION_DNS, content_type - ))); - } - } + .to_bytes() + .to_vec(); - DnsResponse::from_buffer(response_bytes.to_vec()).map_err(NetError::from) + crate::build_response(response, content_length, response_bytes) } } @@ -249,7 +185,7 @@ impl Service for Connector { let tls = self.tls.clone(); let timeout = self.timeout; Box::pin(SendWrapper::new(async move { - let stream = connect_tcp(remote_addr, bind_addr, Some(timeout)).await?; + let stream = crate::connect_tcp(remote_addr, bind_addr, Some(timeout)).await?; let stream = TlsConnector::from(tls) .connect(&server_name, stream) .await?; diff --git a/cyper-hickory/src/lib.rs b/cyper-hickory/src/lib.rs index 21d76ce..f614be8 100644 --- a/cyper-hickory/src/lib.rs +++ b/cyper-hickory/src/lib.rs @@ -16,7 +16,7 @@ use async_trait::async_trait; use compio::{ BufResult, io::{compat::AsyncStream, util::Splittable}, - net::{TcpSocket, TcpStream, UdpSocket}, + net::{TcpStream, UdpSocket}, runtime::Runtime, }; use futures_util::{AsyncRead, AsyncWrite}; @@ -37,6 +37,15 @@ mod tls; #[cfg(feature = "https")] mod https; +#[cfg(feature = "quic")] +mod quic; + +#[cfg(feature = "h3")] +mod h3; + +mod util; +pub(crate) use util::*; + /// [`RuntimeProvider`] implementation for [`compio`]. It should not be used /// directly. Instead, use [`CompioConnectionProvider`] which wraps this /// provider and implements [`ConnectionProvider`] for hickory. @@ -328,8 +337,43 @@ impl ConnectionProvider for CompioConnectionProvider { .await }))) } - #[allow(unreachable_patterns)] - _ => Err(NetError::from("protocol config not supported")), + #[cfg(feature = "quic")] + ProtocolConfig::Quic { server_name } => { + let server_name = server_name.clone(); + let remote_addr = SocketAddr::new(ip, config.port); + let bind_addr = config.bind_addr; + let tls = cx.tls.clone(); + let timeout = cx.options.timeout; + Ok(Box::pin(SendWrapper::new(async move { + quic::connect_quic(server_name, remote_addr, bind_addr, tls, timeout).await + }))) + } + #[cfg(feature = "h3")] + ProtocolConfig::H3 { + server_name, + path, + disable_grease, + } => { + let server_name = server_name.clone(); + let path = path.clone(); + let remote_addr = SocketAddr::new(ip, config.port); + let bind_addr = config.bind_addr; + let tls = cx.tls.clone(); + let timeout = cx.options.timeout; + let enable_grease = !disable_grease; + Ok(Box::pin(SendWrapper::new(async move { + h3::connect_h3( + server_name, + path, + remote_addr, + bind_addr, + tls, + enable_grease, + timeout, + ) + .await + }))) + } } } @@ -337,31 +381,3 @@ impl ConnectionProvider for CompioConnectionProvider { &self.provider } } - -async fn connect_tcp( - server_addr: SocketAddr, - bind_addr: Option, - timeout: Option, -) -> io::Result { - let fut = async move { - if let Some(bind_addr) = bind_addr { - let socket = if bind_addr.is_ipv4() { - TcpSocket::new_v4().await? - } else { - TcpSocket::new_v6().await? - }; - socket.bind(bind_addr).await?; - socket.connect(server_addr).await - } else { - TcpStream::connect(server_addr).await - } - }; - if let Some(timeout) = timeout { - compio::time::timeout(timeout, fut) - .await - .map_err(|_| io::ErrorKind::TimedOut.into()) - .flatten() - } else { - fut.await - } -} diff --git a/cyper-hickory/src/quic.rs b/cyper-hickory/src/quic.rs new file mode 100644 index 0000000..ddbf22f --- /dev/null +++ b/cyper-hickory/src/quic.rs @@ -0,0 +1,143 @@ +use std::{ + net::SocketAddr, + pin::Pin, + sync::Arc, + task::{Context, Poll}, + time::Duration, +}; + +use compio::{bytes::Bytes, quic::Connection, rustls::ClientConfig}; +use compio_log::debug; +use futures_util::Stream; +use hickory_net::{ + NetError, + proto::{ + ProtoError, + op::{DnsRequest, DnsResponse, Message}, + }, + quic::DoqErrorCode, + xfer::{DnsExchange, DnsRequestSender, DnsResponseStream}, +}; +use send_wrapper::SendWrapper; + +use crate::{CompioRuntimeProvider, ToNetError}; + +const DOQ_ALPN: &[u8] = b"doq"; + +pub async fn connect_quic( + server_name: Arc, + remote_addr: SocketAddr, + bind_addr: Option, + config: ClientConfig, + timeout: Duration, +) -> Result, NetError> { + let conn = crate::connect_quic( + server_name, + remote_addr, + bind_addr, + config, + timeout, + DOQ_ALPN, + ) + .await?; + + let stream = CompioQuicClientStream::new(conn); + let (exchange, bg) = DnsExchange::from_stream(stream); + compio::runtime::spawn(bg).detach(); + Ok(exchange) +} + +struct CompioQuicClientStream { + conn: SendWrapper, + is_shutdown: bool, +} + +impl CompioQuicClientStream { + fn new(conn: Connection) -> Self { + Self { + conn: SendWrapper::new(conn), + is_shutdown: false, + } + } + + async fn inner_send( + conn: SendWrapper, + request: DnsRequest, + ) -> Result { + let (send, recv) = conn.open_bi().map_err(ToNetError::to_net_error)?; + + let mut send = send.into_compat(); + let mut recv = recv.into_compat(); + + let mut message = request.into_parts().0; + message.metadata.id = 0; + + let bytes = Bytes::from(message.to_vec()?); + let len = u16::try_from(bytes.len()) + .map_err(|_| NetError::from(ProtoError::MaxBufferSizeExceeded(bytes.len())))?; + + let len_bytes = Bytes::from(len.to_be_bytes().to_vec()); + + send.write_all_chunks(&mut [len_bytes, bytes]) + .await + .map_err(ToNetError::to_net_error)?; + + send.finish()?; + + let mut len_buf = [0u8; 2]; + recv.read_exact(&mut len_buf[..]) + .await + .map_err(ToNetError::to_net_error)?; + let response_len = u16::from_be_bytes(len_buf) as usize; + + let mut msg_buf = vec![0u8; response_len]; + recv.read_exact(&mut msg_buf[..]) + .await + .map_err(ToNetError::to_net_error)?; + + let message = Message::from_vec(&msg_buf)?; + if message.id != 0 { + if let Err(_e) = send.reset(DoqErrorCode::ProtocolError.into()) { + debug!("failed to reset stream: {_e:?}"); + } + return Err(NetError::QuicMessageIdNot0(message.id)); + } + + Ok(DnsResponse::from_buffer(msg_buf)?) + } +} + +impl DnsRequestSender for CompioQuicClientStream { + fn send_message(&mut self, request: DnsRequest) -> DnsResponseStream { + if self.is_shutdown { + panic!("can not send messages after stream is shutdown") + } + + Box::pin(SendWrapper::new(Self::inner_send( + self.conn.clone(), + request, + ))) + .into() + } + + fn shutdown(&mut self) { + self.is_shutdown = true; + self.conn.close(DoqErrorCode::NoError.into(), b"shutdown"); + } + + fn is_shutdown(&self) -> bool { + self.is_shutdown + } +} + +impl Stream for CompioQuicClientStream { + type Item = Result<(), NetError>; + + fn poll_next(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll> { + if self.is_shutdown { + Poll::Ready(None) + } else { + Poll::Ready(Some(Ok(()))) + } + } +} diff --git a/cyper-hickory/src/util/http.rs b/cyper-hickory/src/util/http.rs new file mode 100644 index 0000000..ac4c7b4 --- /dev/null +++ b/cyper-hickory/src/util/http.rs @@ -0,0 +1,89 @@ +use std::str::FromStr; + +use hickory_net::{NetError, proto::op::DnsResponse}; +use http::uri::{Authority, Parts, PathAndQuery, Scheme, Uri}; + +pub const MIME_APPLICATION_DNS: &str = "application/dns-message"; + +pub fn build_request( + server_name: &str, + path: &str, + len: usize, +) -> Result, NetError> { + let mut parts = Parts::default(); + parts.path_and_query = Some( + PathAndQuery::from_str(path).map_err(|e| NetError::from(format!("invalid path: {e:?}")))?, + ); + parts.scheme = Some(Scheme::HTTPS); + parts.authority = Some( + Authority::from_str(server_name) + .map_err(|e| NetError::from(format!("invalid authority: {e:?}")))?, + ); + + let url = + Uri::from_parts(parts).map_err(|e| NetError::from(format!("uri parse error: {e:?}")))?; + + let request = http::Request::builder() + .method("POST") + .uri(url) + .header(http::header::CONTENT_TYPE, MIME_APPLICATION_DNS) + .header(http::header::ACCEPT, MIME_APPLICATION_DNS) + .header(http::header::CONTENT_LENGTH, len) + .body(()) + .map_err(|e| NetError::from(format!("build request error: {e:?}")))?; + + Ok(request) +} + +pub fn get_content_length(headers: &http::HeaderMap) -> Result, NetError> { + headers + .get(http::header::CONTENT_LENGTH) + .map(|v| v.to_str()) + .transpose() + .map_err(|e| NetError::from(format!("bad headers received: {e:?}")))? + .map(usize::from_str) + .transpose() + .map_err(|e| NetError::from(format!("bad headers received: {e:?}"))) +} + +pub fn build_response( + response: http::response::Parts, + content_length: Option, + response_bytes: Vec, +) -> Result { + if let Some(content_length) = content_length + && response_bytes.len() != content_length + { + return Err(NetError::from(format!( + "expected byte length: {}, got: {}", + content_length, + response_bytes.len() + ))); + } + + if !response.status.is_success() { + let error_string = String::from_utf8_lossy(response_bytes.as_ref()); + + return Err(NetError::from(format!( + "http unsuccessful code: {}, message: {}", + response.status, error_string + ))); + } + let content_type = response + .headers + .get(http::header::CONTENT_TYPE) + .map(|h| { + h.to_str() + .map_err(|err| NetError::from(format!("ContentType header not a string: {err}"))) + }) + .unwrap_or(Ok(MIME_APPLICATION_DNS))?; + + if content_type != MIME_APPLICATION_DNS { + return Err(NetError::from(format!( + "ContentType unsupported (must be '{}'): '{}'", + MIME_APPLICATION_DNS, content_type + ))); + } + + DnsResponse::from_buffer(response_bytes.to_vec()).map_err(NetError::from) +} diff --git a/cyper-hickory/src/util/mod.rs b/cyper-hickory/src/util/mod.rs new file mode 100644 index 0000000..7fcd8f6 --- /dev/null +++ b/cyper-hickory/src/util/mod.rs @@ -0,0 +1,12 @@ +mod tcp; +pub use tcp::*; + +#[cfg(feature = "__http")] +mod http; +#[cfg(feature = "__http")] +pub use http::*; + +#[cfg(feature = "__quic")] +mod quic; +#[cfg(feature = "__quic")] +pub use quic::*; diff --git a/cyper-hickory/src/util/quic.rs b/cyper-hickory/src/util/quic.rs new file mode 100644 index 0000000..73d265b --- /dev/null +++ b/cyper-hickory/src/util/quic.rs @@ -0,0 +1,135 @@ +use std::{ + net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr}, + time::Duration, +}; + +use compio::{ + quic::{ + ClientBuilder, Connection, ConnectionError, OpenStreamError, ReadError, ReadExactError, + WriteError, + }, + rustls::ClientConfig, +}; +use hickory_net::NetError; + +pub async fn connect_quic( + server_name: std::sync::Arc, + remote_addr: SocketAddr, + bind_addr: Option, + mut config: ClientConfig, + timeout: Duration, + alpn: &[u8], +) -> Result { + if config.alpn_protocols.is_empty() { + config.alpn_protocols = vec![alpn.to_vec()]; + } + let enable_early_data = config.enable_early_data; + + let bind = bind_addr.unwrap_or_else(|| { + if remote_addr.is_ipv4() { + SocketAddr::new(IpAddr::V4(Ipv4Addr::UNSPECIFIED), 0) + } else { + SocketAddr::new(IpAddr::V6(Ipv6Addr::UNSPECIFIED), 0) + } + }); + + let endpoint = ClientBuilder::new_with_rustls_client_config(config) + .bind(bind) + .await?; + + let mut connecting = endpoint.connect(remote_addr, &server_name, None)?; + if enable_early_data { + match connecting.into_0rtt() { + Ok(conn) => return Ok(conn), + Err(f) => connecting = f, + } + } + + let conn = compio::time::timeout(timeout, connecting) + .await + .map_err(|_| std::io::Error::from(std::io::ErrorKind::TimedOut))? + .map_err(ToNetError::to_net_error)?; + Ok(conn) +} + +pub trait ToNetError { + type QuinnError; + + fn to_net_error(self) -> Self::QuinnError; +} + +impl ToNetError for OpenStreamError { + type QuinnError = NetError; + + fn to_net_error(self) -> Self::QuinnError { + match self { + Self::ConnectionLost(err) => err.to_net_error().into(), + Self::StreamsExhausted => NetError::from(self.to_string()), + } + } +} + +impl ToNetError for ConnectionError { + type QuinnError = quinn::ConnectionError; + + fn to_net_error(self) -> Self::QuinnError { + use quinn::ConnectionError as QuinnConnectionError; + + match self { + Self::VersionMismatch => QuinnConnectionError::VersionMismatch, + Self::TransportError(err) => QuinnConnectionError::TransportError(err), + Self::ConnectionClosed(err) => QuinnConnectionError::ConnectionClosed(err), + Self::ApplicationClosed(err) => QuinnConnectionError::ApplicationClosed(err), + Self::Reset => QuinnConnectionError::Reset, + Self::TimedOut => QuinnConnectionError::TimedOut, + Self::LocallyClosed => QuinnConnectionError::LocallyClosed, + Self::CidsExhausted => QuinnConnectionError::CidsExhausted, + } + } +} + +impl ToNetError for WriteError { + type QuinnError = NetError; + + fn to_net_error(self) -> Self::QuinnError { + use quinn::WriteError as QuinnWriteError; + + let err = match self { + Self::Stopped(code) => QuinnWriteError::Stopped(code), + Self::ConnectionLost(err) => QuinnWriteError::ConnectionLost(err.to_net_error()), + Self::ClosedStream => QuinnWriteError::ClosedStream, + Self::ZeroRttRejected => QuinnWriteError::ZeroRttRejected, + _ => return NetError::from(self.to_string()), + }; + NetError::QuinnWriteError(err) + } +} + +impl ToNetError for ReadExactError { + type QuinnError = quinn::ReadExactError; + + fn to_net_error(self) -> Self::QuinnError { + use quinn::ReadExactError as QuinnReadExactError; + + match self { + Self::FinishedEarly(len) => QuinnReadExactError::FinishedEarly(len), + Self::ReadError(err) => QuinnReadExactError::ReadError(err.to_net_error()), + } + } +} + +impl ToNetError for ReadError { + type QuinnError = quinn::ReadError; + + fn to_net_error(self) -> Self::QuinnError { + use quinn::ReadError as QuinnReadError; + + match self { + Self::Reset(code) => QuinnReadError::Reset(code), + Self::ConnectionLost(err) => QuinnReadError::ConnectionLost(err.to_net_error()), + Self::ClosedStream => QuinnReadError::ClosedStream, + Self::IllegalOrderedRead => QuinnReadError::IllegalOrderedRead, + Self::ZeroRttRejected => QuinnReadError::ZeroRttRejected, + } + } +} diff --git a/cyper-hickory/src/util/tcp.rs b/cyper-hickory/src/util/tcp.rs new file mode 100644 index 0000000..a9ca00c --- /dev/null +++ b/cyper-hickory/src/util/tcp.rs @@ -0,0 +1,31 @@ +use std::{io, net::SocketAddr, time::Duration}; + +use compio::net::{TcpSocket, TcpStream}; + +pub async fn connect_tcp( + server_addr: SocketAddr, + bind_addr: Option, + timeout: Option, +) -> io::Result { + let fut = async move { + if let Some(bind_addr) = bind_addr { + let socket = if bind_addr.is_ipv4() { + TcpSocket::new_v4().await? + } else { + TcpSocket::new_v6().await? + }; + socket.bind(bind_addr).await?; + socket.connect(server_addr).await + } else { + TcpStream::connect(server_addr).await + } + }; + if let Some(timeout) = timeout { + compio::time::timeout(timeout, fut) + .await + .map_err(|_| io::ErrorKind::TimedOut.into()) + .flatten() + } else { + fut.await + } +} diff --git a/cyper-hickory/tests/resolve.rs b/cyper-hickory/tests/resolve.rs index c41f581..c458926 100644 --- a/cyper-hickory/tests/resolve.rs +++ b/cyper-hickory/tests/resolve.rs @@ -33,9 +33,13 @@ enum ServerType { Tcp, Udp, #[cfg(feature = "tls")] - Tls(compio::rustls::ServerConfig), + Tls(ServerConfig), #[cfg(feature = "https")] - Https(compio::rustls::ServerConfig, String, String), + Https(ServerConfig, String, String), + #[cfg(feature = "quic")] + Quic(ServerConfig), + #[cfg(feature = "h3")] + H3(ServerConfig), } const IP_RESPONSE: IpAddr = IpAddr::V4(Ipv4Addr::new(127, 0, 0, 2)); @@ -116,6 +120,35 @@ async fn spawn_dns_server(ty: ServerType) -> (SocketAddr, CancellationToken, Joi ) .unwrap(); } + #[cfg(feature = "quic")] + ServerType::Quic(mut config) => { + config.alpn_protocols = vec![b"doq".to_vec()]; + let socket = UdpSocket::bind("127.0.0.1:0").await.unwrap(); + let addr = socket.local_addr().unwrap(); + tx.send((addr, token)).unwrap(); + server + .register_quic_listener_and_tls_config( + socket, + Duration::from_secs(5), + Arc::new(config), + ) + .unwrap(); + } + #[cfg(feature = "h3")] + ServerType::H3(mut config) => { + config.alpn_protocols = vec![b"h3".to_vec()]; + let socket = UdpSocket::bind("127.0.0.1:0").await.unwrap(); + let addr = socket.local_addr().unwrap(); + tx.send((addr, token)).unwrap(); + server + .register_h3_listener_with_tls_config( + socket, + Duration::from_secs(5), + Arc::new(config), + Some("dns.compio.rs".to_string()), + ) + .unwrap(); + } } server.block_until_done().await.unwrap(); }); @@ -258,3 +291,47 @@ async fn resolve_https() { token.cancel(); handle.join().unwrap_or_else(|e| resume_unwind(e)) } + +#[compio::test] +#[cfg(feature = "quic")] +async fn resolve_quic() { + let (server_config, client_config) = rcgen(); + let (addr, token, handle) = spawn_dns_server(ServerType::Quic(server_config)).await; + let group = ServerGroup { + ips: &[addr.ip()], + server_name: "dns.compio.rs", + path: "/dns-query", + }; + let mut config = ResolverConfig::quic(&group); + update_port(&mut config, addr.port()); + let resolver = Resolver::builder_with_config(config, CompioConnectionProvider::default()) + .with_tls_config(client_config) + .build() + .unwrap(); + + test_resolve(resolver).await; + token.cancel(); + handle.join().unwrap_or_else(|e| resume_unwind(e)) +} + +#[compio::test] +#[cfg(feature = "h3")] +async fn resolve_h3() { + let (server_config, client_config) = rcgen(); + let (addr, token, handle) = spawn_dns_server(ServerType::H3(server_config)).await; + let group = ServerGroup { + ips: &[addr.ip()], + server_name: "dns.compio.rs", + path: "/dns-query", + }; + let mut config = ResolverConfig::h3(&group); + update_port(&mut config, addr.port()); + let resolver = Resolver::builder_with_config(config, CompioConnectionProvider::default()) + .with_tls_config(client_config) + .build() + .unwrap(); + + test_resolve(resolver).await; + token.cancel(); + handle.join().unwrap_or_else(|e| resume_unwind(e)) +}