diff --git a/utils/llm/dspy_langfuse.py b/utils/llm/dspy_langfuse.py index af7c6e4..be0668b 100644 --- a/utils/llm/dspy_langfuse.py +++ b/utils/llm/dspy_langfuse.py @@ -4,7 +4,8 @@ from dspy.adapters import Image as dspy_Image from dspy.signatures import Signature as dspy_Signature from dspy.utils.callback import BaseCallback -from langfuse import Langfuse, LangfuseGeneration, get_client +from langfuse import Langfuse, LangfuseGeneration, LangfuseTool, get_client +from langfuse.types import TraceContext from litellm.cost_calculator import completion_cost from loguru import logger as log from pydantic import BaseModel, Field, ValidationError @@ -60,7 +61,7 @@ def __init__(self, signature: type[dspy_Signature]) -> None: self.input_field_values = contextvars.ContextVar[dict[str, Any]]( "input_field_values" ) - self.current_tool_span = contextvars.ContextVar[Any | None]("current_tool_span") + self.current_tool_span = contextvars.ContextVar[LangfuseTool | None]("current_tool_span") # Initialize Langfuse client self.langfuse: Langfuse = Langfuse() self.input_field_names = signature.input_fields.keys() @@ -137,11 +138,14 @@ def on_lm_start( # noqa parent_observation_id = get_client().get_current_observation_id() span_obj: LangfuseGeneration | None = None if trace_id: - span_obj = self.langfuse.generation( # type: ignore[attr-defined] + trace_context: TraceContext = {"trace_id": trace_id} + if parent_observation_id: + trace_context["parent_span_id"] = parent_observation_id + span_obj = self.langfuse.start_observation( + name=model_name or "unknown", + as_type="generation", input=user_input, - name=model_name, - trace_id=trace_id, - parent_observation_id=parent_observation_id, + trace_context=trace_context, metadata={ "model": model_name, "temperature": temperature, @@ -347,14 +351,13 @@ def on_lm_end( # noqa # --- Finalize Span --- if span: - end_args: dict[str, Any] = { - "output": completion_content, - "model": model_name, - "level": level, - "status_message": status_message, - } - # Langfuse client's `end` method handles None for these specific optional parameters. - span.end(**end_args) + span.update( + output=completion_content, + model=model_name, + level=level, + status_message=status_message, + ) + span.end() self.current_span.set(None) if level == "DEFAULT" and completion_content is not None: @@ -396,11 +399,14 @@ def on_tool_start( # noqa if trace_id: # Create a span for the tool call - tool_span = self.langfuse.span( # type: ignore[attr-defined] + trace_context: TraceContext = {"trace_id": trace_id} + if parent_observation_id: + trace_context["parent_span_id"] = parent_observation_id + tool_span = self.langfuse.start_observation( name=f"tool:{tool_name}", - trace_id=trace_id, - parent_observation_id=parent_observation_id, + as_type="tool", input=tool_args, + trace_context=trace_context, metadata={ "tool_name": tool_name, "tool_type": "function", @@ -437,11 +443,12 @@ def on_tool_end( # noqa except (TypeError, AttributeError, RecursionError) as e: output_value = {"serialization_error": str(e), "raw": str(outputs)} - tool_span.end( + tool_span.update( output=output_value, level=level, status_message=status_message, ) + tool_span.end() self.current_tool_span.set(None) log.debug(f"Tool call ended with output: {str(output_value)[:100]}...")