Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 20 additions & 3 deletions nemoguardrails/actions/llm/generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -887,9 +887,22 @@ async def generate_bot_message(
# of course, it does not work when passed as context in `run_output_rails_in_streaming`
# streaming_handler is set when stream_async method is used

has_streaming_handler = streaming_handler is not None
has_output_streaming = (
self.config.rails.output.streaming
and self.config.rails.output.streaming.enabled
)
log.info(
f"generate_bot_message: streaming_handler={has_streaming_handler}, output.streaming={has_output_streaming}"
)
# if streaming_handler and len(self.config.rails.output.flows) > 0:
if streaming_handler and self.config.rails.output.streaming.enabled:
if streaming_handler and has_output_streaming:
log.info("Setting skip_output_rails = True")
context_updates["skip_output_rails"] = True
else:
log.info(
f"NOT setting skip_output_rails: streaming_handler={has_streaming_handler}, output.streaming={has_output_streaming}"
)

if bot_intent in self.config.bot_messages:
# Choose a message randomly from self.config.bot_messages[bot_message]
Expand Down Expand Up @@ -970,7 +983,9 @@ async def generate_bot_message(
new_event_dict("BotMessage", text=text)
)

return ActionResult(events=output_events)
return ActionResult(
events=output_events, context_updates=context_updates
)
else:
if streaming_handler:
await streaming_handler.push_chunk(
Expand All @@ -987,7 +1002,9 @@ async def generate_bot_message(
)
output_events.append(bot_message_event)

return ActionResult(events=output_events)
return ActionResult(
events=output_events, context_updates=context_updates
)

# If we are in passthrough mode, we just use the input for prompting
if self.config.passthrough:
Expand Down
17 changes: 17 additions & 0 deletions nemoguardrails/colang/v1_0/runtime/flows.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@

"""A simplified modeling of the CoFlows engine."""

import logging
import uuid
from dataclasses import dataclass, field
from enum import Enum
Expand All @@ -25,6 +26,8 @@
from nemoguardrails.colang.v1_0.runtime.sliding import slide
from nemoguardrails.utils import new_event_dict, new_uuid

log = logging.getLogger(__name__)


@dataclass
class FlowConfig:
Expand Down Expand Up @@ -356,7 +359,15 @@ def compute_next_state(state: State, event: dict) -> State:
if event["type"] == "ContextUpdate":
# TODO: add support to also remove keys from the context.
# maybe with a special context key e.g. "__remove__": ["key1", "key2"]
if "skip_output_rails" in event["data"]:
log.info(
f"ContextUpdate setting skip_output_rails={event['data']['skip_output_rails']}"
)
state.context.update(event["data"])
if "skip_output_rails" in state.context:
log.info(
f"After update, context skip_output_rails={state.context.get('skip_output_rails')}"
)
state.context_updates = {}
state.next_step = None
return state
Expand Down Expand Up @@ -415,6 +426,12 @@ def compute_next_state(state: State, event: dict) -> State:
_record_next_step(new_state, flow_state, flow_config, priority_modifier=0.9)
continue

# Debug logging for BotMessage event and skip_output_rails
if event["type"] == "BotMessage":
log.info(
f"BotMessage event processing for flow '{flow_config.id}', skip_output_rails in context: {flow_state.context.get('skip_output_rails', 'NOT SET')}"
)

# If we're at a branching point, we look at all individual heads.
matching_head = None

Expand Down
17 changes: 16 additions & 1 deletion nemoguardrails/colang/v1_0/runtime/runtime.py
Original file line number Diff line number Diff line change
Expand Up @@ -523,6 +523,9 @@ async def _run_output_rails_in_parallel_streaming(
flows_with_params: Dictionary mapping flow_id to {"action_name": str, "params": dict}
events: The events list for context
"""
# Compute context from events so actions can access bot_message
context = compute_context(events)

tasks = []

async def run_single_rail(flow_id: str, action_info: dict) -> tuple:
Expand All @@ -532,8 +535,11 @@ async def run_single_rail(flow_id: str, action_info: dict) -> tuple:
action_name = action_info["action_name"]
params = action_info["params"]

# Merge context into params so actions have access to bot_message
params_with_context = {**params, "context": context}

result_tuple = await self.action_dispatcher.execute_action(
action_name, params
action_name, params_with_context
)
result, status = result_tuple

Expand Down Expand Up @@ -731,10 +737,19 @@ async def _process_start_action(self, events: List[dict]) -> List[dict]:
return_events = []
context_updates = {}

if action_name == "generate_bot_message":
log.info(
f"DEBUG: generate_bot_message returned, isinstance(ActionResult)={isinstance(result, ActionResult)}"
)

if isinstance(result, ActionResult):
return_value = result.return_value
return_events = result.events
context_updates.update(result.context_updates)
if action_name == "generate_bot_message":
log.info(
f"generate_bot_message ActionResult: context_updates={context_updates}, skip_output_rails={'skip_output_rails' in context_updates}"
)

# If we have an action result key, we also record the update.
if action_result_key:
Expand Down
215 changes: 215 additions & 0 deletions nemoguardrails/rails/llm/config_loader.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,215 @@
# SPDX-FileCopyrightText: Copyright (c) 2023-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Config loader for loading default flows, library content, and config.py modules."""

import importlib.util
import logging
import os
from typing import List, Any

from nemoguardrails.colang import parse_colang_file
from nemoguardrails.colang.v1_0.runtime.flows import _normalize_flow_id
from nemoguardrails.rails.llm.config import RailsConfig

log = logging.getLogger(__name__)


class ConfigLoader:
"""Enriches RailsConfig with default flows, library content, and config.py modules."""

@staticmethod
def load_config(config: RailsConfig) -> List[Any]:
"""Enrich the config with default flows and library content.

Args:
config: The rails configuration to enrich.

Returns:
List of config modules loaded from config.py files.
"""
# Load default flows for Colang 1.0
if config.colang_version == "1.0":
ConfigLoader._load_default_flows(config)
ConfigLoader._load_library_content(config)

# Mark rail flows as system flows
ConfigLoader._mark_rail_flows_as_system(config)

# Load and execute config.py modules
config_modules = ConfigLoader._load_config_modules(config)

# Validate config
ConfigLoader._validate_config(config)

return config_modules

@staticmethod
def _load_default_flows(config: RailsConfig):
"""Load default LLM flows for Colang 1.0.

Args:
config: The rails configuration.
"""
current_folder = os.path.dirname(__file__)
default_flows_file = "llm_flows.co"
default_flows_path = os.path.join(current_folder, default_flows_file)

with open(default_flows_path, "r") as f:
default_flows_content = f.read()
default_flows = parse_colang_file(
default_flows_file, default_flows_content
)["flows"]

# Mark all default flows as system flows
for flow_config in default_flows:
flow_config["is_system_flow"] = True

# Add default flows to config
config.flows.extend(default_flows)
log.debug(f"Loaded {len(default_flows)} default flows")

@staticmethod
def _load_library_content(config: RailsConfig):
"""Load content from the components library.

Args:
config: The rails configuration.
"""
library_path = os.path.join(os.path.dirname(__file__), "../../library")
loaded_files = 0

for root, dirs, files in os.walk(library_path):
for file in files:
full_path = os.path.join(root, file)
if file.endswith(".co"):
log.debug(f"Loading library file: {full_path}")
with open(full_path, "r", encoding="utf-8") as f:
content = parse_colang_file(
file, content=f.read(), version=config.colang_version
)
if not content:
continue

# Mark all library flows as system flows
for flow_config in content["flows"]:
flow_config["is_system_flow"] = True

# Load all the flows
config.flows.extend(content["flows"])

# Load bot messages if not overwritten
for message_id, utterances in content.get(
"bot_messages", {}
).items():
if message_id not in config.bot_messages:
config.bot_messages[message_id] = utterances

loaded_files += 1

log.debug(f"Loaded {loaded_files} library files")

@staticmethod
def _mark_rail_flows_as_system(config: RailsConfig):
"""Mark all flows used in rails as system flows.

Args:
config: The rails configuration.
"""
rail_flow_ids = (
config.rails.input.flows
+ config.rails.output.flows
+ config.rails.retrieval.flows
)

for flow_config in config.flows:
if flow_config.get("id") in rail_flow_ids:
flow_config["is_system_flow"] = True
# Mark them as subflows by default to simplify syntax
flow_config["is_subflow"] = True

@staticmethod
def _load_config_modules(config: RailsConfig) -> List[AttributeError]:
"""Load and execute config.py modules.

Args:
config: The rails configuration.

Returns:
List of loaded config modules.
"""
config_modules = []
paths = list(
config.imported_paths.values() if config.imported_paths else []
) + [config.config_path]

for _path in paths:
if _path:
filepath = os.path.join(_path, "config.py")
if os.path.exists(filepath):
filename = os.path.basename(filepath)
spec = importlib.util.spec_from_file_location(filename, filepath)
if spec and spec.loader:
config_module = importlib.util.module_from_spec(spec)
spec.loader.exec_module(config_module)
config_modules.append(config_module)
log.debug(f"Loaded config module from: {filepath}")

return config_modules

@staticmethod
def _validate_config(config: RailsConfig):
"""Run validation checks on the config.

Args:
config: The rails configuration to validate.

Raises:
ValueError: If validation fails.
"""
if config.colang_version == "1.0":
existing_flows_names = set([flow.get("id") for flow in config.flows])
else:
existing_flows_names = set([flow.get("name") for flow in config.flows])

# Validate input rail flows
for flow_name in config.rails.input.flows:
flow_name = _normalize_flow_id(flow_name)
if flow_name not in existing_flows_names:
raise ValueError(
f"The provided input rail flow `{flow_name}` does not exist"
)

# Validate output rail flows
for flow_name in config.rails.output.flows:
flow_name = _normalize_flow_id(flow_name)
if flow_name not in existing_flows_names:
raise ValueError(
f"The provided output rail flow `{flow_name}` does not exist"
)

# Validate retrieval rail flows
for flow_name in config.rails.retrieval.flows:
if flow_name not in existing_flows_names:
raise ValueError(
f"The provided retrieval rail flow `{flow_name}` does not exist"
)

# Check for conflicting modes
if config.passthrough and config.rails.dialog.single_call.enabled:
raise ValueError(
"The passthrough mode and the single call dialog rails mode can't be used at the same time. "
"The single call mode needs to use an altered prompt when prompting the LLM."
)
Loading
Loading