Skip to content

Commit dfe9688

Browse files
authored
fix(llmrails): skip output rails when dialog disabled and no bot_message provided (NVIDIA-NeMo#1518)
* fix(llmrails): skip output rails when dialog disabled and no bot_message provided When `generate_async` is called with `options={"dialog": False, "output": True}` and no `bot_message` is provided in context, output rails were incorrectly running and checking `None`. This fix ensures output rails only run when there's actual output to check.
1 parent 2df7a14 commit dfe9688

File tree

3 files changed

+78
-5
lines changed

3 files changed

+78
-5
lines changed

nemoguardrails/rails/llm/llm_flows.co

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -28,11 +28,10 @@ define flow run dialog rails
2828

2929
# If the dialog_rails are disabled
3030
if $generation_options and $generation_options.rails.dialog == False
31-
# If the output rails are also disabled, we just return user message.
32-
if $generation_options.rails.output == False
31+
# If output rails are disabled or there's no bot message to check, skip output rails.
32+
if $generation_options.rails.output == False or $bot_message is None
3333
create event StartUtteranceBotAction(script=$user_message)
3434
else
35-
# we take the $bot_message from context.
3635
create event BotMessage(text=$bot_message)
3736
else
3837
# If not, we continue the usual process

tests/test_generation_options.py

Lines changed: 71 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -344,3 +344,74 @@ def test_generation_log_print_summary(capsys):
344344
capture_lines[8]
345345
== "- 4 LLM calls, 8.00s total duration, 1000 total prompt tokens, 2000 total completion tokens, 3000 total tokens."
346346
)
347+
348+
349+
@pytest.mark.parametrize(
350+
"input_opt,output_opt,dialog_opt,expect_input,expect_output",
351+
[
352+
(True, True, True, True, True),
353+
(True, True, False, True, False),
354+
(True, False, True, True, False),
355+
(True, False, False, True, False),
356+
(False, True, True, False, True),
357+
(False, True, False, False, False),
358+
(False, False, True, False, False),
359+
(False, False, False, False, False),
360+
],
361+
)
362+
@pytest.mark.asyncio
363+
async def test_rails_options_combinations(input_opt, output_opt, dialog_opt, expect_input, expect_output):
364+
"""
365+
Test all combinations of input/output/dialog options.
366+
When dialog=False and no bot_message is provided, output rails should skip.
367+
"""
368+
config = RailsConfig.from_content(
369+
colang_content="""
370+
define user express greeting
371+
"hi"
372+
373+
define flow
374+
user express greeting
375+
bot express greeting
376+
377+
define subflow dummy input rail
378+
if "block" in $user_message
379+
bot refuse to respond
380+
stop
381+
382+
define subflow dummy output rail
383+
if "block" in $bot_message
384+
bot refuse to respond
385+
stop
386+
""",
387+
yaml_content="""
388+
rails:
389+
input:
390+
flows:
391+
- dummy input rail
392+
output:
393+
flows:
394+
- dummy output rail
395+
""",
396+
)
397+
chat = TestChat(
398+
config,
399+
llm_completions=[" express greeting", ' "Hello!"'] if dialog_opt else [],
400+
)
401+
402+
res: GenerationResponse = await chat.app.generate_async(
403+
"Hello!",
404+
options={
405+
"rails": {"input": input_opt, "output": output_opt, "dialog": dialog_opt},
406+
"log": {"activated_rails": True},
407+
},
408+
)
409+
410+
activated_rails = res.log.activated_rails if res.log else []
411+
rail_names = [r.name for r in activated_rails]
412+
413+
input_rails_ran = any("input" in name.lower() for name in rail_names)
414+
output_rails_ran = any("output" in name.lower() for name in rail_names)
415+
416+
assert input_rails_ran == expect_input, f"Input rails: expected {expect_input}, got {rail_names}"
417+
assert output_rails_ran == expect_output, f"Output rails: expected {expect_output}, got {rail_names}"

tests/test_parallel_rails_exceptions.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -375,8 +375,11 @@ async def test_output_rails_only_parallel_with_exceptions():
375375
},
376376
}
377377

378-
chat >> "Hello"
379-
result = await chat.app.generate_async(messages=chat.history, options=options_output_only)
378+
messages = [
379+
{"role": "user", "content": "Hello"},
380+
{"role": "assistant", "content": "This response contains harmful content"},
381+
]
382+
result = await chat.app.generate_async(messages=messages, options=options_output_only)
380383

381384
input_rails = [r for r in result.log.activated_rails if r.type == "input"]
382385
output_rails = [r for r in result.log.activated_rails if r.type == "output"]

0 commit comments

Comments
 (0)