Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@ repository = "https://github.com/compio-rs/cyper"

[workspace.dependencies]
compio = { version = "0.19.0-rc.1", default-features = false }
compio-log = "0.1.0"

cyper-core = { path = "./cyper-core", default-features = false, version = "0.9.0-rc.2" }
cyper-axum = { path = "./cyper-axum", default-features = false, version = "0.9.0-rc.1" }
cyper-hickory = { path = "./cyper-hickory", default-features = false, version = "0.1.0-rc.1" }
Expand Down
2 changes: 1 addition & 1 deletion cyper-axum/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ repository = { workspace = true }

[dependencies]
compio = { workspace = true, features = ["net", "time"] }
compio-log = "0.1.0"
compio-log = { workspace = true }
cyper-core = { workspace = true }
socket2 = { workspace = true }

Expand Down
28 changes: 25 additions & 3 deletions cyper-hickory/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,13 @@ repository.workspace = true

[dependencies]
compio = { workspace = true, features = ["net", "time", "io-compat"] }
compio-log = { workspace = true }

hickory-net = { workspace = true }
hickory-resolver = { workspace = true }

async-trait = "0.1"
futures-channel = { workspace = true, optional = true }
futures-util = { workspace = true }
send_wrapper = { workspace = true, features = ["futures"] }

Expand All @@ -28,6 +30,7 @@ hyper-util = { workspace = true, optional = true, features = [
"http1",
"http2",
] }
quinn = { version = "0.11", default-features = false }
tower-service = { workspace = true, optional = true }

[dev-dependencies]
Expand All @@ -46,17 +49,36 @@ dnssec = ["hickory-resolver/dnssec-ring", "hickory-server/dnssec-ring"]
__ring = ["compio/ring", "compio/rustls"]
__tls = ["compio/tls", "__ring"]
tls = ["__tls", "hickory-resolver/tls-ring", "hickory-server/tls-ring"]
__http = ["dep:http"]
https = [
"__tls",
"tls",
"__http",
"hickory-resolver/https-ring",
"hickory-server/https-ring",
"dep:cyper-core",
"dep:http",
"dep:http-body-util",
"dep:hyper",
"dep:hyper-util",
"dep:tower-service",
]
all = ["dnssec", "tls", "https"]
__quic = ["__ring", "compio/quic"]
quic = [
"__quic",
"tls",
"compio/bytes",
"hickory-resolver/quic-ring",
"hickory-server/quic-ring",
]
h3 = [
"quic",
"__http",
"compio/h3",
"dep:futures-channel",
"hickory-resolver/h3-ring",
"hickory-server/h3-ring",
]
all = ["dnssec", "tls", "https", "quic", "h3"]

enable_log = ["compio-log/enable_log"]

nightly = ["compio/nightly", "cyper-core?/nightly"]
201 changes: 201 additions & 0 deletions cyper-hickory/src/h3.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,201 @@
use std::{
net::SocketAddr,
pin::Pin,
sync::Arc,
task::{Context, Poll},
time::Duration,
};

use compio::{
bytes::{Buf, Bytes},
rustls::ClientConfig,
};
use compio_log::{debug, warn};
use futures_channel::mpsc::Sender;
use futures_util::{FutureExt, SinkExt, Stream};
use hickory_net::{
NetError,
proto::op::{DnsRequest, DnsResponse},
xfer::{DnsExchange, DnsRequestSender, DnsResponseStream},
};
use send_wrapper::SendWrapper;

use crate::CompioRuntimeProvider;

const H3_ALPN: &[u8] = b"h3";

pub async fn connect_h3(
server_name: Arc<str>,
path: Arc<str>,
remote_addr: SocketAddr,
bind_addr: Option<SocketAddr>,
config: ClientConfig,
enable_grease: bool,
timeout: Duration,
) -> Result<DnsExchange<CompioRuntimeProvider>, NetError> {
let conn = crate::connect_quic(
server_name.clone(),
remote_addr,
bind_addr,
config,
timeout,
H3_ALPN,
)
.await?;

let (mut driver, send_request) = compio::quic::h3::client::builder()
.send_grease(enable_grease)
.build(conn)
.await
.map_err(|e| NetError::from(format!("h3 client error: {e}")))?;

let (tx, mut rx) = futures_channel::mpsc::channel::<()>(1);

compio::runtime::spawn(async move {
futures_util::select! {
error = driver.wait_idle().fuse() => {
if !error.is_h3_no_error() {
warn!("h3 connection failed to close: {}", error);
}
}
_ = rx.recv().fuse() => {
debug!("h3 connection is shutting down: {}", remote_addr);
}
}
})
.detach();

let stream = H3RequestSender::new(send_request, server_name, path, tx);
let (exchange, bg) = DnsExchange::from_stream(stream);
compio::runtime::spawn(bg).detach();
Ok(exchange)
}

type SendRequest = compio::quic::h3::client::SendRequest<compio::quic::h3::OpenStreams, Bytes>;

struct H3RequestSender {
send_request: SendWrapper<SendRequest>,
server_name: Arc<str>,
path: Arc<str>,
tx: Sender<()>,
is_shutdown: bool,
}

impl H3RequestSender {
fn new(
send_request: SendRequest,
server_name: Arc<str>,
path: Arc<str>,
tx: Sender<()>,
) -> Self {
Self {
send_request: SendWrapper::new(send_request),
server_name,
path,
tx,
is_shutdown: false,
}
}

async fn inner_send(
mut send_request: SendWrapper<SendRequest>,
server_name: Arc<str>,
path: Arc<str>,
mut request: DnsRequest,
) -> Result<DnsResponse, NetError> {
request.metadata.id = 0;
let bytes = request.to_vec()?;

let request = crate::build_request(&server_name, &path, bytes.len())?;

let mut stream = send_request
.send_request(request)
.await
.map_err(|e| NetError::from(format!("h3 send request error: {e}")))?;

stream
.send_data(Bytes::from(bytes))
.await
.map_err(|e| NetError::from(format!("h3 send data error: {e}")))?;

stream
.finish()
.await
.map_err(|e| NetError::from(format!("h3 finish error: {e}")))?;

let resp = stream
.recv_response()
.await
.map_err(|e| NetError::from(format!("h3 recv response error: {e}")))?;
let (resp, ()) = resp.into_parts();

let content_length = crate::get_content_length(&resp.headers)?;

let mut response_bytes =
Vec::with_capacity(content_length.unwrap_or(512).clamp(512, 4_096));
while let Some(chunk) = stream
.recv_data()
.await
.map_err(|e| NetError::from(format!("h3 recv data error: {e}")))?
{
response_bytes.extend_from_slice(chunk.chunk());

if let Some(content_length) = content_length
&& response_bytes.len() >= content_length
{
break;
}
}

crate::build_response(resp, content_length, response_bytes)
}
}

impl DnsRequestSender for H3RequestSender {
fn send_message(&mut self, request: DnsRequest) -> DnsResponseStream {
if self.is_shutdown {
panic!("can not send messages after stream is shutdown")
}

Box::pin(SendWrapper::new(Self::inner_send(
self.send_request.clone(),
self.server_name.clone(),
self.path.clone(),
request,
)))
.into()
}

fn shutdown(&mut self) {
self.is_shutdown = true;
compio::runtime::spawn({
let mut tx = self.tx.clone();
async move {
let _ = tx.send(()).await;
}
})
.detach();
}

fn is_shutdown(&self) -> bool {
self.is_shutdown
}
}

impl Stream for H3RequestSender {
type Item = Result<(), NetError>;

fn poll_next(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
if self.is_shutdown {
return Poll::Ready(None);
}

if self.tx.is_closed() {
return Poll::Ready(Some(Err(NetError::from(
"h3 connection is already shutdown",
))));
}

Poll::Ready(Some(Ok(())))
}
}
Loading
Loading