Skip to content

Commit dc342da

Browse files
committed
small nits
1 parent 7062123 commit dc342da

File tree

4 files changed

+42
-43
lines changed

4 files changed

+42
-43
lines changed

tests/fault_tolerance/deploy/client.py

Lines changed: 11 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -363,7 +363,6 @@ def run_aiperf(
363363
# Note: For continuous load, we only run once and expect SIGINT to stop it
364364
max_attempts = 1 if continuous_load else (max_retries if max_retries > 0 else 1)
365365
success = False
366-
all_results = []
367366

368367
for attempt in range(max_attempts):
369368
if continuous_load:
@@ -394,16 +393,6 @@ def run_aiperf(
394393
f.write("\n\n=== STDERR ===\n")
395394
f.write(result.stderr)
396395

397-
all_results.append(
398-
{
399-
"attempt": attempt + 1,
400-
"returncode": result.returncode,
401-
"stdout": result.stdout,
402-
"stderr": result.stderr,
403-
}
404-
)
405-
406-
# Even with continuous load, with SIGINT, aiperf should return 0 and create the profile_export_aiperf.json file
407396
if result.returncode == 0:
408397
# AI-Perf returns 0 even if all requests failed, so we need to check the output
409398
json_path = attempt_dir / "profile_export_aiperf.json"
@@ -440,7 +429,6 @@ def run_aiperf(
440429
)
441430
except Exception as e:
442431
logger.error(f"Error in attempt {attempt + 1}: {str(e)}")
443-
all_results.append({"attempt": attempt + 1, "error": str(e)})
444432

445433
# Sleep before next attempt (if not the last attempt and not continuous load)
446434
if not success and attempt < max_attempts - 1 and not continuous_load:
@@ -468,9 +456,9 @@ def run_aiperf_with_signal_handling(
468456
"""
469457
Run aiperf with signal handling for graceful shutdown.
470458
471-
Handles SIGINT forwarding and timeout when running with subprocess.Popen.
472-
This ensures that Ctrl-C and SIGINT are properly forwarded to the subprocess
473-
so it can clean up gracefully and write results files.
459+
Handles SIGINT and SIGTERM forwarding and timeout when running with subprocess.Popen.
460+
This ensures that Ctrl-C (SIGINT) and graceful termination signals (SIGTERM)
461+
are properly forwarded to the subprocess so it can clean up gracefully and write results files.
474462
"""
475463
proc = subprocess.Popen(
476464
cmd_attempt,
@@ -480,15 +468,20 @@ def run_aiperf_with_signal_handling(
480468
stdin=subprocess.DEVNULL,
481469
)
482470

483-
# Set up signal handler to forward SIGINT to subprocess
484471
def signal_handler(signum, frame):
485-
logger.info(f"Received signal {signum}, forwarding to aiperf subprocess")
472+
signal_names = {
473+
signal.SIGINT: "SIGINT",
474+
signal.SIGTERM: "SIGTERM",
475+
}
476+
signal_name = signal_names.get(signum, f"signal {signum}")
477+
logger.info(f"Received {signal_name}, forwarding to aiperf subprocess")
486478
try:
487-
proc.send_signal(signal.SIGINT)
479+
proc.send_signal(signum)
488480
except ProcessLookupError:
489481
pass # Process already terminated
490482

491483
signal.signal(signal.SIGINT, signal_handler)
484+
signal.signal(signal.SIGTERM, signal_handler)
492485

493486
try:
494487
stdout, stderr = proc.communicate(timeout=timeout)

tests/fault_tolerance/deploy/scenarios.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,9 +13,9 @@
1313
# See the License for the specific language governing permissions and
1414
# limitations under the License.
1515

16+
import asyncio
1617
import logging
1718
import re
18-
import time
1919
from abc import ABC, abstractmethod
2020
from dataclasses import dataclass
2121
from enum import Enum, auto
@@ -188,7 +188,7 @@ async def execute(
188188

189189
await deployment._wait_for_ready(timeout=1800) # 30 minute timeout
190190

191-
time.sleep(
191+
await asyncio.sleep(
192192
self.time
193193
) # have some requests processed after the rolling upgrade has completed
194194

tests/fault_tolerance/deploy/test_deployment.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,12 @@
11
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
22
# SPDX-License-Identifier: Apache-2.0
33

4+
import asyncio
45
import logging
56
import multiprocessing
67
import os
78
import re
89
import signal
9-
import time
1010
from contextlib import contextmanager
1111
from multiprocessing.context import SpawnProcess
1212

@@ -224,7 +224,7 @@ async def _inject_failures(
224224
deployment: ManagedDeployment,
225225
): # noqa: F811
226226
for failure in failures:
227-
time.sleep(failure.time)
227+
await asyncio.sleep(failure.time)
228228

229229
logger.info(f"Injecting failure for: {failure}")
230230

tests/utils/managed_deployment.py

Lines changed: 27 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,15 @@ def image(self) -> Optional[str]:
5757
except KeyError:
5858
return None
5959

60+
@property
61+
def envs(self) -> list[dict[str, str]]:
62+
"""Environment variables for the service"""
63+
return self._spec.get("envs", [])
64+
65+
@envs.setter
66+
def envs(self, value: list[dict[str, str]]):
67+
self._spec["envs"] = value
68+
6069
@image.setter
6170
def image(self, value: str):
6271
if "extraPodSpec" not in self._spec:
@@ -318,13 +327,9 @@ def set_service_env_var(self, service_name: str, name: str, value: str):
318327
"""
319328
Set an environment variable for a specific service
320329
"""
321-
# Check service exists
322-
if service_name not in self._deployment_spec["spec"]["services"]:
323-
raise ValueError(f"Service '{service_name}' not found in deployment spec")
324-
325-
service = self._deployment_spec["spec"]["services"][service_name]
326-
if "envs" not in service:
327-
service["envs"] = []
330+
service = self.get_service(service_name)
331+
if service.envs is None:
332+
service.envs = []
328333

329334
# if env var already exists, update it
330335
for env in service["envs"]:
@@ -342,11 +347,8 @@ def get_service_env_vars(self, service_name: str) -> list[dict]:
342347
Returns:
343348
List of environment variable dicts (e.g., [{"name": "VAR", "value": "val"}])
344349
"""
345-
# Check service exists
346-
if service_name not in self._deployment_spec["spec"]["services"]:
347-
raise ValueError(f"Service '{service_name}' not found in deployment spec")
348-
349-
return self._deployment_spec["spec"]["services"][service_name].get("envs", [])
350+
service = self.get_service(service_name)
351+
return service.envs
350352

351353
@property
352354
def services(self) -> list[ServiceSpec]:
@@ -374,11 +376,7 @@ def add_arg_to_service(self, service_name: str, arg_name: str, arg_value: str):
374376
arg_name: Argument name (e.g., "--max-model-len", "--max-seq-len")
375377
arg_value: Argument value (e.g., "1024")
376378
"""
377-
# Get the service
378-
if service_name not in self._deployment_spec["spec"]["services"]:
379-
raise ValueError(f"Service '{service_name}' not found in deployment spec")
380-
381-
service = self._deployment_spec["spec"]["services"][service_name]
379+
service = self.get_service(service_name)
382380

383381
# Ensure args list exists
384382
if "extraPodSpec" not in service:
@@ -418,15 +416,23 @@ def add_arg_to_service(self, service_name: str, arg_name: str, arg_value: str):
418416
# Add new argument
419417
args_list.extend([arg_name, arg_value])
420418

421-
def set_service_replicas(self, service_name: str, replicas: int):
419+
def get_service(self, service_name: str) -> ServiceSpec:
422420
"""
423-
Set the number of replicas for a specific service
421+
Get a specific service from the deployment spec
424422
"""
425-
# Check service exists
426423
if service_name not in self._deployment_spec["spec"]["services"]:
427424
raise ValueError(f"Service '{service_name}' not found in deployment spec")
428425

429-
self._deployment_spec["spec"]["services"][service_name]["replicas"] = replicas
426+
return ServiceSpec(
427+
service_name, self._deployment_spec["spec"]["services"][service_name]
428+
)
429+
430+
def set_service_replicas(self, service_name: str, replicas: int):
431+
"""
432+
Set the number of replicas for a specific service
433+
"""
434+
service = self.get_service(service_name)
435+
service.replicas = replicas
430436

431437
def save(self, out_file: str):
432438
"""Save updated deployment to file"""

0 commit comments

Comments
 (0)