Skip to content

Commit cdeac19

Browse files
committed
Sgl healthcheck integration + precommit
Signed-off-by: jthomson04 <[email protected]>
1 parent 4a4501a commit cdeac19

File tree

10 files changed

+52
-17
lines changed

10 files changed

+52
-17
lines changed

components/src/dynamo/common/utils/input_params.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ def get_input_param(self, request: dict, use_tokenizer: bool):
88
"""
99

1010
if use_tokenizer:
11+
print(f"Request: {request}")
1112
if self.tokenizer is None:
1213
raise ValueError("Tokenizer is not available")
1314

components/src/dynamo/sglang/health_check.py

Lines changed: 19 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,9 @@ class SglangHealthCheckPayload(HealthCheckPayload):
5353
Provides SGLang defaults and inherits environment override support from base class.
5454
"""
5555

56-
def __init__(self, engine: Optional[sgl.Engine] = None) -> None:
56+
def __init__(
57+
self, engine: Optional[sgl.Engine] = None, use_text_input: bool = False
58+
) -> None:
5759
"""Initialize SGLang health check payload with model-specific BOS token.
5860
5961
Args:
@@ -62,7 +64,6 @@ def __init__(self, engine: Optional[sgl.Engine] = None) -> None:
6264
bos_token_id = _get_bos_token_id_from_engine(engine)
6365

6466
self.default_payload = {
65-
"token_ids": [bos_token_id],
6667
"stop_conditions": {
6768
"max_tokens": 1, # Generate only 1 token
6869
"ignore_eos": False,
@@ -75,6 +76,12 @@ def __init__(self, engine: Optional[sgl.Engine] = None) -> None:
7576
"eos_token_ids": [],
7677
"annotations": [],
7778
}
79+
80+
if use_text_input:
81+
self.default_payload["prompt"] = "Test"
82+
else:
83+
self.default_payload["token_ids"] = [bos_token_id]
84+
7885
super().__init__()
7986

8087

@@ -84,7 +91,9 @@ class SglangPrefillHealthCheckPayload(HealthCheckPayload):
8491
The prefill handler expects a wrapped structure with 'request' and 'sampling_params'.
8592
"""
8693

87-
def __init__(self, engine: Optional[sgl.Engine] = None) -> None:
94+
def __init__(
95+
self, engine: Optional[sgl.Engine] = None, use_text_input: bool = False
96+
) -> None:
8897
"""Initialize SGLang prefill health check payload with proper wrapped structure.
8998
9099
Args:
@@ -93,9 +102,7 @@ def __init__(self, engine: Optional[sgl.Engine] = None) -> None:
93102
bos_token_id = _get_bos_token_id_from_engine(engine)
94103

95104
self.default_payload = {
96-
"request": {
97-
"token_ids": [bos_token_id],
98-
},
105+
"request": {},
99106
"sampling_params": {
100107
"max_new_tokens": 1, # Generate only 1 token
101108
"temperature": 0.0,
@@ -104,4 +111,10 @@ def __init__(self, engine: Optional[sgl.Engine] = None) -> None:
104111
"ignore_eos": False,
105112
},
106113
}
114+
115+
if use_text_input:
116+
self.default_payload["request"]["prompt"] = "Test"
117+
else:
118+
self.default_payload["request"]["token_ids"] = [bos_token_id]
119+
107120
super().__init__()

components/src/dynamo/sglang/main.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -171,8 +171,10 @@ async def stop_profile_handler(body: dict) -> dict:
171171
handler = DecodeWorkerHandler(
172172
component, engine, config, publisher, prefill_client, prefill_router_client
173173
)
174-
175-
health_check_payload = SglangHealthCheckPayload(engine).to_dict()
174+
print(f"Config: {config}")
175+
health_check_payload = SglangHealthCheckPayload(
176+
engine, use_text_input=dynamo_args.use_sglang_tokenizer
177+
).to_dict()
176178

177179
logging.info(
178180
f"Registering model with endpoint types: {dynamo_args.dyn_endpoint_types}"
@@ -325,7 +327,9 @@ async def init_embedding(runtime: DistributedRuntime, config: Config):
325327
ready_event = asyncio.Event()
326328

327329
handler = EmbeddingWorkerHandler(component, engine, config, publisher)
328-
health_check_payload = SglangHealthCheckPayload(engine).to_dict()
330+
health_check_payload = SglangHealthCheckPayload(
331+
engine, use_text_input=dynamo_args.use_sglang_tokenizer
332+
).to_dict()
329333

330334
try:
331335
# Start endpoint immediately and register model concurrently

components/src/dynamo/sglang/request_handlers/handler_base.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -75,10 +75,14 @@ def cleanup(self) -> None:
7575
pass
7676

7777
def _get_input_param(self, request: Dict[str, Any]) -> Dict[str, Any]:
78-
return self.input_param_manager.get_input_param(
78+
request_input = self.input_param_manager.get_input_param(
7979
request, use_tokenizer=not self.skip_tokenizer_init
8080
)
8181

82+
return {
83+
"prompt" if isinstance(request_input, str) else "input_ids": request_input
84+
}
85+
8286
@staticmethod
8387
def _generate_bootstrap_room() -> int:
8488
"""Generate a unique bootstrap room ID for disaggregated serving.

components/src/dynamo/trtllm/health_check.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,7 @@ def _get_bos_token_id_from_tokenizer(tokenizer) -> int:
4747
logger.debug("Using default BOS token ID (1) for health check")
4848
return 1
4949

50+
5051
def _make_default_payload(tokenizer, use_text_input: bool) -> dict:
5152
default_payload = {
5253
"stop_conditions": {
@@ -77,6 +78,7 @@ def _make_default_payload(tokenizer, use_text_input: bool) -> dict:
7778

7879
return default_payload
7980

81+
8082
class TrtllmHealthCheckPayload(HealthCheckPayload):
8183
"""
8284
TRT-LLM-specific health check payload.

components/src/dynamo/trtllm/main.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -435,7 +435,9 @@ async def init(runtime: DistributedRuntime, config: Config):
435435
)
436436

437437
# Get health check payload (checks env var and falls back to TensorRT-LLM default)
438-
health_check_payload = TrtllmHealthCheckPayload(tokenizer=tokenizer, use_text_input=config.use_trtllm_tokenizer).to_dict()
438+
health_check_payload = TrtllmHealthCheckPayload(
439+
tokenizer=tokenizer, use_text_input=config.use_trtllm_tokenizer
440+
).to_dict()
439441

440442
if config.publish_events_and_metrics:
441443
# Initialize and pass in the publisher to the request handler to

components/src/dynamo/trtllm/request_handlers/handler_base.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -323,7 +323,9 @@ async def generate_locally(
323323
stop_token_ids = stop_conditions.get("stop_token_ids_hidden")
324324
if stop_token_ids:
325325
existing = sampling_params.stop_token_ids or []
326-
sampling_params.stop_token_ids = list(set(existing).union(stop_token_ids))
326+
sampling_params.stop_token_ids = list(
327+
set(existing).union(stop_token_ids)
328+
)
327329

328330
# TODO: Instead of True, we should use streaming from the request.
329331
# However, currently dynamo run does not send streaming in the request.

components/src/dynamo/vllm/handlers.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
from contextlib import asynccontextmanager
1111
from typing import Any, AsyncGenerator, Dict, Final
1212

13-
from vllm.inputs import TokensPrompt, TextPrompt
13+
from vllm.inputs import TextPrompt, TokensPrompt
1414
from vllm.lora.request import LoRARequest
1515
from vllm.outputs import RequestOutput
1616
from vllm.sampling_params import SamplingParams, StructuredOutputsParams

components/src/dynamo/vllm/health_check.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,12 +8,15 @@
88
"""
99

1010
import logging
11-
from typing import Optional
11+
from typing import TYPE_CHECKING, Optional
1212

1313
from dynamo.health_check import HealthCheckPayload
1414

1515
logger = logging.getLogger(__name__)
1616

17+
if TYPE_CHECKING:
18+
from vllm.v1.engine.async_llm import AsyncLLM
19+
1720

1821
def _get_bos_token_id_from_engine(engine_client) -> int:
1922
"""
@@ -45,8 +48,10 @@ def _get_bos_token_id_from_engine(engine_client) -> int:
4548
logger.debug("Using default BOS token ID (1) for health check")
4649
return 1
4750

48-
def _make_default_payload(engine_client: Optional["AsyncLLM"], use_text_input: bool) -> dict:
4951

52+
def _make_default_payload(
53+
engine_client: Optional["AsyncLLM"], use_text_input: bool
54+
) -> dict:
5055
sampling_options = {
5156
"temperature": 0.0,
5257
}
@@ -72,7 +77,7 @@ def _make_default_payload(engine_client: Optional["AsyncLLM"], use_text_input: b
7277
"sampling_options": sampling_options,
7378
"stop_conditions": stop_conditions,
7479
}
75-
80+
7681

7782
class VllmHealthCheckPayload(HealthCheckPayload):
7883
"""

components/src/dynamo/vllm/main.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -432,7 +432,9 @@ async def init_prefill(runtime: DistributedRuntime, config: Config):
432432
migration_limit=0, # Prefill doesn't support migration
433433
)
434434

435-
health_check_payload = VllmPrefillHealthCheckPayload(engine_client, use_text_input=config.use_vllm_tokenizer).to_dict()
435+
health_check_payload = VllmPrefillHealthCheckPayload(
436+
engine_client, use_text_input=config.use_vllm_tokenizer
437+
).to_dict()
436438

437439
try:
438440
logger.debug("Starting serve_endpoint for prefill worker")

0 commit comments

Comments
 (0)