Skip to content

Commit 4f6394d

Browse files
authored
Merge branch 'main' into bis/dep-681-add-agg-lora-tests
2 parents 0651568 + 1e5b20b commit 4f6394d

File tree

107 files changed

+9003
-1251
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

107 files changed

+9003
-1251
lines changed

components/src/dynamo/planner/utils/planner_core.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@
2424
PrefillInterpolator,
2525
)
2626
from dynamo.planner.utils.pre_swept_results_utils import PreSweptResultsHelper
27-
from dynamo.planner.utils.prometheus import PrometheusAPIClient
27+
from dynamo.planner.utils.prometheus import MetricSource, PrometheusAPIClient
2828
from dynamo.planner.utils.trace_data_extractor import extract_metrics_from_mooncake
2929
from dynamo.runtime import DistributedRuntime
3030
from dynamo.runtime.logging import configure_dynamo_logging
@@ -150,9 +150,20 @@ def __init__(
150150
else:
151151
raise ValueError(f"Invalid environment: {args.environment}")
152152

153+
# Use backend metrics for vLLM (queries vllm:* metrics directly from workers)
154+
# Use frontend metrics for other backends (queries dynamo_frontend_* metrics)
155+
metric_source = (
156+
MetricSource.VLLM
157+
if args.backend.lower() == "vllm"
158+
else MetricSource.FRONTEND
159+
)
160+
logger.info(
161+
f"Initializing Prometheus client with metric_source='{metric_source}' for backend '{args.backend}'"
162+
)
153163
self.prometheus_api_client = PrometheusAPIClient(
154164
args.metric_pulling_prometheus_endpoint,
155165
args.namespace,
166+
metric_source=metric_source,
156167
)
157168

158169
self.num_req_predictor = LOAD_PREDICTORS[args.load_predictor](

components/src/dynamo/planner/utils/prometheus.py

Lines changed: 221 additions & 46 deletions
Large diffs are not rendered by default.

components/src/dynamo/sglang/main.py

Lines changed: 44 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -103,11 +103,8 @@ async def init(runtime: DistributedRuntime, config: Config):
103103
server_args, dynamo_args = config.server_args, config.dynamo_args
104104

105105
# Prevent SGLang from blocking on non-leader nodes
106-
# We can switch this to 0 and leverage our own metrics
107-
# after https://github.com/sgl-project/sglang/pull/13686
108-
# is merged in
109106
if server_args.node_rank >= 1:
110-
os.environ["SGLANG_BLOCK_NONZERO_RANK_CHILDREN"] = "1"
107+
os.environ["SGLANG_BLOCK_NONZERO_RANK_CHILDREN"] = "0"
111108

112109
engine = sgl.Engine(server_args=server_args)
113110

@@ -222,11 +219,8 @@ async def init_prefill(runtime: DistributedRuntime, config: Config):
222219
server_args, dynamo_args = config.server_args, config.dynamo_args
223220

224221
# Prevent SGLang from blocking on non-leader nodes
225-
# We can switch this to 0 and leverage our own metrics
226-
# after https://github.com/sgl-project/sglang/pull/13686
227-
# is merged in
228222
if server_args.node_rank >= 1:
229-
os.environ["SGLANG_BLOCK_NONZERO_RANK_CHILDREN"] = "1"
223+
os.environ["SGLANG_BLOCK_NONZERO_RANK_CHILDREN"] = "0"
230224

231225
engine = sgl.Engine(server_args=server_args)
232226

@@ -430,16 +424,24 @@ async def init_multimodal_encode_worker(runtime: DistributedRuntime, config: Con
430424

431425
await pd_worker_client.wait_for_instances()
432426

433-
tasks = [
434-
generate_endpoint.serve_endpoint(
435-
handler.generate,
436-
graceful_shutdown=True,
437-
metrics_labels=[("model", server_args.served_model_name)],
438-
)
439-
]
427+
ready_event = asyncio.Event()
440428

441429
try:
442-
await asyncio.gather(*tasks)
430+
await asyncio.gather(
431+
generate_endpoint.serve_endpoint(
432+
handler.generate,
433+
graceful_shutdown=True,
434+
metrics_labels=[("model", server_args.served_model_name)],
435+
),
436+
register_llm_with_readiness_gate(
437+
None, # encode worker doesn't have engine
438+
generate_endpoint,
439+
server_args,
440+
dynamo_args,
441+
input_type=ModelInput.Text,
442+
readiness_gate=ready_event,
443+
),
444+
)
443445
except Exception as e:
444446
logging.error(f"Failed to serve endpoints: {e}")
445447
raise
@@ -473,11 +475,24 @@ async def init_multimodal_worker(runtime: DistributedRuntime, config: Config):
473475

474476
await handler.async_init()
475477

478+
health_check_payload = SglangHealthCheckPayload(engine).to_dict()
479+
ready_event = asyncio.Event()
480+
476481
try:
477-
await generate_endpoint.serve_endpoint(
478-
handler.generate,
479-
metrics_labels=[("model", server_args.served_model_name)],
480-
graceful_shutdown=True,
482+
await asyncio.gather(
483+
generate_endpoint.serve_endpoint(
484+
handler.generate,
485+
metrics_labels=[("model", server_args.served_model_name)],
486+
graceful_shutdown=True,
487+
health_check_payload=health_check_payload,
488+
),
489+
register_llm_with_readiness_gate(
490+
engine,
491+
generate_endpoint,
492+
server_args,
493+
dynamo_args,
494+
readiness_gate=ready_event,
495+
),
481496
)
482497
except Exception as e:
483498
logging.error(f"Failed to serve endpoints: {e}")
@@ -502,6 +517,7 @@ async def init_multimodal_prefill_worker(runtime: DistributedRuntime, config: Co
502517
await handler.async_init()
503518

504519
health_check_payload = SglangPrefillHealthCheckPayload(engine).to_dict()
520+
ready_event = asyncio.Event()
505521

506522
try:
507523
await asyncio.gather(
@@ -510,7 +526,14 @@ async def init_multimodal_prefill_worker(runtime: DistributedRuntime, config: Co
510526
graceful_shutdown=True,
511527
metrics_labels=[("model", server_args.served_model_name)],
512528
health_check_payload=health_check_payload,
513-
)
529+
),
530+
register_llm_with_readiness_gate(
531+
engine,
532+
generate_endpoint,
533+
server_args,
534+
dynamo_args,
535+
readiness_gate=ready_event,
536+
),
514537
)
515538
except Exception as e:
516539
logging.error(f"Failed to serve endpoints: {e}")

components/src/dynamo/sglang/request_handlers/handler_base.py

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,8 @@
22
# SPDX-License-Identifier: Apache-2.0
33

44
import asyncio
5+
import base64
6+
import json
57
import logging
68
import random
79
import socket
@@ -10,6 +12,7 @@
1012
from typing import Any, AsyncGenerator, Dict, Optional, Tuple
1113

1214
import sglang as sgl
15+
from sglang.srt.tracing import trace as sglang_trace
1316
from sglang.srt.utils import get_local_ip_auto
1417

1518
from dynamo._core import Client, Component, Context
@@ -49,6 +52,7 @@ def __init__(
4952
self.prefill_client = prefill_client
5053
self.serving_mode = config.serving_mode
5154
self.skip_tokenizer_init = config.server_args.skip_tokenizer_init
55+
self.enable_trace = config.server_args.enable_trace
5256

5357
@abstractmethod
5458
async def generate(self, request: Dict[str, Any], context: Context):
@@ -117,6 +121,39 @@ def _get_bootstrap_info(engine: sgl.Engine) -> Tuple[str, int]:
117121

118122
return bootstrap_host, bootstrap_port
119123

124+
def _propagate_trace_context_to_sglang(
125+
self, context: Context, bootstrap_room: int = 0
126+
):
127+
"""Propagate Dynamo's trace context to SGLang for distributed tracing. SGLang expects a certain
128+
format derived by loooking at https://github.com/sgl-project/sglang/blob/main/python/sglang/srt/tracing/trace.py
129+
in the to_dict() method.
130+
131+
Args:
132+
context: Dynamo Context object containing trace information.
133+
bootstrap_room: Bootstrap room ID (0 for aggregated, actual room for disaggregated).
134+
"""
135+
trace_id = context.trace_id
136+
span_id = context.span_id
137+
if not trace_id or not span_id:
138+
return
139+
140+
# Build trace context for SGLang
141+
trace_context = {
142+
str(bootstrap_room): {
143+
"root_span": {"traceparent": f"00-{trace_id}-{span_id}-01"},
144+
"prev_span": {
145+
"span_id": int(span_id, 16),
146+
"trace_id": int(trace_id, 16),
147+
},
148+
}
149+
}
150+
151+
# Encode and propagate
152+
base64_context = base64.b64encode(
153+
json.dumps(trace_context, ensure_ascii=False).encode("utf-8")
154+
).decode("utf-8")
155+
sglang_trace.trace_set_remote_propagate_context(base64_context)
156+
120157
async def _handle_cancellation(
121158
self, request_id_future: asyncio.Future, context: Context
122159
):

components/src/dynamo/sglang/request_handlers/llm/decode_handler.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -112,6 +112,7 @@ async def generate(
112112
RuntimeError: If no bootstrap info received from prefill worker.
113113
"""
114114
logging.debug(f"New Request ID: {context.id()}")
115+
trace_id = context.trace_id
115116
sampling_params = self._build_sampling_params(request)
116117
input_param = self._get_input_param(request)
117118

@@ -154,13 +155,19 @@ async def generate(
154155
if not bootstrap_info:
155156
raise RuntimeError("No bootstrap info received from prefill worker")
156157

158+
if self.enable_trace:
159+
self._propagate_trace_context_to_sglang(
160+
context, bootstrap_info["bootstrap_room"]
161+
)
162+
157163
decode = await self.engine.async_generate(
158164
**input_param,
159165
sampling_params=sampling_params,
160166
stream=True,
161167
bootstrap_host=bootstrap_info["bootstrap_host"],
162168
bootstrap_port=bootstrap_info["bootstrap_port"],
163169
bootstrap_room=bootstrap_info["bootstrap_room"],
170+
rid=trace_id,
164171
)
165172

166173
if self.skip_tokenizer_init:
@@ -170,10 +177,14 @@ async def generate(
170177
async for out in self._process_text_stream(decode, context):
171178
yield out
172179
else:
180+
if self.enable_trace:
181+
self._propagate_trace_context_to_sglang(context)
182+
173183
agg = await self.engine.async_generate(
174184
**input_param,
175185
sampling_params=sampling_params,
176186
stream=True,
187+
rid=trace_id,
177188
)
178189
if self.skip_tokenizer_init:
179190
async for out in self._process_token_stream(agg, context):

components/src/dynamo/sglang/request_handlers/llm/prefill_handler.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,7 @@ async def generate(
6464
Bootstrap info dict with host, port, and room for decode worker connection.
6565
"""
6666
logging.debug(f"New Request ID: {context.id()}")
67+
trace_id = context.trace_id
6768
bootstrap_room = self._generate_bootstrap_room()
6869

6970
bootstrap_info = {
@@ -76,13 +77,18 @@ async def generate(
7677

7778
input_param = self._get_input_param(request["request"])
7879

80+
# Propagate trace context to SGLang
81+
if self.enable_trace:
82+
self._propagate_trace_context_to_sglang(context, bootstrap_room)
83+
7984
results = await self.engine.async_generate(
8085
**input_param,
8186
sampling_params=request["sampling_params"],
8287
stream=True,
8388
bootstrap_host=self.bootstrap_host,
8489
bootstrap_port=self.bootstrap_port,
8590
bootstrap_room=bootstrap_room,
91+
rid=trace_id,
8692
)
8793

8894
task = asyncio.create_task(self._consume_results(results, context))

components/src/dynamo/trtllm/request_handlers/handler_base.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -369,8 +369,12 @@ async def generate_locally(
369369

370370
# 2. Per-request errors - send to client, don't shutdown
371371
except RequestError as e:
372-
logging.warning(f"Request {request_id} error: {e}")
373-
yield {"finish_reason": "error", "token_ids": []}
372+
error_msg = str(e)
373+
logging.warning(f"Request {request_id} error: {error_msg}")
374+
yield {
375+
"finish_reason": {"error": error_msg},
376+
"token_ids": [],
377+
}
374378

375379
# 3. ALL OTHER ERRORS - graceful shutdown
376380
except Exception as e:
@@ -384,7 +388,7 @@ async def generate_locally(
384388
# Try to send error to client before shutdown
385389
try:
386390
yield {
387-
"finish_reason": "error",
391+
"finish_reason": {"error": error_msg},
388392
"token_ids": [],
389393
}
390394
except Exception:

components/src/dynamo/vllm/handlers.py

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
from vllm.inputs import TokensPrompt
1313
from vllm.lora.request import LoRARequest
1414
from vllm.outputs import RequestOutput
15-
from vllm.sampling_params import SamplingParams
15+
from vllm.sampling_params import SamplingParams, StructuredOutputsParams
1616
from vllm.v1.engine.exceptions import EngineDeadError
1717

1818
from dynamo.llm import (
@@ -82,8 +82,22 @@ def build_sampling_params(
8282
sampling_params = SamplingParams(**default_sampling_params)
8383
sampling_params.detokenize = False
8484

85-
# Apply sampling_options
85+
# Handle guided_decoding - convert to StructuredOutputsParams
86+
guided_decoding = request["sampling_options"].get("guided_decoding")
87+
if guided_decoding is not None and isinstance(guided_decoding, dict):
88+
sampling_params.structured_outputs = StructuredOutputsParams(
89+
json=guided_decoding.get("json"),
90+
regex=guided_decoding.get("regex"),
91+
choice=guided_decoding.get("choice"),
92+
grammar=guided_decoding.get("grammar"),
93+
whitespace_pattern=guided_decoding.get("whitespace_pattern"),
94+
)
95+
96+
# Apply remaining sampling_options
8697
for key, value in request["sampling_options"].items():
98+
# Skip guided_decoding - already handled above
99+
if key == "guided_decoding":
100+
continue
87101
if value is not None and hasattr(sampling_params, key):
88102
setattr(sampling_params, key, value)
89103

components/src/dynamo/vllm/health_check.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -67,15 +67,14 @@ def __init__(self, engine_client=None):
6767
self.default_payload = {
6868
"token_ids": [bos_token_id],
6969
"sampling_options": {
70-
"max_tokens": 1,
7170
"temperature": 0.0,
7271
},
7372
"stop_conditions": {
73+
"max_tokens": 1,
7474
"stop": None,
7575
"stop_token_ids": None,
7676
"include_stop_str_in_output": False,
7777
"ignore_eos": False,
78-
"min_tokens": 0,
7978
},
8079
}
8180
super().__init__()

0 commit comments

Comments
 (0)