diff --git a/nemoguardrails/integrations/langchain/message_utils.py b/nemoguardrails/integrations/langchain/message_utils.py index e0adea0ba..2686ea10d 100644 --- a/nemoguardrails/integrations/langchain/message_utils.py +++ b/nemoguardrails/integrations/langchain/message_utils.py @@ -263,7 +263,7 @@ def create_tool_message( status: Optional[str] = None, ) -> ToolMessage: """Create a ToolMessage with optional fields.""" - kwargs = {"tool_call_id": tool_call_id} + kwargs: Dict[str, Any] = {"tool_call_id": tool_call_id} if name is not None: kwargs["name"] = name if additional_kwargs is not None: diff --git a/nemoguardrails/integrations/langchain/runnable_rails.py b/nemoguardrails/integrations/langchain/runnable_rails.py index 21da3eac4..5efa396ff 100644 --- a/nemoguardrails/integrations/langchain/runnable_rails.py +++ b/nemoguardrails/integrations/langchain/runnable_rails.py @@ -16,9 +16,22 @@ from __future__ import annotations import logging -from typing import Any, AsyncIterator, Dict, Iterator, List, Optional, Union +from typing import ( + Any, + AsyncIterator, + Callable, + Dict, + Iterator, + List, + Optional, + Union, + cast, +) from langchain_core.language_models import BaseLanguageModel +from langchain_core.language_models.chat_models import BaseChatModel +from langchain_core.language_models.llms import BaseLLM +from langchain_core.messages import AIMessage from langchain_core.prompt_values import ChatPromptValue, StringPromptValue from langchain_core.runnables import Runnable, RunnableConfig from langchain_core.runnables.utils import Input, Output, gather_with_concurrency @@ -33,7 +46,7 @@ message_to_dict, ) from nemoguardrails.integrations.langchain.utils import async_wrap -from nemoguardrails.rails.llm.options import GenerationOptions +from nemoguardrails.rails.llm.options import GenerationOptions, GenerationResponse logger = logging.getLogger(__name__) @@ -62,7 +75,7 @@ class RunnableRails(Runnable[Input, Output]): def __init__( self, config: RailsConfig, - llm: Optional[BaseLanguageModel] = None, + llm: Optional[Union[BaseLLM, BaseChatModel]] = None, tools: Optional[List[Tool]] = None, passthrough: bool = True, runnable: Optional[Runnable] = None, @@ -72,7 +85,7 @@ def __init__( input_blocked_message: str = "I cannot process this request.", output_blocked_message: str = "I cannot provide this response.", ) -> None: - self.llm = llm + self.llm: Optional[Union[BaseLLM, BaseChatModel, Runnable[Any, Any]]] = llm self.passthrough = passthrough self.passthrough_runnable = runnable self.passthrough_user_input_key = input_key @@ -103,23 +116,30 @@ def __init__( self.passthrough = False for tool in tools: - self.rails.register_action(tool, tool.name) + # Tool is callable at runtime via __call__, cast for type checker + self.rails.register_action(cast(Callable[..., Any], tool), tool.name) # If we have a passthrough Runnable, we need to register a passthrough fn # that will call it if self.passthrough_runnable: self._init_passthrough_fn() - def _init_passthrough_fn(self): + def _init_passthrough_fn(self) -> None: """Initialize the passthrough function for the LLM rails instance.""" + # Capture the runnable in the closure (we know it's not None when this is called) + passthrough_runnable = self.passthrough_runnable + if not passthrough_runnable: + raise RuntimeError("Passthrough function could not be initialized") - async def passthrough_fn(context: dict, events: List[dict]): + async def passthrough_fn(context: dict, events: List[dict]) -> tuple[str, Any]: # First, we fetch the input from the context _input = context.get("passthrough_input") - if hasattr(self.passthrough_runnable, "ainvoke"): - _output = await self.passthrough_runnable.ainvoke(_input, self.config, **self.kwargs) + if _input is None: + raise ValueError("passthrough_input not found in context") + if hasattr(passthrough_runnable, "ainvoke"): + _output = await passthrough_runnable.ainvoke(_input, self.config, **self.kwargs) else: - async_wrapped_invoke = async_wrap(self.passthrough_runnable.invoke) + async_wrapped_invoke = async_wrap(passthrough_runnable.invoke) _output = await async_wrapped_invoke(_input, self.config, **self.kwargs) # If the output is a string, we consider it to be the output text @@ -132,9 +152,13 @@ async def passthrough_fn(context: dict, events: List[dict]): return text, _output - self.rails.llm_generation_actions.passthrough_fn = passthrough_fn + # The passthrough_fn type in LLMGenerationActions is declared as Awaitable[str] + # but actually handles tuple returns (see generation.py lines 487-491) + self.rails.llm_generation_actions.passthrough_fn = passthrough_fn # type: ignore[assignment] - def __or__(self, other: Union[BaseLanguageModel, Runnable[Any, Any]]) -> Union["RunnableRails", Runnable[Any, Any]]: + def __or__( # type: ignore[override] + self, other: Union[BaseLanguageModel, Runnable[Any, Any]] + ) -> Union["RunnableRails[Input, Output]", Runnable[Any, Any]]: """Chain this runnable with another, returning a new runnable. This method handles two different cases: @@ -145,7 +169,7 @@ def __or__(self, other: Union[BaseLanguageModel, Runnable[Any, Any]]) -> Union[" This ensures associativity in complex chains. """ - if isinstance(other, BaseLanguageModel): + if isinstance(other, (BaseLLM, BaseChatModel)): # Case 1: Set the LLM for this RunnableRails self.llm = other self.rails.update_llm(other) @@ -155,14 +179,11 @@ def __or__(self, other: Union[BaseLanguageModel, Runnable[Any, Any]]) -> Union[" # Case 2: Check if this is a RunnableBinding that wraps a BaseLanguageModel # This happens when you call llm.bind_tools([...]) - the result is a RunnableBinding # that wraps the original LLM but is no longer a BaseLanguageModel instance - if ( - hasattr(other, "bound") - and hasattr(other.bound, "__class__") - and isinstance(other.bound, BaseLanguageModel) - ): - # This is an LLM with tools bound to it - treat it as an LLM, not passthrough + bound = getattr(other, "bound", None) + if bound is not None and isinstance(bound, (BaseLLM, BaseChatModel)): + # This is an LLM with tools bound to it - store the binding, update rails with unwrapped LLM self.llm = other - self.rails.update_llm(other) + self.rails.update_llm(bound) return self if self.passthrough_runnable is None: @@ -187,21 +208,28 @@ def OutputType(self) -> Any: """The type of the output of this runnable as a type annotation.""" return Any - def get_name(self, suffix: str = "") -> str: + def get_name( + self, + suffix: Optional[str] = None, + *, + name: Optional[str] = None, + ) -> str: """Get the name of this runnable.""" - name = "RunnableRails" + base_name = name if name else "RunnableRails" if suffix: - name += suffix - return name + base_name += suffix + return base_name - def _extract_text_from_input(self, _input) -> str: + def _extract_text_from_input(self, _input: Any) -> str: """Extract text content from various input types for passthrough mode.""" if isinstance(_input, str): return _input elif is_base_message(_input): - return _input.content + content = _input.content + return str(content) if content else "" elif isinstance(_input, dict) and self.passthrough_user_input_key in _input: - return _input.get(self.passthrough_user_input_key) + value = _input.get(self.passthrough_user_input_key) + return str(value) if value is not None else "" else: return str(_input) @@ -329,7 +357,9 @@ def _extract_content_from_result(self, result: Any) -> str: def _get_bot_message(self, result: Any, context: Dict[str, Any]) -> str: """Extract the bot message from context or result.""" - return context.get("bot_message", result.get("content") if isinstance(result, dict) else result) + default_value = result.get("content") if isinstance(result, dict) else result + bot_message = context.get("bot_message", default_value) + return str(bot_message) if bot_message is not None else "" def _format_passthrough_output(self, result: Any, context: Dict[str, Any]) -> Any: """Format output for passthrough mode.""" @@ -607,7 +637,7 @@ def _convert_messages_to_rails_format(self, messages) -> List[dict]: rails_messages.append({"role": "user", "content": str(msg)}) return rails_messages - def _extract_output_content(self, output: Output) -> str: + def _extract_output_content(self, output: Any) -> str: """Extract content from output for rails checking.""" if isinstance(output, str): return output @@ -637,8 +667,18 @@ def _full_rails_invoke( # Generate response from rails res = self.rails.generate(messages=input_messages, options=GenerationOptions(output_vars=True)) - context = res.output_data - result = res.response + + if isinstance(res, GenerationResponse): + context = res.output_data or {} + result = res.response + tool_calls = res.tool_calls + llm_metadata = res.llm_metadata + else: + # For duck-typed objects (including mocks in tests) + context = getattr(res, "output_data", None) or {} + result = getattr(res, "response", res) + tool_calls = getattr(res, "tool_calls", None) + llm_metadata = getattr(res, "llm_metadata", None) # If more than one message is returned, we only take the first one. # This can happen for advanced use cases, e.g., when the LLM could predict @@ -647,7 +687,7 @@ def _full_rails_invoke( result = result[0] # Format and return the output based in input type - return self._format_output(input, result, context, res.tool_calls, res.llm_metadata) + return self._format_output(input, result, context, tool_calls, llm_metadata) async def ainvoke( self, @@ -701,11 +741,23 @@ async def _full_rails_ainvoke( # Generate response from rails asynchronously res = await self.rails.generate_async(messages=input_messages, options=GenerationOptions(output_vars=True)) - context = res.output_data - result = res.response + + # With options specified, we get a GenerationResponse + # Also handle mock objects for testing + if isinstance(res, GenerationResponse): + context = res.output_data or {} + result = res.response + tool_calls = res.tool_calls + llm_metadata = res.llm_metadata + else: + # For duck-typed objects (including mocks in tests) + context = getattr(res, "output_data", None) or {} + result = getattr(res, "response", res) + tool_calls = getattr(res, "tool_calls", None) + llm_metadata = getattr(res, "llm_metadata", None) # Format and return the output based on input type - return self._format_output(input, result, context, res.tool_calls, res.llm_metadata) + return self._format_output(input, result, context, tool_calls, llm_metadata) def stream( self, @@ -752,11 +804,12 @@ async def astream( input_messages = self._transform_input_to_rails_format(input) - original_streaming = getattr(self.rails.llm, "streaming", False) + llm = self.rails.llm + original_streaming = getattr(llm, "streaming", False) if llm else False streaming_enabled = False - if hasattr(self.rails.llm, "streaming") and not original_streaming: - self.rails.llm.streaming = True + if llm is not None and hasattr(llm, "streaming") and not original_streaming: + setattr(llm, "streaming", True) streaming_enabled = True try: @@ -772,8 +825,8 @@ async def astream( formatted_chunk = self._format_streaming_chunk(input, chunk) yield formatted_chunk finally: - if streaming_enabled and hasattr(self.rails.llm, "streaming"): - self.rails.llm.streaming = original_streaming + if streaming_enabled and llm is not None and hasattr(llm, "streaming"): + setattr(llm, "streaming", original_streaming) def _format_streaming_chunk(self, input: Any, chunk) -> Any: """Format a streaming chunk based on the input type. @@ -825,17 +878,20 @@ def _format_streaming_chunk(self, input: Any, chunk) -> Any: def batch( self, inputs: List[Input], - config: Optional[RunnableConfig] = None, + config: Optional[Union[RunnableConfig, List[RunnableConfig]]] = None, **kwargs: Optional[Any], ) -> List[Output]: """Batch inputs and process them synchronously.""" # Process inputs sequentially to maintain state consistency - return [self.invoke(input, config, **kwargs) for input in inputs] + # Handle both single config and list of configs + if isinstance(config, list): + return [self.invoke(inp, cfg, **kwargs) for inp, cfg in zip(inputs, config)] + return [self.invoke(inp, config, **kwargs) for inp in inputs] async def abatch( self, inputs: List[Input], - config: Optional[RunnableConfig] = None, + config: Optional[Union[RunnableConfig, List[RunnableConfig]]] = None, **kwargs: Optional[Any], ) -> List[Output]: """Batch inputs and process them asynchronously. @@ -843,15 +899,24 @@ async def abatch( Concurrency is controlled via config['max_concurrency'] following LangChain best practices. """ max_concurrency = None + if isinstance(config, list): + # Use first config for max_concurrency setting + if config and "max_concurrency" in config[0]: + max_concurrency = config[0]["max_concurrency"] + # When config is a list, pair each input with its config + return await gather_with_concurrency( + max_concurrency, + *[self.ainvoke(input_item, cfg, **kwargs) for input_item, cfg in zip(inputs, config)], + ) + if config and "max_concurrency" in config: max_concurrency = config["max_concurrency"] - return await gather_with_concurrency( max_concurrency, *[self.ainvoke(input_item, config, **kwargs) for input_item in inputs], ) - def transform( + def transform( # type: ignore[override] self, input: Input, config: Optional[RunnableConfig] = None, @@ -859,11 +924,12 @@ def transform( ) -> Output: """Transform the input. - This is just an alias for invoke. + This is just an alias for invoke. Note: This intentionally differs from + the parent class signature which expects an Iterator. """ return self.invoke(input, config, **kwargs) - async def atransform( + async def atransform( # type: ignore[override] self, input: Input, config: Optional[RunnableConfig] = None, @@ -871,6 +937,7 @@ async def atransform( ) -> Output: """Transform the input asynchronously. - This is just an alias for ainvoke. + This is just an alias for ainvoke. Note: This intentionally differs from + the parent class signature which expects an AsyncIterator. """ return await self.ainvoke(input, config, **kwargs) diff --git a/nemoguardrails/tracing/adapters/filesystem.py b/nemoguardrails/tracing/adapters/filesystem.py index 6c4ecfed3..d7ed93386 100644 --- a/nemoguardrails/tracing/adapters/filesystem.py +++ b/nemoguardrails/tracing/adapters/filesystem.py @@ -60,7 +60,7 @@ def transform(self, interaction_log: "InteractionLog"): async def transform_async(self, interaction_log: "InteractionLog"): try: - import aiofiles + import aiofiles # type: ignore[import-not-found] except ImportError: raise ImportError( "aiofiles is required for async file writing. Please install it using `pip install aiofiles`" diff --git a/pyproject.toml b/pyproject.toml index ef6bf9218..ab62496cd 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -164,6 +164,7 @@ include = [ "nemoguardrails/server/**", "tests/test_callbacks.py", "nemoguardrails/benchmark/**", + "nemoguardrails/integrations/**", ] exclude = [ "nemoguardrails/llm/providers/trtllm/**",