Skip to content

Commit f77fa70

Browse files
biswapandazxue2
authored andcommitted
test: add ci tests for lora (agg and router) (ai-dynamo#4817)
1 parent d914a08 commit f77fa70

File tree

4 files changed

+560
-2
lines changed

4 files changed

+560
-2
lines changed

tests/serve/conftest.py

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
from pytest_httpserver import HTTPServer
99

1010
from dynamo.common.utils.paths import WORKSPACE_DIR
11+
from tests.serve.lora_utils import MinioLoraConfig, MinioService
1112

1213
# Shared constants for multimodal testing
1314
IMAGE_SERVER_PORT = 8765
@@ -50,3 +51,47 @@ def test_multimodal(image_server):
5051
)
5152

5253
return httpserver
54+
55+
56+
@pytest.fixture(scope="function")
57+
def minio_lora_service():
58+
"""
59+
Provide a MinIO service with a pre-uploaded LoRA adapter for testing.
60+
61+
This fixture:
62+
1. Starts a MinIO Docker container
63+
2. Creates the required S3 bucket
64+
3. Downloads the LoRA adapter from Hugging Face Hub
65+
4. Uploads it to MinIO
66+
5. Yields the MinioLoraConfig with connection details
67+
6. Cleans up after the test
68+
69+
Usage:
70+
def test_lora(minio_lora_service):
71+
config = minio_lora_service
72+
# Use config.get_env_vars() for environment setup
73+
# Use config.get_s3_uri() to get the S3 URI for loading LoRA
74+
"""
75+
config = MinioLoraConfig()
76+
service = MinioService(config)
77+
78+
try:
79+
# Start MinIO
80+
service.start()
81+
82+
# Create bucket
83+
service.create_bucket()
84+
85+
# Download and upload LoRA
86+
local_path = service.download_lora()
87+
service.upload_lora(local_path)
88+
89+
# Clean up downloaded files (keep MinIO running)
90+
service.cleanup_temp()
91+
92+
yield config
93+
94+
finally:
95+
# Stop MinIO and clean up
96+
service.stop()
97+
service.cleanup_temp()

tests/serve/lora_utils.py

Lines changed: 274 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,274 @@
1+
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2+
# SPDX-License-Identifier: Apache-2.0
3+
4+
5+
import logging
6+
import os
7+
import shutil
8+
import subprocess
9+
import tempfile
10+
import time
11+
from dataclasses import dataclass
12+
from typing import Optional
13+
14+
import requests
15+
16+
logger = logging.getLogger(__name__)
17+
18+
# LoRA testing constants
19+
MINIO_ENDPOINT = "http://localhost:9000"
20+
MINIO_ACCESS_KEY = "minioadmin"
21+
MINIO_SECRET_KEY = "minioadmin"
22+
MINIO_BUCKET = "my-loras"
23+
DEFAULT_LORA_REPO = "codelion/Qwen3-0.6B-accuracy-recovery-lora"
24+
DEFAULT_LORA_NAME = "codelion/Qwen3-0.6B-accuracy-recovery-lora"
25+
26+
27+
@dataclass
28+
class MinioLoraConfig:
29+
"""Configuration for MinIO and LoRA setup"""
30+
31+
endpoint: str = MINIO_ENDPOINT
32+
access_key: str = MINIO_ACCESS_KEY
33+
secret_key: str = MINIO_SECRET_KEY
34+
bucket: str = MINIO_BUCKET
35+
lora_repo: str = DEFAULT_LORA_REPO
36+
lora_name: str = DEFAULT_LORA_NAME
37+
data_dir: Optional[str] = None
38+
39+
def get_s3_uri(self) -> str:
40+
"""Get the S3 URI for the LoRA adapter"""
41+
return f"s3://{self.bucket}/{self.lora_name}"
42+
43+
def get_env_vars(self) -> dict:
44+
"""Get environment variables for AWS/MinIO access"""
45+
return {
46+
"AWS_ENDPOINT": self.endpoint,
47+
"AWS_ACCESS_KEY_ID": self.access_key,
48+
"AWS_SECRET_ACCESS_KEY": self.secret_key,
49+
"AWS_REGION": "us-east-1",
50+
"AWS_ALLOW_HTTP": "true",
51+
"DYN_LORA_ENABLED": "true",
52+
"DYN_LORA_PATH": "/tmp/dynamo_loras_minio_test",
53+
}
54+
55+
56+
class MinioService:
57+
"""Manages MinIO Docker container lifecycle for tests"""
58+
59+
CONTAINER_NAME = "dynamo-minio-test"
60+
61+
def __init__(self, config: MinioLoraConfig):
62+
self.config = config
63+
self._logger = logging.getLogger(self.__class__.__name__)
64+
self._temp_dir: Optional[str] = None
65+
66+
def start(self) -> None:
67+
"""Start MinIO container"""
68+
self._logger.info("Starting MinIO container...")
69+
70+
# Create data directory
71+
if self.config.data_dir:
72+
data_dir = self.config.data_dir
73+
else:
74+
data_dir = tempfile.mkdtemp(prefix="minio_test_")
75+
self.config.data_dir = data_dir
76+
77+
# Stop existing container if running
78+
self.stop()
79+
80+
# Start MinIO container
81+
cmd = [
82+
"docker",
83+
"run",
84+
"-d",
85+
"--name",
86+
self.CONTAINER_NAME,
87+
"-p",
88+
"9000:9000",
89+
"-p",
90+
"9001:9001",
91+
"-v",
92+
f"{data_dir}:/data",
93+
"quay.io/minio/minio",
94+
"server",
95+
"/data",
96+
"--console-address",
97+
":9001",
98+
]
99+
100+
result = subprocess.run(cmd, capture_output=True, text=True)
101+
if result.returncode != 0:
102+
raise RuntimeError(f"Failed to start MinIO: {result.stderr}")
103+
104+
# Wait for MinIO to be ready
105+
self._wait_for_ready()
106+
self._logger.info("MinIO started successfully")
107+
108+
def _wait_for_ready(self, timeout: int = 30) -> None:
109+
"""Wait for MinIO to be ready"""
110+
health_url = f"{self.config.endpoint}/minio/health/live"
111+
start_time = time.time()
112+
113+
while time.time() - start_time < timeout:
114+
try:
115+
response = requests.get(health_url, timeout=2)
116+
if response.status_code == 200:
117+
return
118+
except requests.RequestException:
119+
pass
120+
time.sleep(1)
121+
122+
raise RuntimeError(f"MinIO did not become ready within {timeout}s")
123+
124+
def stop(self) -> None:
125+
"""Stop and remove MinIO container"""
126+
self._logger.info("Stopping MinIO container...")
127+
128+
# Stop container
129+
subprocess.run(
130+
["docker", "stop", self.CONTAINER_NAME],
131+
capture_output=True,
132+
)
133+
134+
# Remove container
135+
subprocess.run(
136+
["docker", "rm", self.CONTAINER_NAME],
137+
capture_output=True,
138+
)
139+
140+
def create_bucket(self) -> None:
141+
"""Create the S3 bucket using AWS CLI"""
142+
env = os.environ.copy()
143+
env.update(
144+
{
145+
"AWS_ACCESS_KEY_ID": self.config.access_key,
146+
"AWS_SECRET_ACCESS_KEY": self.config.secret_key,
147+
}
148+
)
149+
150+
# Check if bucket exists
151+
result = subprocess.run(
152+
[
153+
"aws",
154+
"--endpoint-url",
155+
self.config.endpoint,
156+
"s3",
157+
"ls",
158+
f"s3://{self.config.bucket}",
159+
],
160+
capture_output=True,
161+
text=True,
162+
env=env,
163+
)
164+
165+
if result.returncode != 0:
166+
# Create bucket
167+
self._logger.info(f"Creating bucket: {self.config.bucket}")
168+
result = subprocess.run(
169+
[
170+
"aws",
171+
"--endpoint-url",
172+
self.config.endpoint,
173+
"s3",
174+
"mb",
175+
f"s3://{self.config.bucket}",
176+
],
177+
capture_output=True,
178+
text=True,
179+
env=env,
180+
)
181+
if result.returncode != 0:
182+
raise RuntimeError(f"Failed to create bucket: {result.stderr}")
183+
184+
def download_lora(self) -> str:
185+
"""Download LoRA from Hugging Face Hub, returns temp directory path"""
186+
self._temp_dir = tempfile.mkdtemp(prefix="lora_download_")
187+
self._logger.info(
188+
f"Downloading LoRA {self.config.lora_repo} to {self._temp_dir}"
189+
)
190+
191+
result = subprocess.run(
192+
[
193+
"huggingface-cli",
194+
"download",
195+
self.config.lora_repo,
196+
"--local-dir",
197+
self._temp_dir,
198+
"--local-dir-use-symlinks",
199+
"False",
200+
],
201+
capture_output=True,
202+
text=True,
203+
)
204+
205+
if result.returncode != 0:
206+
raise RuntimeError(f"Failed to download LoRA: {result.stderr}")
207+
208+
# Clean up cache directory
209+
cache_dir = os.path.join(self._temp_dir, ".cache")
210+
if os.path.exists(cache_dir):
211+
shutil.rmtree(cache_dir)
212+
213+
return self._temp_dir
214+
215+
def upload_lora(self, local_path: str) -> None:
216+
"""Upload LoRA to MinIO"""
217+
self._logger.info(
218+
f"Uploading LoRA to s3://{self.config.bucket}/{self.config.lora_name}"
219+
)
220+
221+
env = os.environ.copy()
222+
env.update(
223+
{
224+
"AWS_ACCESS_KEY_ID": self.config.access_key,
225+
"AWS_SECRET_ACCESS_KEY": self.config.secret_key,
226+
}
227+
)
228+
229+
result = subprocess.run(
230+
[
231+
"aws",
232+
"--endpoint-url",
233+
self.config.endpoint,
234+
"s3",
235+
"sync",
236+
local_path,
237+
f"s3://{self.config.bucket}/{self.config.lora_name}",
238+
"--exclude",
239+
"*.git*",
240+
],
241+
capture_output=True,
242+
text=True,
243+
env=env,
244+
)
245+
246+
if result.returncode != 0:
247+
raise RuntimeError(f"Failed to upload LoRA: {result.stderr}")
248+
249+
def cleanup_temp(self) -> None:
250+
"""Clean up temporary directories"""
251+
if self._temp_dir and os.path.exists(self._temp_dir):
252+
shutil.rmtree(self._temp_dir)
253+
self._temp_dir = None
254+
255+
if self.config.data_dir and os.path.exists(self.config.data_dir):
256+
shutil.rmtree(self.config.data_dir, ignore_errors=True)
257+
258+
259+
def load_lora_adapter(
260+
system_port: int, lora_name: str, s3_uri: str, timeout: int = 60
261+
) -> None:
262+
"""Load a LoRA adapter via the system API"""
263+
url = f"http://localhost:{system_port}/v1/loras"
264+
payload = {"lora_name": lora_name, "source": {"uri": s3_uri}}
265+
266+
logger.info(f"Loading LoRA adapter: {lora_name} from {s3_uri}")
267+
268+
response = requests.post(url, json=payload, timeout=timeout)
269+
if response.status_code != 200:
270+
raise RuntimeError(
271+
f"Failed to load LoRA adapter: {response.status_code} - {response.text}"
272+
)
273+
274+
logger.info(f"LoRA adapter loaded successfully: {response.json()}")

0 commit comments

Comments
 (0)