diff --git a/.github/actions/pytest/action.yml b/.github/actions/pytest/action.yml index cca684695b..2c89e0174f 100644 --- a/.github/actions/pytest/action.yml +++ b/.github/actions/pytest/action.yml @@ -130,4 +130,4 @@ runs: path: | test-results/pytest_test_report_${{ inputs.framework }}_${{ env.STR_TEST_TYPE }}_${{ inputs.platform_arch }}.xml test-results/test_metadata_${{ inputs.framework }}_${{ env.STR_TEST_TYPE }}_${{ inputs.platform_arch }}.json - retention-days: 7 \ No newline at end of file + retention-days: 7 diff --git a/components/src/dynamo/vllm/args.py b/components/src/dynamo/vllm/args.py index 57711ba570..3b8f4d1988 100644 --- a/components/src/dynamo/vllm/args.py +++ b/components/src/dynamo/vllm/args.py @@ -192,6 +192,24 @@ def parse_args() -> Config: args = parser.parse_args() engine_args = AsyncEngineArgs.from_cli_args(args) + # Workaround for vLLM GIL contention bug with NIXL connector when using UniProcExecutor. + # With TP=1, vLLM defaults to UniProcExecutor which runs scheduler and worker in the same + # process. This causes a hot loop in _process_engine_step that doesn't release the GIL, + # blocking NIXL's add_remote_agent from completing. Using "mp" backend forces separate + # processes, avoiding the GIL contention. + # Note: Only apply for NIXL - other connectors (kvbm, lmcache) work fine with UniProcExecutor + # and forcing mp can expose race conditions in vLLM's scheduler. + # See: https://github.com/vllm-project/vllm/issues/29369 + connector_list = [c.lower() for c in args.connector] if args.connector else [] + uses_nixl = "nixl" in connector_list + tp_size = getattr(engine_args, "tensor_parallel_size", None) or 1 + if uses_nixl and tp_size == 1 and engine_args.distributed_executor_backend is None: + logger.info( + "Setting --distributed-executor-backend=mp for TP=1 to avoid " + "UniProcExecutor GIL contention with NIXL connector" + ) + engine_args.distributed_executor_backend = "mp" + if engine_args.enable_prefix_caching is None: logger.debug( "--enable-prefix-caching or --no-enable-prefix-caching not specified. Defaulting to True (vLLM v1 default behavior)" @@ -443,7 +461,7 @@ def overwrite_args(config): # skip tokenizer initialisation. Setting this to **False** avoids # a NoneType error when the processor accesses the tokenizer. "skip_tokenizer_init": False, - "disable_log_requests": True, + "enable_log_requests": False, "disable_log_stats": False, } diff --git a/components/src/dynamo/vllm/main.py b/components/src/dynamo/vllm/main.py index 0f48b4a32c..c3e4288521 100644 --- a/components/src/dynamo/vllm/main.py +++ b/components/src/dynamo/vllm/main.py @@ -328,7 +328,7 @@ def setup_vllm_engine(config, stat_logger=None): vllm_config=vllm_config, usage_context=usage_context, stat_loggers=factory, - disable_log_requests=engine_args.disable_log_requests, + enable_log_requests=engine_args.enable_log_requests, disable_log_stats=engine_args.disable_log_stats, ) if ENABLE_LMCACHE: diff --git a/components/src/dynamo/vllm/multimodal_handlers/processor_handler.py b/components/src/dynamo/vllm/multimodal_handlers/processor_handler.py index eb84c20190..1ee10d02cd 100644 --- a/components/src/dynamo/vllm/multimodal_handlers/processor_handler.py +++ b/components/src/dynamo/vllm/multimodal_handlers/processor_handler.py @@ -11,7 +11,7 @@ from vllm.engine.arg_utils import AsyncEngineArgs from vllm.entrypoints.openai.protocol import ChatCompletionRequest, CompletionRequest from vllm.outputs import RequestOutput -from vllm.transformers_utils.tokenizer import AnyTokenizer +from vllm.tokenizers import TokenizerLike as AnyTokenizer from dynamo.runtime import Client diff --git a/components/src/dynamo/vllm/multimodal_utils/chat_processor.py b/components/src/dynamo/vllm/multimodal_utils/chat_processor.py index fe8d95dc81..3a693131d9 100644 --- a/components/src/dynamo/vllm/multimodal_utils/chat_processor.py +++ b/components/src/dynamo/vllm/multimodal_utils/chat_processor.py @@ -28,9 +28,22 @@ from vllm.entrypoints.openai.serving_chat import OpenAIServingChat from vllm.entrypoints.openai.serving_completion import OpenAIServingCompletion from vllm.entrypoints.openai.serving_engine import RequestPrompt +from vllm.entrypoints.openai.serving_models import BaseModelPath, OpenAIServingModels from vllm.inputs.data import TokensPrompt from vllm.sampling_params import SamplingParams -from vllm.transformers_utils.tokenizer import AnyTokenizer +from vllm.tokenizers import TokenizerLike as AnyTokenizer + + +class StubEngineClient: + """ + Stub EngineClient for preprocessing-only use of OpenAIServingChat/Completion. + Provides the minimal attributes required by OpenAIServingModels. + """ + + def __init__(self, model_config: ModelConfig): + self.model_config = model_config + self.input_processor = None + self.io_processor = None @runtime_checkable @@ -120,12 +133,19 @@ class ChatProcessor: def __init__(self, tokenizer: AnyTokenizer, model_config: ModelConfig): self.tokenizer = tokenizer self.model_config = model_config + # Create stub engine client and models for preprocessing-only usage + stub_engine = StubEngineClient(model_config) + serving_models = OpenAIServingModels( + engine_client=stub_engine, + base_model_paths=[ + BaseModelPath(name=model_config.model, model_path=model_config.model) + ], + ) self.openai_serving = OpenAIServingChat( - engine_client=None, - model_config=model_config, - models=None, - request_logger=None, + engine_client=stub_engine, + models=serving_models, response_role="assistant", + request_logger=None, chat_template=None, chat_template_content_format="auto", ) @@ -186,7 +206,6 @@ async def stream_response( conversation, self.tokenizer, request_metadata, - enable_force_include_usage=False, ): if raw_response.startswith("data: [DONE]"): yield raw_response @@ -220,7 +239,6 @@ async def stream_response( conversation, self.tokenizer, request_metadata, - enable_force_include_usage=False, ): if raw_response.startswith("data: [DONE]"): break @@ -267,10 +285,17 @@ class CompletionsProcessor: def __init__(self, tokenizer: AnyTokenizer, model_config: ModelConfig): self.tokenizer = tokenizer self.model_config = model_config + # Create stub engine client and models for preprocessing-only usage + stub_engine = StubEngineClient(model_config) + serving_models = OpenAIServingModels( + engine_client=stub_engine, + base_model_paths=[ + BaseModelPath(name=model_config.model, model_path=model_config.model) + ], + ) self.openai_serving = OpenAIServingCompletion( - engine_client=None, - model_config=model_config, - models=None, + engine_client=stub_engine, + models=serving_models, request_logger=None, ) diff --git a/components/src/dynamo/vllm/multimodal_utils/protocol.py b/components/src/dynamo/vllm/multimodal_utils/protocol.py index ef8d2bea91..c05f6cdeeb 100644 --- a/components/src/dynamo/vllm/multimodal_utils/protocol.py +++ b/components/src/dynamo/vllm/multimodal_utils/protocol.py @@ -26,7 +26,7 @@ from vllm.multimodal.inputs import MultiModalUUIDDict # noqa: F401 from vllm.outputs import CompletionOutput from vllm.sampling_params import SamplingParams -from vllm.sequence import RequestMetrics +from vllm.v1.metrics.stats import RequestStateStats import dynamo.nixl_connect as connect @@ -156,7 +156,7 @@ class MyRequestOutput(BaseModel): https://github.com/vllm-project/vllm/blob/a4c402a756fa3213caf9d2cde0e4ceb2d57727f2/vllm/outputs.py#L85 This class is used to serialize the RequestOutput and any recursively defined types - We can do this because PromptLogprobs, RequestMetrics, and CompletionOutput are all serializable dataclasses + We can do this because PromptLogprobs, RequestStateStats, and CompletionOutput are all serializable dataclasses """ model_config = ConfigDict(arbitrary_types_allowed=True) @@ -167,7 +167,7 @@ class MyRequestOutput(BaseModel): prompt_logprobs: Optional[PromptLogprobs] = None outputs: List[CompletionOutput] finished: bool - metrics: Optional[RequestMetrics] = None + metrics: Optional[RequestStateStats] = None kv_transfer_params: Optional[dict[str, Any]] = None # lora_request: Optional[LoRARequest] = None # encoder_prompt: Optional[str] = None diff --git a/container/Dockerfile.vllm b/container/Dockerfile.vllm index 3cb388c3fc..e78f6f2fb8 100644 --- a/container/Dockerfile.vllm +++ b/container/Dockerfile.vllm @@ -10,17 +10,17 @@ ARG BASE_IMAGE="nvcr.io/nvidia/cuda-dl-base" ARG BASE_IMAGE_TAG="25.01-cuda12.8-devel-ubuntu24.04" ARG ENABLE_KVBM=false ARG RUNTIME_IMAGE="nvcr.io/nvidia/cuda" -ARG RUNTIME_IMAGE_TAG="12.8.1-runtime-ubuntu24.04" -ARG CUDA_VERSION="12.8" +ARG RUNTIME_IMAGE_TAG="12.9.0-runtime-ubuntu24.04" +ARG CUDA_VERSION="12.9" # Make sure to update the dependency version in pyproject.toml when updating this -ARG VLLM_REF="v0.11.0" -# FlashInfer only respected when building vLLM from source, ie when VLLM_REF does not start with 'v' or for arm64 builds -ARG FLASHINF_REF="v0.3.1" -ARG TORCH_BACKEND="cu128" +ARG VLLM_REF="v0.12.0" +# FlashInfer Ref used to install flashinfer-cubin and flashinfer-jit-cache +ARG FLASHINF_REF="v0.5.3" # If left blank, then we will fallback to vLLM defaults ARG DEEPGEMM_REF="" +ARG LMCACHE_REF="0.3.10" # sccache configuration - inherit from base build ARG USE_SCCACHE @@ -109,7 +109,7 @@ ARG VLLM_REF ARG VLLM_GIT_URL ARG DEEPGEMM_REF ARG FLASHINF_REF -ARG TORCH_BACKEND +ARG LMCACHE_REF ARG CUDA_VERSION ARG MAX_JOBS=16 @@ -143,7 +143,7 @@ RUN --mount=type=bind,source=./container/deps/,target=/tmp/deps \ export SCCACHE_S3_KEY_PREFIX=${SCCACHE_S3_KEY_PREFIX:-${ARCH}} && \ cp /tmp/deps/vllm/install_vllm.sh /tmp/install_vllm.sh && \ chmod +x /tmp/install_vllm.sh && \ - /tmp/install_vllm.sh --editable --vllm-ref $VLLM_REF --max-jobs $MAX_JOBS --arch $ARCH --installation-dir /opt ${DEEPGEMM_REF:+--deepgemm-ref "$DEEPGEMM_REF"} ${FLASHINF_REF:+--flashinf-ref "$FLASHINF_REF"} --torch-backend $TORCH_BACKEND --cuda-version $CUDA_VERSION && \ + /tmp/install_vllm.sh --vllm-ref $VLLM_REF --max-jobs $MAX_JOBS --arch $ARCH --installation-dir /opt ${DEEPGEMM_REF:+--deepgemm-ref "$DEEPGEMM_REF"} ${FLASHINF_REF:+--flashinf-ref "$FLASHINF_REF"} ${LMCACHE_REF:+--lmcache-ref "$LMCACHE_REF"} --cuda-version $CUDA_VERSION && \ /tmp/use-sccache.sh show-stats "vLLM"; ENV LD_LIBRARY_PATH=\ @@ -206,7 +206,7 @@ RUN apt-get update && \ # prometheus dependencies ca-certificates \ # DeepGemm uses 'cuobjdump' which does not come with CUDA image - cuda-command-line-tools-12-8 && \ + cuda-command-line-tools-12-9 && \ rm -rf /var/lib/apt/lists/* # Copy CUDA development tools (nvcc, headers, dependencies, etc.) from base devel image @@ -287,8 +287,14 @@ RUN --mount=type=bind,source=./container/deps/requirements.txt,target=/tmp/requi --requirement /tmp/requirements.txt \ --requirement /tmp/requirements.test.txt -# Copy benchmarks, examples, and tests for CI with correct ownership -COPY --chown=dynamo: . /workspace/ +# Copy tests, benchmarks, deploy and components for CI +COPY --chown=dynamo: benchmarks /workspace/benchmarks +COPY --chown=dynamo: tests /workspace/tests +COPY --chown=dynamo: examples /workspace/examples +COPY --chown=dynamo: deploy /workspace/deploy +COPY --chown=dynamo: recipes/ /workspace/recipes/ +COPY --chown=dynamo: components/ /workspace/components/ +COPY --chown=dynamo: lib/ /workspace/lib/ # Copy attribution files COPY --chown=dynamo: ATTRIBUTION* LICENSE /workspace/ @@ -373,6 +379,7 @@ COPY --from=dynamo_base /usr/local/cargo /usr/local/cargo # Install maturin, for maturin develop # Editable install of dynamo +COPY pyproject.toml README.md hatch_build.py /workspace/ RUN uv pip install maturin[patchelf] && \ uv pip install --no-deps -e . diff --git a/container/build.sh b/container/build.sh index a7d092f047..607e086230 100755 --- a/container/build.sh +++ b/container/build.sh @@ -106,7 +106,7 @@ VLLM_BASE_IMAGE="nvcr.io/nvidia/cuda-dl-base" # Please check https://github.com/ai-dynamo/dynamo/pull/1065 # for details and reproducer to manually test if the image # can be updated to later versions. -VLLM_BASE_IMAGE_TAG="25.01-cuda12.8-devel-ubuntu24.04" +VLLM_BASE_IMAGE_TAG="25.04-cuda12.9-devel-ubuntu24.04" NONE_BASE_IMAGE="nvcr.io/nvidia/cuda-dl-base" NONE_BASE_IMAGE_TAG="25.01-cuda12.8-devel-ubuntu24.04" diff --git a/container/deps/vllm/install_vllm.sh b/container/deps/vllm/install_vllm.sh index 0ebbb58823..8365deecf6 100755 --- a/container/deps/vllm/install_vllm.sh +++ b/container/deps/vllm/install_vllm.sh @@ -2,18 +2,16 @@ # SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-License-Identifier: Apache-2.0 -# This script is used to install vLLM and its dependencies -# If installing vLLM from a release tag, we will use pip to manage the install -# Otherwise, we will use git to checkout the vLLM source code and build it from source. -# The dependencies are installed in the following order: -# 1. vLLM -# 2. LMCache +# This script installs vLLM and its dependencies from PyPI (release versions only). +# Installation order: +# 1. LMCache (installed first so vLLM's dependencies take precedence) +# 2. vLLM # 3. DeepGEMM # 4. EP kernels set -euo pipefail -VLLM_REF="v0.11.0" +VLLM_REF="v0.12.0" # Basic Configurations ARCH=$(uname -m) @@ -21,34 +19,19 @@ MAX_JOBS=16 INSTALLATION_DIR=/tmp # VLLM and Dependency Configurations -TORCH_BACKEND="cu128" TORCH_CUDA_ARCH_LIST="9.0;10.0" # For EP Kernels DEEPGEMM_REF="" -CUDA_VERSION="12.8" # For DEEPGEMM - -# These flags are applicable when installing vLLM from source code -EDITABLE=true -VLLM_GIT_URL="https://github.com/vllm-project/vllm.git" -FLASHINF_REF="v0.3.1" +CUDA_VERSION="12.9" +FLASHINF_REF="v0.5.3" +# LMCache version - 0.3.9+ required for vLLM 0.11.2 compatibility +LMCACHE_REF="0.3.10" while [[ $# -gt 0 ]]; do case $1 in - --editable) - EDITABLE=true - shift - ;; - --no-editable) - EDITABLE=false - shift - ;; --vllm-ref) VLLM_REF="$2" shift 2 ;; - --vllm-git-url) - VLLM_GIT_URL="$2" - shift 2 - ;; --max-jobs) MAX_JOBS="$2" shift 2 @@ -69,8 +52,8 @@ while [[ $# -gt 0 ]]; do FLASHINF_REF="$2" shift 2 ;; - --torch-backend) - TORCH_BACKEND="$2" + --lmcache-ref) + LMCACHE_REF="$2" shift 2 ;; --torch-cuda-arch-list) @@ -82,19 +65,17 @@ while [[ $# -gt 0 ]]; do shift 2 ;; -h|--help) - echo "Usage: $0 [--editable|--no-editable] [--vllm-ref REF] [--max-jobs NUM] [--arch ARCH] [--deepgemm-ref REF] [--flashinf-ref REF] [--torch-backend BACKEND] [--torch-cuda-arch-list LIST] [--cuda-version VERSION]" + echo "Usage: $0 [--vllm-ref REF] [--max-jobs NUM] [--arch ARCH] [--deepgemm-ref REF] [--flashinf-ref REF] [--lmcache-ref REF] [--torch-cuda-arch-list LIST] [--cuda-version VERSION]" echo "Options:" - echo " --editable Install vllm in editable mode (default)" - echo " --no-editable Install vllm in non-editable mode" - echo " --vllm-ref REF Git reference to checkout (default: ${VLLM_REF})" - echo " --max-jobs NUM Maximum number of parallel jobs (default: ${MAX_JOBS})" - echo " --arch ARCH Architecture (amd64|arm64, default: auto-detect)" - echo " --installation-dir DIR Directory to install vllm (default: ${INSTALLATION_DIR})" - echo " --deepgemm-ref REF Git reference for DeepGEMM (default: ${DEEPGEMM_REF})" - echo " --flashinf-ref REF Git reference for Flash Infer (default: ${FLASHINF_REF})" - echo " --torch-backend BACKEND Torch backend to use (default: ${TORCH_BACKEND})" - echo " --torch-cuda-arch-list LIST CUDA architectures to compile for (default: ${TORCH_CUDA_ARCH_LIST})" - echo " --cuda-version VERSION CUDA version to use (default: ${CUDA_VERSION})" + echo " --vllm-ref REF vLLM release version (default: ${VLLM_REF})" + echo " --max-jobs NUM Maximum parallel jobs (default: ${MAX_JOBS})" + echo " --arch ARCH Architecture amd64|arm64 (default: auto-detect)" + echo " --installation-dir DIR Install directory (default: ${INSTALLATION_DIR})" + echo " --deepgemm-ref REF DeepGEMM git ref (default: ${DEEPGEMM_REF})" + echo " --flashinf-ref REF FlashInfer version (default: ${FLASHINF_REF})" + echo " --lmcache-ref REF LMCache version (default: ${LMCACHE_REF})" + echo " --torch-cuda-arch-list LIST CUDA architectures (default: ${TORCH_CUDA_ARCH_LIST})" + echo " --cuda-version VERSION CUDA version (default: ${CUDA_VERSION})" exit 0 ;; *) @@ -114,119 +95,43 @@ fi export MAX_JOBS=$MAX_JOBS export CUDA_HOME=/usr/local/cuda +# Derive torch backend from CUDA version (e.g., "12.9" -> "cu129") +TORCH_BACKEND="cu$(echo $CUDA_VERSION | tr -d '.')" + echo "=== Installing prerequisites ===" uv pip install pip cuda-python echo "\n=== Configuration Summary ===" -echo " VLLM_REF=$VLLM_REF | EDITABLE=$EDITABLE | ARCH=$ARCH" -echo " MAX_JOBS=$MAX_JOBS | TORCH_BACKEND=$TORCH_BACKEND | CUDA_VERSION=$CUDA_VERSION" -echo " TORCH_CUDA_ARCH_LIST=$TORCH_CUDA_ARCH_LIST" -echo " DEEPGEMM_REF=$DEEPGEMM_REF | FLASHINF_REF=$FLASHINF_REF" -echo " INSTALLATION_DIR=$INSTALLATION_DIR | VLLM_GIT_URL=$VLLM_GIT_URL" +echo " VLLM_REF=$VLLM_REF | ARCH=$ARCH | CUDA_VERSION=$CUDA_VERSION | TORCH_BACKEND=$TORCH_BACKEND" +echo " FLASHINF_REF=$FLASHINF_REF | LMCACHE_REF=$LMCACHE_REF | DEEPGEMM_REF=$DEEPGEMM_REF" +echo " TORCH_CUDA_ARCH_LIST=$TORCH_CUDA_ARCH_LIST | INSTALLATION_DIR=$INSTALLATION_DIR" + +echo "\n=== Installing LMCache ===" +if [ "$ARCH" = "amd64" ]; then + # LMCache installation currently fails on arm64 due to CUDA dependency issues + # Install LMCache BEFORE vLLM so vLLM's dependencies take precedence + uv pip install lmcache==${LMCACHE_REF} --torch-backend=${TORCH_BACKEND} + echo "✓ LMCache ${LMCACHE_REF} installed" +else + echo "⚠ Skipping LMCache on ARM64 (compatibility issues)" +fi echo "\n=== Cloning vLLM repository ===" -# We need to clone to install dependencies +# Clone needed for DeepGEMM and EP kernels install scripts cd $INSTALLATION_DIR -git clone $VLLM_GIT_URL vllm +git clone https://github.com/vllm-project/vllm.git vllm cd vllm git checkout $VLLM_REF -# TODO leave this here in case we need to do cherry-picks in future -# GIT_COMMITTER_NAME="Container Build" GIT_COMMITTER_EMAIL="container@buildkitsandbox.local" git cherry-pick 740f064 - echo "\n=== Installing vLLM & FlashInfer ===" +echo "Installing vLLM $VLLM_REF from PyPI..." -if [[ $VLLM_REF =~ ^v ]] && { [ "$ARCH" = "amd64" ] || { [ "$ARCH" = "arm64" ] && [ "$TORCH_BACKEND" = "cu129" ]; }; }; then - # VLLM_REF starts with 'v' and either amd64, or arm64 with cu129 backend - use PyPI install - echo "Installing vLLM $VLLM_REF from PyPI... (ARCH=$ARCH, TORCH_BACKEND=$TORCH_BACKEND)" - - uv pip install vllm[flashinfer]==$VLLM_REF --torch-backend=$TORCH_BACKEND - -else - # VLLM_REF does not start with 'v' or amd64 - use git checkout path - if [ "$ARCH" = "arm64" ]; then - - # torch 2.8.0 doesn't have a aarch wheel for cu128, vLLM uses torch 2.8.0 nightly wheel builds to compile its aarch wheel against - # nightly can be unstable so we will not use it here - # for now we will use torch 2.7.1+cu128 but this requires a recompilation from source - - echo "Building vLLM from source for ARM64 architecture..." - - # Try to install specific PyTorch version first - echo "Attempting to install pinned PyTorch nightly versions..." - if ! uv pip install torch==2.7.1+cu128 torchaudio==2.7.1 torchvision==0.22.1 --index-url https://download.pytorch.org/whl/cu128; then - echo "Pinned versions failed" - exit 1 - fi - - # Create constraints file to pin all PyTorch-related versions - echo "Creating constraints file to preserve PyTorch ecosystem versions..." - TORCH_VERSION=$(python -c "import torch; print(torch.__version__)") - TORCHAUDIO_VERSION=$(python -c "import torchaudio; print(torchaudio.__version__)") - TORCHVISION_VERSION=$(python -c "import torchvision; print(torchvision.__version__)") - - rm -rf /tmp/torch_constraints.txt - echo "torch==$TORCH_VERSION" > /tmp/torch_constraints.txt - echo "torchaudio==$TORCHAUDIO_VERSION" >> /tmp/torch_constraints.txt - echo "torchvision==$TORCHVISION_VERSION" >> /tmp/torch_constraints.txt - - echo "Pinned versions:" - echo " - torch==$TORCH_VERSION" - echo " - torchaudio==$TORCHAUDIO_VERSION" - echo " - torchvision==$TORCHVISION_VERSION" - - python use_existing_torch.py - uv pip install -c /tmp/torch_constraints.txt -r requirements/build.txt - - if [ "$EDITABLE" = "true" ]; then - MAX_JOBS=${MAX_JOBS} uv pip install --no-build-isolation -c /tmp/torch_constraints.txt -e . -v - else - MAX_JOBS=${MAX_JOBS} uv pip install --no-build-isolation -c /tmp/torch_constraints.txt . -v - fi - - echo "\n=== Installing FlashInfer from source ===" - cd $INSTALLATION_DIR - git clone https://github.com/flashinfer-ai/flashinfer.git --recursive - cd flashinfer - git checkout $FLASHINF_REF - - # Install with constraints to prevent PyTorch upgrade - uv pip install -v --no-build-isolation -c /tmp/torch_constraints.txt . - - else - echo "Building vLLM from source for AMD64 architecture..." - - # When updating above VLLM_REF make sure precompiled wheel file URL is correct. Run this command: - # aws s3 ls s3://vllm-wheels/${VLLM_REF}/ --region us-west-2 --no-sign-request - export VLLM_PRECOMPILED_WHEEL_LOCATION="https://vllm-wheels.s3.us-west-2.amazonaws.com/${VLLM_REF}/vllm-0.10.2-cp38-abi3-manylinux1_x86_64.whl" - - if [ "$EDITABLE" = "true" ]; then - uv pip install -e . --torch-backend=$TORCH_BACKEND - else - uv pip install . --torch-backend=$TORCH_BACKEND - fi - - echo "\n=== Installing FlashInfer from PyPI ===" - uv pip install flashinfer-python==$FLASHINF_REF - - fi -fi +uv pip install vllm[flashinfer]==$VLLM_REF --torch-backend=${TORCH_BACKEND} +uv pip install flashinfer-cubin==$FLASHINF_REF +uv pip install flashinfer-jit-cache==$FLASHINF_REF --extra-index-url https://flashinfer.ai/whl/${TORCH_BACKEND} echo "✓ vLLM installation completed" -echo "\n=== Installing LMCache ===" -if [ "$ARCH" = "amd64" ]; then - # LMCache installation currently fails on arm64 due to CUDA dependency issues: - # OSError: CUDA_HOME environment variable is not set. Please set it to your CUDA install root. - # TODO: Re-enable for arm64 after verifying lmcache compatibility and resolving the build issue. - - # Alec: Likely lmcache was compiled witha different version of torch and need to install it from source for arm64 - uv pip install lmcache==0.3.7 - echo "✓ LMCache installed" -else - echo "⚠ Skipping LMCache on ARM64 (compatibility issues)" -fi - echo "\n=== Installing DeepGEMM ===" cd $INSTALLATION_DIR/vllm/tools @@ -239,6 +144,7 @@ echo "✓ DeepGEMM installation completed" echo "\n=== Installing EP Kernels (PPLX and DeepEP) ===" cd ep_kernels/ +# TODO we will be able to specify which pplx and deepep commit we want in future TORCH_CUDA_ARCH_LIST="$TORCH_CUDA_ARCH_LIST" bash install_python_libraries.sh echo "\n✅ All installations completed successfully!" diff --git a/examples/backends/sglang/slurm_jobs/scripts/vllm/benchmark_serving.py b/examples/backends/sglang/slurm_jobs/scripts/vllm/benchmark_serving.py index f9c67be7bc..a5962afe17 100644 --- a/examples/backends/sglang/slurm_jobs/scripts/vllm/benchmark_serving.py +++ b/examples/backends/sglang/slurm_jobs/scripts/vllm/benchmark_serving.py @@ -8,7 +8,6 @@ vLLM OpenAI API server vllm serve \ --swap-space 16 \ - --disable-log-requests (TGI backend) ./launch_tgi_server.sh diff --git a/examples/backends/vllm/deploy/agg_kvbm.yaml b/examples/backends/vllm/deploy/agg_kvbm.yaml index 542ab367b9..570c0e6f61 100644 --- a/examples/backends/vllm/deploy/agg_kvbm.yaml +++ b/examples/backends/vllm/deploy/agg_kvbm.yaml @@ -42,7 +42,6 @@ spec: - Qwen/Qwen3-8B - --gpu-memory-utilization - "0.45" - - --disable-log-requests - --max-model-len - "32000" - --enforce-eager diff --git a/examples/backends/vllm/deploy/disagg_kvbm.yaml b/examples/backends/vllm/deploy/disagg_kvbm.yaml index f5f179d9bc..919f3cfd24 100644 --- a/examples/backends/vllm/deploy/disagg_kvbm.yaml +++ b/examples/backends/vllm/deploy/disagg_kvbm.yaml @@ -35,7 +35,6 @@ spec: - Qwen/Qwen3-8B - --gpu-memory-utilization - "0.3" - - --disable-log-requests - --max-model-len - "32000" - --enforce-eager @@ -68,7 +67,6 @@ spec: - --is-prefill-worker - --gpu-memory-utilization - "0.3" - - --disable-log-requests - --max-model-len - "32000" - --enforce-eager diff --git a/examples/backends/vllm/deploy/disagg_kvbm_2p2d.yaml b/examples/backends/vllm/deploy/disagg_kvbm_2p2d.yaml index f9aa875781..01915200fa 100644 --- a/examples/backends/vllm/deploy/disagg_kvbm_2p2d.yaml +++ b/examples/backends/vllm/deploy/disagg_kvbm_2p2d.yaml @@ -35,7 +35,6 @@ spec: - Qwen/Qwen3-8B - --gpu-memory-utilization - "0.3" - - --disable-log-requests - --max-model-len - "32000" - --enforce-eager @@ -68,7 +67,6 @@ spec: - --is-prefill-worker - --gpu-memory-utilization - "0.3" - - --disable-log-requests - --max-model-len - "32000" - --enforce-eager diff --git a/examples/backends/vllm/deploy/disagg_kvbm_tp2.yaml b/examples/backends/vllm/deploy/disagg_kvbm_tp2.yaml index a282825ff6..80752b1f88 100644 --- a/examples/backends/vllm/deploy/disagg_kvbm_tp2.yaml +++ b/examples/backends/vllm/deploy/disagg_kvbm_tp2.yaml @@ -37,7 +37,6 @@ spec: - Qwen/Qwen3-8B - --gpu-memory-utilization - "0.23" - - --disable-log-requests - --max-model-len - "32000" - --enforce-eager @@ -72,7 +71,6 @@ spec: - --is-prefill-worker - --gpu-memory-utilization - "0.23" - - --disable-log-requests - --max-model-len - "32000" - --enforce-eager diff --git a/examples/backends/vllm/launch/agg_multimodal_epd.sh b/examples/backends/vllm/launch/agg_multimodal_epd.sh index b35de20f89..b7bb3124ed 100755 --- a/examples/backends/vllm/launch/agg_multimodal_epd.sh +++ b/examples/backends/vllm/launch/agg_multimodal_epd.sh @@ -77,7 +77,7 @@ python -m dynamo.vllm --multimodal-processor --enable-multimodal --model $MODEL_ # run E/P/D workers CUDA_VISIBLE_DEVICES=0 python -m dynamo.vllm --multimodal-encode-worker --enable-multimodal --model $MODEL_NAME & -CUDA_VISIBLE_DEVICES=1 python -m dynamo.vllm --multimodal-worker --enable-multimodal --model $MODEL_NAME $EXTRA_ARGS & +CUDA_VISIBLE_DEVICES=0 python -m dynamo.vllm --multimodal-worker --enable-multimodal --enable-mm-embeds --model $MODEL_NAME $EXTRA_ARGS & # Wait for all background processes to complete wait diff --git a/examples/backends/vllm/launch/disagg_multimodal_epd.sh b/examples/backends/vllm/launch/disagg_multimodal_epd.sh index b392a83946..09c1187c68 100755 --- a/examples/backends/vllm/launch/disagg_multimodal_epd.sh +++ b/examples/backends/vllm/launch/disagg_multimodal_epd.sh @@ -80,23 +80,20 @@ python -m dynamo.vllm --multimodal-processor --enable-multimodal --model $MODEL_ # Configure GPU memory optimization for specific models EXTRA_ARGS="" -if [[ "$MODEL_NAME" == "Qwen/Qwen2.5-VL-7B-Instruct" ]]; then - EXTRA_ARGS="--gpu-memory-utilization 0.85 --max-model-len 2048" -fi # Start encode worker -echo "Starting encode worker on GPU 1..." -VLLM_NIXL_SIDE_CHANNEL_PORT=20097 CUDA_VISIBLE_DEVICES=1 python -m dynamo.vllm --multimodal-encode-worker --enable-multimodal --model $MODEL_NAME $EXTRA_ARGS --kv-events-config '{"publisher":"zmq","topic":"kv-events","endpoint":"tcp://*:20080"}' & +echo "Starting encode worker on GPU 0..." +VLLM_NIXL_SIDE_CHANNEL_PORT=20097 CUDA_VISIBLE_DEVICES=0 python -m dynamo.vllm --multimodal-encode-worker --enable-multimodal --model $MODEL_NAME $EXTRA_ARGS --kv-events-config '{"publisher":"zmq","topic":"kv-events","endpoint":"tcp://*:20080"}' & # Start prefill worker -echo "Starting prefill worker on GPU 2..." +echo "Starting prefill worker on GPU 1..." VLLM_NIXL_SIDE_CHANNEL_PORT=20098 \ -CUDA_VISIBLE_DEVICES=2 python -m dynamo.vllm --multimodal-worker --is-prefill-worker --enable-multimodal --model $MODEL_NAME $EXTRA_ARGS --kv-events-config '{"publisher":"zmq","topic":"kv-events","endpoint":"tcp://*:20081"}' & +CUDA_VISIBLE_DEVICES=1 python -m dynamo.vllm --multimodal-worker --is-prefill-worker --enable-multimodal --enable-mm-embeds --model $MODEL_NAME $EXTRA_ARGS --kv-events-config '{"publisher":"zmq","topic":"kv-events","endpoint":"tcp://*:20081"}' & # Start decode worker -echo "Starting decode worker on GPU 3..." +echo "Starting decode worker on GPU 2..." VLLM_NIXL_SIDE_CHANNEL_PORT=20099 \ -CUDA_VISIBLE_DEVICES=3 python -m dynamo.vllm --multimodal-decode-worker --enable-multimodal --model $MODEL_NAME $EXTRA_ARGS --kv-events-config '{"publisher":"zmq","topic":"kv-events","endpoint":"tcp://*:20082"}' & +CUDA_VISIBLE_DEVICES=2 python -m dynamo.vllm --multimodal-decode-worker --enable-multimodal --model $MODEL_NAME $EXTRA_ARGS --kv-events-config '{"publisher":"zmq","topic":"kv-events","endpoint":"tcp://*:20082"}' & echo "==================================================" echo "All components started. Waiting for initialization..." diff --git a/examples/multimodal/components/audio_encode_worker.py b/examples/multimodal/components/audio_encode_worker.py new file mode 100644 index 0000000000..4384ec2e9c --- /dev/null +++ b/examples/multimodal/components/audio_encode_worker.py @@ -0,0 +1,307 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import argparse +import asyncio +import logging +import os +import signal +import sys +from typing import AsyncIterator, Tuple + +import torch +import uvloop +from transformers import AutoProcessor, Qwen2AudioForConditionalGeneration +from vllm.engine.arg_utils import AsyncEngineArgs +from vllm.utils.argparse_utils import FlexibleArgumentParser + +import dynamo.nixl_connect as connect +from dynamo.runtime import Client, DistributedRuntime, dynamo_worker +from dynamo.runtime.logging import configure_dynamo_logging + +sys.path.append(os.path.join(os.path.dirname(os.path.abspath(__file__)), "..")) +from utils.args import Config, base_parse_args, parse_endpoint +from utils.audio_loader import AudioLoader +from utils.protocol import MyRequestOutput, vLLMMultimodalRequest + +configure_dynamo_logging() +logger = logging.getLogger(__name__) + +try: + import cupy as array_module + + if not array_module.cuda.is_available(): + raise ImportError("CUDA is not available.") + DEVICE = "cuda" + logger.info("Using cupy for array operations (GPU mode).") +except ImportError as e: + logger.warning(f"Failed to import cupy, falling back to numpy: {e}.") + import numpy as array_module + + DEVICE = "cpu" + +CACHE_SIZE_MAXIMUM = 8 + + +class VllmEncodeWorker: + def __init__( + self, + args: argparse.Namespace, + engine_args: AsyncEngineArgs, + pd_worker_client: Client, + ) -> None: + self.pd_worker_client = pd_worker_client + self.engine_args = engine_args + self.model = self.engine_args.model + + self.audio_loader = AudioLoader(cache_size=CACHE_SIZE_MAXIMUM) + self.audio_processor = AutoProcessor.from_pretrained( + self.model, trust_remote_code=True + ) + self.audio_model = Qwen2AudioForConditionalGeneration.from_pretrained( + self.model, device_map="auto", dtype=torch.float16 + ).eval() + + def get_audio_embeddings(self, audio_features): + input_features, feature_attention_mask = ( + audio_features.input_features, + audio_features.feature_attention_mask, + ) + with torch.no_grad(): + ( + audio_feat_lengths, + audio_output_lengths, + ) = self.audio_model.audio_tower._get_feat_extract_output_lengths( + feature_attention_mask.sum(-1) + ) + batch_size, _, max_mel_seq_len = input_features.shape + max_seq_len = (max_mel_seq_len - 2) // 2 + 1 + # Create a sequence tensor of shape (batch_size, max_seq_len) + seq_range = ( + torch.arange( + 0, + max_seq_len, + dtype=audio_feat_lengths.dtype, + device=audio_feat_lengths.device, + ) + .unsqueeze(0) + .expand(batch_size, max_seq_len) + ) + lengths_expand = audio_feat_lengths.unsqueeze(1).expand( + batch_size, max_seq_len + ) + # Create mask + padding_mask = seq_range >= lengths_expand + + audio_attention_mask_ = padding_mask.view( + batch_size, 1, 1, max_seq_len + ).expand(batch_size, 1, max_seq_len, max_seq_len) + audio_attention_mask = audio_attention_mask_.to( + dtype=self.audio_model.audio_tower.conv1.weight.dtype, + device=self.audio_model.audio_tower.conv1.weight.device, + ) + audio_attention_mask[audio_attention_mask_] = float("-inf") + + audio_outputs = self.audio_model.audio_tower( + input_features, attention_mask=audio_attention_mask + ) + selected_audio_feature = audio_outputs.last_hidden_state + audio_features = self.audio_model.multi_modal_projector( + selected_audio_feature + ) + + num_audios, max_audio_tokens, embed_dim = audio_features.shape + audio_features_mask = torch.arange( + max_audio_tokens, device=audio_output_lengths.device + )[None, :] + audio_features_mask = audio_features_mask < audio_output_lengths[:, None] + audio_features = audio_features[audio_features_mask] + + return audio_features + + def cleanup(self): + pass + + async def generate( + self, request: vLLMMultimodalRequest + ) -> AsyncIterator[MyRequestOutput]: + logger.debug(f"Got raw request: {request}") + if not isinstance(request, vLLMMultimodalRequest): + if isinstance(request, str): + request = vLLMMultimodalRequest.model_validate_json(request) + else: + request = vLLMMultimodalRequest.model_validate(request) + logger.debug(f"Received encode request: {{ id: {request.request_id} }}.") + + request_id = request.request_id + + # The following steps encode the requested audio and provided useful embeddings. + # 1. Open the audio from the provided URL. + # 2. Process the audio using the audio processor. + # 3. Run the audio through the audio model's audio tower. + # 4. Run the results of the audio tower through the multi-modal projector. + # 5. Create a descriptor for the embeddings. + # 6. Create a write operation using the serialized request and the descriptor. + # 7. Await for the write operation to complete. + # 8. Yield the encode response. + + try: + audio, sr = await self.audio_loader.load_audio( + request.multimodal_input.audio_url + ) + + audio_features = self.audio_processor( + text="test<|AUDIO|>", audio=audio, return_tensors="pt", padding=False + ) + with torch.no_grad(): + audio_embeddings = self.get_audio_embeddings(audio_features) + descriptor = connect.Descriptor(audio_embeddings) + with await self._connector.create_readable(descriptor) as readable: + request.serialized_request = readable.metadata() + # Clear the audio URL as hint that the audio is passed as embeddings. + request.multimodal_input.audio_url = None + request.embeddings_shape = tuple(audio_embeddings.shape) + logger.debug(f"Request: {request.model_dump_json()}") + + response_generator = await self.pd_worker_client.round_robin( + request.model_dump_json() + ) + + await readable.wait_for_completion() + + async for response in response_generator: + output = MyRequestOutput.model_validate_json(response.data()) + yield MyRequestOutput( + request_id=output.request_id, + prompt=output.prompt, + prompt_token_ids=output.prompt_token_ids, + prompt_logprobs=output.prompt_logprobs, + outputs=output.outputs, + finished=output.finished, + ).model_dump_json() + + except Exception as e: + logger.error(f"Error processing request {request_id}: {e}") + raise + + async def async_init(self, runtime: DistributedRuntime): + logger.info("Startup started.") + # Create and initialize a dynamo connector for this worker. + # We'll needs this to move data between this worker and remote workers efficiently. + self._connector = connect.Connector() + + logger.info("Startup completed.") + + @classmethod + def parse_args(cls) -> Tuple[argparse.Namespace, Config]: + DYN_NAMESPACE = os.environ.get("DYN_NAMESPACE", "dynamo") + DEFAULT_ENDPOINT = f"dyn://{DYN_NAMESPACE}.encoder.generate" + DEFAULT_DOWNSTREAM_ENDPOINT = f"dyn://{DYN_NAMESPACE}.llm.generate" + + parser = FlexibleArgumentParser( + description="vLLM based encoder for Dynamo LLM." + ) + parser.add_argument( + "--endpoint", + type=str, + default=DEFAULT_ENDPOINT, + help=f"Dynamo endpoint string in 'dyn://namespace.component.endpoint' format. Default: '{DEFAULT_ENDPOINT}'", + ) + parser.add_argument( + "--downstream-endpoint", + type=str, + default=DEFAULT_DOWNSTREAM_ENDPOINT, + help=f"The endpoint string of the downstream LLM in 'dyn://namespace.component.endpoint' format. Default: '{DEFAULT_DOWNSTREAM_ENDPOINT}'", + ) + + args, config = base_parse_args(parser) + + return args, config + + +async def graceful_shutdown(runtime): + """ + By calling `runtime.shutdown()`, the endpoints will immediately be unavailable. + However, in-flight requests will still be processed until they are finished. + After all in-flight requests are finished, the `serve_endpoint` functions will return + and the engine will be shutdown by Python's garbage collector. + """ + logging.info("Received shutdown signal, shutting down DistributedRuntime") + runtime.shutdown() + logging.info("DistributedRuntime shutdown complete") + + +@dynamo_worker() +async def worker(runtime: DistributedRuntime): + # Runtime setup + # Set up signal handler for graceful shutdown + loop = asyncio.get_running_loop() + + def signal_handler(): + asyncio.create_task(graceful_shutdown(runtime)) + + for sig in (signal.SIGTERM, signal.SIGINT): + loop.add_signal_handler(sig, signal_handler) + + logging.info("Signal handlers set up for graceful shutdown") + + # worker setup + args, config = VllmEncodeWorker.parse_args() + await init(runtime, args, config) + + +async def init(runtime: DistributedRuntime, args: argparse.Namespace, config: Config): + """ + Instantiate and serve + """ + + component = runtime.namespace(config.namespace).component(config.component) + + generate_endpoint = component.endpoint(config.endpoint) + + parsed_namespace, parsed_component_name, parsed_endpoint_name = parse_endpoint( + args.downstream_endpoint + ) + pd_worker_client = ( + await runtime.namespace(parsed_namespace) + .component(parsed_component_name) + .endpoint(parsed_endpoint_name) + .client() + ) + + handler = VllmEncodeWorker(args, config.engine_args, pd_worker_client) + await handler.async_init(runtime) + + logger.info("Waiting for PD Worker Instances ...") + await pd_worker_client.wait_for_instances() + + logger.info(f"Starting to serve the {args.endpoint} endpoint...") + + try: + await asyncio.gather( + generate_endpoint.serve_endpoint( + handler.generate, metrics_labels=[("model", config.model)] + ), + ) + except Exception as e: + logger.error(f"Failed to serve endpoints: {e}") + raise + finally: + handler.cleanup() + + +if __name__ == "__main__": + uvloop.install() + asyncio.run(worker()) diff --git a/examples/multimodal/components/encode_worker.py b/examples/multimodal/components/encode_worker.py index 1c0f9d4093..297a97b8b3 100644 --- a/examples/multimodal/components/encode_worker.py +++ b/examples/multimodal/components/encode_worker.py @@ -24,7 +24,7 @@ import uvloop from transformers import AutoImageProcessor from vllm.engine.arg_utils import AsyncEngineArgs -from vllm.utils import FlexibleArgumentParser +from vllm.utils.argparse_utils import FlexibleArgumentParser import dynamo.nixl_connect as connect from dynamo.runtime import Client, DistributedRuntime, dynamo_worker diff --git a/examples/multimodal/components/processor.py b/examples/multimodal/components/processor.py index b972220f5c..c695862b96 100644 --- a/examples/multimodal/components/processor.py +++ b/examples/multimodal/components/processor.py @@ -29,8 +29,8 @@ from vllm.engine.arg_utils import AsyncEngineArgs from vllm.entrypoints.openai.protocol import ChatCompletionRequest, CompletionRequest from vllm.outputs import RequestOutput -from vllm.transformers_utils.tokenizer import AnyTokenizer -from vllm.utils import FlexibleArgumentParser +from vllm.tokenizers import TokenizerLike as AnyTokenizer +from vllm.utils.argparse_utils import FlexibleArgumentParser from dynamo.llm import ModelInput, ModelType, register_llm from dynamo.runtime import Client, DistributedRuntime, dynamo_worker diff --git a/examples/multimodal/components/publisher.py b/examples/multimodal/components/publisher.py index c1937fd6c6..19fe18ccff 100644 --- a/examples/multimodal/components/publisher.py +++ b/examples/multimodal/components/publisher.py @@ -38,6 +38,8 @@ def record( scheduler_stats: Optional[SchedulerStats], iteration_stats: Optional[IterationStats], engine_idx: int = 0, + *args, + **kwargs, ): pass @@ -74,6 +76,8 @@ def record( scheduler_stats: SchedulerStats, iteration_stats: Optional[IterationStats], engine_idx: int = 0, + *args, + **kwargs, ): # request_total_slots and kv_total_blocks are properties of model + gpu # we should only publish them once, not every metric update diff --git a/examples/multimodal/components/video_encode_worker.py b/examples/multimodal/components/video_encode_worker.py index 78f66c19d1..f5f71d32e6 100644 --- a/examples/multimodal/components/video_encode_worker.py +++ b/examples/multimodal/components/video_encode_worker.py @@ -28,7 +28,7 @@ import torch import uvloop from vllm.engine.arg_utils import AsyncEngineArgs -from vllm.utils import FlexibleArgumentParser +from vllm.utils.argparse_utils import FlexibleArgumentParser import dynamo.nixl_connect as connect from dynamo.runtime import Client, DistributedRuntime, dynamo_worker diff --git a/examples/multimodal/components/worker.py b/examples/multimodal/components/worker.py index ad825be5d0..eb10149758 100644 --- a/examples/multimodal/components/worker.py +++ b/examples/multimodal/components/worker.py @@ -15,7 +15,7 @@ from vllm.distributed.kv_events import ZmqEventPublisher from vllm.inputs.data import TokensPrompt from vllm.usage.usage_lib import UsageContext -from vllm.utils import FlexibleArgumentParser +from vllm.utils.argparse_utils import FlexibleArgumentParser from vllm.v1.engine.async_llm import AsyncLLM import dynamo.nixl_connect as connect @@ -142,7 +142,7 @@ def setup_vllm_engine(self, component: Component, endpoint: Endpoint): vllm_config=vllm_config, usage_context=usage_context, stat_loggers=[self.stats_logger], - disable_log_requests=self.engine_args.disable_log_requests, + enable_log_requests=self.engine_args.enable_log_requests, disable_log_stats=self.engine_args.disable_log_stats, ) @@ -251,7 +251,6 @@ async def async_init(self, runtime: DistributedRuntime): # We'll needs this to move data between this worker and remote workers efficiently. parsed_namespace, _, _ = parse_endpoint(self.endpoint) self._connector = connect.Connector() - await self._connector.initialize() self.image_loader = ImageLoader() diff --git a/examples/multimodal/launch/audio_agg.sh b/examples/multimodal/launch/audio_agg.sh new file mode 100755 index 0000000000..0ea01066f0 --- /dev/null +++ b/examples/multimodal/launch/audio_agg.sh @@ -0,0 +1,97 @@ +#!/bin/bash +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +set -e +trap 'echo Cleaning up...; kill 0' EXIT + +# Default values +MODEL_NAME="Qwen/Qwen2-Audio-7B-Instruct" +PROMPT_TEMPLATE="" +PROVIDED_PROMPT_TEMPLATE="" + +# Parse command line arguments +while [[ $# -gt 0 ]]; do + case $1 in + --model) + MODEL_NAME=$2 + shift 2 + ;; + --prompt-template) + PROVIDED_PROMPT_TEMPLATE=$2 + shift 2 + ;; + -h|--help) + echo "Usage: $0 [OPTIONS]" + echo "Options:" + echo " --model Specify the model to use (default: $MODEL_NAME)" + echo " --prompt-template