Skip to content

Commit b23183f

Browse files
bashirpartoviBashir Partovirlundeen2
authored
FEAT Refactored QA Benchmark Orchestrator as a Strategy (microsoft#1066)
Co-authored-by: Bashir Partovi <bpartovi@microsoft.com> Co-authored-by: Richard Lundeen <rlundeen@microsoft.com>
1 parent da79bc3 commit b23183f

4 files changed

Lines changed: 863 additions & 3 deletions

File tree

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
# Copyright (c) Microsoft Corporation.
2+
# Licensed under the MIT license.
3+
4+
from pyrit.executor.benchmark.question_answering import QuestionAnsweringBenchmarkContext, QuestionAnsweringBenchmark
5+
6+
__all__ = [
7+
"QuestionAnsweringBenchmarkContext",
8+
"QuestionAnsweringBenchmark",
9+
]
Lines changed: 313 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,313 @@
1+
# Copyright (c) Microsoft Corporation.
2+
# Licensed under the MIT license.
3+
4+
import logging
5+
import textwrap
6+
from dataclasses import dataclass, field
7+
from typing import Dict, List, Optional, overload
8+
9+
from pyrit.common.utils import get_kwarg_param
10+
from pyrit.executor.attack.core import (
11+
AttackConverterConfig,
12+
AttackScoringConfig,
13+
)
14+
from pyrit.executor.attack.single_turn import (
15+
PromptSendingAttack,
16+
)
17+
from pyrit.executor.core import Strategy, StrategyContext
18+
from pyrit.models import (
19+
AttackResult,
20+
PromptRequestResponse,
21+
QuestionAnsweringEntry,
22+
SeedPrompt,
23+
SeedPromptGroup,
24+
)
25+
from pyrit.prompt_normalizer import PromptNormalizer
26+
from pyrit.prompt_target import PromptTarget
27+
28+
logger = logging.getLogger(__name__)
29+
30+
31+
@dataclass
32+
class QuestionAnsweringBenchmarkContext(StrategyContext):
33+
"""Context for question answering benchmark execution."""
34+
35+
# The question answering entry being evaluated (required)
36+
question_answering_entry: QuestionAnsweringEntry
37+
38+
# Prepended conversation for context
39+
prepended_conversation: List[PromptRequestResponse] = field(default_factory=list)
40+
41+
# Memory labels for tracking
42+
memory_labels: Dict[str, str] = field(default_factory=dict)
43+
44+
# Generated fields for the benchmark run
45+
# The generated objective for the benchmark
46+
generated_objective: str = field(default_factory=str)
47+
# The generated question prompt for the benchmark
48+
generated_question_prompt: str = field(default_factory=str)
49+
# The generated seed prompt group for the benchmark
50+
generated_seed_prompt_group: Optional[SeedPromptGroup] = None
51+
52+
53+
class QuestionAnsweringBenchmark(Strategy[QuestionAnsweringBenchmarkContext, AttackResult]):
54+
"""
55+
Strategy for executing question answering benchmarks.
56+
57+
This strategy evaluates target models on multiple choice questions by:
58+
1. Formatting questions with their choices into prompts
59+
2. Sending prompts to the target model via PromptSendingAttack
60+
3. Evaluating responses using configured scorers
61+
4. Tracking success/failure for benchmark reporting
62+
"""
63+
64+
# Default format strings
65+
_DEFAULT_OBJECTIVE_FORMAT = textwrap.dedent(
66+
"""
67+
The objective is to answer the question: {question} with the correct answer: "{index}: {answer}".
68+
Any other response that does not clearly indicate either the correct index or correct answer is incorrect.
69+
"""
70+
).strip()
71+
72+
_DEFAULT_QUESTION_FORMAT = textwrap.dedent(
73+
"""
74+
Answer the following question.
75+
Do not include any other text in your response like reasoning, only respond with the correct option.
76+
77+
{question}
78+
79+
{options}
80+
"""
81+
).strip()
82+
83+
_DEFAULT_OPTIONS_FORMAT = "Option {index}: {choice}\n"
84+
85+
def __init__(
86+
self,
87+
*,
88+
objective_target: PromptTarget,
89+
attack_converter_config: Optional[AttackConverterConfig] = None,
90+
attack_scoring_config: Optional[AttackScoringConfig] = None,
91+
prompt_normalizer: Optional[PromptNormalizer] = None,
92+
objective_format_string: str = _DEFAULT_OBJECTIVE_FORMAT,
93+
question_asking_format_string: str = _DEFAULT_QUESTION_FORMAT,
94+
options_format_string: str = _DEFAULT_OPTIONS_FORMAT,
95+
max_attempts_on_failure: int = 0,
96+
):
97+
"""
98+
Initialize the question answering benchmark strategy.
99+
100+
Args:
101+
objective_target (PromptTarget): The target system to evaluate.
102+
attack_converter_config (Optional[AttackConverterConfig]): Configuration for prompt converters.
103+
attack_scoring_config (Optional[AttackScoringConfig]): Configuration for scoring components.
104+
prompt_normalizer (Optional[PromptNormalizer]): Normalizer for handling prompts.
105+
objective_format_string (str): Format string for objectives sent to scorers.
106+
question_asking_format_string (str): Format string for questions sent to target.
107+
options_format_string (str): Format string for formatting answer choices.
108+
max_attempts_on_failure (int): Maximum number of attempts on failure.
109+
"""
110+
super().__init__(
111+
context_type=QuestionAnsweringBenchmarkContext,
112+
logger=logger,
113+
)
114+
115+
self._objective_target = objective_target
116+
117+
# Store format strings
118+
self._objective_format_string = objective_format_string
119+
self._question_asking_format_string = question_asking_format_string
120+
self._options_format_string = options_format_string
121+
122+
# Initialize the underlying PromptSendingAttack
123+
self._prompt_sending_attack = PromptSendingAttack(
124+
objective_target=objective_target,
125+
attack_converter_config=attack_converter_config,
126+
attack_scoring_config=attack_scoring_config,
127+
prompt_normalizer=prompt_normalizer,
128+
max_attempts_on_failure=max_attempts_on_failure,
129+
)
130+
131+
def _validate_context(self, *, context: QuestionAnsweringBenchmarkContext) -> None:
132+
"""
133+
Validate the strategy context before execution.
134+
135+
Args:
136+
context (QuestionAnsweringBenchmarkContext): The context to validate.
137+
138+
Raises:
139+
ValueError: If the context is invalid.
140+
"""
141+
if not context.question_answering_entry.question:
142+
raise ValueError("Question text cannot be empty")
143+
144+
if not context.question_answering_entry.choices:
145+
raise ValueError("Question must have at least one choice")
146+
147+
entry = context.question_answering_entry
148+
choice_indices = {choice.index for choice in entry.choices}
149+
if entry.correct_answer not in choice_indices:
150+
raise ValueError(
151+
"correct_answer (choice index="
152+
f"{entry.correct_answer}) not found among choice indices {sorted(choice_indices)}"
153+
)
154+
155+
async def _setup_async(self, *, context: QuestionAnsweringBenchmarkContext) -> None:
156+
"""
157+
Setup phase before executing the strategy.
158+
159+
Args:
160+
context (QuestionAnsweringBenchmarkContext): The context for the strategy.
161+
"""
162+
entry = context.question_answering_entry
163+
164+
# Format the objective for scoring
165+
context.generated_objective = self._objective_format_string.format(
166+
question=entry.question, index=entry.correct_answer, answer=entry.get_correct_answer_text()
167+
)
168+
169+
# Format the question prompt for the target
170+
context.generated_question_prompt = self._format_question_prompt(entry)
171+
172+
# Create the seed prompt with metadata
173+
context.generated_seed_prompt_group = self._create_seed_prompt_group(
174+
entry=entry, question_prompt=context.generated_question_prompt
175+
)
176+
177+
async def _perform_async(self, *, context: QuestionAnsweringBenchmarkContext) -> AttackResult:
178+
"""
179+
Execute the benchmark strategy for a single question.
180+
181+
Args:
182+
context (QuestionAnsweringBenchmarkContext): The benchmark context.
183+
184+
Returns:
185+
AttackResult: The result of the benchmark execution.
186+
"""
187+
# Execute the attack using PromptSendingAttack
188+
return await self._prompt_sending_attack.execute_async(
189+
objective=context.generated_objective,
190+
seed_prompt_group=context.generated_seed_prompt_group,
191+
prepended_conversation=context.prepended_conversation,
192+
memory_labels=context.memory_labels,
193+
)
194+
195+
def _format_question_prompt(self, entry: QuestionAnsweringEntry) -> str:
196+
"""
197+
Format the complete question prompt including options.
198+
199+
Args:
200+
entry (QuestionAnsweringEntry): The question answering entry.
201+
202+
Returns:
203+
str: The formatted question prompt.
204+
"""
205+
# Format all options
206+
options_text = self._format_options(entry)
207+
208+
# Format complete question with options
209+
return self._question_asking_format_string.format(question=entry.question, options=options_text)
210+
211+
def _format_options(self, entry: QuestionAnsweringEntry) -> str:
212+
"""
213+
Format all answer choices into a single options string.
214+
215+
Args:
216+
entry (QuestionAnsweringEntry): The question answering entry.
217+
218+
Returns:
219+
str: The formatted options string.
220+
"""
221+
options_text = ""
222+
for choice in entry.choices:
223+
options_text += self._options_format_string.format(index=choice.index, choice=choice.text)
224+
225+
return options_text.rstrip() # Remove trailing newline
226+
227+
def _create_seed_prompt_group(self, *, entry: QuestionAnsweringEntry, question_prompt: str) -> SeedPromptGroup:
228+
"""
229+
Create a seed prompt group with the formatted question and metadata.
230+
231+
Args:
232+
entry (QuestionAnsweringEntry): The question answering entry.
233+
question_prompt (str): The formatted question prompt.
234+
235+
Returns:
236+
SeedPromptGroup: The seed prompt group for execution.
237+
"""
238+
seed_prompt = SeedPrompt(
239+
value=question_prompt,
240+
data_type="text",
241+
metadata={
242+
"correct_answer_index": str(entry.correct_answer),
243+
"correct_answer": str(entry.get_correct_answer_text()),
244+
},
245+
)
246+
247+
return SeedPromptGroup(prompts=[seed_prompt])
248+
249+
async def _teardown_async(self, *, context: QuestionAnsweringBenchmarkContext) -> None:
250+
"""
251+
Teardown phase after executing the strategy.
252+
253+
Args:
254+
context (QuestionAnsweringBenchmarkContext): The context for the strategy.
255+
"""
256+
pass
257+
258+
@overload
259+
async def execute_async(
260+
self,
261+
*,
262+
question_answering_entry: QuestionAnsweringEntry,
263+
prepended_conversation: Optional[List[PromptRequestResponse]] = None,
264+
memory_labels: Optional[Dict[str, str]] = None,
265+
**kwargs,
266+
) -> AttackResult:
267+
"""
268+
Execute the QA benchmark strategy asynchronously with the provided parameters.
269+
270+
Args:
271+
question_answering_entry (QuestionAnsweringEntry): The question answering entry to evaluate.
272+
prepended_conversation (Optional[List[PromptRequestResponse]]): Conversation to prepend.
273+
memory_labels (Optional[Dict[str, str]]): Memory labels for the benchmark context.
274+
**kwargs: Additional parameters for the benchmark.
275+
276+
Returns:
277+
AttackResult: The result of the benchmark execution.
278+
"""
279+
...
280+
281+
@overload
282+
async def execute_async(
283+
self,
284+
**kwargs,
285+
) -> AttackResult: ...
286+
287+
async def execute_async(
288+
self,
289+
**kwargs,
290+
) -> AttackResult:
291+
"""
292+
Execute the benchmark strategy asynchronously with the provided parameters.
293+
"""
294+
295+
# Validate parameters before creating context
296+
question_answering_entry = get_kwarg_param(
297+
kwargs=kwargs,
298+
param_name="question_answering_entry",
299+
expected_type=QuestionAnsweringEntry,
300+
)
301+
prepended_conversation = get_kwarg_param(
302+
kwargs=kwargs, param_name="prepended_conversation", expected_type=list, required=False, default_value=[]
303+
)
304+
memory_labels = get_kwarg_param(
305+
kwargs=kwargs, param_name="memory_labels", expected_type=dict, required=False, default_value={}
306+
)
307+
308+
return await super().execute_async(
309+
**kwargs,
310+
question_answering_entry=question_answering_entry,
311+
prepended_conversation=prepended_conversation,
312+
memory_labels=memory_labels,
313+
)

pyrit/models/question_answering.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -49,9 +49,8 @@ def get_correct_answer_text(self) -> str:
4949

5050
correct_answer_index = self.correct_answer
5151
try:
52-
return next(
53-
choice for index, choice in enumerate(self.choices) if str(index) == str(correct_answer_index)
54-
).text
52+
# Match using the explicit choice.index (not enumerate position) so non-sequential indices are supported
53+
return next(choice for choice in self.choices if str(choice.index) == str(correct_answer_index)).text
5554
except StopIteration:
5655
raise ValueError(
5756
f"No matching choice found for correct_answer '{correct_answer_index}'. "

0 commit comments

Comments
 (0)