Skip to content

Commit 9b268c9

Browse files
committed
fix rolling upgrade test
1 parent 540f274 commit 9b268c9

File tree

1 file changed

+10
-7
lines changed

1 file changed

+10
-7
lines changed

tests/utils/managed_deployment.py

Lines changed: 10 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -328,17 +328,18 @@ def set_service_env_var(self, service_name: str, name: str, value: str):
328328
Set an environment variable for a specific service
329329
"""
330330
service = self.get_service(service_name)
331-
if service.envs is None:
332-
service.envs = []
331+
envs = service.envs if service.envs is not None else []
333332

334333
# if env var already exists, update it
335-
for env in service.envs:
334+
for env in envs:
336335
if env["name"] == name:
337336
env["value"] = value
337+
service.envs = envs # Save back to trigger the setter
338338
return
339339

340340
# if env var does not exist, add it
341-
service.envs.append({"name": name, "value": value})
341+
envs.append({"name": name, "value": value})
342+
service.envs = envs # Save back to trigger the setter
342343

343344
def get_service_env_vars(self, service_name: str) -> list[dict]:
344345
"""
@@ -754,7 +755,7 @@ async def trigger_rolling_upgrade(self, service_names: list[str]):
754755
f"Failed to patch deployment {self._deployment_name}: {e}"
755756
)
756757
raise
757-
758+
758759
async def get_pod_names(self, service_names: list[str] | None = None) -> list[str]:
759760
if not service_names:
760761
service_names = [service.name for service in self.deployment_spec.services]
@@ -766,10 +767,12 @@ async def get_pod_names(self, service_names: list[str] | None = None) -> list[st
766767
f"nvidia.com/selector={self._deployment_name}-{service_name.lower()}"
767768
)
768769
assert self._core_api is not None, "Kubernetes API not initialized"
769-
pods: client.V1PodList = await self._core_api.list_namespaced_pod(self.namespace, label_selector=label_selector)
770+
pods: client.V1PodList = await self._core_api.list_namespaced_pod(
771+
self.namespace, label_selector=label_selector
772+
)
770773
for pod in pods.items:
771774
pod_names.append(pod.metadata.name)
772-
775+
773776
return pod_names
774777

775778
def get_processes(self, pod: Pod) -> list[PodProcess]:

0 commit comments

Comments
 (0)