Skip to content

Commit 29ec496

Browse files
authored
chore: add test for diverging prefixes (#4919)
Signed-off-by: Neal Vaidya <[email protected]>
1 parent bbeb280 commit 29ec496

File tree

2 files changed

+32
-21
lines changed

2 files changed

+32
-21
lines changed

tests/router/common.py

Lines changed: 30 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -1798,11 +1798,16 @@ def _test_router_decisions(
17981798
test_dp_rank: bool = False,
17991799
block_size: int = BLOCK_SIZE,
18001800
):
1801-
"""Validate KV cache prefix reuse and worker routing by sending progressive requests with overlapping prefixes.
1801+
"""Validate KV cache prefix reuse and worker routing by sending requests diverging prefixes.
18021802
1803-
Assumes engine workers are already initialized. Sends 4 progressive requests where each extends
1804-
the previous tokens by `block_size`. The first request is forced to a specific worker (and optionally
1805-
dp_rank), and subsequent requests should naturally route to the same worker due to prefix reuse.
1803+
Assumes engine workers are already initialized.
1804+
The first request is forced to a specific worker (and optionally dp_rank),
1805+
and subsequent requests should naturally route to the same worker due to prefix reuse.
1806+
1807+
Test sequence:
1808+
1. Request 1: [A, B, C, D] → Forces to Worker 1, caches 4 blocks
1809+
2. Request 2: [A, B, E, F] → Shares [A, B] prefix, diverges from Request 1
1810+
3. Request 3: [A, B, C, D, G, H] → Should route to Worker 1 (has [A, B, C, D] cached)
18061811
18071812
Args:
18081813
engine_workers: Backend worker instance ({MockerProcess, VLLMProcess, TRTLLMProcess}) (already initialized with __enter__())
@@ -1844,23 +1849,27 @@ async def test_sync():
18441849
else:
18451850
logger.info(f"Will force first request to worker_id={forced_worker_id}")
18461851

1847-
# Send 4 progressive requests with overlapping prefixes
1848-
cumulative_tokens = []
1852+
# Send 3 requests with some shared prefixes and some divergent prefixes
18491853
response_worker_ids: list[dict[str, Optional[int]]] = []
18501854

1851-
for i in range(4):
1852-
# Add `block_size` new random tokens
1853-
new_tokens = [random.randint(1, 10000) for _ in range(block_size)]
1854-
cumulative_tokens.extend(new_tokens)
1855+
num_blocks = 8
1856+
blocks = [
1857+
[random.randint(1, 10000) for _ in range(block_size)]
1858+
for _ in range(num_blocks)
1859+
]
18551860

1861+
requests = [
1862+
blocks[0] + blocks[1] + blocks[2] + blocks[3],
1863+
blocks[0] + blocks[1] + blocks[4] + blocks[5],
1864+
blocks[0] + blocks[1] + blocks[2] + blocks[3] + blocks[6] + blocks[7],
1865+
]
1866+
1867+
for i, request in enumerate(requests):
18561868
# Force first request to specific worker_id (and dp_rank if testing DP), let subsequent requests follow naturally
18571869
worker_id_override = forced_worker_id if i == 0 else None
18581870
dp_rank_override = forced_dp_rank if i == 0 and test_dp_rank else None
18591871

1860-
log_msg = (
1861-
f"Sending request {i + 1}/4 with {len(cumulative_tokens)} tokens "
1862-
f"(added {len(new_tokens)} new tokens)"
1863-
)
1872+
log_msg = f"Sending request {i + 1}/4 with {len(request)} tokens "
18641873
if worker_id_override is not None:
18651874
if test_dp_rank:
18661875
log_msg += f" - FORCING worker_id={worker_id_override}, dp_rank={dp_rank_override}"
@@ -1871,7 +1880,7 @@ async def test_sync():
18711880
result = await send_request_via_python_kv_router(
18721881
kv_python_router=kv_push_router,
18731882
model_name=model_name,
1874-
token_ids=cumulative_tokens.copy(),
1883+
token_ids=request,
18751884
initial_wait=1.0,
18761885
max_retries=8,
18771886
stop_conditions={
@@ -1944,12 +1953,12 @@ async def test_sync():
19441953
f"but found {len(keys_with_events_dp)} with events: {keys_with_events_dp}"
19451954
)
19461955

1947-
# Verify: The routing key with events should have exactly 4 events (one per request)
1956+
# Verify: The routing key with events should have exactly 8 events (one per unique block)
19481957
active_key_dp = keys_with_events_dp[0]
19491958
num_events = len(events_by_key_dp[active_key_dp])
19501959

1951-
assert num_events == 4, (
1952-
f"Expected (worker_id, dp_rank) {active_key_dp} to have exactly 4 events, "
1960+
assert num_events == 8, (
1961+
f"Expected (worker_id, dp_rank) {active_key_dp} to have exactly 8 events, "
19531962
f"but found {num_events} events"
19541963
)
19551964

@@ -1991,12 +2000,12 @@ async def test_sync():
19912000
f"but found {len(keys_with_events_single)} with events: {keys_with_events_single}"
19922001
)
19932002

1994-
# Verify: The routing key with events should have exactly 4 events (one per request)
2003+
# Verify: The routing key with events should have exactly 8 events (one per unique block)
19952004
active_worker_id = keys_with_events_single[0]
19962005
num_events = len(events_by_key_single[active_worker_id])
19972006

1998-
assert num_events == 4, (
1999-
f"Expected worker_id {active_worker_id} to have exactly 4 events, "
2007+
assert num_events == 8, (
2008+
f"Expected worker_id {active_worker_id} to have exactly 8 events, "
20002009
f"but found {num_events} events"
20012010
)
20022011

tests/router/test_router_e2e_with_sglang.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -333,6 +333,8 @@ def test_sglang_kv_router_basic(
333333

334334
@pytest.mark.pre_merge
335335
@pytest.mark.gpu_1
336+
@pytest.mark.skip(reason="Broken by sglang changes")
337+
# TODO: Re-enable this test once https://github.com/sgl-project/sglang/pull/14934 is merged
336338
def test_router_decisions_sglang_multiple_workers(
337339
request, runtime_services, predownload_models, set_ucx_tls_no_mm
338340
):

0 commit comments

Comments
 (0)