Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
76 changes: 52 additions & 24 deletions extract_thinker/llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,22 +6,35 @@
from extract_thinker.llm_engine import LLMEngine
from extract_thinker.utils import add_classification_structure, extract_thinking_json

# Add these constants at the top of the file, after the imports
DYNAMIC_PROMPT_TEMPLATE = """Please provide your thinking process within <think> tags, followed by your JSON output.
# Helper to build the dynamic prompt used when `is_dynamic=True`.
# We expose it as a standalone function so that callers (or subclasses)
# can supply their own variants if needed.

JSON structure:
{prompt}

OUTPUT example:
<think>
Your step-by-step reasoning and analysis goes here...
</think>
def build_dynamic_prompt(structure: str, *, think_tag: str = "think") -> str:
"""Return the dynamic prompt used for classification style requests.

##JSON OUTPUT
{{
...
}}
"""
Args:
structure: The JSON structure/fields to be returned by the model.
think_tag: The XML-style tag that should wrap the model's chain-of-thought.

This helper allows downstream users to customise the surrounding text
(for example, changing the tag name or adding extra instructions) rather
than editing a hard-coded string inside *llm.py*.
"""

return (
f"Please provide your thinking process within <{think_tag}> tags, "
"followed by your JSON output.\n\n"
"JSON structure:\n"
f"{structure}\n\n"
"OUTPUT example:\n"
f"<{think_tag}>\n"
"Your step-by-step reasoning and analysis goes here...\n"
f"</{think_tag}>\n\n"
"##JSON OUTPUT\n"
"{\n ...\n}" # placeholder keeps JSON fence out of model context
)

class LLM:
TIMEOUT = 3000 # Timeout in milliseconds
Expand All @@ -34,6 +47,11 @@ class LLM:
MIN_THINKING_BUDGET = 1200 # Minimum thinking budget
DEFAULT_OUTPUT_TOKENS = 32000

# A single default completion-token limit that is accepted by the vast
# majority of models. If a model supports more (or you need fewer), pass
# `token_limit=` when instantiating `LLM` to override this value.
DEFAULT_MAX_COMPLETION_TOKENS = 8000

def __init__(
self,
model: str,
Expand Down Expand Up @@ -194,7 +212,7 @@ def request(
working_messages = messages.copy()
if self.is_dynamic and response_model:
structure = add_classification_structure(response_model)
prompt = DYNAMIC_PROMPT_TEMPLATE.format(prompt=structure)
prompt = build_dynamic_prompt(structure)
working_messages.append({
"role": "system",
"content": prompt
Expand All @@ -219,11 +237,11 @@ def request(

def _request_with_router(self, messages: List[Dict[str, str]], response_model: Optional[str]) -> Any:
"""Handle request using router with or without thinking parameter"""
max_tokens = self.DEFAULT_OUTPUT_TOKENS
max_tokens = self._get_model_max_tokens()
if self.token_limit is not None:
max_tokens = self.token_limit
max_tokens = min(self.token_limit, max_tokens)
elif self.is_thinking:
max_tokens = self.thinking_token_limit
max_tokens = min(self.thinking_token_limit, max_tokens) if self.thinking_token_limit else max_tokens

params = {
"model": self.model,
Expand All @@ -248,11 +266,11 @@ def _request_with_router(self, messages: List[Dict[str, str]], response_model: O

def _request_direct(self, messages: List[Dict[str, str]], response_model: Optional[str]) -> Any:
"""Handle direct request with or without thinking parameter"""
max_tokens = self.DEFAULT_OUTPUT_TOKENS
max_tokens = self._get_model_max_tokens()
if self.token_limit is not None:
max_tokens = self.token_limit
max_tokens = min(self.token_limit, max_tokens)
elif self.is_thinking:
max_tokens = self.thinking_token_limit
max_tokens = min(self.thinking_token_limit, max_tokens) if self.thinking_token_limit else max_tokens

base_params = {
"model": self.model,
Expand Down Expand Up @@ -293,11 +311,11 @@ def raw_completion(self, messages: List[Dict[str, str]]) -> str:
except Exception as e:
raise ValueError(f"Failed to extract from source: {str(e)}")

max_tokens = self.DEFAULT_OUTPUT_TOKENS
max_tokens = self._get_model_max_tokens()
if self.token_limit is not None:
max_tokens = self.token_limit
max_tokens = min(self.token_limit, max_tokens)
elif self.is_thinking:
max_tokens = self.thinking_token_limit
max_tokens = min(self.thinking_token_limit, max_tokens) if self.thinking_token_limit else max_tokens

params = {
"model": self.model,
Expand Down Expand Up @@ -325,4 +343,14 @@ def raw_completion(self, messages: List[Dict[str, str]]) -> str:

def set_timeout(self, timeout_ms: int) -> None:
"""Set the timeout value for LLM requests in milliseconds."""
self.TIMEOUT = timeout_ms
self.TIMEOUT = timeout_ms

def _get_model_max_tokens(self) -> int:
"""Return the default maximum completion-token limit.

This constant (DEFAULT_MAX_COMPLETION_TOKENS) is meant to work for ~99 %
of models. If you need a different value, supply `token_limit=` when
creating the `LLM` instance.
"""

return self.DEFAULT_MAX_COMPLETION_TOKENS
10 changes: 5 additions & 5 deletions tests/test_llm_backends.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

def test_litellm_backend():
"""Test default LiteLLM backend"""
llm = LLM("gpt-4", backend=llm_engine.LITELLM)
llm = LLM("gpt-4o", backend=llm_engine.LITELLM)
assert llm.backend == llm_engine.LITELLM
assert llm.client is not None
assert llm.agent is None
Expand All @@ -12,7 +12,7 @@ def test_pydanticai_backend():
"""Test PydanticAI backend if available"""
try:
import pydantic_ai
llm = LLM("gpt-4", backend=llm_engine.PYDANTIC_AI)
llm = LLM("gpt-4o", backend=llm_engine.PYDANTIC_AI)
assert llm.backend == llm_engine.PYDANTIC_AI
assert llm.client is None
assert llm.agent is not None
Expand All @@ -22,13 +22,13 @@ def test_pydanticai_backend():
def test_invalid_backend():
"""Test invalid backend type raises error"""
with pytest.raises(TypeError):
LLM("gpt-4", backend="invalid") # Should be LLMBackend enum
LLM("gpt-4o", backend="invalid") # Should be LLMBackend enum

def test_router_with_pydanticai():
"""Test router not supported with PydanticAI"""
from litellm import Router
router = Router(model_list=[{"model_name": "gpt-4"}])
router = Router(model_list=[{"model_name": "gpt-4o"}])

llm = LLM("gpt-4", backend=llm_engine.PYDANTIC_AI)
llm = LLM("gpt-4o", backend=llm_engine.PYDANTIC_AI)
with pytest.raises(ValueError, match="Router is only supported with LITELLM backend"):
llm.load_router(router)
Loading