Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions doc/api.rst
Original file line number Diff line number Diff line change
Expand Up @@ -198,6 +198,7 @@ API Reference
TAPAttack
TAPAttackContext
TAPAttackResult
TAPAttackScoringConfig
TreeOfAttacksWithPruningAttack

:py:mod:`pyrit.executor.promptgen`
Expand Down
2 changes: 2 additions & 0 deletions pyrit/executor/attack/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
TAPAttack,
TAPAttackContext,
TAPAttackResult,
TAPAttackScoringConfig,
TreeOfAttacksWithPruningAttack,
generate_simulated_conversation_async,
)
Expand Down Expand Up @@ -68,6 +69,7 @@
"TreeOfAttacksWithPruningAttack",
"TAPAttackContext",
"TAPAttackResult",
"TAPAttackScoringConfig",
"SingleTurnAttackStrategy",
"SingleTurnAttackContext",
"PromptSendingAttack",
Expand Down
7 changes: 5 additions & 2 deletions pyrit/executor/attack/component/conversation_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -544,8 +544,11 @@ async def _process_prepended_for_chat_target_async(
# Multi-part messages (e.g., text + image) may have scores on multiple pieces
last_message = valid_messages[-1]
if last_message.api_role == "assistant":
prompt_ids = [str(piece.original_prompt_id) for piece in last_message.message_pieces]
state.last_assistant_message_scores = list(self._memory.get_prompt_scores(prompt_ids=prompt_ids))
scores = []
for piece in last_message.message_pieces:
if piece.scores:
scores.extend(piece.scores)
state.last_assistant_message_scores = scores

return state

Expand Down
42 changes: 42 additions & 0 deletions pyrit/executor/attack/core/attack_strategy.py
Original file line number Diff line number Diff line change
Expand Up @@ -278,6 +278,48 @@ def get_objective_target(self) -> PromptTarget:
"""
return self._objective_target

def _get_attack_result_metadata(
self,
*,
context: AttackStrategyContextT,
request_converters: Optional[list[Any]] = None,
) -> dict[str, Any]:
"""
Build common metadata fields for AttackResult.

This helper method extracts metadata and consolidates it for per-attack storage.

Args:
context: The attack context containing memory labels and other state.
request_converters: Optional list of PromptConverterConfiguration objects
used in the attack.

Returns:
Dict: A dictionary containing attack_identifier, objective_target_identifier,
request_converter_identifiers, and labels that can be unpacked into
AttackResult constructor.
"""
request_converter_identifiers = None
if request_converters:
# request_converters is a list of PromptConverterConfiguration objects
# Each config has a 'converters' list of actual PromptConverter instances
all_converters = []
for config in request_converters:
if hasattr(config, "converters"):
all_converters.extend(config.converters)
elif hasattr(config, "get_identifier"):
# Direct converter object
all_converters.append(config)
if all_converters:
request_converter_identifiers = [converter.get_identifier() for converter in all_converters]

return {
"attack_identifier": self.get_identifier(),
"objective_target_identifier": self.get_objective_target().get_identifier(),
"request_converter_identifiers": request_converter_identifiers,
"labels": context.memory_labels if context.memory_labels else None,
}
Comment on lines +281 to +321
Copy link
Contributor

@bashirpartovi bashirpartovi Jan 15, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have a few concerns with this:

  • request_converters: Optional[list[Any]] is too loosely typed. When you use Any, it removes IDE autocompletion/type checking and makes it really unclear what callers should pass.
  • Looking at the usage in the code, request_converters only accepts list[PromptConverterConfiguration] or list[PromptConverter]. We should make sure the signature explicitly uses a union (e.g., Optional[list[PromptConverterConfiguration | PromptConverter]]).
  • Using hasattr here is really not ideal because it can fail silently if you ever rename the attribute. For example, if PromptConverterConfiguration.converters were renamed to convs, this code would still type-check and run but would not behave correctly and would be a hidden bug in the code. If we use Union, it would be safer to branch using isinstance() and handle each type explicitly.

I would push back a bit on adding all the identifier fields directly to AttackResult and removing the dataclass. It bloats the class, loses the dataclass benefits (__repr__, __eq__, field defaults), and turns AttackResult into a catch-all bucket. I think a cleaner pattern is adding a single structured field that groups related identifiers:

@dataclass
class AttackLineage: # naming is my weakness
    objective_target_identifier: dict[str, str]
    request_converter_identifiers: Optional[list[dict[str, str]]] = None
    response_converter_identifiers: Optional[list[dict[str, str]]] = None
    objective_scorer_identifier: Optional[dict[str, str]] = None
    adversarial_chat_target_identifier: Optional[dict[str, str]] = None

Then AttackResult just gets one new field which is lineage: Optional[AttackLineage] = None. This keeps the dataclass, groups related tracking info logically, and is easy to extend later without ever touching AttackResult again.
Here's how the helper methods would look:

def _get_attack_lineage(
    self,
    *,
    request_converters: Optional[list[PromptConverterConfiguration | PromptConverter]] = None,
    response_converters: Optional[list[PromptConverterConfiguration | PromptConverter]] = None,
    objective_scorer_identifier: Optional[dict[str, str]] = None,
    adversarial_chat_target_identifier: Optional[dict[str, str]] = None,
) -> AttackLineage:
    return AttackLineage(
        objective_target_identifier=self.get_objective_target().get_identifier(),
        request_converter_identifiers=self._extract_converter_identifiers(request_converters),
        response_converter_identifiers=self._extract_converter_identifiers(response_converters),
        objective_scorer_identifier=objective_scorer_identifier,
        adversarial_chat_target_identifier=adversarial_chat_target_identifier,
    )


@staticmethod
def _extract_converter_identifiers(
    converters: Optional[list[PromptConverterConfiguration | PromptConverter]],
) -> Optional[list[dict[str, str]]]:

    if not converters:
        return None

    all_converters: list[PromptConverter] = []

    for item in converters:
        if isinstance(item, PromptConverterConfiguration):
            all_converters.extend(item.converters)
        elif isinstance(item, PromptConverter):
            all_converters.append(item)

    if not all_converters:
        return None

    return [converter.get_identifier() for converter in all_converters]


def get_attack_scoring_config(self) -> Optional[AttackScoringConfig]:
"""
Get the attack scoring configuration used by this strategy.
Expand Down
2 changes: 2 additions & 0 deletions pyrit/executor/attack/multi_turn/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
TAPAttack,
TAPAttackContext,
TAPAttackResult,
TAPAttackScoringConfig,
TreeOfAttacksWithPruningAttack,
)

Expand All @@ -43,4 +44,5 @@
"TAPAttack",
"TAPAttackResult",
"TAPAttackContext",
"TAPAttackScoringConfig",
]
8 changes: 5 additions & 3 deletions pyrit/executor/attack/multi_turn/chunked_request.py
Original file line number Diff line number Diff line change
Expand Up @@ -291,18 +291,20 @@ async def _perform_async(self, *, context: ChunkedRequestAttackContext) -> Attac
# Determine the outcome
outcome, outcome_reason = self._determine_attack_outcome(score=score)

# Build common metadata for the attack result
metadata = self._get_attack_result_metadata(context=context, request_converters=self._request_converters)

# Create attack result
return AttackResult(
conversation_id=context.session.conversation_id,
objective=context.objective,
attack_identifier=self.get_identifier(),
last_response=response.get_piece() if response else None,
last_score=score,
automated_objective_score=score,
related_conversations=context.related_conversations,
outcome=outcome,
outcome_reason=outcome_reason,
executed_turns=context.executed_turns,
metadata={"combined_chunks": combined_value, "chunk_count": len(context.chunk_responses)},
**metadata,
)

def _determine_attack_outcome(
Expand Down
31 changes: 22 additions & 9 deletions pyrit/executor/attack/multi_turn/crescendo.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@
SelfAskRefusalScorer,
SelfAskScaleScorer,
)
from pyrit.score.score_utils import normalize_score_to_float

logger = logging.getLogger(__name__)

Expand All @@ -64,10 +65,21 @@ class CrescendoAttackContext(MultiTurnAttackContext[Any]):
backtrack_count: int = 0


@dataclass
class CrescendoAttackResult(AttackResult):
"""Result of the Crescendo attack strategy execution."""

def __init__(self, *, backtrack_count: int = 0, **kwargs: Any) -> None:
"""
Initialize a CrescendoAttackResult.

Args:
backtrack_count: Number of backtracks performed during the attack.
**kwargs: All other arguments passed to AttackResult.
"""
super().__init__(**kwargs)
# Store in metadata for database serialization
self.metadata["backtrack_count"] = backtrack_count

@property
def backtrack_count(self) -> int:
"""
Expand All @@ -84,7 +96,7 @@ def backtrack_count(self, value: int) -> None:
Set the number of backtracks performed during the attack.

Args:
value (int): The number of backtracks to set.
value: The number of backtracks to set.
"""
self.metadata["backtrack_count"] = value

Expand Down Expand Up @@ -387,20 +399,21 @@ async def _perform_async(self, *, context: CrescendoAttackContext) -> CrescendoA
else f"Max turns ({self._max_turns}) reached without achieving objective"
)

# Build common metadata for the attack result
metadata = self._get_attack_result_metadata(context=context, request_converters=self._request_converters)

# Prepare the result
result = CrescendoAttackResult(
attack_identifier=self.get_identifier(),
conversation_id=context.session.conversation_id,
objective=context.objective,
outcome=(AttackOutcome.SUCCESS if achieved_objective else AttackOutcome.FAILURE),
outcome_reason=outcome_reason,
executed_turns=context.executed_turns,
last_response=context.last_response.get_piece() if context.last_response else None,
last_score=context.last_score,
related_conversations=context.related_conversations, # Use related_conversations here
automated_objective_score=context.last_score,
related_conversations=context.related_conversations,
backtrack_count=context.backtrack_count,
**metadata,
)
# setting metadata for backtrack count
result.backtrack_count = context.backtrack_count
return result

async def _teardown_async(self, *, context: CrescendoAttackContext) -> None:
Expand Down Expand Up @@ -478,7 +491,7 @@ def _build_adversarial_prompt(
last_message_value = context.last_response.get_value()

if last_message_value:
score_value = context.last_score.get_value()
score_value = normalize_score_to_float(context.last_score)
prompt_parts.extend(
[
f"\nThe target responded to the last question with: {last_message_value}",
Expand Down
8 changes: 5 additions & 3 deletions pyrit/executor/attack/multi_turn/multi_prompt_sending.py
Original file line number Diff line number Diff line change
Expand Up @@ -267,16 +267,18 @@ async def _perform_async(self, *, context: MultiTurnAttackContext[Any]) -> Attac
# Determine the outcome
outcome, outcome_reason = self._determine_attack_outcome(response=response, score=score, context=context)

# Build common metadata for the attack result
metadata = self._get_attack_result_metadata(context=context, request_converters=self._request_converters)

result = AttackResult(
conversation_id=context.session.conversation_id,
objective=context.objective,
attack_identifier=self.get_identifier(),
last_response=response.get_piece() if response else None,
last_score=score,
automated_objective_score=score,
related_conversations=context.related_conversations,
outcome=outcome,
outcome_reason=outcome_reason,
executed_turns=context.executed_turns,
**metadata,
)

return result
Expand Down
7 changes: 4 additions & 3 deletions pyrit/executor/attack/multi_turn/red_teaming.py
Original file line number Diff line number Diff line change
Expand Up @@ -322,15 +322,16 @@ async def _perform_async(self, *, context: MultiTurnAttackContext[Any]) -> Attac
context.executed_turns += 1

# Prepare the result
metadata = self._get_attack_result_metadata(context=context, request_converters=self._request_converters)

return AttackResult(
attack_identifier=self.get_identifier(),
conversation_id=context.session.conversation_id,
objective=context.objective,
outcome=(AttackOutcome.SUCCESS if achieved_objective else AttackOutcome.FAILURE),
executed_turns=context.executed_turns,
last_response=context.last_response.get_piece() if context.last_response else None,
last_score=context.last_score,
automated_objective_score=context.last_score,
related_conversations=context.related_conversations,
**metadata,
)

async def _teardown_async(self, *, context: MultiTurnAttackContext[Any]) -> None:
Expand Down
Loading