Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
43 changes: 25 additions & 18 deletions utils/llm/dspy_langfuse.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,8 @@
from dspy.adapters import Image as dspy_Image
from dspy.signatures import Signature as dspy_Signature
from dspy.utils.callback import BaseCallback
from langfuse import Langfuse, LangfuseGeneration, get_client
from langfuse import Langfuse, LangfuseGeneration, LangfuseTool, get_client
from langfuse.types import TraceContext
from litellm.cost_calculator import completion_cost
from loguru import logger as log
from pydantic import BaseModel, Field, ValidationError
Expand Down Expand Up @@ -60,7 +61,7 @@ def __init__(self, signature: type[dspy_Signature]) -> None:
self.input_field_values = contextvars.ContextVar[dict[str, Any]](
"input_field_values"
)
self.current_tool_span = contextvars.ContextVar[Any | None]("current_tool_span")
self.current_tool_span = contextvars.ContextVar[LangfuseTool | None]("current_tool_span")
# Initialize Langfuse client
self.langfuse: Langfuse = Langfuse()
self.input_field_names = signature.input_fields.keys()
Expand Down Expand Up @@ -137,11 +138,14 @@ def on_lm_start( # noqa
parent_observation_id = get_client().get_current_observation_id()
span_obj: LangfuseGeneration | None = None
if trace_id:
span_obj = self.langfuse.generation( # type: ignore[attr-defined]
trace_context: TraceContext = {"trace_id": trace_id}
if parent_observation_id:
trace_context["parent_span_id"] = parent_observation_id
span_obj = self.langfuse.start_observation(
name=model_name or "unknown",
as_type="generation",
input=user_input,
name=model_name,
trace_id=trace_id,
parent_observation_id=parent_observation_id,
trace_context=trace_context,
metadata={
"model": model_name,
"temperature": temperature,
Expand Down Expand Up @@ -347,14 +351,13 @@ def on_lm_end( # noqa

# --- Finalize Span ---
if span:
end_args: dict[str, Any] = {
"output": completion_content,
"model": model_name,
"level": level,
"status_message": status_message,
}
# Langfuse client's `end` method handles None for these specific optional parameters.
span.end(**end_args)
span.update(
output=completion_content,
model=model_name,
level=level,
status_message=status_message,
)
span.end()
self.current_span.set(None)

if level == "DEFAULT" and completion_content is not None:
Expand Down Expand Up @@ -396,11 +399,14 @@ def on_tool_start( # noqa

if trace_id:
# Create a span for the tool call
tool_span = self.langfuse.span( # type: ignore[attr-defined]
trace_context: TraceContext = {"trace_id": trace_id}
if parent_observation_id:
trace_context["parent_span_id"] = parent_observation_id
tool_span = self.langfuse.start_observation(
name=f"tool:{tool_name}",
trace_id=trace_id,
parent_observation_id=parent_observation_id,
as_type="tool",
input=tool_args,
trace_context=trace_context,
metadata={
"tool_name": tool_name,
"tool_type": "function",
Expand Down Expand Up @@ -437,11 +443,12 @@ def on_tool_end( # noqa
except (TypeError, AttributeError, RecursionError) as e:
output_value = {"serialization_error": str(e), "raw": str(outputs)}

tool_span.end(
tool_span.update(
output=output_value,
level=level,
status_message=status_message,
)
tool_span.end()
self.current_tool_span.set(None)

log.debug(f"Tool call ended with output: {str(output_value)[:100]}...")
Loading