Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/actions/pytest/action.yml
Original file line number Diff line number Diff line change
Expand Up @@ -130,4 +130,4 @@ runs:
path: |
test-results/pytest_test_report_${{ inputs.framework }}_${{ env.STR_TEST_TYPE }}_${{ inputs.platform_arch }}.xml
test-results/test_metadata_${{ inputs.framework }}_${{ env.STR_TEST_TYPE }}_${{ inputs.platform_arch }}.json
retention-days: 7
retention-days: 7
20 changes: 19 additions & 1 deletion components/src/dynamo/vllm/args.py
Original file line number Diff line number Diff line change
Expand Up @@ -192,6 +192,24 @@ def parse_args() -> Config:
args = parser.parse_args()
engine_args = AsyncEngineArgs.from_cli_args(args)

# Workaround for vLLM GIL contention bug with NIXL connector when using UniProcExecutor.
# With TP=1, vLLM defaults to UniProcExecutor which runs scheduler and worker in the same
# process. This causes a hot loop in _process_engine_step that doesn't release the GIL,
# blocking NIXL's add_remote_agent from completing. Using "mp" backend forces separate
# processes, avoiding the GIL contention.
# Note: Only apply for NIXL - other connectors (kvbm, lmcache) work fine with UniProcExecutor
# and forcing mp can expose race conditions in vLLM's scheduler.
# See: https://github.com/vllm-project/vllm/issues/29369
connector_list = [c.lower() for c in args.connector] if args.connector else []
uses_nixl = "nixl" in connector_list
tp_size = getattr(engine_args, "tensor_parallel_size", None) or 1
if uses_nixl and tp_size == 1 and engine_args.distributed_executor_backend is None:
logger.info(
"Setting --distributed-executor-backend=mp for TP=1 to avoid "
"UniProcExecutor GIL contention with NIXL connector"
)
engine_args.distributed_executor_backend = "mp"

if engine_args.enable_prefix_caching is None:
logger.debug(
"--enable-prefix-caching or --no-enable-prefix-caching not specified. Defaulting to True (vLLM v1 default behavior)"
Expand Down Expand Up @@ -443,7 +461,7 @@ def overwrite_args(config):
# skip tokenizer initialisation. Setting this to **False** avoids
# a NoneType error when the processor accesses the tokenizer.
"skip_tokenizer_init": False,
"disable_log_requests": True,
"enable_log_requests": False,
"disable_log_stats": False,
}

Expand Down
2 changes: 1 addition & 1 deletion components/src/dynamo/vllm/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -328,7 +328,7 @@ def setup_vllm_engine(config, stat_logger=None):
vllm_config=vllm_config,
usage_context=usage_context,
stat_loggers=factory,
disable_log_requests=engine_args.disable_log_requests,
enable_log_requests=engine_args.enable_log_requests,
disable_log_stats=engine_args.disable_log_stats,
)
if ENABLE_LMCACHE:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from vllm.engine.arg_utils import AsyncEngineArgs
from vllm.entrypoints.openai.protocol import ChatCompletionRequest, CompletionRequest
from vllm.outputs import RequestOutput
from vllm.transformers_utils.tokenizer import AnyTokenizer
from vllm.tokenizers import TokenizerLike as AnyTokenizer

from dynamo.runtime import Client

Expand Down
45 changes: 35 additions & 10 deletions components/src/dynamo/vllm/multimodal_utils/chat_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,9 +28,22 @@
from vllm.entrypoints.openai.serving_chat import OpenAIServingChat
from vllm.entrypoints.openai.serving_completion import OpenAIServingCompletion
from vllm.entrypoints.openai.serving_engine import RequestPrompt
from vllm.entrypoints.openai.serving_models import BaseModelPath, OpenAIServingModels
from vllm.inputs.data import TokensPrompt
from vllm.sampling_params import SamplingParams
from vllm.transformers_utils.tokenizer import AnyTokenizer
from vllm.tokenizers import TokenizerLike as AnyTokenizer


class StubEngineClient:
"""
Stub EngineClient for preprocessing-only use of OpenAIServingChat/Completion.
Provides the minimal attributes required by OpenAIServingModels.
"""

def __init__(self, model_config: ModelConfig):
self.model_config = model_config
self.input_processor = None
self.io_processor = None


@runtime_checkable
Expand Down Expand Up @@ -120,12 +133,19 @@ class ChatProcessor:
def __init__(self, tokenizer: AnyTokenizer, model_config: ModelConfig):
self.tokenizer = tokenizer
self.model_config = model_config
# Create stub engine client and models for preprocessing-only usage
stub_engine = StubEngineClient(model_config)
serving_models = OpenAIServingModels(
engine_client=stub_engine,
base_model_paths=[
BaseModelPath(name=model_config.model, model_path=model_config.model)
],
)
self.openai_serving = OpenAIServingChat(
engine_client=None,
model_config=model_config,
models=None,
request_logger=None,
engine_client=stub_engine,
models=serving_models,
response_role="assistant",
request_logger=None,
chat_template=None,
chat_template_content_format="auto",
)
Expand Down Expand Up @@ -186,7 +206,6 @@ async def stream_response(
conversation,
self.tokenizer,
request_metadata,
enable_force_include_usage=False,
):
if raw_response.startswith("data: [DONE]"):
yield raw_response
Expand Down Expand Up @@ -220,7 +239,6 @@ async def stream_response(
conversation,
self.tokenizer,
request_metadata,
enable_force_include_usage=False,
):
if raw_response.startswith("data: [DONE]"):
break
Expand Down Expand Up @@ -267,10 +285,17 @@ class CompletionsProcessor:
def __init__(self, tokenizer: AnyTokenizer, model_config: ModelConfig):
self.tokenizer = tokenizer
self.model_config = model_config
# Create stub engine client and models for preprocessing-only usage
stub_engine = StubEngineClient(model_config)
serving_models = OpenAIServingModels(
engine_client=stub_engine,
base_model_paths=[
BaseModelPath(name=model_config.model, model_path=model_config.model)
],
)
self.openai_serving = OpenAIServingCompletion(
engine_client=None,
model_config=model_config,
models=None,
engine_client=stub_engine,
models=serving_models,
request_logger=None,
)

Expand Down
6 changes: 3 additions & 3 deletions components/src/dynamo/vllm/multimodal_utils/protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
from vllm.multimodal.inputs import MultiModalUUIDDict # noqa: F401
from vllm.outputs import CompletionOutput
from vllm.sampling_params import SamplingParams
from vllm.sequence import RequestMetrics
from vllm.v1.metrics.stats import RequestStateStats

import dynamo.nixl_connect as connect

Expand Down Expand Up @@ -156,7 +156,7 @@ class MyRequestOutput(BaseModel):
https://github.com/vllm-project/vllm/blob/a4c402a756fa3213caf9d2cde0e4ceb2d57727f2/vllm/outputs.py#L85

This class is used to serialize the RequestOutput and any recursively defined types
We can do this because PromptLogprobs, RequestMetrics, and CompletionOutput are all serializable dataclasses
We can do this because PromptLogprobs, RequestStateStats, and CompletionOutput are all serializable dataclasses
"""

model_config = ConfigDict(arbitrary_types_allowed=True)
Expand All @@ -167,7 +167,7 @@ class MyRequestOutput(BaseModel):
prompt_logprobs: Optional[PromptLogprobs] = None
outputs: List[CompletionOutput]
finished: bool
metrics: Optional[RequestMetrics] = None
metrics: Optional[RequestStateStats] = None
kv_transfer_params: Optional[dict[str, Any]] = None
# lora_request: Optional[LoRARequest] = None
# encoder_prompt: Optional[str] = None
Expand Down
29 changes: 18 additions & 11 deletions container/Dockerfile.vllm
Original file line number Diff line number Diff line change
Expand Up @@ -10,17 +10,17 @@ ARG BASE_IMAGE="nvcr.io/nvidia/cuda-dl-base"
ARG BASE_IMAGE_TAG="25.01-cuda12.8-devel-ubuntu24.04"
ARG ENABLE_KVBM=false
ARG RUNTIME_IMAGE="nvcr.io/nvidia/cuda"
ARG RUNTIME_IMAGE_TAG="12.8.1-runtime-ubuntu24.04"
ARG CUDA_VERSION="12.8"
ARG RUNTIME_IMAGE_TAG="12.9.0-runtime-ubuntu24.04"
ARG CUDA_VERSION="12.9"

# Make sure to update the dependency version in pyproject.toml when updating this
ARG VLLM_REF="v0.11.0"
# FlashInfer only respected when building vLLM from source, ie when VLLM_REF does not start with 'v' or for arm64 builds
ARG FLASHINF_REF="v0.3.1"
ARG TORCH_BACKEND="cu128"
ARG VLLM_REF="v0.12.0"
# FlashInfer Ref used to install flashinfer-cubin and flashinfer-jit-cache
ARG FLASHINF_REF="v0.5.3"

# If left blank, then we will fallback to vLLM defaults
ARG DEEPGEMM_REF=""
ARG LMCACHE_REF="0.3.10"

# sccache configuration - inherit from base build
ARG USE_SCCACHE
Expand Down Expand Up @@ -109,7 +109,7 @@ ARG VLLM_REF
ARG VLLM_GIT_URL
ARG DEEPGEMM_REF
ARG FLASHINF_REF
ARG TORCH_BACKEND
ARG LMCACHE_REF
ARG CUDA_VERSION

ARG MAX_JOBS=16
Expand Down Expand Up @@ -143,7 +143,7 @@ RUN --mount=type=bind,source=./container/deps/,target=/tmp/deps \
export SCCACHE_S3_KEY_PREFIX=${SCCACHE_S3_KEY_PREFIX:-${ARCH}} && \
cp /tmp/deps/vllm/install_vllm.sh /tmp/install_vllm.sh && \
chmod +x /tmp/install_vllm.sh && \
/tmp/install_vllm.sh --editable --vllm-ref $VLLM_REF --max-jobs $MAX_JOBS --arch $ARCH --installation-dir /opt ${DEEPGEMM_REF:+--deepgemm-ref "$DEEPGEMM_REF"} ${FLASHINF_REF:+--flashinf-ref "$FLASHINF_REF"} --torch-backend $TORCH_BACKEND --cuda-version $CUDA_VERSION && \
/tmp/install_vllm.sh --vllm-ref $VLLM_REF --max-jobs $MAX_JOBS --arch $ARCH --installation-dir /opt ${DEEPGEMM_REF:+--deepgemm-ref "$DEEPGEMM_REF"} ${FLASHINF_REF:+--flashinf-ref "$FLASHINF_REF"} ${LMCACHE_REF:+--lmcache-ref "$LMCACHE_REF"} --cuda-version $CUDA_VERSION && \
/tmp/use-sccache.sh show-stats "vLLM";

ENV LD_LIBRARY_PATH=\
Expand Down Expand Up @@ -206,7 +206,7 @@ RUN apt-get update && \
# prometheus dependencies
ca-certificates \
# DeepGemm uses 'cuobjdump' which does not come with CUDA image
cuda-command-line-tools-12-8 && \
cuda-command-line-tools-12-9 && \
rm -rf /var/lib/apt/lists/*

# Copy CUDA development tools (nvcc, headers, dependencies, etc.) from base devel image
Expand Down Expand Up @@ -287,8 +287,14 @@ RUN --mount=type=bind,source=./container/deps/requirements.txt,target=/tmp/requi
--requirement /tmp/requirements.txt \
--requirement /tmp/requirements.test.txt

# Copy benchmarks, examples, and tests for CI with correct ownership
COPY --chown=dynamo: . /workspace/
# Copy tests, benchmarks, deploy and components for CI
COPY --chown=dynamo: benchmarks /workspace/benchmarks
COPY --chown=dynamo: tests /workspace/tests
COPY --chown=dynamo: examples /workspace/examples
COPY --chown=dynamo: deploy /workspace/deploy
COPY --chown=dynamo: recipes/ /workspace/recipes/
COPY --chown=dynamo: components/ /workspace/components/
COPY --chown=dynamo: lib/ /workspace/lib/

# Copy attribution files
COPY --chown=dynamo: ATTRIBUTION* LICENSE /workspace/
Expand Down Expand Up @@ -373,6 +379,7 @@ COPY --from=dynamo_base /usr/local/cargo /usr/local/cargo

# Install maturin, for maturin develop
# Editable install of dynamo
COPY pyproject.toml README.md hatch_build.py /workspace/
RUN uv pip install maturin[patchelf] && \
uv pip install --no-deps -e .

Expand Down
2 changes: 1 addition & 1 deletion container/build.sh
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,7 @@ VLLM_BASE_IMAGE="nvcr.io/nvidia/cuda-dl-base"
# Please check https://github.com/ai-dynamo/dynamo/pull/1065
# for details and reproducer to manually test if the image
# can be updated to later versions.
VLLM_BASE_IMAGE_TAG="25.01-cuda12.8-devel-ubuntu24.04"
VLLM_BASE_IMAGE_TAG="25.04-cuda12.9-devel-ubuntu24.04"

NONE_BASE_IMAGE="nvcr.io/nvidia/cuda-dl-base"
NONE_BASE_IMAGE_TAG="25.01-cuda12.8-devel-ubuntu24.04"
Expand Down
Loading
Loading