diff --git a/Cargo.lock b/Cargo.lock index 818378ba28..606780fe9a 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -262,6 +262,17 @@ dependencies = [ "tracing", ] +[[package]] +name = "axum-client-ip" +version = "1.3.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a8ba1af5b620232acf37f2eb6d22151ea465491e0b4c25f552d1990f64ec5a67" +dependencies = [ + "axum", + "client-ip", + "serde", +] + [[package]] name = "axum-core" version = "0.5.6" @@ -636,6 +647,15 @@ version = "1.1.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "c8d4a3bb8b1e0c1050499d1815f5ab16d04f0959b233085fb31653fbfc9d98f9" +[[package]] +name = "client-ip" +version = "0.2.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "39d2056bf065c8b4bce5a8898d40e175211ff4410add2a84d695845d3937c729" +dependencies = [ + "http", +] + [[package]] name = "cmake" version = "0.1.57" @@ -4363,6 +4383,7 @@ dependencies = [ "anyhow", "arc-swap", "axum", + "axum-client-ip", "base64 0.22.1", "bitflags", "bytes", @@ -4535,6 +4556,7 @@ dependencies = [ "libc", "metrics", "metrics-exporter-prometheus", + "parking_lot", "rand 0.9.2", "rayon", "scopeguard", diff --git a/Cargo.toml b/Cargo.toml index bcd1fac120..40ad1339b8 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -37,6 +37,7 @@ anyhow = "1.0.79" arc-swap = "1.6.0" async-trait = "0.1" axum = "0.8" +axum-client-ip = "1.3.1" backon = "1.5.1" base64 = "0.22.0" bitflags = "2.6" diff --git a/core/src/blockchain_rpc/mod.rs b/core/src/blockchain_rpc/mod.rs index 6531d6640f..1eb10e586f 100644 --- a/core/src/blockchain_rpc/mod.rs +++ b/core/src/blockchain_rpc/mod.rs @@ -9,6 +9,7 @@ pub use self::client::{ #[cfg(feature = "s3")] pub use self::providers::S3RpcDataProvider; pub use self::providers::{IntoRpcDataProvider, StorageRpcDataProvider}; +pub use self::rate_limits::{BlockchainRpcRateLimitsConfig, BlockchainRpcTrafficLimits}; #[cfg(feature = "s3")] pub use self::service::S3ProxyConfig; pub use self::service::{ @@ -18,6 +19,7 @@ pub use self::service::{ mod broadcast_listener; mod client; mod providers; +mod rate_limits; mod service; pub const BAD_REQUEST_ERROR_CODE: u32 = 1; diff --git a/core/src/blockchain_rpc/rate_limits.rs b/core/src/blockchain_rpc/rate_limits.rs new file mode 100644 index 0000000000..e522144993 --- /dev/null +++ b/core/src/blockchain_rpc/rate_limits.rs @@ -0,0 +1,129 @@ +use std::net::IpAddr; +use std::num::NonZeroU32; + +use serde::{Deserialize, Serialize}; +use tycho_network::{ + OverlayIngressPolicyDecision, PublicOverlayRateLimitPolicy, PublicOverlayRateLimiter, + ServiceRequest, try_handle_prefix, +}; +use tycho_util::FastHashSet; +use tycho_util::rate_limit::{RateLimitConfig, RateLimitPolicy, TrafficLimit}; + +use crate::proto::blockchain::rpc; +use crate::proto::overlay; + +#[derive(Debug, Clone, Default, Serialize, Deserialize)] +#[serde(default)] +pub struct BlockchainRpcRateLimitsConfig { + pub limiter: RateLimitConfig, + pub whitelist: Vec, + pub traffic: BlockchainRpcTrafficLimits, +} + +impl From for PublicOverlayRateLimiter { + fn from(config: BlockchainRpcRateLimitsConfig) -> Self { + PublicOverlayRateLimiter::new(config.limiter, BlockchainRpcRateLimitPolicy { + traffic: config.traffic, + whitelist: config.whitelist.into_iter().collect(), + }) + } +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(default)] +pub struct BlockchainRpcTrafficLimits { + pub light_queries: TrafficLimit, + pub heavy_queries: TrafficLimit, + pub broadcasts: TrafficLimit, +} + +impl Default for BlockchainRpcTrafficLimits { + fn default() -> Self { + Self { + light_queries: TrafficLimit::new( + NonZeroU32::new(20).unwrap(), + NonZeroU32::new(20).unwrap(), + ), + heavy_queries: TrafficLimit::new( + NonZeroU32::new(10).unwrap(), + NonZeroU32::new(10).unwrap(), + ), + broadcasts: TrafficLimit::new( + NonZeroU32::new(20).unwrap(), + NonZeroU32::new(20).unwrap(), + ), + } + } +} + +impl BlockchainRpcTrafficLimits { + fn policy( + &self, + class: BlockchainRpcTrafficClass, + ) -> RateLimitPolicy { + let limit = match class { + BlockchainRpcTrafficClass::LightQuery => self.light_queries, + BlockchainRpcTrafficClass::HeavyQuery => self.heavy_queries, + BlockchainRpcTrafficClass::Broadcast => self.broadcasts, + }; + + RateLimitPolicy { class, limit } + } +} + +struct BlockchainRpcRateLimitPolicy { + traffic: BlockchainRpcTrafficLimits, + whitelist: FastHashSet, +} + +impl BlockchainRpcRateLimitPolicy { + fn classify(constructor: u32) -> BlockchainRpcTrafficClass { + match constructor { + overlay::Ping::TL_ID + | rpc::GetArchiveInfo::TL_ID + | rpc::GetPersistentShardStateInfo::TL_ID + | rpc::GetPersistentQueueStateInfo::TL_ID + | rpc::GetArchiveChunk::TL_ID + | rpc::GetBlockDataChunk::TL_ID => BlockchainRpcTrafficClass::LightQuery, + _ => BlockchainRpcTrafficClass::HeavyQuery, + } + } +} + +impl PublicOverlayRateLimitPolicy for BlockchainRpcRateLimitPolicy { + type Class = BlockchainRpcTrafficClass; + + fn classify_query(&self, req: &ServiceRequest) -> OverlayIngressPolicyDecision { + if self.whitelist.contains(&req.metadata.remote_address.ip()) { + return OverlayIngressPolicyDecision::Bypass; + } + + let constructor = match try_handle_prefix(req) { + Ok((constructor, _)) => constructor, + Err(e) => { + tracing::debug!("failed to deserialize query: {e}"); + return OverlayIngressPolicyDecision::Drop; + } + }; + + let class = BlockchainRpcRateLimitPolicy::classify(constructor); + OverlayIngressPolicyDecision::Allow(self.traffic.policy(class)) + } + + fn classify_message(&self, req: &ServiceRequest) -> OverlayIngressPolicyDecision { + if self.whitelist.contains(&req.metadata.remote_address.ip()) { + OverlayIngressPolicyDecision::Bypass + } else { + OverlayIngressPolicyDecision::Allow( + self.traffic.policy(BlockchainRpcTrafficClass::Broadcast), + ) + } + } +} + +#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] +enum BlockchainRpcTrafficClass { + LightQuery, + HeavyQuery, + Broadcast, +} diff --git a/core/src/blockchain_rpc/service.rs b/core/src/blockchain_rpc/service.rs index 6d39a4a41d..3dfd23628c 100644 --- a/core/src/blockchain_rpc/service.rs +++ b/core/src/blockchain_rpc/service.rs @@ -14,6 +14,7 @@ use tycho_util::metrics::HistogramGuard; use crate::blockchain_rpc::broadcast_listener::{BroadcastListener, NoopBroadcastListener}; use crate::blockchain_rpc::providers::{IntoRpcDataProvider, RpcDataProvider}; +use crate::blockchain_rpc::rate_limits::BlockchainRpcRateLimitsConfig; use crate::blockchain_rpc::{BAD_REQUEST_ERROR_CODE, INTERNAL_ERROR_CODE, NOT_FOUND_ERROR_CODE}; use crate::proto::blockchain::*; use crate::proto::overlay; @@ -64,6 +65,11 @@ pub struct BlockchainRpcServiceConfig { /// Default: yes. pub serve_persistent_states: bool, + /// Rate limits for inbound blockchain-rpc traffic. + /// + /// Default: disabled. + pub rate_limits: Option, + /// S3 proxy configuration. /// /// Default: enabled. @@ -76,6 +82,7 @@ impl Default for BlockchainRpcServiceConfig { Self { max_key_blocks_list_len: 8, serve_persistent_states: true, + rate_limits: None, #[cfg(feature = "s3")] s3_proxy: Some(S3ProxyConfig::default()), } diff --git a/core/src/node/mod.rs b/core/src/node/mod.rs index 1acfbf7d02..7dcd89746c 100644 --- a/core/src/node/mod.rs +++ b/core/src/node/mod.rs @@ -613,7 +613,15 @@ impl ConfiguredNetwork { let public_overlay = PublicOverlay::builder(zerostate.compute_public_overlay_id()) .named("blockchain_rpc") .with_peer_resolver(self.peer_resolver.clone()) + .with_rate_limiter( + base_config + .blockchain_rpc_service + .rate_limits + .clone() + .map(|rate_limit| rate_limit.into()), + ) .build(blockchain_rpc_service.clone()); + self.overlay_service.add_public_overlay(&public_overlay); let blockchain_rpc_client = BlockchainRpcClient::builder() diff --git a/network/src/lib.rs b/network/src/lib.rs index 8f9de8747c..01a563ab20 100644 --- a/network/src/lib.rs +++ b/network/src/lib.rs @@ -20,11 +20,12 @@ pub use types::{ pub use self::overlay::{ ChooseMultiplePrivateOverlayEntries, ChooseMultiplePublicOverlayEntries, OverlayConfig, - OverlayId, OverlayService, OverlayServiceBackgroundTasks, OverlayServiceBuilder, - PrivateOverlay, PrivateOverlayBuilder, PrivateOverlayEntries, PrivateOverlayEntriesEvent, - PrivateOverlayEntriesReadGuard, PrivateOverlayEntriesWriteGuard, PrivateOverlayEntryData, - PublicOverlay, PublicOverlayBuilder, PublicOverlayEntries, PublicOverlayEntriesReadGuard, - PublicOverlayEntryData, UnknownPeersQueue, + OverlayId, OverlayIngressPolicyDecision, OverlayService, OverlayServiceBackgroundTasks, + OverlayServiceBuilder, PrivateOverlay, PrivateOverlayBuilder, PrivateOverlayEntries, + PrivateOverlayEntriesEvent, PrivateOverlayEntriesReadGuard, PrivateOverlayEntriesWriteGuard, + PrivateOverlayEntryData, PublicOverlay, PublicOverlayBuilder, PublicOverlayEntries, + PublicOverlayEntriesReadGuard, PublicOverlayEntryData, PublicOverlayRateLimitPolicy, + PublicOverlayRateLimiter, UnknownPeersQueue, }; pub use self::util::{ NetworkExt, Routable, Router, RouterBuilder, UnknownPeerError, check_peer_signature, diff --git a/network/src/overlay/mod.rs b/network/src/overlay/mod.rs index e8da8d60b4..57abf7f471 100644 --- a/network/src/overlay/mod.rs +++ b/network/src/overlay/mod.rs @@ -19,6 +19,9 @@ pub use self::public_overlay::{ ChooseMultiplePublicOverlayEntries, PublicOverlay, PublicOverlayBuilder, PublicOverlayEntries, PublicOverlayEntriesReadGuard, PublicOverlayEntryData, UnknownPeersQueue, }; +pub use self::rate_limits::{ + OverlayIngressPolicyDecision, PublicOverlayRateLimitPolicy, PublicOverlayRateLimiter, +}; use crate::dht::DhtService; use crate::network::Network; use crate::proto::overlay::{PublicEntriesResponse, PublicEntry, PublicEntryResponse, rpc}; @@ -32,6 +35,7 @@ mod metrics; mod overlay_id; mod private_overlay; mod public_overlay; +mod rate_limits; mod tasks_stream; pub struct OverlayServiceBackgroundTasks { diff --git a/network/src/overlay/public_overlay.rs b/network/src/overlay/public_overlay.rs index d87ae5385b..be17963104 100644 --- a/network/src/overlay/public_overlay.rs +++ b/network/src/overlay/public_overlay.rs @@ -17,6 +17,7 @@ use crate::dht::{PeerResolver, PeerResolverHandle}; use crate::network::Network; use crate::overlay::OverlayId; use crate::overlay::metrics::Metrics; +use crate::overlay::rate_limits::PublicOverlayRateLimiter; use crate::proto::overlay::{PublicEntry, PublicEntryToSign, rpc}; use crate::types::{BoxService, PeerId, Request, Response, Service, ServiceExt, ServiceRequest}; use crate::util::NetworkExt; @@ -27,6 +28,7 @@ pub struct PublicOverlayBuilder { entry_ttl: Duration, banned_peer_ids: FastDashSet, peer_resolver: Option, + rate_limiter: Option, name: Option<&'static str>, } @@ -68,6 +70,11 @@ impl PublicOverlayBuilder { self } + pub fn with_rate_limiter(mut self, rate_limiter: Option) -> Self { + self.rate_limiter = rate_limiter; + self + } + /// Name of the overlay used in metrics. pub fn named(mut self, name: &'static str) -> Self { self.name = Some(name); @@ -97,6 +104,7 @@ impl PublicOverlayBuilder { min_capacity: self.min_capacity, entry_ttl_sec, peer_resolver: self.peer_resolver, + rate_limiter: self.rate_limiter, entries: RwLock::new(entries), entries_added: Notify::new(), entries_changed: Notify::new(), @@ -130,6 +138,7 @@ impl PublicOverlay { entry_ttl: Duration::from_secs(3600), banned_peer_ids: Default::default(), peer_resolver: None, + rate_limiter: None, name: None, } } @@ -218,7 +227,7 @@ impl PublicOverlay { pub(crate) fn handle_query(&self, req: ServiceRequest) -> BoxFutureOrNoop> { self.inner.metrics.record_rx(req.body.len()); - if self.check_peer_id(&req.metadata.peer_id) { + if self.check_peer_id(&req.metadata.peer_id) && self.allow_query(&req) { BoxFutureOrNoop::future(self.inner.service.on_query(req)) } else { BoxFutureOrNoop::Noop @@ -227,13 +236,27 @@ impl PublicOverlay { pub(crate) fn handle_message(&self, req: ServiceRequest) -> BoxFutureOrNoop<()> { self.inner.metrics.record_rx(req.body.len()); - if self.check_peer_id(&req.metadata.peer_id) { + if self.check_peer_id(&req.metadata.peer_id) && self.allow_message(&req) { BoxFutureOrNoop::future(self.inner.service.on_message(req)) } else { BoxFutureOrNoop::Noop } } + fn allow_query(&self, req: &ServiceRequest) -> bool { + self.inner + .rate_limiter + .as_ref() + .is_none_or(|rate_limiter| rate_limiter.allow_query(req)) + } + + fn allow_message(&self, req: &ServiceRequest) -> bool { + self.inner + .rate_limiter + .as_ref() + .is_none_or(|rate_limiter| rate_limiter.allow_message(req)) + } + fn check_peer_id(&self, peer_id: &PeerId) -> bool { // TODO: Merge `banned_peer_ids` with `entires`? if self.inner.banned_peer_ids.contains(peer_id) { @@ -416,6 +439,7 @@ struct Inner { min_capacity: usize, entry_ttl_sec: u32, peer_resolver: Option, + rate_limiter: Option, entries: RwLock, entry_count: AtomicUsize, entries_added: Notify, diff --git a/network/src/overlay/rate_limits.rs b/network/src/overlay/rate_limits.rs new file mode 100644 index 0000000000..9e1780fa68 --- /dev/null +++ b/network/src/overlay/rate_limits.rs @@ -0,0 +1,96 @@ +use std::hash::Hash; +use std::net::IpAddr; +use std::sync::Arc; + +use tycho_util::rate_limit::{RateLimitConfig, RateLimitPolicy, RateLimitVerdict, RateLimiter}; + +use crate::types::ServiceRequest; + +pub enum OverlayIngressPolicyDecision { + Allow(RateLimitPolicy), + Bypass, + Drop, +} + +pub trait PublicOverlayRateLimitPolicy: Send + Sync + 'static { + type Class: Copy + Eq + Hash + Send + Sync + 'static; + + fn classify_query(&self, req: &ServiceRequest) -> OverlayIngressPolicyDecision; + + fn classify_message(&self, req: &ServiceRequest) -> OverlayIngressPolicyDecision; +} + +trait PublicOverlayRateLimitHandler: Send + Sync + 'static { + fn allow_query(&self, req: &ServiceRequest) -> bool; + + fn allow_message(&self, req: &ServiceRequest) -> bool; +} + +#[derive(Clone)] +pub struct PublicOverlayRateLimiter { + inner: Arc, +} + +struct PolicyRateLimiter

+where + P: PublicOverlayRateLimitPolicy, +{ + limiter: RateLimiter, + policy: P, +} + +impl PublicOverlayRateLimiter { + pub fn new

(config: RateLimitConfig, policy: P) -> Self + where + P: PublicOverlayRateLimitPolicy, + { + Self { + inner: Arc::new(PolicyRateLimiter { + limiter: RateLimiter::new(config), + policy, + }), + } + } + + pub(crate) fn allow_query(&self, req: &ServiceRequest) -> bool { + self.inner.allow_query(req) + } + + pub(crate) fn allow_message(&self, req: &ServiceRequest) -> bool { + self.inner.allow_message(req) + } +} + +impl

PublicOverlayRateLimitHandler for PolicyRateLimiter

+where + P: PublicOverlayRateLimitPolicy, +{ + fn allow_query(&self, req: &ServiceRequest) -> bool { + self.check( + req.metadata.remote_address.ip(), + self.policy.classify_query(req), + ) + } + + fn allow_message(&self, req: &ServiceRequest) -> bool { + self.check( + req.metadata.remote_address.ip(), + self.policy.classify_message(req), + ) + } +} + +impl

PolicyRateLimiter

+where + P: PublicOverlayRateLimitPolicy, +{ + fn check(&self, ip: IpAddr, decision: OverlayIngressPolicyDecision) -> bool { + match decision { + OverlayIngressPolicyDecision::Allow(policy) => { + matches!(self.limiter.check(&ip, policy), RateLimitVerdict::Allow) + } + OverlayIngressPolicyDecision::Bypass => true, + OverlayIngressPolicyDecision::Drop => false, + } + } +} diff --git a/rpc/Cargo.toml b/rpc/Cargo.toml index 2642c177b1..73c36171c3 100644 --- a/rpc/Cargo.toml +++ b/rpc/Cargo.toml @@ -14,6 +14,7 @@ ahash = { workspace = true } anyhow = { workspace = true } arc-swap = { workspace = true } axum = { workspace = true } +axum-client-ip = { workspace = true } base64 = { workspace = true } bitflags = { workspace = true } bytes = { workspace = true } diff --git a/rpc/src/config.rs b/rpc/src/config.rs index d9bda8678e..4f694a0620 100644 --- a/rpc/src/config.rs +++ b/rpc/src/config.rs @@ -1,12 +1,16 @@ use std::net::{Ipv4Addr, SocketAddr}; +use std::num::NonZeroU32; use std::path::{Path, PathBuf}; use std::time::Duration; +use axum_client_ip::ClientIpSource; use serde::{Deserialize, Serialize}; use tycho_types::models::StdAddr; use tycho_util::config::PartialConfig; use tycho_util::serde_helpers; +use crate::endpoint::RpcRateLimitsConfig; + #[derive(Debug, Clone, Eq, PartialEq, Serialize, Deserialize, PartialConfig)] #[serde(default)] pub struct RpcConfig { @@ -43,6 +47,14 @@ pub struct RpcConfig { /// Subscriptions limits and buffering. pub subscriptions: SubscriptionsConfig, + /// Source for resolving the real client IP + pub real_ip_source: ClientIpSource, + + /// Rate limits for inbound RPC requests. + /// + /// Default: disabled. + pub rate_limits: Option, + #[important] pub storage: RpcStorageConfig, } @@ -146,6 +158,8 @@ impl Default for RpcConfig { max_parallel_block_downloads: 10, run_get_method: RunGetMethodConfig::default(), subscriptions: SubscriptionsConfig::default(), + real_ip_source: ClientIpSource::ConnectInfo, + rate_limits: None, storage: RpcStorageConfig::Full { gc: Some(Default::default()), force_reindex: false, @@ -160,6 +174,7 @@ impl Default for RpcConfig { pub struct SubscriptionsConfig { pub max_clients: u32, pub max_addrs: u32, + pub max_streams_per_addr: NonZeroU32, /// Pending updates buffered per client; clamped to at least 1. pub queue_depth: usize, } @@ -169,6 +184,7 @@ impl Default for SubscriptionsConfig { Self { max_clients: 1_000_000, max_addrs: 1_000_000, + max_streams_per_addr: NonZeroU32::new(5).unwrap(), queue_depth: 5, } } diff --git a/rpc/src/endpoint/jrpc/stream.rs b/rpc/src/endpoint/jrpc/stream.rs index d96b7d3641..304fdb20e1 100644 --- a/rpc/src/endpoint/jrpc/stream.rs +++ b/rpc/src/endpoint/jrpc/stream.rs @@ -2,11 +2,13 @@ use std::convert::Infallible; use std::sync::Arc; use anyhow::anyhow; +use axum::Extension; use axum::extract::{FromRef, Query, State}; use axum::http::StatusCode; use axum::response::sse::{Event, Sse}; use axum::response::{IntoResponse, Response}; use axum::routing::get; +use axum_client_ip::ClientIp; use futures_util::{StreamExt, stream}; use serde::{Deserialize, Serialize}; use serde_json::json; @@ -16,6 +18,7 @@ use tycho_util::serde_helpers; use uuid::Uuid; use super::RpcStateError; +use crate::endpoint::rate_limits::{ActiveStreamGuard, ActiveStreamLimiter}; use crate::state::{AccountUpdate, McTick, RegisterError, RpcState, RpcSubscriptions}; #[derive(Debug, Deserialize)] @@ -136,6 +139,8 @@ struct StreamState { subs: Arc, last_dropped: u64, mc_rx: watch::Receiver, + // Keeps stream slot occupied until the SSE connection is dropped. + _stream_guard: ActiveStreamGuard, } impl Drop for StreamState { @@ -150,6 +155,7 @@ impl StreamState { rx: mpsc::Receiver, subs: Arc, mc_rx: watch::Receiver, + stream_guard: ActiveStreamGuard, ) -> Self { let last_dropped = subs.dropped(uuid).unwrap_or(0); @@ -159,6 +165,7 @@ impl StreamState { subs, last_dropped, mc_rx, + _stream_guard: stream_guard, } } @@ -191,6 +198,8 @@ where pub async fn stream_route( State(state): State, Query(params): Query, + ClientIp(ip): ClientIp, + Extension(active_streams): Extension, ) -> Response where RpcState: FromRef, @@ -199,13 +208,18 @@ where let subs = Arc::::from_ref(&state); let mc_rx = state.subscribe_mc_tick(); - stream_route_inner(subs, mc_rx, params).await + let Some(stream_guard) = active_streams.try_acquire(ip) else { + return StatusCode::TOO_MANY_REQUESTS.into_response(); + }; + + stream_route_inner(subs, mc_rx, params, stream_guard).await } async fn stream_route_inner( subs: Arc, mc_rx: watch::Receiver, params: StreamParams, + stream_guard: ActiveStreamGuard, ) -> Response { use axum::Json; @@ -259,7 +273,7 @@ async fn stream_route_inner( }); let main_stream = stream::unfold( - StreamState::new(uuid, rx, subs, mc_rx), + StreamState::new(uuid, rx, subs, mc_rx, stream_guard), |mut st| async move { tokio::select! { maybe_update = st.rx.recv() => { @@ -307,6 +321,8 @@ async fn stream_route_inner( #[cfg(test)] mod tests { use std::fmt::Debug; + use std::net::{IpAddr, Ipv4Addr}; + use std::num::NonZeroU32; use std::str::FromStr; use std::time::Duration; @@ -397,8 +413,18 @@ mod tests { utime: 0, }); - let response = - stream_route_inner(subs.clone(), mc_rx, StreamParams { binary: false }).await; + let active_streams = ActiveStreamLimiter::new(NonZeroU32::new(1).unwrap(), vec![]); + let stream_guard = active_streams + .try_acquire(IpAddr::V4(Ipv4Addr::LOCALHOST)) + .expect("stream slot"); + + let response = stream_route_inner( + subs.clone(), + mc_rx, + StreamParams { binary: false }, + stream_guard, + ) + .await; let mut body = response.into_body(); let mut buf = Vec::new(); diff --git a/rpc/src/endpoint/mod.rs b/rpc/src/endpoint/mod.rs index 7bd95b83d3..e2c074fc08 100644 --- a/rpc/src/endpoint/mod.rs +++ b/rpc/src/endpoint/mod.rs @@ -1,20 +1,23 @@ use std::time::Duration; use anyhow::Result; -use axum::RequestExt; use axum::extract::{DefaultBodyLimit, FromRef, Request, State}; use axum::http::StatusCode; use axum::response::{IntoResponse, Response}; use axum::routing::{get, post}; +use axum::{Extension, RequestExt, middleware}; +use axum_client_ip::ClientIpSource; use tokio::net::TcpListener; pub use self::jrpc::JrpcEndpointCache; pub use self::proto::ProtoEndpointCache; +pub use self::rate_limits::RpcRateLimitsConfig; use crate::state::RpcState; use crate::util::mime::{APPLICATION_JSON, APPLICATION_PROTOBUF, get_mime_type}; pub mod jrpc; pub mod proto; +pub mod rate_limits; pub struct RpcEndpointBuilder { common: RpcEndpointBuilderCommon, @@ -55,9 +58,21 @@ impl RpcEndpointBuilder<()> { pub async fn bind(self, state: RpcState) -> Result { let listener = state.bind_socket().await?; + + let rl_config = &state.config().rate_limits; + let ip_source = state.config().real_ip_source.clone(); + let rate_limiter = rl_config.clone().map(Into::into); + let active_streams = rate_limits::ActiveStreamLimiter::new( + state.config().subscriptions.max_streams_per_addr, + rl_config.clone().map(|c| c.whitelist).unwrap_or_default(), + ); + Ok(RpcEndpoint::from_parts( listener, self.common.build(), + ip_source, + rate_limiter, + active_streams, state, )) } @@ -85,10 +100,24 @@ where S: Send + Sync + Clone + 'static, { pub async fn bind(self, state: S) -> Result { - let listener = RpcState::from_ref(&state).bind_socket().await?; + let rpc_state = RpcState::from_ref(&state); + + let listener = rpc_state.bind_socket().await?; + + let rl_config = &rpc_state.config().rate_limits; + let ip_source = rpc_state.config().real_ip_source.clone(); + let rate_limiter = rl_config.clone().map(Into::into); + let active_streams = rate_limits::ActiveStreamLimiter::new( + rpc_state.config().subscriptions.max_streams_per_addr, + rl_config.clone().map(|c| c.whitelist).unwrap_or_default(), + ); + Ok(RpcEndpoint::from_parts( listener, self.common.build::().merge(self.custom_routes), + ip_source, + rate_limiter, + active_streams, state, )) } @@ -149,7 +178,14 @@ impl RpcEndpoint { RpcEndpointBuilder::empty() } - pub fn from_parts(listener: TcpListener, router: axum::Router, state: S) -> Self + pub fn from_parts( + listener: TcpListener, + router: axum::Router, + ip_source: ClientIpSource, + rate_limiter: Option, + active_streams: rate_limits::ActiveStreamLimiter, + state: S, + ) -> Self where S: Clone + Send + Sync + 'static, { @@ -170,14 +206,31 @@ impl RpcEndpoint { let service = service.layer(tower_http::compression::CompressionLayer::new().gzip(true)); // Prepare routes - let router = router.layer(service).with_state(state); + let router = match rate_limiter { + Some(rate_limiter) => router.layer(middleware::from_fn_with_state( + rate_limiter, + rate_limits::rate_limit, + )), + None => router, + }; + + let router = router + .layer(Extension(ip_source)) + .layer(Extension(active_streams)) + .layer(service) + .with_state(state); // Done Self { listener, router } } pub async fn serve(self) -> std::io::Result<()> { - axum::serve(self.listener, self.router).await + axum::serve( + self.listener, + self.router + .into_make_service_with_connect_info::(), + ) + .await } } diff --git a/rpc/src/endpoint/rate_limits.rs b/rpc/src/endpoint/rate_limits.rs new file mode 100644 index 0000000000..164b280bb6 --- /dev/null +++ b/rpc/src/endpoint/rate_limits.rs @@ -0,0 +1,185 @@ +use std::net::IpAddr; +use std::num::NonZeroU32; +use std::sync::Arc; + +use axum::extract::Request; +use axum::http::{Method, StatusCode, header}; +use axum::middleware::Next; +use axum::response::{IntoResponse, Response}; +use axum_client_ip::ClientIp; +use serde::{Deserialize, Serialize}; +use tycho_util::rate_limit::{ + RateLimitConfig, RateLimitPolicy, RateLimitVerdict, RateLimiter, TrafficLimit, +}; +use tycho_util::{FastDashMap, FastHashSet}; + +use crate::util::ip::normalize_ip; + +#[derive(Debug, Default, Eq, PartialEq, Clone, Serialize, Deserialize)] +#[serde(default)] +pub struct RpcRateLimitsConfig { + pub limiter: RateLimitConfig, + pub traffic: RpcTrafficLimits, + pub whitelist: Vec, +} + +impl From for RpcRateLimiter { + fn from(config: RpcRateLimitsConfig) -> Self { + RpcRateLimiter { + limiter: RateLimiter::new(config.limiter), + traffic: config.traffic, + whitelist: config.whitelist.into_iter().map(normalize_ip).collect(), + } + } +} + +#[derive(Debug, Clone, Eq, PartialEq, Serialize, Deserialize)] +#[serde(default)] +pub struct RpcTrafficLimits { + pub requests: TrafficLimit, + pub streams: TrafficLimit, +} + +impl Default for RpcTrafficLimits { + fn default() -> Self { + Self { + requests: TrafficLimit::new(NonZeroU32::new(10).unwrap(), NonZeroU32::new(10).unwrap()), + streams: TrafficLimit::new(NonZeroU32::new(5).unwrap(), NonZeroU32::new(5).unwrap()), + } + } +} + +impl RpcTrafficLimits { + fn policy(&self, class: RpcTrafficClass) -> RateLimitPolicy { + let limit = match class { + RpcTrafficClass::Request => self.requests, + RpcTrafficClass::Stream => self.streams, + }; + + RateLimitPolicy { class, limit } + } +} + +#[derive(Clone)] +pub struct RpcRateLimiter { + limiter: RateLimiter, + traffic: RpcTrafficLimits, + whitelist: FastHashSet, +} + +impl RpcRateLimiter { + fn check(&self, ip: IpAddr, class: RpcTrafficClass) -> RateLimitVerdict { + let ip = normalize_ip(ip); + if self.whitelist.contains(&ip) { + return RateLimitVerdict::Allow; + } + + self.limiter.check(&ip, self.traffic.policy(class)) + } + + fn classify_request(req: &Request) -> Option { + match (req.method(), req.uri().path()) { + (&Method::GET, "/stream") => Some(RpcTrafficClass::Stream), + (&Method::POST, _) => Some(RpcTrafficClass::Request), + _ => None, + } + } +} + +#[derive(Clone)] +pub struct ActiveStreamLimiter { + active: Arc>, + max_streams_per_addr: NonZeroU32, + whitelist: Arc>, +} + +impl ActiveStreamLimiter { + pub fn new(max_streams_per_addr: NonZeroU32, whitelist: Vec) -> Self { + Self { + active: Arc::new(FastDashMap::default()), + max_streams_per_addr, + whitelist: Arc::new(whitelist.into_iter().map(normalize_ip).collect()), + } + } + + pub fn try_acquire(&self, ip: IpAddr) -> Option { + let ip = normalize_ip(ip); + if self.whitelist.contains(&ip) { + return Some(ActiveStreamGuard::whitelisted()); + } + + let mut entry = self.active.entry(ip).or_default(); + + if *entry >= self.max_streams_per_addr.get() { + return None; + } + + *entry += 1; + + Some(ActiveStreamGuard::new(ip, self.clone())) + } + + fn release(&self, ip: IpAddr) { + self.active.remove_if_mut(&ip, |_, count| { + *count = count.saturating_sub(1); + *count == 0 + }); + } +} + +pub enum ActiveStreamGuard { + Counted { + ip: IpAddr, + active: ActiveStreamLimiter, + }, + Whitelisted, +} + +impl ActiveStreamGuard { + fn new(ip: IpAddr, active: ActiveStreamLimiter) -> Self { + Self::Counted { ip, active } + } + + fn whitelisted() -> Self { + Self::Whitelisted + } +} + +impl Drop for ActiveStreamGuard { + fn drop(&mut self) { + if let Self::Counted { ip, active } = self { + active.release(*ip); + } + } +} + +#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] +enum RpcTrafficClass { + Request, + Stream, +} + +pub async fn rate_limit( + axum::extract::State(limiter): axum::extract::State, + ClientIp(ip): ClientIp, + req: Request, + next: Next, +) -> Response { + let Some(class) = RpcRateLimiter::classify_request(&req) else { + return next.run(req).await; + }; + + match limiter.check(ip, class) { + RateLimitVerdict::Allow => next.run(req).await, + RateLimitVerdict::Reject { retry_after } => { + // Round up to whole seconds. + let retry_after = retry_after.as_millis().div_ceil(1_000).max(1); + + (StatusCode::TOO_MANY_REQUESTS, [( + header::RETRY_AFTER, + retry_after.to_string(), + )]) + .into_response() + } + } +} diff --git a/rpc/src/lib.rs b/rpc/src/lib.rs index cfc4ef0f1e..954066b922 100644 --- a/rpc/src/lib.rs +++ b/rpc/src/lib.rs @@ -24,6 +24,7 @@ mod state; pub mod util { pub mod error_codes; + pub mod ip; pub mod jrpc_extractor; pub mod mime; pub mod serde_helpers; diff --git a/rpc/src/util/ip.rs b/rpc/src/util/ip.rs new file mode 100644 index 0000000000..512826ce19 --- /dev/null +++ b/rpc/src/util/ip.rs @@ -0,0 +1,33 @@ +use std::net::IpAddr; + +pub fn normalize_ip(ip: IpAddr) -> IpAddr { + match ip { + IpAddr::V4(_) => ip, + IpAddr::V6(ip) => { + const IPV6_PREFIX_MASK: u128 = u128::MAX << 64; + IpAddr::V6((u128::from(ip) & IPV6_PREFIX_MASK).into()) + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn normalize_ipv4() { + let ip = "192.0.2.1".parse().unwrap(); + + assert_eq!(normalize_ip(ip), ip); + } + + #[test] + fn normalize_ipv6() { + let prefix: IpAddr = "2001:db8:abcd:1234::".parse().unwrap(); + + assert_eq!( + normalize_ip("2001:db8:abcd:1234:ffff:ffff:ffff:ffff".parse().unwrap()), + prefix + ); + } +} diff --git a/util/Cargo.toml b/util/Cargo.toml index ad1a739aa6..9a4ec7f767 100644 --- a/util/Cargo.toml +++ b/util/Cargo.toml @@ -26,6 +26,7 @@ humantime = { workspace = true } libc = { workspace = true, optional = true } metrics = { workspace = true } metrics-exporter-prometheus = { workspace = true, optional = true } +parking_lot = { workspace = true } rand = { workspace = true } rayon = { workspace = true } scopeguard = { workspace = true } diff --git a/util/src/lib.rs b/util/src/lib.rs index ec2c6aca40..9da2bc1327 100644 --- a/util/src/lib.rs +++ b/util/src/lib.rs @@ -10,6 +10,7 @@ pub mod compression; pub mod config; pub mod io; pub mod progress_bar; +pub mod rate_limit; pub mod serde_helpers; pub mod time; pub mod transactional; diff --git a/util/src/rate_limit.rs b/util/src/rate_limit.rs new file mode 100644 index 0000000000..e7a569dd4a --- /dev/null +++ b/util/src/rate_limit.rs @@ -0,0 +1,450 @@ +use std::hash::Hash; +use std::num::NonZeroU32; +use std::sync::Arc; +use std::sync::atomic::{AtomicU8, AtomicU64, Ordering}; +use std::time::Duration; + +use parking_lot::RwLock; +use serde::{Deserialize, Serialize}; + +use crate::{FastDashMap, FastHashMap, serde_helpers, time}; + +#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)] +pub struct TrafficLimit { + pub rate_per_sec: NonZeroU32, + pub burst: NonZeroU32, +} + +impl TrafficLimit { + // Millisecond GCRA cannot represent intervals smaller than 1ms. + pub const MAX_RATE_PER_SEC: u32 = 1_000; + + pub const fn new(rate_per_sec: NonZeroU32, burst: NonZeroU32) -> Self { + Self { + rate_per_sec, + burst, + } + } + + fn normalize(&mut self) { + if self.rate_per_sec.get() > Self::MAX_RATE_PER_SEC { + self.rate_per_sec = NonZeroU32::new(Self::MAX_RATE_PER_SEC).unwrap(); + } + } +} + +#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] +#[serde(default)] +pub struct RateLimitConfig { + pub rejects_before_cooldown: u8, + #[serde(with = "serde_helpers::humantime")] + pub cooldown: Duration, + #[serde(with = "serde_helpers::humantime")] + pub prune_interval: Duration, + #[serde(with = "serde_helpers::humantime")] + pub state_ttl: Duration, +} + +impl RateLimitConfig { + pub const MIN_REJECTS_BEFORE_COOLDOWN: u8 = 1; + pub const MAX_REJECTS_BEFORE_COOLDOWN: u8 = u8::MAX - 1; + pub const MIN_STATE_TTL: Duration = Duration::from_secs(1); + pub const MIN_PRUNE_INTERVAL: Duration = Duration::from_secs(1); + + fn normalize(&mut self) { + self.rejects_before_cooldown = self.rejects_before_cooldown.clamp( + Self::MIN_REJECTS_BEFORE_COOLDOWN, + Self::MAX_REJECTS_BEFORE_COOLDOWN, + ); + + if self.prune_interval.is_zero() { + self.prune_interval = Self::MIN_PRUNE_INTERVAL; + } + + if self.state_ttl.is_zero() { + self.state_ttl = Self::MIN_STATE_TTL; + } + } +} + +impl Default for RateLimitConfig { + fn default() -> Self { + Self { + rejects_before_cooldown: 5, + cooldown: Duration::from_secs(30), + prune_interval: Duration::from_secs(30), + state_ttl: Duration::from_secs(300), + } + } +} + +#[derive(Debug, Clone, Copy, PartialEq)] +pub enum RateLimitVerdict { + Allow, + Reject { retry_after: Duration }, +} + +#[derive(Debug, Clone, Copy)] +pub struct RateLimitPolicy { + pub class: C, + pub limit: TrafficLimit, +} + +#[derive(Clone)] +pub struct RateLimiter { + inner: Arc>, +} + +struct RateLimiterInner { + config: RateLimitConfig, + states: FastDashMap>>, + last_prune_ms: AtomicU64, +} + +impl RateLimiter +where + K: Clone + Eq + Hash + Send + Sync + 'static, + C: Copy + Eq + Hash + Send + Sync + 'static, +{ + pub fn new(mut config: RateLimitConfig) -> Self { + config.normalize(); + + Self { + inner: Arc::new(RateLimiterInner { + config, + states: FastDashMap::default(), + last_prune_ms: AtomicU64::new(time::now_millis()), + }), + } + } + + pub fn check(&self, key: &K, policy: RateLimitPolicy) -> RateLimitVerdict { + self.inner.check(key, policy) + } +} + +impl RateLimiterInner +where + K: Clone + Eq + Hash + Send + Sync + 'static, + C: Copy + Eq + Hash + Send + Sync + 'static, +{ + fn check(&self, key: &K, policy: RateLimitPolicy) -> RateLimitVerdict { + let now = time::now_millis(); + + self.maybe_prune(now); + + let peer = self.peer(key, now); + peer.check(&self.config, policy, now) + } + + fn peer(&self, key: &K, now: u64) -> Arc> { + let state_ttl = self.config.state_ttl.as_millis_u64(); + + if let Some(peer) = self.states.get(key) + && !peer.is_expired(now, state_ttl) + { + return peer.clone(); + } + + let mut entry = self + .states + .entry(key.clone()) + .or_insert_with(|| Arc::new(PeerLimiter::new(now))); + + if entry.is_expired(now, state_ttl) { + *entry = Arc::new(PeerLimiter::new(now)); + } + + entry.clone() + } + + fn prune_expired(&self, now: u64) { + let state_ttl = self.config.state_ttl.as_millis_u64(); + + self.states + .retain(|_, peer| peer.in_cooldown(now) || !peer.is_expired(now, state_ttl)); + } + + fn maybe_prune(&self, now: u64) { + let last = self.last_prune_ms.load(Ordering::Relaxed); + + if now.saturating_sub(last) < self.config.prune_interval.as_millis_u64() { + return; + } + + if self + .last_prune_ms + .compare_exchange(last, now, Ordering::Relaxed, Ordering::Relaxed) + .is_ok() + { + self.prune_expired(now); + } + } +} + +struct PeerLimiter { + buckets: RwLock>>, + rejects: AtomicU8, + cooldown_until_ms: AtomicU64, + last_seen_ms: AtomicU64, +} + +impl PeerLimiter +where + C: Copy + Eq + Hash + Send + Sync + 'static, +{ + fn new(now_ms: u64) -> Self { + Self { + buckets: RwLock::new(FastHashMap::default()), + rejects: AtomicU8::new(0), + cooldown_until_ms: AtomicU64::new(0), + last_seen_ms: AtomicU64::new(now_ms), + } + } + + fn check( + &self, + config: &RateLimitConfig, + policy: RateLimitPolicy, + now: u64, + ) -> RateLimitVerdict { + self.last_seen_ms.store(now, Ordering::Relaxed); + + let cooldown_until = self.cooldown_until_ms.load(Ordering::Acquire); + if now < cooldown_until { + return RateLimitVerdict::Reject { + retry_after: Duration::from_millis(cooldown_until - now), + }; + } + + let bucket = self.bucket(policy.class, policy.limit); + let BucketResult::TryLater { after } = bucket.try_claim(now) else { + self.rejects.store(0, Ordering::Relaxed); + return RateLimitVerdict::Allow; + }; + + self.register_rejection(config, now); + + // Prefer cooldown wait time. + let cooldown_until = self.cooldown_until_ms.load(Ordering::Acquire); + let retry_after = if now < cooldown_until { + cooldown_until - now + } else { + after.saturating_sub(now) + }; + + RateLimitVerdict::Reject { + retry_after: Duration::from_millis(retry_after), + } + } + + fn bucket(&self, class: C, config: TrafficLimit) -> Arc { + if let Some(bucket) = self.buckets.read().get(&class) { + return bucket.clone(); + } + + self.buckets + .write() + .entry(class) + .or_insert_with(|| Arc::new(AtomicBucket::new(config))) + .clone() + } + + fn register_rejection(&self, config: &RateLimitConfig, now: u64) { + let prev_rejects = self.rejects.fetch_add(1, Ordering::AcqRel); + debug_assert!(prev_rejects < u8::MAX); + + let rejects = prev_rejects.saturating_add(1); + if rejects >= config.rejects_before_cooldown { + self.rejects.store(0, Ordering::Release); + self.cooldown_until_ms.fetch_max( + now.saturating_add(config.cooldown.as_millis_u64()), + Ordering::AcqRel, + ); + } + } + + fn is_expired(&self, now: u64, state_ttl: u64) -> bool { + let last_seen = self.last_seen_ms.load(Ordering::Relaxed); + now.saturating_sub(last_seen) >= state_ttl + } + + fn in_cooldown(&self, now: u64) -> bool { + now < self.cooldown_until_ms.load(Ordering::Acquire) + } +} + +#[derive(Debug, PartialEq, Eq)] +enum BucketResult { + Ok, + TryLater { after: u64 }, +} + +/// GCRA bucket +struct AtomicBucket { + /// Theoretical arrival time + tat_ms: AtomicU64, + /// Spacing between requests for configured rate + interval_ms: u64, + /// Burst allowance. + tolerance_ms: u64, +} + +impl AtomicBucket { + const MILLIS_PER_SEC: u64 = 1_000; + + fn new(mut config: TrafficLimit) -> Self { + config.normalize(); + + let rate_per_sec = config.rate_per_sec.get() as u64; + let interval_ms = Self::MILLIS_PER_SEC.div_ceil(rate_per_sec).max(1); + let tolerance_ms = interval_ms.saturating_mul(config.burst.get().saturating_sub(1) as u64); + + Self { + tat_ms: AtomicU64::new(0), + interval_ms, + tolerance_ms, + } + } + + fn try_claim(&self, now: u64) -> BucketResult { + let mut tat = self.tat_ms.load(Ordering::Relaxed); + + loop { + let allowed_at = tat.saturating_sub(self.tolerance_ms); + if now < allowed_at { + return BucketResult::TryLater { after: allowed_at }; + } + + let new_tat = tat.max(now).saturating_add(self.interval_ms); + match self.tat_ms.compare_exchange_weak( + tat, + new_tat, + Ordering::AcqRel, + Ordering::Relaxed, + ) { + Ok(_) => return BucketResult::Ok, + Err(next) => tat = next, + } + } + } +} + +trait DurationExt { + fn as_millis_u64(&self) -> u64; +} + +impl DurationExt for Duration { + fn as_millis_u64(&self) -> u64 { + self.as_millis().try_into().unwrap_or(u64::MAX) + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[derive(Clone, Copy, PartialEq, Eq, Hash)] + enum Class { + A, + B, + } + + fn bucket_config(rate_per_sec: u32, burst: u32) -> TrafficLimit { + TrafficLimit::new( + NonZeroU32::new(rate_per_sec).unwrap(), + NonZeroU32::new(burst).unwrap(), + ) + } + + fn policy(class: Class) -> RateLimitPolicy { + RateLimitPolicy { + class, + limit: bucket_config(1, 1), + } + } + + fn rate_limiter() -> RateLimiter { + RateLimiter::new(RateLimitConfig { + rejects_before_cooldown: 2, + ..Default::default() + }) + } + + #[test] + fn gcra_bucket_burst_and_refills() { + let now = time::now_millis(); + + let bucket = AtomicBucket::new(bucket_config(10, 2)); + + // 10 req/s = refill every 100ms + let delay = Duration::from_millis(100).as_millis_u64(); + + // Spend burst capacity + assert_eq!(bucket.try_claim(now), BucketResult::Ok); + assert_eq!(bucket.try_claim(now), BucketResult::Ok); + + assert_eq!(bucket.try_claim(now), BucketResult::TryLater { + after: now + delay, + }); + + // Wait for refilling tokens + assert_eq!(bucket.try_claim(now + delay), BucketResult::Ok); + } + + #[test] + fn rate_limiter_cooldown() { + let limiter = rate_limiter(); + + let key = 1; + + // Spend burst capacity + assert_eq!( + limiter.check(&key, policy(Class::A)), + RateLimitVerdict::Allow + ); + + // Spend rejects limit + assert!(matches!( + limiter.check(&key, policy(Class::A)), + RateLimitVerdict::Reject { retry_after } + if retry_after <= Duration::from_secs(1) + )); + + // Check cooldown + assert!(matches!( + limiter.check(&key, policy(Class::A)), + RateLimitVerdict::Reject { retry_after } + if retry_after > Duration::from_secs(29) + )); + } + + #[test] + fn rate_limiter_keep_buckets_per_class() { + let limiter = rate_limiter(); + + let key = 1; + + // Spend burst for A + assert_eq!( + limiter.check(&key, policy(Class::A)), + RateLimitVerdict::Allow + ); + + // Spend burst for B + assert_eq!( + limiter.check(&key, policy(Class::B)), + RateLimitVerdict::Allow + ); + + assert!(matches!( + limiter.check(&key, policy(Class::A)), + RateLimitVerdict::Reject { .. }, + )); + + assert!(matches!( + limiter.check(&key, policy(Class::B)), + RateLimitVerdict::Reject { .. }, + )); + } +}