@@ -1798,11 +1798,16 @@ def _test_router_decisions(
17981798 test_dp_rank : bool = False ,
17991799 block_size : int = BLOCK_SIZE ,
18001800):
1801- """Validate KV cache prefix reuse and worker routing by sending progressive requests with overlapping prefixes.
1801+ """Validate KV cache prefix reuse and worker routing by sending requests diverging prefixes.
18021802
1803- Assumes engine workers are already initialized. Sends 4 progressive requests where each extends
1804- the previous tokens by `block_size`. The first request is forced to a specific worker (and optionally
1805- dp_rank), and subsequent requests should naturally route to the same worker due to prefix reuse.
1803+ Assumes engine workers are already initialized.
1804+ The first request is forced to a specific worker (and optionally dp_rank),
1805+ and subsequent requests should naturally route to the same worker due to prefix reuse.
1806+
1807+ Test sequence:
1808+ 1. Request 1: [A, B, C, D] → Forces to Worker 1, caches 4 blocks
1809+ 2. Request 2: [A, B, E, F] → Shares [A, B] prefix, diverges from Request 1
1810+ 3. Request 3: [A, B, C, D, G, H] → Should route to Worker 1 (has [A, B, C, D] cached)
18061811
18071812 Args:
18081813 engine_workers: Backend worker instance ({MockerProcess, VLLMProcess, TRTLLMProcess}) (already initialized with __enter__())
@@ -1844,23 +1849,27 @@ async def test_sync():
18441849 else :
18451850 logger .info (f"Will force first request to worker_id={ forced_worker_id } " )
18461851
1847- # Send 4 progressive requests with overlapping prefixes
1848- cumulative_tokens = []
1852+ # Send 3 requests with some shared prefixes and some divergent prefixes
18491853 response_worker_ids : list [dict [str , Optional [int ]]] = []
18501854
1851- for i in range (4 ):
1852- # Add `block_size` new random tokens
1853- new_tokens = [random .randint (1 , 10000 ) for _ in range (block_size )]
1854- cumulative_tokens .extend (new_tokens )
1855+ num_blocks = 8
1856+ blocks = [
1857+ [random .randint (1 , 10000 ) for _ in range (block_size )]
1858+ for _ in range (num_blocks )
1859+ ]
18551860
1861+ requests = [
1862+ blocks [0 ] + blocks [1 ] + blocks [2 ] + blocks [3 ],
1863+ blocks [0 ] + blocks [1 ] + blocks [4 ] + blocks [5 ],
1864+ blocks [0 ] + blocks [1 ] + blocks [2 ] + blocks [3 ] + blocks [6 ] + blocks [7 ],
1865+ ]
1866+
1867+ for i , request in enumerate (requests ):
18561868 # Force first request to specific worker_id (and dp_rank if testing DP), let subsequent requests follow naturally
18571869 worker_id_override = forced_worker_id if i == 0 else None
18581870 dp_rank_override = forced_dp_rank if i == 0 and test_dp_rank else None
18591871
1860- log_msg = (
1861- f"Sending request { i + 1 } /4 with { len (cumulative_tokens )} tokens "
1862- f"(added { len (new_tokens )} new tokens)"
1863- )
1872+ log_msg = f"Sending request { i + 1 } /4 with { len (request )} tokens "
18641873 if worker_id_override is not None :
18651874 if test_dp_rank :
18661875 log_msg += f" - FORCING worker_id={ worker_id_override } , dp_rank={ dp_rank_override } "
@@ -1871,7 +1880,7 @@ async def test_sync():
18711880 result = await send_request_via_python_kv_router (
18721881 kv_python_router = kv_push_router ,
18731882 model_name = model_name ,
1874- token_ids = cumulative_tokens . copy () ,
1883+ token_ids = request ,
18751884 initial_wait = 1.0 ,
18761885 max_retries = 8 ,
18771886 stop_conditions = {
@@ -1944,12 +1953,12 @@ async def test_sync():
19441953 f"but found { len (keys_with_events_dp )} with events: { keys_with_events_dp } "
19451954 )
19461955
1947- # Verify: The routing key with events should have exactly 4 events (one per request )
1956+ # Verify: The routing key with events should have exactly 8 events (one per unique block )
19481957 active_key_dp = keys_with_events_dp [0 ]
19491958 num_events = len (events_by_key_dp [active_key_dp ])
19501959
1951- assert num_events == 4 , (
1952- f"Expected (worker_id, dp_rank) { active_key_dp } to have exactly 4 events, "
1960+ assert num_events == 8 , (
1961+ f"Expected (worker_id, dp_rank) { active_key_dp } to have exactly 8 events, "
19531962 f"but found { num_events } events"
19541963 )
19551964
@@ -1991,12 +2000,12 @@ async def test_sync():
19912000 f"but found { len (keys_with_events_single )} with events: { keys_with_events_single } "
19922001 )
19932002
1994- # Verify: The routing key with events should have exactly 4 events (one per request )
2003+ # Verify: The routing key with events should have exactly 8 events (one per unique block )
19952004 active_worker_id = keys_with_events_single [0 ]
19962005 num_events = len (events_by_key_single [active_worker_id ])
19972006
1998- assert num_events == 4 , (
1999- f"Expected worker_id { active_worker_id } to have exactly 4 events, "
2007+ assert num_events == 8 , (
2008+ f"Expected worker_id { active_worker_id } to have exactly 8 events, "
20002009 f"but found { num_events } events"
20012010 )
20022011
0 commit comments