diff --git a/src/server.rs b/src/server.rs index 929d96eb3..feea30f8a 100644 --- a/src/server.rs +++ b/src/server.rs @@ -6,6 +6,7 @@ use std::future::Future; use std::net::SocketAddr; #[cfg(feature = "tls")] use std::path::Path; +use std::time::Duration; use futures_util::{future, FutureExt, TryFuture, TryStream, TryStreamExt}; use hyper::server::conn::AddrIncoming; @@ -29,6 +30,7 @@ where Server { pipeline: false, filter, + tcp_keepalive_config: TcpKeepaliveConfig::default(), } } @@ -37,6 +39,22 @@ where pub struct Server { pipeline: bool, filter: F, + tcp_keepalive_config: TcpKeepaliveConfig, +} + +#[derive(Default, Debug, Clone, Copy)] +struct TcpKeepaliveConfig { + time: Option, + interval: Option, + retries: Option, +} + +impl TcpKeepaliveConfig { + fn configure(&self, incoming: &mut AddrIncoming) { + incoming.set_keepalive(self.time); + incoming.set_keepalive_interval(self.interval); + incoming.set_keepalive_retries(self.retries); + } } /// A Warp Server ready to filter requests over TLS. @@ -64,9 +82,10 @@ macro_rules! into_service { } macro_rules! addr_incoming { - ($addr:expr) => {{ + ($this:ident, $addr:expr) => {{ let mut incoming = AddrIncoming::bind($addr)?; incoming.set_nodelay(true); + $this.configure(&mut incoming); let addr = incoming.local_addr(); (addr, incoming) }}; @@ -75,7 +94,8 @@ macro_rules! addr_incoming { macro_rules! bind_inner { ($this:ident, $addr:expr) => {{ let service = into_service!($this.filter); - let (addr, incoming) = addr_incoming!($addr); + let config = &$this.tcp_keepalive_config; + let (addr, incoming) = addr_incoming!(config, $addr); let srv = HyperServer::builder(incoming) .http1_pipeline_flush($this.pipeline) .serve(service); @@ -84,7 +104,8 @@ macro_rules! bind_inner { (tls: $this:ident, $addr:expr) => {{ let service = into_service!($this.server.filter); - let (addr, incoming) = addr_incoming!($addr); + let config = &$this.server.tcp_keepalive_config; + let (addr, incoming) = addr_incoming!(config, $addr); let tls = $this.tls.build()?; let srv = HyperServer::builder(crate::tls::TlsAcceptor::new(tls, incoming)) .http1_pipeline_flush($this.server.pipeline) @@ -400,6 +421,27 @@ where self } + /// Set the duration to remain idle before sending TCP keepalive probes. + /// + /// If `None` is specified, keepalive is disabled. + pub fn set_tcp_keepalive(mut self, time: Option) -> Self { + self.tcp_keepalive_config.time = time; + self + } + + /// Set the duration between two successive TCP keepalive retransmissions, + /// if acknowledgement to the previous keepalive transmission is not received. + pub fn set_tcp_keepalive_interval(mut self, interval: Option) -> Self { + self.tcp_keepalive_config.interval = interval; + self + } + + /// Set the number of retransmissions to be carried out before declaring that remote end is not available. + pub fn set_tcp_keepalive_retries(mut self, retries: Option) -> Self { + self.tcp_keepalive_config.retries = retries; + self + } + /// Configure a server to use TLS. /// /// *This function requires the `"tls"` feature.*