Skip to content

Commit 3b0dc3e

Browse files
committed
fix mypy errors
1 parent 3de77cd commit 3b0dc3e

File tree

4 files changed

+17
-15
lines changed

4 files changed

+17
-15
lines changed

tests/fault_tolerance/deploy/client.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
from typing import Any, Dict, List, Optional, Tuple
2626

2727
import requests
28+
from kr8s.objects import Pod
2829

2930
from tests.utils.managed_deployment import ManagedDeployment
3031

@@ -45,7 +46,7 @@ def get_frontend_port(
4546
deployment_spec: Any,
4647
pod_ports: Dict[str, Any],
4748
logger: logging.Logger,
48-
) -> Tuple[Optional[str], Optional[int], Optional[str]]:
49+
) -> Tuple[Optional[str], Optional[int], Optional[Pod]]:
4950
"""
5051
Select a frontend pod using round-robin and setup port forwarding.
5152

tests/fault_tolerance/deploy/scenarios.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -199,7 +199,7 @@ class RollingUpgradeFailure(Failure):
199199

200200
async def execute(
201201
self, deployment: ManagedDeployment, logger: logging.Logger
202-
) -> None:
202+
) -> list[str]:
203203
"""Execute rolling upgrade failure injection."""
204204
await deployment.trigger_rolling_upgrade(self.service_names)
205205

@@ -225,7 +225,7 @@ class DeletePodFailure(Failure):
225225

226226
async def execute(
227227
self, deployment: ManagedDeployment, logger: logging.Logger
228-
) -> None:
228+
) -> list[str]:
229229
"""Execute pod deletion failure injection."""
230230
service_pod_dict = deployment.get_pods(self.service_names)
231231
pod_names: list[str] = []
@@ -276,7 +276,7 @@ def __init__(
276276

277277
async def execute(
278278
self, deployment: ManagedDeployment, logger: logging.Logger
279-
) -> None:
279+
) -> list[str]:
280280
"""Execute process termination failure injection."""
281281
service_pod_dict = deployment.get_pods(self.service_names)
282282
pod_names: list[str] = []
@@ -323,7 +323,7 @@ def __init__(
323323

324324
async def execute(
325325
self, deployment: ManagedDeployment, logger: logging.Logger
326-
) -> None:
326+
) -> list[str]:
327327
"""Token overflow is handled client-side, so this is a no-op."""
328328
# The actual overflow is handled by the client configuration
329329
# which uses the input_token_length from the Load config

tests/fault_tolerance/deploy/test_deployment.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
import signal
1010
from contextlib import contextmanager
1111
from multiprocessing.context import SpawnProcess
12-
from typing import Any
12+
from typing import Any, Optional
1313

1414
import pytest
1515

@@ -463,6 +463,7 @@ async def test_fault_scenario(
463463
if image:
464464
scenario.deployment.set_image(image)
465465

466+
model: Optional[str] = None
466467
if scenario.model:
467468
scenario.deployment.set_model(scenario.model)
468469
model = scenario.model

tests/utils/managed_deployment.py

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,14 @@ def image(self) -> Optional[str]:
5757
except KeyError:
5858
return None
5959

60+
@image.setter
61+
def image(self, value: str):
62+
if "extraPodSpec" not in self._spec:
63+
self._spec["extraPodSpec"] = {"mainContainer": {}}
64+
if "mainContainer" not in self._spec["extraPodSpec"]:
65+
self._spec["extraPodSpec"]["mainContainer"] = {}
66+
self._spec["extraPodSpec"]["mainContainer"]["image"] = value
67+
6068
@property
6169
def envs(self) -> list[dict[str, str]]:
6270
"""Environment variables for the service"""
@@ -66,14 +74,6 @@ def envs(self) -> list[dict[str, str]]:
6674
def envs(self, value: list[dict[str, str]]):
6775
self._spec["envs"] = value
6876

69-
@image.setter
70-
def image(self, value: str):
71-
if "extraPodSpec" not in self._spec:
72-
self._spec["extraPodSpec"] = {"mainContainer": {}}
73-
if "mainContainer" not in self._spec["extraPodSpec"]:
74-
self._spec["extraPodSpec"]["mainContainer"] = {}
75-
self._spec["extraPodSpec"]["mainContainer"]["image"] = value
76-
7777
# ----- Replicas -----
7878
@property
7979
def replicas(self) -> int:
@@ -729,7 +729,7 @@ async def trigger_rolling_upgrade(self, service_names: list[str]):
729729
"service_names cannot be empty for trigger_rolling_upgrade"
730730
)
731731

732-
patch_body = {"spec": {"services": {}}}
732+
patch_body: dict[str, Any] = {"spec": {"services": {}}}
733733

734734
for service_name in service_names:
735735
self.deployment_spec.set_service_env_var(

0 commit comments

Comments
 (0)