diff --git a/tests/serve/conftest.py b/tests/serve/conftest.py index 2456865a09..29cc71f649 100644 --- a/tests/serve/conftest.py +++ b/tests/serve/conftest.py @@ -8,6 +8,7 @@ from pytest_httpserver import HTTPServer from dynamo.common.utils.paths import WORKSPACE_DIR +from tests.serve.lora_utils import MinioLoraConfig, MinioService # Shared constants for multimodal testing IMAGE_SERVER_PORT = 8765 @@ -50,3 +51,47 @@ def test_multimodal(image_server): ) return httpserver + + +@pytest.fixture(scope="function") +def minio_lora_service(): + """ + Provide a MinIO service with a pre-uploaded LoRA adapter for testing. + + This fixture: + 1. Starts a MinIO Docker container + 2. Creates the required S3 bucket + 3. Downloads the LoRA adapter from Hugging Face Hub + 4. Uploads it to MinIO + 5. Yields the MinioLoraConfig with connection details + 6. Cleans up after the test + + Usage: + def test_lora(minio_lora_service): + config = minio_lora_service + # Use config.get_env_vars() for environment setup + # Use config.get_s3_uri() to get the S3 URI for loading LoRA + """ + config = MinioLoraConfig() + service = MinioService(config) + + try: + # Start MinIO + service.start() + + # Create bucket + service.create_bucket() + + # Download and upload LoRA + local_path = service.download_lora() + service.upload_lora(local_path) + + # Clean up downloaded files (keep MinIO running) + service.cleanup_temp() + + yield config + + finally: + # Stop MinIO and clean up + service.stop() + service.cleanup_temp() diff --git a/tests/serve/lora_utils.py b/tests/serve/lora_utils.py new file mode 100644 index 0000000000..cc8eaa1f59 --- /dev/null +++ b/tests/serve/lora_utils.py @@ -0,0 +1,274 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + + +import logging +import os +import shutil +import subprocess +import tempfile +import time +from dataclasses import dataclass +from typing import Optional + +import requests + +logger = logging.getLogger(__name__) + +# LoRA testing constants +MINIO_ENDPOINT = "http://localhost:9000" +MINIO_ACCESS_KEY = "minioadmin" +MINIO_SECRET_KEY = "minioadmin" +MINIO_BUCKET = "my-loras" +DEFAULT_LORA_REPO = "codelion/Qwen3-0.6B-accuracy-recovery-lora" +DEFAULT_LORA_NAME = "codelion/Qwen3-0.6B-accuracy-recovery-lora" + + +@dataclass +class MinioLoraConfig: + """Configuration for MinIO and LoRA setup""" + + endpoint: str = MINIO_ENDPOINT + access_key: str = MINIO_ACCESS_KEY + secret_key: str = MINIO_SECRET_KEY + bucket: str = MINIO_BUCKET + lora_repo: str = DEFAULT_LORA_REPO + lora_name: str = DEFAULT_LORA_NAME + data_dir: Optional[str] = None + + def get_s3_uri(self) -> str: + """Get the S3 URI for the LoRA adapter""" + return f"s3://{self.bucket}/{self.lora_name}" + + def get_env_vars(self) -> dict: + """Get environment variables for AWS/MinIO access""" + return { + "AWS_ENDPOINT": self.endpoint, + "AWS_ACCESS_KEY_ID": self.access_key, + "AWS_SECRET_ACCESS_KEY": self.secret_key, + "AWS_REGION": "us-east-1", + "AWS_ALLOW_HTTP": "true", + "DYN_LORA_ENABLED": "true", + "DYN_LORA_PATH": "/tmp/dynamo_loras_minio_test", + } + + +class MinioService: + """Manages MinIO Docker container lifecycle for tests""" + + CONTAINER_NAME = "dynamo-minio-test" + + def __init__(self, config: MinioLoraConfig): + self.config = config + self._logger = logging.getLogger(self.__class__.__name__) + self._temp_dir: Optional[str] = None + + def start(self) -> None: + """Start MinIO container""" + self._logger.info("Starting MinIO container...") + + # Create data directory + if self.config.data_dir: + data_dir = self.config.data_dir + else: + data_dir = tempfile.mkdtemp(prefix="minio_test_") + self.config.data_dir = data_dir + + # Stop existing container if running + self.stop() + + # Start MinIO container + cmd = [ + "docker", + "run", + "-d", + "--name", + self.CONTAINER_NAME, + "-p", + "9000:9000", + "-p", + "9001:9001", + "-v", + f"{data_dir}:/data", + "quay.io/minio/minio", + "server", + "/data", + "--console-address", + ":9001", + ] + + result = subprocess.run(cmd, capture_output=True, text=True) + if result.returncode != 0: + raise RuntimeError(f"Failed to start MinIO: {result.stderr}") + + # Wait for MinIO to be ready + self._wait_for_ready() + self._logger.info("MinIO started successfully") + + def _wait_for_ready(self, timeout: int = 30) -> None: + """Wait for MinIO to be ready""" + health_url = f"{self.config.endpoint}/minio/health/live" + start_time = time.time() + + while time.time() - start_time < timeout: + try: + response = requests.get(health_url, timeout=2) + if response.status_code == 200: + return + except requests.RequestException: + pass + time.sleep(1) + + raise RuntimeError(f"MinIO did not become ready within {timeout}s") + + def stop(self) -> None: + """Stop and remove MinIO container""" + self._logger.info("Stopping MinIO container...") + + # Stop container + subprocess.run( + ["docker", "stop", self.CONTAINER_NAME], + capture_output=True, + ) + + # Remove container + subprocess.run( + ["docker", "rm", self.CONTAINER_NAME], + capture_output=True, + ) + + def create_bucket(self) -> None: + """Create the S3 bucket using AWS CLI""" + env = os.environ.copy() + env.update( + { + "AWS_ACCESS_KEY_ID": self.config.access_key, + "AWS_SECRET_ACCESS_KEY": self.config.secret_key, + } + ) + + # Check if bucket exists + result = subprocess.run( + [ + "aws", + "--endpoint-url", + self.config.endpoint, + "s3", + "ls", + f"s3://{self.config.bucket}", + ], + capture_output=True, + text=True, + env=env, + ) + + if result.returncode != 0: + # Create bucket + self._logger.info(f"Creating bucket: {self.config.bucket}") + result = subprocess.run( + [ + "aws", + "--endpoint-url", + self.config.endpoint, + "s3", + "mb", + f"s3://{self.config.bucket}", + ], + capture_output=True, + text=True, + env=env, + ) + if result.returncode != 0: + raise RuntimeError(f"Failed to create bucket: {result.stderr}") + + def download_lora(self) -> str: + """Download LoRA from Hugging Face Hub, returns temp directory path""" + self._temp_dir = tempfile.mkdtemp(prefix="lora_download_") + self._logger.info( + f"Downloading LoRA {self.config.lora_repo} to {self._temp_dir}" + ) + + result = subprocess.run( + [ + "huggingface-cli", + "download", + self.config.lora_repo, + "--local-dir", + self._temp_dir, + "--local-dir-use-symlinks", + "False", + ], + capture_output=True, + text=True, + ) + + if result.returncode != 0: + raise RuntimeError(f"Failed to download LoRA: {result.stderr}") + + # Clean up cache directory + cache_dir = os.path.join(self._temp_dir, ".cache") + if os.path.exists(cache_dir): + shutil.rmtree(cache_dir) + + return self._temp_dir + + def upload_lora(self, local_path: str) -> None: + """Upload LoRA to MinIO""" + self._logger.info( + f"Uploading LoRA to s3://{self.config.bucket}/{self.config.lora_name}" + ) + + env = os.environ.copy() + env.update( + { + "AWS_ACCESS_KEY_ID": self.config.access_key, + "AWS_SECRET_ACCESS_KEY": self.config.secret_key, + } + ) + + result = subprocess.run( + [ + "aws", + "--endpoint-url", + self.config.endpoint, + "s3", + "sync", + local_path, + f"s3://{self.config.bucket}/{self.config.lora_name}", + "--exclude", + "*.git*", + ], + capture_output=True, + text=True, + env=env, + ) + + if result.returncode != 0: + raise RuntimeError(f"Failed to upload LoRA: {result.stderr}") + + def cleanup_temp(self) -> None: + """Clean up temporary directories""" + if self._temp_dir and os.path.exists(self._temp_dir): + shutil.rmtree(self._temp_dir) + self._temp_dir = None + + if self.config.data_dir and os.path.exists(self.config.data_dir): + shutil.rmtree(self.config.data_dir, ignore_errors=True) + + +def load_lora_adapter( + system_port: int, lora_name: str, s3_uri: str, timeout: int = 60 +) -> None: + """Load a LoRA adapter via the system API""" + url = f"http://localhost:{system_port}/v1/loras" + payload = {"lora_name": lora_name, "source": {"uri": s3_uri}} + + logger.info(f"Loading LoRA adapter: {lora_name} from {s3_uri}") + + response = requests.post(url, json=payload, timeout=timeout) + if response.status_code != 200: + raise RuntimeError( + f"Failed to load LoRA adapter: {response.status_code} - {response.text}" + ) + + logger.info(f"LoRA adapter loaded successfully: {response.json()}") diff --git a/tests/serve/test_vllm.py b/tests/serve/test_vllm.py index 55873d07d0..5e62fc365b 100644 --- a/tests/serve/test_vllm.py +++ b/tests/serve/test_vllm.py @@ -6,6 +6,7 @@ import os import random from dataclasses import dataclass, field +from typing import Optional import pytest @@ -15,6 +16,7 @@ run_serve_deployment, ) from tests.serve.conftest import MULTIMODAL_IMG_PATH, MULTIMODAL_IMG_URL +from tests.serve.lora_utils import MinioLoraConfig, load_lora_adapter from tests.utils.engine_process import EngineConfig from tests.utils.payload_builder import ( chat_payload, @@ -22,7 +24,7 @@ completion_payload_default, metric_payload_default, ) -from tests.utils.payloads import ToolCallingChatPayload +from tests.utils.payloads import ChatPayload, ToolCallingChatPayload logger = logging.getLogger(__name__) @@ -581,3 +583,240 @@ def test_multimodal_b64(request, runtime_services, predownload_models): ) run_serve_deployment(config, request) + + +# LoRA Test Directory +lora_dir = os.path.join(vllm_dir, "launch/lora") + + +class LoraTestChatPayload(ChatPayload): + """ + Chat payload that loads a LoRA adapter before sending inference requests. + + This payload first loads the specified LoRA adapter via the system API, + then sends chat completion requests using the LoRA model. + """ + + def __init__( + self, + body: dict, + lora_name: str, + s3_uri: str, + system_port: int = 8081, + repeat_count: int = 1, + expected_response: Optional[list] = None, + expected_log: Optional[list] = None, + timeout: int = 60, + ): + super().__init__( + body=body, + repeat_count=repeat_count, + expected_response=expected_response or [], + expected_log=expected_log or [], + timeout=timeout, + ) + self.system_port = system_port + self.lora_name = lora_name + self.s3_uri = s3_uri + self._lora_loaded = False + + def _ensure_lora_loaded(self) -> None: + """Ensure the LoRA adapter is loaded before making inference requests""" + if not self._lora_loaded: + import time + + import requests + + load_lora_adapter( + system_port=self.system_port, + lora_name=self.lora_name, + s3_uri=self.s3_uri, + timeout=self.timeout, + ) + + # Wait for the LoRA model to appear in /v1/models + models_url = f"http://{self.host}:{self.port}/v1/models" + start_time = time.time() + max_wait = 60 # 1 minute timeout + + logger.info( + f"Waiting for LoRA model '{self.lora_name}' to appear in /v1/models..." + ) + + while time.time() - start_time < max_wait: + try: + response = requests.get(models_url, timeout=5) + if response.status_code == 200: + data = response.json() + models = data.get("data", []) + model_ids = [m.get("id", "") for m in models] + + if self.lora_name in model_ids: + logger.info( + f"LoRA model '{self.lora_name}' is now available" + ) + self._lora_loaded = True + return + + logger.debug( + f"Available models: {model_ids}, waiting for '{self.lora_name}'..." + ) + except requests.RequestException as e: + logger.debug(f"Error checking /v1/models: {e}") + + time.sleep(1) + + raise RuntimeError( + f"Timeout: LoRA model '{self.lora_name}' did not appear in /v1/models within {max_wait}s" + ) + + def url(self) -> str: + """Load LoRA before first request, then return URL""" + self._ensure_lora_loaded() + return super().url() + + +def lora_chat_payload( + lora_name: str, + s3_uri: str, + system_port: int = 8081, + repeat_count: int = 2, + expected_response: Optional[list] = None, + expected_log: Optional[list] = None, + max_tokens: int = 100, + temperature: float = 0.0, +) -> LoraTestChatPayload: + """Create a LoRA-enabled chat payload for testing""" + return LoraTestChatPayload( + body={ + "model": lora_name, + "messages": [ + { + "role": "user", + "content": "What is deep learning? Answer in one sentence.", + } + ], + "max_tokens": max_tokens, + "temperature": temperature, + "stream": False, + }, + lora_name=lora_name, + s3_uri=s3_uri, + system_port=system_port, + repeat_count=repeat_count, + expected_response=expected_response + or ["learning", "neural", "network", "AI", "model"], + expected_log=expected_log or [], + ) + + +@pytest.mark.vllm +@pytest.mark.e2e +@pytest.mark.gpu_1 +@pytest.mark.model("Qwen/Qwen3-0.6B") +@pytest.mark.timeout(600) +@pytest.mark.nightly +def test_lora_aggregated( + request, runtime_services, predownload_models, minio_lora_service +): + """ + Test LoRA inference with aggregated vLLM deployment. + + This test: + 1. Uses MinIO fixture to provide S3-compatible storage with uploaded LoRA + 2. Starts vLLM with LoRA support enabled + 3. Loads the LoRA adapter via system API + 4. Runs inference with the LoRA model + """ + minio_config: MinioLoraConfig = minio_lora_service + + # Create payload that loads LoRA and tests inference + lora_payload = lora_chat_payload( + lora_name=minio_config.lora_name, + s3_uri=minio_config.get_s3_uri(), + system_port=8081, + repeat_count=2, + ) + + # Create test config with MinIO environment variables + config = VLLMConfig( + name="test_lora_aggregated", + directory=vllm_dir, + script_name="lora/agg_lora.sh", + marks=[], # markers at function-level + model="Qwen/Qwen3-0.6B", + timeout=600, + env=minio_config.get_env_vars(), + request_payloads=[lora_payload], + ) + + run_serve_deployment(config, request, extra_env=minio_config.get_env_vars()) + + +@pytest.mark.vllm +@pytest.mark.e2e +@pytest.mark.gpu_2 +@pytest.mark.model("Qwen/Qwen3-0.6B") +@pytest.mark.timeout(600) +@pytest.mark.nightly +def test_lora_aggregated_router( + request, runtime_services, predownload_models, minio_lora_service +): + """ + Test LoRA inference with aggregated vLLM deployment using KV router. + + This test: + 1. Uses MinIO fixture to provide S3-compatible storage with uploaded LoRA + 2. Starts multiple vLLM workers with LoRA support and KV router + 3. Loads the LoRA adapter on both workers via system API + 4. Runs inference with the LoRA model, verifying KV cache routing + """ + minio_config: MinioLoraConfig = minio_lora_service + + # Create payloads that load LoRA on both workers and test inference + # Worker 1 (port 8081) + lora_payload_worker1 = lora_chat_payload( + lora_name=minio_config.lora_name, + s3_uri=minio_config.get_s3_uri(), + system_port=8081, + repeat_count=1, + ) + + # Worker 2 (port 8082) + lora_payload_worker2 = lora_chat_payload( + lora_name=minio_config.lora_name, + s3_uri=minio_config.get_s3_uri(), + system_port=8082, + repeat_count=1, + ) + + # Additional inference payload to test routing (LoRA already loaded) + inference_payload = chat_payload( + content="Explain machine learning in simple terms.", + repeat_count=2, + expected_response=["learn", "data", "algorithm", "model", "pattern"], + max_tokens=150, + temperature=0.0, + ).with_model(minio_config.lora_name) + + # Add env vars including PYTHONHASHSEED for deterministic KV event IDs + env_vars = minio_config.get_env_vars() + env_vars["PYTHONHASHSEED"] = "0" + + # Create test config with MinIO environment variables + config = VLLMConfig( + name="test_lora_aggregated_router", + directory=vllm_dir, + script_name="lora/agg_lora_router.sh", + marks=[], # markers at function-level + model="Qwen/Qwen3-0.6B", + timeout=600, + env=env_vars, + request_payloads=[ + lora_payload_worker1, + lora_payload_worker2, + inference_payload, + ], + ) + + run_serve_deployment(config, request, extra_env=env_vars) diff --git a/tests/utils/payloads.py b/tests/utils/payloads.py index 917ad36c0b..668a0690d8 100644 --- a/tests/utils/payloads.py +++ b/tests/utils/payloads.py @@ -20,7 +20,7 @@ from dataclasses import dataclass from typing import Any, Callable, Dict, List, Optional -from dynamo import prometheus_names +from dynamo import prometheus_names # type: ignore[attr-defined] logger = logging.getLogger(__name__)