|
2 | 2 | # SPDX-License-Identifier: Apache-2.0 |
3 | 3 |
|
4 | 4 |
|
5 | | -import logging |
6 | 5 | import os |
7 | | -import shutil |
8 | | -import subprocess |
9 | | -import tempfile |
10 | | -import time |
11 | | -from dataclasses import dataclass |
12 | | -from typing import Optional |
13 | 6 |
|
14 | 7 | import pytest |
15 | | -import requests |
16 | 8 | from pytest_httpserver import HTTPServer |
17 | 9 |
|
18 | 10 | from dynamo.common.utils.paths import WORKSPACE_DIR |
19 | | - |
20 | | -logger = logging.getLogger(__name__) |
| 11 | +from tests.serve.lora_utils import MinioLoraConfig, MinioService |
21 | 12 |
|
22 | 13 | # Shared constants for multimodal testing |
23 | 14 | IMAGE_SERVER_PORT = 8765 |
|
26 | 17 | ) |
27 | 18 | MULTIMODAL_IMG_URL = f"http://localhost:{IMAGE_SERVER_PORT}/llm-graphic.png" |
28 | 19 |
|
29 | | -# LoRA testing constants |
30 | | -MINIO_ENDPOINT = "http://localhost:9000" |
31 | | -MINIO_ACCESS_KEY = "minioadmin" |
32 | | -MINIO_SECRET_KEY = "minioadmin" |
33 | | -MINIO_BUCKET = "my-loras" |
34 | | -DEFAULT_LORA_REPO = "codelion/Qwen3-0.6B-accuracy-recovery-lora" |
35 | | -DEFAULT_LORA_NAME = "codelion/Qwen3-0.6B-accuracy-recovery-lora" |
36 | | - |
37 | | - |
38 | | -@dataclass |
39 | | -class MinioLoraConfig: |
40 | | - """Configuration for MinIO and LoRA setup""" |
41 | | - |
42 | | - endpoint: str = MINIO_ENDPOINT |
43 | | - access_key: str = MINIO_ACCESS_KEY |
44 | | - secret_key: str = MINIO_SECRET_KEY |
45 | | - bucket: str = MINIO_BUCKET |
46 | | - lora_repo: str = DEFAULT_LORA_REPO |
47 | | - lora_name: str = DEFAULT_LORA_NAME |
48 | | - data_dir: Optional[str] = None |
49 | | - |
50 | | - def get_s3_uri(self) -> str: |
51 | | - """Get the S3 URI for the LoRA adapter""" |
52 | | - return f"s3://{self.bucket}/{self.lora_name}" |
53 | | - |
54 | | - def get_env_vars(self) -> dict: |
55 | | - """Get environment variables for AWS/MinIO access""" |
56 | | - return { |
57 | | - "AWS_ENDPOINT": self.endpoint, |
58 | | - "AWS_ACCESS_KEY_ID": self.access_key, |
59 | | - "AWS_SECRET_ACCESS_KEY": self.secret_key, |
60 | | - "AWS_REGION": "us-east-1", |
61 | | - "AWS_ALLOW_HTTP": "true", |
62 | | - "DYN_LORA_ENABLED": "true", |
63 | | - "DYN_LORA_PATH": "/tmp/dynamo_loras_minio_test", |
64 | | - } |
65 | | - |
66 | | - |
67 | | -class MinioService: |
68 | | - """Manages MinIO Docker container lifecycle for tests""" |
69 | | - |
70 | | - CONTAINER_NAME = "dynamo-minio-test" |
71 | | - |
72 | | - def __init__(self, config: MinioLoraConfig): |
73 | | - self.config = config |
74 | | - self._logger = logging.getLogger(self.__class__.__name__) |
75 | | - self._temp_dir: Optional[str] = None |
76 | | - |
77 | | - def start(self) -> None: |
78 | | - """Start MinIO container""" |
79 | | - self._logger.info("Starting MinIO container...") |
80 | | - |
81 | | - # Create data directory |
82 | | - if self.config.data_dir: |
83 | | - data_dir = self.config.data_dir |
84 | | - else: |
85 | | - data_dir = tempfile.mkdtemp(prefix="minio_test_") |
86 | | - self.config.data_dir = data_dir |
87 | | - |
88 | | - # Stop existing container if running |
89 | | - self.stop() |
90 | | - |
91 | | - # Start MinIO container |
92 | | - cmd = [ |
93 | | - "docker", |
94 | | - "run", |
95 | | - "-d", |
96 | | - "--name", |
97 | | - self.CONTAINER_NAME, |
98 | | - "-p", |
99 | | - "9000:9000", |
100 | | - "-p", |
101 | | - "9001:9001", |
102 | | - "-v", |
103 | | - f"{data_dir}:/data", |
104 | | - "quay.io/minio/minio", |
105 | | - "server", |
106 | | - "/data", |
107 | | - "--console-address", |
108 | | - ":9001", |
109 | | - ] |
110 | | - |
111 | | - result = subprocess.run(cmd, capture_output=True, text=True) |
112 | | - if result.returncode != 0: |
113 | | - raise RuntimeError(f"Failed to start MinIO: {result.stderr}") |
114 | | - |
115 | | - # Wait for MinIO to be ready |
116 | | - self._wait_for_ready() |
117 | | - self._logger.info("MinIO started successfully") |
118 | | - |
119 | | - def _wait_for_ready(self, timeout: int = 30) -> None: |
120 | | - """Wait for MinIO to be ready""" |
121 | | - health_url = f"{self.config.endpoint}/minio/health/live" |
122 | | - start_time = time.time() |
123 | | - |
124 | | - while time.time() - start_time < timeout: |
125 | | - try: |
126 | | - response = requests.get(health_url, timeout=2) |
127 | | - if response.status_code == 200: |
128 | | - return |
129 | | - except requests.RequestException: |
130 | | - pass |
131 | | - time.sleep(1) |
132 | | - |
133 | | - raise RuntimeError(f"MinIO did not become ready within {timeout}s") |
134 | | - |
135 | | - def stop(self) -> None: |
136 | | - """Stop and remove MinIO container""" |
137 | | - self._logger.info("Stopping MinIO container...") |
138 | | - |
139 | | - # Stop container |
140 | | - subprocess.run( |
141 | | - ["docker", "stop", self.CONTAINER_NAME], |
142 | | - capture_output=True, |
143 | | - ) |
144 | | - |
145 | | - # Remove container |
146 | | - subprocess.run( |
147 | | - ["docker", "rm", self.CONTAINER_NAME], |
148 | | - capture_output=True, |
149 | | - ) |
150 | | - |
151 | | - def create_bucket(self) -> None: |
152 | | - """Create the S3 bucket using AWS CLI""" |
153 | | - env = os.environ.copy() |
154 | | - env.update( |
155 | | - { |
156 | | - "AWS_ACCESS_KEY_ID": self.config.access_key, |
157 | | - "AWS_SECRET_ACCESS_KEY": self.config.secret_key, |
158 | | - } |
159 | | - ) |
160 | | - |
161 | | - # Check if bucket exists |
162 | | - result = subprocess.run( |
163 | | - [ |
164 | | - "aws", |
165 | | - "--endpoint-url", |
166 | | - self.config.endpoint, |
167 | | - "s3", |
168 | | - "ls", |
169 | | - f"s3://{self.config.bucket}", |
170 | | - ], |
171 | | - capture_output=True, |
172 | | - env=env, |
173 | | - ) |
174 | | - |
175 | | - if result.returncode != 0: |
176 | | - # Create bucket |
177 | | - self._logger.info(f"Creating bucket: {self.config.bucket}") |
178 | | - result = subprocess.run( |
179 | | - [ |
180 | | - "aws", |
181 | | - "--endpoint-url", |
182 | | - self.config.endpoint, |
183 | | - "s3", |
184 | | - "mb", |
185 | | - f"s3://{self.config.bucket}", |
186 | | - ], |
187 | | - capture_output=True, |
188 | | - env=env, |
189 | | - ) |
190 | | - if result.returncode != 0: |
191 | | - raise RuntimeError(f"Failed to create bucket: {result.stderr}") |
192 | | - |
193 | | - def download_lora(self) -> str: |
194 | | - """Download LoRA from Hugging Face Hub, returns temp directory path""" |
195 | | - self._temp_dir = tempfile.mkdtemp(prefix="lora_download_") |
196 | | - self._logger.info( |
197 | | - f"Downloading LoRA {self.config.lora_repo} to {self._temp_dir}" |
198 | | - ) |
199 | | - |
200 | | - result = subprocess.run( |
201 | | - [ |
202 | | - "huggingface-cli", |
203 | | - "download", |
204 | | - self.config.lora_repo, |
205 | | - "--local-dir", |
206 | | - self._temp_dir, |
207 | | - "--local-dir-use-symlinks", |
208 | | - "False", |
209 | | - ], |
210 | | - capture_output=True, |
211 | | - text=True, |
212 | | - ) |
213 | | - |
214 | | - if result.returncode != 0: |
215 | | - raise RuntimeError(f"Failed to download LoRA: {result.stderr}") |
216 | | - |
217 | | - # Clean up cache directory |
218 | | - cache_dir = os.path.join(self._temp_dir, ".cache") |
219 | | - if os.path.exists(cache_dir): |
220 | | - shutil.rmtree(cache_dir) |
221 | | - |
222 | | - return self._temp_dir |
223 | | - |
224 | | - def upload_lora(self, local_path: str) -> None: |
225 | | - """Upload LoRA to MinIO""" |
226 | | - self._logger.info( |
227 | | - f"Uploading LoRA to s3://{self.config.bucket}/{self.config.lora_name}" |
228 | | - ) |
229 | | - |
230 | | - env = os.environ.copy() |
231 | | - env.update( |
232 | | - { |
233 | | - "AWS_ACCESS_KEY_ID": self.config.access_key, |
234 | | - "AWS_SECRET_ACCESS_KEY": self.config.secret_key, |
235 | | - } |
236 | | - ) |
237 | | - |
238 | | - result = subprocess.run( |
239 | | - [ |
240 | | - "aws", |
241 | | - "--endpoint-url", |
242 | | - self.config.endpoint, |
243 | | - "s3", |
244 | | - "sync", |
245 | | - local_path, |
246 | | - f"s3://{self.config.bucket}/{self.config.lora_name}", |
247 | | - "--exclude", |
248 | | - "*.git*", |
249 | | - ], |
250 | | - capture_output=True, |
251 | | - env=env, |
252 | | - ) |
253 | | - |
254 | | - if result.returncode != 0: |
255 | | - raise RuntimeError(f"Failed to upload LoRA: {result.stderr}") |
256 | | - |
257 | | - def cleanup_temp(self) -> None: |
258 | | - """Clean up temporary directories""" |
259 | | - if self._temp_dir and os.path.exists(self._temp_dir): |
260 | | - shutil.rmtree(self._temp_dir) |
261 | | - self._temp_dir = None |
262 | | - |
263 | | - if self.config.data_dir and os.path.exists(self.config.data_dir): |
264 | | - shutil.rmtree(self.config.data_dir, ignore_errors=True) |
265 | | - |
266 | | - |
267 | | -def load_lora_adapter( |
268 | | - system_port: int, lora_name: str, s3_uri: str, timeout: int = 60 |
269 | | -) -> None: |
270 | | - """Load a LoRA adapter via the system API""" |
271 | | - url = f"http://localhost:{system_port}/v1/loras" |
272 | | - payload = {"lora_name": lora_name, "source": {"uri": s3_uri}} |
273 | | - |
274 | | - logger.info(f"Loading LoRA adapter: {lora_name} from {s3_uri}") |
275 | | - |
276 | | - response = requests.post(url, json=payload, timeout=timeout) |
277 | | - if response.status_code != 200: |
278 | | - raise RuntimeError( |
279 | | - f"Failed to load LoRA adapter: {response.status_code} - {response.text}" |
280 | | - ) |
281 | | - |
282 | | - logger.info(f"LoRA adapter loaded successfully: {response.json()}") |
283 | | - |
284 | 20 |
|
285 | 21 | @pytest.fixture(scope="session") |
286 | 22 | def httpserver_listen_address(): |
|
0 commit comments