From d84cf03c97f494a915db3a5d0153b349e5cb906b Mon Sep 17 00:00:00 2001 From: jiashaokun <1114621279@qq.com> Date: Sat, 15 Nov 2025 14:32:30 +0800 Subject: [PATCH 1/2] router select dp group with the minimum number of tokens 01 --- python/sglang/srt/entrypoints/EngineBase.py | 1 + python/sglang/srt/entrypoints/engine.py | 4 ++ .../sglang/srt/entrypoints/openai/protocol.py | 4 ++ .../srt/entrypoints/openai/serving_chat.py | 3 + .../srt/managers/data_parallel_controller.py | 25 ++++++-- python/sglang/srt/managers/io_struct.py | 7 +++ python/sglang/srt/managers/schedule_batch.py | 2 + python/sglang/srt/managers/scheduler.py | 1 + .../sglang/srt/managers/tokenizer_manager.py | 1 + sgl-router/benches/request_processing.rs | 1 + .../bindings/python/sglang_router/router.py | 2 + .../python/sglang_router/router_args.py | 13 ++++ sgl-router/bindings/python/src/lib.rs | 10 +++ sgl-router/src/app_context.rs | 5 +- sgl-router/src/config/builder.rs | 9 +++ sgl-router/src/config/types.rs | 4 ++ sgl-router/src/core/worker_manager.rs | 38 +++++++----- sgl-router/src/lib.rs | 2 +- sgl-router/src/main.rs | 8 +++ sgl-router/src/policies/cache_aware.rs | 26 ++++++-- sgl-router/src/policies/mod.rs | 62 ++++++++++++++++++- sgl-router/src/policies/power_of_two.rs | 16 ++++- sgl-router/src/policies/random.rs | 27 ++++++-- sgl-router/src/policies/registry.rs | 53 +++++++++++++++- sgl-router/src/policies/round_robin.rs | 17 ++++- sgl-router/src/protocols/generate.rs | 3 + sgl-router/src/protocols/worker_spec.rs | 2 + sgl-router/src/routers/http/pd_router.rs | 48 ++++++++++++++ 28 files changed, 359 insertions(+), 35 deletions(-) diff --git a/python/sglang/srt/entrypoints/EngineBase.py b/python/sglang/srt/entrypoints/EngineBase.py index 5d3162afd51..412bf0d09bd 100644 --- a/python/sglang/srt/entrypoints/EngineBase.py +++ b/python/sglang/srt/entrypoints/EngineBase.py @@ -29,6 +29,7 @@ def generate( bootstrap_port: Optional[Union[List[int], int]] = None, bootstrap_room: Optional[Union[List[int], int]] = None, data_parallel_rank: Optional[int] = None, + decode_dp_rank: Optional[int] = None, rid: Optional[Union[List[str], str]] = None, ) -> Union[Dict, Iterator[Dict]]: """Generate outputs based on given inputs.""" diff --git a/python/sglang/srt/entrypoints/engine.py b/python/sglang/srt/entrypoints/engine.py index 5a0b86f0018..fcd3dcf006b 100644 --- a/python/sglang/srt/entrypoints/engine.py +++ b/python/sglang/srt/entrypoints/engine.py @@ -184,6 +184,7 @@ def generate( bootstrap_port: Optional[Union[List[int], int]] = None, bootstrap_room: Optional[Union[List[int], int]] = None, data_parallel_rank: Optional[int] = None, + decode_dp_rank: Optional[int] = None, rid: Optional[Union[List[str], str]] = None, ) -> Union[Dict, Iterator[Dict]]: """ @@ -219,6 +220,7 @@ def generate( bootstrap_port=bootstrap_port, bootstrap_room=bootstrap_room, data_parallel_rank=data_parallel_rank, + decode_dp_rank=decode_dp_rank, rid=rid, ) generator = self.tokenizer_manager.generate_request(obj, None) @@ -266,6 +268,7 @@ async def async_generate( bootstrap_port: Optional[Union[List[int], int]] = None, bootstrap_room: Optional[Union[List[int], int]] = None, data_parallel_rank: Optional[int] = None, + decode_dp_rank: Optional[int] = None, rid: Optional[Union[List[str], str]] = None, ) -> Union[Dict, AsyncIterator[Dict]]: """ @@ -303,6 +306,7 @@ async def async_generate( bootstrap_port=bootstrap_port, bootstrap_room=bootstrap_room, data_parallel_rank=data_parallel_rank, + decode_dp_rank=decode_dp_rank, rid=rid, ) generator = self.tokenizer_manager.generate_request(obj, None) diff --git a/python/sglang/srt/entrypoints/openai/protocol.py b/python/sglang/srt/entrypoints/openai/protocol.py index ecb8a48b7d2..11a4fa34eeb 100644 --- a/python/sglang/srt/entrypoints/openai/protocol.py +++ b/python/sglang/srt/entrypoints/openai/protocol.py @@ -530,6 +530,10 @@ class ChatCompletionRequest(BaseModel): bootstrap_port: Optional[Union[List[Optional[int]], int]] = None bootstrap_room: Optional[Union[List[int], int]] = None + # For data parallel rank routing + data_parallel_rank: Optional[int] = None + decode_dp_rank: Optional[int] = None + # OpenAI/SGLang default sampling parameters _DEFAULT_SAMPLING_PARAMS = { "temperature": 1.0, diff --git a/python/sglang/srt/entrypoints/openai/serving_chat.py b/python/sglang/srt/entrypoints/openai/serving_chat.py index 773707c36b5..f979b774a3a 100644 --- a/python/sglang/srt/entrypoints/openai/serving_chat.py +++ b/python/sglang/srt/entrypoints/openai/serving_chat.py @@ -191,6 +191,9 @@ def _convert_to_internal_request( bootstrap_host=request.bootstrap_host, bootstrap_port=request.bootstrap_port, bootstrap_room=request.bootstrap_room, + # For data parallel rank routing + data_parallel_rank=request.data_parallel_rank, + decode_dp_rank=request.decode_dp_rank, return_hidden_states=request.return_hidden_states, rid=request.rid, extra_key=self._compute_extra_key(request), diff --git a/python/sglang/srt/managers/data_parallel_controller.py b/python/sglang/srt/managers/data_parallel_controller.py index cb897a643cd..27d34b35956 100644 --- a/python/sglang/srt/managers/data_parallel_controller.py +++ b/python/sglang/srt/managers/data_parallel_controller.py @@ -454,11 +454,26 @@ def launch_tensor_parallel_group( self.max_req_input_len = scheduler_info[0]["max_req_input_len"] def maybe_external_dp_rank_routing(self, req: Req): - if req.data_parallel_rank is not None: - logger.debug(f"Direct routing to DP rank {req.data_parallel_rank}") - self.workers[req.data_parallel_rank].send_pyobj(req) - return True - return False + if self.server_args.disaggregation_mode == "prefill": + if req.data_parallel_rank is not None: + logger.debug( + f"Prefill direct routing to DP rank {req.data_parallel_rank}" + ) + self.workers[req.data_parallel_rank].send_pyobj(req) + return True + return False + else: + if req.decode_dp_rank is not None: + logger.debug(f"Decode direct routing to DP rank {req.decode_dp_rank}") + self.workers[req.decode_dp_rank].send_pyobj(req) + return True + if req.data_parallel_rank is not None: + logger.debug( + f"Decode direct routing to DP rank {req.data_parallel_rank}, by data parallel rank" + ) + self.workers[req.data_parallel_rank].send_pyobj(req) + return True + return False def round_robin_scheduler(self, req: Req): if self.maybe_external_dp_rank_routing(req): diff --git a/python/sglang/srt/managers/io_struct.py b/python/sglang/srt/managers/io_struct.py index cdecebca3c4..3eab4f1ebc6 100644 --- a/python/sglang/srt/managers/io_struct.py +++ b/python/sglang/srt/managers/io_struct.py @@ -203,6 +203,8 @@ class GenerateReqInput(BaseReq): # For data parallel rank routing data_parallel_rank: Optional[int] = None + decode_dp_rank: Optional[int] = None + # For background responses (OpenAI responses API) background: bool = False @@ -620,6 +622,9 @@ def __getitem__(self, i): data_parallel_rank=( self.data_parallel_rank if self.data_parallel_rank is not None else None ), + decode_dp_rank=( + self.decode_dp_rank if self.decode_dp_rank is not None else None + ), conversation_id=self.conversation_id, priority=self.priority, extra_key=self.extra_key, @@ -678,6 +683,8 @@ class TokenizedGenerateReqInput(BaseReq): # For data parallel rank routing data_parallel_rank: Optional[int] = None + decode_dp_rank: Optional[int] = None + # Priority for the request priority: Optional[int] = None diff --git a/python/sglang/srt/managers/schedule_batch.py b/python/sglang/srt/managers/schedule_batch.py index e1b98696128..723d33beaab 100644 --- a/python/sglang/srt/managers/schedule_batch.py +++ b/python/sglang/srt/managers/schedule_batch.py @@ -457,6 +457,7 @@ def __init__( bootstrap_room: Optional[int] = None, disagg_mode: Optional[DisaggregationMode] = None, data_parallel_rank: Optional[int] = None, + decode_dp_rank: Optional[int] = None, vocab_size: Optional[int] = None, priority: Optional[int] = None, metrics_collector: Optional[SchedulerMetricsCollector] = None, @@ -666,6 +667,7 @@ def __init__( # For data parallel rank routing self.data_parallel_rank: Optional[int] = data_parallel_rank + self.decode_dp_rank: Optional[int] = decode_dp_rank # the start index of the sent kv cache # We want to send it chunk by chunk for chunked prefill. diff --git a/python/sglang/srt/managers/scheduler.py b/python/sglang/srt/managers/scheduler.py index 879c66df09c..899af2ec8b9 100644 --- a/python/sglang/srt/managers/scheduler.py +++ b/python/sglang/srt/managers/scheduler.py @@ -1287,6 +1287,7 @@ def handle_generate_request( bootstrap_room=recv_req.bootstrap_room, disagg_mode=self.disaggregation_mode, data_parallel_rank=recv_req.data_parallel_rank, + decode_dp_rank=recv_req.decode_dp_rank, vocab_size=self.model_config.vocab_size, priority=recv_req.priority, metrics_collector=( diff --git a/python/sglang/srt/managers/tokenizer_manager.py b/python/sglang/srt/managers/tokenizer_manager.py index 3375ab60e19..310debc961c 100644 --- a/python/sglang/srt/managers/tokenizer_manager.py +++ b/python/sglang/srt/managers/tokenizer_manager.py @@ -815,6 +815,7 @@ def _create_tokenized_object( custom_logit_processor=obj.custom_logit_processor, return_hidden_states=obj.return_hidden_states, data_parallel_rank=obj.data_parallel_rank, + decode_dp_rank=obj.decode_dp_rank, priority=obj.priority, extra_key=obj.extra_key, ) diff --git a/sgl-router/benches/request_processing.rs b/sgl-router/benches/request_processing.rs index 54d7045120f..d5b116f9ee2 100644 --- a/sgl-router/benches/request_processing.rs +++ b/sgl-router/benches/request_processing.rs @@ -58,6 +58,7 @@ fn default_generate_request() -> GenerateRequest { bootstrap_room: None, bootstrap_pair_key: None, data_parallel_rank: None, + decode_dp_rank: None, background: false, conversation_id: None, priority: None, diff --git a/sgl-router/bindings/python/sglang_router/router.py b/sgl-router/bindings/python/sglang_router/router.py index 0ee130749b0..f58e556528f 100644 --- a/sgl-router/bindings/python/sglang_router/router.py +++ b/sgl-router/bindings/python/sglang_router/router.py @@ -76,6 +76,7 @@ class Router: port: Port number to bind the router server. Default: 3001 worker_startup_timeout_secs: Timeout in seconds for worker startup. Default: 300 worker_startup_check_interval: Interval in seconds between checks for worker initialization. Default: 10 + worker_load_check_interval: Interval in seconds between get loads for worker initialization. Default: 10 cache_threshold: Cache threshold (0.0-1.0) for cache-aware routing. Routes to cached worker if the match rate exceeds threshold, otherwise routes to the worker with the smallest tree. Default: 0.5 @@ -88,6 +89,7 @@ class Router: max_payload_size: Maximum payload size in bytes. Default: 256MB max_tree_size: Maximum size of the approximation tree for cache-aware routing. Default: 2^24 dp_aware: Enable data parallelism aware schedule. Default: False + dp_minimum_tokens_scheduler: Enable minimum tokens scheduler for data parallel group. Default: False enable_igw: Enable IGW (Inference-Gateway) mode for multi-model support. When enabled, the router can manage multiple models simultaneously with per-model load balancing policies. Default: False diff --git a/sgl-router/bindings/python/sglang_router/router_args.py b/sgl-router/bindings/python/sglang_router/router_args.py index 04077b9de41..b561c36d5ce 100644 --- a/sgl-router/bindings/python/sglang_router/router_args.py +++ b/sgl-router/bindings/python/sglang_router/router_args.py @@ -28,6 +28,7 @@ class RouterArgs: decode_policy: Optional[str] = None # Specific policy for decode nodes in PD mode worker_startup_timeout_secs: int = 600 worker_startup_check_interval: int = 30 + worker_load_check_interval: int = 10 cache_threshold: float = 0.3 balance_abs_threshold: int = 64 balance_rel_threshold: float = 1.5 @@ -36,6 +37,7 @@ class RouterArgs: max_payload_size: int = 512 * 1024 * 1024 # 512MB default for large batches bucket_adjust_interval_secs: int = 5 dp_aware: bool = False + dp_minimum_tokens_scheduler: bool = False enable_igw: bool = False # Enable IGW (Inter-Gateway) mode for multi-model support api_key: Optional[str] = None log_dir: Optional[str] = None @@ -217,6 +219,12 @@ def add_cli_args( default=RouterArgs.worker_startup_check_interval, help="Interval in seconds between checks for worker startup", ) + parser.add_argument( + f"--{prefix}worker-load-check-interval", + type=int, + default=RouterArgs.worker_load_check_interval, + help="Interval in seconds between checks for worker startup", + ) parser.add_argument( f"--{prefix}cache-threshold", type=float, @@ -264,6 +272,11 @@ def add_cli_args( action="store_true", help="Enable data parallelism aware schedule", ) + parser.add_argument( + f"--{prefix}dp-minimum-tokens-scheduler", + action="store_true", + help="Enable minimum tokens scheduler for data parallel group", + ) parser.add_argument( f"--{prefix}enable-igw", action="store_true", diff --git a/sgl-router/bindings/python/src/lib.rs b/sgl-router/bindings/python/src/lib.rs index e145ca1f9e8..ebc70c71396 100644 --- a/sgl-router/bindings/python/src/lib.rs +++ b/sgl-router/bindings/python/src/lib.rs @@ -157,6 +157,7 @@ struct Router { policy: PolicyType, worker_startup_timeout_secs: u64, worker_startup_check_interval: u64, + worker_load_check_interval: u64, cache_threshold: f32, balance_abs_threshold: usize, balance_rel_threshold: f32, @@ -164,6 +165,7 @@ struct Router { max_tree_size: usize, max_payload_size: usize, dp_aware: bool, + dp_minimum_tokens_scheduler: bool, api_key: Option, log_dir: Option, log_level: Option, @@ -342,6 +344,7 @@ impl Router { .request_timeout_secs(self.request_timeout_secs) .worker_startup_timeout_secs(self.worker_startup_timeout_secs) .worker_startup_check_interval_secs(self.worker_startup_check_interval) + .worker_load_check_interval_secs(self.worker_load_check_interval) .max_concurrent_requests(self.max_concurrent_requests) .queue_size(self.queue_size) .queue_timeout_secs(self.queue_timeout_secs) @@ -397,6 +400,7 @@ impl Router { self.client_key_path.as_ref(), ) .add_ca_certificates(self.ca_cert_paths.clone()) + .dp_minimum_tokens_scheduler(self.dp_minimum_tokens_scheduler) .build() } } @@ -411,6 +415,7 @@ impl Router { port = 3001, worker_startup_timeout_secs = 600, worker_startup_check_interval = 30, + worker_load_check_interval = 10, cache_threshold = 0.3, balance_abs_threshold = 64, balance_rel_threshold = 1.5, @@ -418,6 +423,7 @@ impl Router { max_tree_size = 2usize.pow(26), max_payload_size = 512 * 1024 * 1024, dp_aware = false, + dp_minimum_tokens_scheduler = false, api_key = None, log_dir = None, log_level = None, @@ -486,6 +492,7 @@ impl Router { port: u16, worker_startup_timeout_secs: u64, worker_startup_check_interval: u64, + worker_load_check_interval: u64, cache_threshold: f32, balance_abs_threshold: usize, balance_rel_threshold: f32, @@ -493,6 +500,7 @@ impl Router { max_tree_size: usize, max_payload_size: usize, dp_aware: bool, + dp_minimum_tokens_scheduler: bool, api_key: Option, log_dir: Option, log_level: Option, @@ -574,6 +582,7 @@ impl Router { policy, worker_startup_timeout_secs, worker_startup_check_interval, + worker_load_check_interval, cache_threshold, balance_abs_threshold, balance_rel_threshold, @@ -581,6 +590,7 @@ impl Router { max_tree_size, max_payload_size, dp_aware, + dp_minimum_tokens_scheduler, api_key, log_dir, log_level, diff --git a/sgl-router/src/app_context.rs b/sgl-router/src/app_context.rs index 31e3161c09e..a81d39ece69 100644 --- a/sgl-router/src/app_context.rs +++ b/sgl-router/src/app_context.rs @@ -426,6 +426,9 @@ impl AppContextBuilder { /// Create policy registry fn with_policy_registry(mut self, config: &RouterConfig) -> Self { self.policy_registry = Some(Arc::new(PolicyRegistry::new(config.policy.clone()))); + if config.dp_minimum_tokens_scheduler { + self.policy_registry.as_ref().unwrap().enable_dp_minimum_tokens_scheduler(); + } self } @@ -457,7 +460,7 @@ impl AppContextBuilder { .expect("policy_registry must be set") .clone(), client.clone(), - config.worker_startup_check_interval_secs, + config.worker_load_check_interval_secs, ))); self } diff --git a/sgl-router/src/config/builder.rs b/sgl-router/src/config/builder.rs index f7e15aae59e..77e687f498f 100644 --- a/sgl-router/src/config/builder.rs +++ b/sgl-router/src/config/builder.rs @@ -181,6 +181,10 @@ impl RouterConfigBuilder { self } + pub fn worker_load_check_interval_secs(mut self, interval: u64) -> Self { + self.config.worker_load_check_interval_secs = interval; + self + } // ==================== Rate Limiting ==================== pub fn max_concurrent_requests(mut self, max: i32) -> Self { @@ -434,6 +438,11 @@ impl RouterConfigBuilder { self } + pub fn dp_minimum_tokens_scheduler(mut self, enable: bool) -> Self { + self.config.dp_minimum_tokens_scheduler = enable; + self + } + // ==================== Option Setters ==================== // Accept Option and only set if Some diff --git a/sgl-router/src/config/types.rs b/sgl-router/src/config/types.rs index 90174f070d4..050b5400afd 100644 --- a/sgl-router/src/config/types.rs +++ b/sgl-router/src/config/types.rs @@ -19,7 +19,9 @@ pub struct RouterConfig { pub request_timeout_secs: u64, pub worker_startup_timeout_secs: u64, pub worker_startup_check_interval_secs: u64, + pub worker_load_check_interval_secs: u64, pub dp_aware: bool, + pub dp_minimum_tokens_scheduler: bool, pub api_key: Option, pub discovery: Option, pub metrics: Option, @@ -471,7 +473,9 @@ impl Default for RouterConfig { request_timeout_secs: 1800, // 30 minutes worker_startup_timeout_secs: 600, worker_startup_check_interval_secs: 30, + worker_load_check_interval_secs: 10, dp_aware: false, + dp_minimum_tokens_scheduler: false, api_key: None, discovery: None, metrics: None, diff --git a/sgl-router/src/core/worker_manager.rs b/sgl-router/src/core/worker_manager.rs index 5f94e55f1ef..5572dea5ebd 100644 --- a/sgl-router/src/core/worker_manager.rs +++ b/sgl-router/src/core/worker_manager.rs @@ -135,7 +135,7 @@ impl WorkerManager { url: &str, api_key: Option<&str>, client: &reqwest::Client, - ) -> Option { + ) -> Option> { let load_url = format!("{}/get_load", url); let mut request = client.get(&load_url); @@ -150,14 +150,15 @@ impl WorkerManager { // The /get_load endpoint returns an array of load info objects (one per DP rank) // Each object has: {dp_rank, num_reqs, num_waiting_reqs, num_tokens} if let Some(array) = json.as_array() { - let total_tokens: i64 = array - .iter() - .filter_map(|entry| { - entry.get("num_tokens").and_then(|v| v.as_i64()) - }) - .sum(); - debug!("Worker {} load (total tokens): {}", url, total_tokens); - Some(total_tokens as isize) + let mut rank_tokens = HashMap::new(); + for entry in array { + let dp_rank = entry.get("dp_rank").and_then(|v| v.as_i64()).map(|rank| rank as isize); + let num_tokens = entry.get("num_tokens").and_then(|v| v.as_i64()).map(|rank| rank as isize); + if let (Some(rank), Some(tokens)) = (dp_rank, num_tokens) { + rank_tokens.insert(rank, tokens); + } + } + Some(rank_tokens) } else { warn!( "Invalid load response from {}: expected array, got {:?}", @@ -208,18 +209,20 @@ impl WorkerManager { let client = client.clone(); tasks.push(async move { - let load = if is_http { + let dp_rank_loads = if is_http { Self::get_worker_load(&url, api_key.as_deref(), &client) .await - .unwrap_or(-1) + .unwrap_or(HashMap::new()) } else { - -1 + HashMap::new() }; + let load = dp_rank_loads.values().sum(); WorkerLoadInfo { worker: url, worker_type, load, + dp_rank_loads, } }); } @@ -407,19 +410,21 @@ impl LoadMonitor { loop { interval_timer.tick().await; - let power_of_two_policies = policy_registry.get_all_power_of_two_policies(); - if power_of_two_policies.is_empty() { + if power_of_two_policies.is_empty() && !policy_registry.is_dp_minimum_tokens_scheduler_enabled() { debug!("No PowerOfTwo policies found, skipping load fetch"); continue; } + let all_policies = policy_registry.get_all_policies(); let result = WorkerManager::get_all_worker_loads(&worker_registry, &client).await; let mut loads = HashMap::new(); + let mut dp_rank_loads = HashMap::new(); for load_info in result.loads { - loads.insert(load_info.worker, load_info.load); + loads.insert(load_info.worker.clone(), load_info.load); + dp_rank_loads.insert(load_info.worker, load_info.dp_rank_loads); } if !loads.is_empty() { @@ -431,6 +436,9 @@ impl LoadMonitor { for policy in &power_of_two_policies { policy.update_loads(&loads); } + for policy in &all_policies { + policy.update_dp_loads(&dp_rank_loads) + } let _ = tx.send(loads); } else { warn!("No loads fetched from workers"); diff --git a/sgl-router/src/lib.rs b/sgl-router/src/lib.rs index 647d0b82be0..94d0b65284f 100644 --- a/sgl-router/src/lib.rs +++ b/sgl-router/src/lib.rs @@ -15,4 +15,4 @@ pub mod routers; pub mod server; pub mod service_discovery; pub mod tokenizer; -pub mod tool_parser; +pub mod tool_parser; \ No newline at end of file diff --git a/sgl-router/src/main.rs b/sgl-router/src/main.rs index 34c9c460931..574f1980318 100644 --- a/sgl-router/src/main.rs +++ b/sgl-router/src/main.rs @@ -132,6 +132,9 @@ struct CliArgs { #[arg(long, default_value_t = 30)] worker_startup_check_interval: u64, + #[arg(long, default_value_t = 1)] + worker_load_check_interval: u64, + #[arg(long, default_value_t = 0.3)] cache_threshold: f32, @@ -153,6 +156,9 @@ struct CliArgs { #[arg(long, default_value_t = false)] dp_aware: bool, + #[arg(long, default_value_t = false)] + dp_minimum_tokens_scheduler: bool, + #[arg(long)] api_key: Option, @@ -564,6 +570,7 @@ impl CliArgs { .request_timeout_secs(self.request_timeout_secs) .worker_startup_timeout_secs(self.worker_startup_timeout_secs) .worker_startup_check_interval_secs(self.worker_startup_check_interval) + .worker_load_check_interval_secs(self.worker_load_check_interval) .max_concurrent_requests(self.max_concurrent_requests) .queue_size(self.queue_size) .queue_timeout_secs(self.queue_timeout_secs) @@ -613,6 +620,7 @@ impl CliArgs { .maybe_tool_call_parser(self.tool_call_parser.as_ref()) .maybe_mcp_config_path(self.mcp_config_path.as_ref()) .dp_aware(self.dp_aware) + .dp_minimum_tokens_scheduler(self.dp_minimum_tokens_scheduler) .retries(!self.disable_retries) .circuit_breaker(!self.disable_circuit_breaker) .igw(self.enable_igw); diff --git a/sgl-router/src/policies/cache_aware.rs b/sgl-router/src/policies/cache_aware.rs index c43d34ed8ac..7fb3567d66c 100644 --- a/sgl-router/src/policies/cache_aware.rs +++ b/sgl-router/src/policies/cache_aware.rs @@ -59,13 +59,17 @@ during the next eviction cycle. */ -use std::{sync::Arc, thread, time::Duration}; +use std::{sync::Arc, + thread, + time::Duration, + collections::HashMap +}; use dashmap::DashMap; use rand::Rng; use tracing::debug; -use super::{get_healthy_worker_indices, tree::Tree, CacheAwareConfig, LoadBalancingPolicy}; +use super::{get_healthy_worker_indices, tree::Tree, CacheAwareConfig, LoadBalancingPolicy, DPLoadManager}; use crate::{core::Worker, metrics::RouterMetrics}; /// Cache-aware routing policy @@ -78,6 +82,7 @@ pub struct CacheAwarePolicy { config: CacheAwareConfig, trees: Arc>>, eviction_handle: Option>, + dp_load_manager: DPLoadManager, } impl CacheAwarePolicy { @@ -116,14 +121,15 @@ impl CacheAwarePolicy { config, trees, eviction_handle, + dp_load_manager: DPLoadManager::new(), } } /// Initialize the tree with worker URLs (used only during initial setup) pub fn init_workers(&self, workers: &[Arc]) { // Group workers by model - let mut model_workers: std::collections::HashMap>> = - std::collections::HashMap::new(); + let mut model_workers: HashMap>> = + HashMap::new(); for worker in workers { // Use "default" for unknown/empty model_ids for backward compatibility let model_id = worker.model_id(); @@ -400,6 +406,18 @@ impl LoadBalancingPolicy for CacheAwarePolicy { fn as_any(&self) -> &dyn std::any::Any { self } + + fn update_dp_loads(&self, loads: &HashMap>) { + return self.dp_load_manager.update_dp_loads(loads); + } + + fn get_lowest_dp_load(&self, worker: &dyn Worker) -> Option { + return self.dp_load_manager.get_lowest_dp_load(worker); + } + + fn load_increment(&self, worker: &dyn Worker, dp_rank: isize, tokens: isize) { + return self.dp_load_manager.load_increment(worker, dp_rank, tokens); + } } impl Default for CacheAwarePolicy { diff --git a/sgl-router/src/policies/mod.rs b/sgl-router/src/policies/mod.rs index 7eca8609775..39bca3a6e1a 100644 --- a/sgl-router/src/policies/mod.rs +++ b/sgl-router/src/policies/mod.rs @@ -3,7 +3,9 @@ //! This module provides a unified abstraction for routing policies that work //! across both regular and prefill-decode (PD) routing modes. -use std::{fmt::Debug, sync::Arc}; +use std::{fmt::Debug, sync::Arc, sync::RwLock}; +use std::collections::{HashMap}; +use tracing::{debug}; use crate::core::Worker; @@ -74,10 +76,22 @@ pub trait LoadBalancingPolicy: Send + Sync + Debug { /// Update worker load information /// /// This is called periodically with current load information for load-aware policies. - fn update_loads(&self, _loads: &std::collections::HashMap) { + fn update_loads(&self, _loads: &HashMap) { // Default: no-op for policies that don't use load information } + fn update_dp_loads(&self, _loads: &HashMap>) { + // Default: no-op for policies that don't use load information + } + + fn get_lowest_dp_load(&self, _worker: &dyn Worker) -> Option { + None + } + + fn load_increment(&self, _worker: &dyn Worker, _dp_rank: isize, _tokens: isize) { + // Default + } + /// Reset any internal state /// /// This is useful for policies that maintain state (e.g., round-robin counters). @@ -128,6 +142,50 @@ impl Default for BucketConfig { } } +/// Configuration for cache-aware policy +#[derive(Debug, Default)] +pub struct DPLoadManager { + dp_cached_loads: RwLock>>, +} + +impl DPLoadManager { + pub fn new() -> Self { + Self { + dp_cached_loads: RwLock::new(HashMap::new()), + } + } + + pub fn update_dp_loads(&self, loads: &HashMap>) { + debug!("RoundRobinPolicy update_dp_loads map:{:?}", loads); + if let Ok(mut cached) = self.dp_cached_loads.write() { + *cached = loads.clone(); + } + } + + pub fn get_lowest_dp_load(&self, worker: &dyn Worker) -> Option { + if let Ok(cached_loads) = self.dp_cached_loads.read() { + if let Some(loads) = cached_loads.get(worker.url()) { + return loads.iter() + .min_by_key(|&(_, load)| load) + .map(|(&rand_id, _)| rand_id); + } + } + None + } + + pub fn load_increment(&self, worker: &dyn Worker, dp_rank: isize, increment: isize) { + // Add an increment to the load of dp group, + // to prevent all request from being scheduled to the same DP group during the interval between two load reports. + if let Ok(mut cached_loads) = self.dp_cached_loads.write() { + if let Some(loads) = cached_loads.get_mut(worker.url()) { + if let Some(dp_load) = loads.get_mut(&dp_rank) { + *dp_load += increment; + } + } + } + } +} + /// Helper function to filter healthy workers and return their indices pub(crate) fn get_healthy_worker_indices(workers: &[Arc]) -> Vec { workers diff --git a/sgl-router/src/policies/power_of_two.rs b/sgl-router/src/policies/power_of_two.rs index b7edef82273..530087d20a0 100644 --- a/sgl-router/src/policies/power_of_two.rs +++ b/sgl-router/src/policies/power_of_two.rs @@ -8,7 +8,7 @@ use std::{ use rand::Rng; use tracing::info; -use super::{get_healthy_worker_indices, LoadBalancingPolicy}; +use super::{get_healthy_worker_indices, LoadBalancingPolicy, DPLoadManager}; use crate::{core::Worker, metrics::RouterMetrics}; /// Power-of-two choices policy @@ -19,12 +19,14 @@ use crate::{core::Worker, metrics::RouterMetrics}; pub struct PowerOfTwoPolicy { /// Cached load information from external monitoring cached_loads: RwLock>, + dp_load_manager: DPLoadManager, } impl PowerOfTwoPolicy { pub fn new() -> Self { Self { cached_loads: RwLock::new(HashMap::new()), + dp_load_manager: DPLoadManager::new(), } } @@ -111,6 +113,18 @@ impl LoadBalancingPolicy for PowerOfTwoPolicy { fn as_any(&self) -> &dyn std::any::Any { self } + + fn update_dp_loads(&self, loads: &HashMap>) { + return self.dp_load_manager.update_dp_loads(loads); + } + + fn get_lowest_dp_load(&self, worker: &dyn Worker) -> Option { + return self.dp_load_manager.get_lowest_dp_load(worker); + } + + fn load_increment(&self, worker: &dyn Worker, dp_rank: isize, tokens: isize) { + return self.dp_load_manager.load_increment(worker, dp_rank, tokens); + } } impl Default for PowerOfTwoPolicy { diff --git a/sgl-router/src/policies/random.rs b/sgl-router/src/policies/random.rs index 5b92b2d738d..6f654b380d6 100644 --- a/sgl-router/src/policies/random.rs +++ b/sgl-router/src/policies/random.rs @@ -1,21 +1,28 @@ //! Random load balancing policy -use std::sync::Arc; +use std::{ + sync::Arc, + collections::HashMap, +}; use rand::Rng; -use super::{get_healthy_worker_indices, LoadBalancingPolicy}; +use super::{get_healthy_worker_indices, LoadBalancingPolicy, DPLoadManager}; use crate::{core::Worker, metrics::RouterMetrics}; /// Random selection policy /// /// Selects workers randomly with uniform distribution among healthy workers. #[derive(Debug, Default)] -pub struct RandomPolicy; +pub struct RandomPolicy { + dp_load_manager: DPLoadManager, +} impl RandomPolicy { pub fn new() -> Self { - Self + Self { + dp_load_manager: DPLoadManager::new(), + } } } @@ -47,6 +54,18 @@ impl LoadBalancingPolicy for RandomPolicy { fn as_any(&self) -> &dyn std::any::Any { self } + + fn update_dp_loads(&self, loads: &HashMap>) { + return self.dp_load_manager.update_dp_loads(loads); + } + + fn get_lowest_dp_load(&self, worker: &dyn Worker) -> Option { + return self.dp_load_manager.get_lowest_dp_load(worker); + } + + fn load_increment(&self, worker: &dyn Worker, dp_rank: isize, tokens: isize) { + return self.dp_load_manager.load_increment(worker, dp_rank, tokens); + } } #[cfg(test)] diff --git a/sgl-router/src/policies/registry.rs b/sgl-router/src/policies/registry.rs index 5fe5b24ae7d..f84cb0ffed7 100644 --- a/sgl-router/src/policies/registry.rs +++ b/sgl-router/src/policies/registry.rs @@ -1,6 +1,10 @@ use std::{ collections::HashMap, - sync::{Arc, RwLock}, + sync::{ + Arc, + RwLock, + atomic::{AtomicBool, Ordering}, + }, }; use tracing::{debug, info, warn}; @@ -34,6 +38,9 @@ pub struct PolicyRegistry { /// Decode policy for PD mode decode_policy: Arc>>>, + + /// Enable minimum tokens scheduler for dp group + dp_minimum_tokens_scheduler: Arc, } impl PolicyRegistry { @@ -47,9 +54,18 @@ impl PolicyRegistry { default_policy, prefill_policy: Arc::new(RwLock::new(None)), decode_policy: Arc::new(RwLock::new(None)), + dp_minimum_tokens_scheduler: Arc::new(AtomicBool::new(false)), } } + pub fn enable_dp_minimum_tokens_scheduler(&self) { + self.dp_minimum_tokens_scheduler.store(true, Ordering::Relaxed); + } + + pub fn is_dp_minimum_tokens_scheduler_enabled(&self) -> bool { + self.dp_minimum_tokens_scheduler.load(Ordering::Relaxed) + } + /// Called when a worker is added /// Returns the policy that should be used for this worker's model pub fn on_worker_added( @@ -314,6 +330,41 @@ impl PolicyRegistry { power_of_two_policies } + pub fn get_all_policies(&self) -> Vec> { + let mut all_policies = Vec::new(); + + all_policies.push(Arc::clone(&self.default_policy)); + + if let Some(ref policy) = *self.prefill_policy.read().unwrap() { + if !Arc::ptr_eq(policy, &self.default_policy) { + all_policies.push(Arc::clone(policy)); + } + } + + if let Some(ref policy) = *self.decode_policy.read().unwrap() { + if !Arc::ptr_eq(policy, &self.default_policy) + && !self + .prefill_policy + .read() + .unwrap() + .as_ref() + .is_some_and(|p| Arc::ptr_eq(p, policy)) + { + all_policies.push(Arc::clone(policy)); + } + } + + let model_policies = self.model_policies.read().unwrap(); + for policy in model_policies.values() { + let already_added = all_policies.iter().any(|p| Arc::ptr_eq(p, policy)); + if !already_added { + all_policies.push(Arc::clone(policy)); + } + } + + all_policies + } + /// Initialize cache-aware policy with workers if applicable /// This should be called after workers are registered for a model pub fn init_cache_aware_policy(&self, model_id: &str, workers: &[Arc]) { diff --git a/sgl-router/src/policies/round_robin.rs b/sgl-router/src/policies/round_robin.rs index 5b0776253cf..6885d30ccc5 100644 --- a/sgl-router/src/policies/round_robin.rs +++ b/sgl-router/src/policies/round_robin.rs @@ -4,8 +4,9 @@ use std::sync::{ atomic::{AtomicUsize, Ordering}, Arc, }; +use std::collections::{HashMap}; -use super::{get_healthy_worker_indices, LoadBalancingPolicy}; +use super::{get_healthy_worker_indices, LoadBalancingPolicy, DPLoadManager}; use crate::{core::Worker, metrics::RouterMetrics}; /// Round-robin selection policy @@ -14,12 +15,14 @@ use crate::{core::Worker, metrics::RouterMetrics}; #[derive(Debug, Default)] pub struct RoundRobinPolicy { counter: AtomicUsize, + dp_load_manager: DPLoadManager, } impl RoundRobinPolicy { pub fn new() -> Self { Self { counter: AtomicUsize::new(0), + dp_load_manager: DPLoadManager::new(), } } } @@ -57,6 +60,18 @@ impl LoadBalancingPolicy for RoundRobinPolicy { fn as_any(&self) -> &dyn std::any::Any { self } + + fn update_dp_loads(&self, loads: &HashMap>) { + return self.dp_load_manager.update_dp_loads(loads); + } + + fn get_lowest_dp_load(&self, worker: &dyn Worker) -> Option { + return self.dp_load_manager.get_lowest_dp_load(worker); + } + + fn load_increment(&self, worker: &dyn Worker, dp_rank: isize, tokens: isize) { + return self.dp_load_manager.load_increment(worker, dp_rank, tokens); + } } #[cfg(test)] diff --git a/sgl-router/src/protocols/generate.rs b/sgl-router/src/protocols/generate.rs index d5819095a37..e9d64ec7612 100644 --- a/sgl-router/src/protocols/generate.rs +++ b/sgl-router/src/protocols/generate.rs @@ -132,6 +132,9 @@ pub struct GenerateRequest { #[serde(skip_serializing_if = "Option::is_none")] pub data_parallel_rank: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub decode_dp_rank:Option, + /// Background response #[serde(default)] pub background: bool, diff --git a/sgl-router/src/protocols/worker_spec.rs b/sgl-router/src/protocols/worker_spec.rs index 753b84db56b..ba0d8b8dc82 100644 --- a/sgl-router/src/protocols/worker_spec.rs +++ b/sgl-router/src/protocols/worker_spec.rs @@ -312,4 +312,6 @@ pub struct WorkerLoadInfo { pub worker_type: Option, /// Current load (-1 indicates failure to fetch) pub load: isize, + /// Current dp rand load + pub dp_rank_loads: HashMap, } diff --git a/sgl-router/src/routers/http/pd_router.rs b/sgl-router/src/routers/http/pd_router.rs index eb4dc817249..58ce594f961 100644 --- a/sgl-router/src/routers/http/pd_router.rs +++ b/sgl-router/src/routers/http/pd_router.rs @@ -292,6 +292,15 @@ impl PDRouter { Err(e) => return Self::handle_serialization_error(e), }; + if self.policy_registry.is_dp_minimum_tokens_scheduler_enabled() { + // data_parallel_rank + json_request = match self.select_data_parallel_rank(json_request, prefill.as_ref(), decode.as_ref(), context.request_text.as_deref()) + .await + { + Ok(v) => v, + Err(e) => return Self::handle_serialization_error(e), + }; + } let response = self .execute_dual_dispatch_internal( headers, @@ -548,6 +557,45 @@ impl PDRouter { prefill_policy.needs_request_text() || decode_policy.needs_request_text() } + async fn select_data_parallel_rank(&self, mut original: Value, prefill_worker: &dyn Worker, decode_worker: &dyn Worker, request_text: Option<&str>) -> Result { + let obj = original + .as_object_mut() + .ok_or_else(|| "Request must be a JSON object".to_string())?; + + let length = match request_text { + Some(s) => s.len(), + None => 0, + }; + let prefill_policy = self.policy_registry.get_prefill_policy(); + let lowest_prefill_dp_rank = prefill_policy.get_lowest_dp_load(prefill_worker); + let decode_policy = self.policy_registry.get_decode_policy(); + let lowest_decode_dp_rank = decode_policy.get_lowest_dp_load(decode_worker); + obj.insert( + "data_parallel_rank".to_string(), + match lowest_prefill_dp_rank { + Some(v) => Value::from(v), + None => Value::Null, + }, + ); + obj.insert( + "decode_dp_rank".to_string(), + match lowest_decode_dp_rank { + Some(v) => Value::from(v), + None => Value::Null, + }, + ); + let prompt_len = length.try_into().map_err(|e| format!("Failed to convert length tp isize:{}", e))?; + debug!("select_data_parallel_rank obj:{:?}, prompt_len:{}", obj, prompt_len); + if let Some(dp_rank) = lowest_prefill_dp_rank { + prefill_policy.load_increment(prefill_worker, dp_rank, prompt_len); + } + if let Some(dp_rank) = lowest_decode_dp_rank { + decode_policy.load_increment(decode_worker, dp_rank, prompt_len); + } + + Ok(original) + } + async fn select_pd_pair( &self, request_text: Option<&str>, From a4869b9ff1df59e872f5a478d7c479c048fa36c5 Mon Sep 17 00:00:00 2001 From: jiashaokun <1114621279@qq.com> Date: Mon, 17 Nov 2025 10:02:03 +0800 Subject: [PATCH 2/2] router select dp group with the minimum number of tokens 02 --- sgl-router/src/lib.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sgl-router/src/lib.rs b/sgl-router/src/lib.rs index 94d0b65284f..647d0b82be0 100644 --- a/sgl-router/src/lib.rs +++ b/sgl-router/src/lib.rs @@ -15,4 +15,4 @@ pub mod routers; pub mod server; pub mod service_discovery; pub mod tokenizer; -pub mod tool_parser; \ No newline at end of file +pub mod tool_parser;