diff --git a/src/openlayer/lib/tracing/__init__.py b/src/openlayer/lib/tracing/__init__.py index c63a5b11..e10c90dc 100644 --- a/src/openlayer/lib/tracing/__init__.py +++ b/src/openlayer/lib/tracing/__init__.py @@ -9,6 +9,7 @@ log_attachment, log_context, log_output, + log_question, trace, trace_async, update_current_step, @@ -23,6 +24,7 @@ "update_current_step", "log_context", "log_output", + "log_question", "configure", "get_current_trace", "get_current_step", diff --git a/src/openlayer/lib/tracing/tracer.py b/src/openlayer/lib/tracing/tracer.py index 8d7b8e31..1a4c87f3 100644 --- a/src/openlayer/lib/tracing/tracer.py +++ b/src/openlayer/lib/tracing/tracer.py @@ -236,6 +236,7 @@ def _get_client() -> Optional[Openlayer]: _current_step = contextvars.ContextVar("current_step") _current_trace = contextvars.ContextVar("current_trace") _rag_context = contextvars.ContextVar("rag_context") +_rag_question = contextvars.ContextVar("rag_question") # ----------------------------- Offline Buffer Implementation ----------------------------- # @@ -459,6 +460,11 @@ def get_rag_context() -> Optional[Dict[str, Any]]: return _rag_context.get(None) +def get_rag_question() -> Optional[str]: + """Returns the current question.""" + return _rag_question.get(None) + + @contextmanager def create_step( name: str, @@ -515,6 +521,7 @@ def trace( *step_args, inference_pipeline_id: Optional[str] = None, context_kwarg: Optional[str] = None, + question_kwarg: Optional[str] = None, guardrails: Optional[List[Any]] = None, on_flush_failure: Optional[OnFlushFailureCallback] = None, **step_kwargs, @@ -605,6 +612,7 @@ def __next__(self): func_args=func_args, func_kwargs=func_kwargs, context_kwarg=context_kwarg, + question_kwarg=question_kwarg, ) self._trace_initialized = True @@ -699,6 +707,7 @@ def wrapper(*func_args, **func_kwargs): func_args=func_args, func_kwargs=func_kwargs, context_kwarg=context_kwarg, + question_kwarg=question_kwarg, ) # Apply input guardrails @@ -785,6 +794,7 @@ def wrapper(*func_args, **func_kwargs): context_kwarg=context_kwarg, output=output, guardrail_metadata=guardrail_metadata, + question_kwarg=question_kwarg, ) if exception is not None: @@ -800,6 +810,7 @@ def trace_async( *step_args, inference_pipeline_id: Optional[str] = None, context_kwarg: Optional[str] = None, + question_kwarg: Optional[str] = None, guardrails: Optional[List[Any]] = None, on_flush_failure: Optional[OnFlushFailureCallback] = None, **step_kwargs, @@ -873,6 +884,7 @@ async def __anext__(self): func_args=func_args, func_kwargs=func_kwargs, context_kwarg=context_kwarg, + question_kwarg=question_kwarg, ) self._trace_initialized = True @@ -935,6 +947,7 @@ async def async_function_wrapper(*func_args, **func_kwargs): func_args=func_args, func_kwargs=func_kwargs, context_kwarg=context_kwarg, + question_kwarg=question_kwarg, ) # Process inputs through guardrails @@ -990,6 +1003,7 @@ async def async_function_wrapper(*func_args, **func_kwargs): func_args=func_args, func_kwargs=func_kwargs, context_kwarg=context_kwarg, + question_kwarg=question_kwarg, ), ) ) @@ -1010,6 +1024,7 @@ async def async_function_wrapper(*func_args, **func_kwargs): context_kwarg=context_kwarg, output=output, guardrail_metadata=guardrail_metadata, + question_kwarg=question_kwarg, ) return output @@ -1035,6 +1050,7 @@ def sync_wrapper(*func_args, **func_kwargs): func_args=func_args, func_kwargs=func_kwargs, context_kwarg=context_kwarg, + question_kwarg=question_kwarg, ) # Process inputs through guardrails @@ -1087,6 +1103,7 @@ def sync_wrapper(*func_args, **func_kwargs): func_args=func_args, func_kwargs=func_kwargs, context_kwarg=context_kwarg, + question_kwarg=question_kwarg, ), ) guardrail_metadata.update(output_metadata) @@ -1106,6 +1123,7 @@ def sync_wrapper(*func_args, **func_kwargs): context_kwarg=context_kwarg, output=output, guardrail_metadata=guardrail_metadata, + question_kwarg=question_kwarg, ) if exception is not None: @@ -1147,6 +1165,18 @@ def log_context(context: List[str]) -> None: logger.warning("No current step found to log context.") +def log_question(question: str) -> None: + """Logs the question to the current step of the trace. + + The `question` parameter should be the user query string for RAG use cases.""" + current_step = get_current_step() + if current_step: + _rag_question.set(question) + current_step.log(metadata={"_question": question}) + else: + logger.warning("No current step found to log question.") + + def log_attachment( data: Union[bytes, str, Path, Any], name: Optional[str] = None, @@ -1630,6 +1660,8 @@ def _upload_and_publish_trace( config.update({"ground_truth_column_name": "groundTruth"}) if "context" in trace_data: config.update({"context_column_name": "context"}) + if "_question" in trace_data: + config.update({"question_column_name": "_question"}) if prompt is not None: config.update({"prompt": prompt}) @@ -1729,6 +1761,7 @@ def _process_wrapper_inputs_and_outputs( context_kwarg: Optional[str], output: Any, guardrail_metadata: Optional[Dict[str, Any]] = None, + question_kwarg: Optional[str] = None, ) -> None: """Extract function inputs and finalize step logging - common pattern across wrappers.""" @@ -1737,6 +1770,7 @@ def _process_wrapper_inputs_and_outputs( func_args=func_args, func_kwargs=func_kwargs, context_kwarg=context_kwarg, + question_kwarg=question_kwarg, ) _finalize_step_logging( step=step, @@ -1752,6 +1786,7 @@ def _extract_function_inputs( func_args: tuple, func_kwargs: dict, context_kwarg: Optional[str] = None, + question_kwarg: Optional[str] = None, ) -> dict: """Extract and clean function inputs for logging.""" bound = func_signature.bind(*func_args, **func_kwargs) @@ -1770,6 +1805,16 @@ def _extract_function_inputs( context_kwarg, ) + # Handle question kwarg if specified + if question_kwarg: + if question_kwarg in inputs: + log_question(inputs.get(question_kwarg)) + else: + logger.warning( + "Question kwarg `%s` not found in inputs of the current function.", + question_kwarg, + ) + return inputs @@ -1955,6 +2000,10 @@ def post_process_trace( if context: trace_data["context"] = context + question = get_rag_question() + if question: + trace_data["_question"] = question + return trace_data, input_variable_names