diff --git a/tests/fault_tolerance/deploy/client.py b/tests/fault_tolerance/deploy/client.py index 03432c0d38..89dd7daec7 100644 --- a/tests/fault_tolerance/deploy/client.py +++ b/tests/fault_tolerance/deploy/client.py @@ -18,12 +18,14 @@ import json import logging import os +import signal import subprocess import time from pathlib import Path from typing import Any, Dict, List, Optional, Tuple import requests +from kr8s.objects import Pod from tests.utils.managed_deployment import ManagedDeployment @@ -44,7 +46,7 @@ def get_frontend_port( deployment_spec: Any, pod_ports: Dict[str, Any], logger: logging.Logger, -) -> Tuple[Optional[str], Optional[int], Optional[str]]: +) -> Tuple[Optional[str], Optional[int], Optional[Pod]]: """ Select a frontend pod using round-robin and setup port forwarding. @@ -60,7 +62,7 @@ def get_frontend_port( Returns: Tuple of (pod_name, local_port, pod_instance) or (None, None, None) if failed """ - pods = managed_deployment.get_pods(managed_deployment.frontend_service_name) + pods = managed_deployment.get_pods([managed_deployment.frontend_service_name]) port = 0 pod_name = None @@ -270,6 +272,7 @@ def run_aiperf( logger: logging.Logger, max_retries: int = 1, retry_delay: float = 1, + continuous_load: bool = False, ) -> bool: """ Execute AI-Perf with specified parameters. @@ -280,13 +283,14 @@ def run_aiperf( model: Model name pod_name: Selected pod name for logging port: Local port number - requests_per_client: Number of requests to send + requests_per_client: Number of requests to send (used if continuous load not enabled) input_token_length: Input token count output_token_length: Output token count output_dir: Directory for AI-Perf artifacts logger: Logger instance max_retries: Maximum number of retry attempts (default: 1) retry_delay: Delay in seconds between retries (default: 1) + continuous_load: If True, use continuous load instead of fixed request count Returns: True if successful, False otherwise @@ -315,8 +319,6 @@ def run_aiperf( # Enable streaming for TTFT and ITL metrics "--streaming", # Request parameters - "--request-count", - str(requests_per_client), # Required: how many requests "--concurrency", "1", # Optional: we set to 1 for sequential # Token configuration @@ -338,8 +340,13 @@ def run_aiperf( "100", # For reproducible results ] - # Calculate timeout (same as legacy would for all requests) - timeout = max(requests_per_client * 2 + 60, 300) # At least 5 minutes + if continuous_load: + cmd.extend(["--benchmark-duration", "1800"]) # 30 minutes for continuous load + logger.info("Using continuous load with duration: 30 minutes") + timeout = 1860 # 31 minutes default for duration-based tests (30 minutes + 1 minute buffer) + else: + cmd.extend(["--request-count", str(requests_per_client)]) + timeout = max(requests_per_client * 2 + 60, 300) # At least 5 minutes # Log execution logger.info(f"Starting AI-Perf for Pod {pod_name} Local Port {port}") @@ -354,15 +361,19 @@ def run_aiperf( logger.info(f"Command: {' '.join(cmd)}") # Retry logic for fault tolerance - retry FULL request count until success - - max_attempts = max_retries if max_retries > 0 else 1 + # Note: For continuous load, we only run once and expect SIGINT to stop it + max_attempts = 1 if continuous_load else (max_retries if max_retries > 0 else 1) success = False - all_results = [] for attempt in range(max_attempts): - logger.info( - f"AI-Perf attempt {attempt + 1}/{max_attempts} with {requests_per_client} requests" - ) + if continuous_load: + logger.info( + "AI-Perf continuous load (will run until interrupted by SIGINT)" + ) + else: + logger.info( + f"AI-Perf attempt {attempt + 1}/{max_attempts} with {requests_per_client} requests" + ) # Update output directory for this attempt attempt_dir = output_dir / f"attempt_{attempt}" @@ -374,13 +385,7 @@ def run_aiperf( cmd_attempt[artifact_dir_idx] = str(attempt_dir) try: - result = subprocess.run( - cmd_attempt, - capture_output=True, - text=True, - timeout=timeout, - stdin=subprocess.DEVNULL, # Prevent stdin reading which can cause process suspension - ) + result = run_aiperf_with_signal_handling(cmd_attempt, logger, timeout) # Save logs for this attempt with open(attempt_dir / "genai_perf.log", "w") as f: @@ -389,15 +394,6 @@ def run_aiperf( f.write("\n\n=== STDERR ===\n") f.write(result.stderr) - all_results.append( - { - "attempt": attempt + 1, - "returncode": result.returncode, - "stdout": result.stdout, - "stderr": result.stderr, - } - ) - if result.returncode == 0: # AI-Perf returns 0 even if all requests failed, so we need to check the output json_path = attempt_dir / "profile_export_aiperf.json" @@ -412,6 +408,19 @@ def run_aiperf( ) if success: break # Success - exit the retry loop + ## TODO: bug with aiperf git+https://github.com/ai-dynamo/aiperf.git@4d3fa29403c8f75da22a14f1f7b3aeb27db9288f + ## where sending a SIGINT on Mac can sometimes have an error code of -9 (SIGABRT) which results in profile_export_aiperf.json not being created + elif result.returncode == -9 and continuous_load: + logger.warning( + f""" + Attempt {attempt + 1} failed with return code {result.returncode} + This is a known bug with aiperf on Mac where sending a SIGINT can sometimes have an error code of -9 (SIGABRT) + which results in profile_export_aiperf.json not being created + """ + ) + logger.debug( + f"Stderr: {result.stderr[:500] if result.stderr else 'No stderr'}" + ) else: logger.warning( f"Attempt {attempt + 1} failed with return code {result.returncode}" @@ -421,22 +430,84 @@ def run_aiperf( ) except Exception as e: logger.error(f"Error in attempt {attempt + 1}: {str(e)}") - all_results.append({"attempt": attempt + 1, "error": str(e)}) - # Sleep before next attempt (if not the last attempt) - if not success and attempt < max_attempts - 1: + # Sleep before next attempt (if not the last attempt and not continuous load) + if not success and attempt < max_attempts - 1 and not continuous_load: time.sleep(retry_delay) - if success: + if success and not continuous_load: logger.info( f"AI-Perf successfully completed all {requests_per_client} requests for {pod_name}" ) + elif success and continuous_load: + logger.info( + f"AI-Perf sustained continuous load for {pod_name} and existed succesfully" + ) else: logger.error(f"AI-Perf failed all {max_attempts} attempts for {pod_name}") return success +# TODO: use file redirection and wait() instead of pipes and communicate +def run_aiperf_with_signal_handling( + cmd_attempt: List[str], + logger: logging.Logger, + timeout: int, +) -> subprocess.CompletedProcess: + """ + Run aiperf with signal handling for graceful shutdown. + + Handles SIGINT and SIGTERM forwarding and timeout when running with subprocess.Popen. + This ensures that Ctrl-C (SIGINT) and graceful termination signals (SIGTERM) + are properly forwarded to the subprocess so it can clean up gracefully and write results files. + """ + proc = subprocess.Popen( + cmd_attempt, + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + text=True, + stdin=subprocess.DEVNULL, + ) + + def signal_handler(signum, frame): + signal_names = { + signal.SIGINT: "SIGINT", + signal.SIGTERM: "SIGTERM", + } + signal_name = signal_names.get(signum, f"signal {signum}") + logger.info(f"Received {signal_name}, forwarding to aiperf subprocess") + try: + proc.send_signal(signum) + except ProcessLookupError: + pass # Process already terminated + + signal.signal(signal.SIGINT, signal_handler) + signal.signal(signal.SIGTERM, signal_handler) + + try: + stdout, stderr = proc.communicate(timeout=timeout) + returncode = proc.returncode + except subprocess.TimeoutExpired: + logger.warning(f"AI-Perf subprocess timed out after {timeout}s") + proc.kill() + stdout, stderr = proc.communicate() + returncode = proc.returncode + except KeyboardInterrupt: + logger.info("Received KeyboardInterrupt, sending SIGINT to aiperf subprocess") + proc.send_signal(signal.SIGINT) + try: + stdout, stderr = proc.communicate(timeout=30) # Give it time to clean up + returncode = proc.returncode + except subprocess.TimeoutExpired: + logger.warning("Subprocess didn't terminate gracefully, killing it") + proc.kill() + stdout, stderr = proc.communicate() + returncode = proc.returncode + + return subprocess.CompletedProcess(cmd_attempt, returncode, stdout, stderr) + + def log_summary_metrics( output_dir: Path, logger: logging.Logger, pod_name: str, port: int ) -> None: @@ -513,6 +584,7 @@ def client( output_token_length: int, max_retries: int, retry_delay: float = 1, + continuous_load: bool = False, ): """ Generate load using AI-Perf for fault tolerance testing. @@ -527,11 +599,12 @@ def client( model: Model name log_dir: Directory for output logs and AI-Perf artifacts index: Client index used for round-robin pod selection - requests_per_client: Number of requests to generate + requests_per_client: Number of requests to generate (used if continuous load not enabled) input_token_length: Number of input tokens per request output_token_length: Number of output tokens per request max_retries: Maximum retry attempts for AI-Perf execution retry_delay: Delay in seconds between retry attempts + continuous_load: If True, use continuous load instead of fixed request count """ logger = logging.getLogger(f"CLIENT: {index}") logging.getLogger("httpx").setLevel(logging.WARNING) @@ -578,6 +651,7 @@ def client( logger=logger, max_retries=max_retries, retry_delay=retry_delay, + continuous_load=continuous_load, ) if not success: diff --git a/tests/fault_tolerance/deploy/client_factory.py b/tests/fault_tolerance/deploy/client_factory.py index 936122f082..d8f8e3f99f 100644 --- a/tests/fault_tolerance/deploy/client_factory.py +++ b/tests/fault_tolerance/deploy/client_factory.py @@ -42,6 +42,7 @@ def get_client_function(client_type: str) -> Callable: output_token_length, max_retries, retry_delay_or_rate, # Differs between implementations + continuous_load, ) Raises: diff --git a/tests/fault_tolerance/deploy/conftest.py b/tests/fault_tolerance/deploy/conftest.py index 70545b9526..2fb85fb5ad 100644 --- a/tests/fault_tolerance/deploy/conftest.py +++ b/tests/fault_tolerance/deploy/conftest.py @@ -35,6 +35,13 @@ def pytest_addoption(parser): help="Include tests that require custom builds (e.g., MoE models). " "By default, these tests are excluded.", ) + parser.addoption( + "--skip-service-restart", + action="store_true", + default=False, + help="Skip restarting NATS and etcd services before deployment. " + "By default, these services are restarted.", + ) def pytest_generate_tests(metafunc): @@ -109,3 +116,9 @@ def namespace(request): def client_type(request): """Get client type from command line or use scenario default.""" return request.config.getoption("--client-type") + + +@pytest.fixture +def skip_service_restart(request): + """Get skip restart services flag from command line.""" + return request.config.getoption("--skip-service-restart") diff --git a/tests/fault_tolerance/deploy/legacy_client.py b/tests/fault_tolerance/deploy/legacy_client.py index 5cb4df4557..668145838c 100644 --- a/tests/fault_tolerance/deploy/legacy_client.py +++ b/tests/fault_tolerance/deploy/legacy_client.py @@ -192,6 +192,7 @@ def client( max_retries, max_request_rate, retry_delay=1, + continuous_load=False, ): """Legacy custom client for fault tolerance testing. @@ -211,7 +212,11 @@ def client( max_retries: Maximum retry attempts per request max_request_rate: Maximum requests per second (for rate limiting) retry_delay: Delay in seconds between retries + continuous_load: If True, use continuous load instead of fixed request count """ + if continuous_load: + raise ValueError("Continuous load is not supported for legacy client") + logger = logging.getLogger(f"CLIENT: {index}") logging.getLogger("httpx").setLevel(logging.WARNING) @@ -228,7 +233,7 @@ def client( for i in range(requests_per_client): # Get available pods pods = managed_deployment.get_pods( - managed_deployment.frontend_service_name + [managed_deployment.frontend_service_name] ) port = 0 pod_name = None diff --git a/tests/fault_tolerance/deploy/parse_results.py b/tests/fault_tolerance/deploy/parse_results.py index 66bc967e9b..00c1839468 100644 --- a/tests/fault_tolerance/deploy/parse_results.py +++ b/tests/fault_tolerance/deploy/parse_results.py @@ -341,6 +341,7 @@ def parse_aiperf_client_results(log_dir: str) -> Dict[str, Any]: Returns: Dictionary with aggregated metrics and client count """ + logger = logging.getLogger(__name__) all_metrics: Dict[str, Any] = { "total_requests": 0, "successful_requests": 0, @@ -382,22 +383,28 @@ def parse_aiperf_client_results(log_dir: str) -> Dict[str, Any]: with open(profile_json) as f: client_metrics = json.load(f) - # AI-Perf format has "records" dictionary at the top level + # AI-Perf format can have "records" dictionary or metrics at top level + # Try records first (older format), then fall back to top level (newer format) records = client_metrics.get("records", {}) - # Extract successful request count - request_count_record = records.get("request_count", {}) + # Extract successful request count - check both locations + request_count_record = records.get( + "request_count" + ) or client_metrics.get("request_count", {}) successful_count = ( int(request_count_record.get("avg", 0)) - if request_count_record + if request_count_record and isinstance(request_count_record, dict) else 0 ) - # Extract error request count - error_request_count_record = records.get("error_request_count", {}) + # Extract error request count - check both locations + error_request_count_record = records.get( + "error_request_count" + ) or client_metrics.get("error_request_count", {}) error_request_count = ( int(error_request_count_record.get("avg", 0)) if error_request_count_record + and isinstance(error_request_count_record, dict) else 0 ) @@ -418,9 +425,17 @@ def parse_aiperf_client_results(log_dir: str) -> Dict[str, Any]: # Sum up actual error counts from each error type error_count = sum(error.get("count", 0) for error in error_summary) - # Check if test was cancelled + # Log if test was cancelled (expected for continuous load mode) if client_metrics.get("was_cancelled", False): - error_count = request_count # Mark all as failed if cancelled + logger.info( + f"AI-Perf client {item} was cancelled - anticipated if running with continuous load mode. " + f"Completed {request_count} requests before cancellation." + ) + + # Note: If test was cancelled (was_cancelled=True), we still count the requests + # that were successfully completed before cancellation. The request_count + # represents successful requests, and error_count represents actual errors. + # We don't mark cancelled requests as failed - they were just interrupted. # Validate data consistency if request_count < error_count: diff --git a/tests/fault_tolerance/deploy/scenarios.py b/tests/fault_tolerance/deploy/scenarios.py index 0dc93e384c..817f28394d 100644 --- a/tests/fault_tolerance/deploy/scenarios.py +++ b/tests/fault_tolerance/deploy/scenarios.py @@ -13,14 +13,17 @@ # See the License for the specific language governing permissions and # limitations under the License. +import asyncio +import logging import re +from abc import ABC, abstractmethod from dataclasses import dataclass, field from enum import Enum, auto from typing import TYPE_CHECKING, Dict, List, Optional, Pattern -from typing_extensions import TypedDict +from typing_extensions import Required, TypedDict -from tests.utils.managed_deployment import DeploymentSpec +from tests.utils.managed_deployment import DeploymentSpec, ManagedDeployment if TYPE_CHECKING: from tests.fault_tolerance.deploy.base_checker import BaseChecker @@ -54,8 +57,8 @@ class DeploymentInfo(TypedDict, total=False): is_moe: Optional flag indicating if this is a Mixture-of-Experts model """ - spec: DeploymentSpec - backend: str + spec: Required[DeploymentSpec] + backend: Required[str] model: str is_moe: bool @@ -155,14 +158,144 @@ class Load: overflow_request_count: int = 15 # Number of overflow requests normal_request_count: int = 15 # Number of normal requests after overflow + continuous_load: bool = ( + False # If True, use continuous load instead of fixed request count + ) + @dataclass -class Failure: +class Failure(ABC): + """Base class for all failure types.""" + + # time to wait in seconds before the failure is injected time: int - pod_name: str - command: str - signal: str = "SIGINT" - replicas: int = 1 + + # names of DGD services to inject the failure into the corresponding pods for + service_names: list[str] + + @abstractmethod + async def execute( + self, deployment: ManagedDeployment, logger: logging.Logger + ) -> list[str]: + """Execute the failure injection. + + Args: + deployment: The managed deployment to inject the failure into + logger: Logger instance for logging failure injection + + Returns: List of affected pod names + """ + pass + + @abstractmethod + def get_failure_key(self) -> str: + """Get the failure key for the failure.""" + pass + + +@dataclass +class RollingUpgradeFailure(Failure): + """Failure type for triggering rolling upgrades.""" + + async def execute( + self, deployment: ManagedDeployment, logger: logging.Logger + ) -> list[str]: + """Execute rolling upgrade failure injection.""" + await deployment.trigger_rolling_upgrade(self.service_names) + + # Need to wait for the deployment to be unready so we know the rolling upgrade has started + await deployment.wait_for_unready(timeout=60, log_interval=10) + + await deployment._wait_for_ready(timeout=1800) # 30 minute timeout + + await asyncio.sleep( + self.time + ) # have some requests processed after the rolling upgrade has completed + + return await deployment.get_pod_names(self.service_names) + + def get_failure_key(self) -> str: + """Get the failure key for the rolling upgrade failure.""" + return f"rolling_upgrade:{','.join(self.service_names)}" + + +@dataclass +class DeletePodFailure(Failure): + """Failure type for deleting pods.""" + + async def execute( + self, deployment: ManagedDeployment, logger: logging.Logger + ) -> list[str]: + """Execute pod deletion failure injection.""" + service_pod_dict = deployment.get_pods(self.service_names) + pod_names: list[str] = [] + for service_name, pods in service_pod_dict.items(): + for pod in pods: + deployment.get_pod_manifest_logs_metrics( + service_name, pod, ".before_delete" + ) + pod.delete(force=True) # force means no graceful termination + pod_names.append(pod.name) + + return pod_names + + def get_failure_key(self) -> str: + """Get the failure key for the delete pod failure.""" + return f"delete_pod:{','.join(self.service_names)}" + + +class TerminateProcessFailure(Failure): + """Failure type for terminating specific processes by name.""" + + def __init__( + self, + time: int, + service_names: list[str], + signal: str = "SIGINT", + process_name: str = "", + ): + """Initialize TerminateProcessFailure. + + Args: + time: Time to wait in seconds before the failure is injected + service_names: Names of DGD services to inject the failure into + signal: Signal to send (default: "SIGINT") + process_name: Name of the process to terminate (required) + end_condition: End condition for failure (e.g., "dgd_ready") + """ + super().__init__( + time=time, + service_names=service_names, + ) + if not process_name or not signal: + raise ValueError( + "process_name and signal are required for TerminateProcessFailure" + ) + self.process_name = process_name + self.signal = signal + + async def execute( + self, deployment: ManagedDeployment, logger: logging.Logger + ) -> list[str]: + """Execute process termination failure injection.""" + service_pod_dict = deployment.get_pods(self.service_names) + pod_names: list[str] = [] + for service_name, pods in service_pod_dict.items(): + for pod in pods: + processes = deployment.get_processes(pod) + for process in processes: + if self.process_name in process.command: + logger.info( + f"Terminating {service_name} pod {pod} Pid {process.pid} Command {process.command}" + ) + process.kill(self.signal) + pod_names.append(pod.name) + + return pod_names + + def get_failure_key(self) -> str: + """Get the failure key for the terminate process failure.""" + return f"terminate_process:{','.join(self.service_names)}:{self.process_name}:{self.signal}" @dataclass @@ -182,13 +315,25 @@ def __init__( ): super().__init__( time=time, - pod_name="Client", - command="token_overflow", + service_names=["Client"], ) self.max_seq_len = max_seq_len self.overflow_multiplier = overflow_multiplier self.overflow_token_count = int(max_seq_len * overflow_multiplier) + async def execute( + self, deployment: ManagedDeployment, logger: logging.Logger + ) -> list[str]: + """Token overflow is handled client-side, so this is a no-op.""" + # The actual overflow is handled by the client configuration + # which uses the input_token_length from the Load config + # This is just a placeholder for the abstract method + return [] + + def get_failure_key(self) -> str: + """Get the failure key for the token overflow failure.""" + return f"token_overflow:{self.overflow_token_count}" + @dataclass class Scenario: @@ -206,7 +351,7 @@ class Scenario: # Helper functions to create deployment specs -def _create_deployment_spec(backend: str, yaml_path: str) -> DeploymentInfo: +def _create_deployment_info(backend: str, yaml_path: str) -> DeploymentInfo: """Create a deployment spec with backend information. Args: @@ -240,7 +385,9 @@ def _set_replicas(deployment_spec, backend, deploy_type, replicas): spec[WORKER_MAP[backend]["prefill"]].replicas = replicas -def _set_tensor_parallel(deployment_spec, backend, deploy_type, tp_size): +def _set_tensor_parallel( + deployment_spec: DeploymentInfo, backend: str, deploy_type: str, tp_size: int +): """Set tensor parallel size for worker components.""" spec = deployment_spec["spec"] @@ -308,7 +455,7 @@ def _create_deployments_for_backend(backend: str) -> Dict[str, DeploymentInfo]: scenario_name = "-".join(name_parts) # Create and configure the deployment - deployment = _create_deployment_spec(backend, yaml_files[deploy_type]) + deployment = _create_deployment_info(backend, yaml_files[deploy_type]) if tp_size > 1: _set_tensor_parallel(deployment, backend, deploy_type, tp_size) if dp_replicas > 1: @@ -397,34 +544,69 @@ def _create_backend_failures(backend, deploy_type="disagg"): process_name = f"dynamo.{backend}" failures = { - "frontend": [Failure(30, "Frontend", "dynamo.frontend")], - "frontend_pod": [Failure(30, "Frontend", "delete_pod")], - "decode_worker": [Failure(30, decode_worker, process_name, "SIGKILL")], - "decode_worker_pod": [Failure(30, decode_worker, "delete_pod")], - "prefill_worker": [Failure(30, prefill_worker, process_name, "SIGKILL")], - "prefill_worker_pod": [Failure(30, prefill_worker, "delete_pod")], + "frontend": [ + TerminateProcessFailure( + 30, ["Frontend"], "SIGINT", process_name="dynamo.frontend" + ) + ], + "frontend_pod": [DeletePodFailure(30, ["Frontend"])], + "decode_worker": [ + TerminateProcessFailure( + 30, [decode_worker], "SIGKILL", process_name=process_name + ) + ], + "decode_worker_pod": [DeletePodFailure(30, [decode_worker])], + "prefill_worker": [ + TerminateProcessFailure( + 30, [prefill_worker], "SIGKILL", process_name=process_name + ) + ], + "prefill_worker_pod": [DeletePodFailure(30, [prefill_worker])], "none": [], } if backend == "vllm": failures["vllm_decode_engine_core"] = [ - Failure(30, decode_worker, "VLLM::EngineCore", "SIGKILL") + TerminateProcessFailure( + 30, [decode_worker], "SIGKILL", process_name="VLLM::EngineCore" + ) ] failures["vllm_prefill_engine_core"] = [ - Failure(30, prefill_worker, "VLLM::EngineCore", "SIGKILL") + TerminateProcessFailure( + 30, [prefill_worker], "SIGKILL", process_name="VLLM::EngineCore" + ) ] elif backend == "sglang": failures["sglang_decode_scheduler"] = [ - Failure(30, decode_worker, "sglang::scheduler", "SIGKILL") + TerminateProcessFailure( + 30, [decode_worker], "SIGKILL", process_name="sglang::scheduler" + ) ] failures["sglang_decode_detokenizer"] = [ - Failure(30, decode_worker, "sglang::detokenizer", "SIGKILL") + TerminateProcessFailure( + 30, [decode_worker], "SIGKILL", process_name="sglang::detokenizer" + ) ] failures["sglang_prefill_scheduler"] = [ - Failure(30, prefill_worker, "sglang::scheduler", "SIGKILL") + TerminateProcessFailure( + 30, [prefill_worker], "SIGKILL", process_name="sglang::scheduler" + ) ] failures["sglang_prefill_detokenizer"] = [ - Failure(30, prefill_worker, "sglang::detokenizer", "SIGKILL") + TerminateProcessFailure( + 30, [prefill_worker], "SIGKILL", process_name="sglang::detokenizer" + ) + ] + elif backend == "trtllm": + failures["trtllm_decode_engine_core"] = [ + TerminateProcessFailure( + 30, [decode_worker], "SIGKILL", process_name="TRTLLM::EngineCore" + ) + ] + failures["trtllm_prefill_engine_core"] = [ + TerminateProcessFailure( + 30, [prefill_worker], "SIGKILL", process_name="TRTLLM::EngineCore" + ) ] return failures @@ -533,7 +715,7 @@ def create_legacy_load( # Populate Scenarios -scenarios = {} +scenarios: dict[str, Scenario] = {} # Map of backend+deploy_type to failure definitions backend_failure_map = {} @@ -729,5 +911,59 @@ def add_token_overflow_scenarios(): ) +def add_rolling_upgrade_scenarios(): + for backend in ["vllm", "sglang", "trtllm"]: + for worker_mode in ["agg", "disagg"]: + yaml_files = { + "agg": f"examples/backends/{backend}/deploy/agg.yaml", + "disagg": f"examples/backends/{backend}/deploy/disagg.yaml", + } + deployment_info = _create_deployment_info(backend, yaml_files[worker_mode]) + deployment_spec: DeploymentSpec = deployment_info["spec"] + + service_names: list[str] = [] + + # setting replicas to 2 so we have availability of 1 replica at a time + if worker_mode == "agg" and backend == "trtllm": + service_names.append(WORKER_MAP[backend]["decode_agg"]) + else: + service_names.append(WORKER_MAP[backend]["decode"]) + + if worker_mode == "disagg": + service_names.append(WORKER_MAP[backend]["prefill"]) + + for service_name in service_names: + deployment_spec.set_service_replicas(service_name, 2) + + load = Load( + clients=10, + input_token_length=100, + output_token_length=100, + max_retries=1, + client_type="aiperf", + max_request_rate=1.0, + success_threshold=100.0, + continuous_load=True, + ) + + scenario_name = f"{backend}-{worker_mode}-rolling-upgrade" + model = "Qwen/Qwen3-0.6B" + + failure = RollingUpgradeFailure( + time=30, + service_names=service_names, + ) + scenarios[scenario_name] = Scenario( + deployment=deployment_info["spec"], + load=load, + failures=[failure], + model=model, + backend=backend, + ) + + # Add the token overflow scenarios add_token_overflow_scenarios() + +# Add the rolling upgrade scenarios +add_rolling_upgrade_scenarios() diff --git a/tests/fault_tolerance/deploy/test_deployment.py b/tests/fault_tolerance/deploy/test_deployment.py index caeb9039a2..8fe12dba20 100644 --- a/tests/fault_tolerance/deploy/test_deployment.py +++ b/tests/fault_tolerance/deploy/test_deployment.py @@ -1,12 +1,15 @@ # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-License-Identifier: Apache-2.0 +import asyncio import logging import multiprocessing +import os import re -import time +import signal from contextlib import contextmanager -from typing import Any +from multiprocessing.context import SpawnProcess +from typing import Any, Optional import pytest @@ -17,11 +20,12 @@ from tests.fault_tolerance.deploy.scenarios import ( OVERFLOW_SUFFIX, RECOVERY_SUFFIX, + Failure, Load, - TokenOverflowFailure, + Scenario, scenarios, ) -from tests.utils.managed_deployment import ManagedDeployment +from tests.utils.managed_deployment import DeploymentSpec, ManagedDeployment @pytest.fixture @@ -55,18 +59,18 @@ def scenario(scenario_name, client_type): @contextmanager def _clients( - logger, - request, - deployment_spec, - namespace, - model, + logger: logging.Logger, + log_dir: str, + deployment_spec: DeploymentSpec, + namespace: str, + model: str, load_config: Load, ): """Start client processes using factory pattern for client selection. Args: logger: Logger instance - request: Pytest request fixture + log_dir: Log directory for output logs and client logs/artifacts deployment_spec: Deployment specification namespace: Kubernetes namespace model: Model name to test @@ -79,7 +83,7 @@ def _clients( f"Starting {load_config.clients} clients using '{load_config.client_type}' client" ) - procs = [] + procs: list[SpawnProcess] = [] ctx = multiprocessing.get_context("spawn") # Determine retry_delay_or_rate based on client type @@ -90,6 +94,9 @@ def _clients( # AI-Perf client uses retry_delay between attempts (default 5s) retry_delay_or_rate = 5 + # Check if this is a continuous load test (rolling upgrade scenarios) + continuous_load = getattr(load_config, "continuous_load", False) + # Check if this is a mixed token test (overflow + recovery) # If mixed_token_test is True, run two phases; otherwise run normally if hasattr(load_config, "mixed_token_test") and load_config.mixed_token_test: @@ -108,13 +115,14 @@ def _clients( deployment_spec, namespace, model, - request.node.name + OVERFLOW_SUFFIX, + f"{log_dir}{OVERFLOW_SUFFIX}", i, load_config.overflow_request_count, # 15 overflow requests load_config.overflow_token_length, # 2x max_seq_len tokens load_config.output_token_length, load_config.max_retries, retry_delay_or_rate, + continuous_load, ), ) proc_overflow.start() @@ -128,7 +136,7 @@ def _clients( logger.info("Overflow requests completed. Starting recovery phase...") # Second phase: Send normal requests to test recovery - procs_recovery = [] + procs_recovery: list[SpawnProcess] = [] for i in range(load_config.clients): proc_normal = ctx.Process( target=client_func, @@ -136,7 +144,7 @@ def _clients( deployment_spec, namespace, model, - request.node.name + RECOVERY_SUFFIX, + f"{log_dir}{RECOVERY_SUFFIX}", i, load_config.normal_request_count, # 15 normal requests load_config.input_token_length, # Normal token count @@ -161,13 +169,14 @@ def _clients( deployment_spec, namespace, model, - request.node.name, + log_dir, i, load_config.requests_per_client, load_config.input_token_length, load_config.output_token_length, load_config.max_retries, retry_delay_or_rate, + continuous_load, # Pass continuous_load flag ), ) ) @@ -182,65 +191,50 @@ def _clients( logger.debug(f"{proc} joined") -def _inject_failures(failures, logger, deployment: ManagedDeployment): # noqa: F811 - """Inject failures and return info about affected pods. - - Returns: - Dict mapping failure info to list of affected pod names - Example: {"VllmDecodeWorker:delete_pod": ["pod-abc123", "pod-xyz789"]} +def _terminate_client_processes( + client_procs: list[SpawnProcess], + logger: logging.Logger, +): """ - affected_pods: dict[str, list] = {} - - for failure in failures: - time.sleep(failure.time) - - # Handle TokenOverflowFailure differently - it's a client-side injection - if isinstance(failure, TokenOverflowFailure): - # The actual overflow is handled by the client configuration - # which uses the input_token_length from the Load config - # This is just logging for visibility - continue - - pods = deployment.get_pods(failure.pod_name)[failure.pod_name] - - num_pods = len(pods) + Terminate client processes. + """ + # Send SIGINT to client processes to stop continuous load + if client_procs: + logger.info(f"Sending SIGINT to {len(client_procs)} client processes...") + for proc in client_procs: + if proc.is_alive(): + try: + if proc.pid is not None: + logger.debug(f"Sending SIGINT to client process {proc.pid}") + os.kill(proc.pid, signal.SIGINT) + else: + raise ValueError(f"Process {proc} has no PID") + except ProcessLookupError: + logger.debug(f"Process {proc.pid} already terminated") + except Exception as e: + logger.warning(f"Failed to send SIGINT to process {proc.pid}: {e}") + logger.info( + "SIGINT sent to all client processes, waiting for graceful shutdown..." + ) + else: + logger.warning("No client processes provided to terminate") - if not pods: - continue - replicas = failure.replicas +async def _inject_failures( + failures: list[Failure], + logger: logging.Logger, + deployment: ManagedDeployment, +) -> dict[str, list]: # noqa: F811 + affected_pods: dict[str, list] = {} - if not replicas: - replicas = num_pods + for failure in failures: + await asyncio.sleep(failure.time) logger.info(f"Injecting failure for: {failure}") - # Track which pods were affected by this failure - failure_key = f"{failure.pod_name}:{failure.command}" - if failure_key not in affected_pods: - affected_pods[failure_key] = [] - - for x in range(replicas): - pod = pods[x % num_pods] - - # Capture the exact pod name before we kill it - pod_name = pod.name - affected_pods[failure_key].append(pod_name) - - logger.info(f"Target pod for failure: {pod_name}") - - if failure.command == "delete_pod": - deployment.get_pod_logs(failure.pod_name, pod, ".before_delete") - logger.info(f"Deleting pod: {pod_name}") - pod.delete(force=True) - else: - processes = deployment.get_processes(pod) - for process in processes: - if failure.command in process.command: - logger.info( - f"Terminating {failure.pod_name} Pid {process.pid} Command {process.command} in pod {pod_name}" - ) - process.kill(failure.signal) + affected_pods[failure.get_failure_key()] = await failure.execute( + deployment, logger + ) return affected_pods @@ -445,11 +439,12 @@ def results_summary(): @pytest.mark.slow @pytest.mark.filterwarnings("ignore::DeprecationWarning") async def test_fault_scenario( - scenario, # noqa: F811 + scenario: Scenario, # noqa: F811 request, - image, - namespace, + image: str, + namespace: str, validation_context, # noqa: F811 # Shared context for passing data to validation + skip_service_restart: bool, ): """ Test dynamo serve deployments with injected failures @@ -468,6 +463,7 @@ async def test_fault_scenario( if image: scenario.deployment.set_image(image) + model: Optional[str] = None if scenario.model: scenario.deployment.set_model(scenario.model) model = scenario.model @@ -500,6 +496,7 @@ async def test_fault_scenario( namespace=namespace, log_dir=request.node.name, deployment_spec=scenario.deployment, + skip_service_restart=skip_service_restart, ) as deployment: # Populate shared context for validation validation_context["deployment"] = deployment @@ -507,14 +504,17 @@ async def test_fault_scenario( with _clients( logger, - request, + request.node.name, scenario.deployment, namespace, model, scenario.load, # Pass entire Load config object - ): + ) as client_procs: # Inject failures and capture which pods were affected - affected_pods = _inject_failures(scenario.failures, logger, deployment) - validation_context["affected_pods"] = affected_pods - + affected_pods = await _inject_failures( + scenario.failures, logger, deployment + ) logger.info(f"Affected pods during test: {affected_pods}") + + if scenario.load.continuous_load: + _terminate_client_processes(client_procs, logger) diff --git a/tests/utils/managed_deployment.py b/tests/utils/managed_deployment.py index 8dd008a61a..5ee541833d 100644 --- a/tests/utils/managed_deployment.py +++ b/tests/utils/managed_deployment.py @@ -5,18 +5,18 @@ import logging import os import re +import secrets import shlex import time from dataclasses import dataclass, field from typing import Any, List, Optional import kr8s -import kubernetes import requests import yaml -from kr8s.objects import Pod as kr8s_Pod -from kr8s.objects import Service as kr8s_Service +from kr8s.objects import Pod, Service from kubernetes_asyncio import client, config +from kubernetes_asyncio.client import exceptions def _get_workspace_dir() -> str: @@ -65,6 +65,15 @@ def image(self, value: str): self._spec["extraPodSpec"]["mainContainer"] = {} self._spec["extraPodSpec"]["mainContainer"]["image"] = value + @property + def envs(self) -> list[dict[str, str]]: + """Environment variables for the service""" + return self._spec.get("envs", []) + + @envs.setter + def envs(self, value: list[dict[str, str]]): + self._spec["envs"] = value + # ----- Replicas ----- @property def replicas(self) -> int: @@ -314,8 +323,36 @@ def get_logging_config(self) -> dict: return {"jsonl_enabled": jsonl_enabled, "log_level": log_level} + def set_service_env_var(self, service_name: str, name: str, value: str): + """ + Set an environment variable for a specific service + """ + service = self.get_service(service_name) + envs = service.envs if service.envs is not None else [] + + # if env var already exists, update it + for env in envs: + if env["name"] == name: + env["value"] = value + service.envs = envs # Save back to trigger the setter + return + + # if env var does not exist, add it + envs.append({"name": name, "value": value}) + service.envs = envs # Save back to trigger the setter + + def get_service_env_vars(self, service_name: str) -> list[dict]: + """ + Get all environment variables for a specific service + + Returns: + List of environment variable dicts (e.g., [{"name": "VAR", "value": "val"}]) + """ + service = self.get_service(service_name) + return service.envs + @property - def services(self) -> list: + def services(self) -> list[ServiceSpec]: """List of ServiceSpec objects""" return [ ServiceSpec(svc, spec) @@ -340,28 +377,25 @@ def add_arg_to_service(self, service_name: str, arg_name: str, arg_value: str): arg_name: Argument name (e.g., "--max-model-len", "--max-seq-len") arg_value: Argument value (e.g., "1024") """ - # Get the service - if service_name not in self._deployment_spec["spec"]["services"]: - raise ValueError(f"Service '{service_name}' not found in deployment spec") - - service = self._deployment_spec["spec"]["services"][service_name] + service = self.get_service(service_name) + service_spec = service._spec # Ensure args list exists - if "extraPodSpec" not in service: - service["extraPodSpec"] = {"mainContainer": {}} - if "mainContainer" not in service["extraPodSpec"]: - service["extraPodSpec"]["mainContainer"] = {} - if "args" not in service["extraPodSpec"]["mainContainer"]: - service["extraPodSpec"]["mainContainer"]["args"] = [] + if "extraPodSpec" not in service_spec: + service_spec["extraPodSpec"] = {"mainContainer": {}} + if "mainContainer" not in service_spec["extraPodSpec"]: + service_spec["extraPodSpec"]["mainContainer"] = {} + if "args" not in service_spec["extraPodSpec"]["mainContainer"]: + service_spec["extraPodSpec"]["mainContainer"]["args"] = [] - args_list = service["extraPodSpec"]["mainContainer"]["args"] + args_list = service_spec["extraPodSpec"]["mainContainer"]["args"] # Convert to list if needed (sometimes it's a single string) if isinstance(args_list, str): import shlex args_list = shlex.split(args_list) - service["extraPodSpec"]["mainContainer"]["args"] = args_list + service_spec["extraPodSpec"]["mainContainer"]["args"] = args_list # Find existing argument arg_index = None @@ -384,6 +418,24 @@ def add_arg_to_service(self, service_name: str, arg_name: str, arg_value: str): # Add new argument args_list.extend([arg_name, arg_value]) + def get_service(self, service_name: str) -> ServiceSpec: + """ + Get a specific service from the deployment spec + """ + if service_name not in self._deployment_spec["spec"]["services"]: + raise ValueError(f"Service '{service_name}' not found in deployment spec") + + return ServiceSpec( + service_name, self._deployment_spec["spec"]["services"][service_name] + ) + + def set_service_replicas(self, service_name: str, replicas: int): + """ + Set the number of replicas for a specific service + """ + service = self.get_service(service_name) + service.replicas = replicas + def save(self, out_file: str): """Save updated deployment to file""" with open(out_file, "w") as f: @@ -391,7 +443,7 @@ def save(self, out_file: str): class PodProcess: - def __init__(self, pod: kr8s_Pod, line: str): + def __init__(self, pod: Pod, line: str): self.pid = int(re.split(r"\s+", line)[1]) self.command = " ".join( re.split(r"\s+", line)[10:] @@ -439,10 +491,13 @@ class ManagedDeployment: log_dir: str deployment_spec: DeploymentSpec namespace: str - frontend_service_name: Optional[str] = "Frontend" + # TODO: this should be determined by the deployment_spec + # the service containing component_type: Frontend determines what is actually the frontend service + frontend_service_name: str = "Frontend" + skip_service_restart: bool = False - _custom_api: Optional[Any] = None - _core_api: Optional[Any] = None + _custom_api: Optional[client.CustomObjectsApi] = None + _core_api: Optional[client.CoreV1Api] = None _in_cluster: bool = False _logger: logging.Logger = logging.getLogger() _port_forward: Optional[Any] = None @@ -457,7 +512,7 @@ async def _init_kubernetes(self): """Initialize kubernetes client""" try: # Try in-cluster config first (for pods with service accounts) - await config.load_incluster_config() + config.load_incluster_config() self._in_cluster = True except Exception: # Fallback to kube config file (for local development) @@ -511,6 +566,17 @@ async def _restart_stateful(self, name, label): self._logger.info(f"Restarted {name} {label}") + async def wait_for_unready(self, timeout: int = 1800, sleep=1, log_interval=60): + """ + Wait for the custom resource to be unready. + + Args: + timeout: Maximum time to wait in seconds, default to 30 mins (image pulling can take a while) + """ + return await self._wait_for_condition( + timeout, sleep, log_interval, False, "pending" + ) + async def _wait_for_ready(self, timeout: int = 1800, sleep=1, log_interval=60): """ Wait for the custom resource to be ready. @@ -518,9 +584,23 @@ async def _wait_for_ready(self, timeout: int = 1800, sleep=1, log_interval=60): Args: timeout: Maximum time to wait in seconds, default to 30 mins (image pulling can take a while) """ + return await self._wait_for_condition( + timeout, sleep, log_interval, True, "successful" + ) + + async def _wait_for_condition( + self, + timeout: int = 1800, + sleep=1, + log_interval=60, + desired_ready_condition_val: bool = True, + desired_state_val: str = "successful", + ): start_time = time.time() - self._logger.info(f"Waiting for Deployment {self._deployment_name}") + self._logger.info( + f"Waiting for Deployment {self._deployment_name} to have Ready condition {desired_ready_condition_val} and state {desired_state_val}" + ) attempt = 0 @@ -528,7 +608,7 @@ async def _wait_for_ready(self, timeout: int = 1800, sleep=1, log_interval=60): try: attempt += 1 assert self._custom_api is not None, "Kubernetes API not initialized" - status = await self._custom_api.get_namespaced_custom_object( + status = await self._custom_api.get_namespaced_custom_object( # type: ignore[awaitable-is-not-coroutine] group="nvidia.com", version="v1alpha1", namespace=self.namespace, @@ -538,29 +618,34 @@ async def _wait_for_ready(self, timeout: int = 1800, sleep=1, log_interval=60): # Check both conditions: # 1. Ready condition is True # 2. State is successful - status_obj = status.get("status", {}) - conditions = status_obj.get("conditions", []) - current_state = status_obj.get("state", "unknown") + status_obj = status.get("status", {}) # type: ignore[attr-defined] + conditions = status_obj.get("conditions", []) # type: ignore[attr-defined] + current_state = status_obj.get("state", "unknown") # type: ignore[attr-defined] - ready_condition = False + observed_ready_condition_val = "" for condition in conditions: - if ( - condition.get("type") == "Ready" - and condition.get("status") == "True" - ): - ready_condition = True - break - - state_successful = status_obj.get("state") == "successful" - - if ready_condition and state_successful: + if condition.get("type") == "Ready": + observed_ready_condition_val = condition.get("status") + if observed_ready_condition_val == str( + desired_ready_condition_val + ): + break + + observed_state_val = status_obj.get("state") # type: ignore[attr-defined] + + if ( + observed_ready_condition_val == str(desired_ready_condition_val) + and observed_state_val == desired_state_val + ): self._logger.info(f"Current deployment state: {current_state}") self._logger.info(f"Current conditions: {conditions}") self._logger.info( f"Elapsed time: {time.time() - start_time:.1f}s / {timeout}s" ) - self._logger.info(f"Deployment {self._deployment_name} is ready") + self._logger.info( + f"Deployment {self._deployment_name} has Ready condition {desired_ready_condition_val} and state {desired_state_val}" + ) return True else: if attempt % log_interval == 0: @@ -570,10 +655,10 @@ async def _wait_for_ready(self, timeout: int = 1800, sleep=1, log_interval=60): f"Elapsed time: {time.time() - start_time:.1f}s / {timeout}s" ) self._logger.info( - f"Deployment not ready yet - Ready condition: {ready_condition}, State successful: {state_successful}" + f"Deployment has Ready condition {observed_ready_condition_val} and state {observed_state_val}, desired condition {desired_ready_condition_val} and state {desired_state_val}" ) - except kubernetes.client.rest.ApiException as e: + except exceptions.ApiException as e: self._logger.info( f"API Exception while checking deployment status: {e}" ) @@ -624,7 +709,7 @@ async def _create_deployment(self): ) self._logger.info(self.deployment_spec.spec()) self._logger.info(f"Deployment Started {self._deployment_name}") - except kubernetes.client.rest.ApiException as e: + except exceptions.ApiException as e: if e.status == 409: # Already exists self._logger.info(f"Deployment {self._deployment_name} already exists") else: @@ -633,7 +718,64 @@ async def _create_deployment(self): ) raise - def get_processes(self, pod) -> list: + async def trigger_rolling_upgrade(self, service_names: list[str]): + """ + Triggers a rolling update for a list of services + This is a dummy update - sets an env var on the service + """ + + if not service_names: + raise ValueError( + "service_names cannot be empty for trigger_rolling_upgrade" + ) + + patch_body: dict[str, Any] = {"spec": {"services": {}}} + + for service_name in service_names: + self.deployment_spec.set_service_env_var( + service_name, "TEST_ROLLING_UPDATE_TRIGGER", secrets.token_hex(8) + ) + + updated_envs = self.deployment_spec.get_service_env_vars(service_name) + patch_body["spec"]["services"][service_name] = {"envs": updated_envs} + + try: + assert self._custom_api is not None, "Kubernetes API not initialized" + await self._custom_api.patch_namespaced_custom_object( + group="nvidia.com", + version="v1alpha1", + namespace=self.namespace, + plural="dynamographdeployments", + name=self._deployment_name, + body=patch_body, + _content_type="application/merge-patch+json", + ) + except exceptions.ApiException as e: + self._logger.info( + f"Failed to patch deployment {self._deployment_name}: {e}" + ) + raise + + async def get_pod_names(self, service_names: list[str] | None = None) -> list[str]: + if not service_names: + service_names = [service.name for service in self.deployment_spec.services] + + pod_names: list[str] = [] + + for service_name in service_names: + label_selector = ( + f"nvidia.com/selector={self._deployment_name}-{service_name.lower()}" + ) + assert self._core_api is not None, "Kubernetes API not initialized" + pods: client.V1PodList = await self._core_api.list_namespaced_pod( + self.namespace, label_selector=label_selector + ) + for pod in pods.items: + pod_names.append(pod.metadata.name) + + return pod_names + + def get_processes(self, pod: Pod) -> list[PodProcess]: """Get list of processes in the given pod""" result = pod.exec(["ps", "-aux"]) lines = result.stdout.decode().splitlines() @@ -646,38 +788,34 @@ def get_service(self, service_name=None): service_name = "" full_service_name = f"{self._deployment_name}-{service_name.lower()}" - return kr8s_Service.get(full_service_name, namespace=self.namespace) + return Service.get(full_service_name, namespace=self.namespace) - def get_pods(self, service_name=None): - result = {} + def get_pods(self, service_names: list[str] | None = None) -> dict[str, list[Pod]]: + result: dict[str, list[Pod]] = {} - service_list = [] + if not service_names: + service_names = [service.name for service in self.deployment_spec.services] - if not service_name: - service_list = [service.name for service in self.deployment_spec.services] - else: - service_list = [service_name] - - for service in service_list: + for service_name in service_names: # List pods for this service using the selector label # nvidia.com/selector: deployment-name-service label_selector = ( - f"nvidia.com/selector={self._deployment_name}-{service.lower()}" + f"nvidia.com/selector={self._deployment_name}-{service_name.lower()}" ) - pods = [] + pods: list[Pod] = [] for pod in kr8s.get( "pods", namespace=self.namespace, label_selector=label_selector ): - pods.append(pod) + pods.append(pod) # type: ignore[arg-type] - result[service] = pods + result[service_name] = pods return result - def get_pod_logs(self, service, pod, suffix=""): - directory = os.path.join(self.log_dir, service) + def get_pod_manifest_logs_metrics(self, service_name: str, pod: Pod, suffix=""): + directory = os.path.join(self.log_dir, service_name) os.makedirs(directory, exist_ok=True) try: @@ -699,16 +837,20 @@ def get_pod_logs(self, service, pod, suffix=""): except Exception as e: self._logger.debug(e) - self._get_pod_metrics(pod, service, suffix) + self._get_pod_metrics(pod, service_name, suffix) def _get_service_logs(self, service_name=None, suffix=""): - service_pods = self.get_pods(service_name) + service_names = None + if service_name: + service_names = [service_name] + + service_pods = self.get_pods(service_names) for service, pods in service_pods.items(): - for i, pod in enumerate(pods): - self.get_pod_logs(service, pod, suffix) + for pod in pods: + self.get_pod_manifest_logs_metrics(service, pod, suffix) - def _get_pod_metrics(self, pod, service_name, suffix=""): + def _get_pod_metrics(self, pod: Pod, service_name: str, suffix=""): directory = os.path.join(self.log_dir, service_name) os.makedirs(directory, exist_ok=True) port = None @@ -757,11 +899,13 @@ async def _delete_deployment(self): plural="dynamographdeployments", name=self._deployment_name, ) - except client.exceptions.ApiException as e: + except exceptions.ApiException as e: if e.status != 404: # Ignore if already deleted raise - def port_forward(self, pod, remote_port, max_connection_attempts=3): + def port_forward( + self, pod: Pod, remote_port: int, max_connection_attempts: int = 3 + ): """Attempt to connect to a pod and return the port-forward object on success. Note: Port forwards run in background threads. When pods are terminated, @@ -866,9 +1010,13 @@ async def __aenter__(self): self._deployment_name = self.deployment_spec.name logging.getLogger("httpx").setLevel(logging.WARNING) await self._init_kubernetes() - await self._delete_deployment() - await self._restart_etcd() - await self._restart_nats() + + # Run delete deployment and service restarts in parallel + tasks = [self._delete_deployment()] + if not self.skip_service_restart: + tasks.extend([self._restart_etcd(), self._restart_nats()]) + await asyncio.gather(*tasks) + await self._create_deployment() await self._wait_for_ready()