Skip to content

Commit 2c1af84

Browse files
committed
kvbm: cleanup rd1
Signed-off-by: Ryan Olson <[email protected]>
1 parent a1c3f4e commit 2c1af84

File tree

7 files changed

+57
-485
lines changed

7 files changed

+57
-485
lines changed

lib/bindings/kvbm/python/kvbm/v2/vllm/schedulers/leader.py

Lines changed: 28 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -20,19 +20,13 @@
2020
from kvbm._core import v2 as _v2
2121
from kvbm.v2.vllm import KvbmVllmConfig
2222

23+
from ..sched_output import process_scheduler_output
24+
from .worker import NovaPeerMetadata
25+
2326
KvbmRuntime = _v2.KvbmRuntime
2427
ConnectorLeader = _v2.ConnectorLeader
2528
KvbmRequest = _v2.KvbmRequest
2629

27-
# TODO: Re-enable when v2 connector bindings are updated
28-
# These classes need to be updated for v2 API changes in kvbm crate
29-
# KvbmRequest = _v2.KvbmRequest
30-
# RustKvConnectorLeader = _v2.KvConnectorLeader
31-
# RustSchedulerOutput = _v2.RustSchedulerOutput
32-
33-
# Import the handshake metadata type from worker module
34-
from .worker import NovaPeerMetadata
35-
from ..sched_output import process_scheduler_output
3630

3731
if TYPE_CHECKING:
3832
from vllm.config import VllmConfig
@@ -41,8 +35,8 @@
4135
)
4236
from vllm.v1.core.kv_cache_manager import KVCacheBlocks, KVCacheConfig
4337
from vllm.v1.core.sched.output import SchedulerOutput
44-
from vllm.v1.request import Request
4538
from vllm.v1.outputs import KVConnectorOutput
39+
from vllm.v1.request import Request
4640

4741

4842
class SchedulerConnectorLeader:
@@ -59,7 +53,11 @@ class SchedulerConnectorLeader:
5953
"""
6054

6155
def __init__(
62-
self, vllm_config: VllmConfig, kvbm_config: KvbmVllmConfig, kv_cache_config: KVCacheConfig, **kwargs
56+
self,
57+
vllm_config: VllmConfig,
58+
kvbm_config: KvbmVllmConfig,
59+
kv_cache_config: KVCacheConfig,
60+
**kwargs,
6361
):
6462
"""Initialize the scheduler connector leader."""
6563
print("[KVBM DEBUG] SchedulerConnectorLeader.__init__ START", flush=True)
@@ -90,7 +88,9 @@ def get_num_new_matched_tokens(
9088
num_computed_tokens: int,
9189
) -> tuple[Optional[int], bool]:
9290
self._create_slot(request)
93-
return self.leader.get_num_new_matched_tokens(request.request_id, num_computed_tokens)
91+
return self.leader.get_num_new_matched_tokens(
92+
request.request_id, num_computed_tokens
93+
)
9494

9595
def update_state_after_alloc(
9696
self, request: "Request", blocks: "KVCacheBlocks", num_external_tokens: int
@@ -101,7 +101,9 @@ def update_state_after_alloc(
101101
This should never be called with num_external_tokens > 0.
102102
"""
103103
block_ids = [int(block_id) for block_id in blocks.get_block_ids()[0]]
104-
self.leader.update_state_after_alloc(request.request_id, block_ids, num_external_tokens)
104+
self.leader.update_state_after_alloc(
105+
request.request_id, block_ids, num_external_tokens
106+
)
105107

106108
def build_connector_meta(self, scheduler_output: "SchedulerOutput") -> bytes:
107109
"""
@@ -136,12 +138,19 @@ def request_finished(
136138

137139
def update_connector_output(self, connector_output: KVConnectorOutput) -> None:
138140
# Convert None to empty sets for Rust binding compatibility
139-
finished_sending = connector_output.finished_sending if connector_output.finished_sending is not None else set()
140-
finished_recving = connector_output.finished_recving if connector_output.finished_recving is not None else set()
141+
finished_sending = (
142+
connector_output.finished_sending
143+
if connector_output.finished_sending is not None
144+
else set()
145+
)
146+
finished_recving = (
147+
connector_output.finished_recving
148+
if connector_output.finished_recving is not None
149+
else set()
150+
)
141151
self.leader.update_connector_output(finished_sending, finished_recving)
142152

143153
def get_finished_count(self) -> Optional[int]:
144-
"""No finished count tracking for Phase 1."""
145154
return None
146155

147156
def set_xfer_handshake_metadata(
@@ -193,7 +202,7 @@ def set_xfer_handshake_metadata(
193202
def _create_slot(self, request: "Request") -> None:
194203
if self.leader.has_slot(request.request_id):
195204
return
196-
205+
197206
if bool(getattr(request, "mm_features", None)) or bool(
198207
getattr(request, "mm_positions", None)
199208
):
@@ -207,7 +216,7 @@ def _create_slot(self, request: "Request") -> None:
207216
else:
208217
# Single-sequence case: already flat
209218
all_token_ids = [int(token) for token in request.all_token_ids]
210-
219+
211220
kv_request = KvbmRequest(
212221
request_id=request.request_id,
213222
tokens=all_token_ids,
@@ -219,5 +228,5 @@ def _create_slot(self, request: "Request") -> None:
219228
else None,
220229
max_tokens=request.max_tokens,
221230
)
222-
231+
223232
self.leader.create_slot(kv_request)

lib/bindings/kvbm/python/kvbm/v2/vllm/schedulers/worker.py

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,6 @@
1313

1414
from __future__ import annotations
1515

16-
import json
1716
from dataclasses import dataclass
1817
from typing import TYPE_CHECKING, Optional
1918

@@ -22,14 +21,11 @@
2221
# Import KvbmRuntime and ConnectorWorker from Rust bindings
2322
from kvbm._core import v2 as _v2
2423
from kvbm.v2.vllm import KvbmVllmConfig
25-
2624
from vllm.distributed.kv_transfer.kv_connector.v1.base import (
2725
KVConnectorHandshakeMetadata,
2826
)
2927
from vllm.model_executor.models.utils import extract_layer_index
3028

31-
from ..config import extract_vllm_config_for_kvbm
32-
3329
KvbmRuntime = _v2.KvbmRuntime
3430
ConnectorWorker = _v2.ConnectorWorker
3531

@@ -68,7 +64,11 @@ class SchedulerConnectorWorker:
6864
"""
6965

7066
def __init__(
71-
self, vllm_config: "VllmConfig", kvbm_config: KvbmVllmConfig, kv_cache_config: KVCacheConfig, **kwargs
67+
self,
68+
vllm_config: "VllmConfig",
69+
kvbm_config: KvbmVllmConfig,
70+
kv_cache_config: KVCacheConfig,
71+
**kwargs,
7272
):
7373
"""Initialize the scheduler connector worker."""
7474
self.vllm_config = vllm_config
@@ -134,7 +134,7 @@ def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]) -> None:
134134
# For NHD layout: [2 (K/V), num_blocks, block_size, num_heads, head_size]
135135
# For HND layout: [2 (K/V), num_blocks, num_heads, block_size, head_size]
136136
num_device_blocks = max(shape[0], shape[1])
137-
page_size = self.kvbm_config.block_size()
137+
page_size = self.vllm_config.cache_config.block_size
138138
dtype_width_bytes = self.kvbm_config.cache_dtype_bytes()
139139

140140
config_gpu_blocks = self.vllm_config.cache_config.num_gpu_blocks
@@ -222,7 +222,9 @@ def get_finished(
222222
Returns:
223223
(None, None): No finished sends/receives
224224
"""
225-
print(f"SchedulerConnectorWorker.get_finished called with {len(finished_req_ids)} finished requests")
225+
print(
226+
f"SchedulerConnectorWorker.get_finished called with {len(finished_req_ids)} finished requests"
227+
)
226228
return self.worker.get_finished()
227229

228230
def get_block_ids_with_load_errors(self) -> set[int]:

0 commit comments

Comments
 (0)