From 59dfc6c0667d89ccbf7fdcf841331292930cfd23 Mon Sep 17 00:00:00 2001 From: Yuyi Wang Date: Wed, 13 May 2026 00:10:51 +0800 Subject: [PATCH 1/7] feat(dns): support DoQ & DoH3 Co-authored-by: DeepSeek --- Cargo.toml | 2 + cyper-axum/Cargo.toml | 2 +- cyper-hickory/Cargo.toml | 19 ++- cyper-hickory/src/h3.rs | 255 +++++++++++++++++++++++++++++++++ cyper-hickory/src/https.rs | 4 +- cyper-hickory/src/lib.rs | 48 ++++++- cyper-hickory/src/quic.rs | 174 ++++++++++++++++++++++ cyper-hickory/tests/resolve.rs | 81 ++++++++++- 8 files changed, 576 insertions(+), 9 deletions(-) create mode 100644 cyper-hickory/src/h3.rs create mode 100644 cyper-hickory/src/quic.rs diff --git a/Cargo.toml b/Cargo.toml index fbc855c..9f54f70 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -10,6 +10,8 @@ repository = "https://github.com/compio-rs/cyper" [workspace.dependencies] compio = { version = "0.19.0-rc.1", default-features = false } +compio-log = "0.1.0" + cyper-core = { path = "./cyper-core", default-features = false, version = "0.9.0-rc.2" } cyper-axum = { path = "./cyper-axum", default-features = false, version = "0.9.0-rc.1" } cyper-hickory = { path = "./cyper-hickory", default-features = false, version = "0.1.0-rc.1" } diff --git a/cyper-axum/Cargo.toml b/cyper-axum/Cargo.toml index c764de3..a5ecfe9 100644 --- a/cyper-axum/Cargo.toml +++ b/cyper-axum/Cargo.toml @@ -11,7 +11,7 @@ repository = { workspace = true } [dependencies] compio = { workspace = true, features = ["net", "time"] } -compio-log = "0.1.0" +compio-log = { workspace = true } cyper-core = { workspace = true } socket2 = { workspace = true } diff --git a/cyper-hickory/Cargo.toml b/cyper-hickory/Cargo.toml index dc095a8..e4000a5 100644 --- a/cyper-hickory/Cargo.toml +++ b/cyper-hickory/Cargo.toml @@ -11,6 +11,7 @@ repository.workspace = true [dependencies] compio = { workspace = true, features = ["net", "time", "io-compat"] } +compio-log = { workspace = true } hickory-net = { workspace = true } hickory-resolver = { workspace = true } @@ -57,6 +58,22 @@ https = [ "dep:hyper-util", "dep:tower-service", ] -all = ["dnssec", "tls", "https"] +__quic = ["__ring", "compio/quic"] +quic = [ + "__quic", + "compio/bytes", + "hickory-resolver/quic-ring", + "hickory-server/quic-ring", +] +h3 = [ + "__quic", + "compio/h3", + "dep:http", + "hickory-resolver/h3-ring", + "hickory-server/h3-ring", +] +all = ["dnssec", "tls", "https", "quic", "h3"] + +enable_log = ["compio-log/enable_log"] nightly = ["compio/nightly", "cyper-core?/nightly"] diff --git a/cyper-hickory/src/h3.rs b/cyper-hickory/src/h3.rs new file mode 100644 index 0000000..f340290 --- /dev/null +++ b/cyper-hickory/src/h3.rs @@ -0,0 +1,255 @@ +use std::{ + net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr}, + pin::Pin, + str::FromStr, + sync::Arc, + task::{Context, Poll}, + time::Duration, +}; + +use compio::{bytes::Bytes, quic::ClientBuilder, rustls::ClientConfig}; +use compio_log::debug; +use futures_util::Stream; +use hickory_net::{ + NetError, + proto::op::{DnsRequest, DnsResponse}, + xfer::{DnsExchange, DnsRequestSender, DnsResponseStream}, +}; +use http::{Request, Uri, uri}; +use hyper::body::Buf; +use send_wrapper::SendWrapper; + +use crate::{CompioRuntimeProvider, MIME_APPLICATION_DNS}; + +const H3_ALPN: &[u8] = b"h3"; + +pub async fn connect_h3( + server_name: Arc, + path: Arc, + remote_addr: SocketAddr, + bind_addr: Option, + mut config: ClientConfig, + enable_grease: bool, + timeout: Duration, +) -> Result, NetError> { + config.alpn_protocols = vec![H3_ALPN.to_vec()]; + let enable_early_data = config.enable_early_data; + + let bind = bind_addr.unwrap_or_else(|| { + if remote_addr.is_ipv4() { + SocketAddr::new(IpAddr::V4(Ipv4Addr::UNSPECIFIED), 0) + } else { + SocketAddr::new(IpAddr::V6(Ipv6Addr::UNSPECIFIED), 0) + } + }); + + let endpoint = ClientBuilder::new_with_rustls_client_config(config) + .bind(bind) + .await?; + + let mut connecting = endpoint.connect(remote_addr, &server_name, None)?; + let conn = async { + if enable_early_data { + match connecting.into_0rtt() { + Ok(conn) => return Ok(conn), + Err(f) => connecting = f, + } + } + compio::time::timeout(timeout, connecting) + .await + .map_err(|_| std::io::Error::from(std::io::ErrorKind::TimedOut))? + .map_err(|e| NetError::from(format!("quic connection error: {e}"))) + } + .await?; + + let (mut driver, send_request) = compio::quic::h3::client::builder() + .send_grease(enable_grease) + .build(conn) + .await + .map_err(|e| NetError::from(format!("h3 client error: {e}")))?; + + compio::runtime::spawn(async move { + driver.wait_idle().await; + }) + .detach(); + + let stream = H3RequestSender::new(send_request, server_name, path); + let (exchange, bg) = DnsExchange::from_stream(stream); + compio::runtime::spawn(bg).detach(); + Ok(exchange) +} + +type SendRequest = compio::quic::h3::client::SendRequest; + +struct H3RequestSender { + send_request: SendWrapper, + server_name: Arc, + path: Arc, + is_shutdown: bool, +} + +impl H3RequestSender { + fn new(send_request: SendRequest, server_name: Arc, path: Arc) -> Self { + Self { + send_request: SendWrapper::new(send_request), + server_name, + path, + is_shutdown: false, + } + } + + async fn inner_send( + mut send_request: SendWrapper, + server_name: Arc, + path: Arc, + mut request: DnsRequest, + ) -> Result { + request.metadata.id = 0; + let bytes = request.to_vec()?; + + let mut parts = uri::Parts::default(); + parts.path_and_query = Some( + uri::PathAndQuery::from_str(&path) + .map_err(|e| NetError::from(format!("invalid DoH3 path: {e}")))?, + ); + parts.scheme = Some(uri::Scheme::HTTPS); + parts.authority = Some( + uri::Authority::from_str(&server_name) + .map_err(|e| NetError::from(format!("invalid authority: {e}")))?, + ); + + let url = + Uri::from_parts(parts).map_err(|e| NetError::from(format!("uri parse error: {e}")))?; + + let request = Request::builder() + .method("POST") + .uri(url) + .header(http::header::CONTENT_TYPE, MIME_APPLICATION_DNS) + .header(http::header::ACCEPT, MIME_APPLICATION_DNS) + .header(http::header::CONTENT_LENGTH, bytes.len()) + .body(()) + .map_err(|e| NetError::from(format!("build h3 request error: {e}")))?; + + let mut stream = send_request + .send_request(request) + .await + .map_err(|e| NetError::from(format!("h3 send request error: {e}")))?; + + stream + .send_data(Bytes::from(bytes)) + .await + .map_err(|e| NetError::from(format!("h3 send data error: {e}")))?; + + stream + .finish() + .await + .map_err(|e| NetError::from(format!("h3 finish error: {e}")))?; + + let resp = stream + .recv_response() + .await + .map_err(|e| NetError::from(format!("h3 recv response error: {e}")))?; + + debug!("got response: {:#?}", resp); + + let content_length = resp + .headers() + .get(http::header::CONTENT_LENGTH) + .map(|v| v.to_str()) + .transpose() + .map_err(|e| NetError::from(format!("bad headers received: {e}")))? + .map(usize::from_str) + .transpose() + .map_err(|e| NetError::from(format!("bad headers received: {e}")))?; + + let mut response_bytes = + Vec::with_capacity(content_length.unwrap_or(512).clamp(512, 4_096)); + while let Some(chunk) = stream + .recv_data() + .await + .map_err(|e| NetError::from(format!("h3 recv data error: {e}")))? + { + response_bytes.extend_from_slice(chunk.chunk()); + + if let Some(content_length) = content_length + && response_bytes.len() >= content_length + { + break; + } + } + + if let Some(content_length) = content_length + && response_bytes.len() != content_length + { + return Err(NetError::from(format!( + "expected byte length: {}, got: {}", + content_length, + response_bytes.len() + ))); + } + + if !resp.status().is_success() { + let error_string = String::from_utf8_lossy(response_bytes.as_ref()); + + return Err(NetError::from(format!( + "http unsuccessful code: {}, message: {}", + resp.status(), + error_string + ))); + } + + let content_type = resp + .headers() + .get(http::header::CONTENT_TYPE) + .map(|h| { + h.to_str() + .map_err(|e| NetError::from(format!("ContentType header not a string: {e}"))) + }) + .unwrap_or(Ok(MIME_APPLICATION_DNS))?; + + if content_type != MIME_APPLICATION_DNS { + return Err(NetError::from(format!( + "ContentType unsupported (must be '{}'): '{}'", + MIME_APPLICATION_DNS, content_type + ))); + } + + DnsResponse::from_buffer(response_bytes).map_err(NetError::from) + } +} + +impl DnsRequestSender for H3RequestSender { + fn send_message(&mut self, request: DnsRequest) -> DnsResponseStream { + if self.is_shutdown { + panic!("can not send messages after stream is shutdown") + } + + Box::pin(SendWrapper::new(Self::inner_send( + self.send_request.clone(), + self.server_name.clone(), + self.path.clone(), + request, + ))) + .into() + } + + fn shutdown(&mut self) { + self.is_shutdown = true; + } + + fn is_shutdown(&self) -> bool { + self.is_shutdown + } +} + +impl Stream for H3RequestSender { + type Item = Result<(), NetError>; + + fn poll_next(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll> { + if self.is_shutdown { + Poll::Ready(None) + } else { + Poll::Ready(Some(Ok(()))) + } + } +} diff --git a/cyper-hickory/src/https.rs b/cyper-hickory/src/https.rs index fa19765..1de8ab8 100644 --- a/cyper-hickory/src/https.rs +++ b/cyper-hickory/src/https.rs @@ -30,7 +30,7 @@ use hyper_util::client::legacy::{ use send_wrapper::SendWrapper; use tower_service::Service; -use crate::{CompioRuntimeProvider, connect_tcp}; +use crate::{CompioRuntimeProvider, MIME_APPLICATION_DNS, connect_tcp}; pub async fn connect_https( server_name: Arc, @@ -53,8 +53,6 @@ pub async fn connect_https( Ok(exchange) } -const MIME_APPLICATION_DNS: &str = "application/dns-message"; - struct RequestSender { client: Arc>>, server_name: Arc, diff --git a/cyper-hickory/src/lib.rs b/cyper-hickory/src/lib.rs index 21d76ce..9e9a249 100644 --- a/cyper-hickory/src/lib.rs +++ b/cyper-hickory/src/lib.rs @@ -37,6 +37,12 @@ mod tls; #[cfg(feature = "https")] mod https; +#[cfg(feature = "quic")] +mod quic; + +#[cfg(feature = "h3")] +mod h3; + /// [`RuntimeProvider`] implementation for [`compio`]. It should not be used /// directly. Instead, use [`CompioConnectionProvider`] which wraps this /// provider and implements [`ConnectionProvider`] for hickory. @@ -328,8 +334,43 @@ impl ConnectionProvider for CompioConnectionProvider { .await }))) } - #[allow(unreachable_patterns)] - _ => Err(NetError::from("protocol config not supported")), + #[cfg(feature = "quic")] + ProtocolConfig::Quic { server_name } => { + let server_name = server_name.clone(); + let remote_addr = SocketAddr::new(ip, config.port); + let bind_addr = config.bind_addr; + let tls = cx.tls.clone(); + let timeout = cx.options.timeout; + Ok(Box::pin(SendWrapper::new(async move { + quic::connect_quic(server_name, remote_addr, bind_addr, tls, timeout).await + }))) + } + #[cfg(feature = "h3")] + ProtocolConfig::H3 { + server_name, + path, + disable_grease, + } => { + let server_name = server_name.clone(); + let path = path.clone(); + let remote_addr = SocketAddr::new(ip, config.port); + let bind_addr = config.bind_addr; + let tls = cx.tls.clone(); + let timeout = cx.options.timeout; + let enable_grease = !disable_grease; + Ok(Box::pin(SendWrapper::new(async move { + h3::connect_h3( + server_name, + path, + remote_addr, + bind_addr, + tls, + enable_grease, + timeout, + ) + .await + }))) + } } } @@ -365,3 +406,6 @@ async fn connect_tcp( fut.await } } + +#[cfg(any(feature = "https", feature = "h3"))] +const MIME_APPLICATION_DNS: &str = "application/dns-message"; diff --git a/cyper-hickory/src/quic.rs b/cyper-hickory/src/quic.rs new file mode 100644 index 0000000..bf90d4c --- /dev/null +++ b/cyper-hickory/src/quic.rs @@ -0,0 +1,174 @@ +//! TODO: error types + +use std::{ + net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr}, + pin::Pin, + sync::Arc, + task::{Context, Poll}, + time::Duration, +}; + +use compio::{ + bytes::Bytes, + quic::{ClientBuilder, Connection}, + rustls::ClientConfig, +}; +use compio_log::debug; +use futures_util::Stream; +use hickory_net::{ + NetError, + proto::{ + ProtoError, + op::{DnsRequest, DnsResponse, Message}, + }, + quic::DoqErrorCode, + xfer::{DnsExchange, DnsRequestSender, DnsResponseStream}, +}; +use send_wrapper::SendWrapper; + +use crate::CompioRuntimeProvider; + +const DOQ_ALPN: &[u8] = b"doq"; + +pub async fn connect_quic( + server_name: Arc, + remote_addr: SocketAddr, + bind_addr: Option, + mut config: ClientConfig, + timeout: Duration, +) -> Result, NetError> { + if config.alpn_protocols.is_empty() { + config.alpn_protocols = vec![DOQ_ALPN.to_vec()]; + } + let enable_early_data = config.enable_early_data; + + let bind = bind_addr.unwrap_or_else(|| { + if remote_addr.is_ipv4() { + SocketAddr::new(IpAddr::V4(Ipv4Addr::UNSPECIFIED), 0) + } else { + SocketAddr::new(IpAddr::V6(Ipv6Addr::UNSPECIFIED), 0) + } + }); + + let endpoint = ClientBuilder::new_with_rustls_client_config(config) + .bind(bind) + .await?; + + let mut connecting = endpoint.connect(remote_addr, &server_name, None)?; + let conn = async { + if enable_early_data { + match connecting.into_0rtt() { + Ok(conn) => return Ok(conn), + Err(f) => connecting = f, + } + } + compio::time::timeout(timeout, connecting) + .await + .map_err(|_| std::io::Error::from(std::io::ErrorKind::TimedOut))? + .map_err(|e| NetError::from(format!("quic connection error: {e}"))) + } + .await?; + + let stream = CompioQuicClientStream::new(conn); + let (exchange, bg) = DnsExchange::from_stream(stream); + compio::runtime::spawn(bg).detach(); + Ok(exchange) +} + +struct CompioQuicClientStream { + conn: SendWrapper, + is_shutdown: bool, +} + +impl CompioQuicClientStream { + fn new(conn: Connection) -> Self { + Self { + conn: SendWrapper::new(conn), + is_shutdown: false, + } + } + + async fn inner_send( + conn: SendWrapper, + request: DnsRequest, + ) -> Result { + let (send, recv) = conn + .open_bi() + .map_err(|e| NetError::from(format!("open_bi error: {e}")))?; + + let mut send = send.into_compat(); + let mut recv = recv.into_compat(); + + let mut message = request.into_parts().0; + message.metadata.id = 0; + + let bytes = Bytes::from(message.to_vec()?); + let len = u16::try_from(bytes.len()) + .map_err(|_| NetError::from(ProtoError::MaxBufferSizeExceeded(bytes.len())))?; + + let len_bytes = Bytes::from(len.to_be_bytes().to_vec()); + + send.write_all_chunks(&mut [len_bytes, bytes]) + .await + .map_err(|e| NetError::from(format!("quic write error: {e}")))?; + + send.finish() + .map_err(|e| NetError::from(format!("quic finish error: {e}")))?; + + let mut len_buf = [0u8; 2]; + recv.read_exact(&mut len_buf[..]) + .await + .map_err(|e| NetError::from(format!("quic read length error: {e}")))?; + let response_len = u16::from_be_bytes(len_buf) as usize; + + let mut msg_buf = vec![0u8; response_len]; + recv.read_exact(&mut msg_buf[..]) + .await + .map_err(|e| NetError::from(format!("quic read message error: {e}")))?; + + let message = Message::from_vec(&msg_buf)?; + if message.id != 0 { + if let Err(_e) = send.reset(DoqErrorCode::ProtocolError.into()) { + debug!("failed to reset stream: {_e:?}"); + } + return Err(NetError::QuicMessageIdNot0(message.id)); + } + + Ok(DnsResponse::from_buffer(msg_buf)?) + } +} + +impl DnsRequestSender for CompioQuicClientStream { + fn send_message(&mut self, request: DnsRequest) -> DnsResponseStream { + if self.is_shutdown { + panic!("can not send messages after stream is shutdown") + } + + Box::pin(SendWrapper::new(Self::inner_send( + self.conn.clone(), + request, + ))) + .into() + } + + fn shutdown(&mut self) { + self.is_shutdown = true; + self.conn.close(DoqErrorCode::NoError.into(), b"shutdown"); + } + + fn is_shutdown(&self) -> bool { + self.is_shutdown + } +} + +impl Stream for CompioQuicClientStream { + type Item = Result<(), NetError>; + + fn poll_next(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll> { + if self.is_shutdown { + Poll::Ready(None) + } else { + Poll::Ready(Some(Ok(()))) + } + } +} diff --git a/cyper-hickory/tests/resolve.rs b/cyper-hickory/tests/resolve.rs index c41f581..c458926 100644 --- a/cyper-hickory/tests/resolve.rs +++ b/cyper-hickory/tests/resolve.rs @@ -33,9 +33,13 @@ enum ServerType { Tcp, Udp, #[cfg(feature = "tls")] - Tls(compio::rustls::ServerConfig), + Tls(ServerConfig), #[cfg(feature = "https")] - Https(compio::rustls::ServerConfig, String, String), + Https(ServerConfig, String, String), + #[cfg(feature = "quic")] + Quic(ServerConfig), + #[cfg(feature = "h3")] + H3(ServerConfig), } const IP_RESPONSE: IpAddr = IpAddr::V4(Ipv4Addr::new(127, 0, 0, 2)); @@ -116,6 +120,35 @@ async fn spawn_dns_server(ty: ServerType) -> (SocketAddr, CancellationToken, Joi ) .unwrap(); } + #[cfg(feature = "quic")] + ServerType::Quic(mut config) => { + config.alpn_protocols = vec![b"doq".to_vec()]; + let socket = UdpSocket::bind("127.0.0.1:0").await.unwrap(); + let addr = socket.local_addr().unwrap(); + tx.send((addr, token)).unwrap(); + server + .register_quic_listener_and_tls_config( + socket, + Duration::from_secs(5), + Arc::new(config), + ) + .unwrap(); + } + #[cfg(feature = "h3")] + ServerType::H3(mut config) => { + config.alpn_protocols = vec![b"h3".to_vec()]; + let socket = UdpSocket::bind("127.0.0.1:0").await.unwrap(); + let addr = socket.local_addr().unwrap(); + tx.send((addr, token)).unwrap(); + server + .register_h3_listener_with_tls_config( + socket, + Duration::from_secs(5), + Arc::new(config), + Some("dns.compio.rs".to_string()), + ) + .unwrap(); + } } server.block_until_done().await.unwrap(); }); @@ -258,3 +291,47 @@ async fn resolve_https() { token.cancel(); handle.join().unwrap_or_else(|e| resume_unwind(e)) } + +#[compio::test] +#[cfg(feature = "quic")] +async fn resolve_quic() { + let (server_config, client_config) = rcgen(); + let (addr, token, handle) = spawn_dns_server(ServerType::Quic(server_config)).await; + let group = ServerGroup { + ips: &[addr.ip()], + server_name: "dns.compio.rs", + path: "/dns-query", + }; + let mut config = ResolverConfig::quic(&group); + update_port(&mut config, addr.port()); + let resolver = Resolver::builder_with_config(config, CompioConnectionProvider::default()) + .with_tls_config(client_config) + .build() + .unwrap(); + + test_resolve(resolver).await; + token.cancel(); + handle.join().unwrap_or_else(|e| resume_unwind(e)) +} + +#[compio::test] +#[cfg(feature = "h3")] +async fn resolve_h3() { + let (server_config, client_config) = rcgen(); + let (addr, token, handle) = spawn_dns_server(ServerType::H3(server_config)).await; + let group = ServerGroup { + ips: &[addr.ip()], + server_name: "dns.compio.rs", + path: "/dns-query", + }; + let mut config = ResolverConfig::h3(&group); + update_port(&mut config, addr.port()); + let resolver = Resolver::builder_with_config(config, CompioConnectionProvider::default()) + .with_tls_config(client_config) + .build() + .unwrap(); + + test_resolve(resolver).await; + token.cancel(); + handle.join().unwrap_or_else(|e| resume_unwind(e)) +} From 1eb207e2a48830f52f3b95b3c32f4befd37f8209 Mon Sep 17 00:00:00 2001 From: Yuyi Wang Date: Wed, 13 May 2026 00:16:50 +0800 Subject: [PATCH 2/7] refactor(dns): reuse quic code --- cyper-hickory/src/h3.rs | 42 +++++++++----------------------- cyper-hickory/src/lib.rs | 43 +++++++++++++++++++++++++++++++++ cyper-hickory/src/quic.rs | 50 +++++++++------------------------------ 3 files changed, 65 insertions(+), 70 deletions(-) diff --git a/cyper-hickory/src/h3.rs b/cyper-hickory/src/h3.rs index f340290..aaa046d 100644 --- a/cyper-hickory/src/h3.rs +++ b/cyper-hickory/src/h3.rs @@ -1,5 +1,5 @@ use std::{ - net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr}, + net::SocketAddr, pin::Pin, str::FromStr, sync::Arc, @@ -7,7 +7,7 @@ use std::{ time::Duration, }; -use compio::{bytes::Bytes, quic::ClientBuilder, rustls::ClientConfig}; +use compio::{bytes::Bytes, rustls::ClientConfig}; use compio_log::debug; use futures_util::Stream; use hickory_net::{ @@ -28,38 +28,18 @@ pub async fn connect_h3( path: Arc, remote_addr: SocketAddr, bind_addr: Option, - mut config: ClientConfig, + config: ClientConfig, enable_grease: bool, timeout: Duration, ) -> Result, NetError> { - config.alpn_protocols = vec![H3_ALPN.to_vec()]; - let enable_early_data = config.enable_early_data; - - let bind = bind_addr.unwrap_or_else(|| { - if remote_addr.is_ipv4() { - SocketAddr::new(IpAddr::V4(Ipv4Addr::UNSPECIFIED), 0) - } else { - SocketAddr::new(IpAddr::V6(Ipv6Addr::UNSPECIFIED), 0) - } - }); - - let endpoint = ClientBuilder::new_with_rustls_client_config(config) - .bind(bind) - .await?; - - let mut connecting = endpoint.connect(remote_addr, &server_name, None)?; - let conn = async { - if enable_early_data { - match connecting.into_0rtt() { - Ok(conn) => return Ok(conn), - Err(f) => connecting = f, - } - } - compio::time::timeout(timeout, connecting) - .await - .map_err(|_| std::io::Error::from(std::io::ErrorKind::TimedOut))? - .map_err(|e| NetError::from(format!("quic connection error: {e}"))) - } + let conn = crate::connect_quic( + server_name.clone(), + remote_addr, + bind_addr, + config, + timeout, + H3_ALPN, + ) .await?; let (mut driver, send_request) = compio::quic::h3::client::builder() diff --git a/cyper-hickory/src/lib.rs b/cyper-hickory/src/lib.rs index 9e9a249..d16cf57 100644 --- a/cyper-hickory/src/lib.rs +++ b/cyper-hickory/src/lib.rs @@ -409,3 +409,46 @@ async fn connect_tcp( #[cfg(any(feature = "https", feature = "h3"))] const MIME_APPLICATION_DNS: &str = "application/dns-message"; + +#[cfg(feature = "__quic")] +async fn connect_quic( + server_name: std::sync::Arc, + remote_addr: SocketAddr, + bind_addr: Option, + mut config: compio::rustls::ClientConfig, + timeout: Duration, + alpn: &[u8], +) -> Result { + use std::net::{Ipv4Addr, Ipv6Addr}; + + use compio::quic::ClientBuilder; + + if config.alpn_protocols.is_empty() { + config.alpn_protocols = vec![alpn.to_vec()]; + } + let enable_early_data = config.enable_early_data; + + let bind = bind_addr.unwrap_or_else(|| { + if remote_addr.is_ipv4() { + SocketAddr::new(IpAddr::V4(Ipv4Addr::UNSPECIFIED), 0) + } else { + SocketAddr::new(IpAddr::V6(Ipv6Addr::UNSPECIFIED), 0) + } + }); + + let endpoint = ClientBuilder::new_with_rustls_client_config(config) + .bind(bind) + .await?; + + let mut connecting = endpoint.connect(remote_addr, &server_name, None)?; + if enable_early_data { + match connecting.into_0rtt() { + Ok(conn) => return Ok(conn), + Err(f) => connecting = f, + } + } + compio::time::timeout(timeout, connecting) + .await + .map_err(|_| std::io::Error::from(std::io::ErrorKind::TimedOut))? + .map_err(|e| NetError::from(format!("quic connection error: {e}"))) +} diff --git a/cyper-hickory/src/quic.rs b/cyper-hickory/src/quic.rs index bf90d4c..58ebda2 100644 --- a/cyper-hickory/src/quic.rs +++ b/cyper-hickory/src/quic.rs @@ -1,18 +1,12 @@ -//! TODO: error types - use std::{ - net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr}, + net::SocketAddr, pin::Pin, sync::Arc, task::{Context, Poll}, time::Duration, }; -use compio::{ - bytes::Bytes, - quic::{ClientBuilder, Connection}, - rustls::ClientConfig, -}; +use compio::{bytes::Bytes, quic::Connection, rustls::ClientConfig}; use compio_log::debug; use futures_util::Stream; use hickory_net::{ @@ -34,39 +28,17 @@ pub async fn connect_quic( server_name: Arc, remote_addr: SocketAddr, bind_addr: Option, - mut config: ClientConfig, + config: ClientConfig, timeout: Duration, ) -> Result, NetError> { - if config.alpn_protocols.is_empty() { - config.alpn_protocols = vec![DOQ_ALPN.to_vec()]; - } - let enable_early_data = config.enable_early_data; - - let bind = bind_addr.unwrap_or_else(|| { - if remote_addr.is_ipv4() { - SocketAddr::new(IpAddr::V4(Ipv4Addr::UNSPECIFIED), 0) - } else { - SocketAddr::new(IpAddr::V6(Ipv6Addr::UNSPECIFIED), 0) - } - }); - - let endpoint = ClientBuilder::new_with_rustls_client_config(config) - .bind(bind) - .await?; - - let mut connecting = endpoint.connect(remote_addr, &server_name, None)?; - let conn = async { - if enable_early_data { - match connecting.into_0rtt() { - Ok(conn) => return Ok(conn), - Err(f) => connecting = f, - } - } - compio::time::timeout(timeout, connecting) - .await - .map_err(|_| std::io::Error::from(std::io::ErrorKind::TimedOut))? - .map_err(|e| NetError::from(format!("quic connection error: {e}"))) - } + let conn = crate::connect_quic( + server_name, + remote_addr, + bind_addr, + config, + timeout, + DOQ_ALPN, + ) .await?; let stream = CompioQuicClientStream::new(conn); From 1b6ee1074aaee5168858f96fae2677fcb1f63051 Mon Sep 17 00:00:00 2001 From: Yuyi Wang Date: Wed, 13 May 2026 00:19:29 +0800 Subject: [PATCH 3/7] fix(dns): features --- cyper-hickory/Cargo.toml | 5 +++-- cyper-hickory/src/h3.rs | 6 ++++-- 2 files changed, 7 insertions(+), 4 deletions(-) diff --git a/cyper-hickory/Cargo.toml b/cyper-hickory/Cargo.toml index e4000a5..a4db540 100644 --- a/cyper-hickory/Cargo.toml +++ b/cyper-hickory/Cargo.toml @@ -48,7 +48,7 @@ __ring = ["compio/ring", "compio/rustls"] __tls = ["compio/tls", "__ring"] tls = ["__tls", "hickory-resolver/tls-ring", "hickory-server/tls-ring"] https = [ - "__tls", + "tls", "hickory-resolver/https-ring", "hickory-server/https-ring", "dep:cyper-core", @@ -61,12 +61,13 @@ https = [ __quic = ["__ring", "compio/quic"] quic = [ "__quic", + "tls", "compio/bytes", "hickory-resolver/quic-ring", "hickory-server/quic-ring", ] h3 = [ - "__quic", + "quic", "compio/h3", "dep:http", "hickory-resolver/h3-ring", diff --git a/cyper-hickory/src/h3.rs b/cyper-hickory/src/h3.rs index aaa046d..33e377e 100644 --- a/cyper-hickory/src/h3.rs +++ b/cyper-hickory/src/h3.rs @@ -7,7 +7,10 @@ use std::{ time::Duration, }; -use compio::{bytes::Bytes, rustls::ClientConfig}; +use compio::{ + bytes::{Buf, Bytes}, + rustls::ClientConfig, +}; use compio_log::debug; use futures_util::Stream; use hickory_net::{ @@ -16,7 +19,6 @@ use hickory_net::{ xfer::{DnsExchange, DnsRequestSender, DnsResponseStream}, }; use http::{Request, Uri, uri}; -use hyper::body::Buf; use send_wrapper::SendWrapper; use crate::{CompioRuntimeProvider, MIME_APPLICATION_DNS}; From ad8ad97ec660ba7ba6d2a7b90cf19d656a9721ff Mon Sep 17 00:00:00 2001 From: Yuyi Wang Date: Wed, 13 May 2026 00:26:00 +0800 Subject: [PATCH 4/7] feat(dns,h3): shutdown driver gracefully --- cyper-hickory/Cargo.toml | 2 ++ cyper-hickory/src/h3.rs | 48 +++++++++++++++++++++++++++++++++------- 2 files changed, 42 insertions(+), 8 deletions(-) diff --git a/cyper-hickory/Cargo.toml b/cyper-hickory/Cargo.toml index a4db540..a5b4267 100644 --- a/cyper-hickory/Cargo.toml +++ b/cyper-hickory/Cargo.toml @@ -17,6 +17,7 @@ hickory-net = { workspace = true } hickory-resolver = { workspace = true } async-trait = "0.1" +futures-channel = { workspace = true, optional = true } futures-util = { workspace = true } send_wrapper = { workspace = true, features = ["futures"] } @@ -70,6 +71,7 @@ h3 = [ "quic", "compio/h3", "dep:http", + "dep:futures-channel", "hickory-resolver/h3-ring", "hickory-server/h3-ring", ] diff --git a/cyper-hickory/src/h3.rs b/cyper-hickory/src/h3.rs index 33e377e..7586b38 100644 --- a/cyper-hickory/src/h3.rs +++ b/cyper-hickory/src/h3.rs @@ -11,8 +11,9 @@ use compio::{ bytes::{Buf, Bytes}, rustls::ClientConfig, }; -use compio_log::debug; -use futures_util::Stream; +use compio_log::{debug, warn}; +use futures_channel::mpsc::Sender; +use futures_util::{FutureExt, SinkExt, Stream}; use hickory_net::{ NetError, proto::op::{DnsRequest, DnsResponse}, @@ -50,12 +51,23 @@ pub async fn connect_h3( .await .map_err(|e| NetError::from(format!("h3 client error: {e}")))?; + let (tx, mut rx) = futures_channel::mpsc::channel::<()>(1); + compio::runtime::spawn(async move { - driver.wait_idle().await; + futures_util::select! { + error = driver.wait_idle().fuse() => { + if !error.is_h3_no_error() { + warn!("h3 connection failed to close: {}", error); + } + } + _ = rx.recv().fuse() => { + debug!("h3 connection is shutting down: {}", remote_addr); + } + } }) .detach(); - let stream = H3RequestSender::new(send_request, server_name, path); + let stream = H3RequestSender::new(send_request, server_name, path, tx); let (exchange, bg) = DnsExchange::from_stream(stream); compio::runtime::spawn(bg).detach(); Ok(exchange) @@ -67,15 +79,22 @@ struct H3RequestSender { send_request: SendWrapper, server_name: Arc, path: Arc, + tx: Sender<()>, is_shutdown: bool, } impl H3RequestSender { - fn new(send_request: SendRequest, server_name: Arc, path: Arc) -> Self { + fn new( + send_request: SendRequest, + server_name: Arc, + path: Arc, + tx: Sender<()>, + ) -> Self { Self { send_request: SendWrapper::new(send_request), server_name, path, + tx, is_shutdown: false, } } @@ -217,6 +236,13 @@ impl DnsRequestSender for H3RequestSender { fn shutdown(&mut self) { self.is_shutdown = true; + compio::runtime::spawn({ + let mut tx = self.tx.clone(); + async move { + let _ = tx.send(()).await; + } + }) + .detach(); } fn is_shutdown(&self) -> bool { @@ -229,9 +255,15 @@ impl Stream for H3RequestSender { fn poll_next(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll> { if self.is_shutdown { - Poll::Ready(None) - } else { - Poll::Ready(Some(Ok(()))) + return Poll::Ready(None); } + + if self.tx.is_closed() { + return Poll::Ready(Some(Err(NetError::from( + "h3 connection is already shutdown", + )))); + } + + Poll::Ready(Some(Ok(()))) } } From 739126aa99a87be690c28526e250fe658c628d8a Mon Sep 17 00:00:00 2001 From: Yuyi Wang Date: Wed, 13 May 2026 00:45:24 +0800 Subject: [PATCH 5/7] refactor(dns): reuse http handling code --- cyper-hickory/Cargo.toml | 5 ++- cyper-hickory/src/h3.rs | 78 +++----------------------------- cyper-hickory/src/https.rs | 84 +++++----------------------------- cyper-hickory/src/lib.rs | 92 +++++++++++++++++++++++++++++++++++++- 4 files changed, 110 insertions(+), 149 deletions(-) diff --git a/cyper-hickory/Cargo.toml b/cyper-hickory/Cargo.toml index a5b4267..65e1f06 100644 --- a/cyper-hickory/Cargo.toml +++ b/cyper-hickory/Cargo.toml @@ -48,12 +48,13 @@ dnssec = ["hickory-resolver/dnssec-ring", "hickory-server/dnssec-ring"] __ring = ["compio/ring", "compio/rustls"] __tls = ["compio/tls", "__ring"] tls = ["__tls", "hickory-resolver/tls-ring", "hickory-server/tls-ring"] +__http = ["dep:http"] https = [ "tls", + "__http", "hickory-resolver/https-ring", "hickory-server/https-ring", "dep:cyper-core", - "dep:http", "dep:http-body-util", "dep:hyper", "dep:hyper-util", @@ -69,8 +70,8 @@ quic = [ ] h3 = [ "quic", + "__http", "compio/h3", - "dep:http", "dep:futures-channel", "hickory-resolver/h3-ring", "hickory-server/h3-ring", diff --git a/cyper-hickory/src/h3.rs b/cyper-hickory/src/h3.rs index 7586b38..e267ac3 100644 --- a/cyper-hickory/src/h3.rs +++ b/cyper-hickory/src/h3.rs @@ -1,7 +1,6 @@ use std::{ net::SocketAddr, pin::Pin, - str::FromStr, sync::Arc, task::{Context, Poll}, time::Duration, @@ -19,10 +18,9 @@ use hickory_net::{ proto::op::{DnsRequest, DnsResponse}, xfer::{DnsExchange, DnsRequestSender, DnsResponseStream}, }; -use http::{Request, Uri, uri}; use send_wrapper::SendWrapper; -use crate::{CompioRuntimeProvider, MIME_APPLICATION_DNS}; +use crate::CompioRuntimeProvider; const H3_ALPN: &[u8] = b"h3"; @@ -108,28 +106,7 @@ impl H3RequestSender { request.metadata.id = 0; let bytes = request.to_vec()?; - let mut parts = uri::Parts::default(); - parts.path_and_query = Some( - uri::PathAndQuery::from_str(&path) - .map_err(|e| NetError::from(format!("invalid DoH3 path: {e}")))?, - ); - parts.scheme = Some(uri::Scheme::HTTPS); - parts.authority = Some( - uri::Authority::from_str(&server_name) - .map_err(|e| NetError::from(format!("invalid authority: {e}")))?, - ); - - let url = - Uri::from_parts(parts).map_err(|e| NetError::from(format!("uri parse error: {e}")))?; - - let request = Request::builder() - .method("POST") - .uri(url) - .header(http::header::CONTENT_TYPE, MIME_APPLICATION_DNS) - .header(http::header::ACCEPT, MIME_APPLICATION_DNS) - .header(http::header::CONTENT_LENGTH, bytes.len()) - .body(()) - .map_err(|e| NetError::from(format!("build h3 request error: {e}")))?; + let request = crate::build_request(&server_name, &path, bytes.len())?; let mut stream = send_request .send_request(request) @@ -150,18 +127,9 @@ impl H3RequestSender { .recv_response() .await .map_err(|e| NetError::from(format!("h3 recv response error: {e}")))?; + let (resp, ()) = resp.into_parts(); - debug!("got response: {:#?}", resp); - - let content_length = resp - .headers() - .get(http::header::CONTENT_LENGTH) - .map(|v| v.to_str()) - .transpose() - .map_err(|e| NetError::from(format!("bad headers received: {e}")))? - .map(usize::from_str) - .transpose() - .map_err(|e| NetError::from(format!("bad headers received: {e}")))?; + let content_length = crate::get_content_length(&resp.headers)?; let mut response_bytes = Vec::with_capacity(content_length.unwrap_or(512).clamp(512, 4_096)); @@ -179,43 +147,7 @@ impl H3RequestSender { } } - if let Some(content_length) = content_length - && response_bytes.len() != content_length - { - return Err(NetError::from(format!( - "expected byte length: {}, got: {}", - content_length, - response_bytes.len() - ))); - } - - if !resp.status().is_success() { - let error_string = String::from_utf8_lossy(response_bytes.as_ref()); - - return Err(NetError::from(format!( - "http unsuccessful code: {}, message: {}", - resp.status(), - error_string - ))); - } - - let content_type = resp - .headers() - .get(http::header::CONTENT_TYPE) - .map(|h| { - h.to_str() - .map_err(|e| NetError::from(format!("ContentType header not a string: {e}"))) - }) - .unwrap_or(Ok(MIME_APPLICATION_DNS))?; - - if content_type != MIME_APPLICATION_DNS { - return Err(NetError::from(format!( - "ContentType unsupported (must be '{}'): '{}'", - MIME_APPLICATION_DNS, content_type - ))); - } - - DnsResponse::from_buffer(response_bytes).map_err(NetError::from) + crate::build_response(resp, content_length, response_bytes) } } diff --git a/cyper-hickory/src/https.rs b/cyper-hickory/src/https.rs index 1de8ab8..2e5a72b 100644 --- a/cyper-hickory/src/https.rs +++ b/cyper-hickory/src/https.rs @@ -2,7 +2,6 @@ use std::{ io, net::SocketAddr, pin::Pin, - str::FromStr, sync::Arc, task::{Context, Poll}, time::Duration, @@ -20,7 +19,7 @@ use hickory_net::{ proto::op::{DnsRequest, DnsResponse}, xfer::{DnsExchange, DnsRequestSender, DnsResponseStream}, }; -use http::{Request, Uri, uri}; +use http::{Request, Uri}; use http_body_util::{BodyExt, Full}; use hyper::{body::Bytes, rt::ReadBufCursor}; use hyper_util::client::legacy::{ @@ -30,7 +29,7 @@ use hyper_util::client::legacy::{ use send_wrapper::SendWrapper; use tower_service::Service; -use crate::{CompioRuntimeProvider, MIME_APPLICATION_DNS, connect_tcp}; +use crate::CompioRuntimeProvider; pub async fn connect_https( server_name: Arc, @@ -83,28 +82,10 @@ impl RequestSender { request.metadata.id = 0; let bytes = request.to_vec()?; - let mut parts = uri::Parts::default(); - parts.path_and_query = Some( - uri::PathAndQuery::from_str(&path) - .map_err(|e| NetError::from(format!("invalid DoH path: {e:?}")))?, - ); - parts.scheme = Some(uri::Scheme::HTTPS); - parts.authority = Some( - uri::Authority::from_str(&server_name) - .map_err(|e| NetError::from(format!("invalid authority: {e:?}")))?, - ); - - let url = Uri::from_parts(parts) - .map_err(|e| NetError::from(format!("uri parse error: {e:?}")))?; - - let request = Request::builder() - .method("POST") - .uri(url) - .header(http::header::CONTENT_TYPE, MIME_APPLICATION_DNS) - .header(http::header::ACCEPT, MIME_APPLICATION_DNS) - .header(http::header::CONTENT_LENGTH, bytes.len()) - .body(Full::new(Bytes::from(bytes))) - .map_err(|e| NetError::from(format!("build request error: {e:?}")))?; + let request = crate::build_request(&server_name, &path, bytes.len())?; + + let (parts, ()) = request.into_parts(); + let request = Request::from_parts(parts, Full::new(Bytes::from(bytes))); let response = client .request(request) @@ -112,59 +93,16 @@ impl RequestSender { .map_err(|e| NetError::from(format!("request error: {e:?}")))?; let (response, body) = response.into_parts(); - let content_length = response - .headers - .get(http::header::CONTENT_LENGTH) - .map(|v| v.to_str()) - .transpose() - .map_err(|e| NetError::from(format!("bad headers received: {e:?}")))? - .map(usize::from_str) - .transpose() - .map_err(|e| NetError::from(format!("bad headers received: {e:?}")))?; + let content_length = crate::get_content_length(&response.headers)?; let response_bytes = body .collect() .await .map_err(|e| NetError::from(format!("read response body error: {e:?}")))? - .to_bytes(); - - if let Some(content_length) = content_length - && response_bytes.len() != content_length - { - return Err(NetError::from(format!( - "expected byte length: {}, got: {}", - content_length, - response_bytes.len() - ))); - } - - if !response.status.is_success() { - let error_string = String::from_utf8_lossy(response_bytes.as_ref()); - - return Err(NetError::from(format!( - "http unsuccessful code: {}, message: {}", - response.status, error_string - ))); - } else { - let content_type = response - .headers - .get(http::header::CONTENT_TYPE) - .map(|h| { - h.to_str().map_err(|err| { - NetError::from(format!("ContentType header not a string: {err}")) - }) - }) - .unwrap_or(Ok(MIME_APPLICATION_DNS))?; - - if content_type != MIME_APPLICATION_DNS { - return Err(NetError::from(format!( - "ContentType unsupported (must be '{}'): '{}'", - MIME_APPLICATION_DNS, content_type - ))); - } - } + .to_bytes() + .to_vec(); - DnsResponse::from_buffer(response_bytes.to_vec()).map_err(NetError::from) + crate::build_response(response, content_length, response_bytes) } } @@ -247,7 +185,7 @@ impl Service for Connector { let tls = self.tls.clone(); let timeout = self.timeout; Box::pin(SendWrapper::new(async move { - let stream = connect_tcp(remote_addr, bind_addr, Some(timeout)).await?; + let stream = crate::connect_tcp(remote_addr, bind_addr, Some(timeout)).await?; let stream = TlsConnector::from(tls) .connect(&server_name, stream) .await?; diff --git a/cyper-hickory/src/lib.rs b/cyper-hickory/src/lib.rs index d16cf57..f93c7fa 100644 --- a/cyper-hickory/src/lib.rs +++ b/cyper-hickory/src/lib.rs @@ -407,9 +407,99 @@ async fn connect_tcp( } } -#[cfg(any(feature = "https", feature = "h3"))] +#[cfg(feature = "__http")] const MIME_APPLICATION_DNS: &str = "application/dns-message"; +#[cfg(feature = "__http")] +fn build_request(server_name: &str, path: &str, len: usize) -> Result, NetError> { + use std::str::FromStr; + + use http::uri::{Authority, Parts, PathAndQuery, Scheme, Uri}; + + let mut parts = Parts::default(); + parts.path_and_query = Some( + PathAndQuery::from_str(path).map_err(|e| NetError::from(format!("invalid path: {e:?}")))?, + ); + parts.scheme = Some(Scheme::HTTPS); + parts.authority = Some( + Authority::from_str(server_name) + .map_err(|e| NetError::from(format!("invalid authority: {e:?}")))?, + ); + + let url = + Uri::from_parts(parts).map_err(|e| NetError::from(format!("uri parse error: {e:?}")))?; + + let request = http::Request::builder() + .method("POST") + .uri(url) + .header(http::header::CONTENT_TYPE, MIME_APPLICATION_DNS) + .header(http::header::ACCEPT, MIME_APPLICATION_DNS) + .header(http::header::CONTENT_LENGTH, len) + .body(()) + .map_err(|e| NetError::from(format!("build request error: {e:?}")))?; + + Ok(request) +} + +#[cfg(feature = "__http")] +fn get_content_length(headers: &http::HeaderMap) -> Result, NetError> { + use std::str::FromStr; + + headers + .get(http::header::CONTENT_LENGTH) + .map(|v| v.to_str()) + .transpose() + .map_err(|e| NetError::from(format!("bad headers received: {e:?}")))? + .map(usize::from_str) + .transpose() + .map_err(|e| NetError::from(format!("bad headers received: {e:?}"))) +} + +#[cfg(feature = "__http")] +fn build_response( + response: http::response::Parts, + content_length: Option, + response_bytes: Vec, +) -> Result { + use hickory_net::proto::op::DnsResponse; + + if let Some(content_length) = content_length + && response_bytes.len() != content_length + { + return Err(NetError::from(format!( + "expected byte length: {}, got: {}", + content_length, + response_bytes.len() + ))); + } + + if !response.status.is_success() { + let error_string = String::from_utf8_lossy(response_bytes.as_ref()); + + return Err(NetError::from(format!( + "http unsuccessful code: {}, message: {}", + response.status, error_string + ))); + } + let content_type = response + .headers + .get(http::header::CONTENT_TYPE) + .map(|h| { + h.to_str() + .map_err(|err| NetError::from(format!("ContentType header not a string: {err}"))) + }) + .unwrap_or(Ok(MIME_APPLICATION_DNS))?; + + if content_type != MIME_APPLICATION_DNS { + return Err(NetError::from(format!( + "ContentType unsupported (must be '{}'): '{}'", + MIME_APPLICATION_DNS, content_type + ))); + } + + DnsResponse::from_buffer(response_bytes.to_vec()).map_err(NetError::from) +} + #[cfg(feature = "__quic")] async fn connect_quic( server_name: std::sync::Arc, From d3f2499034f2d573f4dc675b8b66a537ba72f71c Mon Sep 17 00:00:00 2001 From: Yuyi Wang Date: Wed, 13 May 2026 01:15:45 +0800 Subject: [PATCH 6/7] refactor(dns,quic): use quinn errors --- cyper-hickory/Cargo.toml | 1 + cyper-hickory/src/quic.rs | 101 ++++++++++++++++++++++++++++++++++---- 2 files changed, 93 insertions(+), 9 deletions(-) diff --git a/cyper-hickory/Cargo.toml b/cyper-hickory/Cargo.toml index 65e1f06..1d6d991 100644 --- a/cyper-hickory/Cargo.toml +++ b/cyper-hickory/Cargo.toml @@ -30,6 +30,7 @@ hyper-util = { workspace = true, optional = true, features = [ "http1", "http2", ] } +quinn = { version = "0.11", default-features = false } tower-service = { workspace = true, optional = true } [dev-dependencies] diff --git a/cyper-hickory/src/quic.rs b/cyper-hickory/src/quic.rs index 58ebda2..73b4d24 100644 --- a/cyper-hickory/src/quic.rs +++ b/cyper-hickory/src/quic.rs @@ -6,7 +6,11 @@ use std::{ time::Duration, }; -use compio::{bytes::Bytes, quic::Connection, rustls::ClientConfig}; +use compio::{ + bytes::Bytes, + quic::{Connection, ConnectionError, OpenStreamError, ReadError, ReadExactError, WriteError}, + rustls::ClientConfig, +}; use compio_log::debug; use futures_util::Stream; use hickory_net::{ @@ -64,9 +68,7 @@ impl CompioQuicClientStream { conn: SendWrapper, request: DnsRequest, ) -> Result { - let (send, recv) = conn - .open_bi() - .map_err(|e| NetError::from(format!("open_bi error: {e}")))?; + let (send, recv) = conn.open_bi().map_err(ToNetError::to_net_error)?; let mut send = send.into_compat(); let mut recv = recv.into_compat(); @@ -82,21 +84,20 @@ impl CompioQuicClientStream { send.write_all_chunks(&mut [len_bytes, bytes]) .await - .map_err(|e| NetError::from(format!("quic write error: {e}")))?; + .map_err(ToNetError::to_net_error)?; - send.finish() - .map_err(|e| NetError::from(format!("quic finish error: {e}")))?; + send.finish()?; let mut len_buf = [0u8; 2]; recv.read_exact(&mut len_buf[..]) .await - .map_err(|e| NetError::from(format!("quic read length error: {e}")))?; + .map_err(ToNetError::to_net_error)?; let response_len = u16::from_be_bytes(len_buf) as usize; let mut msg_buf = vec![0u8; response_len]; recv.read_exact(&mut msg_buf[..]) .await - .map_err(|e| NetError::from(format!("quic read message error: {e}")))?; + .map_err(ToNetError::to_net_error)?; let message = Message::from_vec(&msg_buf)?; if message.id != 0 { @@ -144,3 +145,85 @@ impl Stream for CompioQuicClientStream { } } } + +trait ToNetError { + type QuinnError; + + fn to_net_error(self) -> Self::QuinnError; +} + +impl ToNetError for OpenStreamError { + type QuinnError = NetError; + + fn to_net_error(self) -> Self::QuinnError { + match self { + Self::ConnectionLost(err) => err.to_net_error().into(), + Self::StreamsExhausted => NetError::from(self.to_string()), + } + } +} + +impl ToNetError for ConnectionError { + type QuinnError = quinn::ConnectionError; + + fn to_net_error(self) -> Self::QuinnError { + use quinn::ConnectionError as QuinnConnectionError; + + match self { + Self::VersionMismatch => QuinnConnectionError::VersionMismatch, + Self::TransportError(err) => QuinnConnectionError::TransportError(err), + Self::ConnectionClosed(err) => QuinnConnectionError::ConnectionClosed(err), + Self::ApplicationClosed(err) => QuinnConnectionError::ApplicationClosed(err), + Self::Reset => QuinnConnectionError::Reset, + Self::TimedOut => QuinnConnectionError::TimedOut, + Self::LocallyClosed => QuinnConnectionError::LocallyClosed, + Self::CidsExhausted => QuinnConnectionError::CidsExhausted, + } + } +} + +impl ToNetError for WriteError { + type QuinnError = NetError; + + fn to_net_error(self) -> Self::QuinnError { + use quinn::WriteError as QuinnWriteError; + + let err = match self { + Self::Stopped(code) => QuinnWriteError::Stopped(code), + Self::ConnectionLost(err) => QuinnWriteError::ConnectionLost(err.to_net_error()), + Self::ClosedStream => QuinnWriteError::ClosedStream, + Self::ZeroRttRejected => QuinnWriteError::ZeroRttRejected, + _ => return NetError::from(self.to_string()), + }; + NetError::QuinnWriteError(err) + } +} + +impl ToNetError for ReadExactError { + type QuinnError = quinn::ReadExactError; + + fn to_net_error(self) -> Self::QuinnError { + use quinn::ReadExactError as QuinnReadExactError; + + match self { + Self::FinishedEarly(len) => QuinnReadExactError::FinishedEarly(len), + Self::ReadError(err) => QuinnReadExactError::ReadError(err.to_net_error()), + } + } +} + +impl ToNetError for ReadError { + type QuinnError = quinn::ReadError; + + fn to_net_error(self) -> Self::QuinnError { + use quinn::ReadError as QuinnReadError; + + match self { + Self::Reset(code) => QuinnReadError::Reset(code), + Self::ConnectionLost(err) => QuinnReadError::ConnectionLost(err.to_net_error()), + Self::ClosedStream => QuinnReadError::ClosedStream, + Self::IllegalOrderedRead => QuinnReadError::IllegalOrderedRead, + Self::ZeroRttRejected => QuinnReadError::ZeroRttRejected, + } + } +} From db3a62d955beb1c7220771cde6a46bb1605f4675 Mon Sep 17 00:00:00 2001 From: Yuyi Wang Date: Wed, 13 May 2026 01:35:58 +0800 Subject: [PATCH 7/7] refactor(dns): move util code --- cyper-hickory/src/lib.rs | 169 +-------------------------------- cyper-hickory/src/quic.rs | 90 +----------------- cyper-hickory/src/util/http.rs | 89 +++++++++++++++++ cyper-hickory/src/util/mod.rs | 12 +++ cyper-hickory/src/util/quic.rs | 135 ++++++++++++++++++++++++++ cyper-hickory/src/util/tcp.rs | 31 ++++++ 6 files changed, 273 insertions(+), 253 deletions(-) create mode 100644 cyper-hickory/src/util/http.rs create mode 100644 cyper-hickory/src/util/mod.rs create mode 100644 cyper-hickory/src/util/quic.rs create mode 100644 cyper-hickory/src/util/tcp.rs diff --git a/cyper-hickory/src/lib.rs b/cyper-hickory/src/lib.rs index f93c7fa..f614be8 100644 --- a/cyper-hickory/src/lib.rs +++ b/cyper-hickory/src/lib.rs @@ -16,7 +16,7 @@ use async_trait::async_trait; use compio::{ BufResult, io::{compat::AsyncStream, util::Splittable}, - net::{TcpSocket, TcpStream, UdpSocket}, + net::{TcpStream, UdpSocket}, runtime::Runtime, }; use futures_util::{AsyncRead, AsyncWrite}; @@ -43,6 +43,9 @@ mod quic; #[cfg(feature = "h3")] mod h3; +mod util; +pub(crate) use util::*; + /// [`RuntimeProvider`] implementation for [`compio`]. It should not be used /// directly. Instead, use [`CompioConnectionProvider`] which wraps this /// provider and implements [`ConnectionProvider`] for hickory. @@ -378,167 +381,3 @@ impl ConnectionProvider for CompioConnectionProvider { &self.provider } } - -async fn connect_tcp( - server_addr: SocketAddr, - bind_addr: Option, - timeout: Option, -) -> io::Result { - let fut = async move { - if let Some(bind_addr) = bind_addr { - let socket = if bind_addr.is_ipv4() { - TcpSocket::new_v4().await? - } else { - TcpSocket::new_v6().await? - }; - socket.bind(bind_addr).await?; - socket.connect(server_addr).await - } else { - TcpStream::connect(server_addr).await - } - }; - if let Some(timeout) = timeout { - compio::time::timeout(timeout, fut) - .await - .map_err(|_| io::ErrorKind::TimedOut.into()) - .flatten() - } else { - fut.await - } -} - -#[cfg(feature = "__http")] -const MIME_APPLICATION_DNS: &str = "application/dns-message"; - -#[cfg(feature = "__http")] -fn build_request(server_name: &str, path: &str, len: usize) -> Result, NetError> { - use std::str::FromStr; - - use http::uri::{Authority, Parts, PathAndQuery, Scheme, Uri}; - - let mut parts = Parts::default(); - parts.path_and_query = Some( - PathAndQuery::from_str(path).map_err(|e| NetError::from(format!("invalid path: {e:?}")))?, - ); - parts.scheme = Some(Scheme::HTTPS); - parts.authority = Some( - Authority::from_str(server_name) - .map_err(|e| NetError::from(format!("invalid authority: {e:?}")))?, - ); - - let url = - Uri::from_parts(parts).map_err(|e| NetError::from(format!("uri parse error: {e:?}")))?; - - let request = http::Request::builder() - .method("POST") - .uri(url) - .header(http::header::CONTENT_TYPE, MIME_APPLICATION_DNS) - .header(http::header::ACCEPT, MIME_APPLICATION_DNS) - .header(http::header::CONTENT_LENGTH, len) - .body(()) - .map_err(|e| NetError::from(format!("build request error: {e:?}")))?; - - Ok(request) -} - -#[cfg(feature = "__http")] -fn get_content_length(headers: &http::HeaderMap) -> Result, NetError> { - use std::str::FromStr; - - headers - .get(http::header::CONTENT_LENGTH) - .map(|v| v.to_str()) - .transpose() - .map_err(|e| NetError::from(format!("bad headers received: {e:?}")))? - .map(usize::from_str) - .transpose() - .map_err(|e| NetError::from(format!("bad headers received: {e:?}"))) -} - -#[cfg(feature = "__http")] -fn build_response( - response: http::response::Parts, - content_length: Option, - response_bytes: Vec, -) -> Result { - use hickory_net::proto::op::DnsResponse; - - if let Some(content_length) = content_length - && response_bytes.len() != content_length - { - return Err(NetError::from(format!( - "expected byte length: {}, got: {}", - content_length, - response_bytes.len() - ))); - } - - if !response.status.is_success() { - let error_string = String::from_utf8_lossy(response_bytes.as_ref()); - - return Err(NetError::from(format!( - "http unsuccessful code: {}, message: {}", - response.status, error_string - ))); - } - let content_type = response - .headers - .get(http::header::CONTENT_TYPE) - .map(|h| { - h.to_str() - .map_err(|err| NetError::from(format!("ContentType header not a string: {err}"))) - }) - .unwrap_or(Ok(MIME_APPLICATION_DNS))?; - - if content_type != MIME_APPLICATION_DNS { - return Err(NetError::from(format!( - "ContentType unsupported (must be '{}'): '{}'", - MIME_APPLICATION_DNS, content_type - ))); - } - - DnsResponse::from_buffer(response_bytes.to_vec()).map_err(NetError::from) -} - -#[cfg(feature = "__quic")] -async fn connect_quic( - server_name: std::sync::Arc, - remote_addr: SocketAddr, - bind_addr: Option, - mut config: compio::rustls::ClientConfig, - timeout: Duration, - alpn: &[u8], -) -> Result { - use std::net::{Ipv4Addr, Ipv6Addr}; - - use compio::quic::ClientBuilder; - - if config.alpn_protocols.is_empty() { - config.alpn_protocols = vec![alpn.to_vec()]; - } - let enable_early_data = config.enable_early_data; - - let bind = bind_addr.unwrap_or_else(|| { - if remote_addr.is_ipv4() { - SocketAddr::new(IpAddr::V4(Ipv4Addr::UNSPECIFIED), 0) - } else { - SocketAddr::new(IpAddr::V6(Ipv6Addr::UNSPECIFIED), 0) - } - }); - - let endpoint = ClientBuilder::new_with_rustls_client_config(config) - .bind(bind) - .await?; - - let mut connecting = endpoint.connect(remote_addr, &server_name, None)?; - if enable_early_data { - match connecting.into_0rtt() { - Ok(conn) => return Ok(conn), - Err(f) => connecting = f, - } - } - compio::time::timeout(timeout, connecting) - .await - .map_err(|_| std::io::Error::from(std::io::ErrorKind::TimedOut))? - .map_err(|e| NetError::from(format!("quic connection error: {e}"))) -} diff --git a/cyper-hickory/src/quic.rs b/cyper-hickory/src/quic.rs index 73b4d24..ddbf22f 100644 --- a/cyper-hickory/src/quic.rs +++ b/cyper-hickory/src/quic.rs @@ -6,11 +6,7 @@ use std::{ time::Duration, }; -use compio::{ - bytes::Bytes, - quic::{Connection, ConnectionError, OpenStreamError, ReadError, ReadExactError, WriteError}, - rustls::ClientConfig, -}; +use compio::{bytes::Bytes, quic::Connection, rustls::ClientConfig}; use compio_log::debug; use futures_util::Stream; use hickory_net::{ @@ -24,7 +20,7 @@ use hickory_net::{ }; use send_wrapper::SendWrapper; -use crate::CompioRuntimeProvider; +use crate::{CompioRuntimeProvider, ToNetError}; const DOQ_ALPN: &[u8] = b"doq"; @@ -145,85 +141,3 @@ impl Stream for CompioQuicClientStream { } } } - -trait ToNetError { - type QuinnError; - - fn to_net_error(self) -> Self::QuinnError; -} - -impl ToNetError for OpenStreamError { - type QuinnError = NetError; - - fn to_net_error(self) -> Self::QuinnError { - match self { - Self::ConnectionLost(err) => err.to_net_error().into(), - Self::StreamsExhausted => NetError::from(self.to_string()), - } - } -} - -impl ToNetError for ConnectionError { - type QuinnError = quinn::ConnectionError; - - fn to_net_error(self) -> Self::QuinnError { - use quinn::ConnectionError as QuinnConnectionError; - - match self { - Self::VersionMismatch => QuinnConnectionError::VersionMismatch, - Self::TransportError(err) => QuinnConnectionError::TransportError(err), - Self::ConnectionClosed(err) => QuinnConnectionError::ConnectionClosed(err), - Self::ApplicationClosed(err) => QuinnConnectionError::ApplicationClosed(err), - Self::Reset => QuinnConnectionError::Reset, - Self::TimedOut => QuinnConnectionError::TimedOut, - Self::LocallyClosed => QuinnConnectionError::LocallyClosed, - Self::CidsExhausted => QuinnConnectionError::CidsExhausted, - } - } -} - -impl ToNetError for WriteError { - type QuinnError = NetError; - - fn to_net_error(self) -> Self::QuinnError { - use quinn::WriteError as QuinnWriteError; - - let err = match self { - Self::Stopped(code) => QuinnWriteError::Stopped(code), - Self::ConnectionLost(err) => QuinnWriteError::ConnectionLost(err.to_net_error()), - Self::ClosedStream => QuinnWriteError::ClosedStream, - Self::ZeroRttRejected => QuinnWriteError::ZeroRttRejected, - _ => return NetError::from(self.to_string()), - }; - NetError::QuinnWriteError(err) - } -} - -impl ToNetError for ReadExactError { - type QuinnError = quinn::ReadExactError; - - fn to_net_error(self) -> Self::QuinnError { - use quinn::ReadExactError as QuinnReadExactError; - - match self { - Self::FinishedEarly(len) => QuinnReadExactError::FinishedEarly(len), - Self::ReadError(err) => QuinnReadExactError::ReadError(err.to_net_error()), - } - } -} - -impl ToNetError for ReadError { - type QuinnError = quinn::ReadError; - - fn to_net_error(self) -> Self::QuinnError { - use quinn::ReadError as QuinnReadError; - - match self { - Self::Reset(code) => QuinnReadError::Reset(code), - Self::ConnectionLost(err) => QuinnReadError::ConnectionLost(err.to_net_error()), - Self::ClosedStream => QuinnReadError::ClosedStream, - Self::IllegalOrderedRead => QuinnReadError::IllegalOrderedRead, - Self::ZeroRttRejected => QuinnReadError::ZeroRttRejected, - } - } -} diff --git a/cyper-hickory/src/util/http.rs b/cyper-hickory/src/util/http.rs new file mode 100644 index 0000000..ac4c7b4 --- /dev/null +++ b/cyper-hickory/src/util/http.rs @@ -0,0 +1,89 @@ +use std::str::FromStr; + +use hickory_net::{NetError, proto::op::DnsResponse}; +use http::uri::{Authority, Parts, PathAndQuery, Scheme, Uri}; + +pub const MIME_APPLICATION_DNS: &str = "application/dns-message"; + +pub fn build_request( + server_name: &str, + path: &str, + len: usize, +) -> Result, NetError> { + let mut parts = Parts::default(); + parts.path_and_query = Some( + PathAndQuery::from_str(path).map_err(|e| NetError::from(format!("invalid path: {e:?}")))?, + ); + parts.scheme = Some(Scheme::HTTPS); + parts.authority = Some( + Authority::from_str(server_name) + .map_err(|e| NetError::from(format!("invalid authority: {e:?}")))?, + ); + + let url = + Uri::from_parts(parts).map_err(|e| NetError::from(format!("uri parse error: {e:?}")))?; + + let request = http::Request::builder() + .method("POST") + .uri(url) + .header(http::header::CONTENT_TYPE, MIME_APPLICATION_DNS) + .header(http::header::ACCEPT, MIME_APPLICATION_DNS) + .header(http::header::CONTENT_LENGTH, len) + .body(()) + .map_err(|e| NetError::from(format!("build request error: {e:?}")))?; + + Ok(request) +} + +pub fn get_content_length(headers: &http::HeaderMap) -> Result, NetError> { + headers + .get(http::header::CONTENT_LENGTH) + .map(|v| v.to_str()) + .transpose() + .map_err(|e| NetError::from(format!("bad headers received: {e:?}")))? + .map(usize::from_str) + .transpose() + .map_err(|e| NetError::from(format!("bad headers received: {e:?}"))) +} + +pub fn build_response( + response: http::response::Parts, + content_length: Option, + response_bytes: Vec, +) -> Result { + if let Some(content_length) = content_length + && response_bytes.len() != content_length + { + return Err(NetError::from(format!( + "expected byte length: {}, got: {}", + content_length, + response_bytes.len() + ))); + } + + if !response.status.is_success() { + let error_string = String::from_utf8_lossy(response_bytes.as_ref()); + + return Err(NetError::from(format!( + "http unsuccessful code: {}, message: {}", + response.status, error_string + ))); + } + let content_type = response + .headers + .get(http::header::CONTENT_TYPE) + .map(|h| { + h.to_str() + .map_err(|err| NetError::from(format!("ContentType header not a string: {err}"))) + }) + .unwrap_or(Ok(MIME_APPLICATION_DNS))?; + + if content_type != MIME_APPLICATION_DNS { + return Err(NetError::from(format!( + "ContentType unsupported (must be '{}'): '{}'", + MIME_APPLICATION_DNS, content_type + ))); + } + + DnsResponse::from_buffer(response_bytes.to_vec()).map_err(NetError::from) +} diff --git a/cyper-hickory/src/util/mod.rs b/cyper-hickory/src/util/mod.rs new file mode 100644 index 0000000..7fcd8f6 --- /dev/null +++ b/cyper-hickory/src/util/mod.rs @@ -0,0 +1,12 @@ +mod tcp; +pub use tcp::*; + +#[cfg(feature = "__http")] +mod http; +#[cfg(feature = "__http")] +pub use http::*; + +#[cfg(feature = "__quic")] +mod quic; +#[cfg(feature = "__quic")] +pub use quic::*; diff --git a/cyper-hickory/src/util/quic.rs b/cyper-hickory/src/util/quic.rs new file mode 100644 index 0000000..73d265b --- /dev/null +++ b/cyper-hickory/src/util/quic.rs @@ -0,0 +1,135 @@ +use std::{ + net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr}, + time::Duration, +}; + +use compio::{ + quic::{ + ClientBuilder, Connection, ConnectionError, OpenStreamError, ReadError, ReadExactError, + WriteError, + }, + rustls::ClientConfig, +}; +use hickory_net::NetError; + +pub async fn connect_quic( + server_name: std::sync::Arc, + remote_addr: SocketAddr, + bind_addr: Option, + mut config: ClientConfig, + timeout: Duration, + alpn: &[u8], +) -> Result { + if config.alpn_protocols.is_empty() { + config.alpn_protocols = vec![alpn.to_vec()]; + } + let enable_early_data = config.enable_early_data; + + let bind = bind_addr.unwrap_or_else(|| { + if remote_addr.is_ipv4() { + SocketAddr::new(IpAddr::V4(Ipv4Addr::UNSPECIFIED), 0) + } else { + SocketAddr::new(IpAddr::V6(Ipv6Addr::UNSPECIFIED), 0) + } + }); + + let endpoint = ClientBuilder::new_with_rustls_client_config(config) + .bind(bind) + .await?; + + let mut connecting = endpoint.connect(remote_addr, &server_name, None)?; + if enable_early_data { + match connecting.into_0rtt() { + Ok(conn) => return Ok(conn), + Err(f) => connecting = f, + } + } + + let conn = compio::time::timeout(timeout, connecting) + .await + .map_err(|_| std::io::Error::from(std::io::ErrorKind::TimedOut))? + .map_err(ToNetError::to_net_error)?; + Ok(conn) +} + +pub trait ToNetError { + type QuinnError; + + fn to_net_error(self) -> Self::QuinnError; +} + +impl ToNetError for OpenStreamError { + type QuinnError = NetError; + + fn to_net_error(self) -> Self::QuinnError { + match self { + Self::ConnectionLost(err) => err.to_net_error().into(), + Self::StreamsExhausted => NetError::from(self.to_string()), + } + } +} + +impl ToNetError for ConnectionError { + type QuinnError = quinn::ConnectionError; + + fn to_net_error(self) -> Self::QuinnError { + use quinn::ConnectionError as QuinnConnectionError; + + match self { + Self::VersionMismatch => QuinnConnectionError::VersionMismatch, + Self::TransportError(err) => QuinnConnectionError::TransportError(err), + Self::ConnectionClosed(err) => QuinnConnectionError::ConnectionClosed(err), + Self::ApplicationClosed(err) => QuinnConnectionError::ApplicationClosed(err), + Self::Reset => QuinnConnectionError::Reset, + Self::TimedOut => QuinnConnectionError::TimedOut, + Self::LocallyClosed => QuinnConnectionError::LocallyClosed, + Self::CidsExhausted => QuinnConnectionError::CidsExhausted, + } + } +} + +impl ToNetError for WriteError { + type QuinnError = NetError; + + fn to_net_error(self) -> Self::QuinnError { + use quinn::WriteError as QuinnWriteError; + + let err = match self { + Self::Stopped(code) => QuinnWriteError::Stopped(code), + Self::ConnectionLost(err) => QuinnWriteError::ConnectionLost(err.to_net_error()), + Self::ClosedStream => QuinnWriteError::ClosedStream, + Self::ZeroRttRejected => QuinnWriteError::ZeroRttRejected, + _ => return NetError::from(self.to_string()), + }; + NetError::QuinnWriteError(err) + } +} + +impl ToNetError for ReadExactError { + type QuinnError = quinn::ReadExactError; + + fn to_net_error(self) -> Self::QuinnError { + use quinn::ReadExactError as QuinnReadExactError; + + match self { + Self::FinishedEarly(len) => QuinnReadExactError::FinishedEarly(len), + Self::ReadError(err) => QuinnReadExactError::ReadError(err.to_net_error()), + } + } +} + +impl ToNetError for ReadError { + type QuinnError = quinn::ReadError; + + fn to_net_error(self) -> Self::QuinnError { + use quinn::ReadError as QuinnReadError; + + match self { + Self::Reset(code) => QuinnReadError::Reset(code), + Self::ConnectionLost(err) => QuinnReadError::ConnectionLost(err.to_net_error()), + Self::ClosedStream => QuinnReadError::ClosedStream, + Self::IllegalOrderedRead => QuinnReadError::IllegalOrderedRead, + Self::ZeroRttRejected => QuinnReadError::ZeroRttRejected, + } + } +} diff --git a/cyper-hickory/src/util/tcp.rs b/cyper-hickory/src/util/tcp.rs new file mode 100644 index 0000000..a9ca00c --- /dev/null +++ b/cyper-hickory/src/util/tcp.rs @@ -0,0 +1,31 @@ +use std::{io, net::SocketAddr, time::Duration}; + +use compio::net::{TcpSocket, TcpStream}; + +pub async fn connect_tcp( + server_addr: SocketAddr, + bind_addr: Option, + timeout: Option, +) -> io::Result { + let fut = async move { + if let Some(bind_addr) = bind_addr { + let socket = if bind_addr.is_ipv4() { + TcpSocket::new_v4().await? + } else { + TcpSocket::new_v6().await? + }; + socket.bind(bind_addr).await?; + socket.connect(server_addr).await + } else { + TcpStream::connect(server_addr).await + } + }; + if let Some(timeout) = timeout { + compio::time::timeout(timeout, fut) + .await + .map_err(|_| io::ErrorKind::TimedOut.into()) + .flatten() + } else { + fut.await + } +}