Skip to content

Commit 1e5b20b

Browse files
authored
chore: cleanups of passing around prefill and decode worker ids (#4829)
Signed-off-by: PeaBrane <[email protected]>
1 parent 14321c8 commit 1e5b20b

File tree

6 files changed

+152
-87
lines changed

6 files changed

+152
-87
lines changed

lib/llm/src/kv_router.rs

Lines changed: 21 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,8 @@ use futures::stream::{self, StreamExt};
2222
use serde::{Deserialize, Serialize};
2323
use serde_json::json;
2424

25+
use crate::protocols::openai::nvext::WorkerIdInfo;
26+
2527
pub mod approx;
2628
pub mod indexer;
2729
pub mod prefill_router;
@@ -646,13 +648,19 @@ impl AsyncEngine<SingleIn<PreprocessedRequest>, ManyOut<Annotated<LLMEngineOutpu
646648
backend_input.estimated_prefix_hit_num_blocks = Some(overlap_amount);
647649
backend_input.dp_rank = Some(dp_rank);
648650

649-
// Get prefill worker ID if available (stored by PrefillRouter)
650-
// In aggregated mode, prefill_worker_id is None, so we use decode_worker_id for both
651+
// Get prefill worker ID from prefill_result if available
652+
// In aggregated mode, prefill_result is None, so we use decode_worker_id for both
651653
let decode_worker_id = instance_id;
652-
let prefill_worker_id = context
653-
.get::<u64>("prefill_worker_id")
654-
.ok()
655-
.map(|arc| *arc)
654+
let prefill_worker_id = backend_input
655+
.prefill_result
656+
.as_ref()
657+
.and_then(|prefill_result| {
658+
prefill_result
659+
.disaggregated_params
660+
.get("worker_id")
661+
.and_then(|v| serde_json::from_value::<WorkerIdInfo>(v.clone()).ok())
662+
.and_then(|info| info.prefill_worker_id)
663+
})
656664
.or(Some(decode_worker_id)); // Use decode_worker_id if no separate prefill worker
657665

658666
let updated_request = context.map(|_| backend_input);
@@ -699,12 +707,14 @@ impl AsyncEngine<SingleIn<PreprocessedRequest>, ManyOut<Annotated<LLMEngineOutpu
699707
continue;
700708
};
701709

702-
// prefill_worker_id comes from context (set by PrefillRouter) or falls back to instance_id
710+
// prefill_worker_id comes from prefill_result.disaggregated_params or falls back to instance_id
703711
// decode_worker_id is always the current instance_id
704-
let worker_id_json = json!({
705-
"prefill_worker_id": prefill_worker_id,
706-
"decode_worker_id": decode_worker_id,
707-
});
712+
let worker_id_info = WorkerIdInfo {
713+
prefill_worker_id,
714+
decode_worker_id: Some(decode_worker_id),
715+
};
716+
let worker_id_json = serde_json::to_value(&worker_id_info)
717+
.expect("WorkerIdInfo serialization should not fail");
708718

709719
if let Some(obj) = data.disaggregated_params.as_mut().and_then(|p| p.as_object_mut()) {
710720
obj.insert("worker_id".to_string(), worker_id_json);

lib/llm/src/kv_router/prefill_router.rs

Lines changed: 8 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -176,11 +176,11 @@ impl PrefillRouter {
176176
Ok(())
177177
}
178178

179-
/// Call the prefill router and extract structured prefill result and worker ID
179+
/// Call the prefill router and extract structured prefill result
180180
async fn call_prefill(
181181
&self,
182182
request: SingleIn<PreprocessedRequest>,
183-
) -> Result<(PrefillResult, Option<u64>), PrefillError> {
183+
) -> Result<PrefillResult, PrefillError> {
184184
// Get the prefill router, error if not activated
185185
let Some(prefill_router) = self.prefill_router.get() else {
186186
return Err(PrefillError::NotActivated);
@@ -239,21 +239,10 @@ impl PrefillRouter {
239239
));
240240
};
241241

242-
// Extract prefill worker ID from disaggregated_params
243-
let prefill_worker_id = disaggregated_params
244-
.get("worker_id")
245-
.and_then(|worker_id_json| {
246-
worker_id_json
247-
.get("prefill_worker_id")
248-
.and_then(|v| v.as_u64())
249-
});
250-
Ok((
251-
PrefillResult {
252-
disaggregated_params,
253-
prompt_tokens_details,
254-
},
255-
prefill_worker_id,
256-
))
242+
Ok(PrefillResult {
243+
disaggregated_params,
244+
prompt_tokens_details,
245+
})
257246
}
258247
}
259248

@@ -310,7 +299,7 @@ impl
310299

311300
// Handle prefill result
312301
match prefill_result {
313-
Ok((prefill_result, prefill_worker_id)) => {
302+
Ok(prefill_result) => {
314303
tracing::debug!("Prefill succeeded, using disaggregated params for decode");
315304

316305
let mut decode_req = req;
@@ -326,14 +315,8 @@ impl
326315
..existing_override.unwrap_or_default()
327316
});
328317

329-
// Store prefill worker ID in context if available
330-
let mut decode_context = context;
331-
if let Some(worker_id) = prefill_worker_id {
332-
decode_context.insert("prefill_worker_id", worker_id);
333-
}
334-
335318
// Map the modified request through with preserved context
336-
let decode_request = decode_context.map(|_| decode_req);
319+
let decode_request = context.map(|_| decode_req);
337320
next.generate(decode_request).await
338321
}
339322
Err(PrefillError::NotActivated) => {

lib/llm/src/protocols/openai/chat_completions/delta.rs

Lines changed: 9 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,10 @@
44
use super::{NvCreateChatCompletionRequest, NvCreateChatCompletionStreamResponse};
55
use crate::{
66
local_model::runtime_config::ModelRuntimeConfig,
7-
protocols::common::{self},
7+
protocols::{
8+
common,
9+
openai::nvext::{NvExtResponse, WorkerIdInfo},
10+
},
811
types::TokenIdType,
912
};
1013

@@ -363,35 +366,22 @@ impl crate::protocols::openai::DeltaGeneratorExt<NvCreateChatCompletionStreamRes
363366
let mut stream_response = self.create_choice(index, delta.text, finish_reason, logprobs);
364367

365368
// Extract worker_id from disaggregated_params and inject into nvext if present
366-
if let Some(worker_id_json) = delta
369+
if let Some(worker_id_info) = delta
367370
.disaggregated_params
368371
.as_ref()
369372
.and_then(|params| params.get("worker_id"))
373+
.and_then(|v| serde_json::from_value::<WorkerIdInfo>(v.clone()).ok())
370374
{
371-
use crate::protocols::openai::nvext::{NvExtResponse, WorkerIdInfo};
372-
373-
let prefill_worker_id = worker_id_json
374-
.get("prefill_worker_id")
375-
.and_then(|v| v.as_u64());
376-
let decode_worker_id = worker_id_json
377-
.get("decode_worker_id")
378-
.and_then(|v| v.as_u64());
379-
380-
let worker_id_info = WorkerIdInfo {
381-
prefill_worker_id,
382-
decode_worker_id,
383-
};
384-
385375
let nvext_response = NvExtResponse {
386-
worker_id: Some(worker_id_info),
376+
worker_id: Some(worker_id_info.clone()),
387377
};
388378

389379
if let Ok(nvext_json) = serde_json::to_value(&nvext_response) {
390380
stream_response.nvext = Some(nvext_json);
391381
tracing::debug!(
392382
"Injected worker_id into chat completion nvext: prefill={:?}, decode={:?}",
393-
prefill_worker_id,
394-
decode_worker_id
383+
worker_id_info.prefill_worker_id,
384+
worker_id_info.decode_worker_id
395385
);
396386
}
397387
}

lib/llm/src/protocols/openai/completions/delta.rs

Lines changed: 12 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,13 @@
22
// SPDX-License-Identifier: Apache-2.0
33

44
use super::{NvCreateCompletionRequest, NvCreateCompletionResponse};
5-
use crate::{protocols::common, types::TokenIdType};
5+
use crate::{
6+
protocols::{
7+
common,
8+
openai::nvext::{NvExtResponse, WorkerIdInfo},
9+
},
10+
types::TokenIdType,
11+
};
612

713
impl NvCreateCompletionRequest {
814
/// Enables usage tracking for non-streaming requests to comply with OpenAI API specification.
@@ -266,35 +272,22 @@ impl crate::protocols::openai::DeltaGeneratorExt<NvCreateCompletionResponse> for
266272
let mut response = self.create_choice(index, delta.text.clone(), finish_reason, logprobs);
267273

268274
// Extract worker_id from disaggregated_params and inject into nvext if present
269-
if let Some(worker_id_json) = delta
275+
if let Some(worker_id_info) = delta
270276
.disaggregated_params
271277
.as_ref()
272278
.and_then(|params| params.get("worker_id"))
279+
.and_then(|v| serde_json::from_value::<WorkerIdInfo>(v.clone()).ok())
273280
{
274-
use crate::protocols::openai::nvext::{NvExtResponse, WorkerIdInfo};
275-
276-
let prefill_worker_id = worker_id_json
277-
.get("prefill_worker_id")
278-
.and_then(|v| v.as_u64());
279-
let decode_worker_id = worker_id_json
280-
.get("decode_worker_id")
281-
.and_then(|v| v.as_u64());
282-
283-
let worker_id_info = WorkerIdInfo {
284-
prefill_worker_id,
285-
decode_worker_id,
286-
};
287-
288281
let nvext_response = NvExtResponse {
289-
worker_id: Some(worker_id_info),
282+
worker_id: Some(worker_id_info.clone()),
290283
};
291284

292285
if let Ok(nvext_json) = serde_json::to_value(&nvext_response) {
293286
response.inner.nvext = Some(nvext_json);
294287
tracing::debug!(
295288
"Injected worker_id into completions nvext: prefill={:?}, decode={:?}",
296-
prefill_worker_id,
297-
decode_worker_id
289+
worker_id_info.prefill_worker_id,
290+
worker_id_info.decode_worker_id
298291
);
299292
}
300293
}

0 commit comments

Comments
 (0)