Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion nemoguardrails/integrations/langchain/message_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -263,7 +263,7 @@ def create_tool_message(
status: Optional[str] = None,
) -> ToolMessage:
"""Create a ToolMessage with optional fields."""
kwargs = {"tool_call_id": tool_call_id}
kwargs: Dict[str, Any] = {"tool_call_id": tool_call_id}
if name is not None:
kwargs["name"] = name
if additional_kwargs is not None:
Expand Down
163 changes: 115 additions & 48 deletions nemoguardrails/integrations/langchain/runnable_rails.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,22 @@
from __future__ import annotations

import logging
from typing import Any, AsyncIterator, Dict, Iterator, List, Optional, Union
from typing import (
Any,
AsyncIterator,
Callable,
Dict,
Iterator,
List,
Optional,
Union,
cast,
)

from langchain_core.language_models import BaseLanguageModel
from langchain_core.language_models.chat_models import BaseChatModel
from langchain_core.language_models.llms import BaseLLM
from langchain_core.messages import AIMessage
from langchain_core.prompt_values import ChatPromptValue, StringPromptValue
from langchain_core.runnables import Runnable, RunnableConfig
from langchain_core.runnables.utils import Input, Output, gather_with_concurrency
Expand All @@ -33,7 +46,7 @@
message_to_dict,
)
from nemoguardrails.integrations.langchain.utils import async_wrap
from nemoguardrails.rails.llm.options import GenerationOptions
from nemoguardrails.rails.llm.options import GenerationOptions, GenerationResponse

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -62,7 +75,7 @@ class RunnableRails(Runnable[Input, Output]):
def __init__(
self,
config: RailsConfig,
llm: Optional[BaseLanguageModel] = None,
llm: Optional[Union[BaseLLM, BaseChatModel]] = None,
tools: Optional[List[Tool]] = None,
passthrough: bool = True,
runnable: Optional[Runnable] = None,
Expand All @@ -72,7 +85,7 @@ def __init__(
input_blocked_message: str = "I cannot process this request.",
output_blocked_message: str = "I cannot provide this response.",
) -> None:
self.llm = llm
self.llm: Optional[Union[BaseLLM, BaseChatModel, Runnable[Any, Any]]] = llm
self.passthrough = passthrough
self.passthrough_runnable = runnable
self.passthrough_user_input_key = input_key
Expand Down Expand Up @@ -103,23 +116,30 @@ def __init__(
self.passthrough = False

for tool in tools:
self.rails.register_action(tool, tool.name)
# Tool is callable at runtime via __call__, cast for type checker
self.rails.register_action(cast(Callable[..., Any], tool), tool.name)

# If we have a passthrough Runnable, we need to register a passthrough fn
# that will call it
if self.passthrough_runnable:
self._init_passthrough_fn()

def _init_passthrough_fn(self):
def _init_passthrough_fn(self) -> None:
"""Initialize the passthrough function for the LLM rails instance."""
# Capture the runnable in the closure (we know it's not None when this is called)
passthrough_runnable = self.passthrough_runnable
if not passthrough_runnable:
raise RuntimeError("Passthrough function could not be initialized")

async def passthrough_fn(context: dict, events: List[dict]):
async def passthrough_fn(context: dict, events: List[dict]) -> tuple[str, Any]:
# First, we fetch the input from the context
_input = context.get("passthrough_input")
if hasattr(self.passthrough_runnable, "ainvoke"):
_output = await self.passthrough_runnable.ainvoke(_input, self.config, **self.kwargs)
if _input is None:
raise ValueError("passthrough_input not found in context")
if hasattr(passthrough_runnable, "ainvoke"):
_output = await passthrough_runnable.ainvoke(_input, self.config, **self.kwargs)
else:
async_wrapped_invoke = async_wrap(self.passthrough_runnable.invoke)
async_wrapped_invoke = async_wrap(passthrough_runnable.invoke)
_output = await async_wrapped_invoke(_input, self.config, **self.kwargs)

# If the output is a string, we consider it to be the output text
Expand All @@ -132,9 +152,13 @@ async def passthrough_fn(context: dict, events: List[dict]):

return text, _output

self.rails.llm_generation_actions.passthrough_fn = passthrough_fn
# The passthrough_fn type in LLMGenerationActions is declared as Awaitable[str]
# but actually handles tuple returns (see generation.py lines 487-491)
self.rails.llm_generation_actions.passthrough_fn = passthrough_fn # type: ignore[assignment]

def __or__(self, other: Union[BaseLanguageModel, Runnable[Any, Any]]) -> Union["RunnableRails", Runnable[Any, Any]]:
def __or__( # type: ignore[override]
self, other: Union[BaseLanguageModel, Runnable[Any, Any]]
) -> Union["RunnableRails[Input, Output]", Runnable[Any, Any]]:
"""Chain this runnable with another, returning a new runnable.

This method handles two different cases:
Expand All @@ -145,7 +169,7 @@ def __or__(self, other: Union[BaseLanguageModel, Runnable[Any, Any]]) -> Union["

This ensures associativity in complex chains.
"""
if isinstance(other, BaseLanguageModel):
if isinstance(other, (BaseLLM, BaseChatModel)):
# Case 1: Set the LLM for this RunnableRails
self.llm = other
self.rails.update_llm(other)
Expand All @@ -155,14 +179,11 @@ def __or__(self, other: Union[BaseLanguageModel, Runnable[Any, Any]]) -> Union["
# Case 2: Check if this is a RunnableBinding that wraps a BaseLanguageModel
# This happens when you call llm.bind_tools([...]) - the result is a RunnableBinding
# that wraps the original LLM but is no longer a BaseLanguageModel instance
if (
hasattr(other, "bound")
and hasattr(other.bound, "__class__")
and isinstance(other.bound, BaseLanguageModel)
):
# This is an LLM with tools bound to it - treat it as an LLM, not passthrough
bound = getattr(other, "bound", None)
if bound is not None and isinstance(bound, (BaseLLM, BaseChatModel)):
# This is an LLM with tools bound to it - store the binding, update rails with unwrapped LLM
self.llm = other
self.rails.update_llm(other)
self.rails.update_llm(bound)
Comment on lines 185 to +186
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

style: Check that tool bindings are preserved when using RunnableBinding

When an LLM has tools bound via .bind_tools(), the result is a RunnableBinding that wraps the LLM. You store the full binding in self.llm but pass only the unwrapped bound to rails.update_llm(). Verify that tool calls still work correctly in test scenarios with bound tools.

Prompt To Fix With AI
This is a comment left during a code review.
Path: nemoguardrails/integrations/langchain/runnable_rails.py
Line: 185:186

Comment:
**style:** Check that tool bindings are preserved when using `RunnableBinding`

When an LLM has tools bound via `.bind_tools()`, the result is a `RunnableBinding` that wraps the LLM. You store the full binding in `self.llm` but pass only the unwrapped `bound` to `rails.update_llm()`. Verify that tool calls still work correctly in test scenarios with bound tools.

How can I resolve this? If you propose a fix, please make it concise.

return self

if self.passthrough_runnable is None:
Expand All @@ -187,21 +208,28 @@ def OutputType(self) -> Any:
"""The type of the output of this runnable as a type annotation."""
return Any

def get_name(self, suffix: str = "") -> str:
def get_name(
self,
suffix: Optional[str] = None,
*,
name: Optional[str] = None,
) -> str:
"""Get the name of this runnable."""
name = "RunnableRails"
base_name = name if name else "RunnableRails"
if suffix:
name += suffix
return name
base_name += suffix
return base_name

def _extract_text_from_input(self, _input) -> str:
def _extract_text_from_input(self, _input: Any) -> str:
"""Extract text content from various input types for passthrough mode."""
if isinstance(_input, str):
return _input
elif is_base_message(_input):
return _input.content
content = _input.content
return str(content) if content else ""
elif isinstance(_input, dict) and self.passthrough_user_input_key in _input:
return _input.get(self.passthrough_user_input_key)
value = _input.get(self.passthrough_user_input_key)
return str(value) if value is not None else ""
else:
return str(_input)

Expand Down Expand Up @@ -329,7 +357,9 @@ def _extract_content_from_result(self, result: Any) -> str:

def _get_bot_message(self, result: Any, context: Dict[str, Any]) -> str:
"""Extract the bot message from context or result."""
return context.get("bot_message", result.get("content") if isinstance(result, dict) else result)
default_value = result.get("content") if isinstance(result, dict) else result
bot_message = context.get("bot_message", default_value)
return str(bot_message) if bot_message is not None else ""

def _format_passthrough_output(self, result: Any, context: Dict[str, Any]) -> Any:
"""Format output for passthrough mode."""
Expand Down Expand Up @@ -607,7 +637,7 @@ def _convert_messages_to_rails_format(self, messages) -> List[dict]:
rails_messages.append({"role": "user", "content": str(msg)})
return rails_messages

def _extract_output_content(self, output: Output) -> str:
def _extract_output_content(self, output: Any) -> str:
"""Extract content from output for rails checking."""
if isinstance(output, str):
return output
Expand Down Expand Up @@ -637,8 +667,18 @@ def _full_rails_invoke(

# Generate response from rails
res = self.rails.generate(messages=input_messages, options=GenerationOptions(output_vars=True))
context = res.output_data
result = res.response

if isinstance(res, GenerationResponse):
context = res.output_data or {}
result = res.response
tool_calls = res.tool_calls
llm_metadata = res.llm_metadata
else:
# For duck-typed objects (including mocks in tests)
context = getattr(res, "output_data", None) or {}
result = getattr(res, "response", res)
tool_calls = getattr(res, "tool_calls", None)
llm_metadata = getattr(res, "llm_metadata", None)

# If more than one message is returned, we only take the first one.
# This can happen for advanced use cases, e.g., when the LLM could predict
Expand All @@ -647,7 +687,7 @@ def _full_rails_invoke(
result = result[0]

# Format and return the output based in input type
return self._format_output(input, result, context, res.tool_calls, res.llm_metadata)
return self._format_output(input, result, context, tool_calls, llm_metadata)

async def ainvoke(
self,
Expand Down Expand Up @@ -701,11 +741,23 @@ async def _full_rails_ainvoke(

# Generate response from rails asynchronously
res = await self.rails.generate_async(messages=input_messages, options=GenerationOptions(output_vars=True))
context = res.output_data
result = res.response

# With options specified, we get a GenerationResponse
# Also handle mock objects for testing
if isinstance(res, GenerationResponse):
context = res.output_data or {}
result = res.response
tool_calls = res.tool_calls
llm_metadata = res.llm_metadata
else:
# For duck-typed objects (including mocks in tests)
context = getattr(res, "output_data", None) or {}
result = getattr(res, "response", res)
tool_calls = getattr(res, "tool_calls", None)
llm_metadata = getattr(res, "llm_metadata", None)

# Format and return the output based on input type
return self._format_output(input, result, context, res.tool_calls, res.llm_metadata)
return self._format_output(input, result, context, tool_calls, llm_metadata)

def stream(
self,
Expand Down Expand Up @@ -752,11 +804,12 @@ async def astream(

input_messages = self._transform_input_to_rails_format(input)

original_streaming = getattr(self.rails.llm, "streaming", False)
llm = self.rails.llm
original_streaming = getattr(llm, "streaming", False) if llm else False
streaming_enabled = False

if hasattr(self.rails.llm, "streaming") and not original_streaming:
self.rails.llm.streaming = True
if llm is not None and hasattr(llm, "streaming") and not original_streaming:
setattr(llm, "streaming", True)
streaming_enabled = True

try:
Expand All @@ -772,8 +825,8 @@ async def astream(
formatted_chunk = self._format_streaming_chunk(input, chunk)
yield formatted_chunk
finally:
if streaming_enabled and hasattr(self.rails.llm, "streaming"):
self.rails.llm.streaming = original_streaming
if streaming_enabled and llm is not None and hasattr(llm, "streaming"):
setattr(llm, "streaming", original_streaming)

def _format_streaming_chunk(self, input: Any, chunk) -> Any:
"""Format a streaming chunk based on the input type.
Expand Down Expand Up @@ -825,52 +878,66 @@ def _format_streaming_chunk(self, input: Any, chunk) -> Any:
def batch(
self,
inputs: List[Input],
config: Optional[RunnableConfig] = None,
config: Optional[Union[RunnableConfig, List[RunnableConfig]]] = None,
**kwargs: Optional[Any],
) -> List[Output]:
"""Batch inputs and process them synchronously."""
# Process inputs sequentially to maintain state consistency
return [self.invoke(input, config, **kwargs) for input in inputs]
# Handle both single config and list of configs
if isinstance(config, list):
return [self.invoke(inp, cfg, **kwargs) for inp, cfg in zip(inputs, config)]
return [self.invoke(inp, config, **kwargs) for inp in inputs]

async def abatch(
self,
inputs: List[Input],
config: Optional[RunnableConfig] = None,
config: Optional[Union[RunnableConfig, List[RunnableConfig]]] = None,
**kwargs: Optional[Any],
) -> List[Output]:
"""Batch inputs and process them asynchronously.

Concurrency is controlled via config['max_concurrency'] following LangChain best practices.
"""
max_concurrency = None
if isinstance(config, list):
# Use first config for max_concurrency setting
if config and "max_concurrency" in config[0]:
max_concurrency = config[0]["max_concurrency"]
# When config is a list, pair each input with its config
return await gather_with_concurrency(
max_concurrency,
*[self.ainvoke(input_item, cfg, **kwargs) for input_item, cfg in zip(inputs, config)],
)

if config and "max_concurrency" in config:
max_concurrency = config["max_concurrency"]

return await gather_with_concurrency(
max_concurrency,
*[self.ainvoke(input_item, config, **kwargs) for input_item in inputs],
)

def transform(
def transform( # type: ignore[override]
self,
input: Input,
config: Optional[RunnableConfig] = None,
**kwargs: Optional[Any],
) -> Output:
"""Transform the input.

This is just an alias for invoke.
This is just an alias for invoke. Note: This intentionally differs from
the parent class signature which expects an Iterator.
"""
return self.invoke(input, config, **kwargs)

async def atransform(
async def atransform( # type: ignore[override]
self,
input: Input,
config: Optional[RunnableConfig] = None,
**kwargs: Optional[Any],
) -> Output:
"""Transform the input asynchronously.

This is just an alias for ainvoke.
This is just an alias for ainvoke. Note: This intentionally differs from
the parent class signature which expects an AsyncIterator.
"""
return await self.ainvoke(input, config, **kwargs)
2 changes: 1 addition & 1 deletion nemoguardrails/tracing/adapters/filesystem.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ def transform(self, interaction_log: "InteractionLog"):

async def transform_async(self, interaction_log: "InteractionLog"):
try:
import aiofiles
import aiofiles # type: ignore[import-not-found]
except ImportError:
raise ImportError(
"aiofiles is required for async file writing. Please install it using `pip install aiofiles`"
Expand Down
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -164,6 +164,7 @@ include = [
"nemoguardrails/server/**",
"tests/test_callbacks.py",
"nemoguardrails/benchmark/**",
"nemoguardrails/integrations/**",
]
exclude = [
"nemoguardrails/llm/providers/trtllm/**",
Expand Down