Skip to content

Commit bee9b19

Browse files
Incoporating a round of feedback
Signed-off-by: Zhongxuan Wang <[email protected]>
1 parent 956a435 commit bee9b19

File tree

13 files changed

+92
-67
lines changed

13 files changed

+92
-67
lines changed

components/src/dynamo/vllm/handlers.py

Lines changed: 35 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -75,11 +75,11 @@ def build_sampling_params(
7575

7676

7777
def _request_contains_timing_metrics(request: Dict[str, Any]) -> bool:
78-
"""Check if timing_metrics is requested in extra_fields."""
79-
extra_fields: Optional[List[str]] = request.get("extra_fields")
80-
if extra_fields is None:
78+
"""Check if timing_metrics is requested in observability_fields."""
79+
observability_fields: Optional[List[str]] = request.get("observability_fields")
80+
if observability_fields is None:
8181
return False
82-
return "timing_metrics" in extra_fields
82+
return "timing_metrics" in observability_fields
8383

8484

8585
class BaseWorkerHandler(ABC):
@@ -259,10 +259,10 @@ async def generate_tokens(
259259
out = {"token_ids": output.token_ids[num_output_tokens_so_far:]}
260260
if output.finish_reason:
261261
out["finish_reason"] = output.finish_reason
262-
out[
263-
"completion_usage"
264-
] = BaseWorkerHandler._build_completion_usage(
265-
request_output=res,
262+
out["completion_usage"] = (
263+
BaseWorkerHandler._build_completion_usage(
264+
request_output=res,
265+
)
266266
)
267267
if output.stop_reason:
268268
out["stop_reason"] = output.stop_reason
@@ -309,9 +309,18 @@ async def generate(self, request, context):
309309
include_timing = _request_contains_timing_metrics(request)
310310

311311
# Initialize timing metrics using request_received_seconds from frontend (passed via PreprocessedRequest)
312-
# NOTE: If frontend, prefill workers, and decode workers are running on different machines,
313-
# there may be slight clock drifts between them. As a result, timing values recorded on
314-
# different machines may not be perfectly synchronized and could show minor inconsistencies.
312+
#
313+
# TIMING METRICS:
314+
# - Reliable durations: Use same-machine timestamps (e.g., decode_end - decode_start).
315+
# We use time.perf_counter() for intra-worker duration calculations to ensure monotonic,
316+
# high-resolution timing that's immune to system clock adjustments.
317+
# - Cross-machine calculations (e.g., prefill_start - request_received) assume perfect NTP
318+
# synchronization and should be used with UTMOST CAUTION due to clock drift. Even with NTP,
319+
# clocks can drift by milliseconds each day, leading to negative durations or misleading latency values.
320+
# These cross-machine metrics are useful for rough end-to-end analysis but should not be
321+
# relied upon for precise performance measurements.
322+
# - TODO: Measure actual overhead (network, queueing, etc.) - expected to be low but needs
323+
# benchmarking
315324
timing_metrics: Dict[str, float] = {}
316325
if include_timing:
317326
# Use request_received_seconds from the request (set by frontend) if available
@@ -371,6 +380,7 @@ async def generate(self, request, context):
371380
# Record decode start time
372381
if include_timing:
373382
decode_start_seconds = time.time()
383+
decode_start_perf_counter = time.perf_counter()
374384
# If this is aggregated mode (no prefill_result), prefill_start == decode_start
375385
if prefill_result is None:
376386
timing_metrics["prefill_start_seconds"] = decode_start_seconds
@@ -396,7 +406,9 @@ async def generate(self, request, context):
396406
# On finish, record decode_end_seconds and inject timing_metrics
397407
# Note: request_finish_seconds is set in the Rust HTTP layer when the response actually leaves the server
398408
if tok.get("finish_reason") is not None and include_timing:
399-
timing_metrics["decode_end_seconds"] = time.time()
409+
timing_metrics["decode_end_seconds"] = decode_start_seconds + (
410+
time.perf_counter() - decode_start_perf_counter
411+
)
400412

401413
# Inject timing_metrics into disaggregated_params
402414
if (
@@ -442,9 +454,7 @@ async def generate(self, request, context):
442454
include_timing = _request_contains_timing_metrics(request)
443455

444456
# Initialize timing metrics using request_received_seconds from frontend (passed via PreprocessedRequest)
445-
# NOTE: If frontend, prefill workers, and decode workers are running on different machines,
446-
# there may be slight clock drifts between them. As a result, timing values recorded on
447-
# different machines may not be perfectly synchronized and could show minor inconsistencies.
457+
# See DecodeWorkerHandler.generate() for timing metrics documentation
448458
timing_metrics: Dict[str, float] = {}
449459
if include_timing:
450460
# Use request_received_seconds from the request (set by frontend) if available
@@ -453,7 +463,9 @@ async def generate(self, request, context):
453463
timing_metrics["request_received_seconds"] = frontend_received
454464

455465
# Record prefill_start as when we start processing in the prefill worker
456-
timing_metrics["prefill_start_seconds"] = time.time()
466+
prefill_start_seconds = time.time()
467+
prefill_start_perf_counter = time.perf_counter()
468+
timing_metrics["prefill_start_seconds"] = prefill_start_seconds
457469

458470
# Extract and decode multimodal data if present
459471
multi_modal_data = await self._extract_multimodal_data(request)
@@ -511,12 +523,15 @@ async def generate(self, request, context):
511523
disaggregated_params: Optional[Dict[str, Any]] = {}
512524

513525
if res.kv_transfer_params:
514-
disaggregated_params[
515-
"kv_transfer_params"
516-
] = res.kv_transfer_params
526+
disaggregated_params["kv_transfer_params"] = (
527+
res.kv_transfer_params
528+
)
517529

518530
if include_timing and timing_metrics:
519-
timing_metrics["prefill_end_seconds"] = time.time()
531+
timing_metrics["prefill_end_seconds"] = (
532+
prefill_start_seconds
533+
+ (time.perf_counter() - prefill_start_perf_counter)
534+
)
520535
disaggregated_params["timing_metrics"] = timing_metrics
521536

522537
output: Dict[str, Any] = {

components/src/dynamo/vllm/multimodal_handlers/worker_handler.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -227,9 +227,9 @@ async def generate(self, request: vLLMMultimodalRequest, context):
227227
# Update the prompt token id in the decode request to the one
228228
# in response, which has image templated filled in. So that
229229
# the decode worker will fetch correct amount of KV blocks.
230-
decode_request.engine_prompt[
231-
"prompt_token_ids"
232-
] = prefill_response.prompt_token_ids
230+
decode_request.engine_prompt["prompt_token_ids"] = (
231+
prefill_response.prompt_token_ids
232+
)
233233
logger.debug(
234234
f"Prefill response kv_transfer_params: {prefill_response.kv_transfer_params}"
235235
)

components/src/dynamo/vllm/multimodal_utils/chat_processor.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -178,7 +178,9 @@ async def stream_response(
178178
if request.stream:
179179
# Handle streaming response
180180
num_output_text_so_far = 0
181-
async for raw_response in self.openai_serving.chat_completion_stream_generator(
181+
async for (
182+
raw_response
183+
) in self.openai_serving.chat_completion_stream_generator(
182184
request,
183185
result_generator,
184186
request_id,
@@ -212,7 +214,9 @@ async def stream_response(
212214
# Collect all chunks into a single response
213215
full_response = None
214216
num_output_text_so_far = 0
215-
async for raw_response in self.openai_serving.chat_completion_stream_generator(
217+
async for (
218+
raw_response
219+
) in self.openai_serving.chat_completion_stream_generator(
216220
request,
217221
result_generator,
218222
request_id,

components/src/dynamo/vllm/tests/test_vllm_extra_fields.py renamed to components/src/dynamo/vllm/tests/test_vllm_observability_fields.py

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
22
# SPDX-License-Identifier: Apache-2.0
33

4-
"""Unit tests for extra_fields handling in vLLM handlers."""
4+
"""Unit tests for observability_fields handling in vLLM handlers."""
55

66
import asyncio
77
import warnings
@@ -33,18 +33,18 @@
3333
class TestShouldIncludeTimingMetrics:
3434
"""Tests for _request_contains_timing_metrics helper function."""
3535

36-
def test_returns_true_with_multiple_extra_fields(self):
36+
def test_returns_true_with_multiple_observability_fields(self):
3737
"""Timing metrics should be included when explicitly requested."""
38-
request = {"extra_fields": ["worker_id", "timing_metrics", "other_field"]}
38+
request = {"observability_fields": ["worker_id", "timing_metrics", "other_field"]}
3939
assert _request_contains_timing_metrics(request) is True
4040

41-
def test_returns_false_when_extra_fields_is_none(self):
42-
"""Timing metrics should not be included when extra_fields is None."""
43-
request = {"extra_fields": None}
41+
def test_returns_false_when_observability_fields_is_none(self):
42+
"""Timing metrics should not be included when observability_fields is None."""
43+
request = {"observability_fields": None}
4444
assert _request_contains_timing_metrics(request) is False
4545

46-
def test_returns_false_when_extra_fields_missing(self):
47-
"""Timing metrics should not be included when extra_fields key is absent."""
46+
def test_returns_false_when_observability_fields_missing(self):
47+
"""Timing metrics should not be included when observability_fields key is absent."""
4848
request: dict[str, list[str]] = {}
4949
assert _request_contains_timing_metrics(request) is False
5050

@@ -145,7 +145,7 @@ async def mock_generate(*args, **kwargs):
145145
"token_ids": [1, 2, 3],
146146
"sampling_options": {},
147147
"stop_conditions": {},
148-
"extra_fields": ["timing_metrics"],
148+
"observability_fields": ["timing_metrics"],
149149
"request_received_seconds": 1000.0,
150150
"prefill_result": {
151151
"disaggregated_params": {
@@ -192,7 +192,7 @@ async def mock_generate(*args, **kwargs):
192192
"token_ids": [1, 2, 3],
193193
"sampling_options": {},
194194
"stop_conditions": {},
195-
"extra_fields": ["timing_metrics"],
195+
"observability_fields": ["timing_metrics"],
196196
"request_received_seconds": 1000.0,
197197
}
198198

examples/backends/vllm/launch/disagg_router.sh

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,13 +9,15 @@ export PYTHONHASHSEED=0
99

1010
# Common configuration
1111
MODEL="Qwen/Qwen3-0.6B"
12-
BLOCK_SIZE=64
12+
BLOCK_SIZE=16
13+
NUM_GPU_BLOCKS=20000
1314

1415
# Start frontend with KV routing
1516
# The frontend will automatically detect prefill workers and activate an internal prefill router
1617
# dynamo.frontend accepts either --http-port flag or DYN_HTTP_PORT env var (defaults to 8000)
1718
python -m dynamo.frontend \
1819
--router-mode kv \
20+
--enforce-disagg \
1921
--router-reset-states &
2022

2123
# two decode workers
@@ -24,13 +26,15 @@ CUDA_VISIBLE_DEVICES=0 python3 -m dynamo.vllm \
2426
--model $MODEL \
2527
--block-size $BLOCK_SIZE \
2628
--enforce-eager \
29+
--num-gpu-blocks-override $NUM_GPU_BLOCKS \
2730
--kv-events-config '{"publisher":"zmq","topic":"kv-events","endpoint":"tcp://*:20080","enable_kv_cache_events":true}'&
2831

2932
VLLM_NIXL_SIDE_CHANNEL_PORT=20097 \
3033
CUDA_VISIBLE_DEVICES=1 python3 -m dynamo.vllm \
3134
--model $MODEL \
3235
--block-size $BLOCK_SIZE \
3336
--enforce-eager \
37+
--num-gpu-blocks-override $NUM_GPU_BLOCKS \
3438
--kv-events-config '{"publisher":"zmq","topic":"kv-events","endpoint":"tcp://*:20081","enable_kv_cache_events":true}' &
3539

3640
# two prefill workers
@@ -42,6 +46,7 @@ CUDA_VISIBLE_DEVICES=2 python3 -m dynamo.vllm \
4246
--block-size $BLOCK_SIZE \
4347
--enforce-eager \
4448
--is-prefill-worker \
49+
--num-gpu-blocks-override $NUM_GPU_BLOCKS \
4550
--kv-events-config '{"publisher":"zmq","topic":"kv-events","endpoint":"tcp://*:20082","enable_kv_cache_events":true}'&
4651

4752
VLLM_NIXL_SIDE_CHANNEL_PORT=20099 \
@@ -50,4 +55,5 @@ CUDA_VISIBLE_DEVICES=3 python3 -m dynamo.vllm \
5055
--block-size $BLOCK_SIZE \
5156
--enforce-eager \
5257
--is-prefill-worker \
58+
--num-gpu-blocks-override $NUM_GPU_BLOCKS \
5359
--kv-events-config '{"publisher":"zmq","topic":"kv-events","endpoint":"tcp://*:20083","enable_kv_cache_events":true}'

lib/llm/src/http/service/openai.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,7 @@ pub const ANNOTATION_REQUEST_ID: &str = "request_id";
5858

5959
/// Injects `request_completed_seconds` into the nvext timing_metrics field.
6060
/// This captures the exact moment when the response is about to leave the server.
61-
/// Only injects if timing_metrics already exists (i.e., the user requested it via extra_fields).
61+
/// Only injects if timing_metrics already exists (i.e., the user requested it via observability_fields).
6262
fn inject_request_completed_seconds(nvext: &mut Option<serde_json::Value>) {
6363
let ts = SystemTime::now()
6464
.duration_since(UNIX_EPOCH)

lib/llm/src/kv_router.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -690,7 +690,7 @@ impl AsyncEngine<SingleIn<PreprocessedRequest>, ManyOut<Annotated<LLMEngineOutpu
690690
// Always inject worker_id in first item's disaggregated_params
691691
// This is needed for:
692692
// 1. PrefillRouter to know which prefill worker was chosen
693-
// 2. Client response when extra_fields contains "worker_id"
693+
// 2. Client response when observability_fields contains "worker_id"
694694
if first_item {
695695
first_item = false;
696696

lib/llm/src/preprocessor.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -237,10 +237,10 @@ impl OpenAIPreprocessor {
237237
builder.annotations(request.annotations().unwrap_or_default());
238238
builder.mdc_sum(Some(self.mdcsum.clone()));
239239
builder.estimated_prefix_hit_num_blocks(None);
240-
// Extract backend_instance_id, extra_fields, and request_received_seconds from nvext if present
240+
// Extract backend_instance_id, observability_fields, and request_received_seconds from nvext if present
241241
if let Some(nvext) = request.nvext() {
242242
builder.backend_instance_id(nvext.backend_instance_id);
243-
builder.extra_fields(nvext.extra_fields.clone());
243+
builder.observability_fields(nvext.observability_fields.clone());
244244
builder.request_received_seconds(nvext.request_received_seconds);
245245
}
246246

lib/llm/src/protocols/common/preprocessor.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -97,10 +97,10 @@ pub struct PreprocessedRequest {
9797
#[serde(default, skip_serializing_if = "Option::is_none")]
9898
pub extra_args: Option<serde_json::Value>,
9999

100-
/// Extra fields requested to be included in the response's nvext
100+
/// Observability fields requested to be included in the response's nvext
101101
#[builder(default)]
102102
#[serde(default, skip_serializing_if = "Option::is_none")]
103-
pub extra_fields: Option<Vec<String>>,
103+
pub observability_fields: Option<Vec<String>>,
104104

105105
/// Timestamp when the request was received by the frontend (seconds since epoch)
106106
/// Used for timing metrics to track end-to-end latency

lib/llm/src/protocols/openai/chat_completions/delta.rs

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,7 @@ impl NvCreateChatCompletionRequest {
5050
.unwrap_or(false),
5151
enable_logprobs: self.inner.logprobs.unwrap_or(false)
5252
|| self.inner.top_logprobs.unwrap_or(0) > 0,
53-
extra_fields: self.nvext.as_ref().and_then(|nv| nv.extra_fields.clone()),
53+
observability_fields: self.nvext.as_ref().and_then(|nv| nv.observability_fields.clone()),
5454
runtime_config: ModelRuntimeConfig::default(),
5555
};
5656

@@ -66,7 +66,7 @@ pub struct DeltaGeneratorOptions {
6666
/// Determines whether log probabilities should be included in the response.
6767
pub enable_logprobs: bool,
6868
/// Extra fields to include in response nvext (e.g., "worker_id", "timing_metrics")
69-
pub extra_fields: Option<Vec<String>>,
69+
pub observability_fields: Option<Vec<String>>,
7070

7171
pub runtime_config: ModelRuntimeConfig,
7272
}
@@ -292,10 +292,10 @@ impl DeltaGenerator {
292292
self.options.enable_usage
293293
}
294294

295-
/// Check if an extra field is requested
296-
fn is_extra_field_requested(&self, field: &str) -> bool {
295+
/// Check if an observability field is requested
296+
fn is_observability_field_requested(&self, field: &str) -> bool {
297297
self.options
298-
.extra_fields
298+
.observability_fields
299299
.as_ref()
300300
.map(|fields| fields.iter().any(|f| f == field))
301301
.unwrap_or(false)
@@ -375,19 +375,19 @@ impl crate::protocols::openai::DeltaGeneratorExt<NvCreateChatCompletionStreamRes
375375
let mut stream_response = self.create_choice(index, delta.text, finish_reason, logprobs);
376376

377377
// Extract worker_id and timing_metrics from disaggregated_params and inject into nvext
378-
// Only include fields that were explicitly requested via extra_fields
378+
// Only include fields that were explicitly requested via observability_fields
379379
if let Some(ref disaggregated_params) = delta.disaggregated_params {
380380
let mut nvext_obj = serde_json::Map::new();
381381

382382
// Extract worker_id if present and requested
383-
if self.is_extra_field_requested("worker_id")
383+
if self.is_observability_field_requested("worker_id")
384384
&& let Some(worker_id_json) = disaggregated_params.get("worker_id")
385385
{
386386
nvext_obj.insert("worker_id".to_string(), worker_id_json.clone());
387387
}
388388

389389
// Extract timing_metrics if present and requested
390-
if self.is_extra_field_requested("timing_metrics")
390+
if self.is_observability_field_requested("timing_metrics")
391391
&& let Some(timing_metrics_json) = disaggregated_params.get("timing_metrics")
392392
{
393393
nvext_obj.insert("timing_metrics".to_string(), timing_metrics_json.clone());
@@ -483,7 +483,7 @@ mod tests {
483483
use crate::protocols::openai::DeltaGeneratorExt;
484484

485485
let options = DeltaGeneratorOptions {
486-
extra_fields: Some(vec!["worker_id".to_string(), "timing_metrics".to_string()]),
486+
observability_fields: Some(vec!["worker_id".to_string(), "timing_metrics".to_string()]),
487487
..Default::default()
488488
};
489489
let mut generator = DeltaGenerator::new(

0 commit comments

Comments
 (0)