diff --git a/nemoguardrails/actions/llm/generation.py b/nemoguardrails/actions/llm/generation.py
index c7d390aaf..039b1888b 100644
--- a/nemoguardrails/actions/llm/generation.py
+++ b/nemoguardrails/actions/llm/generation.py
@@ -887,9 +887,22 @@ async def generate_bot_message(
# of course, it does not work when passed as context in `run_output_rails_in_streaming`
# streaming_handler is set when stream_async method is used
+ has_streaming_handler = streaming_handler is not None
+ has_output_streaming = (
+ self.config.rails.output.streaming
+ and self.config.rails.output.streaming.enabled
+ )
+ log.info(
+ f"generate_bot_message: streaming_handler={has_streaming_handler}, output.streaming={has_output_streaming}"
+ )
# if streaming_handler and len(self.config.rails.output.flows) > 0:
- if streaming_handler and self.config.rails.output.streaming.enabled:
+ if streaming_handler and has_output_streaming:
+ log.info("Setting skip_output_rails = True")
context_updates["skip_output_rails"] = True
+ else:
+ log.info(
+ f"NOT setting skip_output_rails: streaming_handler={has_streaming_handler}, output.streaming={has_output_streaming}"
+ )
if bot_intent in self.config.bot_messages:
# Choose a message randomly from self.config.bot_messages[bot_message]
@@ -970,7 +983,9 @@ async def generate_bot_message(
new_event_dict("BotMessage", text=text)
)
- return ActionResult(events=output_events)
+ return ActionResult(
+ events=output_events, context_updates=context_updates
+ )
else:
if streaming_handler:
await streaming_handler.push_chunk(
@@ -987,7 +1002,9 @@ async def generate_bot_message(
)
output_events.append(bot_message_event)
- return ActionResult(events=output_events)
+ return ActionResult(
+ events=output_events, context_updates=context_updates
+ )
# If we are in passthrough mode, we just use the input for prompting
if self.config.passthrough:
diff --git a/nemoguardrails/colang/v1_0/runtime/flows.py b/nemoguardrails/colang/v1_0/runtime/flows.py
index 9f4f67791..64e81dcc6 100644
--- a/nemoguardrails/colang/v1_0/runtime/flows.py
+++ b/nemoguardrails/colang/v1_0/runtime/flows.py
@@ -15,6 +15,7 @@
"""A simplified modeling of the CoFlows engine."""
+import logging
import uuid
from dataclasses import dataclass, field
from enum import Enum
@@ -25,6 +26,8 @@
from nemoguardrails.colang.v1_0.runtime.sliding import slide
from nemoguardrails.utils import new_event_dict, new_uuid
+log = logging.getLogger(__name__)
+
@dataclass
class FlowConfig:
@@ -356,7 +359,15 @@ def compute_next_state(state: State, event: dict) -> State:
if event["type"] == "ContextUpdate":
# TODO: add support to also remove keys from the context.
# maybe with a special context key e.g. "__remove__": ["key1", "key2"]
+ if "skip_output_rails" in event["data"]:
+ log.info(
+ f"ContextUpdate setting skip_output_rails={event['data']['skip_output_rails']}"
+ )
state.context.update(event["data"])
+ if "skip_output_rails" in state.context:
+ log.info(
+ f"After update, context skip_output_rails={state.context.get('skip_output_rails')}"
+ )
state.context_updates = {}
state.next_step = None
return state
@@ -415,6 +426,12 @@ def compute_next_state(state: State, event: dict) -> State:
_record_next_step(new_state, flow_state, flow_config, priority_modifier=0.9)
continue
+ # Debug logging for BotMessage event and skip_output_rails
+ if event["type"] == "BotMessage":
+ log.info(
+ f"BotMessage event processing for flow '{flow_config.id}', skip_output_rails in context: {flow_state.context.get('skip_output_rails', 'NOT SET')}"
+ )
+
# If we're at a branching point, we look at all individual heads.
matching_head = None
diff --git a/nemoguardrails/colang/v1_0/runtime/runtime.py b/nemoguardrails/colang/v1_0/runtime/runtime.py
index ffb9c3e64..3a8859dda 100644
--- a/nemoguardrails/colang/v1_0/runtime/runtime.py
+++ b/nemoguardrails/colang/v1_0/runtime/runtime.py
@@ -523,6 +523,9 @@ async def _run_output_rails_in_parallel_streaming(
flows_with_params: Dictionary mapping flow_id to {"action_name": str, "params": dict}
events: The events list for context
"""
+ # Compute context from events so actions can access bot_message
+ context = compute_context(events)
+
tasks = []
async def run_single_rail(flow_id: str, action_info: dict) -> tuple:
@@ -532,8 +535,11 @@ async def run_single_rail(flow_id: str, action_info: dict) -> tuple:
action_name = action_info["action_name"]
params = action_info["params"]
+ # Merge context into params so actions have access to bot_message
+ params_with_context = {**params, "context": context}
+
result_tuple = await self.action_dispatcher.execute_action(
- action_name, params
+ action_name, params_with_context
)
result, status = result_tuple
@@ -731,10 +737,19 @@ async def _process_start_action(self, events: List[dict]) -> List[dict]:
return_events = []
context_updates = {}
+ if action_name == "generate_bot_message":
+ log.info(
+ f"DEBUG: generate_bot_message returned, isinstance(ActionResult)={isinstance(result, ActionResult)}"
+ )
+
if isinstance(result, ActionResult):
return_value = result.return_value
return_events = result.events
context_updates.update(result.context_updates)
+ if action_name == "generate_bot_message":
+ log.info(
+ f"generate_bot_message ActionResult: context_updates={context_updates}, skip_output_rails={'skip_output_rails' in context_updates}"
+ )
# If we have an action result key, we also record the update.
if action_result_key:
diff --git a/nemoguardrails/rails/llm/config_loader.py b/nemoguardrails/rails/llm/config_loader.py
new file mode 100644
index 000000000..9c1d9f136
--- /dev/null
+++ b/nemoguardrails/rails/llm/config_loader.py
@@ -0,0 +1,215 @@
+# SPDX-FileCopyrightText: Copyright (c) 2023-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+# SPDX-License-Identifier: Apache-2.0
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""Config loader for loading default flows, library content, and config.py modules."""
+
+import importlib.util
+import logging
+import os
+from typing import List, Any
+
+from nemoguardrails.colang import parse_colang_file
+from nemoguardrails.colang.v1_0.runtime.flows import _normalize_flow_id
+from nemoguardrails.rails.llm.config import RailsConfig
+
+log = logging.getLogger(__name__)
+
+
+class ConfigLoader:
+ """Enriches RailsConfig with default flows, library content, and config.py modules."""
+
+ @staticmethod
+ def load_config(config: RailsConfig) -> List[Any]:
+ """Enrich the config with default flows and library content.
+
+ Args:
+ config: The rails configuration to enrich.
+
+ Returns:
+ List of config modules loaded from config.py files.
+ """
+ # Load default flows for Colang 1.0
+ if config.colang_version == "1.0":
+ ConfigLoader._load_default_flows(config)
+ ConfigLoader._load_library_content(config)
+
+ # Mark rail flows as system flows
+ ConfigLoader._mark_rail_flows_as_system(config)
+
+ # Load and execute config.py modules
+ config_modules = ConfigLoader._load_config_modules(config)
+
+ # Validate config
+ ConfigLoader._validate_config(config)
+
+ return config_modules
+
+ @staticmethod
+ def _load_default_flows(config: RailsConfig):
+ """Load default LLM flows for Colang 1.0.
+
+ Args:
+ config: The rails configuration.
+ """
+ current_folder = os.path.dirname(__file__)
+ default_flows_file = "llm_flows.co"
+ default_flows_path = os.path.join(current_folder, default_flows_file)
+
+ with open(default_flows_path, "r") as f:
+ default_flows_content = f.read()
+ default_flows = parse_colang_file(
+ default_flows_file, default_flows_content
+ )["flows"]
+
+ # Mark all default flows as system flows
+ for flow_config in default_flows:
+ flow_config["is_system_flow"] = True
+
+ # Add default flows to config
+ config.flows.extend(default_flows)
+ log.debug(f"Loaded {len(default_flows)} default flows")
+
+ @staticmethod
+ def _load_library_content(config: RailsConfig):
+ """Load content from the components library.
+
+ Args:
+ config: The rails configuration.
+ """
+ library_path = os.path.join(os.path.dirname(__file__), "../../library")
+ loaded_files = 0
+
+ for root, dirs, files in os.walk(library_path):
+ for file in files:
+ full_path = os.path.join(root, file)
+ if file.endswith(".co"):
+ log.debug(f"Loading library file: {full_path}")
+ with open(full_path, "r", encoding="utf-8") as f:
+ content = parse_colang_file(
+ file, content=f.read(), version=config.colang_version
+ )
+ if not content:
+ continue
+
+ # Mark all library flows as system flows
+ for flow_config in content["flows"]:
+ flow_config["is_system_flow"] = True
+
+ # Load all the flows
+ config.flows.extend(content["flows"])
+
+ # Load bot messages if not overwritten
+ for message_id, utterances in content.get(
+ "bot_messages", {}
+ ).items():
+ if message_id not in config.bot_messages:
+ config.bot_messages[message_id] = utterances
+
+ loaded_files += 1
+
+ log.debug(f"Loaded {loaded_files} library files")
+
+ @staticmethod
+ def _mark_rail_flows_as_system(config: RailsConfig):
+ """Mark all flows used in rails as system flows.
+
+ Args:
+ config: The rails configuration.
+ """
+ rail_flow_ids = (
+ config.rails.input.flows
+ + config.rails.output.flows
+ + config.rails.retrieval.flows
+ )
+
+ for flow_config in config.flows:
+ if flow_config.get("id") in rail_flow_ids:
+ flow_config["is_system_flow"] = True
+ # Mark them as subflows by default to simplify syntax
+ flow_config["is_subflow"] = True
+
+ @staticmethod
+ def _load_config_modules(config: RailsConfig) -> List[AttributeError]:
+ """Load and execute config.py modules.
+
+ Args:
+ config: The rails configuration.
+
+ Returns:
+ List of loaded config modules.
+ """
+ config_modules = []
+ paths = list(
+ config.imported_paths.values() if config.imported_paths else []
+ ) + [config.config_path]
+
+ for _path in paths:
+ if _path:
+ filepath = os.path.join(_path, "config.py")
+ if os.path.exists(filepath):
+ filename = os.path.basename(filepath)
+ spec = importlib.util.spec_from_file_location(filename, filepath)
+ if spec and spec.loader:
+ config_module = importlib.util.module_from_spec(spec)
+ spec.loader.exec_module(config_module)
+ config_modules.append(config_module)
+ log.debug(f"Loaded config module from: {filepath}")
+
+ return config_modules
+
+ @staticmethod
+ def _validate_config(config: RailsConfig):
+ """Run validation checks on the config.
+
+ Args:
+ config: The rails configuration to validate.
+
+ Raises:
+ ValueError: If validation fails.
+ """
+ if config.colang_version == "1.0":
+ existing_flows_names = set([flow.get("id") for flow in config.flows])
+ else:
+ existing_flows_names = set([flow.get("name") for flow in config.flows])
+
+ # Validate input rail flows
+ for flow_name in config.rails.input.flows:
+ flow_name = _normalize_flow_id(flow_name)
+ if flow_name not in existing_flows_names:
+ raise ValueError(
+ f"The provided input rail flow `{flow_name}` does not exist"
+ )
+
+ # Validate output rail flows
+ for flow_name in config.rails.output.flows:
+ flow_name = _normalize_flow_id(flow_name)
+ if flow_name not in existing_flows_names:
+ raise ValueError(
+ f"The provided output rail flow `{flow_name}` does not exist"
+ )
+
+ # Validate retrieval rail flows
+ for flow_name in config.rails.retrieval.flows:
+ if flow_name not in existing_flows_names:
+ raise ValueError(
+ f"The provided retrieval rail flow `{flow_name}` does not exist"
+ )
+
+ # Check for conflicting modes
+ if config.passthrough and config.rails.dialog.single_call.enabled:
+ raise ValueError(
+ "The passthrough mode and the single call dialog rails mode can't be used at the same time. "
+ "The single call mode needs to use an altered prompt when prompting the LLM."
+ )
diff --git a/nemoguardrails/rails/llm/event_translator.py b/nemoguardrails/rails/llm/event_translator.py
new file mode 100644
index 000000000..d67f8fbe0
--- /dev/null
+++ b/nemoguardrails/rails/llm/event_translator.py
@@ -0,0 +1,246 @@
+# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+# SPDX-License-Identifier: Apache-2.0
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""Event translator for converting between messages and events."""
+
+import logging
+from typing import Any, Dict, List, Optional
+
+from nemoguardrails.colang.v2_x.runtime.flows import Action, State
+from nemoguardrails.rails.llm.config import RailsConfig
+from nemoguardrails.rails.llm.utils import get_history_cache_key
+from nemoguardrails.utils import new_event_dict, new_uuid
+
+log = logging.getLogger(__name__)
+
+
+class EventTranslator:
+ """Translates between messages and Colang events."""
+
+ def __init__(self, config: RailsConfig):
+ """Initialize the EventTranslator.
+
+ Args:
+ config: The rails configuration.
+ """
+ self.config = config
+ self.events_history_cache: Dict[str, List[dict]] = {}
+
+ def messages_to_events(
+ self, messages: List[dict[str, Any]], state: Optional[Any] = None
+ ) -> List[dict]:
+ """Convert messages to events.
+
+ Tries to find a prefix of messages for which we have already a list of events
+ in the cache. For the rest, they are converted as is.
+
+ Args:
+ messages: The list of messages.
+ state: Optional state object (used for Colang 2.x).
+
+ Returns:
+ A list of events.
+ """
+ events = []
+
+ if self.config.colang_version == "1.0":
+ events = self._messages_to_events_v1(messages)
+ else:
+ events = self._messages_to_events_v2(messages, state)
+
+ return events
+
+ def _messages_to_events_v1(self, messages: List[dict]) -> List[dict]:
+ """Convert messages to events for Colang 1.0.
+
+ Args:
+ messages: The list of messages.
+
+ Returns:
+ A list of events.
+ """
+ events = []
+
+ # Try to find the longest prefix of messages for which we have a cache
+ p = len(messages) - 1
+ while p > 0:
+ cache_key = get_history_cache_key(messages[0:p])
+ if cache_key in self.events_history_cache:
+ events = self.events_history_cache[cache_key].copy()
+ break
+ p -= 1
+
+ # For the rest of the messages, transform them directly into events
+ for idx in range(p, len(messages)):
+ msg = messages[idx]
+ if msg["role"] == "user":
+ events.append(
+ {
+ "type": "UtteranceUserActionFinished",
+ "final_transcript": msg["content"],
+ }
+ )
+
+ # If it's not the last message, also add the `UserMessage` event
+ if idx != len(messages) - 1:
+ events.append(
+ {
+ "type": "UserMessage",
+ "text": msg["content"],
+ }
+ )
+
+ elif msg["role"] == "assistant":
+ if msg.get("tool_calls"):
+ events.append(
+ {"type": "BotToolCalls", "tool_calls": msg["tool_calls"]}
+ )
+ else:
+ action_uid = new_uuid()
+ start_event = new_event_dict(
+ "StartUtteranceBotAction",
+ script=msg["content"],
+ action_uid=action_uid,
+ )
+ finished_event = new_event_dict(
+ "UtteranceBotActionFinished",
+ final_script=msg["content"],
+ is_success=True,
+ action_uid=action_uid,
+ )
+ events.extend([start_event, finished_event])
+
+ elif msg["role"] == "context":
+ events.append({"type": "ContextUpdate", "data": msg["content"]})
+
+ elif msg["role"] == "event":
+ events.append(msg["event"])
+
+ elif msg["role"] == "system":
+ events.append({"type": "SystemMessage", "content": msg["content"]})
+
+ elif msg["role"] == "tool":
+ # For the last tool message, create grouped tool event or synthetic UserMessage
+ if idx == len(messages) - 1:
+ # Find the original user message for response generation
+ user_message = None
+ for prev_msg in reversed(messages[:idx]):
+ if prev_msg["role"] == "user":
+ user_message = prev_msg["content"]
+ break
+
+ if user_message:
+ # If tool input rails are configured, group all tool messages
+ if self.config.rails.tool_input.flows:
+ tool_messages = []
+ for tool_idx in range(len(messages)):
+ if messages[tool_idx]["role"] == "tool":
+ tool_messages.append(
+ {
+ "content": messages[tool_idx]["content"],
+ "name": messages[tool_idx].get(
+ "name", "unknown"
+ ),
+ "tool_call_id": messages[tool_idx].get(
+ "tool_call_id", ""
+ ),
+ }
+ )
+
+ events.append(
+ {
+ "type": "UserToolMessages",
+ "tool_messages": tool_messages,
+ }
+ )
+ else:
+ events.append({"type": "UserMessage", "text": user_message})
+
+ return events
+
+ def _messages_to_events_v2(
+ self, messages: List[dict[str, Any]], state: Optional[Any]
+ ) -> List[dict]:
+ """Convert messages to events for Colang 2.x.
+
+ Args:
+ messages: The list of messages.
+ state: The state object.
+
+ Returns:
+ A list of events.
+ """
+ events = []
+
+ for idx in range(len(messages)):
+ msg = messages[idx]
+ if msg["role"] == "user":
+ events.append(
+ {
+ "type": "UtteranceUserActionFinished",
+ "final_transcript": msg["content"],
+ }
+ )
+
+ elif msg["role"] == "assistant":
+ raise ValueError(
+ "Providing `assistant` messages as input is not supported for Colang 2.0 configurations."
+ )
+
+ elif msg["role"] == "context":
+ events.append({"type": "ContextUpdate", "data": msg["content"]})
+
+ elif msg["role"] == "event":
+ events.append(msg["event"])
+
+ elif msg["role"] == "system":
+ events.append({"type": "SystemMessage", "content": msg["content"]})
+
+ elif msg["role"] == "tool":
+ if state is None:
+ raise ValueError(
+ "State object is required for tool messages in Colang 2.0"
+ )
+ action_uid = msg["tool_call_id"]
+ return_value = msg["content"]
+ action: Action = state.actions[action_uid]
+ events.append(
+ new_event_dict(
+ f"{action.name}Finished",
+ action_uid=action_uid,
+ action_name=action.name,
+ status="success",
+ is_success=True,
+ return_value=return_value,
+ events=[],
+ )
+ )
+
+ return events
+
+ def cache_events(self, messages: List[dict], events: List[dict]):
+ """Cache events for a sequence of messages.
+
+ Args:
+ messages: The list of messages.
+ events: The corresponding events.
+ """
+ if self.config.colang_version == "1.0":
+ cache_key = get_history_cache_key(messages)
+ self.events_history_cache[cache_key] = events
+
+ def clear_cache(self):
+ """Clear the events history cache."""
+ self.events_history_cache.clear()
diff --git a/nemoguardrails/rails/llm/kb_builder.py b/nemoguardrails/rails/llm/kb_builder.py
new file mode 100644
index 000000000..201f990e3
--- /dev/null
+++ b/nemoguardrails/rails/llm/kb_builder.py
@@ -0,0 +1,59 @@
+"""Knowledge Base builder component."""
+
+import logging
+from typing import Callable, Optional
+
+from nemoguardrails.embeddings.index import EmbeddingsIndex
+from nemoguardrails.kb.kb import KnowledgeBase
+from nemoguardrails.rails.llm.config import EmbeddingSearchProvider, RailsConfig
+
+log = logging.getLogger(__name__)
+
+
+class KnowledgeBaseBuilder:
+ """Builder for initializing and managing knowledge bases."""
+
+ def __init__(
+ self,
+ config: RailsConfig,
+ get_embeddings_search_provider_instance: Callable[
+ [Optional[EmbeddingSearchProvider]], EmbeddingsIndex
+ ],
+ ):
+ """Initialize the KnowledgeBaseBuilder.
+
+ Args:
+ config: The rails configuration.
+ get_embeddings_search_provider_instance: Function to get embeddings search provider.
+ """
+ self.config = config
+ self.get_embeddings_search_provider_instance = (
+ get_embeddings_search_provider_instance
+ )
+ self.kb: Optional[KnowledgeBase] = None
+
+ async def build(self) -> Optional[KnowledgeBase]:
+ """Build the knowledge base from configuration.
+
+ Returns:
+ The initialized KnowledgeBase or None if no docs configured.
+ """
+ if not self.config.docs:
+ self.kb = None
+ return None
+
+ documents = [doc.content for doc in self.config.docs]
+ self.kb = KnowledgeBase(
+ documents=documents,
+ config=self.config.knowledge_base,
+ get_embedding_search_provider_instance=self.get_embeddings_search_provider_instance,
+ )
+ self.kb.init()
+ await self.kb.build()
+
+ log.info("Knowledge base initialized with %d documents", len(documents))
+ return self.kb
+
+ def get_kb(self) -> Optional[KnowledgeBase]:
+ """Get the knowledge base instance."""
+ return self.kb
diff --git a/nemoguardrails/rails/llm/llmrails.py b/nemoguardrails/rails/llm/llmrails.py
index 187300aa2..d8b06a091 100644
--- a/nemoguardrails/rails/llm/llmrails.py
+++ b/nemoguardrails/rails/llm/llmrails.py
@@ -13,16 +13,12 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-"""LLM Rails entry point."""
+"""LLM Rails entry point - Refactored to use modular components."""
import asyncio
-import importlib.util
import json
import logging
-import os
-import re
import threading
-import time
from functools import partial
from typing import (
Any,
@@ -34,7 +30,6 @@
Tuple,
Type,
Union,
- cast,
)
from langchain_core.language_models import BaseChatModel
@@ -42,42 +37,15 @@
from typing_extensions import Self
from nemoguardrails.actions.llm.generation import LLMGenerationActions
-from nemoguardrails.actions.llm.utils import (
- extract_bot_thinking_from_events,
- extract_tool_calls_from_events,
- get_and_clear_response_metadata_contextvar,
- get_colang_history,
-)
from nemoguardrails.actions.output_mapping import is_output_blocked
from nemoguardrails.actions.v2_x.generation import LLMGenerationActionsV2dotx
-from nemoguardrails.colang import parse_colang_file
-from nemoguardrails.colang.v1_0.runtime.flows import _normalize_flow_id, compute_context
-from nemoguardrails.colang.v1_0.runtime.runtime import Runtime, RuntimeV1_0
-from nemoguardrails.colang.v2_x.runtime.flows import Action, State
-from nemoguardrails.colang.v2_x.runtime.runtime import RuntimeV2_x
-from nemoguardrails.colang.v2_x.runtime.serialization import (
- json_to_state,
- state_to_json,
-)
-from nemoguardrails.context import (
- explain_info_var,
- generation_options_var,
- llm_stats_var,
- raw_llm_request,
- streaming_handler_var,
-)
+from nemoguardrails.colang.v1_0.runtime.runtime import Runtime
+from nemoguardrails.colang.v2_x.runtime.flows import State
+from nemoguardrails.colang.v2_x.runtime.serialization import json_to_state
from nemoguardrails.embeddings.index import EmbeddingsIndex
from nemoguardrails.embeddings.providers import register_embedding_provider
from nemoguardrails.embeddings.providers.base import EmbeddingModel
-from nemoguardrails.kb.kb import KnowledgeBase
-from nemoguardrails.llm.cache import CacheInterface, LFUCache
-from nemoguardrails.llm.models.initializer import (
- ModelInitializationError,
- init_llm_model,
-)
from nemoguardrails.logging.explain import ExplainInfo
-from nemoguardrails.logging.processing_log import compute_generation_log
-from nemoguardrails.logging.stats import LLMStats
from nemoguardrails.logging.verbose import set_verbose
from nemoguardrails.patch_asyncio import check_sync_call_from_async_loop
from nemoguardrails.rails.llm.buffer import get_buffer_strategy
@@ -86,30 +54,35 @@
OutputRailsStreamingConfig,
RailsConfig,
)
+from nemoguardrails.rails.llm.config_loader import ConfigLoader
+from nemoguardrails.rails.llm.event_translator import EventTranslator
+from nemoguardrails.rails.llm.kb_builder import KnowledgeBaseBuilder
+from nemoguardrails.rails.llm.model_factory import ModelFactory
from nemoguardrails.rails.llm.options import (
- GenerationLog,
GenerationOptions,
+ GenerationRailsOptions,
GenerationResponse,
)
-from nemoguardrails.rails.llm.utils import (
- get_action_details_from_flow_id,
- get_history_cache_key,
-)
+from nemoguardrails.rails.llm.response_assembler import ResponseAssembler
+from nemoguardrails.rails.llm.runtime_orchestrator import RuntimeOrchestrator
+from nemoguardrails.rails.llm.utils import get_action_details_from_flow_id
from nemoguardrails.streaming import END_OF_STREAM, StreamingHandler
-from nemoguardrails.utils import (
- extract_error_json,
- get_or_create_event_loop,
- new_event_dict,
- new_uuid,
-)
+from nemoguardrails.utils import get_or_create_event_loop
log = logging.getLogger(__name__)
-process_events_semaphore = asyncio.Semaphore(1)
-
class LLMRails:
- """Rails based on a given configuration."""
+ """Rails based on a given configuration.
+
+ Refactored to use modular components:
+ - ConfigEnricher: Loads and enriches configuration
+ - ModelFactory: Manages LLM instantiation
+ - KnowledgeBaseBuilder: Builds knowledge base
+ - EventTranslator: Converts messages to/from events
+ - RuntimeOrchestrator: Manages Colang runtime
+ - ResponseAssembler: Assembles responses
+ """
config: RailsConfig
llm: Optional[Union[BaseLLM, BaseChatModel]]
@@ -127,11 +100,11 @@ def __init__(
config: A rails configuration.
llm: An optional LLM engine to use. If provided, this will be used as the main LLM
and will take precedence over any main LLM specified in the config.
- verbose: Whether the logging should be verbose or not.
+ verbose: Whether the logging should be verbose.
"""
self.config = config
- self.llm = llm
self.verbose = verbose
+ self.explain_info: Optional[ExplainInfo] = None
if self.verbose:
set_verbose(True, llm_calls=True)
@@ -140,116 +113,24 @@ def __init__(
# an index of them.
self.embedding_search_providers = {}
- # The default embeddings model is using FastEmbed
+ # The default embeddings model is usining FastEmbed
self.default_embedding_model = "all-MiniLM-L6-v2"
self.default_embedding_engine = "FastEmbed"
self.default_embedding_params = {}
- # We keep a cache of the events history associated with a sequence of user messages.
- # TODO: when we update the interface to allow to return a "state object", this
- # should be removed
- self.events_history_cache = {}
-
- # Weather the main LLM supports streaming
- self.main_llm_supports_streaming = False
-
- # We also load the default flows from the `default_flows.yml` file in the current folder.
- # But only for version 1.0.
- # TODO: decide on the default flows for 2.x.
- if config.colang_version == "1.0":
- # We also load the default flows from the `llm_flows.co` file in the current folder.
- current_folder = os.path.dirname(__file__)
- default_flows_file = "llm_flows.co"
- default_flows_path = os.path.join(current_folder, default_flows_file)
- with open(default_flows_path, "r") as f:
- default_flows_content = f.read()
- default_flows = parse_colang_file(
- default_flows_file, default_flows_content
- )["flows"]
-
- # We mark all the default flows as system flows.
- for flow_config in default_flows:
- flow_config["is_system_flow"] = True
-
- # We add the default flows to the config.
- self.config.flows.extend(default_flows)
-
- # We also need to load the content from the components library.
- library_path = os.path.join(os.path.dirname(__file__), "../../library")
- for root, dirs, files in os.walk(library_path):
- for file in files:
- # Extract the full path for the file
- full_path = os.path.join(root, file)
- if file.endswith(".co"):
- log.debug(f"Loading file: {full_path}")
- with open(full_path, "r", encoding="utf-8") as f:
- content = parse_colang_file(
- file, content=f.read(), version=config.colang_version
- )
- if not content:
- continue
-
- # We mark all the flows coming from the guardrails library as system flows.
- for flow_config in content["flows"]:
- flow_config["is_system_flow"] = True
-
- # We load all the flows
- self.config.flows.extend(content["flows"])
-
- # And all the messages as well, if they have not been overwritten
- for message_id, utterances in content.get(
- "bot_messages", {}
- ).items():
- if message_id not in self.config.bot_messages:
- self.config.bot_messages[message_id] = utterances
-
- # Last but not least, we mark all the flows that are used in any of the rails
- # as system flows (so they don't end up in the prompt).
-
- rail_flow_ids = (
- config.rails.input.flows
- + config.rails.output.flows
- + config.rails.retrieval.flows
- )
+ # Load config with default flows and library content
+ config_modules = ConfigLoader.load_config(config)
- for flow_config in self.config.flows:
- if flow_config.get("id") in rail_flow_ids:
- flow_config["is_system_flow"] = True
-
- # We also mark them as subflows by default, to simplify the syntax
- flow_config["is_subflow"] = True
-
- # We check if the configuration or any of the imported ones have config.py modules.
- config_modules = []
- for _path in list(
- self.config.imported_paths.values() if self.config.imported_paths else []
- ) + [self.config.config_path]:
- if _path:
- filepath = os.path.join(_path, "config.py")
- if os.path.exists(filepath):
- filename = os.path.basename(filepath)
- spec = importlib.util.spec_from_file_location(filename, filepath)
- if spec and spec.loader:
- config_module = importlib.util.module_from_spec(spec)
- spec.loader.exec_module(config_module)
- config_modules.append(config_module)
-
- # First, we initialize the runtime.
- if config.colang_version == "1.0":
- self.runtime = RuntimeV1_0(config=config, verbose=verbose)
- elif config.colang_version == "2.x":
- self.runtime = RuntimeV2_x(config=config, verbose=verbose)
- else:
- raise ValueError(f"Unsupported colang version: {config.colang_version}.")
+ # Initialize RuntimeOrchestrator
+ self.runtime_orchestrator = RuntimeOrchestrator(config=config, verbose=verbose)
+ self.runtime = self.runtime_orchestrator.runtime
- # If we have a config_modules with an `init` function, we call it.
- # We need to call this here because the `init` might register additional
- # LLM providers.
+ # Execute config.py init functions
for config_module in config_modules:
if hasattr(config_module, "init"):
config_module.init(self)
- # If we have a customized embedding model, we'll use it.
+ # Update embedding model if specified in config
for model in self.config.models:
if model.type == "embeddings":
self.default_embedding_model = model.model
@@ -257,8 +138,7 @@ def __init__(
self.default_embedding_params = model.parameters or {}
break
- # InteractionLogAdapters used for tracing
- # We ensure that it is used after config.py is loaded
+ # Initialize tracing adapters
if config.tracing:
from nemoguardrails.tracing import create_log_adapters
@@ -266,13 +146,24 @@ def __init__(
else:
self._log_adapters = None
- # We run some additional checks on the config
- self._validate_config()
+ # Initialize ModelFactory
+ self.model_factory = ModelFactory(config=config, injected_llm=llm)
+ self.action_param_registry: Dict[str, Any] = {}
- # Next, we initialize the LLM engines (main engine and action engines if specified).
- self._init_llms()
+ # Initialize models and register action parameters
+ self.model_factory.initialize_models(self.action_param_registry)
+ for param_name, param_value in self.action_param_registry.items():
+ self.runtime.register_action_param(param_name, param_value)
- # Next, we initialize the LLM Generate actions and register them.
+ # Store main LLM reference for backwards compatibility
+ self.llm = self.model_factory.get_main_llm()
+ self.main_llm_supports_streaming = self.model_factory.supports_streaming()
+
+ # Expose specialized LLMs as attributes for convenient access
+ for llm_type, llm_instance in self.model_factory.specialized_llms.items():
+ setattr(self, f"{llm_type}_llm", llm_instance)
+
+ # Initialize LLM Generation Actions
llm_generation_actions_class = (
LLMGenerationActions
if config.colang_version == "1.0"
@@ -285,308 +176,39 @@ def __init__(
get_embedding_search_provider_instance=self._get_embeddings_search_provider_instance,
verbose=verbose,
)
-
- # If there's already an action registered, we don't override.
self.runtime.register_actions(self.llm_generation_actions, override=False)
- # Next, we initialize the Knowledge Base
- # There are still some edge cases not covered by nest_asyncio.
- # Using a separate thread always for now.
+ # Initialize KnowledgeBaseBuilder
+ self.kb_builder = KnowledgeBaseBuilder(
+ config=self.config,
+ get_embeddings_search_provider_instance=self._get_embeddings_search_provider_instance,
+ )
+
+ # Build KB in separate thread
loop = get_or_create_event_loop()
if True or check_sync_call_from_async_loop():
- t = threading.Thread(target=asyncio.run, args=(self._init_kb(),))
+ t = threading.Thread(target=asyncio.run, args=(self.kb_builder.build(),))
t.start()
t.join()
else:
- loop.run_until_complete(self._init_kb())
+ loop.run_until_complete(self.kb_builder.build())
- # We also register the kb as a parameter that can be passed to actions.
+ # Register KB as action parameter
+ self.kb = self.kb_builder.get_kb()
self.runtime.register_action_param("kb", self.kb)
- # Reference to the general ExplainInfo object.
- self.explain_info = None
-
- def update_llm(self, llm):
- """Replace the main LLM with the provided one.
-
- Arguments:
- llm: The new LLM that should be used.
- """
- self.llm = llm
- self.llm_generation_actions.llm = llm
- self.runtime.register_action_param("llm", llm)
-
- def _validate_config(self):
- """Runs additional validation checks on the config."""
-
- if self.config.colang_version == "1.0":
- existing_flows_names = set([flow.get("id") for flow in self.config.flows])
- else:
- existing_flows_names = set([flow.get("name") for flow in self.config.flows])
-
- for flow_name in self.config.rails.input.flows:
- # content safety check input/output flows are special as they have parameters
- flow_name = _normalize_flow_id(flow_name)
- if flow_name not in existing_flows_names:
- raise ValueError(
- f"The provided input rail flow `{flow_name}` does not exist"
- )
-
- for flow_name in self.config.rails.output.flows:
- flow_name = _normalize_flow_id(flow_name)
- if flow_name not in existing_flows_names:
- raise ValueError(
- f"The provided output rail flow `{flow_name}` does not exist"
- )
-
- for flow_name in self.config.rails.retrieval.flows:
- if flow_name not in existing_flows_names:
- raise ValueError(
- f"The provided retrieval rail flow `{flow_name}` does not exist"
- )
-
- # If both passthrough mode and single call mode are specified, we raise an exception.
- if self.config.passthrough and self.config.rails.dialog.single_call.enabled:
- raise ValueError(
- "The passthrough mode and the single call dialog rails mode can't be used at the same time. "
- "The single call mode needs to use an altered prompt when prompting the LLM. "
- )
-
- async def _init_kb(self):
- """Initializes the knowledge base."""
- self.kb = None
-
- if not self.config.docs:
- return
-
- documents = [doc.content for doc in self.config.docs]
- self.kb = KnowledgeBase(
- documents=documents,
- config=self.config.knowledge_base,
- get_embedding_search_provider_instance=self._get_embeddings_search_provider_instance,
- )
- self.kb.init()
- await self.kb.build()
-
- def _prepare_model_kwargs(self, model_config):
- """
- Prepare kwargs for model initialization, including API key from environment variable.
-
- Args:
- model_config: The model configuration object
-
- Returns:
- dict: The prepared kwargs for model initialization
- """
- kwargs = model_config.parameters or {}
-
- # If the optional API Key Environment Variable is set, add it to kwargs
- if model_config.api_key_env_var:
- api_key = os.environ.get(model_config.api_key_env_var)
- if api_key:
- kwargs["api_key"] = api_key
-
- # enable streaming token usage when streaming is enabled
- # providers that don't support this parameter will simply ignore it
- if self.config.streaming:
- kwargs["stream_usage"] = True
-
- return kwargs
-
- def _configure_main_llm_streaming(
- self,
- llm: Union[BaseLLM, BaseChatModel],
- model_name: Optional[str] = None,
- provider_name: Optional[str] = None,
- ):
- """Configure streaming support for the main LLM.
-
- Args:
- llm (Union[BaseLLM, BaseChatModel]): The main LLM model instance.
- model_name (Optional[str], optional): Optional model name for logging.
- provider_name (Optional[str], optional): Optional provider name for logging.
-
- """
- if not self.config.streaming:
- return
-
- if hasattr(llm, "streaming"):
- setattr(llm, "streaming", True)
- self.main_llm_supports_streaming = True
- else:
- self.main_llm_supports_streaming = False
- if model_name and provider_name:
- log.warning(
- "Model %s from provider %s does not support streaming.",
- model_name,
- provider_name,
- )
- else:
- log.warning("Provided main LLM does not support streaming.")
-
- def _init_llms(self):
- """
- Initializes the right LLM engines based on the configuration.
- There can be multiple LLM engines and types that can be specified in the config.
- The main LLM engine is the one that will be used for all the core guardrails generations.
- Other LLM engines can be specified for use in specific actions.
-
- The reason we provide an option for decoupling the main LLM engine from the action LLM
- is to allow for flexibility in using specialized LLM engines for specific actions.
-
- Raises:
- ModelInitializationError: If any model initialization fails
- """
- # If the user supplied an already-constructed LLM via the constructor we
- # treat it as the *main* model, but **still** iterate through the
- # configuration to load any additional models (e.g. `content_safety`).
-
- if self.llm:
- # If an LLM was provided via constructor, use it as the main LLM
- # Log a warning if a main LLM is also specified in the config
- if any(model.type == "main" for model in self.config.models):
- log.warning(
- "Both an LLM was provided via constructor and a main LLM is specified in the config. "
- "The LLM provided via constructor will be used and the main LLM from config will be ignored."
- )
- self.runtime.register_action_param("llm", self.llm)
-
- self._configure_main_llm_streaming(self.llm)
- else:
- # Otherwise, initialize the main LLM from the config
- main_model = next(
- (model for model in self.config.models if model.type == "main"), None
- )
-
- if main_model and main_model.model:
- kwargs = self._prepare_model_kwargs(main_model)
- self.llm = init_llm_model(
- model_name=main_model.model,
- provider_name=main_model.engine,
- mode="chat",
- kwargs=kwargs,
- )
- self.runtime.register_action_param("llm", self.llm)
-
- self._configure_main_llm_streaming(
- self.llm,
- model_name=main_model.model,
- provider_name=main_model.engine,
- )
- else:
- log.warning(
- "No main LLM specified in the config and no LLM provided via constructor."
- )
-
- llms = dict()
-
- for llm_config in self.config.models:
- if llm_config.type in ["embeddings", "jailbreak_detection"]:
- continue
+ # Initialize EventTranslator
+ self.event_translator = EventTranslator(config=self.config)
+ # For backwards compatibility
+ self.events_history_cache = self.event_translator.events_history_cache
- # If a constructor LLM is provided, skip initializing any 'main' model from config
- if self.llm and llm_config.type == "main":
- continue
-
- try:
- model_name = llm_config.model
- if not model_name:
- raise ValueError("LLM Config model field not set")
-
- provider_name = llm_config.engine
- kwargs = self._prepare_model_kwargs(llm_config)
- mode = llm_config.mode
-
- llm_model = init_llm_model(
- model_name=model_name,
- provider_name=provider_name,
- mode=mode,
- kwargs=kwargs,
- )
-
- # Configure the model based on its type
- if llm_config.type == "main":
- # If a main LLM was already injected, skip creating another
- # one. Otherwise, create and register it.
- if not self.llm:
- self.llm = llm_model
- self.runtime.register_action_param("llm", self.llm)
- else:
- model_name = f"{llm_config.type}_llm"
- if not hasattr(self, model_name):
- setattr(self, model_name, llm_model)
- self.runtime.register_action_param(
- model_name, getattr(self, model_name)
- )
- # this is used for content safety and topic control
- llms[llm_config.type] = getattr(self, model_name)
-
- except ModelInitializationError as e:
- log.error("Failed to initialize model: %s", str(e))
- raise
- except Exception as e:
- log.error("Unexpected error initializing model: %s", str(e))
- raise
-
- self.runtime.register_action_param("llms", llms)
-
- self._initialize_model_caches()
-
- def _create_model_cache(self, model) -> LFUCache:
- """
- Create cache instance for a model based on its configuration.
-
- Args:
- model: The model configuration object
-
- Returns:
- LFUCache: The cache instance
- """
-
- if model.cache.maxsize <= 0:
- raise ValueError(
- f"Invalid cache maxsize for model '{model.type}': {model.cache.maxsize}. "
- "Capacity must be greater than 0. Skipping cache creation."
- )
-
- stats_logging_interval = None
- if model.cache.stats.enabled and model.cache.stats.log_interval is not None:
- stats_logging_interval = model.cache.stats.log_interval
-
- cache = LFUCache(
- maxsize=model.cache.maxsize,
- track_stats=model.cache.stats.enabled,
- stats_logging_interval=stats_logging_interval,
- )
-
- log.info(
- f"Created cache for model '{model.type}' with maxsize {model.cache.maxsize}"
- )
-
- return cache
-
- def _initialize_model_caches(self) -> None:
- """Initialize caches for configured models."""
- model_caches: Optional[Dict[str, CacheInterface]] = dict()
- for model in self.config.models:
- if model.type in ["main", "embeddings"]:
- continue
-
- if model.cache and model.cache.enabled:
- cache = self._create_model_cache(model)
- model_caches[model.type] = cache
-
- log.info(
- f"Initialized model '{model.type}' with cache %s",
- "enabled" if cache else "disabled",
- )
-
- if model_caches:
- self.runtime.register_action_param("model_caches", model_caches)
+ # Initialize ResponseAssembler
+ self.response_assembler = ResponseAssembler(config=self.config)
def _get_embeddings_search_provider_instance(
self, esp_config: Optional[EmbeddingSearchProvider] = None
) -> EmbeddingsIndex:
+ """Get an embeddings search provider instance."""
if esp_config is None:
esp_config = EmbeddingSearchProvider()
@@ -604,7 +226,6 @@ def _get_embeddings_search_provider_instance(
"embedding_parameters", self.default_embedding_params
),
cache_config=esp_config.cache,
- # We make sure we also pass additional relevant params.
**{
k: v
for k, v in esp_config.parameters.items()
@@ -625,181 +246,11 @@ def _get_embeddings_search_provider_instance(
kwargs = esp_config.parameters
return self.embedding_search_providers[esp_config.name](**kwargs)
- def _get_events_for_messages(self, messages: List[dict], state: Any):
- """Return the list of events corresponding to the provided messages.
-
- Tries to find a prefix of messages for which we have already a list of events
- in the cache. For the rest, they are converted as is.
-
- The reason this cache exists is that we want to benefit from events generated in
- previous turns, which can't be computed again because it would be expensive (e.g.,
- involving multiple LLM calls).
-
- When an explicit state object will be added, this mechanism can be removed.
-
- Args:
- messages: The list of messages.
-
- Returns:
- A list of events.
- """
- events = []
-
- if self.config.colang_version == "1.0":
- # We try to find the longest prefix of messages for which we have a cache
- # of events.
- p = len(messages) - 1
- while p > 0:
- cache_key = get_history_cache_key(messages[0:p])
- if cache_key in self.events_history_cache:
- events = self.events_history_cache[cache_key].copy()
- break
-
- p -= 1
-
- # For the rest of the messages, we transform them directly into events.
- # TODO: Move this to separate function once more types of messages are supported.
- for idx in range(p, len(messages)):
- msg = messages[idx]
- if msg["role"] == "user":
- events.append(
- {
- "type": "UtteranceUserActionFinished",
- "final_transcript": msg["content"],
- }
- )
-
- # If it's not the last message, we also need to add the `UserMessage` event
- if idx != len(messages) - 1:
- events.append(
- {
- "type": "UserMessage",
- "text": msg["content"],
- }
- )
-
- elif msg["role"] == "assistant":
- if msg.get("tool_calls"):
- events.append(
- {"type": "BotToolCalls", "tool_calls": msg["tool_calls"]}
- )
- else:
- action_uid = new_uuid()
- start_event = new_event_dict(
- "StartUtteranceBotAction",
- script=msg["content"],
- action_uid=action_uid,
- )
- finished_event = new_event_dict(
- "UtteranceBotActionFinished",
- final_script=msg["content"],
- is_success=True,
- action_uid=action_uid,
- )
- events.extend([start_event, finished_event])
- elif msg["role"] == "context":
- events.append({"type": "ContextUpdate", "data": msg["content"]})
- elif msg["role"] == "event":
- events.append(msg["event"])
- elif msg["role"] == "system":
- # Handle system messages - convert them to SystemMessage events
- events.append({"type": "SystemMessage", "content": msg["content"]})
- elif msg["role"] == "tool":
- # For the last tool message, create grouped tool event and synthetic UserMessage
- if idx == len(messages) - 1:
- # Find the original user message for response generation
- user_message = None
- for prev_msg in reversed(messages[:idx]):
- if prev_msg["role"] == "user":
- user_message = prev_msg["content"]
- break
-
- if user_message:
- # If tool input rails are configured, group all tool messages
- if self.config.rails.tool_input.flows:
- # Collect all tool messages for grouped processing
- tool_messages = []
- for tool_idx in range(len(messages)):
- if messages[tool_idx]["role"] == "tool":
- tool_messages.append(
- {
- "content": messages[tool_idx][
- "content"
- ],
- "name": messages[tool_idx].get(
- "name", "unknown"
- ),
- "tool_call_id": messages[tool_idx].get(
- "tool_call_id", ""
- ),
- }
- )
-
- events.append(
- {
- "type": "UserToolMessages",
- "tool_messages": tool_messages,
- }
- )
-
- else:
- events.append(
- {"type": "UserMessage", "text": user_message}
- )
-
- else:
- for idx in range(len(messages)):
- msg = messages[idx]
- if msg["role"] == "user":
- events.append(
- {
- "type": "UtteranceUserActionFinished",
- "final_transcript": msg["content"],
- }
- )
-
- elif msg["role"] == "assistant":
- raise ValueError(
- "Providing `assistant` messages as input is not supported for Colang 2.0 configurations."
- )
- elif msg["role"] == "context":
- events.append({"type": "ContextUpdate", "data": msg["content"]})
- elif msg["role"] == "event":
- events.append(msg["event"])
- elif msg["role"] == "system":
- # Handle system messages - convert them to SystemMessage events
- events.append({"type": "SystemMessage", "content": msg["content"]})
- elif msg["role"] == "tool":
- action_uid = msg["tool_call_id"]
- return_value = msg["content"]
- action: Action = state.actions[action_uid]
- events.append(
- new_event_dict(
- f"{action.name}Finished",
- action_uid=action_uid,
- action_name=action.name,
- status="success",
- is_success=True,
- return_value=return_value,
- events=[],
- )
- )
-
- return events
-
- @staticmethod
- def _ensure_explain_info() -> ExplainInfo:
- """Ensure that the ExplainInfo variable is present in the current context
-
- Returns:
- A ExplainInfo class containing the llm calls' statistics
- """
- explain_info = explain_info_var.get()
- if explain_info is None:
- explain_info = ExplainInfo()
- explain_info_var.set(explain_info)
-
- return explain_info
+ def update_llm(self, llm: Union[BaseLLM, BaseChatModel]):
+ """Replace the main LLM with the provided one."""
+ self.llm = llm
+ self.model_factory.update_main_llm(llm, self.action_param_registry)
+ self.llm_generation_actions.llm = llm
async def generate_async(
self,
@@ -809,85 +260,50 @@ async def generate_async(
state: Optional[Union[dict, State]] = None,
streaming_handler: Optional[StreamingHandler] = None,
) -> Union[str, dict, GenerationResponse, Tuple[dict, dict]]:
- """Generate a completion or a next message.
-
- The format for messages is the following:
+ """Generate a completion or next message.
- ```python
- [
- {"role": "context", "content": {"user_name": "John"}},
- {"role": "user", "content": "Hello! How are you?"},
- {"role": "assistant", "content": "I am fine, thank you!"},
- {"role": "event", "event": {"type": "UserSilent"}},
- ...
- ]
- ```
-
- Args:
- prompt: The prompt to be used for completion.
- messages: The history of messages to be used to generate the next message.
- options: Options specific for the generation.
- state: The state object that should be used as the starting point.
- streaming_handler: If specified, and the config supports streaming, the
- provided handler will be used for streaming.
-
- Returns:
- The completion (when a prompt is provided) or the next message.
-
- System messages are not yet supported."""
- # convert options to gen_options of type GenerationOptions
- gen_options: Optional[GenerationOptions] = None
+ Delegates to components for actual processing.
+ """
+ import time
+
+ from nemoguardrails.actions.llm.utils import get_colang_history
+ from nemoguardrails.context import (
+ explain_info_var,
+ generation_options_var,
+ llm_stats_var,
+ raw_llm_request,
+ streaming_handler_var,
+ )
+ from nemoguardrails.logging.stats import LLMStats
+ from nemoguardrails.utils import extract_error_json
+ # Input validation
if prompt is None and messages is None:
raise ValueError("Either prompt or messages must be provided.")
-
if prompt is not None and messages is not None:
raise ValueError("Only one of prompt or messages can be provided.")
+ # Convert prompt to messages
if prompt is not None:
- # Currently, we transform the prompt request into a single turn conversation
messages = [{"role": "user", "content": prompt}]
- # If a state object is specified, then we switch to "generation options" mode.
- # This is because we want the output to be a GenerationResponse which will contain
- # the output state.
+ # Deserialize state if needed
if state is not None:
- # We deserialize the state if needed.
if isinstance(state, dict) and state.get("version", "1.0") == "2.x":
state = json_to_state(state["state"])
- if options is None:
- gen_options = GenerationOptions()
- elif isinstance(options, dict):
- gen_options = GenerationOptions(**options)
- else:
- gen_options = options
- else:
- # We allow options to be specified both as a dict and as an object.
- if options and isinstance(options, dict):
- gen_options = GenerationOptions(**options)
- elif isinstance(options, GenerationOptions):
- gen_options = options
- elif options is None:
- gen_options = None
- else:
- raise TypeError("options must be a dict or GenerationOptions")
-
- # Save the generation options in the current async context.
- # At this point, gen_options is either None or GenerationOptions
+ # Process options
+ gen_options = self._process_options(options, state)
generation_options_var.set(gen_options)
if streaming_handler:
streaming_handler_var.set(streaming_handler)
- # Initialize the object with additional explanation information.
- # We allow this to also be set externally. This is useful when multiple parallel
- # requests are made.
+ # Initialize explain info
self.explain_info = self._ensure_explain_info()
-
raw_llm_request.set(messages)
- # If we have generation options, we also add them to the context
+ # Inject generation options into messages
if gen_options:
messages = [
{
@@ -896,157 +312,72 @@ async def generate_async(
}
] + (messages or [])
- # If the last message is from the assistant, rather than the user, then
- # we move that to the `$bot_message` variable. This is to enable a more
- # convenient interface. (only when dialog rails are disabled)
+ # Handle bot message in context for non-dialog mode
if (
messages
and messages[-1]["role"] == "assistant"
and gen_options
and gen_options.rails.dialog is False
):
- # We already have the first message with a context update, so we use that
messages[0]["content"]["bot_message"] = messages[-1]["content"]
messages = messages[0:-1]
- # TODO: Add support to load back history of events, next to history of messages
- # This is important as without it, the LLM prediction is not as good.
-
t0 = time.time()
- # Initialize the LLM stats
+ # Initialize LLM stats
llm_stats = LLMStats()
llm_stats_var.set(llm_stats)
- processing_log = []
-
- # The array of events corresponding to the provided sequence of messages.
- events = self._get_events_for_messages(messages, state) # type: ignore
- if self.config.colang_version == "1.0":
- # If we had a state object, we also need to prepend the events from the state.
- state_events = []
- if state:
- assert isinstance(state, dict)
- state_events = state["events"]
-
- new_events = []
- # Compute the new events.
- try:
- new_events = await self.runtime.generate_events(
- state_events + events, processing_log=processing_log
- )
- output_state = None
-
- except Exception as e:
- log.error("Error in generate_async: %s", e, exc_info=True)
- streaming_handler = streaming_handler_var.get()
- if streaming_handler:
- # Push an error chunk instead of None.
- error_message = str(e)
- error_dict = extract_error_json(error_message)
- error_payload: str = json.dumps(error_dict)
- await streaming_handler.push_chunk(error_payload)
- # push a termination signal
- await streaming_handler.push_chunk(END_OF_STREAM) # type: ignore
- # Re-raise the exact exception
- raise
+ # Translate messages to events
+ if messages is None:
+ raise ValueError("messages must be provided")
else:
- # In generation mode, by default the bot response is an instant action.
- instant_actions = ["UtteranceBotAction"]
- if self.config.rails.actions.instant_actions is not None:
- instant_actions = self.config.rails.actions.instant_actions
-
- # Cast this explicitly to avoid certain warnings
- runtime: RuntimeV2_x = cast(RuntimeV2_x, self.runtime)
-
- # Compute the new events.
- # In generation mode, the processing is always blocking, i.e., it waits for
- # all local actions (sync and async).
- new_events, output_state = await runtime.process_events(
- events, state=state, instant_actions=instant_actions, blocking=True
- )
- # We also encode the output state as a JSON
- output_state = {"state": state_to_json(output_state), "version": "2.x"}
-
- # Extract and join all the messages from StartUtteranceBotAction events as the response.
- responses = []
- response_tool_calls = []
- response_events = []
- new_extra_events = []
- exception = None
+ events = self.event_translator.messages_to_events(messages, state)
- # The processing is different for Colang 1.0 and 2.0
- if self.config.colang_version == "1.0":
- for event in new_events:
- if event["type"] == "StartUtteranceBotAction":
- # Check if we need to remove a message
- if event["script"] == "(remove last message)":
- responses = responses[0:-1]
- else:
- responses.append(event["script"])
- elif event["type"].endswith("Exception"):
- exception = event
-
- else:
- for event in new_events:
- start_action_match = re.match(r"Start(.*Action)", event["type"])
-
- if start_action_match:
- action_name = start_action_match[1]
- # TODO: is there an elegant way to extract just the arguments?
- arguments = {
- k: v
- for k, v in event.items()
- if k != "type"
- and k != "uid"
- and k != "event_created_at"
- and k != "source_uid"
- and k != "action_uid"
- }
- response_tool_calls.append(
- {
- "id": event["action_uid"],
- "type": "function",
- "function": {"name": action_name, "arguments": arguments},
- }
- )
-
- elif event["type"] == "UtteranceBotActionFinished":
- responses.append(event["final_script"])
- else:
- # We just append the event
- response_events.append(event)
-
- if exception:
- new_message: dict = {"role": "exception", "content": exception}
-
- else:
- # Ensure all items in responses are strings
- responses = [
- str(response) if not isinstance(response, str) else response
- for response in responses
- ]
- new_message: dict = {"role": "assistant", "content": "\n".join(responses)}
- if response_tool_calls:
- new_message["tool_calls"] = response_tool_calls
- if response_events:
- new_message["events"] = response_events
+ # Generate new events using runtime orchestrator
+ try:
+ log.info(
+ f"DEBUG: Calling generate_events with state containing skip_output_rails: {state.get('skip_output_rails') if isinstance(state, dict) else 'state is not dict'}"
+ )
+ (
+ new_events,
+ output_state,
+ processing_log,
+ ) = await self.runtime_orchestrator.generate_events(
+ events=events, state=state
+ )
+ except Exception as e:
+ log.error("Error in generate_async: %s", e, exc_info=True)
+ streaming_handler = streaming_handler_var.get()
+ if streaming_handler:
+ error_message = str(e)
+ error_dict = extract_error_json(error_message)
+ error_payload = json.dumps(error_dict)
+ await streaming_handler.push_chunk(error_payload)
+ await streaming_handler.push_chunk(END_OF_STREAM)
+ raise
+ # Cache events for Colang 1.0
if self.config.colang_version == "1.0":
- events.extend(new_events)
- events.extend(new_extra_events)
+ responses, _, _, _ = self.response_assembler._extract_from_events(
+ new_events
+ )
+ responses = [str(r) if not isinstance(r, str) else r for r in responses]
+ new_message = {"role": "assistant", "content": "\n".join(responses)}
- # If a state object is not used, then we use the implicit caching
if state is None:
- # Save the new events in the history and update the cache
- cache_key = get_history_cache_key((messages) + [new_message]) # type: ignore
- self.events_history_cache[cache_key] = events
+ all_events = events + new_events
+ self.event_translator.cache_events(
+ (messages or []) + [new_message], all_events
+ )
else:
- output_state = {"events": events}
+ output_state = {"events": events + new_events}
- # If logging is enabled, we log the conversation
- # TODO: add support for logging flag
- self.explain_info.colang_history = get_colang_history(events)
+ # Log conversation history
+ all_events = (
+ events + new_events if self.config.colang_version == "1.0" else new_events
+ )
+ self.explain_info.colang_history = get_colang_history(all_events)
if self.verbose:
log.info(
f"Conversation history so far: \n{self.explain_info.colang_history}"
@@ -1058,192 +389,160 @@ async def generate_async(
% (total_time, llm_stats)
)
- # If there is a streaming handler, we make sure we close it now
+ # Close streaming handler
streaming_handler = streaming_handler_var.get()
if streaming_handler:
- # print("Closing the stream handler explicitly")
- await streaming_handler.push_chunk(END_OF_STREAM) # type: ignore
-
- # IF tracing is enabled we need to set GenerationLog attrs
- original_log_options = None
- if self.config.tracing.enabled:
- if gen_options is None:
- gen_options = GenerationOptions()
- else:
- # create a copy of the gen_options to avoid modifying the original
- gen_options = gen_options.model_copy(deep=True)
- original_log_options = gen_options.log.model_copy(deep=True)
+ await streaming_handler.push_chunk(END_OF_STREAM)
- # enable log options
- # it is aggressive, but these are required for tracing
- if (
- not gen_options.log.activated_rails
- or not gen_options.log.llm_calls
- or not gen_options.log.internal_events
- ):
- gen_options.log.activated_rails = True
- gen_options.log.llm_calls = True
- gen_options.log.internal_events = True
-
- tool_calls = extract_tool_calls_from_events(new_events)
- llm_metadata = get_and_clear_response_metadata_contextvar()
- reasoning_content = extract_bot_thinking_from_events(new_events)
- # If we have generation options, we prepare a GenerationResponse instance.
+ # Assemble response
if gen_options:
- # If a prompt was used, we only need to return the content of the message.
- if prompt:
- res = GenerationResponse(response=new_message["content"])
- else:
- res = GenerationResponse(response=[new_message])
-
- if reasoning_content:
- res.reasoning_content = reasoning_content
+ res = self.response_assembler.assemble_response(
+ new_events=new_events,
+ all_events=all_events,
+ output_state=output_state,
+ processing_log=processing_log,
+ gen_options=gen_options,
+ prompt=prompt,
+ )
- if tool_calls:
- res.tool_calls = tool_calls
+ # Handle tracing if enabled
+ if self.config.tracing.enabled and messages is not None:
+ await self._handle_tracing(
+ messages, res, gen_options, processing_log, all_events
+ )
- if llm_metadata:
- res.llm_metadata = llm_metadata
+ return res
+ else:
+ simple_res = self.response_assembler.assemble_simple_response(
+ new_events=new_events, prompt=prompt
+ )
- if self.config.colang_version == "1.0":
- # If output variables are specified, we extract their values
- if gen_options and gen_options.output_vars:
- context = compute_context(events)
- output_vars = gen_options.output_vars
- if isinstance(output_vars, list):
- # If we have only a selection of keys, we filter to only that.
- res.output_data = {k: context.get(k) for k in output_vars}
- else:
- # Otherwise, we return the full context
- res.output_data = context
+ # Handle tracing if enabled (even when options is None)
+ if self.config.tracing.enabled and messages is not None:
+ # Convert simple response to GenerationResponse for tracing
+ if isinstance(simple_res, dict):
+ trace_res = GenerationResponse(response=[simple_res], log=None)
+ else:
+ trace_res = GenerationResponse(response=simple_res, log=None)
+ await self._handle_tracing(
+ messages, trace_res, None, processing_log, all_events
+ )
+ # Return GenerationResponse when tracing is enabled
+ return trace_res
- _log = compute_generation_log(processing_log)
+ return simple_res
- # Include information about activated rails and LLM calls if requested
- log_options = gen_options.log if gen_options else None
- if log_options and (
- log_options.activated_rails or log_options.llm_calls
- ):
- res.log = GenerationLog()
-
- # We always include the stats
- res.log.stats = _log.stats
-
- if log_options.activated_rails:
- res.log.activated_rails = _log.activated_rails
-
- if log_options.llm_calls:
- res.log.llm_calls = []
- for activated_rail in _log.activated_rails:
- for executed_action in activated_rail.executed_actions:
- res.log.llm_calls.extend(executed_action.llm_calls)
-
- # Include internal events if requested
- if log_options and log_options.internal_events:
- if res.log is None:
- res.log = GenerationLog()
-
- res.log.internal_events = new_events
-
- # Include the Colang history if requested
- if log_options and log_options.colang_history:
- if res.log is None:
- res.log = GenerationLog()
-
- res.log.colang_history = get_colang_history(events)
-
- # Include the raw llm output if requested
- if gen_options and gen_options.llm_output:
- # Currently, we include the output from the generation LLM calls.
- for activated_rail in _log.activated_rails:
- if activated_rail.type == "generation":
- for executed_action in activated_rail.executed_actions:
- for llm_call in executed_action.llm_calls:
- res.llm_output = llm_call.raw_response
+ def _process_options(
+ self,
+ options: Optional[Union[dict, GenerationOptions]],
+ state: Optional[Any],
+ ) -> Optional[GenerationOptions]:
+ """Process and normalize generation options."""
+ if state is not None:
+ if options is None:
+ return GenerationOptions()
+ elif isinstance(options, dict):
+ return GenerationOptions(**options)
else:
- if gen_options and gen_options.output_vars:
- raise ValueError(
- "The `output_vars` option is not supported for Colang 2.0 configurations."
- )
-
- log_options = gen_options.log if gen_options else None
- if log_options and (
- log_options.activated_rails
- or log_options.llm_calls
- or log_options.internal_events
- or log_options.colang_history
- ):
- raise ValueError(
- "The `log` option is not supported for Colang 2.0 configurations."
- )
-
- if gen_options and gen_options.llm_output:
- raise ValueError(
- "The `llm_output` option is not supported for Colang 2.0 configurations."
- )
-
- # Include the state
- if state is not None:
- res.state = output_state
+ return options
+ else:
+ if options and isinstance(options, dict):
+ return GenerationOptions(**options)
+ elif isinstance(options, GenerationOptions):
+ return options
+ elif options is None:
+ return None
+ else:
+ raise TypeError("options must be a dict or GenerationOptions")
- if self.config.tracing.enabled:
- # TODO: move it to the top once resolved circular dependency of eval
- # lazy import to avoid circular dependency
- from nemoguardrails.tracing import Tracer
+ @staticmethod
+ def _ensure_explain_info() -> ExplainInfo:
+ """Ensure ExplainInfo variable is present in context."""
+ from nemoguardrails.context import explain_info_var
- span_format = getattr(
- self.config.tracing, "span_format", "opentelemetry"
- )
- enable_content_capture = getattr(
- self.config.tracing, "enable_content_capture", False
- )
- # Create a Tracer instance with instantiated adapters and span configuration
- tracer = Tracer(
- input=messages,
- response=res,
- adapters=self._log_adapters,
- span_format=span_format,
- enable_content_capture=enable_content_capture,
- )
- await tracer.export_async()
+ explain_info = explain_info_var.get()
+ if explain_info is None:
+ explain_info = ExplainInfo()
+ explain_info_var.set(explain_info)
+ return explain_info
- # respect original log specification, if tracing added information to the output
- if original_log_options:
- if not any(
- (
- original_log_options.internal_events,
- original_log_options.activated_rails,
- original_log_options.llm_calls,
- original_log_options.colang_history,
- )
- ):
- res.log = None
- else:
- # Ensure res.log exists before setting attributes
- if res.log is not None:
- if not original_log_options.internal_events:
- res.log.internal_events = []
- if not original_log_options.activated_rails:
- res.log.activated_rails = []
- if not original_log_options.llm_calls:
- res.log.llm_calls = []
+ async def _handle_tracing(
+ self,
+ messages: List[dict],
+ res: GenerationResponse,
+ gen_options: Optional[GenerationOptions],
+ processing_log: List[dict],
+ all_events: List[dict],
+ ):
+ """Handle tracing export."""
+ from nemoguardrails.actions.llm.utils import get_colang_history
+ from nemoguardrails.logging.processing_log import compute_generation_log
+ from nemoguardrails.rails.llm.options import GenerationLog
+ from nemoguardrails.tracing import Tracer
+
+ span_format = getattr(self.config.tracing, "span_format", "opentelemetry")
+ enable_content_capture = getattr(
+ self.config.tracing, "enable_content_capture", False
+ )
- return res
+ # If response.log is None but tracing is enabled, create a temporary log for tracing
+ # without attaching it to the response (to avoid mutating user's response)
+ if res.log is None:
+ # Create a log from processing_log for tracing purposes
+ _log = compute_generation_log(processing_log)
+ temp_log = GenerationLog()
+ temp_log.stats = _log.stats
+ temp_log.activated_rails = _log.activated_rails or []
+ # Collect llm_calls from activated_rails
+ temp_log.llm_calls = []
+ for activated_rail in _log.activated_rails or []:
+ for executed_action in activated_rail.executed_actions:
+ temp_log.llm_calls.extend(executed_action.llm_calls)
+ # Include internal events and colang history for comprehensive tracing
+ temp_log.internal_events = all_events
+ temp_log.colang_history = get_colang_history(all_events)
+ # Create a temporary response with the log for tracing
+ temp_response = GenerationResponse(response=res.response, log=temp_log)
+ tracer = Tracer(
+ input=messages,
+ response=temp_response,
+ adapters=self._log_adapters,
+ span_format=span_format,
+ enable_content_capture=enable_content_capture,
+ )
else:
- # If a prompt is used, we only return the content of the message.
+ tracer = Tracer(
+ input=messages,
+ response=res,
+ adapters=self._log_adapters,
+ span_format=span_format,
+ enable_content_capture=enable_content_capture,
+ )
+ await tracer.export_async()
- if reasoning_content:
- thinking_trace = f"{reasoning_content}\n"
- new_message["content"] = thinking_trace + new_message["content"]
+ def generate(
+ self,
+ prompt: Optional[str] = None,
+ messages: Optional[List[dict]] = None,
+ options: Optional[Union[dict, GenerationOptions]] = None,
+ state: Optional[dict] = None,
+ ):
+ """Synchronous version of generate_async."""
+ if check_sync_call_from_async_loop():
+ raise RuntimeError(
+ "You are using the sync `generate` inside async code. "
+ "You should replace with `await generate_async(...)` or use `nest_asyncio.apply()`."
+ )
- if prompt:
- return new_message["content"]
- else:
- if tool_calls:
- new_message["tool_calls"] = tool_calls
- return new_message
+ loop = get_or_create_event_loop()
+ return loop.run_until_complete(
+ self.generate_async(
+ prompt=prompt, messages=messages, options=options, state=state
+ )
+ )
def _validate_streaming_with_output_rails(self) -> None:
+ """Validate streaming configuration with output rails."""
if len(self.config.rails.output.flows) > 0 and (
not self.config.rails.output.streaming
or not self.config.rails.output.streaming.enabled
@@ -1264,10 +563,10 @@ def stream_async(
include_generation_metadata: Optional[bool] = False,
generator: Optional[AsyncIterator[str]] = None,
) -> AsyncIterator[str]:
- """Simplified interface for getting directly the streamed tokens from the LLM."""
-
+ """Simplified interface for getting streamed tokens from the LLM."""
self._validate_streaming_with_output_rails()
- # if an external generator is provided, use it directly
+
+ # if external generator provided, use it directly
if generator:
if (
self.config.rails.output.streaming
@@ -1288,47 +587,66 @@ def stream_async(
include_generation_metadata=include_generation_metadata
)
- # Create a properly managed task with exception handling
+ # Create properly managed task with exception handling
async def _generation_task():
try:
+ # When output rails streaming is enabled, we need to skip the normal
+ # output rails execution during the main flow, as they will be run
+ # separately in _run_output_rails_in_streaming
+ generation_options = options
+ if (
+ self.config.rails.output.streaming
+ and self.config.rails.output.streaming.enabled
+ ):
+ # Disable output rails in generation_options so the llm_flows.co check prevents them from running
+ if generation_options is None:
+ generation_options = GenerationOptions(
+ rails=GenerationRailsOptions(output=False)
+ )
+ elif isinstance(generation_options, dict):
+ generation_options = dict(generation_options)
+ if "rails" not in generation_options:
+ generation_options["rails"] = {}
+ generation_options["rails"]["output"] = False
+ else:
+ # It's a GenerationOptions object
+ generation_options = generation_options.model_copy(deep=True)
+ generation_options.rails.output = False
+
await self.generate_async(
prompt=prompt,
messages=messages,
streaming_handler=streaming_handler,
- options=options,
+ options=generation_options,
state=state,
)
except Exception as e:
- # If an exception occurs during generation, push it to the streaming handler as a json string
- # This ensures the streaming pipeline is properly terminated
log.error(f"Error in generation task: {e}", exc_info=True)
+ from nemoguardrails.utils import extract_error_json
+
error_message = str(e)
error_dict = extract_error_json(error_message)
error_payload = json.dumps(error_dict)
await streaming_handler.push_chunk(error_payload)
- await streaming_handler.push_chunk(END_OF_STREAM) # type: ignore
+ await streaming_handler.push_chunk(END_OF_STREAM)
task = asyncio.create_task(_generation_task())
- # Store task reference to prevent garbage collection and ensure proper cleanup
+ # Store task reference
if not hasattr(self, "_active_tasks"):
self._active_tasks = set()
self._active_tasks.add(task)
- # Clean up task when it's done
def task_done_callback(task):
self._active_tasks.discard(task)
task.add_done_callback(task_done_callback)
- # when we have output rails we wrap the streaming handler
- # if len(self.config.rails.output.flows) > 0:
- #
+ # Wrap with output rails if configured
if (
self.config.rails.output.streaming
and self.config.rails.output.streaming.enabled
):
- # returns an async generator
return self._run_output_rails_in_streaming(
streaming_handler=streaming_handler,
output_rails_streaming_config=self.config.rails.output.streaming,
@@ -1338,69 +656,244 @@ def task_done_callback(task):
else:
return streaming_handler
- def generate(
+ async def _run_output_rails_in_streaming(
self,
+ streaming_handler: AsyncIterator[str],
+ output_rails_streaming_config: OutputRailsStreamingConfig,
prompt: Optional[str] = None,
messages: Optional[List[dict]] = None,
- options: Optional[Union[dict, GenerationOptions]] = None,
- state: Optional[dict] = None,
- ):
- """Synchronous version of generate_async."""
+ stream_first: Optional[bool] = None,
+ ) -> AsyncIterator[str]:
+ """Run output rails in streaming mode."""
+ from nemoguardrails.context import explain_info_var
+ from nemoguardrails.rails.llm.buffer import ChunkBatch
+
+ # Ensure explain_info_var is set so LLM calls during streaming output rails
+ # are tracked properly. Use self.explain_info if available, otherwise ensure
+ # it's created and set in the context. We need to ensure both reference the
+ # same object so LLM calls are tracked correctly.
+ if self.explain_info is None:
+ self.explain_info = self._ensure_explain_info()
+ # Always set the context variable to point to self.explain_info
+ explain_info_var.set(self.explain_info)
- if check_sync_call_from_async_loop():
- raise RuntimeError(
- "You are using the sync `generate` inside async code. "
- "You should replace with `await generate_async(...)` or use `nest_asyncio.apply()`."
+ # Get buffer strategy from config
+ buffer_strategy = get_buffer_strategy(output_rails_streaming_config)
+
+ # Determine stream_first behavior
+ should_stream_first = (
+ stream_first
+ if stream_first is not None
+ else output_rails_streaming_config.stream_first
+ )
+
+ # Get output flows
+ output_flows = self.config.rails.output.flows or []
+ is_parallel = self.config.rails.output.parallel
+
+ # Get action details for each flow if parallel
+ flows_with_params = {}
+ if is_parallel and output_flows:
+ try:
+ flows = self.config.flows
+ for flow_id in output_flows:
+ try:
+ action_name, action_params = get_action_details_from_flow_id(
+ flow_id, flows
+ )
+ flows_with_params[flow_id] = {
+ "action_name": action_name,
+ "params": action_params,
+ }
+ except (ValueError, KeyError) as e:
+ log.warning(
+ f"Could not get action details for flow {flow_id}: {e}"
+ )
+ except Exception as e:
+ log.error(f"Error getting flow details: {e}")
+
+ # Process stream using buffer strategy
+ log.info("Starting to process stream with buffer strategy")
+ batch_count = 0
+ async for chunk_batch in buffer_strategy.process_stream(streaming_handler):
+ batch_count += 1
+ log.info(
+ f"Received chunk_batch #{batch_count} with {len(chunk_batch.user_output_chunks)} chunks"
+ )
+ # Format the processing context
+ processing_text = buffer_strategy.format_chunks(
+ chunk_batch.processing_context
)
- loop = get_or_create_event_loop()
+ # If stream_first is True, yield chunks immediately before checking rails
+ if should_stream_first:
+ for chunk in chunk_batch.user_output_chunks:
+ yield chunk
- return loop.run_until_complete(
- self.generate_async(
- prompt=prompt,
- messages=messages,
- options=options,
- state=state,
+ # Create events with bot_message for rail processing
+ events = []
+ if messages:
+ events.extend(self.event_translator.messages_to_events(messages))
+ events.append(
+ {
+ "type": "BotMessage",
+ "text": processing_text,
+ }
+ )
+ # Add context update so actions can access bot_message
+ log.info(
+ f"Streaming output rail batch #{batch_count}: processing_text = {repr(processing_text)}"
+ )
+ events.append(
+ {
+ "type": "ContextUpdate",
+ "data": {"bot_message": processing_text},
+ }
)
- )
- async def generate_events_async(
- self,
- events: List[dict],
- ) -> List[dict]:
- """Generate the next events based on the provided history.
+ # Run output rails
+ blocked = False
+ blocking_rail = None
+ internal_error = None
- The format for events is the following:
+ if is_parallel and flows_with_params:
+ # Run parallel rails (only available in Colang 1.0)
+ if hasattr(self.runtime, "_run_output_rails_in_parallel_streaming"):
+ result = await self.runtime._run_output_rails_in_parallel_streaming( # type: ignore[attr-defined]
+ flows_with_params, events
+ )
+ else:
+ raise RuntimeError(
+ "Parallel streaming output rails are only supported in Colang 1.0"
+ )
+ if result.events:
+ event = result.events[0]
+ if event.get("error_type") == "internal_error":
+ internal_error = event.get("error_message")
+ blocked = True
+ blocking_rail = event.get("flow_id")
+ elif event.get("intent") == "stop":
+ blocked = True
+ blocking_rail = event.get("flow_id")
+ else:
+ # Run sequential rails
+ for flow_id in output_flows:
+ try:
+ # Ensure explain_info_var is set before processing events
+ # so LLM calls are tracked properly
+ if self.explain_info is None:
+ self.explain_info = self._ensure_explain_info()
+ explain_info_var.set(self.explain_info)
+
+ # Create start event
+ start_event = {
+ "type": "StartOutputRail",
+ "flow_id": flow_id,
+ }
+ rail_events = events + [start_event]
- ```python
- [
- {"type": "...", ...},
- ...
- ]
- ```
+ # Process the rail
+ (
+ new_events,
+ _,
+ ) = await self.runtime_orchestrator.process_events_async(
+ rail_events, blocking=True
+ )
- Args:
- events: The history of events to be used to generate the next events.
- options: The options to be used for the generation.
+ # Check if rail blocked (look for stop intent or exception)
+ for event in new_events:
+ if (
+ event.get("type") == "BotIntent"
+ and event.get("intent") == "stop"
+ ):
+ blocked = True
+ blocking_rail = flow_id
+ break
+ elif event.get("type", "").endswith("Exception"):
+ blocked = True
+ blocking_rail = flow_id
+ break
+ # Also check for action results with output_mapping
+ elif event.get("type") == "InternalSystemActionFinished":
+ action_name = event.get("action_name")
+ return_value = event.get("return_value")
+ log.info(
+ f"Sequential rail {flow_id}: action {action_name} returned {return_value}"
+ )
+ if action_name and return_value is not None:
+ action_func = (
+ self.runtime.action_dispatcher.get_action(
+ action_name
+ )
+ )
+ if action_func:
+ is_blocked = is_output_blocked(
+ return_value, action_func
+ )
+ log.info(
+ f"Action {action_name} output_blocked={is_blocked}"
+ )
+ if is_blocked:
+ blocked = True
+ blocking_rail = flow_id
+ break
+
+ if blocked:
+ break
+ except Exception as e:
+ log.error(f"Error in sequential rail {flow_id}: {e}")
+ blocked = True
+ blocking_rail = flow_id
+ internal_error = str(e)
+ break
+
+ # If blocked, yield error and stop
+ if blocked:
+ # Create error message
+ if internal_error:
+ error_dict = {
+ "error": {
+ "message": f"Internal error in {blocking_rail} rail: {internal_error}",
+ "type": "internal_error",
+ "param": blocking_rail,
+ "code": "rail_execution_failure",
+ }
+ }
+ else:
+ error_dict = {
+ "error": {
+ "message": f"Blocked by {blocking_rail} rails.",
+ "type": "guardrails_violation",
+ "param": blocking_rail,
+ "code": "content_blocked",
+ }
+ }
+ yield json.dumps(error_dict)
+ return
- Returns:
- The newly generate event(s).
+ # If not blocked and stream_first is False, yield chunks after rails check
+ if not should_stream_first:
+ for chunk in chunk_batch.user_output_chunks:
+ yield chunk
+
+ async def generate_events_async(self, events: List[dict]) -> List[dict]:
+ """Generate the next events based on the provided history."""
+ import time
+
+ from nemoguardrails.actions.llm.utils import get_colang_history
+ from nemoguardrails.context import llm_stats_var
+ from nemoguardrails.logging.stats import LLMStats
- """
t0 = time.time()
- # Initialize the LLM stats
llm_stats = LLMStats()
llm_stats_var.set(llm_stats)
- # Compute the new events.
processing_log = []
new_events = await self.runtime.generate_events(
events, processing_log=processing_log
)
- # If logging is enabled, we log the conversation
- # TODO: add support for logging flag
if self.verbose:
history = get_colang_history(events)
log.info(f"Conversation history so far: \n{history}")
@@ -1410,16 +903,12 @@ async def generate_events_async(
return new_events
- def generate_events(
- self,
- events: List[dict],
- ) -> List[dict]:
- """Synchronous version of `LLMRails.generate_events_async`."""
-
+ def generate_events(self, events: List[dict]) -> List[dict]:
+ """Synchronous version of generate_events_async."""
if check_sync_call_from_async_loop():
raise RuntimeError(
"You are using the sync `generate_events` inside async code. "
- "You should replace with `await generate_events_async(...)` or use `nest_asyncio.apply()`."
+ "You should replace with `await generate_events_async(...)`."
)
loop = get_or_create_event_loop()
@@ -1431,38 +920,10 @@ async def process_events_async(
state: Optional[dict] = None,
blocking: bool = False,
) -> Tuple[List[dict], dict]:
- """Process a sequence of events in a given state.
-
- The events will be processed one by one, in the input order.
-
- Args:
- events: A sequence of events that needs to be processed.
- state: The state that should be used as the starting point. If not provided,
- a clean state will be used.
-
- Returns:
- (output_events, output_state) Returns a sequence of output events and an output
- state.
- """
- t0 = time.time()
- llm_stats = LLMStats()
- llm_stats_var.set(llm_stats)
-
- # Compute the new events.
- # We need to protect 'process_events' to be called only once at a time
- # TODO (cschueller): Why is this?
- async with process_events_semaphore:
- output_events, output_state = await self.runtime.process_events(
- events, state, blocking
- )
-
- took = time.time() - t0
- # Small tweak, disable this when there were no events (or it was just too fast).
- if took > 0.1:
- log.info("--- :: Total processing took %.2f seconds." % took)
- log.info("--- :: Stats: %s" % llm_stats)
-
- return output_events, output_state
+ """Process a sequence of events in a given state."""
+ return await self.runtime_orchestrator.process_events_async(
+ events, state, blocking
+ )
def process_events(
self,
@@ -1470,12 +931,11 @@ def process_events(
state: Optional[dict] = None,
blocking: bool = False,
) -> Tuple[List[dict], dict]:
- """Synchronous version of `LLMRails.process_events_async`."""
-
+ """Synchronous version of process_events_async."""
if check_sync_call_from_async_loop():
raise RuntimeError(
- "You are using the sync `generate_events` inside async code. "
- "You should replace with `await generate_events_async(...)."
+ "You are using the sync `process_events` inside async code. "
+ "You should replace with `await process_events_async(...)`."
)
loop = get_or_create_event_loop()
@@ -1483,70 +943,67 @@ def process_events(
self.process_events_async(events, state, blocking)
)
+ # Registration methods
+
def register_action(self, action: Callable, name: Optional[str] = None) -> Self:
- """Register a custom action for the rails configuration."""
+ """Register a custom action."""
self.runtime.register_action(action, name)
return self
def register_action_param(self, name: str, value: Any) -> Self:
- """Registers a custom action parameter."""
+ """Register a custom action parameter."""
self.runtime.register_action_param(name, value)
return self
def register_filter(self, filter_fn: Callable, name: Optional[str] = None) -> Self:
- """Register a custom filter for the rails configuration."""
+ """Register a custom filter."""
self.runtime.llm_task_manager.register_filter(filter_fn, name)
return self
def register_output_parser(self, output_parser: Callable, name: str) -> Self:
- """Register a custom output parser for the rails configuration."""
+ """Register a custom output parser."""
self.runtime.llm_task_manager.register_output_parser(output_parser, name)
return self
def register_prompt_context(self, name: str, value_or_fn: Any) -> Self:
- """Register a value to be included in the prompt context.
-
- :name: The name of the variable or function that will be used.
- :value_or_fn: The value or function that will be used to generate the value.
- """
+ """Register a value to be included in the prompt context."""
self.runtime.llm_task_manager.register_prompt_context(name, value_or_fn)
return self
def register_embedding_search_provider(
self, name: str, cls: Type[EmbeddingsIndex]
) -> Self:
- """Register a new embedding search provider.
-
- Args:
- name: The name of the embedding search provider that will be used.
- cls: The class that will be used to generate and search embedding
- """
-
+ """Register a new embedding search provider."""
self.embedding_search_providers[name] = cls
return self
def register_embedding_provider(
self, cls: Type[EmbeddingModel], name: Optional[str] = None
) -> Self:
- """Register a custom embedding provider.
-
- Args:
- model (Type[EmbeddingModel]): The embedding model class.
- name (str): The name of the embedding engine. If available in the model, it will be used.
-
- Raises:
- ValueError: If the engine name is not provided and the model does not have an engine name.
- ValueError: If the model does not have 'encode' or 'encode_async' methods.
- """
+ """Register a custom embedding provider."""
register_embedding_provider(engine_name=name, model=cls)
return self
def explain(self) -> ExplainInfo:
- """Helper function to return the latest ExplainInfo object."""
+ """Return the latest ExplainInfo object."""
if self.explain_info is None:
self.explain_info = self._ensure_explain_info()
return self.explain_info
+ def _prepare_model_kwargs(self, model_config) -> dict:
+ """Prepare kwargs for model initialization, including API key from environment.
+
+ This method is maintained for backwards compatibility.
+ It delegates to ModelFactory._prepare_model_kwargs.
+
+ Args:
+ model_config: The model configuration object.
+
+ Returns:
+ Dictionary of kwargs for model initialization.
+ """
+ return self.model_factory._prepare_model_kwargs(model_config)
+
def __getstate__(self):
return {"config": self.config}
@@ -1557,270 +1014,11 @@ def __setstate__(self, state):
config = state["config"]
self.__init__(config=config, verbose=False)
- async def _run_output_rails_in_streaming(
- self,
- streaming_handler: AsyncIterator[str],
- output_rails_streaming_config: OutputRailsStreamingConfig,
- prompt: Optional[str] = None,
- messages: Optional[List[dict]] = None,
- stream_first: Optional[bool] = None,
- ) -> AsyncIterator[str]:
- """
- 1. Buffers tokens from 'streaming_handler' via BufferStrategy.
- 2. Runs sequential (parallel for colang 2.0 in future) flows for each chunk.
- 3. Yields the chunk if not blocked, or STOP if blocked.
- """
-
- def _get_last_context_message(
- messages: Optional[List[dict]] = None,
- ) -> dict:
- if messages is None:
- return {}
-
- for message in reversed(messages):
- if message.get("role") == "context":
- return message
- return {}
-
- def _get_latest_user_message(
- messages: Optional[List[dict]] = None,
- ) -> dict:
- if messages is None:
- return {}
- for message in reversed(messages):
- if message.get("role") == "user":
- return message
- return {}
-
- def _prepare_context_for_parallel_rails(
- chunk_str: str,
- prompt: Optional[str] = None,
- messages: Optional[List[dict]] = None,
- ) -> dict:
- """Prepare context for parallel rails execution."""
- context_message = _get_last_context_message(messages)
- user_message = prompt or _get_latest_user_message(messages)
-
- context = {
- "user_message": user_message,
- "bot_message": chunk_str,
- }
-
- if context_message:
- context.update(context_message["content"])
-
- return context
-
- def _create_events_for_chunk(chunk_str: str, context: dict) -> List[dict]:
- """Create events for running output rails on a chunk."""
- return [
- {"type": "ContextUpdate", "data": context},
- {"type": "BotMessage", "text": chunk_str},
- ]
-
- def _prepare_params(
- flow_id: str,
- action_name: str,
- bot_response_chunk: str,
- prompt: Optional[str] = None,
- messages: Optional[List[dict]] = None,
- action_params: Dict[str, Any] = {},
- ):
- context_message = _get_last_context_message(messages)
- user_message = prompt or _get_latest_user_message(messages)
-
- context = {
- "user_message": user_message,
- "bot_message": bot_response_chunk,
- }
-
- if context_message:
- context.update(context_message["content"])
-
- model_name = flow_id.split("$")[-1].split("=")[-1].strip('"')
-
- # we pass action params that are defined in the flow
- # caveate, e.g. prmpt_security uses bot_response=$bot_message
- # to resolve replace placeholders in action_params
- for key, value in action_params.items():
- if value == "$bot_message":
- action_params[key] = bot_response_chunk
- elif value == "$user_message":
- action_params[key] = user_message
-
- return {
- # TODO:: are there other context variables that need to be passed?
- # passing events to compute context was not successful
- # context var failed due to different context
- "context": context,
- "llm_task_manager": self.runtime.llm_task_manager,
- "config": self.config,
- "model_name": model_name,
- "llms": self.runtime.registered_action_params.get("llms", {}),
- "llm": self.runtime.registered_action_params.get(
- f"{action_name}_llm", self.llm
- ),
- **action_params,
- }
-
- buffer_strategy = get_buffer_strategy(output_rails_streaming_config)
- output_rails_flows_id = self.config.rails.output.flows
- stream_first = stream_first or output_rails_streaming_config.stream_first
- get_action_details = partial(
- get_action_details_from_flow_id, flows=self.config.flows
- )
-
- parallel_mode = getattr(self.config.rails.output, "parallel", False)
-
- async for chunk_batch in buffer_strategy(streaming_handler):
- user_output_chunks = chunk_batch.user_output_chunks
- # format processing_context for output rails processing (needs full context)
- bot_response_chunk = buffer_strategy.format_chunks(
- chunk_batch.processing_context
- )
-
- # check if user_output_chunks is a list of individual chunks
- # or if it's a JSON string, by convention this means an error occurred and the error dict is stored as a JSON
- if not isinstance(user_output_chunks, list):
- try:
- json.loads(user_output_chunks)
- yield user_output_chunks
- return
- except (json.JSONDecodeError, TypeError):
- # if it's not JSON, treat it as empty list
- user_output_chunks = []
-
- if stream_first:
- # yield the individual chunks directly from the buffer strategy
- for chunk in user_output_chunks:
- yield chunk
-
- if parallel_mode:
- try:
- context = _prepare_context_for_parallel_rails(
- bot_response_chunk, prompt, messages
- )
- events = _create_events_for_chunk(bot_response_chunk, context)
-
- flows_with_params = {}
- for flow_id in output_rails_flows_id:
- action_name, action_params = get_action_details(flow_id)
- params = _prepare_params(
- flow_id=flow_id,
- action_name=action_name,
- bot_response_chunk=bot_response_chunk,
- prompt=prompt,
- messages=messages,
- action_params=action_params,
- )
- flows_with_params[flow_id] = {
- "action_name": action_name,
- "params": params,
- }
-
- result_tuple = await self.runtime.action_dispatcher.execute_action(
- "run_output_rails_in_parallel_streaming",
- {
- "flows_with_params": flows_with_params,
- "events": events,
- },
- )
- # ActionDispatcher.execute_action always returns (result, status)
- result, status = result_tuple
-
- if status != "success":
- log.error(
- f"Parallel rails execution failed with status: {status}"
- )
- # continue processing the chunk even if rails fail
- pass
- else:
- # if there are any stop events, content was blocked or internal error occurred
- result_events = getattr(result, "events", None)
- if result_events:
- # extract the flow info from the first stop event
- stop_event = result_events[0]
- blocked_flow = stop_event.get("flow_id", "output rails")
- error_type = stop_event.get("error_type")
-
- if error_type == "internal_error":
- error_message = stop_event.get(
- "error_message", "Unknown error"
- )
- reason = f"Internal error in {blocked_flow} rail: {error_message}"
- error_code = "rail_execution_failure"
- error_type = "internal_error"
- else:
- reason = f"Blocked by {blocked_flow} rails."
- error_code = "content_blocked"
- error_type = "guardrails_violation"
-
- error_data = {
- "error": {
- "message": reason,
- "type": error_type,
- "param": blocked_flow,
- "code": error_code,
- }
- }
- yield json.dumps(error_data)
- return
-
- except Exception as e:
- log.error(f"Error in parallel rail execution: {e}")
- # don't block the stream for rail execution errors
- # continue processing the chunk
- pass
-
- # update explain info for parallel mode
- self.explain_info = self._ensure_explain_info()
-
- else:
- for flow_id in output_rails_flows_id:
- action_name, action_params = get_action_details(flow_id)
-
- params = _prepare_params(
- flow_id=flow_id,
- action_name=action_name,
- bot_response_chunk=bot_response_chunk,
- prompt=prompt,
- messages=messages,
- action_params=action_params,
- )
-
- result = await self.runtime.action_dispatcher.execute_action(
- action_name, params
- )
- self.explain_info = self._ensure_explain_info()
-
- action_func = self.runtime.action_dispatcher.get_action(action_name)
-
- # Use the mapping to decide if the result indicates blocked content.
- if is_output_blocked(result, action_func):
- reason = f"Blocked by {flow_id} rails."
-
- # return the error as a plain JSON string (not in SSE format)
- # NOTE: When integrating with the OpenAI Python client, the server code should:
- # 1. detect this JSON error object in the stream
- # 2. terminate the stream
- # 3. format the error following OpenAI's SSE format
- # the OpenAI client will then properly raise an APIError with this error message
-
- error_data = {
- "error": {
- "message": reason,
- "type": "guardrails_violation",
- "param": flow_id,
- "code": "content_blocked",
- }
- }
-
- # return as plain JSON: the server should detect this JSON and convert it to an HTTP error
- yield json.dumps(error_data)
- return
-
- if not stream_first:
- # yield the individual chunks directly from the buffer strategy
- for chunk in user_output_chunks:
- yield chunk
+# Re-export for backwards compatibility
+__all__ = [
+ "LLMRails",
+ "get_action_details_from_flow_id",
+ "GenerationOptions",
+ "GenerationResponse",
+]
diff --git a/nemoguardrails/rails/llm/model_factory.py b/nemoguardrails/rails/llm/model_factory.py
new file mode 100644
index 000000000..48e0a851e
--- /dev/null
+++ b/nemoguardrails/rails/llm/model_factory.py
@@ -0,0 +1,261 @@
+import logging
+import os
+from typing import Any, Dict, Optional, Union
+
+from langchain_core.language_models import BaseChatModel
+from langchain_core.language_models.llms import BaseLLM
+
+from nemoguardrails.llm.cache import CacheInterface, LFUCache
+from nemoguardrails.llm.models.initializer import init_llm_model
+from nemoguardrails.rails.llm.config import RailsConfig
+
+log = logging.getLogger(__name__)
+
+
+class ModelFactory:
+ """Factory for initializing and configuring LLM models."""
+
+ def __init__(
+ self,
+ config: RailsConfig,
+ injected_llm: Optional[Union[BaseLLM, BaseChatModel]] = None,
+ ):
+ """Initialize the ModelFactory.
+
+ Args:
+ config: The rails configuration.
+ injected_llm: An optional LLM provided via constructor that takes precedence.
+ """
+ self.config = config
+ self.injected_llm = injected_llm
+ self.main_llm: Optional[Union[BaseLLM, BaseChatModel]] = None
+ self.specialized_llms: Dict[str, Union[BaseLLM, BaseChatModel]] = {}
+ self.model_caches: Dict[str, CacheInterface] = {}
+ self.main_llm_supports_streaming = False
+
+ def initialize_models(
+ self, action_param_registry: Dict[str, Any]
+ ) -> Dict[str, Union[BaseLLM, BaseChatModel]]:
+ """Initialize all LLM models from configuration.
+
+ Args:
+ action_param_registry: Registry to store action parameters.
+
+ Returns:
+ Dictionary of specialized LLMs (excluding main).
+ """
+ # Handle injected LLM first
+ if self.injected_llm:
+ self.main_llm = self.injected_llm
+ action_param_registry["llm"] = self.main_llm
+ self._configure_streaming(self.main_llm)
+
+ # Warn if main LLM also specified in config
+ if any(model.type == "main" for model in self.config.models):
+ log.warning(
+ "Both an LLM was provided via constructor and a main LLM is specified in the config. "
+ "The LLM provided via constructor will be used and the main LLM from config will be ignored."
+ )
+ else:
+ # Initialize main LLM from config
+ main_model = next(
+ (model for model in self.config.models if model.type == "main"), None
+ )
+
+ if main_model and main_model.model:
+ kwargs = self._prepare_model_kwargs(main_model)
+ self.main_llm = init_llm_model(
+ model_name=main_model.model,
+ provider_name=main_model.engine,
+ mode="chat",
+ kwargs=kwargs,
+ )
+ action_param_registry["llm"] = self.main_llm
+ self._configure_streaming(
+ self.main_llm,
+ model_name=main_model.model,
+ provider_name=main_model.engine,
+ )
+ else:
+ log.warning(
+ "No main LLM specified in the config and no LLM provided via constructor."
+ )
+
+ # Initialize specialized LLMs
+ for llm_config in self.config.models:
+ if llm_config.type in ["embeddings", "jailbreak_detection"]:
+ continue
+
+ # Skip main model - already initialized above
+ if llm_config.type == "main":
+ continue
+
+ model_name = llm_config.model
+ if not model_name:
+ raise ValueError(
+ f"LLM Config model field not set for {llm_config.type}"
+ )
+
+ provider_name = llm_config.engine
+ kwargs = self._prepare_model_kwargs(llm_config)
+ mode = llm_config.mode
+
+ llm_model = init_llm_model(
+ model_name=model_name,
+ provider_name=provider_name,
+ mode=mode,
+ kwargs=kwargs,
+ )
+
+ # Configure based on type
+ if llm_config.type == "main":
+ if not self.main_llm:
+ self.main_llm = llm_model
+ action_param_registry["llm"] = self.main_llm
+ else:
+ param_name = f"{llm_config.type}_llm"
+ self.specialized_llms[llm_config.type] = llm_model
+ action_param_registry[param_name] = llm_model
+
+ # Register specialized LLMs dictionary
+ action_param_registry["llms"] = self.specialized_llms
+
+ # Initialize model caches
+ self._initialize_model_caches(action_param_registry)
+
+ return self.specialized_llms
+
+ def _prepare_model_kwargs(self, model_config) -> dict:
+ """Prepare kwargs for model initialization, including API key from environment.
+
+ Args:
+ model_config: The model configuration object.
+
+ Returns:
+ Dictionary of kwargs for model initialization.
+ """
+ kwargs = model_config.parameters or {}
+
+ # Add API key from environment if specified
+ if model_config.api_key_env_var:
+ api_key = os.environ.get(model_config.api_key_env_var)
+ if api_key:
+ kwargs["api_key"] = api_key
+
+ # Enable streaming token usage when streaming is enabled
+ if self.config.streaming:
+ kwargs["stream_usage"] = True
+
+ return kwargs
+
+ def _configure_streaming(
+ self,
+ llm: Union[BaseLLM, BaseChatModel],
+ model_name: Optional[str] = None,
+ provider_name: Optional[str] = None,
+ ):
+ """Configure streaming support for the LLM.
+
+ Args:
+ llm: The LLM model instance.
+ model_name: Optional model name for logging.
+ provider_name: Optional provider name for logging.
+ """
+ if not self.config.streaming:
+ return
+
+ if hasattr(llm, "streaming"):
+ setattr(llm, "streaming", True)
+ self.main_llm_supports_streaming = True
+ else:
+ self.main_llm_supports_streaming = False
+ if model_name and provider_name:
+ log.warning(
+ "Model %s from provider %s does not support streaming.",
+ model_name,
+ provider_name,
+ )
+ else:
+ log.warning("Provided main LLM does not support streaming.")
+
+ def _create_model_cache(self, model) -> LFUCache:
+ """Create cache instance for a model based on its configuration.
+
+ Args:
+ model: The model configuration object.
+
+ Returns:
+ LFUCache: The cache instance.
+ """
+ if model.cache.maxsize <= 0:
+ raise ValueError(
+ f"Invalid cache maxsize for model '{model.type}': {model.cache.maxsize}. "
+ "Capacity must be greater than 0. Skipping cache creation."
+ )
+
+ stats_logging_interval = None
+ if model.cache.stats.enabled and model.cache.stats.log_interval is not None:
+ stats_logging_interval = model.cache.stats.log_interval
+
+ cache = LFUCache(
+ maxsize=model.cache.maxsize,
+ track_stats=model.cache.stats.enabled,
+ stats_logging_interval=stats_logging_interval,
+ )
+
+ log.info(
+ "Created cache for model '%s' with maxsize %s",
+ model.type,
+ model.cache.maxsize,
+ )
+
+ return cache
+
+ def _initialize_model_caches(self, action_param_registry: Dict[str, Any]) -> None:
+ """Initialize caches for configured models.
+
+ Args:
+ action_param_registry: Registry to store action parameters.
+ """
+ for model in self.config.models:
+ if model.type in ["main", "embeddings"]:
+ continue
+
+ if model.cache and model.cache.enabled:
+ cache = self._create_model_cache(model)
+ self.model_caches[model.type] = cache
+
+ log.info(
+ "Initialized model '%s' with cache %s",
+ model.type,
+ "enabled" if cache else "disabled",
+ )
+
+ if self.model_caches:
+ action_param_registry["model_caches"] = self.model_caches
+
+ def get_main_llm(self) -> Optional[Union[BaseLLM, BaseChatModel]]:
+ """Get the main LLM instance."""
+ return self.main_llm
+
+ def get_specialized_llm(
+ self, llm_type: str
+ ) -> Optional[Union[BaseLLM, BaseChatModel]]:
+ """Get a specialized LLM by type."""
+ return self.specialized_llms.get(llm_type)
+
+ def supports_streaming(self) -> bool:
+ """Check if the main LLM supports streaming."""
+ return self.main_llm_supports_streaming
+
+ def update_main_llm(
+ self, llm: Union[BaseLLM, BaseChatModel], action_param_registry: Dict[str, Any]
+ ):
+ """Update the main LLM instance.
+
+ Args:
+ llm: The new LLM instance.
+ action_param_registry: Registry to update action parameters.
+ """
+ self.main_llm = llm
+ action_param_registry["llm"] = llm
diff --git a/nemoguardrails/rails/llm/rails_api.py b/nemoguardrails/rails/llm/rails_api.py
new file mode 100644
index 000000000..f447fb178
--- /dev/null
+++ b/nemoguardrails/rails/llm/rails_api.py
@@ -0,0 +1,570 @@
+# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+# SPDX-License-Identifier: Apache-2.0
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""Rails API - Main facade for NeMo Guardrails functionality."""
+
+import asyncio
+import logging
+import threading
+import time
+from typing import Any, AsyncIterator, Callable, Dict, List, Optional, Type, Union
+
+from langchain_core.language_models import BaseChatModel
+from langchain_core.language_models.llms import BaseLLM
+from typing_extensions import Self
+
+from nemoguardrails.actions.llm.generation import LLMGenerationActions
+from nemoguardrails.actions.llm.utils import get_colang_history
+from nemoguardrails.actions.v2_x.generation import LLMGenerationActionsV2dotx
+from nemoguardrails.colang.v1_0.runtime.runtime import Runtime
+from nemoguardrails.colang.v2_x.runtime.flows import State
+from nemoguardrails.context import (
+ explain_info_var,
+ generation_options_var,
+ llm_stats_var,
+ raw_llm_request,
+ streaming_handler_var,
+)
+from nemoguardrails.embeddings.index import EmbeddingsIndex
+from nemoguardrails.embeddings.providers.base import EmbeddingModel
+from nemoguardrails.logging.explain import ExplainInfo
+from nemoguardrails.logging.stats import LLMStats
+from nemoguardrails.logging.verbose import set_verbose
+from nemoguardrails.patch_asyncio import check_sync_call_from_async_loop
+from nemoguardrails.rails.llm.config import RailsConfig
+from nemoguardrails.rails.llm.event_translator import EventTranslator
+from nemoguardrails.rails.llm.kb_builder import KnowledgeBaseBuilder
+from nemoguardrails.rails.llm.model_factory import ModelFactory
+from nemoguardrails.rails.llm.options import GenerationOptions, GenerationResponse
+from nemoguardrails.rails.llm.response_assembler import ResponseAssembler
+from nemoguardrails.rails.llm.runtime_orchestrator import RuntimeOrchestrator
+from nemoguardrails.streaming import END_OF_STREAM, StreamingHandler
+from nemoguardrails.utils import get_or_create_event_loop
+from nemoguardrails.colang.v2_x.runtime.serialization import json_to_state
+
+log = logging.getLogger(__name__)
+
+
+class RailsAPI:
+ """
+ Main API facade for NeMo Guardrails.
+ """
+
+ def __init__(
+ self,
+ config: RailsConfig,
+ llm: Optional[Union[BaseLLM, BaseChatModel]] = None,
+ verbose: bool = False,
+ ):
+ """Initialize the RailsAPI.
+
+ Args:
+ config: A rails configuration.
+ llm: An optional LLM engine to use.
+ verbose: Whether the logging should be verbose.
+ """
+ self.config = config
+ self.verbose = verbose
+ self.explain_info: Optional[ExplainInfo] = None
+
+ if self.verbose:
+ set_verbose(True, llm_calls=True)
+
+ # Initialize embedding configuration
+ self.embedding_search_providers = {}
+ self.default_embedding_model = "all-MiniLM-L6-v2"
+ self.default_embedding_engine = "FastEmbed"
+ self.default_embedding_params = {}
+
+ # Initialize components
+ self._init_components(llm)
+
+ def _init_components(self, llm: Optional[Union[BaseLLM, BaseChatModel]]):
+ """Initialize all components.
+
+ Args:
+ llm: Optional LLM to inject.
+ """
+ # 1. is already done via config parameter
+ # Additional config loading/processing could be added here
+
+ # 2. Initialize RuntimeOrchestrator first (needs to be available for other components)
+ self.runtime_orchestrator = RuntimeOrchestrator(
+ config=self.config, verbose=self.verbose
+ )
+ self.runtime: Runtime = self.runtime_orchestrator.runtime
+
+ # Registry for action parameters
+ self.action_param_registry: Dict[str, Any] = {}
+
+ # 3. Initialize ModelFactory
+ self.model_factory = ModelFactory(config=self.config, injected_llm=llm)
+
+ # Update embedding model if specified in config
+ for model in self.config.models:
+ if model.type == "embeddings":
+ self.default_embedding_model = model.model
+ self.default_embedding_engine = model.engine
+ self.default_embedding_params = model.parameters or {}
+ break
+
+ # Initialize models and populate action registry
+ self.model_factory.initialize_models(self.action_param_registry)
+
+ # Register action parameters with runtime
+ for param_name, param_value in self.action_param_registry.items():
+ self.runtime.register_action_param(param_name, param_value)
+
+ # 4. Initialize LLM Generation Actions
+ llm_generation_actions_class = (
+ LLMGenerationActions
+ if self.config.colang_version == "1.0"
+ else LLMGenerationActionsV2dotx
+ )
+ self.llm_generation_actions = llm_generation_actions_class(
+ config=self.config,
+ llm=self.model_factory.get_main_llm(),
+ llm_task_manager=self.runtime.llm_task_manager,
+ get_embedding_search_provider_instance=self._get_embeddings_search_provider_instance,
+ verbose=self.verbose,
+ )
+ self.runtime.register_actions(self.llm_generation_actions, override=False)
+
+ # 5. Initialize KnowledgeBaseBuilder
+ self.kb_builder = KnowledgeBaseBuilder(
+ config=self.config,
+ get_embeddings_search_provider_instance=self._get_embeddings_search_provider_instance,
+ )
+
+ # Initialize KB (in separate thread to avoid async issues)
+ loop = get_or_create_event_loop()
+ if True or check_sync_call_from_async_loop():
+ t = threading.Thread(target=asyncio.run, args=(self.kb_builder.build(),))
+ t.start()
+ t.join()
+ else:
+ loop.run_until_complete(self.kb_builder.build())
+
+ # Register KB as action parameter
+ self.runtime.register_action_param("kb", self.kb_builder.get_kb())
+
+ # 6. Initialize EventTranslator
+ self.event_translator = EventTranslator(config=self.config)
+
+ # 7. Initialize ResponseAssembler
+ self.response_assembler = ResponseAssembler(config=self.config)
+
+ # Initialize tracing adapters if configured
+ if self.config.tracing:
+ from nemoguardrails.tracing import create_log_adapters
+
+ self._log_adapters = create_log_adapters(self.config.tracing)
+ else:
+ self._log_adapters = None
+
+ def _get_embeddings_search_provider_instance(self, esp_config=None):
+ """Get an embeddings search provider instance."""
+ from nemoguardrails.rails.llm.config import EmbeddingSearchProvider
+
+ if esp_config is None:
+ esp_config = EmbeddingSearchProvider()
+
+ if esp_config.name == "default":
+ from nemoguardrails.embeddings.basic import BasicEmbeddingsIndex
+
+ return BasicEmbeddingsIndex(
+ embedding_model=esp_config.parameters.get(
+ "embedding_model", self.default_embedding_model
+ ),
+ embedding_engine=esp_config.parameters.get(
+ "embedding_engine", self.default_embedding_engine
+ ),
+ embedding_params=esp_config.parameters.get(
+ "embedding_parameters", self.default_embedding_params
+ ),
+ cache_config=esp_config.cache,
+ **{
+ k: v
+ for k, v in esp_config.parameters.items()
+ if k
+ in [
+ "use_batching",
+ "max_batch_size",
+ "matx_batch_hold",
+ "search_threshold",
+ ]
+ and v is not None
+ },
+ )
+ else:
+ if esp_config.name not in self.embedding_search_providers:
+ raise Exception(f"Unknown embedding search provider: {esp_config.name}")
+ else:
+ kwargs = esp_config.parameters
+ return self.embedding_search_providers[esp_config.name](**kwargs)
+
+ @staticmethod
+ def _ensure_explain_info() -> ExplainInfo:
+ """Ensure that the ExplainInfo variable is present in the current context."""
+ explain_info = explain_info_var.get()
+ if explain_info is None:
+ explain_info = ExplainInfo()
+ explain_info_var.set(explain_info)
+ return explain_info
+
+ async def generate_async(
+ self,
+ prompt: Optional[str] = None,
+ messages: Optional[List[dict]] = None,
+ options: Optional[Union[dict, GenerationOptions]] = None,
+ state: Optional[Union[dict, State]] = None,
+ streaming_handler: Optional[StreamingHandler] = None,
+ ) -> Union[str, dict, GenerationResponse]:
+ """Generate a completion or next message.
+
+ Args:
+ prompt: The prompt to be used for completion.
+ messages: The history of messages.
+ options: Options specific for the generation.
+ state: The state object.
+ streaming_handler: Optional streaming handler.
+
+ Returns:
+ The completion or next message.
+ """
+ # Input validation
+ if prompt is None and messages is None:
+ raise ValueError("Either prompt or messages must be provided.")
+ if prompt is not None and messages is not None:
+ raise ValueError("Only one of prompt or messages can be provided.")
+
+ # Convert prompt to messages format
+ if prompt is not None:
+ messages = [{"role": "user", "content": prompt}]
+
+ # Handle state deserialization
+ if state is not None:
+ if isinstance(state, dict) and state.get("version", "1.0") == "2.x":
+ state = json_to_state(state["state"])
+
+ # Process options
+ gen_options = self._process_options(options, state)
+ generation_options_var.set(gen_options)
+
+ if streaming_handler:
+ streaming_handler_var.set(streaming_handler)
+
+ # Initialize explain info
+ self.explain_info = self._ensure_explain_info()
+ raw_llm_request.set(messages)
+
+ # Inject generation options into messages
+ if gen_options:
+ messages = [
+ {
+ "role": "context",
+ "content": {"generation_options": gen_options.model_dump()},
+ }
+ ] + (messages or [])
+
+ # Handle bot message in context for non-dialog mode
+ if (
+ messages
+ and messages[-1]["role"] == "assistant"
+ and gen_options
+ and gen_options.rails.dialog is False
+ ):
+ messages[0]["content"]["bot_message"] = messages[-1]["content"]
+ messages = messages[0:-1]
+
+ t0 = time.time()
+
+ # Initialize LLM stats
+ llm_stats = LLMStats()
+ llm_stats_var.set(llm_stats)
+
+ # Translate messages to events
+ if messages is None:
+ raise ValueError("messages must be provided")
+ else:
+ events = self.event_translator.messages_to_events(messages, state)
+
+ # Generate new events using runtime orchestrator
+ try:
+ (
+ new_events,
+ output_state,
+ processing_log,
+ ) = await self.runtime_orchestrator.generate_events(
+ events=events, state=state
+ )
+ except Exception as e:
+ log.error("Error in generate_async: %s", e, exc_info=True)
+ streaming_handler = streaming_handler_var.get()
+ if streaming_handler:
+ import json
+
+ from nemoguardrails.utils import extract_error_json
+
+ error_message = str(e)
+ error_dict = extract_error_json(error_message)
+ error_payload = json.dumps(error_dict)
+ await streaming_handler.push_chunk(error_payload)
+ await streaming_handler.push_chunk(END_OF_STREAM) # type: ignore
+ raise
+
+ # Update event translator cache for Colang 1.0
+ if self.config.colang_version == "1.0":
+ # Build the new message for caching
+ responses, _, _, _ = self.response_assembler._extract_from_events(
+ new_events
+ )
+ responses = [str(r) if not isinstance(r, str) else r for r in responses]
+ new_message = {"role": "assistant", "content": "\n".join(responses)}
+
+ if state is None:
+ all_events = events + new_events
+ self.event_translator.cache_events(
+ (messages or []) + [new_message], all_events
+ )
+
+ # Log conversation history
+ all_events = (
+ events + new_events if self.config.colang_version == "1.0" else new_events
+ )
+ self.explain_info.colang_history = get_colang_history(all_events)
+ if self.verbose:
+ log.info(
+ f"Conversation history so far: \n{self.explain_info.colang_history}"
+ )
+
+ total_time = time.time() - t0
+ log.info(
+ "--- :: Total processing took %.2f seconds. LLM Stats: %s"
+ % (total_time, llm_stats)
+ )
+
+ # Close streaming handler if present
+ streaming_handler = streaming_handler_var.get()
+ if streaming_handler:
+ await streaming_handler.push_chunk(END_OF_STREAM) # type: ignore
+
+ # Assemble response
+ if gen_options:
+ res = self.response_assembler.assemble_response(
+ new_events=new_events,
+ all_events=all_events,
+ output_state=output_state,
+ processing_log=processing_log,
+ gen_options=gen_options,
+ prompt=prompt,
+ )
+
+ # Handle tracing if enabled
+ if self.config.tracing.enabled:
+ await self._handle_tracing(messages, res)
+
+ return res
+ else:
+ return self.response_assembler.assemble_simple_response(
+ new_events=new_events, prompt=prompt
+ )
+
+ def _process_options(
+ self,
+ options: Optional[Union[dict, GenerationOptions]],
+ state: Optional[Any],
+ ) -> Optional[GenerationOptions]:
+ """Process and normalize generation options.
+
+ Args:
+ options: Raw options.
+ state: State object.
+
+ Returns:
+ Normalized GenerationOptions or None.
+ """
+ if state is not None:
+ if options is None:
+ return GenerationOptions()
+ elif isinstance(options, dict):
+ return GenerationOptions(**options)
+ else:
+ return options
+ else:
+ if options and isinstance(options, dict):
+ return GenerationOptions(**options)
+ elif isinstance(options, GenerationOptions):
+ return options
+ elif options is None:
+ return None
+ else:
+ raise TypeError("options must be a dict or GenerationOptions")
+
+ async def _handle_tracing(
+ self,
+ messages: List[dict[str, Any]],
+ res: GenerationResponse,
+ ):
+ """Handle tracing export.
+
+ Args:
+ messages: Input messages.
+ res: Generation response.
+ """
+ from nemoguardrails.tracing import Tracer
+
+ span_format = getattr(self.config.tracing, "span_format", "opentelemetry")
+ enable_content_capture = getattr(
+ self.config.tracing, "enable_content_capture", False
+ )
+
+ tracer = Tracer(
+ input=messages,
+ response=res,
+ adapters=self._log_adapters,
+ span_format=span_format,
+ enable_content_capture=enable_content_capture,
+ )
+ await tracer.export_async()
+
+ def generate(
+ self,
+ prompt: Optional[str] = None,
+ messages: Optional[List[dict]] = None,
+ options: Optional[Union[dict, GenerationOptions]] = None,
+ state: Optional[dict] = None,
+ ):
+ """Synchronous version of generate_async."""
+ if check_sync_call_from_async_loop():
+ raise RuntimeError(
+ "You are using the sync `generate` inside async code. "
+ "You should replace with `await generate_async(...)` or use `nest_asyncio.apply()`."
+ )
+
+ loop = get_or_create_event_loop()
+ return loop.run_until_complete(
+ self.generate_async(
+ prompt=prompt, messages=messages, options=options, state=state
+ )
+ )
+
+ def stream_async(
+ self,
+ prompt: Optional[str] = None,
+ messages: Optional[List[dict]] = None,
+ options: Optional[Union[dict, GenerationOptions]] = None,
+ state: Optional[Union[dict, State]] = None,
+ include_generation_metadata: Optional[bool] = False,
+ ) -> AsyncIterator[str]:
+ """Stream tokens from the LLM.
+
+ Note: This is a simplified stub. Full streaming implementation would
+ require additional logic from the original LLMRails.stream_async method.
+ """
+ # Simplified implementation - full version would include output rails streaming
+ self.explain_info = self._ensure_explain_info()
+
+ streaming_handler = StreamingHandler(
+ include_generation_metadata=include_generation_metadata
+ )
+
+ async def _generation_task():
+ try:
+ await self.generate_async(
+ prompt=prompt,
+ messages=messages,
+ streaming_handler=streaming_handler,
+ options=options,
+ state=state,
+ )
+ except Exception as e:
+ log.error(f"Error in generation task: {e}", exc_info=True)
+ import json
+
+ from nemoguardrails.utils import extract_error_json
+
+ error_message = str(e)
+ error_dict = extract_error_json(error_message)
+ error_payload = json.dumps(error_dict)
+ await streaming_handler.push_chunk(error_payload)
+ await streaming_handler.push_chunk(END_OF_STREAM)
+
+ task = asyncio.create_task(_generation_task())
+
+ if not hasattr(self, "_active_tasks"):
+ self._active_tasks = set()
+ self._active_tasks.add(task)
+
+ def task_done_callback(task):
+ self._active_tasks.discard(task)
+
+ task.add_done_callback(task_done_callback)
+
+ return streaming_handler
+
+ # Additional API methods for compatibility
+
+ def register_action(self, action: Callable, name: Optional[str] = None) -> Self:
+ """Register a custom action."""
+ self.runtime.register_action(action, name)
+ return self
+
+ def register_action_param(self, name: str, value: Any) -> Self:
+ """Register a custom action parameter."""
+ self.runtime.register_action_param(name, value)
+ return self
+
+ def register_filter(self, filter_fn: Callable, name: Optional[str] = None) -> Self:
+ """Register a custom filter."""
+ self.runtime.llm_task_manager.register_filter(filter_fn, name)
+ return self
+
+ def register_output_parser(self, output_parser: Callable, name: str) -> Self:
+ """Register a custom output parser."""
+ self.runtime.llm_task_manager.register_output_parser(output_parser, name)
+ return self
+
+ def register_prompt_context(self, name: str, value_or_fn: Any) -> Self:
+ """Register a value to be included in the prompt context."""
+ self.runtime.llm_task_manager.register_prompt_context(name, value_or_fn)
+ return self
+
+ def register_embedding_search_provider(
+ self, name: str, cls: Type[EmbeddingsIndex]
+ ) -> Self:
+ """Register a new embedding search provider."""
+ self.embedding_search_providers[name] = cls
+ return self
+
+ def register_embedding_provider(
+ self, cls: Type[EmbeddingModel], name: Optional[str] = None
+ ) -> Self:
+ """Register a custom embedding provider."""
+ from nemoguardrails.embeddings.providers import register_embedding_provider
+
+ register_embedding_provider(engine_name=name, model=cls)
+ return self
+
+ def explain(self) -> ExplainInfo:
+ """Return the latest ExplainInfo object."""
+ if self.explain_info is None:
+ self.explain_info = self._ensure_explain_info()
+ return self.explain_info
+
+ def update_llm(self, llm: Union[BaseLLM, BaseChatModel]):
+ """Replace the main LLM with the provided one."""
+ self.model_factory.update_main_llm(llm, self.action_param_registry)
+ self.llm_generation_actions.llm = llm
diff --git a/nemoguardrails/rails/llm/response_assembler.py b/nemoguardrails/rails/llm/response_assembler.py
new file mode 100644
index 000000000..807eb8013
--- /dev/null
+++ b/nemoguardrails/rails/llm/response_assembler.py
@@ -0,0 +1,364 @@
+# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+# SPDX-License-Identifier: Apache-2.0
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""Response assembler for generating GenerationResponse objects."""
+
+import logging
+import re
+from typing import Any, Dict, List, Optional
+
+from nemoguardrails.actions.llm.utils import (
+ extract_bot_thinking_from_events,
+ extract_tool_calls_from_events,
+ get_and_clear_response_metadata_contextvar,
+ get_colang_history,
+)
+from nemoguardrails.colang.v1_0.runtime.flows import compute_context
+from nemoguardrails.logging.processing_log import compute_generation_log
+from nemoguardrails.rails.llm.config import RailsConfig
+from nemoguardrails.rails.llm.options import (
+ GenerationLog,
+ GenerationOptions,
+ GenerationResponse,
+)
+
+log = logging.getLogger(__name__)
+
+
+class ResponseAssembler:
+ """Assembles responses from events into GenerationResponse objects."""
+
+ def __init__(self, config: RailsConfig):
+ """Initialize the ResponseAssembler.
+
+ Args:
+ config: The rails configuration.
+ """
+ self.config = config
+
+ def assemble_response(
+ self,
+ new_events: List[dict],
+ all_events: List[dict],
+ output_state: Optional[Any],
+ processing_log: List[dict],
+ gen_options: Optional[GenerationOptions],
+ prompt: Optional[str] = None,
+ ) -> GenerationResponse:
+ """Assemble a GenerationResponse from events.
+
+ Args:
+ new_events: The newly generated events.
+ all_events: All events including history.
+ output_state: The output state (for Colang 2.x).
+ processing_log: The processing log.
+ gen_options: Generation options.
+ prompt: Optional prompt (if used instead of messages).
+
+ Returns:
+ A GenerationResponse object.
+ """
+ # Extract responses and metadata from events
+ (
+ responses,
+ response_tool_calls,
+ response_events,
+ exception,
+ ) = self._extract_from_events(new_events)
+
+ # Build the new message
+ new_message = self._build_message(
+ responses, response_tool_calls, response_events, exception
+ )
+
+ # Extract additional metadata
+ tool_calls = extract_tool_calls_from_events(new_events)
+ llm_metadata = get_and_clear_response_metadata_contextvar()
+ reasoning_content = extract_bot_thinking_from_events(new_events)
+
+ # Create response object
+ if prompt:
+ res = GenerationResponse(response=new_message["content"])
+ else:
+ res = GenerationResponse(response=[new_message])
+
+ if reasoning_content:
+ res.reasoning_content = reasoning_content
+
+ if tool_calls:
+ res.tool_calls = tool_calls
+
+ if llm_metadata:
+ res.llm_metadata = llm_metadata
+
+ # Add version-specific information
+ if self.config.colang_version == "1.0":
+ self._add_v1_specific_info(res, all_events, processing_log, gen_options)
+ else:
+ self._validate_v2_options(gen_options)
+
+ # Include the state
+ if output_state is not None:
+ res.state = output_state
+
+ return res
+
+ def _extract_from_events(
+ self, new_events: List[dict]
+ ) -> tuple[List[str], List[dict], List[dict], Optional[dict]]:
+ """Extract responses, tool calls, and events from new events.
+
+ Args:
+ new_events: The newly generated events.
+
+ Returns:
+ Tuple of (responses, response_tool_calls, response_events, exception).
+ """
+ responses = []
+ response_tool_calls = []
+ response_events = []
+ exception = None
+
+ if self.config.colang_version == "1.0":
+ for event in new_events:
+ if event["type"] == "StartUtteranceBotAction":
+ # Check if we need to remove a message
+ if event["script"] == "(remove last message)":
+ responses = responses[0:-1]
+ else:
+ responses.append(event["script"])
+ elif event["type"].endswith("Exception"):
+ exception = event
+ else:
+ for event in new_events:
+ start_action_match = re.match(r"Start(.*Action)", event["type"])
+
+ if start_action_match:
+ action_name = start_action_match[1]
+ # Extract arguments
+ arguments = {
+ k: v
+ for k, v in event.items()
+ if k
+ not in [
+ "type",
+ "uid",
+ "event_created_at",
+ "source_uid",
+ "action_uid",
+ ]
+ }
+ response_tool_calls.append(
+ {
+ "id": event["action_uid"],
+ "type": "function",
+ "function": {"name": action_name, "arguments": arguments},
+ }
+ )
+
+ elif event["type"] == "UtteranceBotActionFinished":
+ responses.append(event["final_script"])
+ else:
+ response_events.append(event)
+
+ return responses, response_tool_calls, response_events, exception
+
+ def _build_message(
+ self,
+ responses: List[str],
+ response_tool_calls: List[dict],
+ response_events: List[dict],
+ exception: Optional[dict],
+ ) -> dict:
+ """Build a message from response components.
+
+ Args:
+ responses: List of response strings.
+ response_tool_calls: List of tool calls.
+ response_events: List of response events.
+ exception: Optional exception event.
+
+ Returns:
+ A message dictionary.
+ """
+ new_message: Dict[str, Any]
+ if exception:
+ new_message = {"role": "exception", "content": exception}
+ else:
+ # Ensure all items in responses are strings
+ responses = [
+ str(response) if not isinstance(response, str) else response
+ for response in responses
+ ]
+ new_message = {"role": "assistant", "content": "\n".join(responses)}
+
+ if response_tool_calls:
+ new_message["tool_calls"] = response_tool_calls
+
+ if response_events:
+ new_message["events"] = response_events
+
+ return new_message
+
+ def _add_v1_specific_info(
+ self,
+ res: GenerationResponse,
+ all_events: List[dict],
+ processing_log: List[dict],
+ gen_options: Optional[GenerationOptions],
+ ):
+ """Add Colang 1.0 specific information to the response.
+
+ Args:
+ res: The GenerationResponse to update.
+ all_events: All events including history.
+ processing_log: The processing log.
+ gen_options: Generation options.
+ """
+ # Extract output variables if specified
+ if gen_options and gen_options.output_vars:
+ context = compute_context(all_events)
+ output_vars = gen_options.output_vars
+ if isinstance(output_vars, list):
+ res.output_data = {k: context.get(k) for k in output_vars}
+ else:
+ res.output_data = context
+
+ # Add logging information
+ _log = compute_generation_log(processing_log)
+ log_options = gen_options.log if gen_options else None
+
+ if log_options and (log_options.activated_rails or log_options.llm_calls):
+ res.log = GenerationLog()
+ res.log.stats = _log.stats
+
+ if log_options.activated_rails:
+ res.log.activated_rails = _log.activated_rails
+ else:
+ # Keep as empty list when not requested
+ res.log.activated_rails = []
+
+ if log_options.llm_calls:
+ res.log.llm_calls = []
+ for activated_rail in _log.activated_rails:
+ for executed_action in activated_rail.executed_actions:
+ res.log.llm_calls.extend(executed_action.llm_calls)
+ else:
+ # Set to empty list instead of None when not requested
+ res.log.llm_calls = []
+
+ # Include internal events if requested
+ if log_options and log_options.internal_events:
+ if res.log is None:
+ res.log = GenerationLog()
+ res.log.internal_events = all_events
+ elif res.log is not None:
+ # Set to empty list instead of None when not requested but log exists
+ res.log.internal_events = []
+
+ # Include Colang history if requested
+ if log_options and log_options.colang_history:
+ if res.log is None:
+ res.log = GenerationLog()
+ res.log.colang_history = get_colang_history(all_events)
+
+ # Normalize list fields: ensure they're empty lists instead of None when log exists
+ if res.log is not None:
+ if res.log.llm_calls is None:
+ res.log.llm_calls = []
+ if res.log.internal_events is None:
+ res.log.internal_events = []
+
+ # Include raw LLM output if requested
+ if gen_options and gen_options.llm_output:
+ for activated_rail in _log.activated_rails:
+ if activated_rail.type == "generation":
+ for executed_action in activated_rail.executed_actions:
+ for llm_call in executed_action.llm_calls:
+ res.llm_output = llm_call.raw_response
+
+ def _validate_v2_options(self, gen_options: Optional[GenerationOptions]):
+ """Validate that unsupported options are not used for Colang 2.x.
+
+ Args:
+ gen_options: Generation options to validate.
+
+ Raises:
+ ValueError: If unsupported options are used.
+ """
+ if not gen_options:
+ return
+
+ if gen_options.output_vars:
+ raise ValueError(
+ "The `output_vars` option is not supported for Colang 2.0 configurations."
+ )
+
+ log_options = gen_options.log
+ if log_options and (
+ log_options.activated_rails
+ or log_options.llm_calls
+ or log_options.internal_events
+ or log_options.colang_history
+ ):
+ raise ValueError(
+ "The `log` option is not supported for Colang 2.0 configurations."
+ )
+
+ if gen_options.llm_output:
+ raise ValueError(
+ "The `llm_output` option is not supported for Colang 2.0 configurations."
+ )
+
+ def assemble_simple_response(
+ self,
+ new_events: List[dict],
+ prompt: Optional[str] = None,
+ ) -> dict:
+ """Assemble a simple response (non-GenerationResponse mode).
+
+ Args:
+ new_events: The newly generated events.
+ prompt: Optional prompt (if used instead of messages).
+
+ Returns:
+ A message dictionary or content string.
+ """
+ (
+ responses,
+ response_tool_calls,
+ response_events,
+ exception,
+ ) = self._extract_from_events(new_events)
+
+ new_message = self._build_message(
+ responses, response_tool_calls, response_events, exception
+ )
+
+ # Add thinking trace if present
+ reasoning_content = extract_bot_thinking_from_events(new_events)
+ if reasoning_content:
+ thinking_trace = f"{reasoning_content}\n"
+ new_message["content"] = thinking_trace + new_message["content"]
+
+ # Add tool calls if present
+ tool_calls = extract_tool_calls_from_events(new_events)
+ if tool_calls:
+ new_message["tool_calls"] = tool_calls
+
+ if prompt:
+ return new_message["content"]
+ else:
+ return new_message
diff --git a/nemoguardrails/rails/llm/runtime_orchestrator.py b/nemoguardrails/rails/llm/runtime_orchestrator.py
new file mode 100644
index 000000000..1d581afde
--- /dev/null
+++ b/nemoguardrails/rails/llm/runtime_orchestrator.py
@@ -0,0 +1,164 @@
+# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+# SPDX-License-Identifier: Apache-2.0
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""Runtime orchestrator for managing Colang runtime execution."""
+
+import asyncio
+import logging
+from typing import Any, List, Optional, Tuple, cast
+
+from nemoguardrails.colang.runtime import Runtime
+from nemoguardrails.colang.v1_0.runtime.runtime import RuntimeV1_0
+from nemoguardrails.colang.v2_x.runtime.runtime import RuntimeV2_x
+from nemoguardrails.colang.v2_x.runtime.serialization import state_to_json
+from nemoguardrails.logging.verbose import set_verbose
+from nemoguardrails.rails.llm.config import RailsConfig
+
+log = logging.getLogger(__name__)
+
+# Semaphore for protecting process_events calls
+process_events_semaphore = asyncio.Semaphore(1)
+
+
+class RuntimeOrchestrator:
+ """Orchestrates the Colang runtime execution.
+
+ Handles runtime initialization, event generation, and process coordination.
+ """
+
+ def __init__(self, config: RailsConfig, verbose: bool = False):
+ """Initialize the RuntimeOrchestrator.
+
+ Args:
+ config: The rails configuration.
+ verbose: Whether to enable verbose logging.
+ """
+ self.config = config
+ self.verbose = verbose
+ self.runtime: Runtime
+
+ if self.verbose:
+ set_verbose(True, llm_calls=True)
+
+ # Initialize the appropriate runtime based on Colang version
+ if config.colang_version == "1.0":
+ self.runtime = RuntimeV1_0(config=config, verbose=verbose)
+ elif config.colang_version == "2.x":
+ self.runtime = RuntimeV2_x(config=config, verbose=verbose)
+ else:
+ raise ValueError(f"Unsupported colang version: {config.colang_version}.")
+
+ async def generate_events(
+ self,
+ events: List[dict],
+ state: Optional[Any] = None,
+ ) -> Tuple[List[dict], Optional[Any], List[dict]]:
+ """Generate new events based on the history of events.
+
+ Args:
+ events: The history of events.
+ state: Optional state object (for Colang 2.x).
+
+ Returns:
+ Tuple of (new_events, output_state, processing_log).
+ """
+ processing_log = []
+
+ if self.config.colang_version == "1.0":
+ # For Colang 1.0, we use generate_events
+ state_events = []
+ if state:
+ assert isinstance(state, dict)
+ state_events = state.get("events", [])
+
+ # Extract any context variables from state (keys other than "events")
+ # and create a ContextUpdate event to make them available in Colang
+ context_vars = {k: v for k, v in state.items() if k != "events"}
+ if context_vars:
+ log.info(
+ f"DEBUG: Creating ContextUpdate event from state with context_vars={context_vars}"
+ )
+ from nemoguardrails.utils import new_event_dict
+
+ context_update_event = new_event_dict(
+ "ContextUpdate", data=context_vars
+ )
+ log.info(
+ f"DEBUG: Created ContextUpdate event: {context_update_event}"
+ )
+ state_events = [context_update_event] + state_events
+
+ all_events = state_events + events
+ log.info(
+ f"DEBUG: Calling runtime.generate_events with {len(all_events)} events, first few: {all_events[:3]}"
+ )
+ new_events = await self.runtime.generate_events(
+ all_events, processing_log=processing_log
+ )
+ output_state = None
+
+ else:
+ # For Colang 2.x, we use process_events
+ instant_actions = ["UtteranceBotAction"]
+ if self.config.rails.actions.instant_actions is not None:
+ instant_actions = self.config.rails.actions.instant_actions
+
+ runtime: RuntimeV2_x = cast(RuntimeV2_x, self.runtime)
+
+ new_events, output_state = await runtime.process_events(
+ events, state=state, instant_actions=instant_actions, blocking=True
+ )
+
+ # Encode output state as JSON
+ output_state = {"state": state_to_json(output_state), "version": "2.x"}
+
+ return new_events, output_state, processing_log
+
+ async def process_events_async(
+ self,
+ events: List[dict],
+ state: Optional[dict] = None,
+ blocking: bool = False,
+ ) -> Tuple[List[dict], Any]:
+ """Process a sequence of events in a given state.
+
+ Args:
+ events: A sequence of events to process.
+ state: The starting state.
+ blocking: Whether to block on all actions.
+
+ Returns:
+ Tuple of (output_events, output_state).
+ """
+ # Protect process_events to be called only once at a time
+ async with process_events_semaphore:
+ if self.config.colang_version == "1.0":
+ # For Colang 1.0, use generate_events
+ state_events = []
+ if state:
+ assert isinstance(state, dict)
+ state_events = state.get("events", [])
+
+ output_events = await self.runtime.generate_events(
+ state_events + events
+ )
+ output_state = None
+ else:
+ # For Colang 2.x, use process_events
+ output_events, output_state = await self.runtime.process_events(
+ events, state, blocking
+ )
+
+ return (output_events, output_state)
diff --git a/nemoguardrails/server/api.py b/nemoguardrails/server/api.py
index 6769dec1e..d8a94bc02 100644
--- a/nemoguardrails/server/api.py
+++ b/nemoguardrails/server/api.py
@@ -197,8 +197,7 @@ class RequestBody(BaseModel):
)
config_ids: Optional[List[str]] = Field(
default=None,
- description="The list of configuration ids to be used. "
- "If set, the configurations will be combined.",
+ description="The list of configuration ids to be used. If set, the configurations will be combined.",
# alias="guardrails",
validate_default=True,
)
@@ -592,8 +591,7 @@ def on_any_event(self, event):
except ImportError:
# Since this is running in a separate thread, we just print the error.
print(
- "The auto-reload feature requires `watchdog`. "
- "Please install using `pip install watchdog`."
+ "The auto-reload feature requires `watchdog`. Please install using `pip install watchdog`."
)
# Force close everything.
os._exit(-1)
diff --git a/nemoguardrails/streaming.py b/nemoguardrails/streaming.py
index 06ad3ee93..98c807405 100644
--- a/nemoguardrails/streaming.py
+++ b/nemoguardrails/streaming.py
@@ -259,7 +259,7 @@ async def _process(
async def push_chunk(
self,
- chunk: Union[str, GenerationChunk, AIMessageChunk, ChatGenerationChunk, None],
+ chunk: Union[str, GenerationChunk, AIMessageChunk, ChatGenerationChunk, Any],
generation_info: Optional[Dict[str, Any]] = None,
):
"""Push a new chunk to the stream."""
diff --git a/tests/rails/llm/test_config.py b/tests/rails/llm/test_config.py
index aa3be4789..fa715aba9 100644
--- a/tests/rails/llm/test_config.py
+++ b/tests/rails/llm/test_config.py
@@ -21,7 +21,7 @@
from pydantic import ValidationError
from nemoguardrails.rails.llm.config import Model, RailsConfig, TaskPrompt
-from nemoguardrails.rails.llm.llmrails import LLMRails
+from nemoguardrails.rails.llm.model_factory import ModelFactory
def test_task_prompt_valid_content():
@@ -317,9 +317,9 @@ def test_llm_rails_configure_streaming_with_attr():
streaming=True,
)
- rails = LLMRails(config, llm=mock_llm)
+ model_factory = ModelFactory(config=config, injected_llm=mock_llm)
setattr(mock_llm, "streaming", None)
- rails._configure_main_llm_streaming(llm=mock_llm)
+ model_factory._configure_streaming(llm=mock_llm)
assert mock_llm.streaming
@@ -333,8 +333,8 @@ def test_llm_rails_configure_streaming_without_attr(caplog):
streaming=True,
)
- rails = LLMRails(config, llm=mock_llm)
- rails._configure_main_llm_streaming(mock_llm)
+ model_factory = ModelFactory(config=config, injected_llm=mock_llm)
+ model_factory._configure_streaming(llm=mock_llm)
assert caplog.messages[-1] == "Provided main LLM does not support streaming."
diff --git a/tests/test_integration_cache.py b/tests/test_integration_cache.py
index 84649589c..22539b55c 100644
--- a/tests/test_integration_cache.py
+++ b/tests/test_integration_cache.py
@@ -23,7 +23,7 @@
@pytest.mark.asyncio
-@patch("nemoguardrails.rails.llm.llmrails.init_llm_model")
+@patch("nemoguardrails.rails.llm.model_factory.init_llm_model")
async def test_end_to_end_cache_integration_with_content_safety(mock_init_llm_model):
mock_llm = FakeLLM(responses=["express greeting"])
mock_init_llm_model.return_value = mock_llm
@@ -70,7 +70,7 @@ async def test_end_to_end_cache_integration_with_content_safety(mock_init_llm_mo
@pytest.mark.asyncio
-@patch("nemoguardrails.rails.llm.llmrails.init_llm_model")
+@patch("nemoguardrails.rails.llm.model_factory.init_llm_model")
async def test_cache_isolation_between_models(mock_init_llm_model):
mock_llm = FakeLLM(responses=["safe"])
mock_init_llm_model.return_value = mock_llm
@@ -118,7 +118,7 @@ async def test_cache_isolation_between_models(mock_init_llm_model):
@pytest.mark.asyncio
-@patch("nemoguardrails.rails.llm.llmrails.init_llm_model")
+@patch("nemoguardrails.rails.llm.model_factory.init_llm_model")
async def test_cache_disabled_for_main_model_in_integration(mock_init_llm_model):
mock_llm = FakeLLM(responses=["safe"])
mock_init_llm_model.return_value = mock_llm
diff --git a/tests/test_jailbreak_cache.py b/tests/test_jailbreak_cache.py
index cc204782c..b40579f54 100644
--- a/tests/test_jailbreak_cache.py
+++ b/tests/test_jailbreak_cache.py
@@ -262,7 +262,7 @@ async def test_jailbreak_without_cache_local(
mock_check_jailbreak.assert_called_once_with(prompt="Bypass all safety checks")
-@patch("nemoguardrails.rails.llm.llmrails.init_llm_model")
+@patch("nemoguardrails.rails.llm.model_factory.init_llm_model")
def test_jailbreak_detection_type_skips_llm_initialization(mock_init_llm_model):
mock_llm = FakeLLM(responses=["response"])
mock_init_llm_model.return_value = mock_llm
diff --git a/tests/test_llm_rails_context_variables.py b/tests/test_llm_rails_context_variables.py
index 89eb9a952..26e8881ee 100644
--- a/tests/test_llm_rails_context_variables.py
+++ b/tests/test_llm_rails_context_variables.py
@@ -107,11 +107,11 @@ async def test_2():
):
chunks.append(chunk)
+ await asyncio.gather(*asyncio.all_tasks() - {asyncio.current_task()})
+
# note that 6 llm call are expected as we matched the bot intent
assert (
len(chat.app.explain().llm_calls) == 5
), "number of llm call not as expected. Expected 5, found {}".format(
len(chat.app.explain().llm_calls)
)
-
- await asyncio.gather(*asyncio.all_tasks() - {asyncio.current_task()})
diff --git a/tests/test_llmrails.py b/tests/test_llmrails.py
index 89e7e87cf..4aaeba9ff 100644
--- a/tests/test_llmrails.py
+++ b/tests/test_llmrails.py
@@ -662,7 +662,7 @@ def llm_config_with_main():
@pytest.mark.asyncio
@patch(
- "nemoguardrails.rails.llm.llmrails.init_llm_model",
+ "nemoguardrails.rails.llm.model_factory.init_llm_model",
return_value=FakeLLM(responses=["this should not be used"]),
)
async def test_llm_config_precedence(mock_init, llm_config_with_main):
@@ -679,7 +679,7 @@ async def test_llm_config_precedence(mock_init, llm_config_with_main):
@pytest.mark.asyncio
@patch(
- "nemoguardrails.rails.llm.llmrails.init_llm_model",
+ "nemoguardrails.rails.llm.model_factory.init_llm_model",
return_value=FakeLLM(responses=["this should not be used"]),
)
async def test_llm_config_warning(mock_init, llm_config_with_main, caplog):
@@ -728,7 +728,7 @@ def llm_config_with_multiple_models():
@pytest.mark.asyncio
@patch(
- "nemoguardrails.rails.llm.llmrails.init_llm_model",
+ "nemoguardrails.rails.llm.model_factory.init_llm_model",
return_value=FakeLLM(responses=["content safety response"]),
)
async def test_other_models_honored(mock_init, llm_config_with_multiple_models):
@@ -779,7 +779,7 @@ async def test_llm_constructor_with_empty_models_config():
@pytest.mark.asyncio
@patch(
- "nemoguardrails.rails.llm.llmrails.init_llm_model",
+ "nemoguardrails.rails.llm.model_factory.init_llm_model",
return_value=FakeLLM(responses=["safe"]),
)
async def test_main_llm_from_config_registered_as_action_param(
@@ -839,7 +839,7 @@ async def test_llm_action(llm: BaseLLM):
assert action_finished_event["return_value"] == "llm_action_success"
-@patch("nemoguardrails.rails.llm.llmrails.init_llm_model")
+@patch("nemoguardrails.rails.llm.model_factory.init_llm_model")
@patch.dict(os.environ, {"TEST_OPENAI_KEY": "secret-api-key-from-env"})
def test_api_key_environment_variable_passed_to_init_llm_model(mock_init_llm_model):
"""Test that API keys from environment variables are passed to init_llm_model."""
@@ -873,7 +873,7 @@ def test_api_key_environment_variable_passed_to_init_llm_model(mock_init_llm_mod
assert call_args.kwargs["mode"] == "chat"
-@patch("nemoguardrails.rails.llm.llmrails.init_llm_model")
+@patch("nemoguardrails.rails.llm.model_factory.init_llm_model")
@patch.dict(os.environ, {"CONTENT_SAFETY_KEY": "safety-key-from-env"})
def test_api_key_environment_variable_for_non_main_models(mock_init_llm_model):
"""Test that API keys from environment variables work for non-main models too.
@@ -915,7 +915,7 @@ def test_api_key_environment_variable_for_non_main_models(mock_init_llm_model):
assert safety_call_args.kwargs["kwargs"]["temperature"] == 0.0
-@patch("nemoguardrails.rails.llm.llmrails.init_llm_model")
+@patch("nemoguardrails.rails.llm.model_factory.init_llm_model")
def test_missing_api_key_environment_variable_graceful_handling(mock_init_llm_model):
"""Test that missing environment variables are handled gracefully during LLM initialization.
@@ -992,7 +992,7 @@ def __init__(self):
@pytest.mark.asyncio
-@patch("nemoguardrails.rails.llm.llmrails.init_llm_model")
+@patch("nemoguardrails.rails.llm.model_factory.init_llm_model")
async def test_stream_usage_enabled_for_streaming_supported_providers(
mock_init_llm_model,
):
@@ -1020,7 +1020,7 @@ async def test_stream_usage_enabled_for_streaming_supported_providers(
@pytest.mark.asyncio
-@patch("nemoguardrails.rails.llm.llmrails.init_llm_model")
+@patch("nemoguardrails.rails.llm.model_factory.init_llm_model")
async def test_stream_usage_not_set_without_streaming(mock_init_llm_model):
"""Test that stream_usage is not set when streaming is disabled."""
config = RailsConfig.from_content(
@@ -1046,7 +1046,7 @@ async def test_stream_usage_not_set_without_streaming(mock_init_llm_model):
@pytest.mark.asyncio
-@patch("nemoguardrails.rails.llm.llmrails.init_llm_model")
+@patch("nemoguardrails.rails.llm.model_factory.init_llm_model")
async def test_stream_usage_enabled_for_all_providers_when_streaming(
mock_init_llm_model,
):
@@ -1190,7 +1190,7 @@ def test_explain_calls_ensure_explain_info():
assert rails.explain_info == ExplainInfo()
-@patch("nemoguardrails.rails.llm.llmrails.init_llm_model")
+@patch("nemoguardrails.rails.llm.model_factory.init_llm_model")
def test_cache_initialization_disabled_by_default(mock_init_llm_model):
mock_llm = FakeLLM(responses=["response"])
mock_init_llm_model.return_value = mock_llm
@@ -1216,7 +1216,7 @@ def test_cache_initialization_disabled_by_default(mock_init_llm_model):
assert model_caches is None or len(model_caches) == 0
-@patch("nemoguardrails.rails.llm.llmrails.init_llm_model")
+@patch("nemoguardrails.rails.llm.model_factory.init_llm_model")
def test_cache_initialization_with_enabled_cache(mock_init_llm_model):
from nemoguardrails.rails.llm.config import CacheStatsConfig, ModelCacheConfig
@@ -1251,7 +1251,7 @@ def test_cache_initialization_with_enabled_cache(mock_init_llm_model):
assert model_caches["content_safety"].maxsize == 1000
-@patch("nemoguardrails.rails.llm.llmrails.init_llm_model")
+@patch("nemoguardrails.rails.llm.model_factory.init_llm_model")
def test_cache_not_created_for_main_and_embeddings_models(mock_init_llm_model):
from nemoguardrails.rails.llm.config import ModelCacheConfig
@@ -1282,7 +1282,7 @@ def test_cache_not_created_for_main_and_embeddings_models(mock_init_llm_model):
assert "embeddings" not in model_caches
-@patch("nemoguardrails.rails.llm.llmrails.init_llm_model")
+@patch("nemoguardrails.rails.llm.model_factory.init_llm_model")
def test_cache_initialization_with_zero_maxsize_raises_error(mock_init_llm_model):
from nemoguardrails.rails.llm.config import ModelCacheConfig
@@ -1304,7 +1304,7 @@ def test_cache_initialization_with_zero_maxsize_raises_error(mock_init_llm_model
LLMRails(config=config, verbose=False)
-@patch("nemoguardrails.rails.llm.llmrails.init_llm_model")
+@patch("nemoguardrails.rails.llm.model_factory.init_llm_model")
def test_cache_initialization_with_stats_enabled(mock_init_llm_model):
from nemoguardrails.rails.llm.config import CacheStatsConfig, ModelCacheConfig
@@ -1336,7 +1336,7 @@ def test_cache_initialization_with_stats_enabled(mock_init_llm_model):
assert cache.supports_stats_logging() is True
-@patch("nemoguardrails.rails.llm.llmrails.init_llm_model")
+@patch("nemoguardrails.rails.llm.model_factory.init_llm_model")
def test_cache_initialization_with_multiple_models(mock_init_llm_model):
from nemoguardrails.rails.llm.config import ModelCacheConfig
diff --git a/tests/test_streaming.py b/tests/test_streaming.py
index 8fb1ac22a..7bdfc5472 100644
--- a/tests/test_streaming.py
+++ b/tests/test_streaming.py
@@ -16,6 +16,7 @@
import asyncio
import json
import math
+from typing import Optional
import pytest
@@ -283,15 +284,18 @@ def output_rails_streaming_config():
define flow
user express greeting
bot tell joke
+
+ define flow self check output
+ execute self_check_output
""",
)
@action(is_system_action=True, output_mapping=lambda result: not result)
-def self_check_output(**params):
+def self_check_output(context: Optional[dict] = None):
"""A dummy self check action that checks if the bot message contains the BLOCK keyword."""
- if params.get("context", {}).get("bot_message"):
- bot_message_chunk = params.get("context", {}).get("bot_message")
+ if context and context.get("bot_message"):
+ bot_message_chunk = context.get("bot_message")
if "BLOCK" in bot_message_chunk:
return False
diff --git a/tests/test_streaming_output_rails.py b/tests/test_streaming_output_rails.py
index 8583a2fb9..a6f582692 100644
--- a/tests/test_streaming_output_rails.py
+++ b/tests/test_streaming_output_rails.py
@@ -18,7 +18,7 @@
import asyncio
import json
from json.decoder import JSONDecodeError
-from typing import AsyncIterator
+from typing import AsyncIterator, Optional
import pytest
@@ -57,6 +57,9 @@ def output_rails_streaming_config():
define flow
user express greeting
bot tell joke
+
+ define flow self check output
+ execute self_check_output
""",
)
@@ -83,6 +86,9 @@ def output_rails_streaming_config_default():
define flow
user express greeting
bot tell joke
+
+ define flow self check output
+ execute self_check_output
""",
)
@@ -100,11 +106,11 @@ async def test_stream_async_streaming_enabled(output_rails_streaming_config):
@action(is_system_action=True, output_mapping=lambda result: not result)
-def self_check_output(**params):
+def self_check_output(context: Optional[dict] = None):
"""A dummy self check action that checks if the bot message contains the BLOCK keyword."""
- if params.get("context", {}).get("bot_message"):
- bot_message_chunk = params.get("context", {}).get("bot_message")
+ if context and context.get("bot_message"):
+ bot_message_chunk = context.get("bot_message")
print(f"bot_message_chunk: {bot_message_chunk}")
if "BLOCK" in bot_message_chunk:
return False
@@ -322,10 +328,8 @@ async def test_external_generator_with_output_rails_blocked():
rails = LLMRails(config)
@action(name="self_check_output")
- async def self_check_output(**kwargs):
- bot_message = kwargs.get(
- "bot_message", kwargs.get("context", {}).get("bot_message", "")
- )
+ async def self_check_output(context: Optional[dict] = None):
+ bot_message = context.get("bot_message", "") if context else ""
# block if message contains "offensive" or "idiot"
if "offensive" in bot_message.lower() or "idiot" in bot_message.lower():
return False
diff --git a/tests/test_system_message_conversion.py b/tests/test_system_message_conversion.py
index 08d1f6797..7bda08de6 100644
--- a/tests/test_system_message_conversion.py
+++ b/tests/test_system_message_conversion.py
@@ -44,7 +44,7 @@ async def test_system_message_conversion_v1():
{"role": "user", "content": "Hello!"},
]
- events = llm_rails._get_events_for_messages(messages, None)
+ events = llm_rails.event_translator.messages_to_events(messages, None)
system_messages = [event for event in events if event["type"] == "SystemMessage"]
assert len(system_messages) == 1
@@ -76,7 +76,7 @@ async def test_system_message_conversion_v2x():
{"role": "user", "content": "Hello!"},
]
- events = llm_rails._get_events_for_messages(messages, None)
+ events = llm_rails.event_translator.messages_to_events(messages, None)
system_messages = [event for event in events if event["type"] == "SystemMessage"]
assert len(system_messages) == 1
@@ -108,7 +108,7 @@ async def test_system_message_conversion_multiple():
{"role": "user", "content": "Hello!"},
]
- events = llm_rails._get_events_for_messages(messages, None)
+ events = llm_rails.event_translator.messages_to_events(messages, None)
system_messages = [event for event in events if event["type"] == "SystemMessage"]
assert len(system_messages) == 2
diff --git a/tests/v2_x/chat.py b/tests/v2_x/chat.py
index e3f5713b1..0db5490b7 100644
--- a/tests/v2_x/chat.py
+++ b/tests/v2_x/chat.py
@@ -18,10 +18,10 @@
from dataclasses import dataclass, field
from typing import Dict, List, Optional
-import nemoguardrails.rails.llm.llmrails
from nemoguardrails import LLMRails, RailsConfig
from nemoguardrails.cli.chat import extract_scene_text_content, parse_events_inputs
from nemoguardrails.colang.v2_x.runtime.flows import State
+from nemoguardrails.rails.llm import runtime_orchestrator
from nemoguardrails.utils import new_event_dict, new_uuid
os.environ["TOKENIZERS_PARALLELISM"] = "false"
@@ -49,9 +49,7 @@ def __init__(self, rails_app: LLMRails):
asyncio.create_task(self.run())
# Ensure that the semaphore is assigned to the same loop that we just created
- nemoguardrails.rails.llm.llmrails.process_events_semaphore = asyncio.Semaphore(
- 1
- )
+ runtime_orchestrator.process_events_semaphore = asyncio.Semaphore(1)
self.output_summary: list[str] = []
self.should_terminate = False
self.enable_input = asyncio.Event()