Skip to content

Commit 9fa8125

Browse files
authored
chore: trtllm use unified frontend (#4097)
Signed-off-by: PeaBrane <[email protected]>
1 parent 427ca9a commit 9fa8125

File tree

28 files changed

+158
-646
lines changed

28 files changed

+158
-646
lines changed

benchmarks/router/README.md

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,13 @@ This directory contains scripts for benchmarking the Dynamo router with prefix c
1717
- `matplotlib` for plotting results
1818
- `data-generator` package (install with `pip install -e ./benchmarks` from repo root)
1919

20+
> [!Note]
21+
> If running outside a container, set `DYNAMO_HOME` to the root path of your Dynamo repository:
22+
> ```bash
23+
> export DYNAMO_HOME=/path/to/dynamo
24+
> ```
25+
> When running in a container, this defaults to `/workspace`.
26+
2027
### Setting up etcd and NATS
2128
2229
This benchmark requires etcd and NATS. To quickly set them up, run:

benchmarks/router/run_engines.sh

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -225,7 +225,7 @@ else
225225

226226
if [ "$USE_TRTLLM" = true ]; then
227227
echo "[$MODE_CAPITALIZED Worker-$i] Using GPUs: $GPU_DEVICES"
228-
# Run TensorRT-LLM engine with trtllm-llmapi-launch for proper initialization
228+
# Run TensorRT-LLM engine
229229
TRTLLM_ARGS=()
230230
TRTLLM_ARGS+=("--model-path" "$MODEL_PATH")
231231
TRTLLM_ARGS+=("--tensor-parallel-size" "$TENSOR_PARALLEL_SIZE")
@@ -234,7 +234,7 @@ else
234234
fi
235235
TRTLLM_ARGS+=("${EXTRA_ARGS[@]}")
236236

237-
exec env CUDA_VISIBLE_DEVICES=$GPU_DEVICES trtllm-llmapi-launch python -m dynamo.trtllm \
237+
exec env CUDA_VISIBLE_DEVICES=$GPU_DEVICES trtllm-llmapi-launch python3 -m dynamo.trtllm \
238238
"${TRTLLM_ARGS[@]}"
239239
else
240240
echo "[$MODE_CAPITALIZED Worker-$i] Using GPUs: $GPU_DEVICES"
@@ -252,12 +252,18 @@ else
252252
fi
253253
VLLM_ARGS+=("${EXTRA_ARGS[@]}")
254254

255-
exec env PYTHONHASHSEED=0 CUDA_VISIBLE_DEVICES=$GPU_DEVICES python -m dynamo.vllm \
255+
exec env PYTHONHASHSEED=0 CUDA_VISIBLE_DEVICES=$GPU_DEVICES python3 -m dynamo.vllm \
256256
"${VLLM_ARGS[@]}"
257257
fi
258258
} &
259259
PIDS+=($!)
260260
echo "Started $MODE worker $i (PID: $!)"
261+
262+
# Add delay between TensorRT-LLM worker launches to avoid MPI initialization conflicts
263+
if [ "$USE_TRTLLM" = true ] && [ "$i" -lt "$NUM_WORKERS" ]; then
264+
echo "Waiting 2 seconds before launching next TensorRT-LLM worker..."
265+
sleep 2
266+
fi
261267
done
262268
fi
263269

components/src/dynamo/planner/defaults.py

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -119,16 +119,14 @@ class SGLangComponentName:
119119

120120

121121
class TrtllmComponentName:
122-
# Note: Planner only supports DECODE_FIRST strategy in TRT-LLM:
123-
# - Decode worker is the first worker (tensorrt_llm)
124-
# - Prefill worker is the next worker (tensorrt_llm_next)
122+
# Unified frontend architecture (consistent with vLLM/SGLang):
123+
# - Prefill workers use "prefill" component
124+
# - Decode workers use "tensorrt_llm" component
125125
prefill_worker_k8s_name = "TRTLLMPrefillWorker"
126-
prefill_worker_component_name = (
127-
"tensorrt_llm_next" # Prefill is "next" with DECODE_FIRST
128-
)
126+
prefill_worker_component_name = "prefill"
129127
prefill_worker_endpoint = "generate"
130128
decode_worker_k8s_name = "TRTLLMDecodeWorker"
131-
decode_worker_component_name = "tensorrt_llm" # Decode is "first" with DECODE_FIRST
129+
decode_worker_component_name = "tensorrt_llm"
132130
decode_worker_endpoint = "generate"
133131

134132

components/src/dynamo/trtllm/main.py

Lines changed: 12 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,7 @@
4545
from dynamo.trtllm.health_check import TrtllmHealthCheckPayload
4646
from dynamo.trtllm.multimodal_processor import MultimodalRequestProcessor
4747
from dynamo.trtllm.publisher import get_publisher
48+
from dynamo.trtllm.request_handlers.handler_base import DisaggregationMode
4849
from dynamo.trtllm.request_handlers.handlers import (
4950
RequestHandlerConfig,
5051
RequestHandlerFactory,
@@ -53,7 +54,6 @@
5354
Config,
5455
cmd_line_args,
5556
deep_update,
56-
is_first_worker,
5757
parse_endpoint,
5858
)
5959

@@ -126,37 +126,6 @@ async def init(runtime: DistributedRuntime, config: Config):
126126
"""
127127
logging.info(f"Initializing the worker with config: {config}")
128128

129-
next_client = None
130-
if config.next_endpoint:
131-
logging.info(
132-
f"Initializing next worker client for endpoint: {config.next_endpoint}"
133-
)
134-
parsed_namespace, parsed_component_name, parsed_endpoint_name = parse_endpoint(
135-
config.next_endpoint
136-
)
137-
next_client = (
138-
await runtime.namespace(parsed_namespace)
139-
.component(parsed_component_name)
140-
.endpoint(parsed_endpoint_name)
141-
.client()
142-
)
143-
144-
# Set up prefill router client for decode workers
145-
next_router_client = None
146-
if config.disaggregation_mode.value == "decode":
147-
try:
148-
logging.info("Initializing prefill router client")
149-
next_router_client = (
150-
await runtime.namespace(config.namespace)
151-
.component("router") # Standalone router for prefill workers
152-
.endpoint("generate")
153-
.client()
154-
)
155-
logging.info("Prefill router client initialized successfully")
156-
except Exception as e:
157-
logging.warning(f"Failed to initialize prefill router client: {e}")
158-
logging.info("Will use direct prefill worker client only")
159-
160129
encode_client = None
161130
if config.encode_endpoint:
162131
logging.info(
@@ -273,7 +242,13 @@ async def init(runtime: DistributedRuntime, config: Config):
273242
default_sampling_params._setup(tokenizer)
274243
default_sampling_params.stop = None
275244
model_input = ModelInput.Tokens
276-
model_type = ModelType.Chat | ModelType.Completions
245+
246+
# Set model type based on disaggregation mode for unified frontend support
247+
if config.disaggregation_mode == DisaggregationMode.PREFILL:
248+
model_type = ModelType.Prefill
249+
else:
250+
model_type = ModelType.Chat | ModelType.Completions
251+
277252
multimodal_processor = None
278253

279254
if os.getenv("DYNAMO_ENABLE_TEST_LOGITS_PROCESSOR") == "1":
@@ -376,24 +351,17 @@ async def init(runtime: DistributedRuntime, config: Config):
376351
default_sampling_params=default_sampling_params,
377352
publisher=None,
378353
disaggregation_mode=config.disaggregation_mode,
379-
disaggregation_strategy=config.disaggregation_strategy,
380-
next_client=next_client,
381-
next_router_client=next_router_client,
382354
encode_client=encode_client,
383355
multimodal_processor=multimodal_processor,
384356
connector=connector,
385357
runtime=runtime, # Pass runtime for graceful shutdown
386358
metrics_collector=metrics_collector,
387359
)
388360

389-
if next_client:
390-
logging.info(
391-
f"Waiting for the next endpoint to be ready: {config.next_endpoint}"
392-
)
393-
await next_client.wait_for_instances()
394-
395-
if is_first_worker(config):
396-
# Register the model with runtime config
361+
# Register the model with runtime config
362+
# Encode workers do NOT register - they're internal workers only
363+
# Prefill and decode workers register - frontend detects their role via ModelType
364+
if config.disaggregation_mode != DisaggregationMode.ENCODE:
397365
await register_llm(
398366
model_input,
399367
model_type,

components/src/dynamo/trtllm/request_handlers/handler_base.py

Lines changed: 0 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -52,11 +52,6 @@ class DisaggregationMode(Enum):
5252
ENCODE = "encode"
5353

5454

55-
class DisaggregationStrategy(Enum):
56-
PREFILL_FIRST = "prefill_first"
57-
DECODE_FIRST = "decode_first"
58-
59-
6055
@dataclass
6156
class RequestHandlerConfig:
6257
"""
@@ -68,9 +63,6 @@ class RequestHandlerConfig:
6863
default_sampling_params: SamplingParams
6964
publisher: Publisher
7065
disaggregation_mode: DisaggregationMode
71-
disaggregation_strategy: DisaggregationStrategy
72-
next_client: object
73-
next_router_client: Optional[object] = None
7466
encode_client: Optional[object] = None
7567
multimodal_processor: Optional[
7668
MultimodalRequestProcessor
@@ -94,9 +86,6 @@ def __init__(self, config: RequestHandlerConfig):
9486
self.publisher = config.publisher
9587
self.metrics_collector = config.metrics_collector
9688
self.disaggregation_mode = config.disaggregation_mode
97-
self.disaggregation_strategy = config.disaggregation_strategy
98-
self.next_client = config.next_client
99-
self.next_router_client = config.next_router_client
10089
self.encode_client = config.encode_client
10190
self.multimodal_processor = config.multimodal_processor
10291
self.first_generation = True

0 commit comments

Comments
 (0)