diff --git a/examples/deployments/router_standalone_trtllm/README.md b/examples/deployments/router_standalone_trtllm/README.md new file mode 100644 index 0000000000..eb9b24bcb4 --- /dev/null +++ b/examples/deployments/router_standalone_trtllm/README.md @@ -0,0 +1,248 @@ + + +# Router Standalone - TensorRT-LLM + +A standalone implementation of KvRouter that demonstrates usage with TensorRT-LLM workers, without dependency on the dynamo runtime, etcd control plane, or nats event plane. + +## Overview + +This example shows how to use KvRouter with TensorRT-LLM workers to intelligently route requests across multiple GPUs based on KV cache overlap and load metrics. The router maintains a view of each worker's cached blocks and routes new requests to the worker with the best combination of cache overlap and available capacity. + +Key features: +- **KV cache-aware routing**: Routes requests to workers with matching cached blocks +- **Multimodal support**: Handles vision-language models (e.g., Qwen2-VL) with image inputs +- **MM hash routing**: Identical images produce identical hashes for cache reuse + +## How It Works + +### Core Architecture + +The router uses a **RadixTree** data structure (written in Rust) to efficiently track which blocks each worker has cached. When a new request arrives, the router: + +1. Tokenizes the request and computes block hashes (including MM hashes for images) +2. Uses `find_matches` to calculate overlap scores between the request and each worker's cached blocks +3. Combines this with current load metrics to select the optimal worker +4. Routes the request to the chosen worker for processing + +### Multimodal Routing + +For vision-language models: +1. Images are processed using `default_multimodal_input_loader` from TensorRT-LLM +2. Image placeholders are expanded to visual tokens using HuggingFace `AutoProcessor` +3. `apply_mm_hashes` computes a content hash for each image +4. The MM hash is included in block hash computation, so identical images produce cache hits + +### Event-Driven Updates + +The router receives two types of events from TensorRT-LLM engines: + +1. **KV Events**: Emitted automatically when blocks are stored/removed from cache (includes `mm_keys` for multimodal) +2. **Load Metrics**: GPU cache usage and waiting request count + +## Components + +### `worker.py` +- **TrtllmWorkers**: Manages multiple TensorRT-LLM worker processes +- Each worker runs on a separate GPU with KV cache event emission enabled +- Publishes metrics and KV events over ZMQ +- Extracts `mm_hash` from TRTLLM's `mm_keys` field for multimodal routing + +### `router.py` +- **KvRouter**: Core routing logic using RadixTree +- Subscribes to KV cache events and load metrics from workers +- Implements `get_best_worker()` to select optimal routing destination + +### `api.py` +- **ServiceAPI**: FastAPI server providing OpenAI-compatible chat completions endpoint +- Handles multimodal inputs (images) via `default_multimodal_input_loader` +- Computes block hashes including MM hashes for routing decisions +- Streams responses in OpenAI format + +### `test_router.py` +- Comprehensive test suite for router functionality +- Includes local hash computation tests and server-side multimodal tests +- Run with `--mm-only` for multimodal-specific tests + +## Requirements + +- **TensorRT-LLM >= 1.2.0rc6**: You need TensorRT-LLM version 1.2.0rc6 or later, which includes multimodal information (`mm_keys`) in KV cache events. This is required for MM hash-based routing. See [PR #9604](https://github.com/NVIDIA/TensorRT-LLM/pull/9604) for details. +- TensorRT-LLM with pytorch backend +- Multiple GPUs (one per worker) +- Python 3.10+ +- Required packages: fastapi, uvicorn, httpx, zmq, tensorrt_llm, transformers + +## Usage + +### 1. Start the API Server + +```bash +python api.py \ + --model Qwen/Qwen2-VL-2B-Instruct \ + --num-workers 2 \ + --block-size 32 \ + --base-kv-events-port 5557 \ + --base-metrics-port 5657 \ + --router-port 7000 \ + --http-port 8000 +``` + +This will: +- Initialize TensorRT-LLM engines on each GPU +- Start ZMQ publishers for metrics and KV events +- Start the router service +- Start the OpenAI-compatible API server + +### 2. Test with curl + +**Text-only request:** +```bash +curl -s http://localhost:8000/v1/chat/completions \ + -H "Content-Type: application/json" \ + -d '{ + "model": "Qwen/Qwen2-VL-2B-Instruct", + "messages": [{"role": "user", "content": "Hello, how are you?"}], + "max_tokens": 100, + "stream": false + }' | jq +``` + +**Multimodal request (with images):** +```bash +curl -s -X POST http://localhost:8000/v1/chat/completions \ + -H "Content-Type: application/json" \ + -d '{ + "model": "Qwen/Qwen2-VL-2B-Instruct", + "messages": [{ + "role": "user", + "content": [ + {"type": "text", "text": "Describe both images in detail."}, + {"type": "image_url", "image_url": {"url": "https://huggingface.co/datasets/Sayali9141/traffic_signal_images/resolve/main/61.jpg"}}, + {"type": "image_url", "image_url": {"url": "http://images.cocodataset.org/test2017/000000000001.jpg"}} + ] + }], + "max_tokens": 500, + "stream": false + }' | jq +``` + +### 3. Run Tests + +```bash +# Run all tests +python test_router.py + +# Run multimodal tests only +python test_router.py --mm-only + +# Verbose output +python test_router.py -v +``` + +### 4. Check endpoint health + +```bash +./ping.sh +``` + +## Configuration + +### Command-line Arguments + +- `--model`: HuggingFace model name (default: Qwen/Qwen2-VL-2B-Instruct) +- `--num-workers`: Number of GPU workers (default: 2) +- `--block-size`: KV cache block size (default: 32, TensorRT-LLM's default) +- `--base-kv-events-port`: Base port for KV events ZMQ (default: 5557) +- `--base-metrics-port`: Base port for metrics ZMQ (default: 5657) +- `--router-port`: Router HTTP service port (default: 7000) +- `--http-port`: API server port (default: 8000) + +### Environment Variables + +- `DYNAMO_DEBUG=1`: Enable debug file dumps to `/tmp/debug_*.txt` +- `LOGLEVEL=DEBUG`: Set logging level (DEBUG, INFO, WARNING, ERROR) +- `TRANSFORMERS_ATTN_IMPLEMENTATION=eager`: Disable FlashAttention (set automatically) + +### Port Assignment + +Workers use sequential ports: +- Worker 0: KV events on 5557, metrics on 5657 +- Worker 1: KV events on 5558, metrics on 5658 +- Worker N: KV events on 5557+N, metrics on 5657+N + +## Architecture Diagram + +``` +┌─────────────┐ +│ Client │ +└──────┬──────┘ + │ HTTP + ▼ +┌─────────────────┐ +│ API Server │ +│ (api.py) │ +└────────┬────────┘ + │ HTTP + ▼ +┌─────────────────┐ +│ Router │──┐ +│ (router.py) │ │ ZMQ (KV Events) +└────────┬────────┘ │ + │ │ + │ Select │ + │ Worker │ + ▼ │ +┌─────────────────┐ │ +│ TrtllmWorkers │ │ +│ (worker.py) │◄-┘ +└─────────────────┘ + │ │ + ▼ ▼ + GPU 0 GPU 1 +``` + +## Multimodal KV Cache Routing + +When processing multimodal requests: + +1. **API Layer** (`api.py`): + - Parses OpenAI-format messages with `image_url` content + - Uses `default_multimodal_input_loader` to process images + - Expands image placeholders to visual tokens via `AutoProcessor` + - Computes `mm_hash` using `apply_mm_hashes` + - Includes `mm_hash` in block hash computation for routing + +2. **Worker Layer** (`worker.py`): + - Receives multimodal input and passes to TRTLLM + - Extracts `mm_hash` from TRTLLM's `mm_keys` in KV events + - Publishes KV events with `mm_extra_info` to router + +3. **Router Layer** (`router.py`): + - RadixTree matches blocks including MM hash + - Same image content = same hash = cache hit on same worker + +## Notes + +- This is a standalone implementation for pedagogical purposes +- Production dynamo uses NATS for events and etcd for service discovery +- Each worker needs its own GPU +- TensorRT-LLM models may take time to compile on first run + +## See Also + +- [vLLM Router Standalone](../router_standalone/) - Original vLLM version +- [TensorRT-LLM KV Event Documentation](https://nvidia.github.io/TensorRT-LLM/0.21.0/examples/llm_inference_kv_events.html) diff --git a/examples/deployments/router_standalone_trtllm/__init__.py b/examples/deployments/router_standalone_trtllm/__init__.py new file mode 100644 index 0000000000..1a8431c3e3 --- /dev/null +++ b/examples/deployments/router_standalone_trtllm/__init__.py @@ -0,0 +1,2 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 diff --git a/examples/deployments/router_standalone_trtllm/api.py b/examples/deployments/router_standalone_trtllm/api.py new file mode 100644 index 0000000000..becefe195f --- /dev/null +++ b/examples/deployments/router_standalone_trtllm/api.py @@ -0,0 +1,765 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +import os + +# Fix protobuf version conflict with etcd3 +os.environ.setdefault("PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION", "python") + +import argparse +import asyncio +import json +import logging +import time +import uuid +from dataclasses import dataclass +from typing import Optional + +import httpx +import uvicorn +from fastapi import FastAPI +from fastapi.responses import JSONResponse, StreamingResponse +from pydantic import BaseModel +from router import RouterAPI, RouterRequest, RouterResponse +from tensorrt_llm.inputs.multimodal import apply_mm_hashes +from tensorrt_llm.inputs.utils import default_multimodal_input_loader, load_image +from tensorrt_llm.llmapi.tokenizer import tokenizer_factory +from transformers import AutoProcessor +from worker import TrtllmWorkers + +from dynamo._core import compute_block_hash_for_seq_py + +logger = logging.getLogger(__name__) + +# Debug flag: set DYNAMO_DEBUG=1 to enable debug file dumps +DEBUG_ENABLED = os.environ.get("DYNAMO_DEBUG", "0") == "1" +DEBUG_API_FILE = "/tmp/debug_api_hashes.txt" + +# Qwen2-VL specific token IDs +QWEN2_VL_IMAGE_TOKEN_ID = 151655 +QWEN2_VL_REPLACEMENT_ID = 151937 + + +def dump_api_debug( + tokens: list[int], + block_size: int, + local_hashes: list[int], + mm_hashes: list[int] | None, + block_mm_infos: list | None, + image_urls: list[str] | None, +): + """Dump API-side hash computation to file for debugging.""" + if not DEBUG_ENABLED: + return + import datetime + + with open(DEBUG_API_FILE, "a") as f: + f.write(f"\n{'='*60}\n") + f.write(f"Timestamp: {datetime.datetime.now()}\n") + f.write(f"Image URLs: {image_urls}\n") + f.write(f"mm_hashes: {mm_hashes}\n") + f.write(f"block_size: {block_size}\n") + f.write(f"num_tokens: {len(tokens)}\n") + f.write(f"tokens (first 50): {tokens[:50]}\n") + f.write(f"tokens (last 50): {tokens[-50:]}\n") + f.write(f"block_mm_infos: {block_mm_infos}\n") + f.write(f"local_hashes ({len(local_hashes)}): {local_hashes}\n") + f.write(f"{'='*60}\n") + + +def make_error(message: str, error_type: str, code: int) -> dict: + """Create a standardized error response dict.""" + return {"message": message, "type": error_type, "code": code} + + +# Pydantic models for OpenAI-compatible API +class ImageUrl(BaseModel): + url: str + + +class ContentPart(BaseModel): + type: str # "text" | "image_url" + text: Optional[str] = None + image_url: Optional[ImageUrl] = None + + +class Message(BaseModel): + role: str + content: str | list[ContentPart] + + +class ChatCompletionRequest(BaseModel): + model: str + messages: list[Message] + max_tokens: Optional[int] = None + max_completion_tokens: Optional[int] = None + temperature: Optional[float] = 1.0 + top_p: Optional[float] = 1.0 + stream: bool = True + + +class ErrorResponse(BaseModel): + error: dict + + +@dataclass(frozen=True) +class ServingParams: + """Configuration parameters for the serving API.""" + + model: str + model_type: str # e.g., "qwen2_vl", "llava" + block_size: int + num_workers: int + base_kv_events_port: int + base_metrics_port: int + router_port: int + http_port: int + + +@dataclass +class ParsedRequest: + """Parsed and preprocessed request data.""" + + messages_dict: list[dict] + image_urls: list[str] + max_tokens: int + temperature: float + top_p: float + model: str + + +@dataclass +class ProcessedInput: + """Processed input ready for routing and generation.""" + + tokens: list[int] + mm_input: dict | None # For multimodal requests + mm_hashes: list[int] | None # List of mm_hash for each image + image_offsets_list: list[list[int]] | None # List of [start, end] for each image + + +class ServiceAPI: + """Main API service handling chat completion requests with KV cache routing.""" + + def __init__(self, init_params: ServingParams): + self.init_params = init_params + self.app = FastAPI(title="TensorRT-LLM Router API", version="0.0.1") + + self.workers: Optional[TrtllmWorkers] = None + self.tokenizer = None + self.processor = None + self.http_client: Optional[httpx.AsyncClient] = None + + self._setup_routes() + + # ------------------------------------------------------------------------- + # Request Parsing Helpers + # ------------------------------------------------------------------------- + + def _parse_request( + self, request: ChatCompletionRequest + ) -> ParsedRequest | ErrorResponse: + """Parse and validate the incoming request.""" + max_tokens = request.max_completion_tokens or request.max_tokens + if max_tokens is None: + return ErrorResponse( + error=make_error( + "Either max_tokens or max_completion_tokens must be specified", + "invalid_request_error", + 400, + ) + ) + + messages_dict, image_urls = self._extract_messages(request.messages) + + return ParsedRequest( + messages_dict=messages_dict, + image_urls=image_urls, + max_tokens=max_tokens, + temperature=request.temperature, + top_p=request.top_p, + model=request.model, + ) + + def _extract_messages( + self, messages: list[Message] + ) -> tuple[list[dict], list[str]]: + """Extract text messages and image URLs from request messages.""" + messages_dict = [] + image_urls = [] + + for msg in messages: + if isinstance(msg.content, str): + messages_dict.append({"role": msg.role, "content": msg.content}) + else: + text_parts = [] + for part in msg.content: + if part.type == "text" and part.text: + text_parts.append(part.text) + elif part.type == "image_url" and part.image_url: + image_urls.append(part.image_url.url) + messages_dict.append( + {"role": msg.role, "content": " ".join(text_parts)} + ) + + return messages_dict, image_urls + + def _build_prompt(self, messages_dict: list[dict]) -> str: + """Build prompt text from messages using chat template.""" + try: + return self.tokenizer.apply_chat_template( + messages_dict, tokenize=False, add_generation_prompt=True + ) + except Exception as e: + logger.warning(f"Chat template failed: {e}, using simple format") + return self._format_messages_simple(messages_dict) + + def _format_messages_simple(self, messages: list[dict]) -> str: + """Simple fallback formatting when chat template is unavailable.""" + parts = [] + role_map = {"system": "System", "user": "User", "assistant": "Assistant"} + for msg in messages: + prefix = role_map.get(msg["role"], msg["role"].capitalize()) + parts.append(f"{prefix}: {msg['content']}\n") + parts.append("Assistant: ") + return "\n".join(parts) + + # ------------------------------------------------------------------------- + # Multimodal Processing Helpers + # ------------------------------------------------------------------------- + + def _process_multimodal(self, prompt: str, image_urls: list[str]) -> ProcessedInput: + """Process multimodal request: load images, compute tokens and mm_hashes.""" + try: + # Use "multiple_image" modality when there are multiple images + modality = "multiple_image" if len(image_urls) > 1 else "image" + inputs = default_multimodal_input_loader( + tokenizer=self.tokenizer, + model_dir=self.init_params.model, + model_type=self.init_params.model_type, + modality=modality, + prompts=[prompt], + media=[image_urls], + image_data_format="pt", + device="cuda", + ) + mm_input = inputs[0] + processed_prompt = mm_input.get("prompt", prompt) + multi_modal_data = mm_input.get("multi_modal_data") + + tokens, image_offsets_list = self._get_mm_tokens( + processed_prompt, image_urls + ) + mm_hashes = self._compute_mm_hashes(multi_modal_data) + + return ProcessedInput( + tokens=tokens, + mm_input=mm_input, + mm_hashes=mm_hashes, + image_offsets_list=image_offsets_list, + ) + except Exception as e: + logger.warning(f"MM processing failed: {e}, falling back to text-only") + return ProcessedInput( + tokens=self.tokenizer.encode(prompt), + mm_input=None, + mm_hashes=None, + image_offsets_list=None, + ) + + def _get_mm_tokens( + self, prompt: str, image_urls: list[str] + ) -> tuple[list[int], list[list[int]] | None]: + """Get tokens with visual expansion and find image token positions.""" + if self.processor is None: + return self.tokenizer.encode(prompt), None + + pil_images = [load_image(url, format="pil") for url in image_urls] + processor_output = self.processor( + text=[prompt], images=pil_images, return_tensors="pt", padding=True + ) + tokens = processor_output["input_ids"][0].tolist() + + image_token_id = getattr( + self.processor, "image_token_id", QWEN2_VL_IMAGE_TOKEN_ID + ) + return self._replace_image_tokens( + tokens, image_token_id, QWEN2_VL_REPLACEMENT_ID + ) + + def _replace_image_tokens( + self, tokens: list[int], image_token_id: int, replacement_id: int + ) -> tuple[list[int], list[list[int]] | None]: + """Replace image tokens and return their positions as list of [start, end] per image. + + Finds contiguous regions of image tokens. Each contiguous region is assumed + to be one image. + """ + image_offsets_list: list[list[int]] = [] + current_start: int | None = None + + for i, t in enumerate(tokens): + if t == image_token_id: + if current_start is None: + current_start = i + tokens[i] = replacement_id + else: + # End of a contiguous image token region + if current_start is not None: + image_offsets_list.append([current_start, i]) + current_start = None + + # Handle case where image tokens go to the end + if current_start is not None: + image_offsets_list.append([current_start, len(tokens)]) + + if image_offsets_list: + logger.debug(f"Image token regions: {image_offsets_list}") + return tokens, image_offsets_list + return tokens, None + + def _compute_mm_hashes(self, multi_modal_data: dict | None) -> list[int] | None: + """Compute mm_hash for each image in multimodal data. + + Returns: + List of mm_hash (one per image), or None if no images. + """ + if not multi_modal_data: + return None + + mm_hashes_dict = apply_mm_hashes(multi_modal_data) + if "image" in mm_hashes_dict and mm_hashes_dict["image"]: + # Convert each 256-bit hex digest to 64-bit int + mm_hashes = [ + int(hex_digest[:16], 16) for hex_digest in mm_hashes_dict["image"] + ] + logger.debug(f"Computed mm_hashes for {len(mm_hashes)} images: {mm_hashes}") + return mm_hashes + return None + + # ------------------------------------------------------------------------- + # Routing Helpers + # ------------------------------------------------------------------------- + + def _build_block_mm_infos( + self, + num_tokens: int, + mm_hashes: list[int] | None, + image_offsets_list: list[list[int]] | None, + ) -> list[dict | None] | None: + """Build block_mm_infos for routing hash computation. + + For each block, includes mm_objects for all images that overlap with that block. + + Args: + num_tokens: Total number of tokens + mm_hashes: List of mm_hash, one per image + image_offsets_list: List of [start, end] offsets, one per image + + Returns: + List of mm_info dicts (one per block), with None for blocks without images. + """ + if mm_hashes is None or image_offsets_list is None: + return None + + if len(mm_hashes) != len(image_offsets_list): + logger.warning( + f"mm_hashes ({len(mm_hashes)}) and image_offsets_list " + f"({len(image_offsets_list)}) length mismatch" + ) + return None + + block_size = self.init_params.block_size + num_blocks = (num_tokens + block_size - 1) // block_size + + result: list[dict | None] = [] + for block_idx in range(num_blocks): + block_start = block_idx * block_size + block_end = block_start + block_size + + # Find all images that overlap with this block + mm_objects = [] + for mm_hash, offsets in zip(mm_hashes, image_offsets_list): + img_start, img_end = offsets + if block_end > img_start and block_start < img_end: + mm_objects.append({"mm_hash": mm_hash, "offsets": [offsets]}) + + if mm_objects: + result.append({"mm_objects": mm_objects}) + else: + result.append(None) + + return result + + async def _route_request( + self, local_hashes: list[int], num_tokens: int + ) -> int | ErrorResponse: + """Query router for best worker ID.""" + try: + router_request = RouterRequest( + local_hashes=local_hashes, num_tokens=num_tokens + ) + response = await self.http_client.post( + f"http://localhost:{self.init_params.router_port}/find_best_worker", + json=router_request.model_dump(), + timeout=1, + ) + response.raise_for_status() + return RouterResponse.model_validate(response.json()).worker_id + except (httpx.RequestError, httpx.HTTPStatusError) as e: + logger.error(f"Router request failed: {e}") + return ErrorResponse( + error=make_error( + "Router service unavailable", "service_unavailable", 503 + ) + ) + + # ------------------------------------------------------------------------- + # Response Streaming + # ------------------------------------------------------------------------- + + async def _stream_response( + self, request: ChatCompletionRequest, result_generator, request_id: str + ): + """Generate SSE formatted streaming responses.""" + created = int(time.time()) + first_chunk = True + try: + async for output in result_generator: + # Handle both dict (from worker) and object responses + if isinstance(output, dict): + text = output.get("text_diff") or output.get("text", "") + else: + text = getattr(output, "text_diff", None) or getattr( + output, "text", "" + ) + + if not text and not first_chunk: + continue + + delta = ( + {"role": "assistant", "content": text} + if first_chunk + else {"content": text} + ) + yield self._format_chunk( + request_id, created, request.model, delta, None + ) + first_chunk = False + + # Final chunk + yield self._format_chunk(request_id, created, request.model, {}, "stop") + yield "data: [DONE]\n\n" + + except Exception as e: + logger.error(f"Streaming error: {e}") + yield f"data: {json.dumps({'error': make_error(str(e), 'internal_error', 500)})}\n\n" + + def _format_chunk( + self, + request_id: str, + created: int, + model: str, + delta: dict, + finish_reason: str | None, + ) -> str: + """Format a single SSE chunk.""" + chunk = { + "id": request_id, + "object": "chat.completion.chunk", + "created": created, + "model": model, + "choices": [{"index": 0, "delta": delta, "finish_reason": finish_reason}], + } + return f"data: {json.dumps(chunk)}\n\n" + + async def _generate_full_response( + self, request: ChatCompletionRequest, result_generator, request_id: str + ) -> dict: + """Collect all outputs and generate a complete (non-streaming) response.""" + created = int(time.time()) + full_text = "" + + try: + async for output in result_generator: + if isinstance(output, dict): + text = output.get("text_diff") or output.get("text", "") + else: + text = getattr(output, "text_diff", None) or getattr( + output, "text", "" + ) + full_text += text + + return { + "id": request_id, + "object": "chat.completion", + "created": created, + "model": request.model, + "choices": [ + { + "index": 0, + "message": {"role": "assistant", "content": full_text}, + "finish_reason": "stop", + } + ], + "usage": { + "prompt_tokens": 0, # Not tracked in this implementation + "completion_tokens": 0, + "total_tokens": 0, + }, + } + + except Exception as e: + logger.error(f"Generation error: {e}") + return {"error": make_error(str(e), "internal_error", 500)} + + # ------------------------------------------------------------------------- + # Main Request Handler + # ------------------------------------------------------------------------- + + def _setup_routes(self): + @self.app.post("/v1/chat/completions") + async def chat_completions(request: ChatCompletionRequest): + # Check service readiness + if ( + self.workers is None + or self.tokenizer is None + or self.http_client is None + ): + return ErrorResponse( + error=make_error("Service not ready", "service_unavailable", 503) + ) + + try: + # Parse request + parsed = self._parse_request(request) + if isinstance(parsed, ErrorResponse): + return parsed + + # Process input (multimodal or text-only) + if parsed.image_urls: + # For multimodal: pass raw text, let default_multimodal_input_loader apply chat template + raw_text = " ".join( + msg["content"] + for msg in parsed.messages_dict + if msg.get("content") + ) + processed = self._process_multimodal(raw_text, parsed.image_urls) + else: + # For text-only: apply chat template ourselves + prompt = self._build_prompt(parsed.messages_dict) + processed = ProcessedInput( + tokens=self.tokenizer.encode(prompt), + mm_input=None, + mm_hashes=None, + image_offsets_list=None, + ) + + # Validate tokens + if not processed.tokens: + return ErrorResponse( + error=make_error( + "Input prompt is empty", "invalid_request_error", 400 + ) + ) + + # Compute block hashes for routing + block_mm_infos = self._build_block_mm_infos( + len(processed.tokens), + processed.mm_hashes, + processed.image_offsets_list, + ) + logger.debug(f"block_mm_infos: {block_mm_infos}") + local_hashes = compute_block_hash_for_seq_py( + processed.tokens, self.init_params.block_size, block_mm_infos + ) + + # Debug dump + dump_api_debug( + tokens=processed.tokens, + block_size=self.init_params.block_size, + local_hashes=local_hashes, + mm_hashes=processed.mm_hashes, + block_mm_infos=block_mm_infos, + image_urls=parsed.image_urls, + ) + + # Route to best worker + worker_id = await self._route_request( + local_hashes, len(processed.tokens) + ) + if isinstance(worker_id, ErrorResponse): + return worker_id + + # Generate response + request_id = f"chatcmpl-{uuid.uuid4()}" + sampling_params = { + "max_tokens": parsed.max_tokens, + "temperature": parsed.temperature, + "top_p": parsed.top_p, + } + prompt_input = processed.mm_input or processed.tokens + logger.debug(f"Sending to worker {worker_id}") + result_generator = self.workers.direct( + prompt_input, worker_id, sampling_params + ) + + if request.stream: + return StreamingResponse( + self._stream_response(request, result_generator, request_id), + media_type="text/event-stream", + headers={ + "Cache-Control": "no-cache", + "Connection": "keep-alive", + }, + ) + else: + # Non-streaming: collect all outputs and return complete response + response_data = await self._generate_full_response( + request, result_generator, request_id + ) + return JSONResponse(content=response_data) + + except Exception as e: + logger.error(f"Request processing error: {e}") + return ErrorResponse(error=make_error(str(e), "internal_error", 500)) + + # ------------------------------------------------------------------------- + # Lifecycle Management + # ------------------------------------------------------------------------- + + async def initialize_services(self): + """Initialize workers, HTTP client, and tokenizer.""" + logger.info( + f"Initializing services: model={self.init_params.model}, " + f"workers={self.init_params.num_workers}, block_size={self.init_params.block_size}" + ) + + self.workers = TrtllmWorkers( + model=self.init_params.model, + block_size=self.init_params.block_size, + base_kv_events_port=self.init_params.base_kv_events_port, + base_metrics_port=self.init_params.base_metrics_port, + num_workers=self.init_params.num_workers, + ) + await self.workers.start_all() + + self.http_client = httpx.AsyncClient() + self.tokenizer = tokenizer_factory(self.init_params.model) + + try: + self.processor = AutoProcessor.from_pretrained( + self.init_params.model, trust_remote_code=True + ) + except Exception as e: + logger.warning(f"Failed to initialize HF processor: {e}") + self.processor = None + + await asyncio.sleep(2) + logger.info("All services initialized") + + async def start(self): + """Start the API server.""" + await self.initialize_services() + + logger.info(f"Starting API server on port {self.init_params.http_port}") + config = uvicorn.Config( + self.app, host="0.0.0.0", port=self.init_params.http_port, log_level="info" + ) + server = uvicorn.Server(config) + await server.serve() + + async def shutdown(self): + """Proper shutdown handler.""" + logger.info("Shutting down API...") + + if self.http_client: + await self.http_client.aclose() + + if self.workers: + self.workers.shutdown_all() + + logger.info("API shutdown completed") + + +def main(): + parser = argparse.ArgumentParser(description="TensorRT-LLM Router API Server") + + parser.add_argument( + "--model", + type=str, + default="Qwen/Qwen2-VL-2B-Instruct", + help="Model name to use (VLM for multimodal support)", + ) + parser.add_argument( + "--model-type", + type=str, + default="qwen2_vl", + help="Model type for TRTLLM (e.g., qwen2_vl, llava, phi3_v)", + ) + parser.add_argument( + "--block-size", + type=int, + default=32, + help="Block size for caching (TensorRT-LLM uses 32)", + ) + parser.add_argument( + "--num-workers", type=int, default=2, help="Number of worker processes" + ) + parser.add_argument( + "--base-kv-events-port", type=int, default=5557, help="Base port for KV events" + ) + parser.add_argument( + "--base-metrics-port", type=int, default=5657, help="Base port for metrics" + ) + parser.add_argument( + "--router-port", type=int, default=7000, help="Port for router service" + ) + parser.add_argument( + "--http-port", type=int, default=8000, help="Port to serve the API on" + ) + + args = parser.parse_args() + logging.basicConfig(level=logging.INFO) + + init_params = ServingParams( + model=args.model, + model_type=args.model_type, + block_size=args.block_size, + num_workers=args.num_workers, + base_kv_events_port=args.base_kv_events_port, + base_metrics_port=args.base_metrics_port, + router_port=args.router_port, + http_port=args.http_port, + ) + + api = ServiceAPI(init_params=init_params) + router_api = RouterAPI( + block_size=args.block_size, + num_workers=args.num_workers, + base_kv_events_port=args.base_kv_events_port, + base_metrics_port=args.base_metrics_port, + port=args.router_port, + ) + + async def run_with_shutdown(): + try: + router_task = asyncio.create_task(router_api.start()) + await asyncio.sleep(0.5) + api_task = asyncio.create_task(api.start()) + await asyncio.gather(router_task, api_task) + except KeyboardInterrupt: + logger.info("Shutting down services...") + except Exception as e: + logger.exception(f"Unhandled exception: {e}") + finally: + await api.shutdown() + + try: + asyncio.run(run_with_shutdown()) + except KeyboardInterrupt: + logger.info("Force shutdown via KeyboardInterrupt.") + + +if __name__ == "__main__": + main() diff --git a/examples/deployments/router_standalone_trtllm/ping.sh b/examples/deployments/router_standalone_trtllm/ping.sh new file mode 100755 index 0000000000..8f36d3c9e9 --- /dev/null +++ b/examples/deployments/router_standalone_trtllm/ping.sh @@ -0,0 +1,27 @@ +#!/bin/bash +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Simple health check - sends a basic chat request +# Model name should match what you started api.py with + +curl -s -X POST http://localhost:8000/v1/chat/completions \ + -H "Content-Type: application/json" \ + -d '{ + "model": "Qwen/Qwen2-VL-2B-Instruct", + "messages": [{"role": "user", "content": "Hello!"}], + "stream": false, + "max_tokens": 50 + }' | jq \ No newline at end of file diff --git a/examples/deployments/router_standalone_trtllm/router.py b/examples/deployments/router_standalone_trtllm/router.py new file mode 100644 index 0000000000..94c0f6d5c7 --- /dev/null +++ b/examples/deployments/router_standalone_trtllm/router.py @@ -0,0 +1,388 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +import argparse +import asyncio +import json +import logging +import os +from contextlib import asynccontextmanager + +import numpy as np +import uvicorn +import zmq +from fastapi import FastAPI, HTTPException +from pydantic import BaseModel, ValidationError + +from dynamo._core import RadixTree, ZmqKvEventListener + +logger = logging.getLogger(__name__) + +DEBUG_ENABLED = os.environ.get("DYNAMO_DEBUG", "0") == "1" + + +def dump_kv_event(worker_id: int, event: dict): + """Dump KV event to file for debugging (only when DYNAMO_DEBUG=1).""" + if not DEBUG_ENABLED: + return + import datetime + + with open("/tmp/debug_kv_events.txt", "a") as f: + f.write(f"\n{'='*60}\n") + f.write(f"Timestamp: {datetime.datetime.now()}\n") + f.write(f"Worker ID: {worker_id}\n") + f.write(f"Event: {json.dumps(event, indent=2)}\n") + + +# ----------------------------------------------------------------------------- +# Request/Response Models +# ----------------------------------------------------------------------------- + + +class RouterRequest(BaseModel): + local_hashes: list[int] + num_tokens: int + + +class RouterResponse(BaseModel): + worker_id: int + overlap: float = 0.0 + matched_blocks: int = 0 + + +class InjectEventRequest(BaseModel): + """For testing: inject a KV event directly into RadixTree.""" + + worker_id: int + tokens_hash: int + block_hash: int | None = None + mm_extra_info: dict | None = None + + +class LoadMetrics(BaseModel): + kv_cache_usage: float + num_waiting_reqs: int + + +# ----------------------------------------------------------------------------- +# ZMQ Helpers +# ----------------------------------------------------------------------------- + + +def create_zmq_subscriber(context: zmq.Context, endpoint: str) -> zmq.Socket[bytes]: + """Create a ZMQ SUB socket with standard settings.""" + socket = context.socket(zmq.SUB) + socket.connect(endpoint) + socket.setsockopt(zmq.SUBSCRIBE, b"") + socket.setsockopt(zmq.CONFLATE, 1) + socket.setsockopt(zmq.RCVTIMEO, 1) + return socket + + +# ----------------------------------------------------------------------------- +# KvRouter Core +# ----------------------------------------------------------------------------- + + +class KvRouter: + """Router that uses RadixTree for KV cache-aware worker selection.""" + + def __init__( + self, + block_size: int = 64, + num_workers: int = 4, + base_kv_events_port: int = 5557, + base_metrics_port: int = 5657, + ): + self.num_workers = num_workers + self.block_size = block_size + self.radix_tree = RadixTree() + + # Per-worker metrics + self.kv_usages = [0.0] * num_workers + self.waitings = [0] * num_workers + + # ZMQ setup + self.context = zmq.Context() + self.load_listeners = [ + create_zmq_subscriber( + self.context, f"tcp://localhost:{base_metrics_port + i}" + ) + for i in range(num_workers) + ] + self.kv_listeners = [ + ZmqKvEventListener( + f"tcp://localhost:{base_kv_events_port + i}", "", block_size + ) + for i in range(num_workers) + ] + + self.background_tasks: list[asyncio.Task] = [] + logger.info("Router initialized") + + # ------------------------------------------------------------------------- + # Background Tasks + # ------------------------------------------------------------------------- + + async def start_background_tasks(self): + """Start background tasks for load and tree updates.""" + logger.info("Starting router background tasks...") + for worker_id in range(self.num_workers): + self.background_tasks.append( + asyncio.create_task(self._poll_worker_load(worker_id)) + ) + self.background_tasks.append( + asyncio.create_task(self._poll_worker_kv_events(worker_id)) + ) + + async def _poll_worker_load(self, worker_id: int): + """Poll load metrics for a single worker.""" + while True: + try: + data = self.load_listeners[worker_id].recv_json(zmq.NOBLOCK) + metrics = LoadMetrics.model_validate(data) + self.kv_usages[worker_id] = metrics.kv_cache_usage + self.waitings[worker_id] = metrics.num_waiting_reqs + except zmq.Again: + pass + except (zmq.ZMQError, ValidationError) as e: + logger.warning(f"Worker {worker_id} metrics error: {e}") + except Exception: + logger.exception(f"Worker {worker_id} unexpected metrics error") + await asyncio.sleep(0.1) + + async def _poll_worker_kv_events(self, worker_id: int): + """Poll KV events for a single worker and update RadixTree.""" + while True: + try: + events: list[str] = await self.kv_listeners[worker_id].get_events() + for event_str in events: + event = json.loads(event_str) + dump_kv_event(worker_id, event) + self.radix_tree.apply_event( + worker_id, json.dumps(event).encode("utf-8") + ) + except zmq.Again: + pass + except (zmq.ZMQError, json.JSONDecodeError) as e: + logger.warning(f"Worker {worker_id} KV events error: {e}") + except Exception: + logger.exception(f"Worker {worker_id} unexpected KV events error") + await asyncio.sleep(0.1) + + # ------------------------------------------------------------------------- + # Worker Selection + # ------------------------------------------------------------------------- + + async def get_best_worker( + self, local_hashes: list[int], num_tokens: int + ) -> tuple[int, float, int]: + """ + Find best worker for request. + + Returns: (worker_id, overlap_ratio, matched_blocks) + """ + if num_tokens <= 0: + raise ValueError("num_tokens must be positive") + + # Get cache matches from RadixTree + matched_blocks = self._get_matched_blocks(local_hashes) + + # Compute overlap scores + overlap_scores = { + wid: matched_blocks[wid] * self.block_size / num_tokens + for wid in range(self.num_workers) + } + + # Compute routing logits + logits = self._compute_logits(overlap_scores) + + # Select best worker (random tie-breaking) + best_id = self._select_best_worker(logits) + + # Predictive update for burst handling + self.waitings[best_id] += 1 + + return best_id, overlap_scores[best_id], matched_blocks[best_id] + + def _get_matched_blocks(self, local_hashes: list[int]) -> dict[int, int]: + """Get matched block count per worker from RadixTree.""" + result = self.radix_tree.find_matches(local_hashes) + raw_scores = result.scores + logger.info(f"Router: raw_scores={raw_scores}") + + # raw_scores is keyed by (worker_id, dp_rank); assume dp_rank=0 + return {wid: raw_scores.get((wid, 0), 0) for wid in range(self.num_workers)} + + def _compute_logits(self, overlap_scores: dict[int, float]) -> list[float]: + """Compute routing logits for each worker.""" + max_waiting = max(self.waitings) if self.waitings else 0 + + logits = [] + for wid in range(self.num_workers): + overlap = overlap_scores[wid] + usage = self.kv_usages[wid] + waiting_norm = self.waitings[wid] / max_waiting if max_waiting else 0.0 + logit = 2 * overlap - usage - waiting_norm + logits.append(logit) + logger.info( + f"worker_id: {wid}, logit = 2 * {overlap:.3f} - {usage:.3f} - {waiting_norm:.3f} = {logit:.3f}" + ) + return logits + + def _select_best_worker(self, logits: list[float]) -> int: + """Select worker with highest logit (random tie-breaking).""" + arr = np.array(logits) + return int(np.random.choice(np.flatnonzero(arr == arr.max()))) + + # ------------------------------------------------------------------------- + # Shutdown + # ------------------------------------------------------------------------- + + async def shutdown(self): + """Shutdown ZMQ listeners and background tasks.""" + logger.info("Shutting down KvRouter...") + + for task in self.background_tasks: + task.cancel() + if self.background_tasks: + await asyncio.gather(*self.background_tasks, return_exceptions=True) + + for listener in self.load_listeners: + listener.close() + + self.context.term() + logger.info("KvRouter shutdown completed") + + +# ----------------------------------------------------------------------------- +# Router API Server +# ----------------------------------------------------------------------------- + + +class RouterAPI: + """FastAPI wrapper for KvRouter.""" + + def __init__( + self, + block_size: int = 64, + num_workers: int = 4, + base_kv_events_port: int = 5557, + base_metrics_port: int = 5657, + port: int = 7000, + ): + self.port = port + self.router_config = { + "block_size": block_size, + "num_workers": num_workers, + "base_kv_events_port": base_kv_events_port, + "base_metrics_port": base_metrics_port, + } + self.router: KvRouter | None = None + self.app = FastAPI( + title="KV Router API", version="0.0.1", lifespan=self.lifespan + ) + self._setup_routes() + + def _require_router(self) -> KvRouter: + """Get router or raise 503 if not initialized.""" + if self.router is None: + raise HTTPException(status_code=503, detail="Router not initialized") + return self.router + + @asynccontextmanager + async def lifespan(self, app: FastAPI): + self.router = KvRouter(**self.router_config) + await self.router.start_background_tasks() + logger.info("Router API started") + yield + if self.router: + await self.router.shutdown() + + def _setup_routes(self): + @self.app.post("/find_best_worker", response_model=RouterResponse) + async def find_best_worker(request: RouterRequest): + router = self._require_router() + try: + wid, overlap, matched = await router.get_best_worker( + request.local_hashes, request.num_tokens + ) + return RouterResponse( + worker_id=wid, overlap=overlap, matched_blocks=matched + ) + except ValueError as e: + raise HTTPException(status_code=400, detail=str(e)) + + @self.app.get("/debug/tree_info") + async def get_tree_info(): + router = self._require_router() + events = router.radix_tree.dump_tree_as_events() + return {"num_blocks": len(events), "events": events[:20]} + + @self.app.post("/debug/inject_event") + async def inject_event(request: InjectEventRequest): + router = self._require_router() + block_hash = request.block_hash or request.tokens_hash + event = { + "event_id": 99999, + "data": { + "stored": { + "parent_hash": None, + "blocks": [ + { + "block_hash": block_hash, + "tokens_hash": request.tokens_hash, + "mm_extra_info": request.mm_extra_info, + } + ], + } + }, + } + router.radix_tree.apply_event( + request.worker_id, json.dumps(event).encode("utf-8") + ) + return { + "status": "ok", + "tokens_hash": request.tokens_hash, + "worker_id": request.worker_id, + } + + async def start(self): + """Start the router API server.""" + logger.info(f"Starting Router API on port {self.port}") + config = uvicorn.Config( + self.app, host="0.0.0.0", port=self.port, log_level="info" + ) + await uvicorn.Server(config).serve() + + +def main(): + parser = argparse.ArgumentParser(description="KV Router API Server") + parser.add_argument( + "--block-size", type=int, default=32, help="Block size (default: 32)" + ) + parser.add_argument("--num-workers", type=int, default=2, help="Number of workers") + parser.add_argument( + "--base-kv-events-port", type=int, default=5557, help="Base KV events port" + ) + parser.add_argument( + "--base-metrics-port", type=int, default=5657, help="Base metrics port" + ) + parser.add_argument("--port", type=int, default=7000, help="Router API port") + args = parser.parse_args() + + logging.basicConfig(level=logging.INFO) + + api = RouterAPI( + block_size=args.block_size, + num_workers=args.num_workers, + base_kv_events_port=args.base_kv_events_port, + base_metrics_port=args.base_metrics_port, + port=args.port, + ) + + asyncio.run(api.start()) + + +if __name__ == "__main__": + main() diff --git a/examples/deployments/router_standalone_trtllm/test_router.py b/examples/deployments/router_standalone_trtllm/test_router.py new file mode 100644 index 0000000000..02953b4914 --- /dev/null +++ b/examples/deployments/router_standalone_trtllm/test_router.py @@ -0,0 +1,1028 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +""" +Test suite for TensorRT-LLM KV Router. + +Usage: + python test_router.py # Run text-only tests (requires server) + python test_router.py --verbose # Show detailed logs + python test_router.py --mm-only # Run multimodal hash tests (no server needed) + python test_router.py --mm-server # Run multimodal server tests (requires VLM) + python test_router.py --all # Run all tests +""" + +import argparse +import sys +import time +from dataclasses import dataclass + +import httpx + +from dynamo.llm import compute_block_hash_for_seq_py + +# Sample test images from COCO dataset +TEST_IMAGE_1 = "http://images.cocodataset.org/test2017/000000155781.jpg" +TEST_IMAGE_2 = "http://images.cocodataset.org/test2017/000000000001.jpg" +TEST_IMAGE_3 = "http://images.cocodataset.org/test2017/000000155721.jpg" +TEST_IMAGE_4 = "https://huggingface.co/datasets/Sayali9141/traffic_signal_images/resolve/main/61.jpg" + + +@dataclass +class RouterTestConfig: + api_url: str = "http://localhost:8000" + router_url: str = "http://localhost:7000" + timeout: int = 30 + kv_settle_time: float = 3.0 # Time to wait for KV events to propagate + + +@dataclass +class RouterTestResult: + name: str + passed: bool + message: str + overlap: float = 0.0 + + +def make_request(content: str, max_tokens: int = 10) -> dict: + """Create a text-only chat completion request.""" + return { + "model": "test", + "messages": [{"role": "user", "content": content}], + "stream": True, + "max_tokens": max_tokens, + } + + +def make_mm_request(text: str, image_url: str, max_tokens: int = 10) -> dict: + """Create a multimodal chat completion request with image.""" + return { + "model": "test", + "messages": [ + { + "role": "user", + "content": [ + {"type": "text", "text": text}, + {"type": "image_url", "image_url": {"url": image_url}}, + ], + } + ], + "stream": True, + "max_tokens": max_tokens, + } + + +def make_multi_image_request( + text: str, image_urls: list[str], max_tokens: int = 10 +) -> dict: + """Create a multimodal chat completion request with multiple images.""" + content: list[dict] = [{"type": "text", "text": text}] + for url in image_urls: + content.append({"type": "image_url", "image_url": {"url": url}}) + return { + "model": "test", + "messages": [{"role": "user", "content": content}], + "stream": True, + "max_tokens": max_tokens, + } + + +def send_request(client: httpx.Client, url: str, payload: dict) -> bool: + """Send a chat completion request and consume the stream.""" + try: + resp = client.post(f"{url}/v1/chat/completions", json=payload) + if resp.status_code != 200: + return False + for _ in resp.iter_lines(): + pass + return True + except Exception: + return False + + +def get_tree_info(client: httpx.Client, url: str) -> dict: + """Get radix tree debug info.""" + try: + resp = client.get(f"{url}/debug/tree_info") + return resp.json() + except Exception: + return {"num_blocks": -1, "events": []} + + +class KvRouterTests: + """Test cases for KV cache routing.""" + + def __init__(self, config: RouterTestConfig, verbose: bool = False): + self.config = config + self.verbose = verbose + self.client = httpx.Client(timeout=config.timeout) + self.results: list[RouterTestResult] = [] + + # Test messages designed for block_size=32 + # "Are you ok? Hello! Thank you! Thank you very much! " is ~12 tokens + # Chat template adds ~4 tokens + self.base_phrase = "Are you ok? Hello! Thank you! Thank you very much! " + + def log(self, msg: str): + if self.verbose: + print(f" {msg}") + + def run_all(self) -> bool: + """Run all test cases.""" + print("\nKV Router Test Suite") + print("=" * 50) + + # Check server connectivity first + if not self._check_servers(): + print("\nFATAL: Cannot connect to servers") + return False + + # Run test cases + self._test_full_match() + self._test_partial_match() + self._test_no_match() + + # Print summary + return self._print_summary() + + def run_mm_tests(self) -> bool: + """Run multimodal tests (local hash computation, no server needed).""" + print("\nMultimodal KV Router Tests (Local)") + print("=" * 50) + print("(These tests verify hash computation without server)") + + self._test_mm_hash_computation() + self._test_mm_routing_distinction() + self._test_mm_hash_consistency() + self._test_mm_offset_affects_hash() + self._test_mm_block_boundary() + self._test_mm_multi_image_partial_match() + + return self._print_summary() + + def run_mm_server_tests(self) -> bool: + """Run multimodal tests that require server.""" + print("\nMultimodal KV Router Tests (Server)") + print("=" * 50) + + if not self._check_servers(): + print("\nFATAL: Cannot connect to servers") + return False + + self._test_mm_same_image_cache_hit() + self._test_mm_different_images_no_cache_hit() + self._test_text_cache_hit_with_overlap() + self._test_mm_multi_image_partial_match() + + return self._print_summary() + + def _check_servers(self) -> bool: + """Verify both API and Router servers are reachable.""" + print("\nChecking server connectivity...") + try: + # Check router + resp = self.client.get(f"{self.config.router_url}/debug/tree_info") + if resp.status_code != 200: + print(f" Router not responding: {resp.status_code}") + return False + print(f" Router OK (blocks in tree: {resp.json().get('num_blocks', '?')})") + + # Check API - just verify it's up + # A simple request to verify the endpoint exists + return True + except Exception as e: + print(f" Connection error: {e}") + return False + + def _test_full_match(self): + """ + Test: Send identical request twice. + Expected: Second request should have overlap > 0. + """ + print("\n[1] Full Match Test") + print(" Sending same request twice, expecting cache hit on second...") + + # Create a request with enough tokens for multiple full blocks + # 5 repetitions ≈ 64 tokens ≈ 2 full blocks + content = (self.base_phrase * 5).strip() + payload = make_request(content) + + # Get initial state + initial = get_tree_info(self.client, self.config.router_url) + initial_blocks = initial["num_blocks"] + self.log(f"Initial blocks: {initial_blocks}") + + # First request - should populate cache (or hit existing cache) + self.log("Sending first request...") + if not send_request(self.client, self.config.api_url, payload): + self.results.append( + RouterTestResult("full_match", False, "First request failed") + ) + return + + # Wait for KV events + self.log(f"Waiting {self.config.kv_settle_time}s for KV events...") + time.sleep(self.config.kv_settle_time) + + # Check blocks after first request + after_first = get_tree_info(self.client, self.config.router_url) + blocks_added = after_first["num_blocks"] - initial_blocks + self.log( + f"Blocks after first: {after_first['num_blocks']} (added {blocks_added})" + ) + + # Second request - should hit cache + self.log("Sending second request (should hit cache)...") + if not send_request(self.client, self.config.api_url, payload): + self.results.append( + RouterTestResult("full_match", False, "Second request failed") + ) + return + + # Success: either new blocks were added, or blocks already existed (from previous runs) + # Either way, the second request should show overlap > 0 in server logs + total_blocks = after_first["num_blocks"] + self.results.append( + RouterTestResult( + "full_match", + True, + f"OK - Tree has {total_blocks} blocks. Check server logs for 'overlap > 0'.", + ) + ) + + def _test_partial_match(self): + """ + Test: Send request A, then request B that shares same prefix but is longer. + Expected: Request B should have partial overlap (matching the shared prefix blocks). + """ + print("\n[2] Partial Match Test") + print(" Request B shares prefix with cached request A...") + + # Request A: 5 repetitions (~64 tokens, ~2 full blocks) + content_a = (self.base_phrase * 5).strip() + + # Request B: 8 repetitions (~100 tokens, ~3 full blocks) + # First 2 blocks should match A, third block is new + content_b = (self.base_phrase * 8).strip() + + payload_a = make_request(content_a) + payload_b = make_request(content_b) + + # Ensure A is cached (might already be from previous test) + self.log("Ensuring request A is cached...") + send_request(self.client, self.config.api_url, payload_a) + time.sleep(self.config.kv_settle_time) + + before = get_tree_info(self.client, self.config.router_url) + self.log(f"Blocks before B: {before['num_blocks']}") + + # Send request B + self.log("Sending request B (longer, shares prefix)...") + if not send_request(self.client, self.config.api_url, payload_b): + self.results.append( + RouterTestResult("partial_match", False, "Request B failed") + ) + return + + time.sleep(self.config.kv_settle_time) + + after = get_tree_info(self.client, self.config.router_url) + new_blocks = after["num_blocks"] - before["num_blocks"] + self.log(f"New blocks from B: {new_blocks}") + + # B should add new blocks (the non-matching suffix) + # The matching prefix blocks already exist + self.results.append( + RouterTestResult( + "partial_match", + True, + f"OK - Request B added {new_blocks} new blocks. " + f"Check server logs for partial overlap (0 < overlap < 1).", + ) + ) + + def _test_no_match(self): + """ + Test: Send completely different content. + Expected: No cache hit (overlap = 0). + """ + print("\n[3] No Match Test") + print(" Sending completely different content...") + + # Content that's very different from previous tests + # ~80 tokens, completely different from "Hello are you ok leijun" + content = ( + "The quick brown fox jumps over the lazy dog. " + "Pack my box with five dozen liquor jugs. " + "How vexingly quick daft zebras jump. " + "The five boxing wizards jump quickly. " + "Sphinx of black quartz, judge my vow." + ) + payload = make_request(content) + + before = get_tree_info(self.client, self.config.router_url) + self.log(f"Blocks before: {before['num_blocks']}") + + # Send the different request + self.log("Sending unrelated request...") + if not send_request(self.client, self.config.api_url, payload): + self.results.append(RouterTestResult("no_match", False, "Request failed")) + return + + # No need to wait - we're checking overlap on this request, not the next + self.results.append( + RouterTestResult( + "no_match", + True, + "OK - Check server logs for 'overlap = 0.000' (no cache hit expected).", + ) + ) + + def _test_mm_hash_computation(self): + """ + Test: Verify that compute_block_hash_for_seq_py produces different hashes + for same tokens with different mm_hash values. + """ + print("\n[MM-1] MM Hash Computation Test") + print(" Verifying same tokens + different mm_hash = different block_hash...") + + # Simulated tokens (32 tokens = 1 block) + tokens = [100] * 32 + block_size = 32 + + # Hash without MM info + hash_no_mm = compute_block_hash_for_seq_py(tokens, block_size) + + # Hash with MM info (simulated mm_hash) + mm_info_1 = {"mm_objects": [{"mm_hash": 0xDEADBEEF, "offsets": [[0, 32]]}]} + hash_with_mm1 = compute_block_hash_for_seq_py(tokens, block_size, [mm_info_1]) + + # Hash with different MM info + mm_info_2 = {"mm_objects": [{"mm_hash": 0xCAFEBABE, "offsets": [[0, 32]]}]} + hash_with_mm2 = compute_block_hash_for_seq_py(tokens, block_size, [mm_info_2]) + + self.log(f"Hash without MM: {hash_no_mm}") + self.log(f"Hash with MM 1: {hash_with_mm1}") + self.log(f"Hash with MM 2: {hash_with_mm2}") + + # Verify all hashes are different + if hash_no_mm == hash_with_mm1: + self.results.append( + RouterTestResult( + "mm_hash_computation", + False, + "FAIL - Hash without MM equals hash with MM", + ) + ) + return + + if hash_with_mm1 == hash_with_mm2: + self.results.append( + RouterTestResult( + "mm_hash_computation", + False, + "FAIL - Different mm_hash produced same block_hash", + ) + ) + return + + self.results.append( + RouterTestResult( + "mm_hash_computation", + True, + "OK - Different mm_hash values produce different block hashes", + ) + ) + + def _test_mm_routing_distinction(self): + """ + Test: Verify that the routing logic can distinguish between + requests with same text but different images. + """ + print("\n[MM-2] MM Routing Distinction Test") + print(" Verifying routing can distinguish same text + different images...") + + # This test simulates what the router would see + tokens = [100] * 64 # 2 blocks + block_size = 32 + + # Simulate Image A cached on worker 0 + mm_info_a = { + "mm_objects": [{"mm_hash": 0x1111111111111111, "offsets": [[0, 64]]}] + } + hashes_a = compute_block_hash_for_seq_py( + tokens, block_size, [mm_info_a, mm_info_a] + ) + + # Simulate Image B cached on worker 1 + mm_info_b = { + "mm_objects": [{"mm_hash": 0x2222222222222222, "offsets": [[0, 64]]}] + } + hashes_b = compute_block_hash_for_seq_py( + tokens, block_size, [mm_info_b, mm_info_b] + ) + + self.log(f"Hashes for Image A: {hashes_a}") + self.log(f"Hashes for Image B: {hashes_b}") + + # Verify hashes are different + if hashes_a == hashes_b: + self.results.append( + RouterTestResult( + "mm_routing_distinction", + False, + "FAIL - Same tokens with different images produced same hashes", + ) + ) + return + + self.results.append( + RouterTestResult( + "mm_routing_distinction", + True, + "OK - Router can distinguish requests with different images", + ) + ) + + def _test_mm_hash_consistency(self): + """ + Test: Verify that the same mm_hash + tokens produce the same block_hash + regardless of when computed (idempotency). + """ + print("\n[MM-3] MM Hash Consistency Test") + print(" Verifying same inputs produce same hash (idempotent)...") + + tokens = [151937] * 32 # Image token placeholder + block_size = 32 + mm_hash = 0xDEADBEEFCAFEBABE + + mm_info = {"mm_objects": [{"mm_hash": mm_hash, "offsets": [[0, 32]]}]} + + # Compute hash multiple times + hash1 = compute_block_hash_for_seq_py(tokens, block_size, [mm_info]) + hash2 = compute_block_hash_for_seq_py(tokens, block_size, [mm_info]) + hash3 = compute_block_hash_for_seq_py(tokens, block_size, [mm_info]) + + self.log(f"Hash 1: {hash1}") + self.log(f"Hash 2: {hash2}") + self.log(f"Hash 3: {hash3}") + + if hash1 != hash2 or hash2 != hash3: + self.results.append( + RouterTestResult( + "mm_hash_consistency", + False, + f"FAIL - Same inputs produced different hashes: {hash1}, {hash2}, {hash3}", + ) + ) + return + + self.results.append( + RouterTestResult( + "mm_hash_consistency", + True, + f"OK - Hash computation is idempotent: {hash1[0]}", + ) + ) + + def _test_mm_offset_affects_hash(self): + """ + Test: Verify that different offsets produce different hashes, + even with same mm_hash and tokens. + """ + print("\n[MM-4] MM Offset Affects Hash Test") + print(" Verifying different offsets produce different hashes...") + + tokens = [151937] * 64 # 2 blocks of image tokens + block_size = 32 + mm_hash = 0x123456789ABCDEF0 + + # Image covers first block only + mm_info_first = {"mm_objects": [{"mm_hash": mm_hash, "offsets": [[0, 32]]}]} + hash_first = compute_block_hash_for_seq_py( + tokens, block_size, [mm_info_first, None] + ) + + # Image covers second block only + mm_info_second = {"mm_objects": [{"mm_hash": mm_hash, "offsets": [[32, 64]]}]} + hash_second = compute_block_hash_for_seq_py( + tokens, block_size, [None, mm_info_second] + ) + + # Image covers both blocks + mm_info_both = {"mm_objects": [{"mm_hash": mm_hash, "offsets": [[0, 64]]}]} + hash_both = compute_block_hash_for_seq_py( + tokens, block_size, [mm_info_both, mm_info_both] + ) + + self.log(f"Hash (first block MM): {hash_first}") + self.log(f"Hash (second block MM): {hash_second}") + self.log(f"Hash (both blocks MM): {hash_both}") + + # Block 0 with mm_info should differ from block 0 without mm_info + # Block 1 with mm_info should differ from block 1 without mm_info + if hash_first[0] == hash_second[0]: + self.results.append( + RouterTestResult( + "mm_offset_affects_hash", + False, + "FAIL - First block hash should differ based on MM presence", + ) + ) + return + + self.results.append( + RouterTestResult( + "mm_offset_affects_hash", + True, + "OK - Different MM offsets produce different block hashes", + ) + ) + + def _test_mm_block_boundary(self): + """ + Test: Verify that MM info correctly applies at block boundaries. + """ + print("\n[MM-5] MM Block Boundary Test") + print(" Verifying MM info applies correctly at block boundaries...") + + block_size = 32 + mm_hash = 0xFEDCBA9876543210 + + # 96 tokens = 3 blocks + # Image tokens in the middle block (32-64) + tokens = [100] * 32 + [151937] * 32 + [200] * 32 + + # MM info only applies to middle block + mm_info = {"mm_objects": [{"mm_hash": mm_hash, "offsets": [[32, 64]]}]} + hashes_with_mm = compute_block_hash_for_seq_py( + tokens, block_size, [None, mm_info, None] + ) + + # No MM info + hashes_without_mm = compute_block_hash_for_seq_py(tokens, block_size, None) + + self.log(f"Hashes with MM: {hashes_with_mm}") + self.log(f"Hashes without MM: {hashes_without_mm}") + + # Block 0 and 2 should be the same (no image tokens) + # Block 1 should be different (has image tokens + mm_hash) + if hashes_with_mm[0] != hashes_without_mm[0]: + self.results.append( + RouterTestResult( + "mm_block_boundary", False, "FAIL - Block 0 should be same (no MM)" + ) + ) + return + + if hashes_with_mm[1] == hashes_without_mm[1]: + self.results.append( + RouterTestResult( + "mm_block_boundary", False, "FAIL - Block 1 should differ (has MM)" + ) + ) + return + + if hashes_with_mm[2] != hashes_without_mm[2]: + self.results.append( + RouterTestResult( + "mm_block_boundary", False, "FAIL - Block 2 should be same (no MM)" + ) + ) + return + + self.results.append( + RouterTestResult( + "mm_block_boundary", + True, + "OK - MM info correctly applies only to relevant blocks", + ) + ) + + def _test_mm_same_image_cache_hit(self): + """ + Test: Send same text + same image twice. + Expected: Second request should have cache hit (overlap > 0). + """ + print("\n[MM-S1] Same Image Cache Hit Test") + print(" Sending same text + same image twice...") + + payload = make_mm_request("Describe this image", TEST_IMAGE_1) + + # Get initial state + initial = get_tree_info(self.client, self.config.router_url) + self.log(f"Initial blocks: {initial['num_blocks']}") + + # First request - populates the cache + self.log("Sending first MM request...") + if not send_request(self.client, self.config.api_url, payload): + self.results.append( + RouterTestResult("mm_same_image", False, "First MM request failed") + ) + return + + # Wait for KV events to propagate + self.log(f"Waiting {self.config.kv_settle_time}s for KV events...") + time.sleep(self.config.kv_settle_time) + + after_first = get_tree_info(self.client, self.config.router_url) + blocks_added = after_first["num_blocks"] - initial["num_blocks"] + self.log( + f"Blocks after first: {after_first['num_blocks']} (added {blocks_added})" + ) + + if blocks_added == 0: + self.results.append( + RouterTestResult( + "mm_same_image", False, "FAIL - No blocks added after first request" + ) + ) + return + + # Second identical request - should hit cache + self.log("Sending second MM request (same image)...") + if not send_request(self.client, self.config.api_url, payload): + self.results.append( + RouterTestResult("mm_same_image", False, "Second MM request failed") + ) + return + + # Query router to check overlap (simulating what the second request saw) + # We need to compute the same hashes that the API computed + # For now, check the tree grew or stayed same (cache reuse) + after_second = get_tree_info(self.client, self.config.router_url) + self.log(f"Blocks after second: {after_second['num_blocks']}") + + # The second request should reuse cached blocks, so minimal new blocks added + new_blocks_second = after_second["num_blocks"] - after_first["num_blocks"] + self.log(f"New blocks from second request: {new_blocks_second}") + + self.results.append( + RouterTestResult( + "mm_same_image", + True, + f"OK - First added {blocks_added} blocks, second added {new_blocks_second}. " + f"Check logs for 'overlap > 0' on second request.", + ) + ) + + def _test_mm_different_images_no_cache_hit(self): + """ + Test: Send same text but different images. + Expected: No cache hit (overlap ≈ 0) because mm_hash differs. + Image blocks should not match, only text prefix might match. + """ + print("\n[MM-S2] Different Images No Cache Hit Test") + print(" Sending same text + different images...") + + # First image + payload_1 = make_mm_request("Describe this image in detail", TEST_IMAGE_2) + + initial = get_tree_info(self.client, self.config.router_url) + self.log(f"Initial blocks: {initial['num_blocks']}") + + self.log(f"Sending request with image 1: {TEST_IMAGE_2}") + if not send_request(self.client, self.config.api_url, payload_1): + self.results.append( + RouterTestResult("mm_different_images", False, "Image 1 request failed") + ) + return + + time.sleep(self.config.kv_settle_time) + + after_img1 = get_tree_info(self.client, self.config.router_url) + blocks_img1 = after_img1["num_blocks"] - initial["num_blocks"] + self.log( + f"Blocks after image 1: {after_img1['num_blocks']} (added {blocks_img1})" + ) + + # Second image (same text, different image) + payload_2 = make_mm_request("Describe this image in detail", TEST_IMAGE_3) + + self.log(f"Sending request with image 2: {TEST_IMAGE_3}") + if not send_request(self.client, self.config.api_url, payload_2): + self.results.append( + RouterTestResult("mm_different_images", False, "Image 2 request failed") + ) + return + + time.sleep(self.config.kv_settle_time) + + after_img2 = get_tree_info(self.client, self.config.router_url) + blocks_img2 = after_img2["num_blocks"] - after_img1["num_blocks"] + self.log( + f"Blocks after image 2: {after_img2['num_blocks']} (added {blocks_img2})" + ) + + # Different images should add similar number of blocks + # If image 2 had cache hit, it would add fewer blocks + if blocks_img2 == 0: + self.results.append( + RouterTestResult( + "mm_different_images", + False, + "FAIL - Image 2 added 0 blocks (unexpected full cache hit)", + ) + ) + return + + # Image 2 should add approximately same number of blocks as image 1 + # (since different mm_hash means image blocks don't match) + self.results.append( + RouterTestResult( + "mm_different_images", + True, + f"OK - Image 1 added {blocks_img1} blocks, image 2 added {blocks_img2} blocks. " + f"Different images = different block hashes.", + ) + ) + + def _test_text_cache_hit_with_overlap(self): + """ + Test: Send same text request twice and verify overlap via router API. + Expected: Second request should show overlap > 0 in router response. + """ + print("\n[MM-S3] Text Cache Hit with Overlap Verification") + print(" Sending same text twice and verifying overlap value...") + + # Use a unique prompt to avoid interference from other tests + unique_text = ( + "This is a unique test prompt for cache hit verification. " + "We need enough tokens to fill at least one block. " + "The quick brown fox jumps over the lazy dog repeatedly. " * 3 + ) + payload = make_request(unique_text, max_tokens=5) + + # First request + self.log("Sending first text request...") + if not send_request(self.client, self.config.api_url, payload): + self.results.append( + RouterTestResult( + "text_cache_hit_overlap", False, "First request failed" + ) + ) + return + + # Wait for KV events + self.log(f"Waiting {self.config.kv_settle_time}s for KV events...") + time.sleep(self.config.kv_settle_time) + + # Get tree info to see blocks + tree_info = get_tree_info(self.client, self.config.router_url) + self.log(f"Blocks in tree: {tree_info['num_blocks']}") + + # Second request - should see cache hit + self.log("Sending second text request (should hit cache)...") + if not send_request(self.client, self.config.api_url, payload): + self.results.append( + RouterTestResult( + "text_cache_hit_overlap", False, "Second request failed" + ) + ) + return + + # For a true verification, we'd need to intercept the router response + # or add an endpoint that returns the last routing decision + # For now, we verify by checking if blocks increased (they shouldn't much) + tree_info_after = get_tree_info(self.client, self.config.router_url) + new_blocks = tree_info_after["num_blocks"] - tree_info["num_blocks"] + self.log(f"New blocks after second request: {new_blocks}") + + self.results.append( + RouterTestResult( + "text_cache_hit_overlap", + True, + f"OK - Second request added {new_blocks} new blocks. " + f"Check logs for 'overlap > 0' (cache hit).", + ) + ) + + def _test_mm_multi_image_partial_match(self): + """ + Test: Verify partial cache match with multi-image requests. + + Scenario: + Step 1: Send Request A = text + [Image_1, Image_4] + Step 2: Send Request A again (identical) - verify full cache hit (0 new blocks) + Step 3: Send Request B = text + [Image_1, Image_3] - verify partial match + (Image_3 is different, should add new blocks) + + Expected: + - Identical request = no new blocks (full cache hit) + - Different second image = new blocks added (partial match) + """ + print("\n[MM-S4] Multi-Image Partial Match Test") + print(" Verifying cache behavior with multi-image requests...") + + # Use longer settle time for this test + settle_time = self.config.kv_settle_time * 2 + + # Request A: text + Image_1 + Image_4 + payload_a = make_multi_image_request( + "Describe these images in detail", [TEST_IMAGE_1, TEST_IMAGE_4] + ) + + initial = get_tree_info(self.client, self.config.router_url) + self.log(f"Initial blocks: {initial['num_blocks']}") + + # Step 1: Send Request A first time + self.log("Step 1: Sending Request A (text + Image_1 + Image_4)...") + if not send_request(self.client, self.config.api_url, payload_a): + self.results.append( + RouterTestResult("mm_multi_image_partial", False, "Request A failed") + ) + return + + time.sleep(settle_time) + + after_a1 = get_tree_info(self.client, self.config.router_url) + blocks_a1 = after_a1["num_blocks"] - initial["num_blocks"] + self.log( + f"Blocks after Request A: {after_a1['num_blocks']} (added {blocks_a1})" + ) + + if blocks_a1 == 0: + self.results.append( + RouterTestResult( + "mm_multi_image_partial", + False, + "FAIL - Request A added 0 blocks (should populate cache)", + ) + ) + return + + # Step 2: Send Request A again (identical) - should be full cache hit + self.log( + "Step 2: Sending Request A again (identical, expect full cache hit)..." + ) + if not send_request(self.client, self.config.api_url, payload_a): + self.results.append( + RouterTestResult( + "mm_multi_image_partial", False, "Request A (repeat) failed" + ) + ) + return + + time.sleep(settle_time) + + after_a2 = get_tree_info(self.client, self.config.router_url) + blocks_a2 = after_a2["num_blocks"] - after_a1["num_blocks"] + self.log( + f"Blocks after Request A repeat: {after_a2['num_blocks']} (added {blocks_a2})" + ) + + # Identical request should add 0 new blocks (full cache hit) + if blocks_a2 != 0: + self.log( + f"WARNING: Identical request added {blocks_a2} blocks (expected 0)" + ) + + # Step 3: Send Request B with different second image + payload_b = make_multi_image_request( + "Describe these images in detail", [TEST_IMAGE_1, TEST_IMAGE_3] + ) + + self.log( + "Step 3: Sending Request B (text + Image_1 + Image_3, different 2nd image)..." + ) + if not send_request(self.client, self.config.api_url, payload_b): + self.results.append( + RouterTestResult("mm_multi_image_partial", False, "Request B failed") + ) + return + + time.sleep(settle_time) + + after_b = get_tree_info(self.client, self.config.router_url) + blocks_b = after_b["num_blocks"] - after_a2["num_blocks"] + self.log(f"Blocks after Request B: {after_b['num_blocks']} (added {blocks_b})") + + # Analysis: + # - If blocks_b > 0: Image_3 created new blocks (correct - different image) + # - If blocks_b == 0: Full cache hit (wrong - Image_3 should be different) + # + # Note: We can't easily verify partial match vs full cache miss because + # the tree growth depends on whether routing hit the cached worker. + # What we CAN verify is that different images should NOT fully cache hit. + + if blocks_b == 0 and blocks_a2 == 0: + # Both identical and different requests added 0 blocks + # This suggests Image_3's mm_hash is incorrectly matching Image_4 + self.results.append( + RouterTestResult( + "mm_multi_image_partial", + False, + "FAIL - Request B (different image) added 0 blocks. " + "Image_3 should have different mm_hash than Image_4. " + "Check if mm_hash computation is correct.", + ) + ) + return + + if blocks_b == 0: + # Different image but 0 new blocks - might be timing or routing issue + self.results.append( + RouterTestResult( + "mm_multi_image_partial", + False, + f"FAIL - Request B added 0 blocks. " + f"Identical request added {blocks_a2}. " + f"This is unexpected - different images should not fully cache hit.", + ) + ) + return + + # Success: different image added new blocks + self.results.append( + RouterTestResult( + "mm_multi_image_partial", + True, + f"OK - Request A: {blocks_a1} blocks, A repeat: {blocks_a2}, " + f"Request B (diff image): {blocks_b}. " + f"Different images correctly create distinct cache entries.", + ) + ) + + def _print_summary(self) -> bool: + """Print test results summary.""" + print("\n" + "=" * 50) + print("Results") + print("=" * 50) + + all_passed = True + for r in self.results: + _ = "PASS" if r.passed else "FAIL" + symbol = "[OK]" if r.passed else "[X]" + print(f" {symbol} {r.name}: {r.message}") + if not r.passed: + all_passed = False + + print("\n" + "-" * 50) + if all_passed: + print("All tests passed.") + print("\nTo fully verify, check server logs for:") + print(" - Full match: overlap > 0.5") + print(" - Partial match: 0 < overlap < 0.5") + print(" - No match: overlap = 0.000") + else: + print("Some tests failed. Check the messages above.") + + return all_passed + + def cleanup(self): + self.client.close() + + +def main(): + parser = argparse.ArgumentParser(description="KV Router Test Suite") + parser.add_argument( + "--verbose", "-v", action="store_true", help="Show detailed logs" + ) + parser.add_argument( + "--api-url", default="http://localhost:8000", help="API server URL" + ) + parser.add_argument( + "--router-url", default="http://localhost:7000", help="Router URL" + ) + parser.add_argument( + "--mm-only", + action="store_true", + help="Run only multimodal local tests (no server needed)", + ) + parser.add_argument( + "--mm-server", + action="store_true", + help="Run multimodal server tests (requires VLM model)", + ) + parser.add_argument( + "--all", action="store_true", help="Run all tests including multimodal" + ) + args = parser.parse_args() + + config = RouterTestConfig(api_url=args.api_url, router_url=args.router_url) + tests = KvRouterTests(config, verbose=args.verbose) + + try: + if args.mm_only: + # Local MM tests only (no server) + success = tests.run_mm_tests() + elif args.mm_server: + # MM server tests (requires VLM) + success = tests.run_mm_server_tests() + elif args.all: + # Run all tests + success = tests.run_all() + if success: + success = tests.run_mm_tests() + if success: + success = tests.run_mm_server_tests() + else: + # Default: text-only tests + success = tests.run_all() + sys.exit(0 if success else 1) + finally: + tests.cleanup() + + +if __name__ == "__main__": + main() diff --git a/examples/deployments/router_standalone_trtllm/worker.py b/examples/deployments/router_standalone_trtllm/worker.py new file mode 100644 index 0000000000..e3a2819ff3 --- /dev/null +++ b/examples/deployments/router_standalone_trtllm/worker.py @@ -0,0 +1,627 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +import os + +# Fix protobuf version conflict with etcd3 +os.environ.setdefault("PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION", "python") + +import asyncio +import logging +import time +from typing import AsyncGenerator, Optional + +import msgpack +import zmq +from tensorrt_llm import LLM +from tensorrt_llm.llmapi import KvCacheConfig + +logger = logging.getLogger(__name__) + +DEFAULT_KV_EVENT_BUFFER_MAX_SIZE = 1024 + +# Debug flag: set DYNAMO_DEBUG=1 to enable debug file dumps +DEBUG_ENABLED = os.environ.get("DYNAMO_DEBUG", "0") == "1" +DEBUG_WORKER_KV_FILE = "/tmp/debug_worker_kv.txt" + +# Qwen2-VL specific token ID for image placeholders +IMAGE_TOKEN_ID = 151937 + + +def dump_worker_kv_event(worker_id: int, event: dict, token_ids: list[int]): + """Dump worker-side KV event to file for debugging.""" + if not DEBUG_ENABLED: + return + import datetime + + with open(DEBUG_WORKER_KV_FILE, "a") as f: + f.write(f"\n{'='*60}\n") + f.write(f"Timestamp: {datetime.datetime.now()}\n") + f.write(f"Worker ID: {worker_id}\n") + f.write(f"Event: {event}\n") + f.write(f"Tokens ({len(token_ids)}): {token_ids[:50]}...\n") + f.write(f"{'='*60}\n") + + +def to_unsigned_u64(value: int | None) -> int | None: + """Ensure value is in unsigned 64-bit range for Rust/msgpack.""" + if value is None: + return None + # Handle negative values (two's complement) + return (1 << 64) + value if value < 0 else value + + +# ----------------------------------------------------------------------------- +# ZMQ Publishers +# ----------------------------------------------------------------------------- + + +class MetricsPublisher: + """Publishes worker metrics over ZMQ.""" + + def __init__(self, port: int): + self.context = zmq.Context() + self.socket = self.context.socket(zmq.PUB) + self.socket.bind(f"tcp://*:{port}") + + def publish(self, num_waiting_reqs: int, kv_cache_usage: float): + self.socket.send_json( + { + "num_waiting_reqs": num_waiting_reqs, + "kv_cache_usage": kv_cache_usage, + } + ) + + def close(self): + self.socket.close() + self.context.term() + + +class KvEventsPublisher: + """Publishes KV cache events over ZMQ.""" + + def __init__(self, port: int, block_size: int): + self.context = zmq.Context() + self.socket = self.context.socket(zmq.PUB) + self.socket.bind(f"tcp://*:{port}") + self.block_size = block_size + self.partial_block_hashes: set[int] = set() + self.sequence_number = 0 + + def publish_stored( + self, + block_hashes: list[int], + token_ids: list[int], + parent_hash: int | None, + block_mm_infos: list[dict | None] | None, + ): + """Publish a BlockStored event. + + Args: + block_hashes: List of block hashes being stored. + token_ids: All token IDs across the blocks. + parent_hash: Hash of the parent block (if any). + block_mm_infos: Per-block multimodal info list. Each element corresponds + to a block and is either the mm_info dict (for blocks containing + image tokens) or None (for text-only blocks). + """ + event = { + "type": "BlockStored", + "block_hashes": [to_unsigned_u64(h) for h in block_hashes], + "token_ids": token_ids, + "block_size": self.block_size, + } + + if parent_hash is not None: + event["parent_block_hash"] = to_unsigned_u64(parent_hash) + + if block_mm_infos is not None: + event["block_mm_infos"] = block_mm_infos + + self._send([event]) + + def publish_removed(self, block_hashes: list[int]): + """Publish a BlockRemoved event.""" + # Filter out partial blocks + filtered = [] + for h in block_hashes: + if h in self.partial_block_hashes: + self.partial_block_hashes.remove(h) + else: + filtered.append(to_unsigned_u64(h)) + + if filtered: + self._send([{"type": "BlockRemoved", "block_hashes": filtered}]) + + def _send(self, events: list[dict]): + """Send events via ZMQ multipart message.""" + batch = [time.time(), events, 0] + try: + payload = msgpack.packb(batch, use_bin_type=True) + except Exception as e: + logger.error(f"msgpack error: {e}") + return + + seq_bytes = self.sequence_number.to_bytes(8, byteorder="big") + self.sequence_number += 1 + self.socket.send_multipart([b"", seq_bytes, payload]) + + def close(self): + self.socket.close() + self.context.term() + + +# ----------------------------------------------------------------------------- +# KV Event Processing Helpers +# ----------------------------------------------------------------------------- + + +def extract_mm_info( + blocks_data: list[dict], all_token_ids: list[int] +) -> tuple[list[int] | None, list[list[int]] | None]: + """Extract multimodal hash info from TRTLLM block data. + + Handles multiple images by extracting all mm_hashes and matching them + to their corresponding image token ranges. + + Returns: + Tuple of (list of mm_hashes, list of offsets) or (None, None). + Each offset is [start, end) for one image's token range. + """ + # Collect all mm_hashes from blocks + mm_hashes: list[int] = [] + for block in blocks_data: + mm_keys = block.get("mm_keys", []) + for mm_key in mm_keys: + if mm_key.get("type") != "mm_key": + continue + hash_hex = mm_key.get("hash", "") + if hash_hex: + mm_hash = int(hash_hex[:16], 16) + if mm_hash not in mm_hashes: # Avoid duplicates + mm_hashes.append(mm_hash) + + if not mm_hashes: + return None, None + + # Find all image token ranges + image_offsets_list = find_all_image_token_ranges(all_token_ids) + if not image_offsets_list: + return None, None + + # Match mm_hashes to image_offsets by order + # (assumes mm_hashes appear in same order as images in token sequence) + return mm_hashes, image_offsets_list + + +def find_all_image_token_ranges(token_ids: list[int]) -> list[list[int]] | None: + """Find all [start, end) ranges of contiguous image tokens. + + Returns: + List of [start, end) ranges, one per contiguous image token sequence. + Returns None if no image tokens found. + """ + ranges: list[list[int]] = [] + current_start: int | None = None + + for i, tid in enumerate(token_ids): + if tid == IMAGE_TOKEN_ID: + if current_start is None: + current_start = i + elif current_start is not None: + # End of contiguous sequence + ranges.append([current_start, i]) + current_start = None + + # Handle sequence ending with image tokens + if current_start is not None: + ranges.append([current_start, len(token_ids)]) + + return ranges if ranges else None + + +def build_per_block_mm_infos( + num_blocks: int, + block_size: int, + mm_hashes: list[int] | None, + image_offsets_list: list[list[int]] | None, +) -> list[dict | None] | None: + """Build per-block mm_infos list for multiple images. + + Each block that overlaps with an image's token range gets the corresponding + mm_info with that image's mm_hash. + + Args: + num_blocks: Number of blocks in the stored event. + block_size: Number of tokens per block. + mm_hashes: List of mm_hash values, one per image. + image_offsets_list: List of [start, end) token ranges, one per image. + + Returns: + List of mm_info (one per block), with None for blocks without image tokens. + Returns None if no mm_info is provided. + """ + if mm_hashes is None or image_offsets_list is None: + return None + + if not mm_hashes or not image_offsets_list: + return None + + # Initialize result with None for all blocks + result: list[dict | None] = [None] * num_blocks + + # Process each image + for mm_hash, offsets in zip(mm_hashes, image_offsets_list): + img_start, img_end = offsets + + for block_idx in range(num_blocks): + block_start = block_idx * block_size + block_end = block_start + block_size + + # Check if this block overlaps with this image's token range + if block_end > img_start and block_start < img_end: + if result[block_idx] is None: + result[block_idx] = {"mm_objects": []} + # Add this image's mm_object to the block + result[block_idx]["mm_objects"].append( + {"mm_hash": mm_hash, "offsets": [offsets]} + ) + + return result + + +def parse_stored_blocks( + blocks_data: list[dict], block_size: int, partial_hashes: set[int] +) -> tuple[list[dict], list[int]]: + """Parse stored blocks from TRTLLM event data. + + Returns: + Tuple of (blocks list, all token_ids) + """ + blocks = [] + all_token_ids = [] + + for block in blocks_data: + tokens = block["tokens"] + num_tokens = len(tokens) + block_hash = block["block_hash"] + + if num_tokens == block_size: + token_ids = [int(t["token_id"]) for t in tokens] + blocks.append( + { + "block_hash": block_hash, + "token_ids": token_ids, + "num_tokens": num_tokens, + } + ) + all_token_ids.extend(token_ids) + elif num_tokens < block_size: + # Partial block - track but don't publish + partial_hashes.add(block_hash) + break + else: + logger.error(f"Block too large: {num_tokens} > {block_size}") + break + + return blocks, all_token_ids + + +# ----------------------------------------------------------------------------- +# TRT-LLM Worker +# ----------------------------------------------------------------------------- + + +class TrtllmWorker: + """Manages a single TensorRT-LLM worker with event/metrics publishing.""" + + def __init__( + self, + worker_id: int, + model: str, + block_size: int, + kv_events_port: int, + metrics_port: int, + ): + self.worker_id = worker_id + self.model = model + self.block_size = block_size + + self.llm: Optional[LLM] = None + self.metrics_publisher: Optional[MetricsPublisher] = None + self.kv_events_publisher: Optional[KvEventsPublisher] = None + + self.background_tasks: list[asyncio.Task] = [] + self.max_window_size: int | None = None + self.processing_initial_events = True + self.kv_events_started = False + + self._initialize(kv_events_port, metrics_port) + + def _initialize(self, kv_events_port: int, metrics_port: int): + """Initialize TensorRT-LLM engine and publishers.""" + logger.info(f"Worker {self.worker_id}: Initializing") + + self.llm = LLM( + model=self.model, + kv_cache_config=KvCacheConfig( + enable_block_reuse=True, + event_buffer_max_size=DEFAULT_KV_EVENT_BUFFER_MAX_SIZE, + ), + ) + + self.metrics_publisher = MetricsPublisher(metrics_port) + self.kv_events_publisher = KvEventsPublisher(kv_events_port, self.block_size) + + logger.info(f"Worker {self.worker_id}: Initialized") + + # ------------------------------------------------------------------------- + # Background Tasks + # ------------------------------------------------------------------------- + + async def start_background_tasks(self): + """Start metrics publishing task.""" + self.background_tasks.append(asyncio.create_task(self._metrics_loop())) + + def _start_kv_events_task(self): + """Lazily start KV events task on first request.""" + if self.kv_events_started: + return + self.kv_events_started = True + logger.info(f"Worker {self.worker_id}: Starting KV events monitoring") + self.background_tasks.append(asyncio.create_task(self._kv_events_loop())) + + async def _metrics_loop(self): + """Continuously publish worker metrics.""" + await asyncio.sleep(1) + + try: + async for stat in self.llm.get_stats_async(timeout=5): + if not isinstance(stat, dict): + continue + + num_waiting = ( + stat["numQueuedRequests"] + + stat["inflightBatchingStats"]["numPausedRequests"] + ) + kv_stats = stat["kvCacheStats"] + usage = ( + kv_stats["allocTotalBlocks"] / kv_stats["maxNumBlocks"] + if kv_stats["maxNumBlocks"] > 0 + else 0.0 + ) + + self.metrics_publisher.publish(num_waiting, usage) + + except asyncio.CancelledError: + pass + except Exception as e: + logger.error(f"Worker {self.worker_id} metrics error: {e}") + + async def _kv_events_loop(self): + """Continuously process and publish KV cache events.""" + await asyncio.sleep(2) + + try: + events = self.llm.get_kv_cache_events_async(timeout=None) + logger.info(f"Worker {self.worker_id}: KV events iterator obtained") + + async for event in events: + self._process_kv_event(event) + + except asyncio.CancelledError: + pass + except RuntimeError as e: + if "IterationResult is not properly instantiated" in str(e): + logger.warning(f"Worker {self.worker_id}: KV events not available") + else: + logger.error(f"Worker {self.worker_id} KV events error: {e}") + except Exception as e: + logger.error(f"Worker {self.worker_id} KV events error: {e}") + + logger.warning(f"Worker {self.worker_id}: KV events loop exited") + + def _process_kv_event(self, event: dict): + """Process a single KV cache event.""" + if not isinstance(event, dict): + return + if "event_id" not in event or "data" not in event: + return + + data = event["data"] + event_type = data.get("type") + + if self._should_drop_event(event): + return + + if event_type == "stored": + self._handle_stored_event(data) + elif event_type == "removed": + self._handle_removed_event(data) + elif event_type == "created" and self.processing_initial_events: + self._update_window_size(event) + + def _should_drop_event(self, event: dict) -> bool: + """Check if event should be dropped (non-global attention).""" + if self.processing_initial_events: + return False + window_size = event.get("window_size") + if window_size is None: + return False + return window_size != self.max_window_size + + def _update_window_size(self, event: dict): + """Update max window size from created events.""" + window_size = event.get("window_size") + if window_size and ( + self.max_window_size is None or window_size > self.max_window_size + ): + self.max_window_size = window_size + + def _handle_stored_event(self, data: dict): + """Handle a stored block event.""" + self.processing_initial_events = False + + blocks, all_token_ids = parse_stored_blocks( + data["blocks"], + self.block_size, + self.kv_events_publisher.partial_block_hashes, + ) + + if not blocks: + return + + parent_hash = data.get("parent_hash") + mm_hashes, image_offsets_list = extract_mm_info(data["blocks"], all_token_ids) + + block_hashes = [b["block_hash"] for b in blocks] + + # Build per-block mm_infos (only blocks with image tokens get mm_info) + block_mm_infos = build_per_block_mm_infos( + len(blocks), self.block_size, mm_hashes, image_offsets_list + ) + + # Debug dump + dump_worker_kv_event( + self.worker_id, + {"type": "stored", "blocks": len(blocks), "mm_hashes": mm_hashes}, + all_token_ids, + ) + + self.kv_events_publisher.publish_stored( + block_hashes, all_token_ids, parent_hash, block_mm_infos + ) + + def _handle_removed_event(self, data: dict): + """Handle a removed block event.""" + self.processing_initial_events = False + + block_hashes = data.get("block_hashes", []) + self.kv_events_publisher.publish_removed(block_hashes) + + # ------------------------------------------------------------------------- + # Generation + # ------------------------------------------------------------------------- + + async def generate( + self, + prompt_input, # list[int] (tokens) or dict (MM input) + sampling_params: dict, + ) -> AsyncGenerator[dict, None]: + """Generate tokens for a request.""" + from tensorrt_llm.llmapi.llm import SamplingParams + + # Start KV events on first request + self._start_kv_events_task() + + trtllm_params = SamplingParams( + max_tokens=sampling_params.get("max_tokens", 100), + temperature=sampling_params.get("temperature", 1.0), + top_p=sampling_params.get("top_p", 1.0), + top_k=max(0, sampling_params.get("top_k", 0)), + ) + + outputs = self.llm.generate_async( + prompt_input, sampling_params=trtllm_params, streaming=False + ) + + async for output in outputs: + yield self._format_output(output) + + def _format_output(self, request_output) -> dict: + """Format TRTLLM output to standard response dict.""" + if not hasattr(request_output, "outputs") or not request_output.outputs: + return {"text": "", "text_diff": "", "token_ids": [], "finish_reason": None} + + completion = request_output.outputs[0] + text = getattr(completion, "text_diff", None) or getattr(completion, "text", "") + + return { + "text": text, + "text_diff": getattr(completion, "text_diff", text), + "token_ids": getattr(completion, "token_ids", []), + "finish_reason": getattr(completion, "finish_reason", None), + } + + # ------------------------------------------------------------------------- + # Lifecycle + # ------------------------------------------------------------------------- + + def shutdown(self): + """Shutdown worker and cleanup resources.""" + logger.info(f"Worker {self.worker_id}: Shutting down") + + for task in self.background_tasks: + task.cancel() + + if self.llm: + self.llm.shutdown() + if self.metrics_publisher: + self.metrics_publisher.close() + if self.kv_events_publisher: + self.kv_events_publisher.close() + + +# ----------------------------------------------------------------------------- +# Worker Manager +# ----------------------------------------------------------------------------- + + +class TrtllmWorkers: + """Manages multiple TensorRT-LLM workers. + + Warning: Creating multiple workers in the same process causes them to share + the same GPU(s). + """ + + def __init__( + self, + model: str = "Qwen/Qwen2-VL-2B-Instruct", + block_size: int = 32, + base_kv_events_port: int = 5557, + base_metrics_port: int = 5657, + num_workers: int = 1, + ): + self.workers = [] + + if num_workers > 1: + logger.warning( + f"Creating {num_workers} workers in the same process. " + "All workers will share the same GPU(s). For multi-GPU isolation, " + "start each worker in a separate process with CUDA_VISIBLE_DEVICES set." + ) + + logger.info(f"Initializing {num_workers} workers for {model}") + + for i in range(num_workers): + self.workers.append( + TrtllmWorker( + worker_id=i, + model=model, + block_size=block_size, + kv_events_port=base_kv_events_port + i, + metrics_port=base_metrics_port + i, + ) + ) + + logger.info(f"All {num_workers} workers initialized") + + async def start_all(self): + """Start background tasks for all workers.""" + for worker in self.workers: + await worker.start_background_tasks() + + async def direct( + self, prompt_input, worker_id: int, sampling_params: dict + ) -> AsyncGenerator[dict, None]: + """Send request to a specific worker.""" + async for output in self.workers[worker_id].generate( + prompt_input, sampling_params + ): + yield output + + def shutdown_all(self): + """Shutdown all workers.""" + logger.info("Shutting down all workers") + for worker in self.workers: + worker.shutdown() diff --git a/lib/bindings/c/src/lib.rs b/lib/bindings/c/src/lib.rs index da2ba53f3a..da08f063f9 100644 --- a/lib/bindings/c/src/lib.rs +++ b/lib/bindings/c/src/lib.rs @@ -170,10 +170,12 @@ fn kv_event_create_stored_block_from_parts( let tokens_hash = compute_block_hash_for_seq( unsafe { std::slice::from_raw_parts(token_ids, num_tokens) }, kv_block_size, + None, )[0]; KvCacheStoredBlockData { block_hash: ExternalSequenceBlockHash(block_hash), tokens_hash, + mm_extra_info: None, } } static WARN_COUNT: AtomicU32 = AtomicU32::new(0); diff --git a/lib/bindings/python/rust/llm/kv.rs b/lib/bindings/python/rust/llm/kv.rs index f4e08db1a8..79ea8de899 100644 --- a/lib/bindings/python/rust/llm/kv.rs +++ b/lib/bindings/python/rust/llm/kv.rs @@ -27,12 +27,35 @@ use llm_rs::protocols::common::{OutputOptions, SamplingOptions, StopConditions}; use serde_json::json; #[pyfunction] -pub fn compute_block_hash_for_seq_py(tokens: Vec, kv_block_size: usize) -> PyResult> { +#[pyo3(signature = (tokens, kv_block_size, block_mm_infos=None))] +pub fn compute_block_hash_for_seq_py( + _py: Python, + tokens: Vec, + kv_block_size: usize, + block_mm_infos: Option>, +) -> PyResult> { if kv_block_size == 0 { - return Err(to_pyerr(anyhow::anyhow!("kv_block_size cannot be 0"))); + return Err(PyErr::new::( + "kv_block_size cannot be 0", + )); } - let hashes = compute_block_hash_for_seq(&tokens, kv_block_size as u32); + // Convert Python block_mm_infos to Rust Vec> + let mm_infos_rust: Option>> = block_mm_infos + .as_ref() + .map(|infos_py| { + depythonize::>>(infos_py).map_err(|e| { + PyErr::new::(format!( + "Failed to convert block_mm_infos: {}", + e + )) + }) + }) + .transpose()?; + + let hashes = + compute_block_hash_for_seq(&tokens, kv_block_size as u32, mm_infos_rust.as_deref()); + Ok(hashes.into_iter().map(|h| h.0).collect()) } @@ -280,7 +303,7 @@ impl KvEventPublisher { } #[allow(clippy::too_many_arguments)] - #[pyo3(signature = (event_id, token_ids, num_block_tokens, block_hashes, lora_id, parent_hash=None))] + #[pyo3(signature = (event_id, token_ids, num_block_tokens, block_hashes, lora_id, parent_hash=None, block_mm_infos=None))] fn publish_stored( &mut self, py: Python, @@ -290,12 +313,26 @@ impl KvEventPublisher { block_hashes: Vec, lora_id: u64, parent_hash: Option, + block_mm_infos: Option>, ) -> PyResult<()> { let kv_block_size = self.kv_block_size as u32; let dp_rank = self.dp_rank; let warning_count = self.warning_count.clone(); let inner = self.inner.clone(); + // Convert Python block_mm_infos to Rust Vec> + let mm_infos_rust: Option>> = block_mm_infos + .as_ref() + .map(|infos_py| { + depythonize::>>(infos_py).map_err(|e| { + PyErr::new::(format!( + "Failed to convert block_mm_infos: {}", + e + )) + }) + }) + .transpose()?; + py.allow_threads(|| { let block_hashes_u64: Vec = block_hashes.iter().map(|&h| h as u64).collect(); let event = KvCacheEvent { @@ -309,6 +346,7 @@ impl KvEventPublisher { &block_hashes_u64, lora_id, &warning_count, + mm_infos_rust.as_deref(), ), }), dp_rank, diff --git a/lib/bindings/python/src/dynamo/_core.pyi b/lib/bindings/python/src/dynamo/_core.pyi index 3f87a3127c..de7109f2e4 100644 --- a/lib/bindings/python/src/dynamo/_core.pyi +++ b/lib/bindings/python/src/dynamo/_core.pyi @@ -232,16 +232,42 @@ class Client: ... -def compute_block_hash_for_seq_py(tokens: List[int], kv_block_size: int) -> List[int]: +def compute_block_hash_for_seq_py( + tokens: List[int], + kv_block_size: int, + block_mm_infos: Optional[List[Optional[Dict[str, Any]]]] = None +) -> List[int]: """ - Compute block hashes for a sequence of tokens + Compute block hashes for a sequence of tokens, optionally including multimodal metadata. + + When block_mm_infos is provided, the mm_hashes are included in the hash computation + to ensure that blocks with identical tokens but different multimodal objects produce + different hashes. Args: tokens: List of token IDs - kv_block_size: Size of each KV cache block + kv_block_size: Size of each block in tokens + block_mm_infos: Optional per-block multimodal metadata. Each element corresponds to a block + and should be None or a dict with structure: + { + "mm_objects": [ + { + "mm_hash": int, # Hash of the MM object + } + ] + } Returns: - List of block hashes as integers + List of block hashes (one per block) + + Example: + >>> tokens = [1, 2, 3, 4] * 8 # 32 tokens = 1 block + >>> mm_info = { + ... "mm_objects": [{ + ... "mm_hash": 0xDEADBEEF, + ... }] + ... } + >>> hashes = compute_block_hash_for_seq_py(tokens, 32, [mm_info]) """ ... diff --git a/lib/bindings/python/tests/test_mm_kv_router.py b/lib/bindings/python/tests/test_mm_kv_router.py new file mode 100644 index 0000000000..618b9ee9eb --- /dev/null +++ b/lib/bindings/python/tests/test_mm_kv_router.py @@ -0,0 +1,438 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +""" +Tests for Multimodal KV Router functionality. + +These tests verify that the KV router correctly handles multimodal content (images, videos) +by distinguishing between requests with identical token sequences but different MM objects. + +Key Concepts: +- block_hash: External hash used to identify blocks uniquely (includes MM info) +- tokens_hash: Local hash based only on token content +- mm_hash: Hash of the multimodal object (image, video, etc.) + +Test Strategy: +- Use RadixTree directly to avoid NATS/etcd infrastructure dependencies +- Simulate multiple workers caching same tokens with different MM content +- Verify that routing distinguishes between different MM objects +""" + +import json +from typing import Any + +import pytest + +from dynamo.llm import RadixTree, compute_block_hash_for_seq_py + +pytestmark = pytest.mark.pre_merge + +# Constants for testing +DEFAULT_BLOCK_SIZE = 32 +MM_HASH_1 = 0xDEADBEEF +MM_HASH_2 = 0xCAFEBABE +MM_HASH_3 = 0xFEEDFACE + + +def make_mm_info(mm_hash: int, offsets: list[list[int]] | None = None) -> dict: + """Create a block's MM extra info structure.""" + if offsets is None: + offsets = [[0, 10]] + return {"mm_objects": [{"mm_hash": mm_hash, "offsets": offsets}]} + + +def make_store_event( + event_id: int, + blocks: list[dict], + parent_hash: int | None = None, +) -> bytes: + """Create a JSON-encoded store event for RadixTree.""" + event = { + "event_id": event_id, + "data": { + "stored": { + "parent_hash": parent_hash, + "blocks": blocks, + } + }, + } + return json.dumps(event).encode("utf-8") + + +def make_block( + block_hash: int, + tokens_hash: int | None = None, + mm_info: dict | None = None, +) -> dict: + """Create a block structure for store events.""" + block: dict[str, Any] = { + "block_hash": block_hash, + "tokens_hash": tokens_hash if tokens_hash is not None else block_hash, + } + if mm_info is not None: + block["mm_extra_info"] = mm_info + return block + + +# ============================================================================= +# RadixTree MM Routing Tests +# ============================================================================= + + +# # @pytest.mark.timeout(5) +def test_radix_tree_mm_routing_basic(): + """Test RadixTree correctly distinguishes blocks with same tokens but different MM content.""" + radix_tree = RadixTree() + + # Worker 0: Store block with MM Object 1 + worker_0, block_hash_w0 = 0, 1000 + event_w0 = make_store_event( + event_id=1, + blocks=[make_block(block_hash_w0, mm_info=make_mm_info(MM_HASH_1))], + ) + radix_tree.apply_event(worker_0, event_w0) + + # Worker 1: Store block with DIFFERENT MM Object (same tokens) + worker_1, block_hash_w1 = 1, 2000 + event_w1 = make_store_event( + event_id=2, + blocks=[make_block(block_hash_w1, mm_info=make_mm_info(MM_HASH_2))], + ) + radix_tree.apply_event(worker_1, event_w1) + + # Verify both blocks are stored + all_blocks = radix_tree.dump_tree_as_events() + assert len(all_blocks) == 2 + + # Query for worker 0's block + scores_w0 = radix_tree.find_matches([block_hash_w0]) + assert (worker_0, 0) in scores_w0.scores + assert scores_w0.scores[(worker_0, 0)] == 1 + + # Query for worker 1's block + scores_w1 = radix_tree.find_matches([block_hash_w1]) + assert (worker_1, 0) in scores_w1.scores + assert scores_w1.scores[(worker_1, 0)] == 1 + + # Query with non-existent hash should return no matches + scores_none = radix_tree.find_matches([9999]) + assert len(scores_none.scores) == 0 + + +# @pytest.mark.timeout(5) +def test_radix_tree_mm_block_chaining(): + """Test block chaining with parent_hash for multi-block sequences with MM content.""" + radix_tree = RadixTree() + + worker_id = 0 + parent_hash = 1000 + child_hash = 2000 + + # Store parent block + parent_event = make_store_event( + event_id=1, + blocks=[make_block(parent_hash, mm_info=make_mm_info(MM_HASH_1))], + ) + radix_tree.apply_event(worker_id, parent_event) + + # Store child block that references parent + child_event = make_store_event( + event_id=2, + blocks=[make_block(child_hash, mm_info=make_mm_info(MM_HASH_1))], + parent_hash=parent_hash, + ) + radix_tree.apply_event(worker_id, child_event) + + # Verify chain exists + all_blocks = radix_tree.dump_tree_as_events() + assert len(all_blocks) == 2 + + # Query with both hashes should match the chain + scores = radix_tree.find_matches([parent_hash, child_hash]) + assert (worker_id, 0) in scores.scores + assert scores.scores[(worker_id, 0)] == 2 + + +# @pytest.mark.timeout(5) +def test_radix_tree_worker_removal(): + """Test worker removal clears all its blocks.""" + radix_tree = RadixTree() + + worker_0, worker_1 = 0, 1 + + # Add blocks for both workers + radix_tree.apply_event( + worker_0, + make_store_event(1, [make_block(1000, mm_info=make_mm_info(MM_HASH_1))]), + ) + radix_tree.apply_event( + worker_1, + make_store_event(2, [make_block(2000, mm_info=make_mm_info(MM_HASH_2))]), + ) + + assert len(radix_tree.dump_tree_as_events()) == 2 + + # Remove worker 0 + radix_tree.remove_worker(worker_0) + + # Only worker 1's block should remain + remaining = radix_tree.dump_tree_as_events() + assert len(remaining) == 1 + + scores = radix_tree.find_matches([2000]) + assert (worker_1, 0) in scores.scores + + +# @pytest.mark.timeout(5) +def test_radix_tree_clear_all_blocks(): + """Test clearing all blocks for a specific worker.""" + radix_tree = RadixTree() + + worker_id = 0 + + # Add multiple blocks + radix_tree.apply_event( + worker_id, + make_store_event(1, [make_block(1000), make_block(2000)]), + ) + + assert len(radix_tree.dump_tree_as_events()) == 2 + + # Clear all blocks for worker + radix_tree.clear_all_blocks(worker_id) + + assert len(radix_tree.dump_tree_as_events()) == 0 + + +# ============================================================================= +# Block Hash Computation Tests +# ============================================================================= + + +# @pytest.mark.timeout(5) +def test_mm_block_hash_computation_basic(): + """Test that same tokens with different MM content produce different hashes.""" + tokens = [100] * DEFAULT_BLOCK_SIZE + + # Without MM info + hashes_no_mm = compute_block_hash_for_seq_py(tokens, DEFAULT_BLOCK_SIZE) + assert len(hashes_no_mm) == 1 + + # With MM info 1 + hashes_mm1 = compute_block_hash_for_seq_py( + tokens, DEFAULT_BLOCK_SIZE, [make_mm_info(MM_HASH_1)] + ) + assert len(hashes_mm1) == 1 + + # With MM info 2 + hashes_mm2 = compute_block_hash_for_seq_py( + tokens, DEFAULT_BLOCK_SIZE, [make_mm_info(MM_HASH_2)] + ) + assert len(hashes_mm2) == 1 + + # All three should be different + assert hashes_no_mm != hashes_mm1 + assert hashes_no_mm != hashes_mm2 + assert hashes_mm1 != hashes_mm2 + + +# @pytest.mark.timeout(5) +def test_mm_block_hash_determinism(): + """Test that hash computation is deterministic.""" + tokens = [100] * DEFAULT_BLOCK_SIZE + mm_info = [make_mm_info(MM_HASH_1)] + + hash1 = compute_block_hash_for_seq_py(tokens, DEFAULT_BLOCK_SIZE, mm_info) + hash2 = compute_block_hash_for_seq_py(tokens, DEFAULT_BLOCK_SIZE, mm_info) + + assert hash1 == hash2 + + +# @pytest.mark.timeout(5) +@pytest.mark.parametrize("block_size", [16, 32, 64]) +def test_mm_block_hash_multiple_blocks(block_size: int): + """Test hash computation for sequences spanning multiple blocks.""" + num_blocks = 3 + # Use different tokens per block to get unique hashes + tokens = [] + for i in range(num_blocks): + tokens.extend([100 + i] * block_size) + + # One MM info per block + mm_infos = [make_mm_info(MM_HASH_1) for _ in range(num_blocks)] + + hashes = compute_block_hash_for_seq_py(tokens, block_size, mm_infos) + + assert len(hashes) == num_blocks + # Each block should have a unique hash (due to different tokens) + assert len(set(hashes)) == num_blocks + + +# @pytest.mark.timeout(5) +def test_mm_block_hash_partial_block(): + """Test hash computation when tokens don't fill complete blocks.""" + # 1.5 blocks worth of tokens + tokens = [100] * (DEFAULT_BLOCK_SIZE + DEFAULT_BLOCK_SIZE // 2) + + # MM info for each block + mm_infos = [make_mm_info(MM_HASH_1), make_mm_info(MM_HASH_2)] + + hashes = compute_block_hash_for_seq_py(tokens, DEFAULT_BLOCK_SIZE, mm_infos) + + # Only complete blocks get hashes - partial blocks are not hashed + assert len(hashes) == 1 + + +# @pytest.mark.timeout(5) +def test_mm_block_hash_none_mm_info(): + """Test that None MM info is handled correctly.""" + tokens = [100] * DEFAULT_BLOCK_SIZE + + # Pass None for some blocks' MM info + mm_infos = [None] + + hashes_with_none = compute_block_hash_for_seq_py( + tokens, DEFAULT_BLOCK_SIZE, mm_infos + ) + hashes_without = compute_block_hash_for_seq_py(tokens, DEFAULT_BLOCK_SIZE) + + # Both should produce the same result + assert hashes_with_none == hashes_without + + +# @pytest.mark.timeout(5) +def test_mm_block_hash_different_offsets(): + """Test that same mm_hash with different offsets produces same hash.""" + tokens = [100] * DEFAULT_BLOCK_SIZE + + # Same MM hash, different offsets + mm_info_1 = make_mm_info(MM_HASH_1, offsets=[[0, 10]]) + mm_info_2 = make_mm_info(MM_HASH_1, offsets=[[5, 15]]) + + hash1 = compute_block_hash_for_seq_py(tokens, DEFAULT_BLOCK_SIZE, [mm_info_1]) + hash2 = compute_block_hash_for_seq_py(tokens, DEFAULT_BLOCK_SIZE, [mm_info_2]) + + # Currently offsets are not included in hash computation - just mm_hash + # This behavior may change - update test if needed + assert hash1 == hash2 + + +# @pytest.mark.timeout(5) +def test_mm_block_hash_multiple_mm_objects(): + """Test hash with multiple MM objects in a single block.""" + tokens = [100] * DEFAULT_BLOCK_SIZE + + # Multiple MM objects in one block + mm_info = { + "mm_objects": [ + {"mm_hash": MM_HASH_1, "offsets": [[0, 5]]}, + {"mm_hash": MM_HASH_2, "offsets": [[10, 15]]}, + ] + } + + hashes = compute_block_hash_for_seq_py(tokens, DEFAULT_BLOCK_SIZE, [mm_info]) + + assert len(hashes) == 1 + + # Compare with single MM object + single_mm_hashes = compute_block_hash_for_seq_py( + tokens, DEFAULT_BLOCK_SIZE, [make_mm_info(MM_HASH_1)] + ) + + # Should be different due to additional MM object + assert hashes != single_mm_hashes + + +# @pytest.mark.timeout(5) +def test_mm_block_hash_error_zero_block_size(): + """Test that zero block size raises an error.""" + tokens = [100] * 32 + + with pytest.raises(ValueError, match="kv_block_size cannot be 0"): + compute_block_hash_for_seq_py(tokens, 0) + + +# ============================================================================= +# Integration Tests: RadixTree + Hash Computation +# ============================================================================= + + +# @pytest.mark.timeout(5) +def test_integration_mm_hash_to_routing(): + """Test end-to-end: compute hash -> store in tree -> query matches correctly.""" + radix_tree = RadixTree() + tokens = [100] * DEFAULT_BLOCK_SIZE + + # Compute hashes for two different MM contents + hash_mm1 = compute_block_hash_for_seq_py( + tokens, DEFAULT_BLOCK_SIZE, [make_mm_info(MM_HASH_1)] + )[0] + hash_mm2 = compute_block_hash_for_seq_py( + tokens, DEFAULT_BLOCK_SIZE, [make_mm_info(MM_HASH_2)] + )[0] + + # Store each on different workers + worker_0, worker_1 = 0, 1 + + radix_tree.apply_event( + worker_0, + make_store_event(1, [make_block(hash_mm1, mm_info=make_mm_info(MM_HASH_1))]), + ) + radix_tree.apply_event( + worker_1, + make_store_event(2, [make_block(hash_mm2, mm_info=make_mm_info(MM_HASH_2))]), + ) + + # Query with MM1's hash should match worker 0 + scores_mm1 = radix_tree.find_matches([hash_mm1]) + assert (worker_0, 0) in scores_mm1.scores + assert (worker_1, 0) not in scores_mm1.scores + + # Query with MM2's hash should match worker 1 + scores_mm2 = radix_tree.find_matches([hash_mm2]) + assert (worker_1, 0) in scores_mm2.scores + assert (worker_0, 0) not in scores_mm2.scores + + +# @pytest.mark.timeout(5) +@pytest.mark.parametrize("num_workers", [2, 3, 5]) +def test_integration_multiple_workers_same_tokens(num_workers: int): + """Test routing with multiple workers caching same tokens but different MM content.""" + radix_tree = RadixTree() + tokens = [100] * DEFAULT_BLOCK_SIZE + + # Each worker has unique MM content + mm_hashes = [0x1000 + i for i in range(num_workers)] + + # Store blocks for each worker + for worker_id, mm_hash in enumerate(mm_hashes): + block_hash = compute_block_hash_for_seq_py( + tokens, DEFAULT_BLOCK_SIZE, [make_mm_info(mm_hash)] + )[0] + + radix_tree.apply_event( + worker_id, + make_store_event( + event_id=worker_id + 1, + blocks=[make_block(block_hash, mm_info=make_mm_info(mm_hash))], + ), + ) + + # Verify all blocks stored + assert len(radix_tree.dump_tree_as_events()) == num_workers + + # Query for each worker's block should match only that worker + for worker_id, mm_hash in enumerate(mm_hashes): + block_hash = compute_block_hash_for_seq_py( + tokens, DEFAULT_BLOCK_SIZE, [make_mm_info(mm_hash)] + )[0] + + scores = radix_tree.find_matches([block_hash]) + + assert (worker_id, 0) in scores.scores + assert scores.scores[(worker_id, 0)] == 1 + + # No other workers should match + for other_id in range(num_workers): + if other_id != worker_id: + assert (other_id, 0) not in scores.scores diff --git a/lib/llm/src/kv_router.rs b/lib/llm/src/kv_router.rs index df8b7aeb30..4c358b17ec 100644 --- a/lib/llm/src/kv_router.rs +++ b/lib/llm/src/kv_router.rs @@ -475,7 +475,7 @@ impl KvRouter { let isl_tokens = tokens.len(); - let block_hashes = compute_block_hash_for_seq(tokens, self.block_size); + let block_hashes = compute_block_hash_for_seq(tokens, self.block_size, None); let seq_hashes = compute_seq_hash_for_block(&block_hashes); let overlap_scores = self.indexer.find_matches(block_hashes.clone()).await?; @@ -530,7 +530,7 @@ impl KvRouter { let isl_tokens = tokens.len(); let maybe_seq_hashes = self.kv_router_config.router_track_active_blocks.then(|| { - let block_hashes = compute_block_hash_for_seq(tokens, self.block_size); + let block_hashes = compute_block_hash_for_seq(tokens, self.block_size, None); compute_seq_hash_for_block(&block_hashes) }); @@ -573,11 +573,11 @@ impl KvRouter { /// Get potential prefill and decode loads for all workers pub async fn get_potential_loads(&self, tokens: &[u32]) -> Result> { let isl_tokens = tokens.len(); - let block_hashes = compute_block_hash_for_seq(tokens, self.block_size); + let block_hashes = compute_block_hash_for_seq(tokens, self.block_size, None); let overlap_scores = self.indexer.find_matches(block_hashes).await?; let maybe_seq_hashes = self.kv_router_config.router_track_active_blocks.then(|| { - let block_hashes = compute_block_hash_for_seq(tokens, self.block_size); + let block_hashes = compute_block_hash_for_seq(tokens, self.block_size, None); compute_seq_hash_for_block(&block_hashes) }); @@ -661,7 +661,10 @@ impl AsyncEngine, ManyOut>, Er let context_id = ctx.context().id().to_string(); // Handle different request types let response = match request { - RouterRequest::New { tokens } => { + RouterRequest::New { + tokens, + request_extra_info: _, + } => { let (best_worker, overlap_blocks) = self .find_best_match(Some(&context_id), &tokens, None, true) .await?; @@ -761,7 +764,7 @@ impl AsyncEngine, ManyOut LocalBlockHash { // let hash = xxh3::xxh3_64_with_seed(&bytes, XXH3_SEED); // } -/// Compute the hash for a sequence of tokens. +/// Compute the hash for a sequence of tokens, optionally including multimodal metadata. +/// +/// When multimodal extra info is provided, the mm_hashes are included in the hash computation +/// to ensure that blocks with identical tokens but different multimodal objects produce +/// different hashes. /// /// ### Arguments /// /// * `tokens` - A vector of `u32` tokens. +/// * `kv_block_size` - The size of each block in tokens. +/// * `block_mm_infos` - Optional per-block multimodal metadata. /// /// ### Returns /// /// A vector of `LocalBlockHash` representing the computed hashes for each chunk of tokens. -pub fn compute_block_hash_for_seq(tokens: &[u32], kv_block_size: u32) -> Vec { +pub fn compute_block_hash_for_seq( + tokens: &[u32], + kv_block_size: u32, + block_mm_infos: Option<&[Option]>, +) -> Vec { tokens - .chunks_exact(kv_block_size as usize) // Split into chunks of kv_block_size elements - .map(|chunk| { - let bytes: Vec = chunk - .iter() - .flat_map(|&num| num.to_le_bytes()) // Convert each i32 to its little-endian bytes - .collect(); + .chunks_exact(kv_block_size as usize) + .enumerate() + .map(|(block_idx, chunk)| { + let mut bytes: Vec = chunk.iter().flat_map(|&num| num.to_le_bytes()).collect(); + + // Include MM hashes in the block hash computation if present + if let Some(mm_infos) = block_mm_infos + && let Some(Some(block_mm_info)) = mm_infos.get(block_idx) + { + // The order of different multimodal hashes does not matter. + // Only which multimodal infos are present in a block is important. + // The order may differ in different code paths, so the hashes are sorted + // to keep the block hash stable. + let mut mm_hashes: Vec = block_mm_info + .mm_objects + .iter() + .map(|obj| obj.mm_hash) + .collect(); + mm_hashes.sort_unstable(); + + // Append sorted mm_hashes to the byte array + for mm_hash in mm_hashes { + bytes.extend_from_slice(&mm_hash.to_le_bytes()); + } + } - compute_block_hash(&Bytes::from(bytes)) // Convert the byte Vec to Bytes + compute_block_hash(&bytes) }) .collect() } @@ -610,6 +638,7 @@ impl RadixTree { parent_hash, blocks: vec![KvCacheStoredBlockData { block_hash: *external_hash, + mm_extra_info: None, tokens_hash, }], }), @@ -1076,6 +1105,7 @@ impl KvIndexer { blocks: hashes.map(|(local_hash, sequence_hash)| KvCacheStoredBlockData { tokens_hash: *local_hash, block_hash: ExternalSequenceBlockHash(*sequence_hash), + mm_extra_info: None, }).collect(), }); @@ -1243,7 +1273,7 @@ impl KvIndexerInterface for KvIndexer { tokens, tokens.len() ); - let sequence = compute_block_hash_for_seq(tokens, self.kv_block_size); + let sequence = compute_block_hash_for_seq(tokens, self.kv_block_size, None); tracing::debug!("Computed sequence: {:?}", sequence); self.find_matches(sequence).await } @@ -1296,7 +1326,7 @@ impl KvIndexerInterface for KvIndexer { tokens: &[u32], worker: WorkerWithDpRank, ) -> Result<(), KvRouterError> { - let local_hashes = compute_block_hash_for_seq(tokens, self.kv_block_size); + let local_hashes = compute_block_hash_for_seq(tokens, self.kv_block_size, None); let sequence = TokenBlockSequence::new(tokens.into(), self.kv_block_size, None); let sequence_hashes = sequence .blocks() @@ -1813,6 +1843,7 @@ impl KvIndexerSharded { blocks: hashes.map(|(local_hash, sequence_hash)| KvCacheStoredBlockData { tokens_hash: *local_hash, block_hash: ExternalSequenceBlockHash(*sequence_hash), + mm_extra_info: None, }).collect(), }); @@ -1973,7 +2004,7 @@ impl KvIndexerInterface for KvIndexerSharded { &self, tokens: &[u32], ) -> Result { - let sequence = compute_block_hash_for_seq(tokens, self.kv_block_size); + let sequence = compute_block_hash_for_seq(tokens, self.kv_block_size, None); self.find_matches(sequence).await } @@ -2073,7 +2104,7 @@ impl KvIndexerInterface for KvIndexerSharded { tokens: &[u32], worker: WorkerWithDpRank, ) -> Result<(), KvRouterError> { - let local_hashes = compute_block_hash_for_seq(tokens, self.kv_block_size); + let local_hashes = compute_block_hash_for_seq(tokens, self.kv_block_size, None); let sequence = TokenBlockSequence::new(tokens.into(), self.kv_block_size, None); let sequence_hashes = sequence .blocks() @@ -2111,6 +2142,7 @@ mod tests { .map(|i| KvCacheStoredBlockData { tokens_hash: LocalBlockHash(*i), block_hash: ExternalSequenceBlockHash(*i * 100), + mm_extra_info: None, }) .collect() } @@ -2714,17 +2746,17 @@ mod tests { setup(); // create a sequence of 64 elements let sequence = (0..kv_block_size).collect::>(); - let hashes = compute_block_hash_for_seq(&sequence, kv_block_size); + let hashes = compute_block_hash_for_seq(&sequence, kv_block_size, None); assert_eq!(hashes.len(), 1); // create a sequence of 65 elements let sequence = (0..(kv_block_size + 1)).collect::>(); - let hashes = compute_block_hash_for_seq(&sequence, kv_block_size); + let hashes = compute_block_hash_for_seq(&sequence, kv_block_size, None); assert_eq!(hashes.len(), 1); // create a sequence of 129 elements let sequence = (0..(2 * kv_block_size + 1)).collect::>(); - let hashes = compute_block_hash_for_seq(&sequence, kv_block_size); + let hashes = compute_block_hash_for_seq(&sequence, kv_block_size, None); assert_eq!(hashes.len(), 2); } @@ -2929,6 +2961,7 @@ mod tests { parent_hash: None, blocks: vec![KvCacheStoredBlockData { block_hash: ExternalSequenceBlockHash(0), + mm_extra_info: None, tokens_hash: LocalBlockHash(13226331709069118873), }], }), @@ -3392,6 +3425,7 @@ mod tests { blocks: vec![KvCacheStoredBlockData { block_hash: ExternalSequenceBlockHash(id * 100), tokens_hash: LocalBlockHash(id * 200), + mm_extra_info: None, }], }), dp_rank: 0, @@ -3567,6 +3601,7 @@ mod tests { blocks: vec![KvCacheStoredBlockData { block_hash: ExternalSequenceBlockHash(100), tokens_hash: LocalBlockHash(200), + mm_extra_info: None, }], }), dp_rank: 0, diff --git a/lib/llm/src/kv_router/protocols.rs b/lib/llm/src/kv_router/protocols.rs index 81a36a4fb9..057beca614 100644 --- a/lib/llm/src/kv_router/protocols.rs +++ b/lib/llm/src/kv_router/protocols.rs @@ -40,6 +40,8 @@ pub enum RouterRequest { #[serde(rename = "new")] New { tokens: Vec, + #[serde(default, skip_serializing_if = "Option::is_none")] + request_extra_info: Option, }, MarkPrefill, MarkFree, @@ -47,7 +49,10 @@ pub enum RouterRequest { impl Default for RouterRequest { fn default() -> Self { - RouterRequest::New { tokens: vec![] } + RouterRequest::New { + tokens: vec![], + request_extra_info: None, + } } } @@ -276,6 +281,111 @@ pub struct KvCacheStoreData { pub blocks: Vec, } +/// Multimodal object information within a block. +/// Offsets are relative to the block (0 to block_size-1). +#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)] +pub struct BlockMmObjectInfo { + /// Hash identifying this multimodal object + pub mm_hash: u64, + /// Token offset ranges where this MM object's placeholders appear within THIS block + /// Each tuple is (start_offset, end_offset) relative to block start + pub offsets: Vec<(usize, usize)>, +} + +/// Extra metadata for a block containing multimodal objects +#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)] +pub struct BlockExtraInfo { + /// All multimodal objects referenced in this block + pub mm_objects: Vec, +} + +/// Request-level multimodal object information. +/// Offsets are relative to the entire request token sequence. +#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)] +pub struct RequestMmObjectInfo { + /// Hash identifying this multimodal object + pub mm_hash: u64, + /// Token offset ranges where this MM object's placeholders appear in the ENTIRE request + /// Each tuple is (start_offset, end_offset) relative to request start + pub offsets: Vec<(usize, usize)>, +} + +/// Request-level multimodal metadata +#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)] +pub struct RequestExtraInfo { + /// All multimodal objects in this request + pub mm_objects: Vec, +} + +impl RequestExtraInfo { + /// Convert request-level MM info to block-level MM info for a sequence of blocks. + /// + /// This function splits request-level offsets (relative to the entire request token sequence) + /// into block-level offsets (relative to each block). + /// + /// # Arguments + /// * `block_size` - The size of each block in tokens + /// * `total_tokens` - Total number of tokens in the request + /// + /// # Returns + /// A vector of `Option` where each element corresponds to a block. + /// `None` indicates a block with no multimodal objects. + pub fn to_block_level( + &self, + block_size: usize, + total_tokens: usize, + ) -> Vec> { + let num_blocks = total_tokens.div_ceil(block_size); + let mut block_infos: Vec> = vec![None; num_blocks]; + + for req_mm_obj in &self.mm_objects { + for (req_start, req_end) in &req_mm_obj.offsets { + // Find which blocks this offset range spans + let start_block = req_start / block_size; + let end_block = (req_end.saturating_sub(1)) / block_size; + + let upper_bound = end_block.min(num_blocks - 1) + 1; + for (block_idx, block_info_opt) in block_infos + .iter_mut() + .enumerate() + .take(upper_bound) + .skip(start_block) + { + let block_start_global = block_idx * block_size; + let block_end_global = ((block_idx + 1) * block_size).min(total_tokens); + + // Calculate the intersection of this MM object's range with this block + let local_start = (*req_start).max(block_start_global) - block_start_global; + let local_end = (*req_end).min(block_end_global) - block_start_global; + + if local_start < local_end { + let block_info = block_info_opt + .get_or_insert_with(|| BlockExtraInfo { mm_objects: vec![] }); + + // Check if we already have this mm_hash in this block + if let Some(existing) = block_info + .mm_objects + .iter_mut() + .find(|obj| obj.mm_hash == req_mm_obj.mm_hash) + { + // Add the offset range to existing object + existing.offsets.push((local_start, local_end)); + } else { + // Create new MM object entry for this block + block_info.mm_objects.push(BlockMmObjectInfo { + mm_hash: req_mm_obj.mm_hash, + offsets: vec![(local_start, local_end)], + }); + } + } + } + } + } + + block_infos + } +} + /// Represents data for a stored block. #[derive(Serialize, Deserialize, Debug, Clone, PartialEq)] pub struct KvCacheStoredBlockData { @@ -283,6 +393,11 @@ pub struct KvCacheStoredBlockData { pub block_hash: ExternalSequenceBlockHash, /// The hash of the tokens in the block. pub tokens_hash: LocalBlockHash, + /// Extra multimodal metadata for this block + /// Note: Do NOT use skip_serializing_if with bincode - it breaks deserialization + /// because bincode is positional and expects all fields to be present. + #[serde(default)] + pub mm_extra_info: Option, } /// Represents the data associated with a removed cache event. @@ -365,6 +480,7 @@ mod tests { blocks: vec![KvCacheStoredBlockData { block_hash: ExternalSequenceBlockHash(2), tokens_hash: LocalBlockHash(3), + mm_extra_info: None, }], }); diff --git a/lib/llm/src/kv_router/publisher.rs b/lib/llm/src/kv_router/publisher.rs index b0b98e3eec..29bf5f5d38 100644 --- a/lib/llm/src/kv_router/publisher.rs +++ b/lib/llm/src/kv_router/publisher.rs @@ -543,6 +543,7 @@ fn convert_event( token_ids, block_size, lora_id, + block_mm_infos, .. } => { let num_block_tokens = vec![block_size as u64; block_hashes.len()]; @@ -563,6 +564,7 @@ fn convert_event( &block_hashes_u64, lora_id.unwrap_or(0), warning_count, + block_mm_infos.as_deref(), ), }), dp_rank, @@ -595,18 +597,25 @@ pub fn create_stored_block_from_parts( block_hash: u64, token_ids: &[u32], _lora_id: u64, + mm_extra_info: Option, ) -> KvCacheStoredBlockData { - let tokens_hash = compute_block_hash_for_seq(token_ids, kv_block_size)[0]; + // Compute tokens_hash including MM info if present + let block_mm_infos = mm_extra_info.as_ref().map(|info| vec![Some(info.clone())]); + let tokens_hash = + compute_block_hash_for_seq(token_ids, kv_block_size, block_mm_infos.as_deref())[0]; + tracing::trace!( - "Creating stored block: external_block_hash={}, tokens_hash={}, token_ids={:?}, kv_block_size={}", + "Creating stored block: external_block_hash={}, tokens_hash={}, token_ids={:?}, kv_block_size={}, mm_extra_info={:?}", block_hash, tokens_hash.0, token_ids, - kv_block_size + kv_block_size, + mm_extra_info ); KvCacheStoredBlockData { block_hash: ExternalSequenceBlockHash::from(block_hash), tokens_hash, + mm_extra_info, } } @@ -617,11 +626,14 @@ pub fn create_stored_blocks( block_hashes: &[u64], lora_id: u64, warning_count: &Arc, + block_mm_infos: Option<&[Option]>, ) -> Vec { let mut blocks: Vec = Vec::new(); let mut token_offset: usize = 0; - for (num_tokens_it, block_hash_it) in num_block_tokens.iter().zip(block_hashes.iter()) { + for (block_idx, (num_tokens_it, block_hash_it)) in + num_block_tokens.iter().zip(block_hashes.iter()).enumerate() + { if *num_tokens_it != kv_block_size as u64 { if warning_count.fetch_add(1, Ordering::Relaxed) < 3 { tracing::warn!( @@ -634,11 +646,16 @@ pub fn create_stored_blocks( } let tokens = &token_ids[token_offset..(token_offset + *num_tokens_it as usize)]; + let mm_extra_info = block_mm_infos + .and_then(|infos| infos.get(block_idx)) + .and_then(|opt| opt.clone()); + blocks.push(create_stored_block_from_parts( kv_block_size, *block_hash_it, tokens, lora_id, + mm_extra_info, )); token_offset += *num_tokens_it as usize; } @@ -702,6 +719,9 @@ enum RawKvEvent { lora_id: Option, #[serde(skip_serializing_if = "Option::is_none")] medium: Option, + /// Multimodal extra info for each block (length should match block_hashes) + #[serde(default, skip_serializing_if = "Option::is_none")] + block_mm_infos: Option>>, }, BlockRemoved { block_hashes: Vec, @@ -747,6 +767,7 @@ impl<'de> Visitor<'de> for RawKvEventVisitor { let mut block_size: Option = None; let mut lora_id: Option> = None; let mut medium: Option> = None; + let mut block_mm_infos: Option>>> = None; while let Some(key) = map.next_key::()? { match key.as_str() { @@ -771,6 +792,9 @@ impl<'de> Visitor<'de> for RawKvEventVisitor { "medium" => { medium = Some(map.next_value()?); } + "block_mm_infos" => { + block_mm_infos = Some(map.next_value()?); + } _ => { map.next_value::()?; } @@ -791,6 +815,7 @@ impl<'de> Visitor<'de> for RawKvEventVisitor { block_size, lora_id: lora_id.unwrap_or(None), medium: medium.unwrap_or(None), + block_mm_infos: block_mm_infos.unwrap_or(None), }) } Some("BlockRemoved") => { @@ -836,6 +861,8 @@ impl<'de> Visitor<'de> for RawKvEventVisitor { .ok_or_else(|| de::Error::invalid_length(4, &"missing block_size"))?; let lora_id: Option = seq.next_element()?.unwrap_or(None); let medium: Option = seq.next_element()?.unwrap_or(None); + let block_mm_infos: Option>> = + seq.next_element()?.unwrap_or(None); while seq.next_element::()?.is_some() {} @@ -846,6 +873,7 @@ impl<'de> Visitor<'de> for RawKvEventVisitor { block_size, lora_id, medium, + block_mm_infos, }) } "BlockRemoved" => { @@ -1088,11 +1116,12 @@ mod test_event_processing { let token_ids = vec![10, 20, 30, 40]; let blk_hash = 0xdead_beef; - let stored = create_stored_block_from_parts(kv_block_size, blk_hash, &token_ids, 0); + let stored = create_stored_block_from_parts(kv_block_size, blk_hash, &token_ids, 0, None); assert_eq!(stored.block_hash.0, blk_hash); - let expected_hash = compute_block_hash_for_seq(&token_ids, 4)[0]; + let expected_hash = compute_block_hash_for_seq(&token_ids, 4, None)[0]; assert_eq!(stored.tokens_hash, expected_hash); + assert!(stored.mm_extra_info.is_none()); } // --------------------------------------------------------------------- @@ -1113,6 +1142,7 @@ mod test_event_processing { &block_hashes, /*lora_id=*/ 0, &Arc::new(AtomicU32::new(0)), + None, ); assert_eq!(blocks.len(), 2); @@ -1136,6 +1166,7 @@ mod test_event_processing { &block_hashes, /*lora_id=*/ 0, &warning_count, + None, ); // should early-exit as second has mismatch @@ -1156,6 +1187,7 @@ mod test_event_processing { block_size: 4, lora_id: Some(0), medium: None, + block_mm_infos: None, }; let out = convert_event(raw_evt, 42, kv_block_size, 0, &Arc::new(AtomicU32::new(0))); @@ -1303,10 +1335,12 @@ mod tests_startup_helpers { KvCacheStoredBlockData { block_hash: ExternalSequenceBlockHash(100), tokens_hash: LocalBlockHash(200), + mm_extra_info: None, }, KvCacheStoredBlockData { block_hash: ExternalSequenceBlockHash(101), tokens_hash: LocalBlockHash(201), + mm_extra_info: None, }, ], }), @@ -1391,6 +1425,7 @@ mod tests_startup_helpers { blocks: vec![KvCacheStoredBlockData { block_hash: ExternalSequenceBlockHash(100), tokens_hash: LocalBlockHash(200), + mm_extra_info: None, }], }), dp_rank: 0, @@ -1471,6 +1506,7 @@ mod tests_startup_helpers { blocks: vec![KvCacheStoredBlockData { block_hash: ExternalSequenceBlockHash(100), tokens_hash: LocalBlockHash(200), + mm_extra_info: None, }], }), dp_rank: 0, @@ -1615,6 +1651,7 @@ mod tests_startup_helpers { block_size: 4, lora_id: None, medium: None, + block_mm_infos: None, }]; let batch = KvEventBatch { @@ -1705,10 +1742,12 @@ mod tests_startup_helpers { KvCacheStoredBlockData { block_hash: ExternalSequenceBlockHash(100), tokens_hash: LocalBlockHash(200), + mm_extra_info: None, }, KvCacheStoredBlockData { block_hash: ExternalSequenceBlockHash(101), tokens_hash: LocalBlockHash(201), + mm_extra_info: None, }, ], }), @@ -1769,10 +1808,12 @@ mod tests_startup_helpers { KvCacheStoredBlockData { block_hash: ExternalSequenceBlockHash(100), // Shared prefix tokens_hash: LocalBlockHash(200), + mm_extra_info: None, }, KvCacheStoredBlockData { block_hash: ExternalSequenceBlockHash(102), // New block tokens_hash: LocalBlockHash(202), + mm_extra_info: None, }, ], }), diff --git a/lib/llm/src/kv_router/recorder.rs b/lib/llm/src/kv_router/recorder.rs index 75fc1c9eb0..cb55118687 100644 --- a/lib/llm/src/kv_router/recorder.rs +++ b/lib/llm/src/kv_router/recorder.rs @@ -24,6 +24,7 @@ mod tests { .map(|i| KvCacheStoredBlockData { tokens_hash: LocalBlockHash(*i), block_hash: ExternalSequenceBlockHash(*i * 100), + mm_extra_info: None, }) .collect() } diff --git a/lib/llm/src/mocker/kv_manager.rs b/lib/llm/src/mocker/kv_manager.rs index a949139475..013deda3be 100644 --- a/lib/llm/src/mocker/kv_manager.rs +++ b/lib/llm/src/mocker/kv_manager.rs @@ -139,6 +139,7 @@ impl KvManager { .map(|(global_hash, local_hash)| KvCacheStoredBlockData { block_hash: ExternalSequenceBlockHash(global_hash), tokens_hash: LocalBlockHash(*local_hash), + mm_extra_info: None, }) .collect(), }) diff --git a/lib/llm/src/protocols/common/preprocessor.rs b/lib/llm/src/protocols/common/preprocessor.rs index a6844a2126..1b93f6deb2 100644 --- a/lib/llm/src/protocols/common/preprocessor.rs +++ b/lib/llm/src/protocols/common/preprocessor.rs @@ -8,7 +8,7 @@ use serde::{Deserialize, Serialize}; use super::timing::RequestTracker; use super::{OutputOptions, SamplingOptions, StopConditions}; -use crate::kv_router::RouterConfigOverride; +use crate::kv_router::{RouterConfigOverride, protocols::RequestExtraInfo}; #[cfg(feature = "media-nixl")] use crate::preprocessor::media::RdmaMediaDataDescriptor; use crate::protocols::TokenIdType; @@ -118,6 +118,10 @@ pub struct PreprocessedRequest { #[serde(default, skip_serializing_if = "Option::is_none")] pub extra_fields: Option>, + /// Multimodal request-level metadata (mm_hash and token offsets) + #[builder(default)] + #[serde(default, skip_serializing_if = "Option::is_none")] + pub request_extra_info: Option, /// Optional request tracker for per-request metrics (shared with DeltaGenerator) #[builder(default)] #[serde(skip)]