diff --git a/components/src/dynamo/common/utils/input_params.py b/components/src/dynamo/common/utils/input_params.py index 7201101306..896da904ba 100644 --- a/components/src/dynamo/common/utils/input_params.py +++ b/components/src/dynamo/common/utils/input_params.py @@ -12,7 +12,6 @@ def get_input_param(self, request: dict, use_tokenizer: bool): """ if use_tokenizer: - print(f"Request: {request}") if self.tokenizer is None: raise ValueError("Tokenizer is not available") @@ -21,10 +20,9 @@ def get_input_param(self, request: dict, use_tokenizer: bool): request["messages"], tokenize=False, add_generation_prompt=True ) elif "prompt" in request: - return request["prompt"] + return self.tokenizer.encode(request["prompt"]) elif "text" in request: - return request["text"] + return self.tokenizer.encode(request["text"]) else: raise ValueError("No input parameter found in request") - return request.get("token_ids") diff --git a/components/src/dynamo/vllm/handlers.py b/components/src/dynamo/vllm/handlers.py index f8d17c7369..6167a1786b 100644 --- a/components/src/dynamo/vllm/handlers.py +++ b/components/src/dynamo/vllm/handlers.py @@ -758,7 +758,6 @@ async def generate_tokens( logger.debug( f"Starting token generation for request {request_id} (no LoRA)" ) - gen = self.engine_client.generate( prompt, sampling_params, @@ -947,12 +946,15 @@ async def _generate_token_mode(self, request, context, request_id): async def _generate_text_mode(self, request, context, request_id): """Generate text using OpenAI-compatible format (text-in-text-out).""" # Get text input using InputParamManager - input_text = self.input_param_manager.get_input_param( + input_data = self.input_param_manager.get_input_param( request, use_tokenizer=True ) # Build prompt for vLLM - prompt = TextPrompt(prompt=input_text) + if isinstance(input_data, list): + prompt = TokensPrompt(prompt_token_ids=input_data) + else: + prompt = TextPrompt(prompt=input_data) # Build sampling params from OpenAI-style request sampling_params = build_sampling_params_openai( @@ -1050,14 +1052,9 @@ async def generate(self, request, context): request_id = context.id() logger.debug(f"Prefill Request ID: {request_id}") - if self.use_vllm_tokenizer: - # Text-in-text-out mode: use InputParamManager - async for chunk in self._generate_text_mode(request, context, request_id): - yield chunk - else: - # Token-in-token-out mode: internal protocol format - async for chunk in self._generate_token_mode(request, context, request_id): - yield chunk + # Token-in-token-out mode: internal protocol format + async for chunk in self._generate_token_mode(request, context, request_id): + yield chunk async def _generate_token_mode(self, request, context, request_id): """Generate prefill using internal protocol format (token-in-token-out).""" @@ -1164,77 +1161,3 @@ async def _generate_token_mode(self, request, context, request_id): raise GeneratorExit( "Prefill engine was shut down during token generation" ) from None - - async def _generate_text_mode(self, request, context, request_id): - """Generate prefill using OpenAI-compatible format (text-in-text-out).""" - # Get text input using InputParamManager - input_text = self.input_param_manager.get_input_param( - request, use_tokenizer=True - ) - - # Build prompt for vLLM - prompt = TextPrompt(prompt=input_text) - - # Build sampling params from OpenAI-style request - sampling_params = build_sampling_params_openai( - request, self.default_sampling_params - ) - sampling_params.detokenize = False # Prefill doesn't need detokenization - - # Configure for prefill-only mode with remote decode - if sampling_params.extra_args is None: - sampling_params.extra_args = {} - sampling_params.extra_args["kv_transfer_params"] = { - "do_remote_decode": True, - } - sampling_params_defaults = { - "do_remote_prefill": False, - "remote_engine_id": None, - "remote_block_ids": None, - "remote_host": None, - "remote_port": None, - } - # Add only missing keys - for k, v in sampling_params_defaults.items(): - sampling_params.extra_args["kv_transfer_params"].setdefault(k, v) - # Override for prefill: only generate 1 token - sampling_params.max_tokens = 1 - sampling_params.min_tokens = 1 - - dp_rank = request.get("dp_rank", None) - - async with self._abort_monitor(context, request_id, is_prefill=True): - try: - gen = self.engine_client.generate( - prompt, sampling_params, request_id, data_parallel_rank=dp_rank - ) - except EngineDeadError as e: - logger.error(f"vLLM EngineDeadError: {e}") - logger.warning("Initiating Dynamo Runtime shutdown.") - self.runtime.shutdown() - os._exit(1) - - try: - async for res in gen: - logger.debug(f"kv transfer params: {res.kv_transfer_params}") - - token_ids = res.outputs[0].token_ids if res.outputs else [] - - output: Dict[str, Any] = { - "token_ids": list(token_ids), - "disaggregated_params": ( - {"kv_transfer_params": res.kv_transfer_params} - if res.kv_transfer_params - else None - ), - "completion_usage": BaseWorkerHandler._build_completion_usage( - request_output=res - ), - } - - yield output - except asyncio.CancelledError: - # raise the error because we cannot migrate prefill requests - raise GeneratorExit( - "Prefill engine was shut down during token generation" - ) from None diff --git a/lib/llm/src/model_card.rs b/lib/llm/src/model_card.rs index 77fc780f6c..4b764e0782 100644 --- a/lib/llm/src/model_card.rs +++ b/lib/llm/src/model_card.rs @@ -162,6 +162,11 @@ impl GenerationConfig { } } +/// Check if our model only has config fields for a Mistral-format model. +fn is_exclusively_mistral_model(directory: &Path) -> bool { + !directory.join("config.json").exists() && directory.join("params.json").exists() +} + #[derive(Serialize, Deserialize, Clone, Debug, Builder, Default)] pub struct ModelDeploymentCard { /// Human readable model name, e.g. "Meta Llama 3.1 8B Instruct" @@ -352,7 +357,9 @@ impl ModelDeploymentCard { .with_context(|| p.display().to_string()) } None => { - anyhow::bail!("Blank ModelDeploymentCard does not have a tokenizer"); + anyhow::bail!( + "Blank ModelDeploymentCard does not have a tokenizer. Is this a mistral model? If so, the `--use--tokenizer` flag in the engine command is required." + ); } } } @@ -497,8 +504,23 @@ impl ModelDeploymentCard { // If neither of those are present let the engine default it .unwrap_or(0); + let is_mistral_model = is_exclusively_mistral_model(local_path); + + let (model_info, tokenizer, gen_config, prompt_formatter) = if !is_mistral_model { + ( + Some(ModelInfoType::from_disk(local_path)?), + Some(TokenizerKind::from_disk(local_path)?), + GenerationConfig::from_disk(local_path).ok(), + PromptFormatterArtifact::from_disk(local_path)?, + ) + } else { + (None, None, None, None) + }; + // Load chat template - either custom or from repo - let chat_template_file = if let Some(template_path) = custom_template_path { + let chat_template_file = if is_mistral_model { + None + } else if let Some(template_path) = custom_template_path { if !template_path.exists() { anyhow::bail!( "Custom template file does not exist: {}", @@ -525,10 +547,10 @@ impl ModelDeploymentCard { Ok(Self { slug: Slug::from_string(&display_name), display_name, - model_info: Some(ModelInfoType::from_disk(local_path)?), - tokenizer: Some(TokenizerKind::from_disk(local_path)?), - gen_config: GenerationConfig::from_disk(local_path).ok(), // optional - prompt_formatter: PromptFormatterArtifact::from_disk(local_path)?, + model_info, + tokenizer, + gen_config, + prompt_formatter, chat_template_file, prompt_context: None, // TODO - auto-detect prompt context context_length,