Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from vllm.engine.arg_utils import AsyncEngineArgs
from vllm.entrypoints.openai.protocol import ChatCompletionRequest, CompletionRequest
from vllm.outputs import RequestOutput
from vllm.transformers_utils.tokenizer import AnyTokenizer
from vllm.tokenizers import TokenizerLike as AnyTokenizer

from dynamo.runtime import Client

Expand Down
45 changes: 35 additions & 10 deletions components/src/dynamo/vllm/multimodal_utils/chat_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,9 +28,22 @@
from vllm.entrypoints.openai.serving_chat import OpenAIServingChat
from vllm.entrypoints.openai.serving_completion import OpenAIServingCompletion
from vllm.entrypoints.openai.serving_engine import RequestPrompt
from vllm.entrypoints.openai.serving_models import BaseModelPath, OpenAIServingModels
from vllm.inputs.data import TokensPrompt
from vllm.sampling_params import SamplingParams
from vllm.transformers_utils.tokenizer import AnyTokenizer
from vllm.tokenizers import TokenizerLike as AnyTokenizer


class StubEngineClient:
"""
Stub EngineClient for preprocessing-only use of OpenAIServingChat/Completion.
Provides the minimal attributes required by OpenAIServingModels.
"""

def __init__(self, model_config: ModelConfig):
self.model_config = model_config
self.input_processor = None
self.io_processor = None


@runtime_checkable
Expand Down Expand Up @@ -120,12 +133,19 @@ class ChatProcessor:
def __init__(self, tokenizer: AnyTokenizer, model_config: ModelConfig):
self.tokenizer = tokenizer
self.model_config = model_config
# Create stub engine client and models for preprocessing-only usage
stub_engine = StubEngineClient(model_config)
serving_models = OpenAIServingModels(
engine_client=stub_engine,
base_model_paths=[
BaseModelPath(name=model_config.model, model_path=model_config.model)
],
)
self.openai_serving = OpenAIServingChat(
engine_client=None,
model_config=model_config,
models=None,
request_logger=None,
engine_client=stub_engine,
models=serving_models,
response_role="assistant",
request_logger=None,
chat_template=None,
chat_template_content_format="auto",
)
Expand Down Expand Up @@ -186,7 +206,6 @@ async def stream_response(
conversation,
self.tokenizer,
request_metadata,
enable_force_include_usage=False,
):
if raw_response.startswith("data: [DONE]"):
yield raw_response
Expand Down Expand Up @@ -220,7 +239,6 @@ async def stream_response(
conversation,
self.tokenizer,
request_metadata,
enable_force_include_usage=False,
):
if raw_response.startswith("data: [DONE]"):
break
Expand Down Expand Up @@ -267,10 +285,17 @@ class CompletionsProcessor:
def __init__(self, tokenizer: AnyTokenizer, model_config: ModelConfig):
self.tokenizer = tokenizer
self.model_config = model_config
# Create stub engine client and models for preprocessing-only usage
stub_engine = StubEngineClient(model_config)
serving_models = OpenAIServingModels(
engine_client=stub_engine,
base_model_paths=[
BaseModelPath(name=model_config.model, model_path=model_config.model)
],
)
self.openai_serving = OpenAIServingCompletion(
engine_client=None,
model_config=model_config,
models=None,
engine_client=stub_engine,
models=serving_models,
request_logger=None,
)

Expand Down
6 changes: 3 additions & 3 deletions components/src/dynamo/vllm/multimodal_utils/protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
from vllm.multimodal.inputs import MultiModalUUIDDict # noqa: F401
from vllm.outputs import CompletionOutput
from vllm.sampling_params import SamplingParams
from vllm.sequence import RequestMetrics
from vllm.v1.metrics.stats import RequestStateStats

import dynamo.nixl_connect as connect

Expand Down Expand Up @@ -156,7 +156,7 @@ class MyRequestOutput(BaseModel):
https://github.com/vllm-project/vllm/blob/a4c402a756fa3213caf9d2cde0e4ceb2d57727f2/vllm/outputs.py#L85

This class is used to serialize the RequestOutput and any recursively defined types
We can do this because PromptLogprobs, RequestMetrics, and CompletionOutput are all serializable dataclasses
We can do this because PromptLogprobs, RequestStateStats, and CompletionOutput are all serializable dataclasses
"""

model_config = ConfigDict(arbitrary_types_allowed=True)
Expand All @@ -167,7 +167,7 @@ class MyRequestOutput(BaseModel):
prompt_logprobs: Optional[PromptLogprobs] = None
outputs: List[CompletionOutput]
finished: bool
metrics: Optional[RequestMetrics] = None
metrics: Optional[RequestStateStats] = None
kv_transfer_params: Optional[dict[str, Any]] = None
# lora_request: Optional[LoRARequest] = None
# encoder_prompt: Optional[str] = None
Expand Down
133 changes: 132 additions & 1 deletion docs/backends/vllm/multimodal.md
Original file line number Diff line number Diff line change
Expand Up @@ -213,7 +213,7 @@

#### Workflow

In this workflow, we have [MultimodalPDWorkerHandler](../../../components/src/dynamo/vllm/multimodal_handlers/worker_handler.py) which will encode the image, prefill and decode the prompt, just like the [LLM aggregated serving](README.md) example.
In this workflow, we have [VllmPDWorker](components/worker.py) which will encode the image, prefill and decode the prompt, just like the [LLM aggregated serving](README.md) example.

This figure illustrates the workflow:
```mermaid
Expand Down Expand Up @@ -512,3 +512,134 @@
"usage": null
}
```
## Multimodal Aggregated Audio Serving

This example demonstrates deploying an aggregated multimodal model that can process audio inputs.

### Components

- workers: For audio serving, we use the [AudioEncodeWorker](components/audio_encode_worker.py) for decoding audio into audio embeddings, and send the embeddings to [VllmPDWorker](components/worker.py) for prefilling and decoding.
- processor: Tokenizes the prompt and passes it to the AudioEncodeWorker.
- frontend: HTTP endpoint to handle incoming requests.

### Workflow

In this workflow, we have two workers, [AudioEncodeWorker](components/audio_encode_worker.py) and [VllmPDWorker](components/worker.py).
The AudioEncodeWorker is responsible for decoding the audio into embeddings.
Its VllmPDWorker then prefills and decodes the prompt, just like the [LLM aggregated serving](README.md) example.
By separating the audio processing from the prefill and decode stages, we can have a more flexible deployment and scale the
AudioEncodeWorker independently from the prefill and decode workers if needed.

This figure illustrates the workflow:
```mermaid
flowchart LR
HTTP --> processor
processor --> HTTP
processor --audio_url--> audio_encode_worker
audio_encode_worker --> processor
audio_encode_worker --embeddings--> pd_worker
pd_worker --> audio_encode_worker
```

```bash
pip install vllm["audio"] accelerate # multimodal audio models dependency
cd $DYNAMO_HOME/examples/multimodal
bash launch/audio_agg.sh
```

### Client

In another terminal:
```bash
curl http://localhost:8000/v1/chat/completions \
-H "Content-Type: application/json" \
-d '{
"model": "Qwen/Qwen2-Audio-7B-Instruct",
"messages": [
{
"role": "user",
"content": [
{
"type": "text",
"text": "What is recited in the audio?"
},
{
"type": "audio_url",
"audio_url": {
"url": "https://raw.githubusercontent.com/yuekaizhang/Triton-ASR-Client/main/datasets/mini_en/wav/1221-135766-0002.wav"
}
}
]
}
],
"max_tokens": 6000,
"temperature": 0.8,
"stream": false
}' | jq
```

You should see a response describing the audio's content similar to
```json
{
"id": "e2d8d67c37634b309400974eaa058ce8",
"choices": [
{
"index": 0,
"message": {
"content": "The original content of this audio is:'yet these thoughts affected Hester Pynne less with hope than apprehension.'",
"refusal": null,
"tool_calls": null,
"role": "assistant",
"function_call": null,
"audio": null
},
"finish_reason": "stop",
"logprobs": null
}
],
"created": 1756368148,
"model": "Qwen/Qwen2-Audio-7B-Instruct",
"service_tier": null,
"system_fingerprint": null,
"object": "chat.completion",
"usage": null
}
```

## Multimodal Disaggregated Audio Serving

This example demonstrates deploying a disaggregated multimodal model that can process audio inputs.

### Components

- workers: For disaggregated audio serving, we have three workers, [AudioEncodeWorker](components/audio_encode_worker.py) for decoding audio into embeddings,
[VllmDecodeWorker](components/worker.py) for decoding, and [VllmPDWorker](components/worker.py) for prefilling.
- processor: Tokenizes the prompt and passes it to the AudioEncodeWorker.
- frontend: HTTP endpoint to handle incoming requests.

### Workflow

In this workflow, we have three workers, [AudioEncodeWorker](components/audio_encode_worker.py), [VllmDecodeWorker](components/worker.py), and [VllmPDWorker](components/worker.py).
For the Qwen/Qwen2-Audio-7B-Instruct model, audio embeddings are only required during the prefill stage. As such, the AudioEncodeWorker is connected directly to the prefill worker.
The AudioEncodeWorker is responsible for decoding the audio into embeddings and passing them to the prefill worker via RDMA.
The prefill worker performs the prefilling step and forwards the KV cache to the decode worker for decoding.
For more details on the roles of the prefill and decode workers, refer to the [LLM disaggregated serving](../../docs/backends/vllm/README.md) example.

Check failure on line 626 in docs/backends/vllm/multimodal.md

View workflow job for this annotation

GitHub Actions / Check for broken markdown links

Broken link: [LLM disaggregated serving](../../docs/backends/vllm/README.md) - View: https://github.com/ai-dynamo/dynamo/blob/HEAD/docs/backends/vllm/multimodal.md?plain=1#L626

This figure illustrates the workflow:
```mermaid
flowchart LR
HTTP --> processor
processor --> HTTP
processor --audio_url--> audio_encode_worker
audio_encode_worker --> processor
audio_encode_worker --embeddings--> prefill_worker
prefill_worker --> audio_encode_worker
prefill_worker --> decode_worker
decode_worker --> prefill_worker
```

```bash
pip install vllm["audio"] accelerate # multimodal audio models dependency
cd $DYNAMO_HOME/examples/multimodal
bash launch/audio_disagg.sh
```
2 changes: 1 addition & 1 deletion examples/backends/vllm/launch/agg_multimodal_epd.sh
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ python -m dynamo.vllm --multimodal-processor --enable-multimodal --model $MODEL_

# run E/P/D workers
CUDA_VISIBLE_DEVICES=0 python -m dynamo.vllm --multimodal-encode-worker --enable-multimodal --model $MODEL_NAME &
CUDA_VISIBLE_DEVICES=1 python -m dynamo.vllm --multimodal-worker --enable-multimodal --model $MODEL_NAME $EXTRA_ARGS &
CUDA_VISIBLE_DEVICES=0 python -m dynamo.vllm --multimodal-worker --enable-multimodal --enable-mm-embeds --model $MODEL_NAME $EXTRA_ARGS &

# Wait for all background processes to complete
wait
15 changes: 6 additions & 9 deletions examples/backends/vllm/launch/disagg_multimodal_epd.sh
Original file line number Diff line number Diff line change
Expand Up @@ -80,23 +80,20 @@ python -m dynamo.vllm --multimodal-processor --enable-multimodal --model $MODEL_

# Configure GPU memory optimization for specific models
EXTRA_ARGS=""
if [[ "$MODEL_NAME" == "Qwen/Qwen2.5-VL-7B-Instruct" ]]; then
EXTRA_ARGS="--gpu-memory-utilization 0.85 --max-model-len 2048"
fi

# Start encode worker
echo "Starting encode worker on GPU 1..."
VLLM_NIXL_SIDE_CHANNEL_PORT=20097 CUDA_VISIBLE_DEVICES=1 python -m dynamo.vllm --multimodal-encode-worker --enable-multimodal --model $MODEL_NAME $EXTRA_ARGS --kv-events-config '{"publisher":"zmq","topic":"kv-events","endpoint":"tcp://*:20080"}' &
echo "Starting encode worker on GPU 0..."
VLLM_NIXL_SIDE_CHANNEL_PORT=20097 CUDA_VISIBLE_DEVICES=0 python -m dynamo.vllm --multimodal-encode-worker --enable-multimodal --model $MODEL_NAME $EXTRA_ARGS --kv-events-config '{"publisher":"zmq","topic":"kv-events","endpoint":"tcp://*:20080"}' &

# Start prefill worker
echo "Starting prefill worker on GPU 2..."
echo "Starting prefill worker on GPU 1..."
VLLM_NIXL_SIDE_CHANNEL_PORT=20098 \
CUDA_VISIBLE_DEVICES=2 python -m dynamo.vllm --multimodal-worker --is-prefill-worker --enable-multimodal --model $MODEL_NAME $EXTRA_ARGS --kv-events-config '{"publisher":"zmq","topic":"kv-events","endpoint":"tcp://*:20081"}' &
CUDA_VISIBLE_DEVICES=1 python -m dynamo.vllm --multimodal-worker --is-prefill-worker --enable-multimodal --enable-mm-embeds --model $MODEL_NAME $EXTRA_ARGS --kv-events-config '{"publisher":"zmq","topic":"kv-events","endpoint":"tcp://*:20081"}' &

# Start decode worker
echo "Starting decode worker on GPU 3..."
echo "Starting decode worker on GPU 2..."
VLLM_NIXL_SIDE_CHANNEL_PORT=20099 \
CUDA_VISIBLE_DEVICES=3 python -m dynamo.vllm --multimodal-decode-worker --enable-multimodal --model $MODEL_NAME $EXTRA_ARGS --kv-events-config '{"publisher":"zmq","topic":"kv-events","endpoint":"tcp://*:20082"}' &
CUDA_VISIBLE_DEVICES=2 python -m dynamo.vllm --multimodal-decode-worker --enable-multimodal --model $MODEL_NAME $EXTRA_ARGS --kv-events-config '{"publisher":"zmq","topic":"kv-events","endpoint":"tcp://*:20082"}' &

echo "=================================================="
echo "All components started. Waiting for initialization..."
Expand Down
Loading
Loading