diff --git a/nemoguardrails/actions/llm/generation.py b/nemoguardrails/actions/llm/generation.py index c7d390aaf..039b1888b 100644 --- a/nemoguardrails/actions/llm/generation.py +++ b/nemoguardrails/actions/llm/generation.py @@ -887,9 +887,22 @@ async def generate_bot_message( # of course, it does not work when passed as context in `run_output_rails_in_streaming` # streaming_handler is set when stream_async method is used + has_streaming_handler = streaming_handler is not None + has_output_streaming = ( + self.config.rails.output.streaming + and self.config.rails.output.streaming.enabled + ) + log.info( + f"generate_bot_message: streaming_handler={has_streaming_handler}, output.streaming={has_output_streaming}" + ) # if streaming_handler and len(self.config.rails.output.flows) > 0: - if streaming_handler and self.config.rails.output.streaming.enabled: + if streaming_handler and has_output_streaming: + log.info("Setting skip_output_rails = True") context_updates["skip_output_rails"] = True + else: + log.info( + f"NOT setting skip_output_rails: streaming_handler={has_streaming_handler}, output.streaming={has_output_streaming}" + ) if bot_intent in self.config.bot_messages: # Choose a message randomly from self.config.bot_messages[bot_message] @@ -970,7 +983,9 @@ async def generate_bot_message( new_event_dict("BotMessage", text=text) ) - return ActionResult(events=output_events) + return ActionResult( + events=output_events, context_updates=context_updates + ) else: if streaming_handler: await streaming_handler.push_chunk( @@ -987,7 +1002,9 @@ async def generate_bot_message( ) output_events.append(bot_message_event) - return ActionResult(events=output_events) + return ActionResult( + events=output_events, context_updates=context_updates + ) # If we are in passthrough mode, we just use the input for prompting if self.config.passthrough: diff --git a/nemoguardrails/colang/v1_0/runtime/flows.py b/nemoguardrails/colang/v1_0/runtime/flows.py index 9f4f67791..64e81dcc6 100644 --- a/nemoguardrails/colang/v1_0/runtime/flows.py +++ b/nemoguardrails/colang/v1_0/runtime/flows.py @@ -15,6 +15,7 @@ """A simplified modeling of the CoFlows engine.""" +import logging import uuid from dataclasses import dataclass, field from enum import Enum @@ -25,6 +26,8 @@ from nemoguardrails.colang.v1_0.runtime.sliding import slide from nemoguardrails.utils import new_event_dict, new_uuid +log = logging.getLogger(__name__) + @dataclass class FlowConfig: @@ -356,7 +359,15 @@ def compute_next_state(state: State, event: dict) -> State: if event["type"] == "ContextUpdate": # TODO: add support to also remove keys from the context. # maybe with a special context key e.g. "__remove__": ["key1", "key2"] + if "skip_output_rails" in event["data"]: + log.info( + f"ContextUpdate setting skip_output_rails={event['data']['skip_output_rails']}" + ) state.context.update(event["data"]) + if "skip_output_rails" in state.context: + log.info( + f"After update, context skip_output_rails={state.context.get('skip_output_rails')}" + ) state.context_updates = {} state.next_step = None return state @@ -415,6 +426,12 @@ def compute_next_state(state: State, event: dict) -> State: _record_next_step(new_state, flow_state, flow_config, priority_modifier=0.9) continue + # Debug logging for BotMessage event and skip_output_rails + if event["type"] == "BotMessage": + log.info( + f"BotMessage event processing for flow '{flow_config.id}', skip_output_rails in context: {flow_state.context.get('skip_output_rails', 'NOT SET')}" + ) + # If we're at a branching point, we look at all individual heads. matching_head = None diff --git a/nemoguardrails/colang/v1_0/runtime/runtime.py b/nemoguardrails/colang/v1_0/runtime/runtime.py index ffb9c3e64..3a8859dda 100644 --- a/nemoguardrails/colang/v1_0/runtime/runtime.py +++ b/nemoguardrails/colang/v1_0/runtime/runtime.py @@ -523,6 +523,9 @@ async def _run_output_rails_in_parallel_streaming( flows_with_params: Dictionary mapping flow_id to {"action_name": str, "params": dict} events: The events list for context """ + # Compute context from events so actions can access bot_message + context = compute_context(events) + tasks = [] async def run_single_rail(flow_id: str, action_info: dict) -> tuple: @@ -532,8 +535,11 @@ async def run_single_rail(flow_id: str, action_info: dict) -> tuple: action_name = action_info["action_name"] params = action_info["params"] + # Merge context into params so actions have access to bot_message + params_with_context = {**params, "context": context} + result_tuple = await self.action_dispatcher.execute_action( - action_name, params + action_name, params_with_context ) result, status = result_tuple @@ -731,10 +737,19 @@ async def _process_start_action(self, events: List[dict]) -> List[dict]: return_events = [] context_updates = {} + if action_name == "generate_bot_message": + log.info( + f"DEBUG: generate_bot_message returned, isinstance(ActionResult)={isinstance(result, ActionResult)}" + ) + if isinstance(result, ActionResult): return_value = result.return_value return_events = result.events context_updates.update(result.context_updates) + if action_name == "generate_bot_message": + log.info( + f"generate_bot_message ActionResult: context_updates={context_updates}, skip_output_rails={'skip_output_rails' in context_updates}" + ) # If we have an action result key, we also record the update. if action_result_key: diff --git a/nemoguardrails/rails/llm/config_loader.py b/nemoguardrails/rails/llm/config_loader.py new file mode 100644 index 000000000..9c1d9f136 --- /dev/null +++ b/nemoguardrails/rails/llm/config_loader.py @@ -0,0 +1,215 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Config loader for loading default flows, library content, and config.py modules.""" + +import importlib.util +import logging +import os +from typing import List, Any + +from nemoguardrails.colang import parse_colang_file +from nemoguardrails.colang.v1_0.runtime.flows import _normalize_flow_id +from nemoguardrails.rails.llm.config import RailsConfig + +log = logging.getLogger(__name__) + + +class ConfigLoader: + """Enriches RailsConfig with default flows, library content, and config.py modules.""" + + @staticmethod + def load_config(config: RailsConfig) -> List[Any]: + """Enrich the config with default flows and library content. + + Args: + config: The rails configuration to enrich. + + Returns: + List of config modules loaded from config.py files. + """ + # Load default flows for Colang 1.0 + if config.colang_version == "1.0": + ConfigLoader._load_default_flows(config) + ConfigLoader._load_library_content(config) + + # Mark rail flows as system flows + ConfigLoader._mark_rail_flows_as_system(config) + + # Load and execute config.py modules + config_modules = ConfigLoader._load_config_modules(config) + + # Validate config + ConfigLoader._validate_config(config) + + return config_modules + + @staticmethod + def _load_default_flows(config: RailsConfig): + """Load default LLM flows for Colang 1.0. + + Args: + config: The rails configuration. + """ + current_folder = os.path.dirname(__file__) + default_flows_file = "llm_flows.co" + default_flows_path = os.path.join(current_folder, default_flows_file) + + with open(default_flows_path, "r") as f: + default_flows_content = f.read() + default_flows = parse_colang_file( + default_flows_file, default_flows_content + )["flows"] + + # Mark all default flows as system flows + for flow_config in default_flows: + flow_config["is_system_flow"] = True + + # Add default flows to config + config.flows.extend(default_flows) + log.debug(f"Loaded {len(default_flows)} default flows") + + @staticmethod + def _load_library_content(config: RailsConfig): + """Load content from the components library. + + Args: + config: The rails configuration. + """ + library_path = os.path.join(os.path.dirname(__file__), "../../library") + loaded_files = 0 + + for root, dirs, files in os.walk(library_path): + for file in files: + full_path = os.path.join(root, file) + if file.endswith(".co"): + log.debug(f"Loading library file: {full_path}") + with open(full_path, "r", encoding="utf-8") as f: + content = parse_colang_file( + file, content=f.read(), version=config.colang_version + ) + if not content: + continue + + # Mark all library flows as system flows + for flow_config in content["flows"]: + flow_config["is_system_flow"] = True + + # Load all the flows + config.flows.extend(content["flows"]) + + # Load bot messages if not overwritten + for message_id, utterances in content.get( + "bot_messages", {} + ).items(): + if message_id not in config.bot_messages: + config.bot_messages[message_id] = utterances + + loaded_files += 1 + + log.debug(f"Loaded {loaded_files} library files") + + @staticmethod + def _mark_rail_flows_as_system(config: RailsConfig): + """Mark all flows used in rails as system flows. + + Args: + config: The rails configuration. + """ + rail_flow_ids = ( + config.rails.input.flows + + config.rails.output.flows + + config.rails.retrieval.flows + ) + + for flow_config in config.flows: + if flow_config.get("id") in rail_flow_ids: + flow_config["is_system_flow"] = True + # Mark them as subflows by default to simplify syntax + flow_config["is_subflow"] = True + + @staticmethod + def _load_config_modules(config: RailsConfig) -> List[AttributeError]: + """Load and execute config.py modules. + + Args: + config: The rails configuration. + + Returns: + List of loaded config modules. + """ + config_modules = [] + paths = list( + config.imported_paths.values() if config.imported_paths else [] + ) + [config.config_path] + + for _path in paths: + if _path: + filepath = os.path.join(_path, "config.py") + if os.path.exists(filepath): + filename = os.path.basename(filepath) + spec = importlib.util.spec_from_file_location(filename, filepath) + if spec and spec.loader: + config_module = importlib.util.module_from_spec(spec) + spec.loader.exec_module(config_module) + config_modules.append(config_module) + log.debug(f"Loaded config module from: {filepath}") + + return config_modules + + @staticmethod + def _validate_config(config: RailsConfig): + """Run validation checks on the config. + + Args: + config: The rails configuration to validate. + + Raises: + ValueError: If validation fails. + """ + if config.colang_version == "1.0": + existing_flows_names = set([flow.get("id") for flow in config.flows]) + else: + existing_flows_names = set([flow.get("name") for flow in config.flows]) + + # Validate input rail flows + for flow_name in config.rails.input.flows: + flow_name = _normalize_flow_id(flow_name) + if flow_name not in existing_flows_names: + raise ValueError( + f"The provided input rail flow `{flow_name}` does not exist" + ) + + # Validate output rail flows + for flow_name in config.rails.output.flows: + flow_name = _normalize_flow_id(flow_name) + if flow_name not in existing_flows_names: + raise ValueError( + f"The provided output rail flow `{flow_name}` does not exist" + ) + + # Validate retrieval rail flows + for flow_name in config.rails.retrieval.flows: + if flow_name not in existing_flows_names: + raise ValueError( + f"The provided retrieval rail flow `{flow_name}` does not exist" + ) + + # Check for conflicting modes + if config.passthrough and config.rails.dialog.single_call.enabled: + raise ValueError( + "The passthrough mode and the single call dialog rails mode can't be used at the same time. " + "The single call mode needs to use an altered prompt when prompting the LLM." + ) diff --git a/nemoguardrails/rails/llm/event_translator.py b/nemoguardrails/rails/llm/event_translator.py new file mode 100644 index 000000000..d67f8fbe0 --- /dev/null +++ b/nemoguardrails/rails/llm/event_translator.py @@ -0,0 +1,246 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Event translator for converting between messages and events.""" + +import logging +from typing import Any, Dict, List, Optional + +from nemoguardrails.colang.v2_x.runtime.flows import Action, State +from nemoguardrails.rails.llm.config import RailsConfig +from nemoguardrails.rails.llm.utils import get_history_cache_key +from nemoguardrails.utils import new_event_dict, new_uuid + +log = logging.getLogger(__name__) + + +class EventTranslator: + """Translates between messages and Colang events.""" + + def __init__(self, config: RailsConfig): + """Initialize the EventTranslator. + + Args: + config: The rails configuration. + """ + self.config = config + self.events_history_cache: Dict[str, List[dict]] = {} + + def messages_to_events( + self, messages: List[dict[str, Any]], state: Optional[Any] = None + ) -> List[dict]: + """Convert messages to events. + + Tries to find a prefix of messages for which we have already a list of events + in the cache. For the rest, they are converted as is. + + Args: + messages: The list of messages. + state: Optional state object (used for Colang 2.x). + + Returns: + A list of events. + """ + events = [] + + if self.config.colang_version == "1.0": + events = self._messages_to_events_v1(messages) + else: + events = self._messages_to_events_v2(messages, state) + + return events + + def _messages_to_events_v1(self, messages: List[dict]) -> List[dict]: + """Convert messages to events for Colang 1.0. + + Args: + messages: The list of messages. + + Returns: + A list of events. + """ + events = [] + + # Try to find the longest prefix of messages for which we have a cache + p = len(messages) - 1 + while p > 0: + cache_key = get_history_cache_key(messages[0:p]) + if cache_key in self.events_history_cache: + events = self.events_history_cache[cache_key].copy() + break + p -= 1 + + # For the rest of the messages, transform them directly into events + for idx in range(p, len(messages)): + msg = messages[idx] + if msg["role"] == "user": + events.append( + { + "type": "UtteranceUserActionFinished", + "final_transcript": msg["content"], + } + ) + + # If it's not the last message, also add the `UserMessage` event + if idx != len(messages) - 1: + events.append( + { + "type": "UserMessage", + "text": msg["content"], + } + ) + + elif msg["role"] == "assistant": + if msg.get("tool_calls"): + events.append( + {"type": "BotToolCalls", "tool_calls": msg["tool_calls"]} + ) + else: + action_uid = new_uuid() + start_event = new_event_dict( + "StartUtteranceBotAction", + script=msg["content"], + action_uid=action_uid, + ) + finished_event = new_event_dict( + "UtteranceBotActionFinished", + final_script=msg["content"], + is_success=True, + action_uid=action_uid, + ) + events.extend([start_event, finished_event]) + + elif msg["role"] == "context": + events.append({"type": "ContextUpdate", "data": msg["content"]}) + + elif msg["role"] == "event": + events.append(msg["event"]) + + elif msg["role"] == "system": + events.append({"type": "SystemMessage", "content": msg["content"]}) + + elif msg["role"] == "tool": + # For the last tool message, create grouped tool event or synthetic UserMessage + if idx == len(messages) - 1: + # Find the original user message for response generation + user_message = None + for prev_msg in reversed(messages[:idx]): + if prev_msg["role"] == "user": + user_message = prev_msg["content"] + break + + if user_message: + # If tool input rails are configured, group all tool messages + if self.config.rails.tool_input.flows: + tool_messages = [] + for tool_idx in range(len(messages)): + if messages[tool_idx]["role"] == "tool": + tool_messages.append( + { + "content": messages[tool_idx]["content"], + "name": messages[tool_idx].get( + "name", "unknown" + ), + "tool_call_id": messages[tool_idx].get( + "tool_call_id", "" + ), + } + ) + + events.append( + { + "type": "UserToolMessages", + "tool_messages": tool_messages, + } + ) + else: + events.append({"type": "UserMessage", "text": user_message}) + + return events + + def _messages_to_events_v2( + self, messages: List[dict[str, Any]], state: Optional[Any] + ) -> List[dict]: + """Convert messages to events for Colang 2.x. + + Args: + messages: The list of messages. + state: The state object. + + Returns: + A list of events. + """ + events = [] + + for idx in range(len(messages)): + msg = messages[idx] + if msg["role"] == "user": + events.append( + { + "type": "UtteranceUserActionFinished", + "final_transcript": msg["content"], + } + ) + + elif msg["role"] == "assistant": + raise ValueError( + "Providing `assistant` messages as input is not supported for Colang 2.0 configurations." + ) + + elif msg["role"] == "context": + events.append({"type": "ContextUpdate", "data": msg["content"]}) + + elif msg["role"] == "event": + events.append(msg["event"]) + + elif msg["role"] == "system": + events.append({"type": "SystemMessage", "content": msg["content"]}) + + elif msg["role"] == "tool": + if state is None: + raise ValueError( + "State object is required for tool messages in Colang 2.0" + ) + action_uid = msg["tool_call_id"] + return_value = msg["content"] + action: Action = state.actions[action_uid] + events.append( + new_event_dict( + f"{action.name}Finished", + action_uid=action_uid, + action_name=action.name, + status="success", + is_success=True, + return_value=return_value, + events=[], + ) + ) + + return events + + def cache_events(self, messages: List[dict], events: List[dict]): + """Cache events for a sequence of messages. + + Args: + messages: The list of messages. + events: The corresponding events. + """ + if self.config.colang_version == "1.0": + cache_key = get_history_cache_key(messages) + self.events_history_cache[cache_key] = events + + def clear_cache(self): + """Clear the events history cache.""" + self.events_history_cache.clear() diff --git a/nemoguardrails/rails/llm/kb_builder.py b/nemoguardrails/rails/llm/kb_builder.py new file mode 100644 index 000000000..201f990e3 --- /dev/null +++ b/nemoguardrails/rails/llm/kb_builder.py @@ -0,0 +1,59 @@ +"""Knowledge Base builder component.""" + +import logging +from typing import Callable, Optional + +from nemoguardrails.embeddings.index import EmbeddingsIndex +from nemoguardrails.kb.kb import KnowledgeBase +from nemoguardrails.rails.llm.config import EmbeddingSearchProvider, RailsConfig + +log = logging.getLogger(__name__) + + +class KnowledgeBaseBuilder: + """Builder for initializing and managing knowledge bases.""" + + def __init__( + self, + config: RailsConfig, + get_embeddings_search_provider_instance: Callable[ + [Optional[EmbeddingSearchProvider]], EmbeddingsIndex + ], + ): + """Initialize the KnowledgeBaseBuilder. + + Args: + config: The rails configuration. + get_embeddings_search_provider_instance: Function to get embeddings search provider. + """ + self.config = config + self.get_embeddings_search_provider_instance = ( + get_embeddings_search_provider_instance + ) + self.kb: Optional[KnowledgeBase] = None + + async def build(self) -> Optional[KnowledgeBase]: + """Build the knowledge base from configuration. + + Returns: + The initialized KnowledgeBase or None if no docs configured. + """ + if not self.config.docs: + self.kb = None + return None + + documents = [doc.content for doc in self.config.docs] + self.kb = KnowledgeBase( + documents=documents, + config=self.config.knowledge_base, + get_embedding_search_provider_instance=self.get_embeddings_search_provider_instance, + ) + self.kb.init() + await self.kb.build() + + log.info("Knowledge base initialized with %d documents", len(documents)) + return self.kb + + def get_kb(self) -> Optional[KnowledgeBase]: + """Get the knowledge base instance.""" + return self.kb diff --git a/nemoguardrails/rails/llm/llmrails.py b/nemoguardrails/rails/llm/llmrails.py index 187300aa2..d8b06a091 100644 --- a/nemoguardrails/rails/llm/llmrails.py +++ b/nemoguardrails/rails/llm/llmrails.py @@ -13,16 +13,12 @@ # See the License for the specific language governing permissions and # limitations under the License. -"""LLM Rails entry point.""" +"""LLM Rails entry point - Refactored to use modular components.""" import asyncio -import importlib.util import json import logging -import os -import re import threading -import time from functools import partial from typing import ( Any, @@ -34,7 +30,6 @@ Tuple, Type, Union, - cast, ) from langchain_core.language_models import BaseChatModel @@ -42,42 +37,15 @@ from typing_extensions import Self from nemoguardrails.actions.llm.generation import LLMGenerationActions -from nemoguardrails.actions.llm.utils import ( - extract_bot_thinking_from_events, - extract_tool_calls_from_events, - get_and_clear_response_metadata_contextvar, - get_colang_history, -) from nemoguardrails.actions.output_mapping import is_output_blocked from nemoguardrails.actions.v2_x.generation import LLMGenerationActionsV2dotx -from nemoguardrails.colang import parse_colang_file -from nemoguardrails.colang.v1_0.runtime.flows import _normalize_flow_id, compute_context -from nemoguardrails.colang.v1_0.runtime.runtime import Runtime, RuntimeV1_0 -from nemoguardrails.colang.v2_x.runtime.flows import Action, State -from nemoguardrails.colang.v2_x.runtime.runtime import RuntimeV2_x -from nemoguardrails.colang.v2_x.runtime.serialization import ( - json_to_state, - state_to_json, -) -from nemoguardrails.context import ( - explain_info_var, - generation_options_var, - llm_stats_var, - raw_llm_request, - streaming_handler_var, -) +from nemoguardrails.colang.v1_0.runtime.runtime import Runtime +from nemoguardrails.colang.v2_x.runtime.flows import State +from nemoguardrails.colang.v2_x.runtime.serialization import json_to_state from nemoguardrails.embeddings.index import EmbeddingsIndex from nemoguardrails.embeddings.providers import register_embedding_provider from nemoguardrails.embeddings.providers.base import EmbeddingModel -from nemoguardrails.kb.kb import KnowledgeBase -from nemoguardrails.llm.cache import CacheInterface, LFUCache -from nemoguardrails.llm.models.initializer import ( - ModelInitializationError, - init_llm_model, -) from nemoguardrails.logging.explain import ExplainInfo -from nemoguardrails.logging.processing_log import compute_generation_log -from nemoguardrails.logging.stats import LLMStats from nemoguardrails.logging.verbose import set_verbose from nemoguardrails.patch_asyncio import check_sync_call_from_async_loop from nemoguardrails.rails.llm.buffer import get_buffer_strategy @@ -86,30 +54,35 @@ OutputRailsStreamingConfig, RailsConfig, ) +from nemoguardrails.rails.llm.config_loader import ConfigLoader +from nemoguardrails.rails.llm.event_translator import EventTranslator +from nemoguardrails.rails.llm.kb_builder import KnowledgeBaseBuilder +from nemoguardrails.rails.llm.model_factory import ModelFactory from nemoguardrails.rails.llm.options import ( - GenerationLog, GenerationOptions, + GenerationRailsOptions, GenerationResponse, ) -from nemoguardrails.rails.llm.utils import ( - get_action_details_from_flow_id, - get_history_cache_key, -) +from nemoguardrails.rails.llm.response_assembler import ResponseAssembler +from nemoguardrails.rails.llm.runtime_orchestrator import RuntimeOrchestrator +from nemoguardrails.rails.llm.utils import get_action_details_from_flow_id from nemoguardrails.streaming import END_OF_STREAM, StreamingHandler -from nemoguardrails.utils import ( - extract_error_json, - get_or_create_event_loop, - new_event_dict, - new_uuid, -) +from nemoguardrails.utils import get_or_create_event_loop log = logging.getLogger(__name__) -process_events_semaphore = asyncio.Semaphore(1) - class LLMRails: - """Rails based on a given configuration.""" + """Rails based on a given configuration. + + Refactored to use modular components: + - ConfigEnricher: Loads and enriches configuration + - ModelFactory: Manages LLM instantiation + - KnowledgeBaseBuilder: Builds knowledge base + - EventTranslator: Converts messages to/from events + - RuntimeOrchestrator: Manages Colang runtime + - ResponseAssembler: Assembles responses + """ config: RailsConfig llm: Optional[Union[BaseLLM, BaseChatModel]] @@ -127,11 +100,11 @@ def __init__( config: A rails configuration. llm: An optional LLM engine to use. If provided, this will be used as the main LLM and will take precedence over any main LLM specified in the config. - verbose: Whether the logging should be verbose or not. + verbose: Whether the logging should be verbose. """ self.config = config - self.llm = llm self.verbose = verbose + self.explain_info: Optional[ExplainInfo] = None if self.verbose: set_verbose(True, llm_calls=True) @@ -140,116 +113,24 @@ def __init__( # an index of them. self.embedding_search_providers = {} - # The default embeddings model is using FastEmbed + # The default embeddings model is usining FastEmbed self.default_embedding_model = "all-MiniLM-L6-v2" self.default_embedding_engine = "FastEmbed" self.default_embedding_params = {} - # We keep a cache of the events history associated with a sequence of user messages. - # TODO: when we update the interface to allow to return a "state object", this - # should be removed - self.events_history_cache = {} - - # Weather the main LLM supports streaming - self.main_llm_supports_streaming = False - - # We also load the default flows from the `default_flows.yml` file in the current folder. - # But only for version 1.0. - # TODO: decide on the default flows for 2.x. - if config.colang_version == "1.0": - # We also load the default flows from the `llm_flows.co` file in the current folder. - current_folder = os.path.dirname(__file__) - default_flows_file = "llm_flows.co" - default_flows_path = os.path.join(current_folder, default_flows_file) - with open(default_flows_path, "r") as f: - default_flows_content = f.read() - default_flows = parse_colang_file( - default_flows_file, default_flows_content - )["flows"] - - # We mark all the default flows as system flows. - for flow_config in default_flows: - flow_config["is_system_flow"] = True - - # We add the default flows to the config. - self.config.flows.extend(default_flows) - - # We also need to load the content from the components library. - library_path = os.path.join(os.path.dirname(__file__), "../../library") - for root, dirs, files in os.walk(library_path): - for file in files: - # Extract the full path for the file - full_path = os.path.join(root, file) - if file.endswith(".co"): - log.debug(f"Loading file: {full_path}") - with open(full_path, "r", encoding="utf-8") as f: - content = parse_colang_file( - file, content=f.read(), version=config.colang_version - ) - if not content: - continue - - # We mark all the flows coming from the guardrails library as system flows. - for flow_config in content["flows"]: - flow_config["is_system_flow"] = True - - # We load all the flows - self.config.flows.extend(content["flows"]) - - # And all the messages as well, if they have not been overwritten - for message_id, utterances in content.get( - "bot_messages", {} - ).items(): - if message_id not in self.config.bot_messages: - self.config.bot_messages[message_id] = utterances - - # Last but not least, we mark all the flows that are used in any of the rails - # as system flows (so they don't end up in the prompt). - - rail_flow_ids = ( - config.rails.input.flows - + config.rails.output.flows - + config.rails.retrieval.flows - ) + # Load config with default flows and library content + config_modules = ConfigLoader.load_config(config) - for flow_config in self.config.flows: - if flow_config.get("id") in rail_flow_ids: - flow_config["is_system_flow"] = True - - # We also mark them as subflows by default, to simplify the syntax - flow_config["is_subflow"] = True - - # We check if the configuration or any of the imported ones have config.py modules. - config_modules = [] - for _path in list( - self.config.imported_paths.values() if self.config.imported_paths else [] - ) + [self.config.config_path]: - if _path: - filepath = os.path.join(_path, "config.py") - if os.path.exists(filepath): - filename = os.path.basename(filepath) - spec = importlib.util.spec_from_file_location(filename, filepath) - if spec and spec.loader: - config_module = importlib.util.module_from_spec(spec) - spec.loader.exec_module(config_module) - config_modules.append(config_module) - - # First, we initialize the runtime. - if config.colang_version == "1.0": - self.runtime = RuntimeV1_0(config=config, verbose=verbose) - elif config.colang_version == "2.x": - self.runtime = RuntimeV2_x(config=config, verbose=verbose) - else: - raise ValueError(f"Unsupported colang version: {config.colang_version}.") + # Initialize RuntimeOrchestrator + self.runtime_orchestrator = RuntimeOrchestrator(config=config, verbose=verbose) + self.runtime = self.runtime_orchestrator.runtime - # If we have a config_modules with an `init` function, we call it. - # We need to call this here because the `init` might register additional - # LLM providers. + # Execute config.py init functions for config_module in config_modules: if hasattr(config_module, "init"): config_module.init(self) - # If we have a customized embedding model, we'll use it. + # Update embedding model if specified in config for model in self.config.models: if model.type == "embeddings": self.default_embedding_model = model.model @@ -257,8 +138,7 @@ def __init__( self.default_embedding_params = model.parameters or {} break - # InteractionLogAdapters used for tracing - # We ensure that it is used after config.py is loaded + # Initialize tracing adapters if config.tracing: from nemoguardrails.tracing import create_log_adapters @@ -266,13 +146,24 @@ def __init__( else: self._log_adapters = None - # We run some additional checks on the config - self._validate_config() + # Initialize ModelFactory + self.model_factory = ModelFactory(config=config, injected_llm=llm) + self.action_param_registry: Dict[str, Any] = {} - # Next, we initialize the LLM engines (main engine and action engines if specified). - self._init_llms() + # Initialize models and register action parameters + self.model_factory.initialize_models(self.action_param_registry) + for param_name, param_value in self.action_param_registry.items(): + self.runtime.register_action_param(param_name, param_value) - # Next, we initialize the LLM Generate actions and register them. + # Store main LLM reference for backwards compatibility + self.llm = self.model_factory.get_main_llm() + self.main_llm_supports_streaming = self.model_factory.supports_streaming() + + # Expose specialized LLMs as attributes for convenient access + for llm_type, llm_instance in self.model_factory.specialized_llms.items(): + setattr(self, f"{llm_type}_llm", llm_instance) + + # Initialize LLM Generation Actions llm_generation_actions_class = ( LLMGenerationActions if config.colang_version == "1.0" @@ -285,308 +176,39 @@ def __init__( get_embedding_search_provider_instance=self._get_embeddings_search_provider_instance, verbose=verbose, ) - - # If there's already an action registered, we don't override. self.runtime.register_actions(self.llm_generation_actions, override=False) - # Next, we initialize the Knowledge Base - # There are still some edge cases not covered by nest_asyncio. - # Using a separate thread always for now. + # Initialize KnowledgeBaseBuilder + self.kb_builder = KnowledgeBaseBuilder( + config=self.config, + get_embeddings_search_provider_instance=self._get_embeddings_search_provider_instance, + ) + + # Build KB in separate thread loop = get_or_create_event_loop() if True or check_sync_call_from_async_loop(): - t = threading.Thread(target=asyncio.run, args=(self._init_kb(),)) + t = threading.Thread(target=asyncio.run, args=(self.kb_builder.build(),)) t.start() t.join() else: - loop.run_until_complete(self._init_kb()) + loop.run_until_complete(self.kb_builder.build()) - # We also register the kb as a parameter that can be passed to actions. + # Register KB as action parameter + self.kb = self.kb_builder.get_kb() self.runtime.register_action_param("kb", self.kb) - # Reference to the general ExplainInfo object. - self.explain_info = None - - def update_llm(self, llm): - """Replace the main LLM with the provided one. - - Arguments: - llm: The new LLM that should be used. - """ - self.llm = llm - self.llm_generation_actions.llm = llm - self.runtime.register_action_param("llm", llm) - - def _validate_config(self): - """Runs additional validation checks on the config.""" - - if self.config.colang_version == "1.0": - existing_flows_names = set([flow.get("id") for flow in self.config.flows]) - else: - existing_flows_names = set([flow.get("name") for flow in self.config.flows]) - - for flow_name in self.config.rails.input.flows: - # content safety check input/output flows are special as they have parameters - flow_name = _normalize_flow_id(flow_name) - if flow_name not in existing_flows_names: - raise ValueError( - f"The provided input rail flow `{flow_name}` does not exist" - ) - - for flow_name in self.config.rails.output.flows: - flow_name = _normalize_flow_id(flow_name) - if flow_name not in existing_flows_names: - raise ValueError( - f"The provided output rail flow `{flow_name}` does not exist" - ) - - for flow_name in self.config.rails.retrieval.flows: - if flow_name not in existing_flows_names: - raise ValueError( - f"The provided retrieval rail flow `{flow_name}` does not exist" - ) - - # If both passthrough mode and single call mode are specified, we raise an exception. - if self.config.passthrough and self.config.rails.dialog.single_call.enabled: - raise ValueError( - "The passthrough mode and the single call dialog rails mode can't be used at the same time. " - "The single call mode needs to use an altered prompt when prompting the LLM. " - ) - - async def _init_kb(self): - """Initializes the knowledge base.""" - self.kb = None - - if not self.config.docs: - return - - documents = [doc.content for doc in self.config.docs] - self.kb = KnowledgeBase( - documents=documents, - config=self.config.knowledge_base, - get_embedding_search_provider_instance=self._get_embeddings_search_provider_instance, - ) - self.kb.init() - await self.kb.build() - - def _prepare_model_kwargs(self, model_config): - """ - Prepare kwargs for model initialization, including API key from environment variable. - - Args: - model_config: The model configuration object - - Returns: - dict: The prepared kwargs for model initialization - """ - kwargs = model_config.parameters or {} - - # If the optional API Key Environment Variable is set, add it to kwargs - if model_config.api_key_env_var: - api_key = os.environ.get(model_config.api_key_env_var) - if api_key: - kwargs["api_key"] = api_key - - # enable streaming token usage when streaming is enabled - # providers that don't support this parameter will simply ignore it - if self.config.streaming: - kwargs["stream_usage"] = True - - return kwargs - - def _configure_main_llm_streaming( - self, - llm: Union[BaseLLM, BaseChatModel], - model_name: Optional[str] = None, - provider_name: Optional[str] = None, - ): - """Configure streaming support for the main LLM. - - Args: - llm (Union[BaseLLM, BaseChatModel]): The main LLM model instance. - model_name (Optional[str], optional): Optional model name for logging. - provider_name (Optional[str], optional): Optional provider name for logging. - - """ - if not self.config.streaming: - return - - if hasattr(llm, "streaming"): - setattr(llm, "streaming", True) - self.main_llm_supports_streaming = True - else: - self.main_llm_supports_streaming = False - if model_name and provider_name: - log.warning( - "Model %s from provider %s does not support streaming.", - model_name, - provider_name, - ) - else: - log.warning("Provided main LLM does not support streaming.") - - def _init_llms(self): - """ - Initializes the right LLM engines based on the configuration. - There can be multiple LLM engines and types that can be specified in the config. - The main LLM engine is the one that will be used for all the core guardrails generations. - Other LLM engines can be specified for use in specific actions. - - The reason we provide an option for decoupling the main LLM engine from the action LLM - is to allow for flexibility in using specialized LLM engines for specific actions. - - Raises: - ModelInitializationError: If any model initialization fails - """ - # If the user supplied an already-constructed LLM via the constructor we - # treat it as the *main* model, but **still** iterate through the - # configuration to load any additional models (e.g. `content_safety`). - - if self.llm: - # If an LLM was provided via constructor, use it as the main LLM - # Log a warning if a main LLM is also specified in the config - if any(model.type == "main" for model in self.config.models): - log.warning( - "Both an LLM was provided via constructor and a main LLM is specified in the config. " - "The LLM provided via constructor will be used and the main LLM from config will be ignored." - ) - self.runtime.register_action_param("llm", self.llm) - - self._configure_main_llm_streaming(self.llm) - else: - # Otherwise, initialize the main LLM from the config - main_model = next( - (model for model in self.config.models if model.type == "main"), None - ) - - if main_model and main_model.model: - kwargs = self._prepare_model_kwargs(main_model) - self.llm = init_llm_model( - model_name=main_model.model, - provider_name=main_model.engine, - mode="chat", - kwargs=kwargs, - ) - self.runtime.register_action_param("llm", self.llm) - - self._configure_main_llm_streaming( - self.llm, - model_name=main_model.model, - provider_name=main_model.engine, - ) - else: - log.warning( - "No main LLM specified in the config and no LLM provided via constructor." - ) - - llms = dict() - - for llm_config in self.config.models: - if llm_config.type in ["embeddings", "jailbreak_detection"]: - continue + # Initialize EventTranslator + self.event_translator = EventTranslator(config=self.config) + # For backwards compatibility + self.events_history_cache = self.event_translator.events_history_cache - # If a constructor LLM is provided, skip initializing any 'main' model from config - if self.llm and llm_config.type == "main": - continue - - try: - model_name = llm_config.model - if not model_name: - raise ValueError("LLM Config model field not set") - - provider_name = llm_config.engine - kwargs = self._prepare_model_kwargs(llm_config) - mode = llm_config.mode - - llm_model = init_llm_model( - model_name=model_name, - provider_name=provider_name, - mode=mode, - kwargs=kwargs, - ) - - # Configure the model based on its type - if llm_config.type == "main": - # If a main LLM was already injected, skip creating another - # one. Otherwise, create and register it. - if not self.llm: - self.llm = llm_model - self.runtime.register_action_param("llm", self.llm) - else: - model_name = f"{llm_config.type}_llm" - if not hasattr(self, model_name): - setattr(self, model_name, llm_model) - self.runtime.register_action_param( - model_name, getattr(self, model_name) - ) - # this is used for content safety and topic control - llms[llm_config.type] = getattr(self, model_name) - - except ModelInitializationError as e: - log.error("Failed to initialize model: %s", str(e)) - raise - except Exception as e: - log.error("Unexpected error initializing model: %s", str(e)) - raise - - self.runtime.register_action_param("llms", llms) - - self._initialize_model_caches() - - def _create_model_cache(self, model) -> LFUCache: - """ - Create cache instance for a model based on its configuration. - - Args: - model: The model configuration object - - Returns: - LFUCache: The cache instance - """ - - if model.cache.maxsize <= 0: - raise ValueError( - f"Invalid cache maxsize for model '{model.type}': {model.cache.maxsize}. " - "Capacity must be greater than 0. Skipping cache creation." - ) - - stats_logging_interval = None - if model.cache.stats.enabled and model.cache.stats.log_interval is not None: - stats_logging_interval = model.cache.stats.log_interval - - cache = LFUCache( - maxsize=model.cache.maxsize, - track_stats=model.cache.stats.enabled, - stats_logging_interval=stats_logging_interval, - ) - - log.info( - f"Created cache for model '{model.type}' with maxsize {model.cache.maxsize}" - ) - - return cache - - def _initialize_model_caches(self) -> None: - """Initialize caches for configured models.""" - model_caches: Optional[Dict[str, CacheInterface]] = dict() - for model in self.config.models: - if model.type in ["main", "embeddings"]: - continue - - if model.cache and model.cache.enabled: - cache = self._create_model_cache(model) - model_caches[model.type] = cache - - log.info( - f"Initialized model '{model.type}' with cache %s", - "enabled" if cache else "disabled", - ) - - if model_caches: - self.runtime.register_action_param("model_caches", model_caches) + # Initialize ResponseAssembler + self.response_assembler = ResponseAssembler(config=self.config) def _get_embeddings_search_provider_instance( self, esp_config: Optional[EmbeddingSearchProvider] = None ) -> EmbeddingsIndex: + """Get an embeddings search provider instance.""" if esp_config is None: esp_config = EmbeddingSearchProvider() @@ -604,7 +226,6 @@ def _get_embeddings_search_provider_instance( "embedding_parameters", self.default_embedding_params ), cache_config=esp_config.cache, - # We make sure we also pass additional relevant params. **{ k: v for k, v in esp_config.parameters.items() @@ -625,181 +246,11 @@ def _get_embeddings_search_provider_instance( kwargs = esp_config.parameters return self.embedding_search_providers[esp_config.name](**kwargs) - def _get_events_for_messages(self, messages: List[dict], state: Any): - """Return the list of events corresponding to the provided messages. - - Tries to find a prefix of messages for which we have already a list of events - in the cache. For the rest, they are converted as is. - - The reason this cache exists is that we want to benefit from events generated in - previous turns, which can't be computed again because it would be expensive (e.g., - involving multiple LLM calls). - - When an explicit state object will be added, this mechanism can be removed. - - Args: - messages: The list of messages. - - Returns: - A list of events. - """ - events = [] - - if self.config.colang_version == "1.0": - # We try to find the longest prefix of messages for which we have a cache - # of events. - p = len(messages) - 1 - while p > 0: - cache_key = get_history_cache_key(messages[0:p]) - if cache_key in self.events_history_cache: - events = self.events_history_cache[cache_key].copy() - break - - p -= 1 - - # For the rest of the messages, we transform them directly into events. - # TODO: Move this to separate function once more types of messages are supported. - for idx in range(p, len(messages)): - msg = messages[idx] - if msg["role"] == "user": - events.append( - { - "type": "UtteranceUserActionFinished", - "final_transcript": msg["content"], - } - ) - - # If it's not the last message, we also need to add the `UserMessage` event - if idx != len(messages) - 1: - events.append( - { - "type": "UserMessage", - "text": msg["content"], - } - ) - - elif msg["role"] == "assistant": - if msg.get("tool_calls"): - events.append( - {"type": "BotToolCalls", "tool_calls": msg["tool_calls"]} - ) - else: - action_uid = new_uuid() - start_event = new_event_dict( - "StartUtteranceBotAction", - script=msg["content"], - action_uid=action_uid, - ) - finished_event = new_event_dict( - "UtteranceBotActionFinished", - final_script=msg["content"], - is_success=True, - action_uid=action_uid, - ) - events.extend([start_event, finished_event]) - elif msg["role"] == "context": - events.append({"type": "ContextUpdate", "data": msg["content"]}) - elif msg["role"] == "event": - events.append(msg["event"]) - elif msg["role"] == "system": - # Handle system messages - convert them to SystemMessage events - events.append({"type": "SystemMessage", "content": msg["content"]}) - elif msg["role"] == "tool": - # For the last tool message, create grouped tool event and synthetic UserMessage - if idx == len(messages) - 1: - # Find the original user message for response generation - user_message = None - for prev_msg in reversed(messages[:idx]): - if prev_msg["role"] == "user": - user_message = prev_msg["content"] - break - - if user_message: - # If tool input rails are configured, group all tool messages - if self.config.rails.tool_input.flows: - # Collect all tool messages for grouped processing - tool_messages = [] - for tool_idx in range(len(messages)): - if messages[tool_idx]["role"] == "tool": - tool_messages.append( - { - "content": messages[tool_idx][ - "content" - ], - "name": messages[tool_idx].get( - "name", "unknown" - ), - "tool_call_id": messages[tool_idx].get( - "tool_call_id", "" - ), - } - ) - - events.append( - { - "type": "UserToolMessages", - "tool_messages": tool_messages, - } - ) - - else: - events.append( - {"type": "UserMessage", "text": user_message} - ) - - else: - for idx in range(len(messages)): - msg = messages[idx] - if msg["role"] == "user": - events.append( - { - "type": "UtteranceUserActionFinished", - "final_transcript": msg["content"], - } - ) - - elif msg["role"] == "assistant": - raise ValueError( - "Providing `assistant` messages as input is not supported for Colang 2.0 configurations." - ) - elif msg["role"] == "context": - events.append({"type": "ContextUpdate", "data": msg["content"]}) - elif msg["role"] == "event": - events.append(msg["event"]) - elif msg["role"] == "system": - # Handle system messages - convert them to SystemMessage events - events.append({"type": "SystemMessage", "content": msg["content"]}) - elif msg["role"] == "tool": - action_uid = msg["tool_call_id"] - return_value = msg["content"] - action: Action = state.actions[action_uid] - events.append( - new_event_dict( - f"{action.name}Finished", - action_uid=action_uid, - action_name=action.name, - status="success", - is_success=True, - return_value=return_value, - events=[], - ) - ) - - return events - - @staticmethod - def _ensure_explain_info() -> ExplainInfo: - """Ensure that the ExplainInfo variable is present in the current context - - Returns: - A ExplainInfo class containing the llm calls' statistics - """ - explain_info = explain_info_var.get() - if explain_info is None: - explain_info = ExplainInfo() - explain_info_var.set(explain_info) - - return explain_info + def update_llm(self, llm: Union[BaseLLM, BaseChatModel]): + """Replace the main LLM with the provided one.""" + self.llm = llm + self.model_factory.update_main_llm(llm, self.action_param_registry) + self.llm_generation_actions.llm = llm async def generate_async( self, @@ -809,85 +260,50 @@ async def generate_async( state: Optional[Union[dict, State]] = None, streaming_handler: Optional[StreamingHandler] = None, ) -> Union[str, dict, GenerationResponse, Tuple[dict, dict]]: - """Generate a completion or a next message. - - The format for messages is the following: + """Generate a completion or next message. - ```python - [ - {"role": "context", "content": {"user_name": "John"}}, - {"role": "user", "content": "Hello! How are you?"}, - {"role": "assistant", "content": "I am fine, thank you!"}, - {"role": "event", "event": {"type": "UserSilent"}}, - ... - ] - ``` - - Args: - prompt: The prompt to be used for completion. - messages: The history of messages to be used to generate the next message. - options: Options specific for the generation. - state: The state object that should be used as the starting point. - streaming_handler: If specified, and the config supports streaming, the - provided handler will be used for streaming. - - Returns: - The completion (when a prompt is provided) or the next message. - - System messages are not yet supported.""" - # convert options to gen_options of type GenerationOptions - gen_options: Optional[GenerationOptions] = None + Delegates to components for actual processing. + """ + import time + + from nemoguardrails.actions.llm.utils import get_colang_history + from nemoguardrails.context import ( + explain_info_var, + generation_options_var, + llm_stats_var, + raw_llm_request, + streaming_handler_var, + ) + from nemoguardrails.logging.stats import LLMStats + from nemoguardrails.utils import extract_error_json + # Input validation if prompt is None and messages is None: raise ValueError("Either prompt or messages must be provided.") - if prompt is not None and messages is not None: raise ValueError("Only one of prompt or messages can be provided.") + # Convert prompt to messages if prompt is not None: - # Currently, we transform the prompt request into a single turn conversation messages = [{"role": "user", "content": prompt}] - # If a state object is specified, then we switch to "generation options" mode. - # This is because we want the output to be a GenerationResponse which will contain - # the output state. + # Deserialize state if needed if state is not None: - # We deserialize the state if needed. if isinstance(state, dict) and state.get("version", "1.0") == "2.x": state = json_to_state(state["state"]) - if options is None: - gen_options = GenerationOptions() - elif isinstance(options, dict): - gen_options = GenerationOptions(**options) - else: - gen_options = options - else: - # We allow options to be specified both as a dict and as an object. - if options and isinstance(options, dict): - gen_options = GenerationOptions(**options) - elif isinstance(options, GenerationOptions): - gen_options = options - elif options is None: - gen_options = None - else: - raise TypeError("options must be a dict or GenerationOptions") - - # Save the generation options in the current async context. - # At this point, gen_options is either None or GenerationOptions + # Process options + gen_options = self._process_options(options, state) generation_options_var.set(gen_options) if streaming_handler: streaming_handler_var.set(streaming_handler) - # Initialize the object with additional explanation information. - # We allow this to also be set externally. This is useful when multiple parallel - # requests are made. + # Initialize explain info self.explain_info = self._ensure_explain_info() - raw_llm_request.set(messages) - # If we have generation options, we also add them to the context + # Inject generation options into messages if gen_options: messages = [ { @@ -896,157 +312,72 @@ async def generate_async( } ] + (messages or []) - # If the last message is from the assistant, rather than the user, then - # we move that to the `$bot_message` variable. This is to enable a more - # convenient interface. (only when dialog rails are disabled) + # Handle bot message in context for non-dialog mode if ( messages and messages[-1]["role"] == "assistant" and gen_options and gen_options.rails.dialog is False ): - # We already have the first message with a context update, so we use that messages[0]["content"]["bot_message"] = messages[-1]["content"] messages = messages[0:-1] - # TODO: Add support to load back history of events, next to history of messages - # This is important as without it, the LLM prediction is not as good. - t0 = time.time() - # Initialize the LLM stats + # Initialize LLM stats llm_stats = LLMStats() llm_stats_var.set(llm_stats) - processing_log = [] - - # The array of events corresponding to the provided sequence of messages. - events = self._get_events_for_messages(messages, state) # type: ignore - if self.config.colang_version == "1.0": - # If we had a state object, we also need to prepend the events from the state. - state_events = [] - if state: - assert isinstance(state, dict) - state_events = state["events"] - - new_events = [] - # Compute the new events. - try: - new_events = await self.runtime.generate_events( - state_events + events, processing_log=processing_log - ) - output_state = None - - except Exception as e: - log.error("Error in generate_async: %s", e, exc_info=True) - streaming_handler = streaming_handler_var.get() - if streaming_handler: - # Push an error chunk instead of None. - error_message = str(e) - error_dict = extract_error_json(error_message) - error_payload: str = json.dumps(error_dict) - await streaming_handler.push_chunk(error_payload) - # push a termination signal - await streaming_handler.push_chunk(END_OF_STREAM) # type: ignore - # Re-raise the exact exception - raise + # Translate messages to events + if messages is None: + raise ValueError("messages must be provided") else: - # In generation mode, by default the bot response is an instant action. - instant_actions = ["UtteranceBotAction"] - if self.config.rails.actions.instant_actions is not None: - instant_actions = self.config.rails.actions.instant_actions - - # Cast this explicitly to avoid certain warnings - runtime: RuntimeV2_x = cast(RuntimeV2_x, self.runtime) - - # Compute the new events. - # In generation mode, the processing is always blocking, i.e., it waits for - # all local actions (sync and async). - new_events, output_state = await runtime.process_events( - events, state=state, instant_actions=instant_actions, blocking=True - ) - # We also encode the output state as a JSON - output_state = {"state": state_to_json(output_state), "version": "2.x"} - - # Extract and join all the messages from StartUtteranceBotAction events as the response. - responses = [] - response_tool_calls = [] - response_events = [] - new_extra_events = [] - exception = None + events = self.event_translator.messages_to_events(messages, state) - # The processing is different for Colang 1.0 and 2.0 - if self.config.colang_version == "1.0": - for event in new_events: - if event["type"] == "StartUtteranceBotAction": - # Check if we need to remove a message - if event["script"] == "(remove last message)": - responses = responses[0:-1] - else: - responses.append(event["script"]) - elif event["type"].endswith("Exception"): - exception = event - - else: - for event in new_events: - start_action_match = re.match(r"Start(.*Action)", event["type"]) - - if start_action_match: - action_name = start_action_match[1] - # TODO: is there an elegant way to extract just the arguments? - arguments = { - k: v - for k, v in event.items() - if k != "type" - and k != "uid" - and k != "event_created_at" - and k != "source_uid" - and k != "action_uid" - } - response_tool_calls.append( - { - "id": event["action_uid"], - "type": "function", - "function": {"name": action_name, "arguments": arguments}, - } - ) - - elif event["type"] == "UtteranceBotActionFinished": - responses.append(event["final_script"]) - else: - # We just append the event - response_events.append(event) - - if exception: - new_message: dict = {"role": "exception", "content": exception} - - else: - # Ensure all items in responses are strings - responses = [ - str(response) if not isinstance(response, str) else response - for response in responses - ] - new_message: dict = {"role": "assistant", "content": "\n".join(responses)} - if response_tool_calls: - new_message["tool_calls"] = response_tool_calls - if response_events: - new_message["events"] = response_events + # Generate new events using runtime orchestrator + try: + log.info( + f"DEBUG: Calling generate_events with state containing skip_output_rails: {state.get('skip_output_rails') if isinstance(state, dict) else 'state is not dict'}" + ) + ( + new_events, + output_state, + processing_log, + ) = await self.runtime_orchestrator.generate_events( + events=events, state=state + ) + except Exception as e: + log.error("Error in generate_async: %s", e, exc_info=True) + streaming_handler = streaming_handler_var.get() + if streaming_handler: + error_message = str(e) + error_dict = extract_error_json(error_message) + error_payload = json.dumps(error_dict) + await streaming_handler.push_chunk(error_payload) + await streaming_handler.push_chunk(END_OF_STREAM) + raise + # Cache events for Colang 1.0 if self.config.colang_version == "1.0": - events.extend(new_events) - events.extend(new_extra_events) + responses, _, _, _ = self.response_assembler._extract_from_events( + new_events + ) + responses = [str(r) if not isinstance(r, str) else r for r in responses] + new_message = {"role": "assistant", "content": "\n".join(responses)} - # If a state object is not used, then we use the implicit caching if state is None: - # Save the new events in the history and update the cache - cache_key = get_history_cache_key((messages) + [new_message]) # type: ignore - self.events_history_cache[cache_key] = events + all_events = events + new_events + self.event_translator.cache_events( + (messages or []) + [new_message], all_events + ) else: - output_state = {"events": events} + output_state = {"events": events + new_events} - # If logging is enabled, we log the conversation - # TODO: add support for logging flag - self.explain_info.colang_history = get_colang_history(events) + # Log conversation history + all_events = ( + events + new_events if self.config.colang_version == "1.0" else new_events + ) + self.explain_info.colang_history = get_colang_history(all_events) if self.verbose: log.info( f"Conversation history so far: \n{self.explain_info.colang_history}" @@ -1058,192 +389,160 @@ async def generate_async( % (total_time, llm_stats) ) - # If there is a streaming handler, we make sure we close it now + # Close streaming handler streaming_handler = streaming_handler_var.get() if streaming_handler: - # print("Closing the stream handler explicitly") - await streaming_handler.push_chunk(END_OF_STREAM) # type: ignore - - # IF tracing is enabled we need to set GenerationLog attrs - original_log_options = None - if self.config.tracing.enabled: - if gen_options is None: - gen_options = GenerationOptions() - else: - # create a copy of the gen_options to avoid modifying the original - gen_options = gen_options.model_copy(deep=True) - original_log_options = gen_options.log.model_copy(deep=True) + await streaming_handler.push_chunk(END_OF_STREAM) - # enable log options - # it is aggressive, but these are required for tracing - if ( - not gen_options.log.activated_rails - or not gen_options.log.llm_calls - or not gen_options.log.internal_events - ): - gen_options.log.activated_rails = True - gen_options.log.llm_calls = True - gen_options.log.internal_events = True - - tool_calls = extract_tool_calls_from_events(new_events) - llm_metadata = get_and_clear_response_metadata_contextvar() - reasoning_content = extract_bot_thinking_from_events(new_events) - # If we have generation options, we prepare a GenerationResponse instance. + # Assemble response if gen_options: - # If a prompt was used, we only need to return the content of the message. - if prompt: - res = GenerationResponse(response=new_message["content"]) - else: - res = GenerationResponse(response=[new_message]) - - if reasoning_content: - res.reasoning_content = reasoning_content + res = self.response_assembler.assemble_response( + new_events=new_events, + all_events=all_events, + output_state=output_state, + processing_log=processing_log, + gen_options=gen_options, + prompt=prompt, + ) - if tool_calls: - res.tool_calls = tool_calls + # Handle tracing if enabled + if self.config.tracing.enabled and messages is not None: + await self._handle_tracing( + messages, res, gen_options, processing_log, all_events + ) - if llm_metadata: - res.llm_metadata = llm_metadata + return res + else: + simple_res = self.response_assembler.assemble_simple_response( + new_events=new_events, prompt=prompt + ) - if self.config.colang_version == "1.0": - # If output variables are specified, we extract their values - if gen_options and gen_options.output_vars: - context = compute_context(events) - output_vars = gen_options.output_vars - if isinstance(output_vars, list): - # If we have only a selection of keys, we filter to only that. - res.output_data = {k: context.get(k) for k in output_vars} - else: - # Otherwise, we return the full context - res.output_data = context + # Handle tracing if enabled (even when options is None) + if self.config.tracing.enabled and messages is not None: + # Convert simple response to GenerationResponse for tracing + if isinstance(simple_res, dict): + trace_res = GenerationResponse(response=[simple_res], log=None) + else: + trace_res = GenerationResponse(response=simple_res, log=None) + await self._handle_tracing( + messages, trace_res, None, processing_log, all_events + ) + # Return GenerationResponse when tracing is enabled + return trace_res - _log = compute_generation_log(processing_log) + return simple_res - # Include information about activated rails and LLM calls if requested - log_options = gen_options.log if gen_options else None - if log_options and ( - log_options.activated_rails or log_options.llm_calls - ): - res.log = GenerationLog() - - # We always include the stats - res.log.stats = _log.stats - - if log_options.activated_rails: - res.log.activated_rails = _log.activated_rails - - if log_options.llm_calls: - res.log.llm_calls = [] - for activated_rail in _log.activated_rails: - for executed_action in activated_rail.executed_actions: - res.log.llm_calls.extend(executed_action.llm_calls) - - # Include internal events if requested - if log_options and log_options.internal_events: - if res.log is None: - res.log = GenerationLog() - - res.log.internal_events = new_events - - # Include the Colang history if requested - if log_options and log_options.colang_history: - if res.log is None: - res.log = GenerationLog() - - res.log.colang_history = get_colang_history(events) - - # Include the raw llm output if requested - if gen_options and gen_options.llm_output: - # Currently, we include the output from the generation LLM calls. - for activated_rail in _log.activated_rails: - if activated_rail.type == "generation": - for executed_action in activated_rail.executed_actions: - for llm_call in executed_action.llm_calls: - res.llm_output = llm_call.raw_response + def _process_options( + self, + options: Optional[Union[dict, GenerationOptions]], + state: Optional[Any], + ) -> Optional[GenerationOptions]: + """Process and normalize generation options.""" + if state is not None: + if options is None: + return GenerationOptions() + elif isinstance(options, dict): + return GenerationOptions(**options) else: - if gen_options and gen_options.output_vars: - raise ValueError( - "The `output_vars` option is not supported for Colang 2.0 configurations." - ) - - log_options = gen_options.log if gen_options else None - if log_options and ( - log_options.activated_rails - or log_options.llm_calls - or log_options.internal_events - or log_options.colang_history - ): - raise ValueError( - "The `log` option is not supported for Colang 2.0 configurations." - ) - - if gen_options and gen_options.llm_output: - raise ValueError( - "The `llm_output` option is not supported for Colang 2.0 configurations." - ) - - # Include the state - if state is not None: - res.state = output_state + return options + else: + if options and isinstance(options, dict): + return GenerationOptions(**options) + elif isinstance(options, GenerationOptions): + return options + elif options is None: + return None + else: + raise TypeError("options must be a dict or GenerationOptions") - if self.config.tracing.enabled: - # TODO: move it to the top once resolved circular dependency of eval - # lazy import to avoid circular dependency - from nemoguardrails.tracing import Tracer + @staticmethod + def _ensure_explain_info() -> ExplainInfo: + """Ensure ExplainInfo variable is present in context.""" + from nemoguardrails.context import explain_info_var - span_format = getattr( - self.config.tracing, "span_format", "opentelemetry" - ) - enable_content_capture = getattr( - self.config.tracing, "enable_content_capture", False - ) - # Create a Tracer instance with instantiated adapters and span configuration - tracer = Tracer( - input=messages, - response=res, - adapters=self._log_adapters, - span_format=span_format, - enable_content_capture=enable_content_capture, - ) - await tracer.export_async() + explain_info = explain_info_var.get() + if explain_info is None: + explain_info = ExplainInfo() + explain_info_var.set(explain_info) + return explain_info - # respect original log specification, if tracing added information to the output - if original_log_options: - if not any( - ( - original_log_options.internal_events, - original_log_options.activated_rails, - original_log_options.llm_calls, - original_log_options.colang_history, - ) - ): - res.log = None - else: - # Ensure res.log exists before setting attributes - if res.log is not None: - if not original_log_options.internal_events: - res.log.internal_events = [] - if not original_log_options.activated_rails: - res.log.activated_rails = [] - if not original_log_options.llm_calls: - res.log.llm_calls = [] + async def _handle_tracing( + self, + messages: List[dict], + res: GenerationResponse, + gen_options: Optional[GenerationOptions], + processing_log: List[dict], + all_events: List[dict], + ): + """Handle tracing export.""" + from nemoguardrails.actions.llm.utils import get_colang_history + from nemoguardrails.logging.processing_log import compute_generation_log + from nemoguardrails.rails.llm.options import GenerationLog + from nemoguardrails.tracing import Tracer + + span_format = getattr(self.config.tracing, "span_format", "opentelemetry") + enable_content_capture = getattr( + self.config.tracing, "enable_content_capture", False + ) - return res + # If response.log is None but tracing is enabled, create a temporary log for tracing + # without attaching it to the response (to avoid mutating user's response) + if res.log is None: + # Create a log from processing_log for tracing purposes + _log = compute_generation_log(processing_log) + temp_log = GenerationLog() + temp_log.stats = _log.stats + temp_log.activated_rails = _log.activated_rails or [] + # Collect llm_calls from activated_rails + temp_log.llm_calls = [] + for activated_rail in _log.activated_rails or []: + for executed_action in activated_rail.executed_actions: + temp_log.llm_calls.extend(executed_action.llm_calls) + # Include internal events and colang history for comprehensive tracing + temp_log.internal_events = all_events + temp_log.colang_history = get_colang_history(all_events) + # Create a temporary response with the log for tracing + temp_response = GenerationResponse(response=res.response, log=temp_log) + tracer = Tracer( + input=messages, + response=temp_response, + adapters=self._log_adapters, + span_format=span_format, + enable_content_capture=enable_content_capture, + ) else: - # If a prompt is used, we only return the content of the message. + tracer = Tracer( + input=messages, + response=res, + adapters=self._log_adapters, + span_format=span_format, + enable_content_capture=enable_content_capture, + ) + await tracer.export_async() - if reasoning_content: - thinking_trace = f"{reasoning_content}\n" - new_message["content"] = thinking_trace + new_message["content"] + def generate( + self, + prompt: Optional[str] = None, + messages: Optional[List[dict]] = None, + options: Optional[Union[dict, GenerationOptions]] = None, + state: Optional[dict] = None, + ): + """Synchronous version of generate_async.""" + if check_sync_call_from_async_loop(): + raise RuntimeError( + "You are using the sync `generate` inside async code. " + "You should replace with `await generate_async(...)` or use `nest_asyncio.apply()`." + ) - if prompt: - return new_message["content"] - else: - if tool_calls: - new_message["tool_calls"] = tool_calls - return new_message + loop = get_or_create_event_loop() + return loop.run_until_complete( + self.generate_async( + prompt=prompt, messages=messages, options=options, state=state + ) + ) def _validate_streaming_with_output_rails(self) -> None: + """Validate streaming configuration with output rails.""" if len(self.config.rails.output.flows) > 0 and ( not self.config.rails.output.streaming or not self.config.rails.output.streaming.enabled @@ -1264,10 +563,10 @@ def stream_async( include_generation_metadata: Optional[bool] = False, generator: Optional[AsyncIterator[str]] = None, ) -> AsyncIterator[str]: - """Simplified interface for getting directly the streamed tokens from the LLM.""" - + """Simplified interface for getting streamed tokens from the LLM.""" self._validate_streaming_with_output_rails() - # if an external generator is provided, use it directly + + # if external generator provided, use it directly if generator: if ( self.config.rails.output.streaming @@ -1288,47 +587,66 @@ def stream_async( include_generation_metadata=include_generation_metadata ) - # Create a properly managed task with exception handling + # Create properly managed task with exception handling async def _generation_task(): try: + # When output rails streaming is enabled, we need to skip the normal + # output rails execution during the main flow, as they will be run + # separately in _run_output_rails_in_streaming + generation_options = options + if ( + self.config.rails.output.streaming + and self.config.rails.output.streaming.enabled + ): + # Disable output rails in generation_options so the llm_flows.co check prevents them from running + if generation_options is None: + generation_options = GenerationOptions( + rails=GenerationRailsOptions(output=False) + ) + elif isinstance(generation_options, dict): + generation_options = dict(generation_options) + if "rails" not in generation_options: + generation_options["rails"] = {} + generation_options["rails"]["output"] = False + else: + # It's a GenerationOptions object + generation_options = generation_options.model_copy(deep=True) + generation_options.rails.output = False + await self.generate_async( prompt=prompt, messages=messages, streaming_handler=streaming_handler, - options=options, + options=generation_options, state=state, ) except Exception as e: - # If an exception occurs during generation, push it to the streaming handler as a json string - # This ensures the streaming pipeline is properly terminated log.error(f"Error in generation task: {e}", exc_info=True) + from nemoguardrails.utils import extract_error_json + error_message = str(e) error_dict = extract_error_json(error_message) error_payload = json.dumps(error_dict) await streaming_handler.push_chunk(error_payload) - await streaming_handler.push_chunk(END_OF_STREAM) # type: ignore + await streaming_handler.push_chunk(END_OF_STREAM) task = asyncio.create_task(_generation_task()) - # Store task reference to prevent garbage collection and ensure proper cleanup + # Store task reference if not hasattr(self, "_active_tasks"): self._active_tasks = set() self._active_tasks.add(task) - # Clean up task when it's done def task_done_callback(task): self._active_tasks.discard(task) task.add_done_callback(task_done_callback) - # when we have output rails we wrap the streaming handler - # if len(self.config.rails.output.flows) > 0: - # + # Wrap with output rails if configured if ( self.config.rails.output.streaming and self.config.rails.output.streaming.enabled ): - # returns an async generator return self._run_output_rails_in_streaming( streaming_handler=streaming_handler, output_rails_streaming_config=self.config.rails.output.streaming, @@ -1338,69 +656,244 @@ def task_done_callback(task): else: return streaming_handler - def generate( + async def _run_output_rails_in_streaming( self, + streaming_handler: AsyncIterator[str], + output_rails_streaming_config: OutputRailsStreamingConfig, prompt: Optional[str] = None, messages: Optional[List[dict]] = None, - options: Optional[Union[dict, GenerationOptions]] = None, - state: Optional[dict] = None, - ): - """Synchronous version of generate_async.""" + stream_first: Optional[bool] = None, + ) -> AsyncIterator[str]: + """Run output rails in streaming mode.""" + from nemoguardrails.context import explain_info_var + from nemoguardrails.rails.llm.buffer import ChunkBatch + + # Ensure explain_info_var is set so LLM calls during streaming output rails + # are tracked properly. Use self.explain_info if available, otherwise ensure + # it's created and set in the context. We need to ensure both reference the + # same object so LLM calls are tracked correctly. + if self.explain_info is None: + self.explain_info = self._ensure_explain_info() + # Always set the context variable to point to self.explain_info + explain_info_var.set(self.explain_info) - if check_sync_call_from_async_loop(): - raise RuntimeError( - "You are using the sync `generate` inside async code. " - "You should replace with `await generate_async(...)` or use `nest_asyncio.apply()`." + # Get buffer strategy from config + buffer_strategy = get_buffer_strategy(output_rails_streaming_config) + + # Determine stream_first behavior + should_stream_first = ( + stream_first + if stream_first is not None + else output_rails_streaming_config.stream_first + ) + + # Get output flows + output_flows = self.config.rails.output.flows or [] + is_parallel = self.config.rails.output.parallel + + # Get action details for each flow if parallel + flows_with_params = {} + if is_parallel and output_flows: + try: + flows = self.config.flows + for flow_id in output_flows: + try: + action_name, action_params = get_action_details_from_flow_id( + flow_id, flows + ) + flows_with_params[flow_id] = { + "action_name": action_name, + "params": action_params, + } + except (ValueError, KeyError) as e: + log.warning( + f"Could not get action details for flow {flow_id}: {e}" + ) + except Exception as e: + log.error(f"Error getting flow details: {e}") + + # Process stream using buffer strategy + log.info("Starting to process stream with buffer strategy") + batch_count = 0 + async for chunk_batch in buffer_strategy.process_stream(streaming_handler): + batch_count += 1 + log.info( + f"Received chunk_batch #{batch_count} with {len(chunk_batch.user_output_chunks)} chunks" + ) + # Format the processing context + processing_text = buffer_strategy.format_chunks( + chunk_batch.processing_context ) - loop = get_or_create_event_loop() + # If stream_first is True, yield chunks immediately before checking rails + if should_stream_first: + for chunk in chunk_batch.user_output_chunks: + yield chunk - return loop.run_until_complete( - self.generate_async( - prompt=prompt, - messages=messages, - options=options, - state=state, + # Create events with bot_message for rail processing + events = [] + if messages: + events.extend(self.event_translator.messages_to_events(messages)) + events.append( + { + "type": "BotMessage", + "text": processing_text, + } + ) + # Add context update so actions can access bot_message + log.info( + f"Streaming output rail batch #{batch_count}: processing_text = {repr(processing_text)}" + ) + events.append( + { + "type": "ContextUpdate", + "data": {"bot_message": processing_text}, + } ) - ) - async def generate_events_async( - self, - events: List[dict], - ) -> List[dict]: - """Generate the next events based on the provided history. + # Run output rails + blocked = False + blocking_rail = None + internal_error = None - The format for events is the following: + if is_parallel and flows_with_params: + # Run parallel rails (only available in Colang 1.0) + if hasattr(self.runtime, "_run_output_rails_in_parallel_streaming"): + result = await self.runtime._run_output_rails_in_parallel_streaming( # type: ignore[attr-defined] + flows_with_params, events + ) + else: + raise RuntimeError( + "Parallel streaming output rails are only supported in Colang 1.0" + ) + if result.events: + event = result.events[0] + if event.get("error_type") == "internal_error": + internal_error = event.get("error_message") + blocked = True + blocking_rail = event.get("flow_id") + elif event.get("intent") == "stop": + blocked = True + blocking_rail = event.get("flow_id") + else: + # Run sequential rails + for flow_id in output_flows: + try: + # Ensure explain_info_var is set before processing events + # so LLM calls are tracked properly + if self.explain_info is None: + self.explain_info = self._ensure_explain_info() + explain_info_var.set(self.explain_info) + + # Create start event + start_event = { + "type": "StartOutputRail", + "flow_id": flow_id, + } + rail_events = events + [start_event] - ```python - [ - {"type": "...", ...}, - ... - ] - ``` + # Process the rail + ( + new_events, + _, + ) = await self.runtime_orchestrator.process_events_async( + rail_events, blocking=True + ) - Args: - events: The history of events to be used to generate the next events. - options: The options to be used for the generation. + # Check if rail blocked (look for stop intent or exception) + for event in new_events: + if ( + event.get("type") == "BotIntent" + and event.get("intent") == "stop" + ): + blocked = True + blocking_rail = flow_id + break + elif event.get("type", "").endswith("Exception"): + blocked = True + blocking_rail = flow_id + break + # Also check for action results with output_mapping + elif event.get("type") == "InternalSystemActionFinished": + action_name = event.get("action_name") + return_value = event.get("return_value") + log.info( + f"Sequential rail {flow_id}: action {action_name} returned {return_value}" + ) + if action_name and return_value is not None: + action_func = ( + self.runtime.action_dispatcher.get_action( + action_name + ) + ) + if action_func: + is_blocked = is_output_blocked( + return_value, action_func + ) + log.info( + f"Action {action_name} output_blocked={is_blocked}" + ) + if is_blocked: + blocked = True + blocking_rail = flow_id + break + + if blocked: + break + except Exception as e: + log.error(f"Error in sequential rail {flow_id}: {e}") + blocked = True + blocking_rail = flow_id + internal_error = str(e) + break + + # If blocked, yield error and stop + if blocked: + # Create error message + if internal_error: + error_dict = { + "error": { + "message": f"Internal error in {blocking_rail} rail: {internal_error}", + "type": "internal_error", + "param": blocking_rail, + "code": "rail_execution_failure", + } + } + else: + error_dict = { + "error": { + "message": f"Blocked by {blocking_rail} rails.", + "type": "guardrails_violation", + "param": blocking_rail, + "code": "content_blocked", + } + } + yield json.dumps(error_dict) + return - Returns: - The newly generate event(s). + # If not blocked and stream_first is False, yield chunks after rails check + if not should_stream_first: + for chunk in chunk_batch.user_output_chunks: + yield chunk + + async def generate_events_async(self, events: List[dict]) -> List[dict]: + """Generate the next events based on the provided history.""" + import time + + from nemoguardrails.actions.llm.utils import get_colang_history + from nemoguardrails.context import llm_stats_var + from nemoguardrails.logging.stats import LLMStats - """ t0 = time.time() - # Initialize the LLM stats llm_stats = LLMStats() llm_stats_var.set(llm_stats) - # Compute the new events. processing_log = [] new_events = await self.runtime.generate_events( events, processing_log=processing_log ) - # If logging is enabled, we log the conversation - # TODO: add support for logging flag if self.verbose: history = get_colang_history(events) log.info(f"Conversation history so far: \n{history}") @@ -1410,16 +903,12 @@ async def generate_events_async( return new_events - def generate_events( - self, - events: List[dict], - ) -> List[dict]: - """Synchronous version of `LLMRails.generate_events_async`.""" - + def generate_events(self, events: List[dict]) -> List[dict]: + """Synchronous version of generate_events_async.""" if check_sync_call_from_async_loop(): raise RuntimeError( "You are using the sync `generate_events` inside async code. " - "You should replace with `await generate_events_async(...)` or use `nest_asyncio.apply()`." + "You should replace with `await generate_events_async(...)`." ) loop = get_or_create_event_loop() @@ -1431,38 +920,10 @@ async def process_events_async( state: Optional[dict] = None, blocking: bool = False, ) -> Tuple[List[dict], dict]: - """Process a sequence of events in a given state. - - The events will be processed one by one, in the input order. - - Args: - events: A sequence of events that needs to be processed. - state: The state that should be used as the starting point. If not provided, - a clean state will be used. - - Returns: - (output_events, output_state) Returns a sequence of output events and an output - state. - """ - t0 = time.time() - llm_stats = LLMStats() - llm_stats_var.set(llm_stats) - - # Compute the new events. - # We need to protect 'process_events' to be called only once at a time - # TODO (cschueller): Why is this? - async with process_events_semaphore: - output_events, output_state = await self.runtime.process_events( - events, state, blocking - ) - - took = time.time() - t0 - # Small tweak, disable this when there were no events (or it was just too fast). - if took > 0.1: - log.info("--- :: Total processing took %.2f seconds." % took) - log.info("--- :: Stats: %s" % llm_stats) - - return output_events, output_state + """Process a sequence of events in a given state.""" + return await self.runtime_orchestrator.process_events_async( + events, state, blocking + ) def process_events( self, @@ -1470,12 +931,11 @@ def process_events( state: Optional[dict] = None, blocking: bool = False, ) -> Tuple[List[dict], dict]: - """Synchronous version of `LLMRails.process_events_async`.""" - + """Synchronous version of process_events_async.""" if check_sync_call_from_async_loop(): raise RuntimeError( - "You are using the sync `generate_events` inside async code. " - "You should replace with `await generate_events_async(...)." + "You are using the sync `process_events` inside async code. " + "You should replace with `await process_events_async(...)`." ) loop = get_or_create_event_loop() @@ -1483,70 +943,67 @@ def process_events( self.process_events_async(events, state, blocking) ) + # Registration methods + def register_action(self, action: Callable, name: Optional[str] = None) -> Self: - """Register a custom action for the rails configuration.""" + """Register a custom action.""" self.runtime.register_action(action, name) return self def register_action_param(self, name: str, value: Any) -> Self: - """Registers a custom action parameter.""" + """Register a custom action parameter.""" self.runtime.register_action_param(name, value) return self def register_filter(self, filter_fn: Callable, name: Optional[str] = None) -> Self: - """Register a custom filter for the rails configuration.""" + """Register a custom filter.""" self.runtime.llm_task_manager.register_filter(filter_fn, name) return self def register_output_parser(self, output_parser: Callable, name: str) -> Self: - """Register a custom output parser for the rails configuration.""" + """Register a custom output parser.""" self.runtime.llm_task_manager.register_output_parser(output_parser, name) return self def register_prompt_context(self, name: str, value_or_fn: Any) -> Self: - """Register a value to be included in the prompt context. - - :name: The name of the variable or function that will be used. - :value_or_fn: The value or function that will be used to generate the value. - """ + """Register a value to be included in the prompt context.""" self.runtime.llm_task_manager.register_prompt_context(name, value_or_fn) return self def register_embedding_search_provider( self, name: str, cls: Type[EmbeddingsIndex] ) -> Self: - """Register a new embedding search provider. - - Args: - name: The name of the embedding search provider that will be used. - cls: The class that will be used to generate and search embedding - """ - + """Register a new embedding search provider.""" self.embedding_search_providers[name] = cls return self def register_embedding_provider( self, cls: Type[EmbeddingModel], name: Optional[str] = None ) -> Self: - """Register a custom embedding provider. - - Args: - model (Type[EmbeddingModel]): The embedding model class. - name (str): The name of the embedding engine. If available in the model, it will be used. - - Raises: - ValueError: If the engine name is not provided and the model does not have an engine name. - ValueError: If the model does not have 'encode' or 'encode_async' methods. - """ + """Register a custom embedding provider.""" register_embedding_provider(engine_name=name, model=cls) return self def explain(self) -> ExplainInfo: - """Helper function to return the latest ExplainInfo object.""" + """Return the latest ExplainInfo object.""" if self.explain_info is None: self.explain_info = self._ensure_explain_info() return self.explain_info + def _prepare_model_kwargs(self, model_config) -> dict: + """Prepare kwargs for model initialization, including API key from environment. + + This method is maintained for backwards compatibility. + It delegates to ModelFactory._prepare_model_kwargs. + + Args: + model_config: The model configuration object. + + Returns: + Dictionary of kwargs for model initialization. + """ + return self.model_factory._prepare_model_kwargs(model_config) + def __getstate__(self): return {"config": self.config} @@ -1557,270 +1014,11 @@ def __setstate__(self, state): config = state["config"] self.__init__(config=config, verbose=False) - async def _run_output_rails_in_streaming( - self, - streaming_handler: AsyncIterator[str], - output_rails_streaming_config: OutputRailsStreamingConfig, - prompt: Optional[str] = None, - messages: Optional[List[dict]] = None, - stream_first: Optional[bool] = None, - ) -> AsyncIterator[str]: - """ - 1. Buffers tokens from 'streaming_handler' via BufferStrategy. - 2. Runs sequential (parallel for colang 2.0 in future) flows for each chunk. - 3. Yields the chunk if not blocked, or STOP if blocked. - """ - - def _get_last_context_message( - messages: Optional[List[dict]] = None, - ) -> dict: - if messages is None: - return {} - - for message in reversed(messages): - if message.get("role") == "context": - return message - return {} - - def _get_latest_user_message( - messages: Optional[List[dict]] = None, - ) -> dict: - if messages is None: - return {} - for message in reversed(messages): - if message.get("role") == "user": - return message - return {} - - def _prepare_context_for_parallel_rails( - chunk_str: str, - prompt: Optional[str] = None, - messages: Optional[List[dict]] = None, - ) -> dict: - """Prepare context for parallel rails execution.""" - context_message = _get_last_context_message(messages) - user_message = prompt or _get_latest_user_message(messages) - - context = { - "user_message": user_message, - "bot_message": chunk_str, - } - - if context_message: - context.update(context_message["content"]) - - return context - - def _create_events_for_chunk(chunk_str: str, context: dict) -> List[dict]: - """Create events for running output rails on a chunk.""" - return [ - {"type": "ContextUpdate", "data": context}, - {"type": "BotMessage", "text": chunk_str}, - ] - - def _prepare_params( - flow_id: str, - action_name: str, - bot_response_chunk: str, - prompt: Optional[str] = None, - messages: Optional[List[dict]] = None, - action_params: Dict[str, Any] = {}, - ): - context_message = _get_last_context_message(messages) - user_message = prompt or _get_latest_user_message(messages) - - context = { - "user_message": user_message, - "bot_message": bot_response_chunk, - } - - if context_message: - context.update(context_message["content"]) - - model_name = flow_id.split("$")[-1].split("=")[-1].strip('"') - - # we pass action params that are defined in the flow - # caveate, e.g. prmpt_security uses bot_response=$bot_message - # to resolve replace placeholders in action_params - for key, value in action_params.items(): - if value == "$bot_message": - action_params[key] = bot_response_chunk - elif value == "$user_message": - action_params[key] = user_message - - return { - # TODO:: are there other context variables that need to be passed? - # passing events to compute context was not successful - # context var failed due to different context - "context": context, - "llm_task_manager": self.runtime.llm_task_manager, - "config": self.config, - "model_name": model_name, - "llms": self.runtime.registered_action_params.get("llms", {}), - "llm": self.runtime.registered_action_params.get( - f"{action_name}_llm", self.llm - ), - **action_params, - } - - buffer_strategy = get_buffer_strategy(output_rails_streaming_config) - output_rails_flows_id = self.config.rails.output.flows - stream_first = stream_first or output_rails_streaming_config.stream_first - get_action_details = partial( - get_action_details_from_flow_id, flows=self.config.flows - ) - - parallel_mode = getattr(self.config.rails.output, "parallel", False) - - async for chunk_batch in buffer_strategy(streaming_handler): - user_output_chunks = chunk_batch.user_output_chunks - # format processing_context for output rails processing (needs full context) - bot_response_chunk = buffer_strategy.format_chunks( - chunk_batch.processing_context - ) - - # check if user_output_chunks is a list of individual chunks - # or if it's a JSON string, by convention this means an error occurred and the error dict is stored as a JSON - if not isinstance(user_output_chunks, list): - try: - json.loads(user_output_chunks) - yield user_output_chunks - return - except (json.JSONDecodeError, TypeError): - # if it's not JSON, treat it as empty list - user_output_chunks = [] - - if stream_first: - # yield the individual chunks directly from the buffer strategy - for chunk in user_output_chunks: - yield chunk - - if parallel_mode: - try: - context = _prepare_context_for_parallel_rails( - bot_response_chunk, prompt, messages - ) - events = _create_events_for_chunk(bot_response_chunk, context) - - flows_with_params = {} - for flow_id in output_rails_flows_id: - action_name, action_params = get_action_details(flow_id) - params = _prepare_params( - flow_id=flow_id, - action_name=action_name, - bot_response_chunk=bot_response_chunk, - prompt=prompt, - messages=messages, - action_params=action_params, - ) - flows_with_params[flow_id] = { - "action_name": action_name, - "params": params, - } - - result_tuple = await self.runtime.action_dispatcher.execute_action( - "run_output_rails_in_parallel_streaming", - { - "flows_with_params": flows_with_params, - "events": events, - }, - ) - # ActionDispatcher.execute_action always returns (result, status) - result, status = result_tuple - - if status != "success": - log.error( - f"Parallel rails execution failed with status: {status}" - ) - # continue processing the chunk even if rails fail - pass - else: - # if there are any stop events, content was blocked or internal error occurred - result_events = getattr(result, "events", None) - if result_events: - # extract the flow info from the first stop event - stop_event = result_events[0] - blocked_flow = stop_event.get("flow_id", "output rails") - error_type = stop_event.get("error_type") - - if error_type == "internal_error": - error_message = stop_event.get( - "error_message", "Unknown error" - ) - reason = f"Internal error in {blocked_flow} rail: {error_message}" - error_code = "rail_execution_failure" - error_type = "internal_error" - else: - reason = f"Blocked by {blocked_flow} rails." - error_code = "content_blocked" - error_type = "guardrails_violation" - - error_data = { - "error": { - "message": reason, - "type": error_type, - "param": blocked_flow, - "code": error_code, - } - } - yield json.dumps(error_data) - return - - except Exception as e: - log.error(f"Error in parallel rail execution: {e}") - # don't block the stream for rail execution errors - # continue processing the chunk - pass - - # update explain info for parallel mode - self.explain_info = self._ensure_explain_info() - - else: - for flow_id in output_rails_flows_id: - action_name, action_params = get_action_details(flow_id) - - params = _prepare_params( - flow_id=flow_id, - action_name=action_name, - bot_response_chunk=bot_response_chunk, - prompt=prompt, - messages=messages, - action_params=action_params, - ) - - result = await self.runtime.action_dispatcher.execute_action( - action_name, params - ) - self.explain_info = self._ensure_explain_info() - - action_func = self.runtime.action_dispatcher.get_action(action_name) - - # Use the mapping to decide if the result indicates blocked content. - if is_output_blocked(result, action_func): - reason = f"Blocked by {flow_id} rails." - - # return the error as a plain JSON string (not in SSE format) - # NOTE: When integrating with the OpenAI Python client, the server code should: - # 1. detect this JSON error object in the stream - # 2. terminate the stream - # 3. format the error following OpenAI's SSE format - # the OpenAI client will then properly raise an APIError with this error message - - error_data = { - "error": { - "message": reason, - "type": "guardrails_violation", - "param": flow_id, - "code": "content_blocked", - } - } - - # return as plain JSON: the server should detect this JSON and convert it to an HTTP error - yield json.dumps(error_data) - return - - if not stream_first: - # yield the individual chunks directly from the buffer strategy - for chunk in user_output_chunks: - yield chunk +# Re-export for backwards compatibility +__all__ = [ + "LLMRails", + "get_action_details_from_flow_id", + "GenerationOptions", + "GenerationResponse", +] diff --git a/nemoguardrails/rails/llm/model_factory.py b/nemoguardrails/rails/llm/model_factory.py new file mode 100644 index 000000000..48e0a851e --- /dev/null +++ b/nemoguardrails/rails/llm/model_factory.py @@ -0,0 +1,261 @@ +import logging +import os +from typing import Any, Dict, Optional, Union + +from langchain_core.language_models import BaseChatModel +from langchain_core.language_models.llms import BaseLLM + +from nemoguardrails.llm.cache import CacheInterface, LFUCache +from nemoguardrails.llm.models.initializer import init_llm_model +from nemoguardrails.rails.llm.config import RailsConfig + +log = logging.getLogger(__name__) + + +class ModelFactory: + """Factory for initializing and configuring LLM models.""" + + def __init__( + self, + config: RailsConfig, + injected_llm: Optional[Union[BaseLLM, BaseChatModel]] = None, + ): + """Initialize the ModelFactory. + + Args: + config: The rails configuration. + injected_llm: An optional LLM provided via constructor that takes precedence. + """ + self.config = config + self.injected_llm = injected_llm + self.main_llm: Optional[Union[BaseLLM, BaseChatModel]] = None + self.specialized_llms: Dict[str, Union[BaseLLM, BaseChatModel]] = {} + self.model_caches: Dict[str, CacheInterface] = {} + self.main_llm_supports_streaming = False + + def initialize_models( + self, action_param_registry: Dict[str, Any] + ) -> Dict[str, Union[BaseLLM, BaseChatModel]]: + """Initialize all LLM models from configuration. + + Args: + action_param_registry: Registry to store action parameters. + + Returns: + Dictionary of specialized LLMs (excluding main). + """ + # Handle injected LLM first + if self.injected_llm: + self.main_llm = self.injected_llm + action_param_registry["llm"] = self.main_llm + self._configure_streaming(self.main_llm) + + # Warn if main LLM also specified in config + if any(model.type == "main" for model in self.config.models): + log.warning( + "Both an LLM was provided via constructor and a main LLM is specified in the config. " + "The LLM provided via constructor will be used and the main LLM from config will be ignored." + ) + else: + # Initialize main LLM from config + main_model = next( + (model for model in self.config.models if model.type == "main"), None + ) + + if main_model and main_model.model: + kwargs = self._prepare_model_kwargs(main_model) + self.main_llm = init_llm_model( + model_name=main_model.model, + provider_name=main_model.engine, + mode="chat", + kwargs=kwargs, + ) + action_param_registry["llm"] = self.main_llm + self._configure_streaming( + self.main_llm, + model_name=main_model.model, + provider_name=main_model.engine, + ) + else: + log.warning( + "No main LLM specified in the config and no LLM provided via constructor." + ) + + # Initialize specialized LLMs + for llm_config in self.config.models: + if llm_config.type in ["embeddings", "jailbreak_detection"]: + continue + + # Skip main model - already initialized above + if llm_config.type == "main": + continue + + model_name = llm_config.model + if not model_name: + raise ValueError( + f"LLM Config model field not set for {llm_config.type}" + ) + + provider_name = llm_config.engine + kwargs = self._prepare_model_kwargs(llm_config) + mode = llm_config.mode + + llm_model = init_llm_model( + model_name=model_name, + provider_name=provider_name, + mode=mode, + kwargs=kwargs, + ) + + # Configure based on type + if llm_config.type == "main": + if not self.main_llm: + self.main_llm = llm_model + action_param_registry["llm"] = self.main_llm + else: + param_name = f"{llm_config.type}_llm" + self.specialized_llms[llm_config.type] = llm_model + action_param_registry[param_name] = llm_model + + # Register specialized LLMs dictionary + action_param_registry["llms"] = self.specialized_llms + + # Initialize model caches + self._initialize_model_caches(action_param_registry) + + return self.specialized_llms + + def _prepare_model_kwargs(self, model_config) -> dict: + """Prepare kwargs for model initialization, including API key from environment. + + Args: + model_config: The model configuration object. + + Returns: + Dictionary of kwargs for model initialization. + """ + kwargs = model_config.parameters or {} + + # Add API key from environment if specified + if model_config.api_key_env_var: + api_key = os.environ.get(model_config.api_key_env_var) + if api_key: + kwargs["api_key"] = api_key + + # Enable streaming token usage when streaming is enabled + if self.config.streaming: + kwargs["stream_usage"] = True + + return kwargs + + def _configure_streaming( + self, + llm: Union[BaseLLM, BaseChatModel], + model_name: Optional[str] = None, + provider_name: Optional[str] = None, + ): + """Configure streaming support for the LLM. + + Args: + llm: The LLM model instance. + model_name: Optional model name for logging. + provider_name: Optional provider name for logging. + """ + if not self.config.streaming: + return + + if hasattr(llm, "streaming"): + setattr(llm, "streaming", True) + self.main_llm_supports_streaming = True + else: + self.main_llm_supports_streaming = False + if model_name and provider_name: + log.warning( + "Model %s from provider %s does not support streaming.", + model_name, + provider_name, + ) + else: + log.warning("Provided main LLM does not support streaming.") + + def _create_model_cache(self, model) -> LFUCache: + """Create cache instance for a model based on its configuration. + + Args: + model: The model configuration object. + + Returns: + LFUCache: The cache instance. + """ + if model.cache.maxsize <= 0: + raise ValueError( + f"Invalid cache maxsize for model '{model.type}': {model.cache.maxsize}. " + "Capacity must be greater than 0. Skipping cache creation." + ) + + stats_logging_interval = None + if model.cache.stats.enabled and model.cache.stats.log_interval is not None: + stats_logging_interval = model.cache.stats.log_interval + + cache = LFUCache( + maxsize=model.cache.maxsize, + track_stats=model.cache.stats.enabled, + stats_logging_interval=stats_logging_interval, + ) + + log.info( + "Created cache for model '%s' with maxsize %s", + model.type, + model.cache.maxsize, + ) + + return cache + + def _initialize_model_caches(self, action_param_registry: Dict[str, Any]) -> None: + """Initialize caches for configured models. + + Args: + action_param_registry: Registry to store action parameters. + """ + for model in self.config.models: + if model.type in ["main", "embeddings"]: + continue + + if model.cache and model.cache.enabled: + cache = self._create_model_cache(model) + self.model_caches[model.type] = cache + + log.info( + "Initialized model '%s' with cache %s", + model.type, + "enabled" if cache else "disabled", + ) + + if self.model_caches: + action_param_registry["model_caches"] = self.model_caches + + def get_main_llm(self) -> Optional[Union[BaseLLM, BaseChatModel]]: + """Get the main LLM instance.""" + return self.main_llm + + def get_specialized_llm( + self, llm_type: str + ) -> Optional[Union[BaseLLM, BaseChatModel]]: + """Get a specialized LLM by type.""" + return self.specialized_llms.get(llm_type) + + def supports_streaming(self) -> bool: + """Check if the main LLM supports streaming.""" + return self.main_llm_supports_streaming + + def update_main_llm( + self, llm: Union[BaseLLM, BaseChatModel], action_param_registry: Dict[str, Any] + ): + """Update the main LLM instance. + + Args: + llm: The new LLM instance. + action_param_registry: Registry to update action parameters. + """ + self.main_llm = llm + action_param_registry["llm"] = llm diff --git a/nemoguardrails/rails/llm/rails_api.py b/nemoguardrails/rails/llm/rails_api.py new file mode 100644 index 000000000..f447fb178 --- /dev/null +++ b/nemoguardrails/rails/llm/rails_api.py @@ -0,0 +1,570 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Rails API - Main facade for NeMo Guardrails functionality.""" + +import asyncio +import logging +import threading +import time +from typing import Any, AsyncIterator, Callable, Dict, List, Optional, Type, Union + +from langchain_core.language_models import BaseChatModel +from langchain_core.language_models.llms import BaseLLM +from typing_extensions import Self + +from nemoguardrails.actions.llm.generation import LLMGenerationActions +from nemoguardrails.actions.llm.utils import get_colang_history +from nemoguardrails.actions.v2_x.generation import LLMGenerationActionsV2dotx +from nemoguardrails.colang.v1_0.runtime.runtime import Runtime +from nemoguardrails.colang.v2_x.runtime.flows import State +from nemoguardrails.context import ( + explain_info_var, + generation_options_var, + llm_stats_var, + raw_llm_request, + streaming_handler_var, +) +from nemoguardrails.embeddings.index import EmbeddingsIndex +from nemoguardrails.embeddings.providers.base import EmbeddingModel +from nemoguardrails.logging.explain import ExplainInfo +from nemoguardrails.logging.stats import LLMStats +from nemoguardrails.logging.verbose import set_verbose +from nemoguardrails.patch_asyncio import check_sync_call_from_async_loop +from nemoguardrails.rails.llm.config import RailsConfig +from nemoguardrails.rails.llm.event_translator import EventTranslator +from nemoguardrails.rails.llm.kb_builder import KnowledgeBaseBuilder +from nemoguardrails.rails.llm.model_factory import ModelFactory +from nemoguardrails.rails.llm.options import GenerationOptions, GenerationResponse +from nemoguardrails.rails.llm.response_assembler import ResponseAssembler +from nemoguardrails.rails.llm.runtime_orchestrator import RuntimeOrchestrator +from nemoguardrails.streaming import END_OF_STREAM, StreamingHandler +from nemoguardrails.utils import get_or_create_event_loop +from nemoguardrails.colang.v2_x.runtime.serialization import json_to_state + +log = logging.getLogger(__name__) + + +class RailsAPI: + """ + Main API facade for NeMo Guardrails. + """ + + def __init__( + self, + config: RailsConfig, + llm: Optional[Union[BaseLLM, BaseChatModel]] = None, + verbose: bool = False, + ): + """Initialize the RailsAPI. + + Args: + config: A rails configuration. + llm: An optional LLM engine to use. + verbose: Whether the logging should be verbose. + """ + self.config = config + self.verbose = verbose + self.explain_info: Optional[ExplainInfo] = None + + if self.verbose: + set_verbose(True, llm_calls=True) + + # Initialize embedding configuration + self.embedding_search_providers = {} + self.default_embedding_model = "all-MiniLM-L6-v2" + self.default_embedding_engine = "FastEmbed" + self.default_embedding_params = {} + + # Initialize components + self._init_components(llm) + + def _init_components(self, llm: Optional[Union[BaseLLM, BaseChatModel]]): + """Initialize all components. + + Args: + llm: Optional LLM to inject. + """ + # 1. is already done via config parameter + # Additional config loading/processing could be added here + + # 2. Initialize RuntimeOrchestrator first (needs to be available for other components) + self.runtime_orchestrator = RuntimeOrchestrator( + config=self.config, verbose=self.verbose + ) + self.runtime: Runtime = self.runtime_orchestrator.runtime + + # Registry for action parameters + self.action_param_registry: Dict[str, Any] = {} + + # 3. Initialize ModelFactory + self.model_factory = ModelFactory(config=self.config, injected_llm=llm) + + # Update embedding model if specified in config + for model in self.config.models: + if model.type == "embeddings": + self.default_embedding_model = model.model + self.default_embedding_engine = model.engine + self.default_embedding_params = model.parameters or {} + break + + # Initialize models and populate action registry + self.model_factory.initialize_models(self.action_param_registry) + + # Register action parameters with runtime + for param_name, param_value in self.action_param_registry.items(): + self.runtime.register_action_param(param_name, param_value) + + # 4. Initialize LLM Generation Actions + llm_generation_actions_class = ( + LLMGenerationActions + if self.config.colang_version == "1.0" + else LLMGenerationActionsV2dotx + ) + self.llm_generation_actions = llm_generation_actions_class( + config=self.config, + llm=self.model_factory.get_main_llm(), + llm_task_manager=self.runtime.llm_task_manager, + get_embedding_search_provider_instance=self._get_embeddings_search_provider_instance, + verbose=self.verbose, + ) + self.runtime.register_actions(self.llm_generation_actions, override=False) + + # 5. Initialize KnowledgeBaseBuilder + self.kb_builder = KnowledgeBaseBuilder( + config=self.config, + get_embeddings_search_provider_instance=self._get_embeddings_search_provider_instance, + ) + + # Initialize KB (in separate thread to avoid async issues) + loop = get_or_create_event_loop() + if True or check_sync_call_from_async_loop(): + t = threading.Thread(target=asyncio.run, args=(self.kb_builder.build(),)) + t.start() + t.join() + else: + loop.run_until_complete(self.kb_builder.build()) + + # Register KB as action parameter + self.runtime.register_action_param("kb", self.kb_builder.get_kb()) + + # 6. Initialize EventTranslator + self.event_translator = EventTranslator(config=self.config) + + # 7. Initialize ResponseAssembler + self.response_assembler = ResponseAssembler(config=self.config) + + # Initialize tracing adapters if configured + if self.config.tracing: + from nemoguardrails.tracing import create_log_adapters + + self._log_adapters = create_log_adapters(self.config.tracing) + else: + self._log_adapters = None + + def _get_embeddings_search_provider_instance(self, esp_config=None): + """Get an embeddings search provider instance.""" + from nemoguardrails.rails.llm.config import EmbeddingSearchProvider + + if esp_config is None: + esp_config = EmbeddingSearchProvider() + + if esp_config.name == "default": + from nemoguardrails.embeddings.basic import BasicEmbeddingsIndex + + return BasicEmbeddingsIndex( + embedding_model=esp_config.parameters.get( + "embedding_model", self.default_embedding_model + ), + embedding_engine=esp_config.parameters.get( + "embedding_engine", self.default_embedding_engine + ), + embedding_params=esp_config.parameters.get( + "embedding_parameters", self.default_embedding_params + ), + cache_config=esp_config.cache, + **{ + k: v + for k, v in esp_config.parameters.items() + if k + in [ + "use_batching", + "max_batch_size", + "matx_batch_hold", + "search_threshold", + ] + and v is not None + }, + ) + else: + if esp_config.name not in self.embedding_search_providers: + raise Exception(f"Unknown embedding search provider: {esp_config.name}") + else: + kwargs = esp_config.parameters + return self.embedding_search_providers[esp_config.name](**kwargs) + + @staticmethod + def _ensure_explain_info() -> ExplainInfo: + """Ensure that the ExplainInfo variable is present in the current context.""" + explain_info = explain_info_var.get() + if explain_info is None: + explain_info = ExplainInfo() + explain_info_var.set(explain_info) + return explain_info + + async def generate_async( + self, + prompt: Optional[str] = None, + messages: Optional[List[dict]] = None, + options: Optional[Union[dict, GenerationOptions]] = None, + state: Optional[Union[dict, State]] = None, + streaming_handler: Optional[StreamingHandler] = None, + ) -> Union[str, dict, GenerationResponse]: + """Generate a completion or next message. + + Args: + prompt: The prompt to be used for completion. + messages: The history of messages. + options: Options specific for the generation. + state: The state object. + streaming_handler: Optional streaming handler. + + Returns: + The completion or next message. + """ + # Input validation + if prompt is None and messages is None: + raise ValueError("Either prompt or messages must be provided.") + if prompt is not None and messages is not None: + raise ValueError("Only one of prompt or messages can be provided.") + + # Convert prompt to messages format + if prompt is not None: + messages = [{"role": "user", "content": prompt}] + + # Handle state deserialization + if state is not None: + if isinstance(state, dict) and state.get("version", "1.0") == "2.x": + state = json_to_state(state["state"]) + + # Process options + gen_options = self._process_options(options, state) + generation_options_var.set(gen_options) + + if streaming_handler: + streaming_handler_var.set(streaming_handler) + + # Initialize explain info + self.explain_info = self._ensure_explain_info() + raw_llm_request.set(messages) + + # Inject generation options into messages + if gen_options: + messages = [ + { + "role": "context", + "content": {"generation_options": gen_options.model_dump()}, + } + ] + (messages or []) + + # Handle bot message in context for non-dialog mode + if ( + messages + and messages[-1]["role"] == "assistant" + and gen_options + and gen_options.rails.dialog is False + ): + messages[0]["content"]["bot_message"] = messages[-1]["content"] + messages = messages[0:-1] + + t0 = time.time() + + # Initialize LLM stats + llm_stats = LLMStats() + llm_stats_var.set(llm_stats) + + # Translate messages to events + if messages is None: + raise ValueError("messages must be provided") + else: + events = self.event_translator.messages_to_events(messages, state) + + # Generate new events using runtime orchestrator + try: + ( + new_events, + output_state, + processing_log, + ) = await self.runtime_orchestrator.generate_events( + events=events, state=state + ) + except Exception as e: + log.error("Error in generate_async: %s", e, exc_info=True) + streaming_handler = streaming_handler_var.get() + if streaming_handler: + import json + + from nemoguardrails.utils import extract_error_json + + error_message = str(e) + error_dict = extract_error_json(error_message) + error_payload = json.dumps(error_dict) + await streaming_handler.push_chunk(error_payload) + await streaming_handler.push_chunk(END_OF_STREAM) # type: ignore + raise + + # Update event translator cache for Colang 1.0 + if self.config.colang_version == "1.0": + # Build the new message for caching + responses, _, _, _ = self.response_assembler._extract_from_events( + new_events + ) + responses = [str(r) if not isinstance(r, str) else r for r in responses] + new_message = {"role": "assistant", "content": "\n".join(responses)} + + if state is None: + all_events = events + new_events + self.event_translator.cache_events( + (messages or []) + [new_message], all_events + ) + + # Log conversation history + all_events = ( + events + new_events if self.config.colang_version == "1.0" else new_events + ) + self.explain_info.colang_history = get_colang_history(all_events) + if self.verbose: + log.info( + f"Conversation history so far: \n{self.explain_info.colang_history}" + ) + + total_time = time.time() - t0 + log.info( + "--- :: Total processing took %.2f seconds. LLM Stats: %s" + % (total_time, llm_stats) + ) + + # Close streaming handler if present + streaming_handler = streaming_handler_var.get() + if streaming_handler: + await streaming_handler.push_chunk(END_OF_STREAM) # type: ignore + + # Assemble response + if gen_options: + res = self.response_assembler.assemble_response( + new_events=new_events, + all_events=all_events, + output_state=output_state, + processing_log=processing_log, + gen_options=gen_options, + prompt=prompt, + ) + + # Handle tracing if enabled + if self.config.tracing.enabled: + await self._handle_tracing(messages, res) + + return res + else: + return self.response_assembler.assemble_simple_response( + new_events=new_events, prompt=prompt + ) + + def _process_options( + self, + options: Optional[Union[dict, GenerationOptions]], + state: Optional[Any], + ) -> Optional[GenerationOptions]: + """Process and normalize generation options. + + Args: + options: Raw options. + state: State object. + + Returns: + Normalized GenerationOptions or None. + """ + if state is not None: + if options is None: + return GenerationOptions() + elif isinstance(options, dict): + return GenerationOptions(**options) + else: + return options + else: + if options and isinstance(options, dict): + return GenerationOptions(**options) + elif isinstance(options, GenerationOptions): + return options + elif options is None: + return None + else: + raise TypeError("options must be a dict or GenerationOptions") + + async def _handle_tracing( + self, + messages: List[dict[str, Any]], + res: GenerationResponse, + ): + """Handle tracing export. + + Args: + messages: Input messages. + res: Generation response. + """ + from nemoguardrails.tracing import Tracer + + span_format = getattr(self.config.tracing, "span_format", "opentelemetry") + enable_content_capture = getattr( + self.config.tracing, "enable_content_capture", False + ) + + tracer = Tracer( + input=messages, + response=res, + adapters=self._log_adapters, + span_format=span_format, + enable_content_capture=enable_content_capture, + ) + await tracer.export_async() + + def generate( + self, + prompt: Optional[str] = None, + messages: Optional[List[dict]] = None, + options: Optional[Union[dict, GenerationOptions]] = None, + state: Optional[dict] = None, + ): + """Synchronous version of generate_async.""" + if check_sync_call_from_async_loop(): + raise RuntimeError( + "You are using the sync `generate` inside async code. " + "You should replace with `await generate_async(...)` or use `nest_asyncio.apply()`." + ) + + loop = get_or_create_event_loop() + return loop.run_until_complete( + self.generate_async( + prompt=prompt, messages=messages, options=options, state=state + ) + ) + + def stream_async( + self, + prompt: Optional[str] = None, + messages: Optional[List[dict]] = None, + options: Optional[Union[dict, GenerationOptions]] = None, + state: Optional[Union[dict, State]] = None, + include_generation_metadata: Optional[bool] = False, + ) -> AsyncIterator[str]: + """Stream tokens from the LLM. + + Note: This is a simplified stub. Full streaming implementation would + require additional logic from the original LLMRails.stream_async method. + """ + # Simplified implementation - full version would include output rails streaming + self.explain_info = self._ensure_explain_info() + + streaming_handler = StreamingHandler( + include_generation_metadata=include_generation_metadata + ) + + async def _generation_task(): + try: + await self.generate_async( + prompt=prompt, + messages=messages, + streaming_handler=streaming_handler, + options=options, + state=state, + ) + except Exception as e: + log.error(f"Error in generation task: {e}", exc_info=True) + import json + + from nemoguardrails.utils import extract_error_json + + error_message = str(e) + error_dict = extract_error_json(error_message) + error_payload = json.dumps(error_dict) + await streaming_handler.push_chunk(error_payload) + await streaming_handler.push_chunk(END_OF_STREAM) + + task = asyncio.create_task(_generation_task()) + + if not hasattr(self, "_active_tasks"): + self._active_tasks = set() + self._active_tasks.add(task) + + def task_done_callback(task): + self._active_tasks.discard(task) + + task.add_done_callback(task_done_callback) + + return streaming_handler + + # Additional API methods for compatibility + + def register_action(self, action: Callable, name: Optional[str] = None) -> Self: + """Register a custom action.""" + self.runtime.register_action(action, name) + return self + + def register_action_param(self, name: str, value: Any) -> Self: + """Register a custom action parameter.""" + self.runtime.register_action_param(name, value) + return self + + def register_filter(self, filter_fn: Callable, name: Optional[str] = None) -> Self: + """Register a custom filter.""" + self.runtime.llm_task_manager.register_filter(filter_fn, name) + return self + + def register_output_parser(self, output_parser: Callable, name: str) -> Self: + """Register a custom output parser.""" + self.runtime.llm_task_manager.register_output_parser(output_parser, name) + return self + + def register_prompt_context(self, name: str, value_or_fn: Any) -> Self: + """Register a value to be included in the prompt context.""" + self.runtime.llm_task_manager.register_prompt_context(name, value_or_fn) + return self + + def register_embedding_search_provider( + self, name: str, cls: Type[EmbeddingsIndex] + ) -> Self: + """Register a new embedding search provider.""" + self.embedding_search_providers[name] = cls + return self + + def register_embedding_provider( + self, cls: Type[EmbeddingModel], name: Optional[str] = None + ) -> Self: + """Register a custom embedding provider.""" + from nemoguardrails.embeddings.providers import register_embedding_provider + + register_embedding_provider(engine_name=name, model=cls) + return self + + def explain(self) -> ExplainInfo: + """Return the latest ExplainInfo object.""" + if self.explain_info is None: + self.explain_info = self._ensure_explain_info() + return self.explain_info + + def update_llm(self, llm: Union[BaseLLM, BaseChatModel]): + """Replace the main LLM with the provided one.""" + self.model_factory.update_main_llm(llm, self.action_param_registry) + self.llm_generation_actions.llm = llm diff --git a/nemoguardrails/rails/llm/response_assembler.py b/nemoguardrails/rails/llm/response_assembler.py new file mode 100644 index 000000000..807eb8013 --- /dev/null +++ b/nemoguardrails/rails/llm/response_assembler.py @@ -0,0 +1,364 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Response assembler for generating GenerationResponse objects.""" + +import logging +import re +from typing import Any, Dict, List, Optional + +from nemoguardrails.actions.llm.utils import ( + extract_bot_thinking_from_events, + extract_tool_calls_from_events, + get_and_clear_response_metadata_contextvar, + get_colang_history, +) +from nemoguardrails.colang.v1_0.runtime.flows import compute_context +from nemoguardrails.logging.processing_log import compute_generation_log +from nemoguardrails.rails.llm.config import RailsConfig +from nemoguardrails.rails.llm.options import ( + GenerationLog, + GenerationOptions, + GenerationResponse, +) + +log = logging.getLogger(__name__) + + +class ResponseAssembler: + """Assembles responses from events into GenerationResponse objects.""" + + def __init__(self, config: RailsConfig): + """Initialize the ResponseAssembler. + + Args: + config: The rails configuration. + """ + self.config = config + + def assemble_response( + self, + new_events: List[dict], + all_events: List[dict], + output_state: Optional[Any], + processing_log: List[dict], + gen_options: Optional[GenerationOptions], + prompt: Optional[str] = None, + ) -> GenerationResponse: + """Assemble a GenerationResponse from events. + + Args: + new_events: The newly generated events. + all_events: All events including history. + output_state: The output state (for Colang 2.x). + processing_log: The processing log. + gen_options: Generation options. + prompt: Optional prompt (if used instead of messages). + + Returns: + A GenerationResponse object. + """ + # Extract responses and metadata from events + ( + responses, + response_tool_calls, + response_events, + exception, + ) = self._extract_from_events(new_events) + + # Build the new message + new_message = self._build_message( + responses, response_tool_calls, response_events, exception + ) + + # Extract additional metadata + tool_calls = extract_tool_calls_from_events(new_events) + llm_metadata = get_and_clear_response_metadata_contextvar() + reasoning_content = extract_bot_thinking_from_events(new_events) + + # Create response object + if prompt: + res = GenerationResponse(response=new_message["content"]) + else: + res = GenerationResponse(response=[new_message]) + + if reasoning_content: + res.reasoning_content = reasoning_content + + if tool_calls: + res.tool_calls = tool_calls + + if llm_metadata: + res.llm_metadata = llm_metadata + + # Add version-specific information + if self.config.colang_version == "1.0": + self._add_v1_specific_info(res, all_events, processing_log, gen_options) + else: + self._validate_v2_options(gen_options) + + # Include the state + if output_state is not None: + res.state = output_state + + return res + + def _extract_from_events( + self, new_events: List[dict] + ) -> tuple[List[str], List[dict], List[dict], Optional[dict]]: + """Extract responses, tool calls, and events from new events. + + Args: + new_events: The newly generated events. + + Returns: + Tuple of (responses, response_tool_calls, response_events, exception). + """ + responses = [] + response_tool_calls = [] + response_events = [] + exception = None + + if self.config.colang_version == "1.0": + for event in new_events: + if event["type"] == "StartUtteranceBotAction": + # Check if we need to remove a message + if event["script"] == "(remove last message)": + responses = responses[0:-1] + else: + responses.append(event["script"]) + elif event["type"].endswith("Exception"): + exception = event + else: + for event in new_events: + start_action_match = re.match(r"Start(.*Action)", event["type"]) + + if start_action_match: + action_name = start_action_match[1] + # Extract arguments + arguments = { + k: v + for k, v in event.items() + if k + not in [ + "type", + "uid", + "event_created_at", + "source_uid", + "action_uid", + ] + } + response_tool_calls.append( + { + "id": event["action_uid"], + "type": "function", + "function": {"name": action_name, "arguments": arguments}, + } + ) + + elif event["type"] == "UtteranceBotActionFinished": + responses.append(event["final_script"]) + else: + response_events.append(event) + + return responses, response_tool_calls, response_events, exception + + def _build_message( + self, + responses: List[str], + response_tool_calls: List[dict], + response_events: List[dict], + exception: Optional[dict], + ) -> dict: + """Build a message from response components. + + Args: + responses: List of response strings. + response_tool_calls: List of tool calls. + response_events: List of response events. + exception: Optional exception event. + + Returns: + A message dictionary. + """ + new_message: Dict[str, Any] + if exception: + new_message = {"role": "exception", "content": exception} + else: + # Ensure all items in responses are strings + responses = [ + str(response) if not isinstance(response, str) else response + for response in responses + ] + new_message = {"role": "assistant", "content": "\n".join(responses)} + + if response_tool_calls: + new_message["tool_calls"] = response_tool_calls + + if response_events: + new_message["events"] = response_events + + return new_message + + def _add_v1_specific_info( + self, + res: GenerationResponse, + all_events: List[dict], + processing_log: List[dict], + gen_options: Optional[GenerationOptions], + ): + """Add Colang 1.0 specific information to the response. + + Args: + res: The GenerationResponse to update. + all_events: All events including history. + processing_log: The processing log. + gen_options: Generation options. + """ + # Extract output variables if specified + if gen_options and gen_options.output_vars: + context = compute_context(all_events) + output_vars = gen_options.output_vars + if isinstance(output_vars, list): + res.output_data = {k: context.get(k) for k in output_vars} + else: + res.output_data = context + + # Add logging information + _log = compute_generation_log(processing_log) + log_options = gen_options.log if gen_options else None + + if log_options and (log_options.activated_rails or log_options.llm_calls): + res.log = GenerationLog() + res.log.stats = _log.stats + + if log_options.activated_rails: + res.log.activated_rails = _log.activated_rails + else: + # Keep as empty list when not requested + res.log.activated_rails = [] + + if log_options.llm_calls: + res.log.llm_calls = [] + for activated_rail in _log.activated_rails: + for executed_action in activated_rail.executed_actions: + res.log.llm_calls.extend(executed_action.llm_calls) + else: + # Set to empty list instead of None when not requested + res.log.llm_calls = [] + + # Include internal events if requested + if log_options and log_options.internal_events: + if res.log is None: + res.log = GenerationLog() + res.log.internal_events = all_events + elif res.log is not None: + # Set to empty list instead of None when not requested but log exists + res.log.internal_events = [] + + # Include Colang history if requested + if log_options and log_options.colang_history: + if res.log is None: + res.log = GenerationLog() + res.log.colang_history = get_colang_history(all_events) + + # Normalize list fields: ensure they're empty lists instead of None when log exists + if res.log is not None: + if res.log.llm_calls is None: + res.log.llm_calls = [] + if res.log.internal_events is None: + res.log.internal_events = [] + + # Include raw LLM output if requested + if gen_options and gen_options.llm_output: + for activated_rail in _log.activated_rails: + if activated_rail.type == "generation": + for executed_action in activated_rail.executed_actions: + for llm_call in executed_action.llm_calls: + res.llm_output = llm_call.raw_response + + def _validate_v2_options(self, gen_options: Optional[GenerationOptions]): + """Validate that unsupported options are not used for Colang 2.x. + + Args: + gen_options: Generation options to validate. + + Raises: + ValueError: If unsupported options are used. + """ + if not gen_options: + return + + if gen_options.output_vars: + raise ValueError( + "The `output_vars` option is not supported for Colang 2.0 configurations." + ) + + log_options = gen_options.log + if log_options and ( + log_options.activated_rails + or log_options.llm_calls + or log_options.internal_events + or log_options.colang_history + ): + raise ValueError( + "The `log` option is not supported for Colang 2.0 configurations." + ) + + if gen_options.llm_output: + raise ValueError( + "The `llm_output` option is not supported for Colang 2.0 configurations." + ) + + def assemble_simple_response( + self, + new_events: List[dict], + prompt: Optional[str] = None, + ) -> dict: + """Assemble a simple response (non-GenerationResponse mode). + + Args: + new_events: The newly generated events. + prompt: Optional prompt (if used instead of messages). + + Returns: + A message dictionary or content string. + """ + ( + responses, + response_tool_calls, + response_events, + exception, + ) = self._extract_from_events(new_events) + + new_message = self._build_message( + responses, response_tool_calls, response_events, exception + ) + + # Add thinking trace if present + reasoning_content = extract_bot_thinking_from_events(new_events) + if reasoning_content: + thinking_trace = f"{reasoning_content}\n" + new_message["content"] = thinking_trace + new_message["content"] + + # Add tool calls if present + tool_calls = extract_tool_calls_from_events(new_events) + if tool_calls: + new_message["tool_calls"] = tool_calls + + if prompt: + return new_message["content"] + else: + return new_message diff --git a/nemoguardrails/rails/llm/runtime_orchestrator.py b/nemoguardrails/rails/llm/runtime_orchestrator.py new file mode 100644 index 000000000..1d581afde --- /dev/null +++ b/nemoguardrails/rails/llm/runtime_orchestrator.py @@ -0,0 +1,164 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Runtime orchestrator for managing Colang runtime execution.""" + +import asyncio +import logging +from typing import Any, List, Optional, Tuple, cast + +from nemoguardrails.colang.runtime import Runtime +from nemoguardrails.colang.v1_0.runtime.runtime import RuntimeV1_0 +from nemoguardrails.colang.v2_x.runtime.runtime import RuntimeV2_x +from nemoguardrails.colang.v2_x.runtime.serialization import state_to_json +from nemoguardrails.logging.verbose import set_verbose +from nemoguardrails.rails.llm.config import RailsConfig + +log = logging.getLogger(__name__) + +# Semaphore for protecting process_events calls +process_events_semaphore = asyncio.Semaphore(1) + + +class RuntimeOrchestrator: + """Orchestrates the Colang runtime execution. + + Handles runtime initialization, event generation, and process coordination. + """ + + def __init__(self, config: RailsConfig, verbose: bool = False): + """Initialize the RuntimeOrchestrator. + + Args: + config: The rails configuration. + verbose: Whether to enable verbose logging. + """ + self.config = config + self.verbose = verbose + self.runtime: Runtime + + if self.verbose: + set_verbose(True, llm_calls=True) + + # Initialize the appropriate runtime based on Colang version + if config.colang_version == "1.0": + self.runtime = RuntimeV1_0(config=config, verbose=verbose) + elif config.colang_version == "2.x": + self.runtime = RuntimeV2_x(config=config, verbose=verbose) + else: + raise ValueError(f"Unsupported colang version: {config.colang_version}.") + + async def generate_events( + self, + events: List[dict], + state: Optional[Any] = None, + ) -> Tuple[List[dict], Optional[Any], List[dict]]: + """Generate new events based on the history of events. + + Args: + events: The history of events. + state: Optional state object (for Colang 2.x). + + Returns: + Tuple of (new_events, output_state, processing_log). + """ + processing_log = [] + + if self.config.colang_version == "1.0": + # For Colang 1.0, we use generate_events + state_events = [] + if state: + assert isinstance(state, dict) + state_events = state.get("events", []) + + # Extract any context variables from state (keys other than "events") + # and create a ContextUpdate event to make them available in Colang + context_vars = {k: v for k, v in state.items() if k != "events"} + if context_vars: + log.info( + f"DEBUG: Creating ContextUpdate event from state with context_vars={context_vars}" + ) + from nemoguardrails.utils import new_event_dict + + context_update_event = new_event_dict( + "ContextUpdate", data=context_vars + ) + log.info( + f"DEBUG: Created ContextUpdate event: {context_update_event}" + ) + state_events = [context_update_event] + state_events + + all_events = state_events + events + log.info( + f"DEBUG: Calling runtime.generate_events with {len(all_events)} events, first few: {all_events[:3]}" + ) + new_events = await self.runtime.generate_events( + all_events, processing_log=processing_log + ) + output_state = None + + else: + # For Colang 2.x, we use process_events + instant_actions = ["UtteranceBotAction"] + if self.config.rails.actions.instant_actions is not None: + instant_actions = self.config.rails.actions.instant_actions + + runtime: RuntimeV2_x = cast(RuntimeV2_x, self.runtime) + + new_events, output_state = await runtime.process_events( + events, state=state, instant_actions=instant_actions, blocking=True + ) + + # Encode output state as JSON + output_state = {"state": state_to_json(output_state), "version": "2.x"} + + return new_events, output_state, processing_log + + async def process_events_async( + self, + events: List[dict], + state: Optional[dict] = None, + blocking: bool = False, + ) -> Tuple[List[dict], Any]: + """Process a sequence of events in a given state. + + Args: + events: A sequence of events to process. + state: The starting state. + blocking: Whether to block on all actions. + + Returns: + Tuple of (output_events, output_state). + """ + # Protect process_events to be called only once at a time + async with process_events_semaphore: + if self.config.colang_version == "1.0": + # For Colang 1.0, use generate_events + state_events = [] + if state: + assert isinstance(state, dict) + state_events = state.get("events", []) + + output_events = await self.runtime.generate_events( + state_events + events + ) + output_state = None + else: + # For Colang 2.x, use process_events + output_events, output_state = await self.runtime.process_events( + events, state, blocking + ) + + return (output_events, output_state) diff --git a/nemoguardrails/server/api.py b/nemoguardrails/server/api.py index 6769dec1e..d8a94bc02 100644 --- a/nemoguardrails/server/api.py +++ b/nemoguardrails/server/api.py @@ -197,8 +197,7 @@ class RequestBody(BaseModel): ) config_ids: Optional[List[str]] = Field( default=None, - description="The list of configuration ids to be used. " - "If set, the configurations will be combined.", + description="The list of configuration ids to be used. If set, the configurations will be combined.", # alias="guardrails", validate_default=True, ) @@ -592,8 +591,7 @@ def on_any_event(self, event): except ImportError: # Since this is running in a separate thread, we just print the error. print( - "The auto-reload feature requires `watchdog`. " - "Please install using `pip install watchdog`." + "The auto-reload feature requires `watchdog`. Please install using `pip install watchdog`." ) # Force close everything. os._exit(-1) diff --git a/nemoguardrails/streaming.py b/nemoguardrails/streaming.py index 06ad3ee93..98c807405 100644 --- a/nemoguardrails/streaming.py +++ b/nemoguardrails/streaming.py @@ -259,7 +259,7 @@ async def _process( async def push_chunk( self, - chunk: Union[str, GenerationChunk, AIMessageChunk, ChatGenerationChunk, None], + chunk: Union[str, GenerationChunk, AIMessageChunk, ChatGenerationChunk, Any], generation_info: Optional[Dict[str, Any]] = None, ): """Push a new chunk to the stream.""" diff --git a/tests/rails/llm/test_config.py b/tests/rails/llm/test_config.py index aa3be4789..fa715aba9 100644 --- a/tests/rails/llm/test_config.py +++ b/tests/rails/llm/test_config.py @@ -21,7 +21,7 @@ from pydantic import ValidationError from nemoguardrails.rails.llm.config import Model, RailsConfig, TaskPrompt -from nemoguardrails.rails.llm.llmrails import LLMRails +from nemoguardrails.rails.llm.model_factory import ModelFactory def test_task_prompt_valid_content(): @@ -317,9 +317,9 @@ def test_llm_rails_configure_streaming_with_attr(): streaming=True, ) - rails = LLMRails(config, llm=mock_llm) + model_factory = ModelFactory(config=config, injected_llm=mock_llm) setattr(mock_llm, "streaming", None) - rails._configure_main_llm_streaming(llm=mock_llm) + model_factory._configure_streaming(llm=mock_llm) assert mock_llm.streaming @@ -333,8 +333,8 @@ def test_llm_rails_configure_streaming_without_attr(caplog): streaming=True, ) - rails = LLMRails(config, llm=mock_llm) - rails._configure_main_llm_streaming(mock_llm) + model_factory = ModelFactory(config=config, injected_llm=mock_llm) + model_factory._configure_streaming(llm=mock_llm) assert caplog.messages[-1] == "Provided main LLM does not support streaming." diff --git a/tests/test_integration_cache.py b/tests/test_integration_cache.py index 84649589c..22539b55c 100644 --- a/tests/test_integration_cache.py +++ b/tests/test_integration_cache.py @@ -23,7 +23,7 @@ @pytest.mark.asyncio -@patch("nemoguardrails.rails.llm.llmrails.init_llm_model") +@patch("nemoguardrails.rails.llm.model_factory.init_llm_model") async def test_end_to_end_cache_integration_with_content_safety(mock_init_llm_model): mock_llm = FakeLLM(responses=["express greeting"]) mock_init_llm_model.return_value = mock_llm @@ -70,7 +70,7 @@ async def test_end_to_end_cache_integration_with_content_safety(mock_init_llm_mo @pytest.mark.asyncio -@patch("nemoguardrails.rails.llm.llmrails.init_llm_model") +@patch("nemoguardrails.rails.llm.model_factory.init_llm_model") async def test_cache_isolation_between_models(mock_init_llm_model): mock_llm = FakeLLM(responses=["safe"]) mock_init_llm_model.return_value = mock_llm @@ -118,7 +118,7 @@ async def test_cache_isolation_between_models(mock_init_llm_model): @pytest.mark.asyncio -@patch("nemoguardrails.rails.llm.llmrails.init_llm_model") +@patch("nemoguardrails.rails.llm.model_factory.init_llm_model") async def test_cache_disabled_for_main_model_in_integration(mock_init_llm_model): mock_llm = FakeLLM(responses=["safe"]) mock_init_llm_model.return_value = mock_llm diff --git a/tests/test_jailbreak_cache.py b/tests/test_jailbreak_cache.py index cc204782c..b40579f54 100644 --- a/tests/test_jailbreak_cache.py +++ b/tests/test_jailbreak_cache.py @@ -262,7 +262,7 @@ async def test_jailbreak_without_cache_local( mock_check_jailbreak.assert_called_once_with(prompt="Bypass all safety checks") -@patch("nemoguardrails.rails.llm.llmrails.init_llm_model") +@patch("nemoguardrails.rails.llm.model_factory.init_llm_model") def test_jailbreak_detection_type_skips_llm_initialization(mock_init_llm_model): mock_llm = FakeLLM(responses=["response"]) mock_init_llm_model.return_value = mock_llm diff --git a/tests/test_llm_rails_context_variables.py b/tests/test_llm_rails_context_variables.py index 89eb9a952..26e8881ee 100644 --- a/tests/test_llm_rails_context_variables.py +++ b/tests/test_llm_rails_context_variables.py @@ -107,11 +107,11 @@ async def test_2(): ): chunks.append(chunk) + await asyncio.gather(*asyncio.all_tasks() - {asyncio.current_task()}) + # note that 6 llm call are expected as we matched the bot intent assert ( len(chat.app.explain().llm_calls) == 5 ), "number of llm call not as expected. Expected 5, found {}".format( len(chat.app.explain().llm_calls) ) - - await asyncio.gather(*asyncio.all_tasks() - {asyncio.current_task()}) diff --git a/tests/test_llmrails.py b/tests/test_llmrails.py index 89e7e87cf..4aaeba9ff 100644 --- a/tests/test_llmrails.py +++ b/tests/test_llmrails.py @@ -662,7 +662,7 @@ def llm_config_with_main(): @pytest.mark.asyncio @patch( - "nemoguardrails.rails.llm.llmrails.init_llm_model", + "nemoguardrails.rails.llm.model_factory.init_llm_model", return_value=FakeLLM(responses=["this should not be used"]), ) async def test_llm_config_precedence(mock_init, llm_config_with_main): @@ -679,7 +679,7 @@ async def test_llm_config_precedence(mock_init, llm_config_with_main): @pytest.mark.asyncio @patch( - "nemoguardrails.rails.llm.llmrails.init_llm_model", + "nemoguardrails.rails.llm.model_factory.init_llm_model", return_value=FakeLLM(responses=["this should not be used"]), ) async def test_llm_config_warning(mock_init, llm_config_with_main, caplog): @@ -728,7 +728,7 @@ def llm_config_with_multiple_models(): @pytest.mark.asyncio @patch( - "nemoguardrails.rails.llm.llmrails.init_llm_model", + "nemoguardrails.rails.llm.model_factory.init_llm_model", return_value=FakeLLM(responses=["content safety response"]), ) async def test_other_models_honored(mock_init, llm_config_with_multiple_models): @@ -779,7 +779,7 @@ async def test_llm_constructor_with_empty_models_config(): @pytest.mark.asyncio @patch( - "nemoguardrails.rails.llm.llmrails.init_llm_model", + "nemoguardrails.rails.llm.model_factory.init_llm_model", return_value=FakeLLM(responses=["safe"]), ) async def test_main_llm_from_config_registered_as_action_param( @@ -839,7 +839,7 @@ async def test_llm_action(llm: BaseLLM): assert action_finished_event["return_value"] == "llm_action_success" -@patch("nemoguardrails.rails.llm.llmrails.init_llm_model") +@patch("nemoguardrails.rails.llm.model_factory.init_llm_model") @patch.dict(os.environ, {"TEST_OPENAI_KEY": "secret-api-key-from-env"}) def test_api_key_environment_variable_passed_to_init_llm_model(mock_init_llm_model): """Test that API keys from environment variables are passed to init_llm_model.""" @@ -873,7 +873,7 @@ def test_api_key_environment_variable_passed_to_init_llm_model(mock_init_llm_mod assert call_args.kwargs["mode"] == "chat" -@patch("nemoguardrails.rails.llm.llmrails.init_llm_model") +@patch("nemoguardrails.rails.llm.model_factory.init_llm_model") @patch.dict(os.environ, {"CONTENT_SAFETY_KEY": "safety-key-from-env"}) def test_api_key_environment_variable_for_non_main_models(mock_init_llm_model): """Test that API keys from environment variables work for non-main models too. @@ -915,7 +915,7 @@ def test_api_key_environment_variable_for_non_main_models(mock_init_llm_model): assert safety_call_args.kwargs["kwargs"]["temperature"] == 0.0 -@patch("nemoguardrails.rails.llm.llmrails.init_llm_model") +@patch("nemoguardrails.rails.llm.model_factory.init_llm_model") def test_missing_api_key_environment_variable_graceful_handling(mock_init_llm_model): """Test that missing environment variables are handled gracefully during LLM initialization. @@ -992,7 +992,7 @@ def __init__(self): @pytest.mark.asyncio -@patch("nemoguardrails.rails.llm.llmrails.init_llm_model") +@patch("nemoguardrails.rails.llm.model_factory.init_llm_model") async def test_stream_usage_enabled_for_streaming_supported_providers( mock_init_llm_model, ): @@ -1020,7 +1020,7 @@ async def test_stream_usage_enabled_for_streaming_supported_providers( @pytest.mark.asyncio -@patch("nemoguardrails.rails.llm.llmrails.init_llm_model") +@patch("nemoguardrails.rails.llm.model_factory.init_llm_model") async def test_stream_usage_not_set_without_streaming(mock_init_llm_model): """Test that stream_usage is not set when streaming is disabled.""" config = RailsConfig.from_content( @@ -1046,7 +1046,7 @@ async def test_stream_usage_not_set_without_streaming(mock_init_llm_model): @pytest.mark.asyncio -@patch("nemoguardrails.rails.llm.llmrails.init_llm_model") +@patch("nemoguardrails.rails.llm.model_factory.init_llm_model") async def test_stream_usage_enabled_for_all_providers_when_streaming( mock_init_llm_model, ): @@ -1190,7 +1190,7 @@ def test_explain_calls_ensure_explain_info(): assert rails.explain_info == ExplainInfo() -@patch("nemoguardrails.rails.llm.llmrails.init_llm_model") +@patch("nemoguardrails.rails.llm.model_factory.init_llm_model") def test_cache_initialization_disabled_by_default(mock_init_llm_model): mock_llm = FakeLLM(responses=["response"]) mock_init_llm_model.return_value = mock_llm @@ -1216,7 +1216,7 @@ def test_cache_initialization_disabled_by_default(mock_init_llm_model): assert model_caches is None or len(model_caches) == 0 -@patch("nemoguardrails.rails.llm.llmrails.init_llm_model") +@patch("nemoguardrails.rails.llm.model_factory.init_llm_model") def test_cache_initialization_with_enabled_cache(mock_init_llm_model): from nemoguardrails.rails.llm.config import CacheStatsConfig, ModelCacheConfig @@ -1251,7 +1251,7 @@ def test_cache_initialization_with_enabled_cache(mock_init_llm_model): assert model_caches["content_safety"].maxsize == 1000 -@patch("nemoguardrails.rails.llm.llmrails.init_llm_model") +@patch("nemoguardrails.rails.llm.model_factory.init_llm_model") def test_cache_not_created_for_main_and_embeddings_models(mock_init_llm_model): from nemoguardrails.rails.llm.config import ModelCacheConfig @@ -1282,7 +1282,7 @@ def test_cache_not_created_for_main_and_embeddings_models(mock_init_llm_model): assert "embeddings" not in model_caches -@patch("nemoguardrails.rails.llm.llmrails.init_llm_model") +@patch("nemoguardrails.rails.llm.model_factory.init_llm_model") def test_cache_initialization_with_zero_maxsize_raises_error(mock_init_llm_model): from nemoguardrails.rails.llm.config import ModelCacheConfig @@ -1304,7 +1304,7 @@ def test_cache_initialization_with_zero_maxsize_raises_error(mock_init_llm_model LLMRails(config=config, verbose=False) -@patch("nemoguardrails.rails.llm.llmrails.init_llm_model") +@patch("nemoguardrails.rails.llm.model_factory.init_llm_model") def test_cache_initialization_with_stats_enabled(mock_init_llm_model): from nemoguardrails.rails.llm.config import CacheStatsConfig, ModelCacheConfig @@ -1336,7 +1336,7 @@ def test_cache_initialization_with_stats_enabled(mock_init_llm_model): assert cache.supports_stats_logging() is True -@patch("nemoguardrails.rails.llm.llmrails.init_llm_model") +@patch("nemoguardrails.rails.llm.model_factory.init_llm_model") def test_cache_initialization_with_multiple_models(mock_init_llm_model): from nemoguardrails.rails.llm.config import ModelCacheConfig diff --git a/tests/test_streaming.py b/tests/test_streaming.py index 8fb1ac22a..7bdfc5472 100644 --- a/tests/test_streaming.py +++ b/tests/test_streaming.py @@ -16,6 +16,7 @@ import asyncio import json import math +from typing import Optional import pytest @@ -283,15 +284,18 @@ def output_rails_streaming_config(): define flow user express greeting bot tell joke + + define flow self check output + execute self_check_output """, ) @action(is_system_action=True, output_mapping=lambda result: not result) -def self_check_output(**params): +def self_check_output(context: Optional[dict] = None): """A dummy self check action that checks if the bot message contains the BLOCK keyword.""" - if params.get("context", {}).get("bot_message"): - bot_message_chunk = params.get("context", {}).get("bot_message") + if context and context.get("bot_message"): + bot_message_chunk = context.get("bot_message") if "BLOCK" in bot_message_chunk: return False diff --git a/tests/test_streaming_output_rails.py b/tests/test_streaming_output_rails.py index 8583a2fb9..a6f582692 100644 --- a/tests/test_streaming_output_rails.py +++ b/tests/test_streaming_output_rails.py @@ -18,7 +18,7 @@ import asyncio import json from json.decoder import JSONDecodeError -from typing import AsyncIterator +from typing import AsyncIterator, Optional import pytest @@ -57,6 +57,9 @@ def output_rails_streaming_config(): define flow user express greeting bot tell joke + + define flow self check output + execute self_check_output """, ) @@ -83,6 +86,9 @@ def output_rails_streaming_config_default(): define flow user express greeting bot tell joke + + define flow self check output + execute self_check_output """, ) @@ -100,11 +106,11 @@ async def test_stream_async_streaming_enabled(output_rails_streaming_config): @action(is_system_action=True, output_mapping=lambda result: not result) -def self_check_output(**params): +def self_check_output(context: Optional[dict] = None): """A dummy self check action that checks if the bot message contains the BLOCK keyword.""" - if params.get("context", {}).get("bot_message"): - bot_message_chunk = params.get("context", {}).get("bot_message") + if context and context.get("bot_message"): + bot_message_chunk = context.get("bot_message") print(f"bot_message_chunk: {bot_message_chunk}") if "BLOCK" in bot_message_chunk: return False @@ -322,10 +328,8 @@ async def test_external_generator_with_output_rails_blocked(): rails = LLMRails(config) @action(name="self_check_output") - async def self_check_output(**kwargs): - bot_message = kwargs.get( - "bot_message", kwargs.get("context", {}).get("bot_message", "") - ) + async def self_check_output(context: Optional[dict] = None): + bot_message = context.get("bot_message", "") if context else "" # block if message contains "offensive" or "idiot" if "offensive" in bot_message.lower() or "idiot" in bot_message.lower(): return False diff --git a/tests/test_system_message_conversion.py b/tests/test_system_message_conversion.py index 08d1f6797..7bda08de6 100644 --- a/tests/test_system_message_conversion.py +++ b/tests/test_system_message_conversion.py @@ -44,7 +44,7 @@ async def test_system_message_conversion_v1(): {"role": "user", "content": "Hello!"}, ] - events = llm_rails._get_events_for_messages(messages, None) + events = llm_rails.event_translator.messages_to_events(messages, None) system_messages = [event for event in events if event["type"] == "SystemMessage"] assert len(system_messages) == 1 @@ -76,7 +76,7 @@ async def test_system_message_conversion_v2x(): {"role": "user", "content": "Hello!"}, ] - events = llm_rails._get_events_for_messages(messages, None) + events = llm_rails.event_translator.messages_to_events(messages, None) system_messages = [event for event in events if event["type"] == "SystemMessage"] assert len(system_messages) == 1 @@ -108,7 +108,7 @@ async def test_system_message_conversion_multiple(): {"role": "user", "content": "Hello!"}, ] - events = llm_rails._get_events_for_messages(messages, None) + events = llm_rails.event_translator.messages_to_events(messages, None) system_messages = [event for event in events if event["type"] == "SystemMessage"] assert len(system_messages) == 2 diff --git a/tests/v2_x/chat.py b/tests/v2_x/chat.py index e3f5713b1..0db5490b7 100644 --- a/tests/v2_x/chat.py +++ b/tests/v2_x/chat.py @@ -18,10 +18,10 @@ from dataclasses import dataclass, field from typing import Dict, List, Optional -import nemoguardrails.rails.llm.llmrails from nemoguardrails import LLMRails, RailsConfig from nemoguardrails.cli.chat import extract_scene_text_content, parse_events_inputs from nemoguardrails.colang.v2_x.runtime.flows import State +from nemoguardrails.rails.llm import runtime_orchestrator from nemoguardrails.utils import new_event_dict, new_uuid os.environ["TOKENIZERS_PARALLELISM"] = "false" @@ -49,9 +49,7 @@ def __init__(self, rails_app: LLMRails): asyncio.create_task(self.run()) # Ensure that the semaphore is assigned to the same loop that we just created - nemoguardrails.rails.llm.llmrails.process_events_semaphore = asyncio.Semaphore( - 1 - ) + runtime_orchestrator.process_events_semaphore = asyncio.Semaphore(1) self.output_summary: list[str] = [] self.should_terminate = False self.enable_input = asyncio.Event()