Skip to content

Commit 31163f7

Browse files
authored
Merge branch 'main' into bis/dep-681-add-agg-lora-tests
2 parents 4f6394d + 0ce7280 commit 31163f7

File tree

8 files changed

+396
-24
lines changed

8 files changed

+396
-24
lines changed

.github/workflows/templates/akamai-eccu-flush.xslt

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,19 @@
1-
<?xml version="1.0" encoding="UTF-8"?>
1+
<?xml-stylesheet type="text/xsl" href="akamai-eccu-flush.xslt"?>
2+
<!--
3+
Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
4+
5+
Licensed under the Apache License, Version 2.0 (the "License");
6+
you may not use this file except in compliance with the License.
7+
You may obtain a copy of the License at
8+
9+
http://www.apache.org/licenses/LICENSE-2.0
10+
11+
Unless required by applicable law or agreed to in writing, software
12+
distributed under the License is distributed on an "AS IS" BASIS,
13+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
See the License for the specific language governing permissions and
15+
limitations under the License.
16+
-->
217
<!--
318
Akamai ECCU (Edge Content Control Utility) XML Generator
419

components/src/dynamo/planner/kube.py

Lines changed: 55 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -78,11 +78,48 @@ def get_graph_deployment(self, graph_deployment_name: str) -> dict:
7878
)
7979
raise
8080

81-
def update_graph_replicas(
82-
self, graph_deployment_name: str, component_name: str, replicas: int
81+
def update_service_replicas(
82+
self, graph_deployment_name: str, service_name: str, replicas: int
83+
) -> None:
84+
"""
85+
Update replicas for a service using Scale subresource when DGDSA exists.
86+
Falls back to DGD patch for backward compatibility with older operators.
87+
88+
Args:
89+
graph_deployment_name: Name of the DynamoGraphDeployment
90+
service_name: Name of the service in DGD.spec.services
91+
replicas: Desired number of replicas
92+
"""
93+
# DGDSA naming convention: <dgd-name>-<lowercase-service-name>
94+
adapter_name = f"{graph_deployment_name}-{service_name.lower()}"
95+
96+
try:
97+
# Try to scale via DGDSA Scale subresource
98+
self.custom_api.patch_namespaced_custom_object_scale(
99+
group="nvidia.com",
100+
version="v1alpha1",
101+
namespace=self.current_namespace,
102+
plural="dynamographdeploymentscalingadapters",
103+
name=adapter_name,
104+
body={"spec": {"replicas": replicas}},
105+
)
106+
logger.info(f"Scaled DGDSA {adapter_name} to {replicas} replicas")
107+
108+
except client.ApiException as e:
109+
if e.status == 404:
110+
# DGDSA doesn't exist - fall back to DGD patch (old operator)
111+
logger.info(
112+
f"DGDSA {adapter_name} not found, falling back to DGD update"
113+
)
114+
self._update_dgd_replicas(graph_deployment_name, service_name, replicas)
115+
else:
116+
raise
117+
118+
def _update_dgd_replicas(
119+
self, graph_deployment_name: str, service_name: str, replicas: int
83120
) -> None:
84-
"""Update the replicas count for a component in a DynamoGraphDeployment"""
85-
patch = {"spec": {"services": {component_name: {"replicas": replicas}}}}
121+
"""Update replicas directly in DGD (fallback for old operators)"""
122+
patch = {"spec": {"services": {service_name: {"replicas": replicas}}}}
86123
self.custom_api.patch_namespaced_custom_object(
87124
group="nvidia.com",
88125
version="v1alpha1",
@@ -91,6 +128,20 @@ def update_graph_replicas(
91128
name=graph_deployment_name,
92129
body=patch,
93130
)
131+
logger.info(
132+
f"Updated DGD {graph_deployment_name} service {service_name} to {replicas} replicas"
133+
)
134+
135+
def update_graph_replicas(
136+
self, graph_deployment_name: str, component_name: str, replicas: int
137+
) -> None:
138+
"""
139+
Update replicas for a service. Now uses DGDSA when available.
140+
141+
Deprecated: Use update_service_replicas() instead for clarity.
142+
This method is kept for backward compatibility.
143+
"""
144+
self.update_service_replicas(graph_deployment_name, component_name, replicas)
94145

95146
def is_deployment_ready(self, deployment: dict) -> bool:
96147
"""Check if a graph deployment is ready"""

components/src/dynamo/trtllm/request_handlers/handler_base.py

Lines changed: 99 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -106,6 +106,76 @@ def check_error(self, result: dict):
106106
result["finish_reason"] == "stop" or result["finish_reason"] == "error"
107107
)
108108

109+
@staticmethod
110+
def _extract_logprobs(
111+
output, num_output_tokens_so_far: int
112+
) -> tuple[list[float] | None, list[list[dict]] | None]:
113+
"""
114+
Extract logprobs from the TRTLLM output for new tokens.
115+
116+
Args:
117+
output: TRTLLM CompletionOutput object
118+
num_output_tokens_so_far: Number of tokens already processed
119+
Returns:
120+
Tuple of (log_probs, top_logprobs) in Dynamo's expected format:
121+
- log_probs: List of log probabilities for each new token
122+
- top_logprobs: List of top logprobs dicts for each new token
123+
"""
124+
if output.logprobs is None:
125+
return None, None
126+
127+
# Get logprobs for new tokens only
128+
new_logprobs = output.logprobs[num_output_tokens_so_far:]
129+
if not new_logprobs:
130+
return None, None
131+
132+
# From TRTLLM CompletionOutput API, logprobs: (TokenLogprobs | List[float], optional)
133+
# Expect TokenLogprobs output when logprobs is set, check edge case where list[float] is returned instead
134+
if isinstance(new_logprobs[0], float):
135+
return [float(lp) for lp in new_logprobs], None
136+
137+
log_probs = []
138+
top_logprobs = []
139+
140+
for token_idx, token_logprobs_dict in enumerate(new_logprobs):
141+
if token_logprobs_dict is None:
142+
continue
143+
144+
# Get the actual token_id that was generated at this position
145+
actual_token_id = output.token_ids[num_output_tokens_so_far + token_idx]
146+
147+
# Extract log probability for the selected token
148+
if actual_token_id in token_logprobs_dict:
149+
selected_logprob = token_logprobs_dict[actual_token_id]
150+
log_probs.append(float(selected_logprob.logprob))
151+
else:
152+
# Fallback: use the first logprob if selected token not found
153+
first_logprob = next(iter(token_logprobs_dict.values()), None)
154+
if first_logprob:
155+
log_probs.append(float(first_logprob.logprob))
156+
157+
# Build top_logprobs list for this token position
158+
# NOTE: TRTLLM LogProb API doesn't have decoded_token, will default to None
159+
token_top_logprobs = []
160+
for tok_id, logprob_info in token_logprobs_dict.items():
161+
token_top_logprobs.append(
162+
{
163+
"rank": logprob_info.rank
164+
if hasattr(logprob_info, "rank")
165+
else 0,
166+
"token_id": tok_id,
167+
"token": (
168+
logprob_info.decoded_token
169+
if hasattr(logprob_info, "decoded_token")
170+
else None
171+
),
172+
"logprob": float(logprob_info.logprob),
173+
}
174+
)
175+
top_logprobs.append(token_top_logprobs)
176+
177+
return log_probs if log_probs else None, top_logprobs if top_logprobs else None
178+
109179
async def _handle_cancellation(
110180
self, generation_result: GenerationResult, context: Context
111181
):
@@ -236,6 +306,26 @@ async def generate_locally(
236306
if hasattr(sampling_params, key):
237307
setattr(sampling_params, key, value)
238308

309+
# Additional sampling params in output options
310+
output_options = request.get("output_options", {})
311+
if output_options:
312+
logprobs_value = output_options.get("logprobs")
313+
314+
# Handle logprobs
315+
if logprobs_value is not None:
316+
if hasattr(sampling_params, "logprobs"):
317+
setattr(
318+
sampling_params, "logprobs", max(1, int(logprobs_value))
319+
) # If top_logprobs = 0, still want to see chosen token logprob
320+
321+
# Handle prompt_logprobs
322+
prompt_logprobs_value = output_options.get("prompt_logprobs")
323+
if prompt_logprobs_value:
324+
if hasattr(sampling_params, "prompt_logprobs"):
325+
setattr(
326+
sampling_params, "prompt_logprobs", int(prompt_logprobs_value)
327+
)
328+
239329
max_tokens = request["stop_conditions"]["max_tokens"]
240330
if max_tokens:
241331
sampling_params.max_tokens = max_tokens
@@ -302,6 +392,15 @@ async def generate_locally(
302392

303393
out = {"token_ids": output.token_ids[num_output_tokens_so_far:]}
304394

395+
# Extract logprobs from the output
396+
log_probs, top_logprobs = self._extract_logprobs(
397+
output, num_output_tokens_so_far
398+
)
399+
if log_probs:
400+
out["log_probs"] = log_probs
401+
if top_logprobs:
402+
out["top_logprobs"] = top_logprobs
403+
305404
if output.finish_reason:
306405
out["finish_reason"] = output.finish_reason
307406
if output.stop_reason:

deploy/cloud/helm/platform/components/operator/templates/planner.yaml

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,9 @@ rules:
3939
- apiGroups: ["nvidia.com"]
4040
resources: ["dynamocomponentdeployments", "dynamographdeployments"]
4141
verbs: ["get", "list", "create", "update", "patch"]
42+
- apiGroups: ["nvidia.com"]
43+
resources: ["dynamographdeploymentscalingadapters/scale"]
44+
verbs: ["patch"]
4245
---
4346
apiVersion: rbac.authorization.k8s.io/v1
4447
kind: RoleBinding
@@ -68,4 +71,7 @@ rules:
6871
- apiGroups: ["nvidia.com"]
6972
resources: ["dynamocomponentdeployments", "dynamographdeployments"]
7073
verbs: ["get", "list", "create", "update", "patch"]
71-
{{- end }}
74+
- apiGroups: ["nvidia.com"]
75+
resources: ["dynamographdeploymentscalingadapters/scale"]
76+
verbs: ["patch"]
77+
{{- end }}

tests/planner/unit/kube.py

Lines changed: 77 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -76,11 +76,87 @@ def test_get_graph_deployment_from_name(k8s_api, mock_custom_api):
7676
)
7777

7878

79-
def test_update_graph_replicas(k8s_api, mock_custom_api):
79+
def test_update_service_replicas_uses_dgdsa_scale(k8s_api, mock_custom_api):
80+
"""Test that update_service_replicas uses DGDSA Scale API when available"""
81+
mock_custom_api.patch_namespaced_custom_object_scale.return_value = None
82+
83+
k8s_api.update_service_replicas("test-deployment", "Frontend", 3)
84+
85+
# Should use Scale subresource with lowercase adapter name
86+
mock_custom_api.patch_namespaced_custom_object_scale.assert_called_once_with(
87+
group="nvidia.com",
88+
version="v1alpha1",
89+
namespace=k8s_api.current_namespace,
90+
plural="dynamographdeploymentscalingadapters",
91+
name="test-deployment-frontend", # lowercase service name
92+
body={"spec": {"replicas": 3}},
93+
)
94+
# Should NOT fall back to DGD patch
95+
mock_custom_api.patch_namespaced_custom_object.assert_not_called()
96+
97+
98+
def test_update_service_replicas_fallback_to_dgd(k8s_api, mock_custom_api):
99+
"""Test that update_service_replicas falls back to DGD when DGDSA not found"""
100+
# DGDSA doesn't exist (404)
101+
mock_custom_api.patch_namespaced_custom_object_scale.side_effect = (
102+
client.ApiException(status=404)
103+
)
80104
mock_custom_api.patch_namespaced_custom_object.return_value = None
81105

106+
k8s_api.update_service_replicas("test-deployment", "test-component", 1)
107+
108+
# Should have tried DGDSA first
109+
mock_custom_api.patch_namespaced_custom_object_scale.assert_called_once()
110+
111+
# Should fall back to DGD patch
112+
mock_custom_api.patch_namespaced_custom_object.assert_called_once_with(
113+
group="nvidia.com",
114+
version="v1alpha1",
115+
namespace=k8s_api.current_namespace,
116+
plural="dynamographdeployments",
117+
name="test-deployment",
118+
body={"spec": {"services": {"test-component": {"replicas": 1}}}},
119+
)
120+
121+
122+
def test_update_service_replicas_propagates_other_errors(k8s_api, mock_custom_api):
123+
"""Test that update_service_replicas propagates non-404 errors"""
124+
mock_custom_api.patch_namespaced_custom_object_scale.side_effect = (
125+
client.ApiException(status=500, reason="Internal Server Error")
126+
)
127+
128+
with pytest.raises(client.ApiException) as exc_info:
129+
k8s_api.update_service_replicas("test-deployment", "test-component", 1)
130+
131+
assert exc_info.value.status == 500
132+
# Should NOT fall back to DGD
133+
mock_custom_api.patch_namespaced_custom_object.assert_not_called()
134+
135+
136+
def test_update_graph_replicas_calls_update_service_replicas(k8s_api, mock_custom_api):
137+
"""Test that deprecated update_graph_replicas calls update_service_replicas"""
138+
mock_custom_api.patch_namespaced_custom_object_scale.return_value = None
139+
140+
# Use the deprecated method
82141
k8s_api.update_graph_replicas("test-deployment", "test-component", 1)
83142

143+
# Should delegate to update_service_replicas which uses Scale API
144+
mock_custom_api.patch_namespaced_custom_object_scale.assert_called_once_with(
145+
group="nvidia.com",
146+
version="v1alpha1",
147+
namespace=k8s_api.current_namespace,
148+
plural="dynamographdeploymentscalingadapters",
149+
name="test-deployment-test-component",
150+
body={"spec": {"replicas": 1}},
151+
)
152+
153+
154+
def test_update_dgd_replicas_directly(k8s_api, mock_custom_api):
155+
"""Test the internal _update_dgd_replicas method"""
156+
mock_custom_api.patch_namespaced_custom_object.return_value = None
157+
158+
k8s_api._update_dgd_replicas("test-deployment", "test-component", 1)
159+
84160
mock_custom_api.patch_namespaced_custom_object.assert_called_once_with(
85161
group="nvidia.com",
86162
version="v1alpha1",

tests/serve/test_trtllm.py

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,10 @@
1414
)
1515
from tests.utils.engine_process import EngineConfig
1616
from tests.utils.payload_builder import (
17+
TEXT_PROMPT,
18+
chat_payload,
1719
chat_payload_default,
20+
completion_payload,
1821
completion_payload_default,
1922
metric_payload_default,
2023
multimodal_payload_default,
@@ -91,6 +94,34 @@ class TRTLLMConfig(EngineConfig):
9194
metric_payload_default(port=8082, min_num_requests=6, backend="trtllm"),
9295
],
9396
),
97+
"aggregated_logprobs": TRTLLMConfig(
98+
name="aggregated_logprobs",
99+
directory=trtllm_dir,
100+
script_name="agg.sh",
101+
marks=[pytest.mark.gpu_1, pytest.mark.pre_merge, pytest.mark.trtllm],
102+
model="Qwen/Qwen3-0.6B",
103+
models_port=8000,
104+
request_payloads=[
105+
chat_payload(content=TEXT_PROMPT, logprobs=True, top_logprobs=5),
106+
chat_payload(content=TEXT_PROMPT, logprobs=False, top_logprobs=5),
107+
chat_payload(content=TEXT_PROMPT, logprobs=True, top_logprobs=None),
108+
chat_payload(content=TEXT_PROMPT, logprobs=True, top_logprobs=0),
109+
],
110+
),
111+
"disaggregated_logprobs": TRTLLMConfig(
112+
name="disaggregated_logprobs",
113+
directory=trtllm_dir,
114+
script_name="disagg.sh",
115+
marks=[pytest.mark.gpu_2, pytest.mark.post_merge, pytest.mark.trtllm],
116+
model="Qwen/Qwen3-0.6B",
117+
models_port=8000,
118+
request_payloads=[
119+
chat_payload(content=TEXT_PROMPT, logprobs=True, top_logprobs=5),
120+
chat_payload(content=TEXT_PROMPT, logprobs=False, top_logprobs=5),
121+
chat_payload(content=TEXT_PROMPT, logprobs=True, top_logprobs=None),
122+
chat_payload(content=TEXT_PROMPT, logprobs=True, top_logprobs=0),
123+
],
124+
),
94125
"aggregated_router": TRTLLMConfig(
95126
name="aggregated_router",
96127
directory=trtllm_dir,
@@ -159,6 +190,7 @@ class TRTLLMConfig(EngineConfig):
159190
},
160191
request_payloads=[
161192
completion_payload_default(),
193+
completion_payload(prompt=TEXT_PROMPT, logprobs=3),
162194
],
163195
),
164196
}

0 commit comments

Comments
 (0)