diff --git a/pyrit/memory/__init__.py b/pyrit/memory/__init__.py index 9f10860130..cb4f8af272 100644 --- a/pyrit/memory/__init__.py +++ b/pyrit/memory/__init__.py @@ -9,6 +9,7 @@ from pyrit.memory.azure_sql_memory import AzureSQLMemory from pyrit.memory.central_memory import CentralMemory +from pyrit.memory.identifier_filters import IdentifierFilter, IdentifierType from pyrit.memory.memory_embedding import MemoryEmbedding from pyrit.memory.memory_exporter import MemoryExporter from pyrit.memory.memory_interface import MemoryInterface @@ -26,4 +27,6 @@ "MemoryExporter", "PromptMemoryEntry", "SeedEntry", + "IdentifierFilter", + "IdentifierType", ] diff --git a/pyrit/memory/azure_sql_memory.py b/pyrit/memory/azure_sql_memory.py index c9f349c0d9..d9a641ac60 100644 --- a/pyrit/memory/azure_sql_memory.py +++ b/pyrit/memory/azure_sql_memory.py @@ -12,7 +12,7 @@ from sqlalchemy import and_, create_engine, event, exists, text from sqlalchemy.engine.base import Engine from sqlalchemy.exc import SQLAlchemyError -from sqlalchemy.orm import joinedload, sessionmaker +from sqlalchemy.orm import InstrumentedAttribute, joinedload, sessionmaker from sqlalchemy.orm.session import Session from sqlalchemy.sql.expression import TextClause @@ -250,22 +250,6 @@ def _get_message_pieces_memory_label_conditions(self, *, memory_labels: dict[str condition = text(conditions).bindparams(**{key: str(value) for key, value in memory_labels.items()}) return [condition] - def _get_message_pieces_attack_conditions(self, *, attack_id: str) -> Any: - """ - Generate SQL condition for filtering message pieces by attack ID. - - Uses JSON_VALUE() function specific to SQL Azure to query the attack identifier. - - Args: - attack_id (str): The attack identifier to filter by. - - Returns: - Any: SQLAlchemy text condition with bound parameter. - """ - return text("ISJSON(attack_identifier) = 1 AND JSON_VALUE(attack_identifier, '$.hash') = :json_id").bindparams( - json_id=str(attack_id) - ) - def _get_metadata_conditions(self, *, prompt_metadata: dict[str, Union[str, int]]) -> list[TextClause]: """ Generate SQL conditions for filtering by prompt metadata. @@ -321,6 +305,105 @@ def _get_seed_metadata_conditions(self, *, metadata: dict[str, Union[str, int]]) """ return self._get_metadata_conditions(prompt_metadata=metadata)[0] + def _get_condition_json_property_match( + self, + *, + json_column: InstrumentedAttribute[Any], + property_path: str, + value_to_match: str, + partial_match: bool = False, + case_sensitive: bool = False, + ) -> Any: + """ + Return an Azure SQL DB condition for matching a value at a given path within a JSON object. + + Args: + json_column (InstrumentedAttribute[Any]): The JSON-backed model field to query. + property_path (str): The JSON path for the property to match. + value_to_match (str): The string value that must match the extracted JSON property value. + partial_match (bool): Whether to perform a substring match. + case_sensitive (bool): Whether the match should be case-sensitive. Defaults to False. + + Returns: + Any: A SQLAlchemy condition for the backend-specific JSON query. + """ + uid = self._uid() + table_name = json_column.class_.__tablename__ + column_name = json_column.key + pp_param = f"pp_{uid}" + mv_param = f"mv_{uid}" + json_func = "JSON_VALUE" if case_sensitive else "LOWER(JSON_VALUE)" + operator = "LIKE" if partial_match else "=" + target = value_to_match if case_sensitive else value_to_match.lower() + if partial_match: + escaped = target.replace("%", "\\%").replace("_", "\\_") + target = f"%{escaped}%" + + escape_clause = " ESCAPE '\\'" if partial_match else "" + return text( + f"""ISJSON("{table_name}".{column_name}) = 1 + AND {json_func}("{table_name}".{column_name}, :{pp_param}) {operator} :{mv_param}{escape_clause}""" + ).bindparams( + **{ + pp_param: property_path, + mv_param: target, + } + ) + + def _get_condition_json_array_match( + self, + *, + json_column: InstrumentedAttribute[Any], + property_path: str, + array_element_path: str | None = None, + array_to_match: Sequence[str], + ) -> Any: + """ + Return an Azure SQL DB condition for matching an array at a given path within a JSON object. + + Args: + json_column (InstrumentedAttribute[Any]): The JSON-backed SQLAlchemy field to query. + property_path (str): The JSON path for the target array. + array_element_path (Optional[str]): An optional JSON path applied to each array item before matching. + array_to_match (Sequence[str]): The array that must match the extracted JSON array values. + For a match, ALL values in this array must be present in the JSON array. + If `array_to_match` is empty, the condition matches only if the target is also an empty array or None. + + Returns: + Any: A database-specific SQLAlchemy condition. + """ + uid = self._uid() + table_name = json_column.class_.__tablename__ + column_name = json_column.key + pp_param = f"pp_{uid}" + sp_param = f"sp_{uid}" + + if len(array_to_match) == 0: + return text( + f"""("{table_name}".{column_name} IS NULL + OR JSON_QUERY("{table_name}".{column_name}, :{pp_param}) IS NULL + OR JSON_QUERY("{table_name}".{column_name}, :{pp_param}) = '[]')""" + ).bindparams(**{pp_param: property_path}) + + value_expression = f"LOWER(JSON_VALUE(value, :{sp_param}))" if array_element_path else "LOWER(value)" + + conditions = [] + bindparams_dict: dict[str, str] = {pp_param: property_path} + if array_element_path: + bindparams_dict[sp_param] = array_element_path + + for index, match_value in enumerate(array_to_match): + mv_param = f"mv_{uid}_{index}" + conditions.append( + f"""EXISTS(SELECT 1 FROM OPENJSON(JSON_QUERY("{table_name}".{column_name}, + :{pp_param})) + WHERE {value_expression} = :{mv_param})""" + ) + bindparams_dict[mv_param] = match_value.lower() + + combined = " AND ".join(conditions) + return text(f"""ISJSON("{table_name}".{column_name}) = 1 AND {combined}""").bindparams(**bindparams_dict) + def _get_attack_result_harm_category_condition(self, *, targeted_harm_categories: Sequence[str]) -> Any: """ Get the SQL Azure implementation for filtering AttackResults by targeted harm categories. @@ -388,67 +471,6 @@ def _get_attack_result_label_condition(self, *, labels: dict[str, str]) -> Any: ) ) - def _get_attack_result_attack_class_condition(self, *, attack_class: str) -> Any: - """ - Azure SQL implementation for filtering AttackResults by attack class. - Uses JSON_VALUE() on the atomic_attack_identifier JSON column. - - Args: - attack_class (str): Exact attack class name to match. - - Returns: - Any: SQLAlchemy text condition with bound parameter. - """ - return text( - """ISJSON("AttackResultEntries".atomic_attack_identifier) = 1 - AND JSON_VALUE("AttackResultEntries".atomic_attack_identifier, - '$.children.attack.class_name') = :attack_class""" - ).bindparams(attack_class=attack_class) - - def _get_attack_result_converter_classes_condition(self, *, converter_classes: Sequence[str]) -> Any: - """ - Azure SQL implementation for filtering AttackResults by converter classes. - - Uses JSON_VALUE()/JSON_QUERY()/OPENJSON() on the atomic_attack_identifier - JSON column. - - When converter_classes is empty, matches attacks with no converters. - When non-empty, uses OPENJSON() to check all specified classes are present - (AND logic, case-insensitive). - - Args: - converter_classes (Sequence[str]): List of converter class names. Empty list means no converters. - - Returns: - Any: SQLAlchemy combined condition with bound parameters. - """ - if len(converter_classes) == 0: - # Explicitly "no converters": match attacks where the converter list - # is absent, null, or empty in the stored JSON. - return text( - """("AttackResultEntries".atomic_attack_identifier IS NULL - OR JSON_QUERY("AttackResultEntries".atomic_attack_identifier, - '$.children.attack.children.request_converters') IS NULL - OR JSON_QUERY("AttackResultEntries".atomic_attack_identifier, - '$.children.attack.children.request_converters') = '[]')""" - ) - - conditions = [] - bindparams_dict: dict[str, str] = {} - for i, cls in enumerate(converter_classes): - param_name = f"conv_cls_{i}" - conditions.append( - f"""EXISTS(SELECT 1 FROM OPENJSON(JSON_QUERY("AttackResultEntries".atomic_attack_identifier, - '$.children.attack.children.request_converters')) - WHERE LOWER(JSON_VALUE(value, '$.class_name')) = :{param_name})""" - ) - bindparams_dict[param_name] = cls.lower() - - combined = " AND ".join(conditions) - return text(f"""ISJSON("AttackResultEntries".atomic_attack_identifier) = 1 AND {combined}""").bindparams( - **bindparams_dict - ) - def get_unique_attack_class_names(self) -> list[str]: """ Azure SQL implementation: extract unique class_name values from @@ -593,44 +615,12 @@ def _get_scenario_result_label_condition(self, *, labels: dict[str, str]) -> Any conditions.append(condition) return and_(*conditions) - def _get_scenario_result_target_endpoint_condition(self, *, endpoint: str) -> TextClause: - """ - Get the SQL Azure implementation for filtering ScenarioResults by target endpoint. - - Uses JSON_VALUE() function specific to SQL Azure. - - Args: - endpoint (str): The endpoint URL substring to filter by (case-insensitive). - - Returns: - Any: SQLAlchemy text condition with bound parameter. - """ - return text( - """ISJSON(objective_target_identifier) = 1 - AND LOWER(JSON_VALUE(objective_target_identifier, '$.endpoint')) LIKE :endpoint""" - ).bindparams(endpoint=f"%{endpoint.lower()}%") - - def _get_scenario_result_target_model_condition(self, *, model_name: str) -> TextClause: - """ - Get the SQL Azure implementation for filtering ScenarioResults by target model name. - - Uses JSON_VALUE() function specific to SQL Azure. - - Args: - model_name (str): The model name substring to filter by (case-insensitive). - - Returns: - Any: SQLAlchemy text condition with bound parameter. - """ - return text( - """ISJSON(objective_target_identifier) = 1 - AND LOWER(JSON_VALUE(objective_target_identifier, '$.model_name')) LIKE :model_name""" - ).bindparams(model_name=f"%{model_name.lower()}%") - def add_message_pieces_to_memory(self, *, message_pieces: Sequence[MessagePiece]) -> None: """ Insert a list of message pieces into the memory storage. + Args: + message_pieces (Sequence[MessagePiece]): A sequence of MessagePiece instances to be added. """ self._insert_entries(entries=[PromptMemoryEntry(entry=piece) for piece in message_pieces]) diff --git a/pyrit/memory/identifier_filters.py b/pyrit/memory/identifier_filters.py new file mode 100644 index 0000000000..357625bbb6 --- /dev/null +++ b/pyrit/memory/identifier_filters.py @@ -0,0 +1,51 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +from dataclasses import dataclass +from enum import Enum + + +class IdentifierType(Enum): + """Enumeration of supported identifier types for filtering.""" + + ATTACK = "attack" + TARGET = "target" + SCORER = "scorer" + CONVERTER = "converter" + + +@dataclass(frozen=True) +class IdentifierFilter: + """ + Immutable filter definition for matching JSON-backed identifier properties. + + Attributes: + identifier_type: The type of identifier column to filter on. + property_path: The JSON path for the property to match. + array_element_path : An optional JSON path that indicates the property at property_path is an array + and the condition should resolve if the value at array_element_path matches the target + for any element in that array. Cannot be used with partial_match or case_sensitive. + value_to_match: The string value that must match the extracted JSON property value. + partial_match: Whether to perform a substring match. Cannot be used with array_element_path or case_sensitive. + case_sensitive: Whether the match should be case-sensitive. + Cannot be used with array_element_path or partial_match. + """ + + identifier_type: IdentifierType + property_path: str + value_to_match: str + array_element_path: str | None = None + partial_match: bool = False + case_sensitive: bool = False + + def __post_init__(self) -> None: + """ + Validate the filter configuration. + + Raises: + ValueError: If the filter configuration is not valid. + """ + if self.array_element_path and (self.partial_match or self.case_sensitive): + raise ValueError("Cannot use array_element_path with partial_match or case_sensitive") + if self.partial_match and self.case_sensitive: + raise ValueError("case_sensitive is not reliably supported with partial_match across all backends") diff --git a/pyrit/memory/memory_interface.py b/pyrit/memory/memory_interface.py index 90322ebec4..b60c5ecaeb 100644 --- a/pyrit/memory/memory_interface.py +++ b/pyrit/memory/memory_interface.py @@ -19,6 +19,7 @@ from sqlalchemy.orm.attributes import InstrumentedAttribute from pyrit.common.path import DB_DATA_PATH +from pyrit.memory.identifier_filters import IdentifierFilter, IdentifierType from pyrit.memory.memory_embedding import ( MemoryEmbedding, default_memory_embedding_factory, @@ -74,6 +75,11 @@ class MemoryInterface(abc.ABC): results_path: str = None engine: Engine = None + @staticmethod + def _uid() -> str: + """Return a short unique suffix for bind-param deduplication.""" + return uuid.uuid4().hex[:8] + def __init__(self, embedding_model: Optional[Any] = None) -> None: """ Initialize the MemoryInterface. @@ -113,6 +119,146 @@ def disable_embedding(self) -> None: """ self.memory_embedding = None + def _build_identifier_filter_conditions( + self, + *, + identifier_filters: Sequence[IdentifierFilter], + identifier_column_map: dict[IdentifierType, Any], + caller: str, + ) -> list[Any]: + """ + Build SQLAlchemy conditions from a sequence of IdentifierFilters. + + Args: + identifier_filters (Sequence[IdentifierFilter]): The filters to convert to conditions. + identifier_column_map (dict[IdentifierType, Any]): Mapping from IdentifierType to the + JSON-backed SQLAlchemy column that should be queried for that type. + caller (str): Name of the calling method, used in error messages. + + Returns: + list[Any]: A list of SQLAlchemy conditions. + + Raises: + ValueError: If a filter uses an IdentifierType not in identifier_column_map. + """ + conditions: list[Any] = [] + for identifier_filter in identifier_filters: + column = identifier_column_map.get(identifier_filter.identifier_type) + if column is None: + supported = ", ".join(t.name for t in identifier_column_map) + raise ValueError( + f"{caller} does not support identifier type " + f"{identifier_filter.identifier_type!r}. Supported: {supported}" + ) + conditions.append( + self._get_condition_json_match( + json_column=column, + property_path=identifier_filter.property_path, + array_element_path=identifier_filter.array_element_path, + value_to_match=identifier_filter.value_to_match, + partial_match=identifier_filter.partial_match, + case_sensitive=identifier_filter.case_sensitive, + ) + ) + return conditions + + def _get_condition_json_match( + self, + *, + json_column: InstrumentedAttribute[Any], + property_path: str, + array_element_path: str | None = None, + value_to_match: str, + partial_match: bool = False, + case_sensitive: bool = False, + ) -> Any: + """ + Return a database-specific condition for matching a value at a given path within a JSON object + or within items of a JSON array if array_element_path is provided. + + Args: + json_column (InstrumentedAttribute[Any]): The JSON-backed model field to query. + property_path (str): The JSON path for the property to match. + array_element_path (str | None): An optional JSON path that indicates property at property_path is an array + and the condition should resolve if any element in that array matches the value. + Cannot be used with partial_match. + value_to_match (str): The string value that must match the extracted JSON property value. + partial_match (bool): Whether to perform a substring match. + case_sensitive (bool): Whether the match should be case-sensitive. Defaults to False. + + Returns: + Any: A SQLAlchemy condition for the backend-specific JSON query. + + Raises: + ValueError: If array_element_path is provided together with partial_match or case_sensitive + """ + if array_element_path and (partial_match or case_sensitive): + raise ValueError("Cannot use array_element_path with partial_match or case_sensitive") + if partial_match and case_sensitive: + raise ValueError("case_sensitive is not reliably supported with partial_match across all backends") + if array_element_path: + return self._get_condition_json_array_match( + json_column=json_column, + property_path=property_path, + array_element_path=array_element_path, + array_to_match=[value_to_match], + ) + return self._get_condition_json_property_match( + json_column=json_column, + property_path=property_path, + value_to_match=value_to_match, + partial_match=partial_match, + case_sensitive=case_sensitive, + ) + + @abc.abstractmethod + def _get_condition_json_property_match( + self, + *, + json_column: InstrumentedAttribute[Any], + property_path: str, + value_to_match: str, + partial_match: bool = False, + case_sensitive: bool = False, + ) -> Any: + """ + Return a database-specific condition for matching a value at a given path within a JSON object. + + Args: + json_column (InstrumentedAttribute[Any]): The JSON-backed model field to query. + property_path (str): The JSON path for the property to match. + value_to_match (str): The string value that must match the extracted JSON property value. + partial_match (bool): Whether to perform a substring match. + case_sensitive (bool): Whether the match should be case-sensitive. Defaults to False. + + Returns: + Any: A SQLAlchemy condition for the backend-specific JSON query. + """ + + @abc.abstractmethod + def _get_condition_json_array_match( + self, + *, + json_column: InstrumentedAttribute[Any], + property_path: str, + array_element_path: str | None = None, + array_to_match: Sequence[str], + ) -> Any: + """ + Return a database-specific condition for matching an array at a given path within a JSON object. + + Args: + json_column (InstrumentedAttribute[Any]): The JSON-backed SQLAlchemy field to query. + property_path (str): The JSON path for the target array. + array_element_path (Optional[str]): An optional JSON path applied to each array item before matching. + array_to_match (Sequence[str]): The array that must match the extracted JSON array values. + For a match, ALL values in this array must be present in the JSON array. + If `array_to_match` is empty, the condition matches only if the target is also an empty array or None. + + Returns: + Any: A database-specific SQLAlchemy condition. + """ + @abc.abstractmethod def get_all_embeddings(self) -> Sequence[EmbeddingDataEntry]: """ @@ -155,12 +301,6 @@ def _get_message_pieces_prompt_metadata_conditions( list: A list of conditions for filtering memory entries based on prompt metadata. """ - @abc.abstractmethod - def _get_message_pieces_attack_conditions(self, *, attack_id: str) -> Any: - """ - Return a condition to retrieve based on attack ID. - """ - @abc.abstractmethod def _get_seed_metadata_conditions(self, *, metadata: dict[str, Union[str, int]]) -> Any: """ @@ -289,40 +429,6 @@ def _get_attack_result_label_condition(self, *, labels: dict[str, str]) -> Any: Database-specific SQLAlchemy condition. """ - @abc.abstractmethod - def _get_attack_result_attack_class_condition(self, *, attack_class: str) -> Any: - """ - Return a database-specific condition for filtering AttackResults by attack class - (class_name in the attack_identifier JSON column). - - Args: - attack_class: Exact attack class name to match. - - Returns: - Database-specific SQLAlchemy condition. - """ - - @abc.abstractmethod - def _get_attack_result_converter_classes_condition(self, *, converter_classes: Sequence[str]) -> Any: - """ - Return a database-specific condition for filtering AttackResults by converter classes - in the request_converter_identifiers array within attack_identifier JSON column. - - This method is only called when converter filtering is requested (converter_classes - is not None). The caller handles the None-vs-list distinction: - - - ``len(converter_classes) == 0``: return a condition matching attacks with NO converters. - - ``len(converter_classes) > 0``: return a condition requiring ALL specified converter - class names to be present (AND logic, case-insensitive). - - Args: - converter_classes: Converter class names to require. An empty sequence means - "match only attacks that have no converters". - - Returns: - Database-specific SQLAlchemy condition. - """ - @abc.abstractmethod def get_unique_attack_class_names(self) -> list[str]: """ @@ -377,30 +483,6 @@ def _get_scenario_result_label_condition(self, *, labels: dict[str, str]) -> Any Database-specific SQLAlchemy condition. """ - @abc.abstractmethod - def _get_scenario_result_target_endpoint_condition(self, *, endpoint: str) -> Any: - """ - Return a database-specific condition for filtering ScenarioResults by target endpoint. - - Args: - endpoint: Endpoint substring to search for (case-insensitive). - - Returns: - Database-specific SQLAlchemy condition. - """ - - @abc.abstractmethod - def _get_scenario_result_target_model_condition(self, *, model_name: str) -> Any: - """ - Return a database-specific condition for filtering ScenarioResults by target model name. - - Args: - model_name: Model name substring to search for (case-insensitive). - - Returns: - Database-specific SQLAlchemy condition. - """ - def add_scores_to_memory(self, *, scores: Sequence[Score]) -> None: """ Insert a list of scores into the memory storage. @@ -425,6 +507,7 @@ def get_scores( score_category: Optional[str] = None, sent_after: Optional[datetime] = None, sent_before: Optional[datetime] = None, + identifier_filters: Optional[Sequence[IdentifierFilter]] = None, ) -> Sequence[Score]: """ Retrieve a list of Score objects based on the specified filters. @@ -435,6 +518,8 @@ def get_scores( score_category (Optional[str]): The category of the score to filter by. sent_after (Optional[datetime]): Filter for scores sent after this datetime. sent_before (Optional[datetime]): Filter for scores sent before this datetime. + identifier_filters (Optional[Sequence[IdentifierFilter]]): A sequence of IdentifierFilter objects that + allows filtering by various scorer identifier JSON properties. Defaults to None. Returns: Sequence[Score]: A list of Score objects that match the specified filters. @@ -451,6 +536,14 @@ def get_scores( conditions.append(ScoreEntry.timestamp >= sent_after) if sent_before: conditions.append(ScoreEntry.timestamp <= sent_before) + if identifier_filters: + conditions.extend( + self._build_identifier_filter_conditions( + identifier_filters=identifier_filters, + identifier_column_map={IdentifierType.SCORER: ScoreEntry.scorer_class_identifier}, + caller="get_scores", + ) + ) if not conditions: return [] @@ -581,6 +674,7 @@ def get_message_pieces( data_type: Optional[str] = None, not_data_type: Optional[str] = None, converted_value_sha256: Optional[Sequence[str]] = None, + identifier_filters: Optional[Sequence[IdentifierFilter]] = None, ) -> Sequence[MessagePiece]: """ Retrieve a list of MessagePiece objects based on the specified filters. @@ -602,6 +696,9 @@ def get_message_pieces( not_data_type (Optional[str], optional): The data type to exclude. Defaults to None. converted_value_sha256 (Optional[Sequence[str]], optional): A list of SHA256 hashes of converted values. Defaults to None. + identifier_filters (Optional[Sequence[IdentifierFilter]], optional): + A sequence of IdentifierFilter objects that + allow filtering by various identifier JSON properties. Defaults to None. Returns: Sequence[MessagePiece]: A list of MessagePiece objects that match the specified filters. @@ -612,7 +709,13 @@ def get_message_pieces( """ conditions = [] if attack_id: - conditions.append(self._get_message_pieces_attack_conditions(attack_id=str(attack_id))) + conditions.append( + self._get_condition_json_property_match( + json_column=PromptMemoryEntry.attack_identifier, + property_path="$.hash", + value_to_match=str(attack_id), + ) + ) if role: conditions.append(PromptMemoryEntry.role == role) if conversation_id: @@ -638,7 +741,18 @@ def get_message_pieces( conditions.append(PromptMemoryEntry.converted_value_data_type != not_data_type) if converted_value_sha256: conditions.append(PromptMemoryEntry.converted_value_sha256.in_(converted_value_sha256)) - + if identifier_filters: + conditions.extend( + self._build_identifier_filter_conditions( + identifier_filters=identifier_filters, + identifier_column_map={ + IdentifierType.ATTACK: PromptMemoryEntry.attack_identifier, + IdentifierType.TARGET: PromptMemoryEntry.prompt_target_identifier, + IdentifierType.CONVERTER: PromptMemoryEntry.converter_identifiers, + }, + caller="get_message_pieces", + ) + ) try: memory_entries: Sequence[PromptMemoryEntry] = self._query_entries( PromptMemoryEntry, conditions=and_(*conditions) if conditions else None, join_scores=True @@ -1365,6 +1479,7 @@ def get_attack_results( converter_classes: Optional[Sequence[str]] = None, targeted_harm_categories: Optional[Sequence[str]] = None, labels: Optional[dict[str, str]] = None, + identifier_filters: Optional[Sequence[IdentifierFilter]] = None, ) -> Sequence[AttackResult]: """ Retrieve a list of AttackResult objects based on the specified filters. @@ -1392,6 +1507,9 @@ def get_attack_results( labels (Optional[dict[str, str]], optional): A dictionary of memory labels to filter results by. These labels are associated with the prompts themselves, used for custom tagging and tracking. Defaults to None. + identifier_filters (Optional[Sequence[IdentifierFilter]], optional): + A sequence of IdentifierFilter objects that allows filtering by various attack identifier + JSON properties. Defaults to None. Returns: Sequence[AttackResult]: A list of AttackResult objects that match the specified filters. @@ -1415,12 +1533,26 @@ def get_attack_results( if attack_class: # Use database-specific JSON query method - conditions.append(self._get_attack_result_attack_class_condition(attack_class=attack_class)) + conditions.append( + self._get_condition_json_property_match( + json_column=AttackResultEntry.atomic_attack_identifier, + property_path="$.children.attack.class_name", + value_to_match=attack_class, + case_sensitive=True, + ) + ) if converter_classes is not None: # converter_classes=[] means "only attacks with no converters" # converter_classes=["A","B"] means "must have all listed converters" - conditions.append(self._get_attack_result_converter_classes_condition(converter_classes=converter_classes)) + conditions.append( + self._get_condition_json_array_match( + json_column=AttackResultEntry.atomic_attack_identifier, + property_path="$.children.attack.children.request_converters", + array_element_path="$.class_name", + array_to_match=converter_classes, + ) + ) if targeted_harm_categories: # Use database-specific JSON query method @@ -1432,6 +1564,15 @@ def get_attack_results( # Use database-specific JSON query method conditions.append(self._get_attack_result_label_condition(labels=labels)) + if identifier_filters: + conditions.extend( + self._build_identifier_filter_conditions( + identifier_filters=identifier_filters, + identifier_column_map={IdentifierType.ATTACK: AttackResultEntry.atomic_attack_identifier}, + caller="get_attack_results", + ) + ) + try: entries: Sequence[AttackResultEntry] = self._query_entries( AttackResultEntry, conditions=and_(*conditions) if conditions else None @@ -1612,6 +1753,7 @@ def get_scenario_results( labels: Optional[dict[str, str]] = None, objective_target_endpoint: Optional[str] = None, objective_target_model_name: Optional[str] = None, + identifier_filters: Optional[Sequence[IdentifierFilter]] = None, ) -> Sequence[ScenarioResult]: """ Retrieve a list of ScenarioResult objects based on the specified filters. @@ -1635,6 +1777,9 @@ def get_scenario_results( objective_target_model_name (Optional[str], optional): Filter for scenarios where the objective_target_identifier has a model_name attribute containing this value (case-insensitive). Defaults to None. + identifier_filters (Optional[Sequence[IdentifierFilter]], optional): + A sequence of IdentifierFilter objects that allows filtering by identifier JSON properties. + Defaults to None. Returns: Sequence[ScenarioResult]: A list of ScenarioResult objects that match the specified filters. @@ -1672,11 +1817,37 @@ def get_scenario_results( if objective_target_endpoint: # Use database-specific JSON query method - conditions.append(self._get_scenario_result_target_endpoint_condition(endpoint=objective_target_endpoint)) + conditions.append( + self._get_condition_json_property_match( + json_column=ScenarioResultEntry.objective_target_identifier, + property_path="$.endpoint", + value_to_match=objective_target_endpoint, + partial_match=True, + ) + ) if objective_target_model_name: # Use database-specific JSON query method - conditions.append(self._get_scenario_result_target_model_condition(model_name=objective_target_model_name)) + conditions.append( + self._get_condition_json_property_match( + json_column=ScenarioResultEntry.objective_target_identifier, + property_path="$.model_name", + value_to_match=objective_target_model_name, + partial_match=True, + ) + ) + + if identifier_filters: + conditions.extend( + self._build_identifier_filter_conditions( + identifier_filters=identifier_filters, + identifier_column_map={ + IdentifierType.SCORER: ScenarioResultEntry.objective_scorer_identifier, + IdentifierType.TARGET: ScenarioResultEntry.objective_target_identifier, + }, + caller="get_scenario_results", + ) + ) try: entries: Sequence[ScenarioResultEntry] = self._query_entries( diff --git a/pyrit/memory/sqlite_memory.py b/pyrit/memory/sqlite_memory.py index 7bd05b4f82..942c4384d2 100644 --- a/pyrit/memory/sqlite_memory.py +++ b/pyrit/memory/sqlite_memory.py @@ -13,7 +13,7 @@ from sqlalchemy import and_, create_engine, func, or_, text from sqlalchemy.engine.base import Engine from sqlalchemy.exc import SQLAlchemyError -from sqlalchemy.orm import joinedload, sessionmaker +from sqlalchemy.orm import InstrumentedAttribute, joinedload, sessionmaker from sqlalchemy.orm.session import Session from sqlalchemy.pool import StaticPool from sqlalchemy.sql.expression import TextClause @@ -177,15 +177,6 @@ def _get_message_pieces_prompt_metadata_conditions( condition = text(json_conditions).bindparams(**{key: str(value) for key, value in prompt_metadata.items()}) return [condition] - def _get_message_pieces_attack_conditions(self, *, attack_id: str) -> Any: - """ - Generate SQLAlchemy filter conditions for filtering by attack ID. - - Returns: - Any: A SQLAlchemy text condition with bound parameters. - """ - return text("JSON_EXTRACT(attack_identifier, '$.hash') = :attack_id").bindparams(attack_id=str(attack_id)) - def _get_seed_metadata_conditions(self, *, metadata: dict[str, Union[str, int]]) -> Any: """ Generate SQLAlchemy filter conditions for filtering seed prompts by metadata. @@ -199,6 +190,93 @@ def _get_seed_metadata_conditions(self, *, metadata: dict[str, Union[str, int]]) # Note: We do NOT convert values to string here, to allow integer comparison in JSON return text(json_conditions).bindparams(**dict(metadata.items())) + def _get_condition_json_property_match( + self, + *, + json_column: InstrumentedAttribute[Any], + property_path: str, + value_to_match: str, + partial_match: bool = False, + case_sensitive: bool = False, + ) -> Any: + """ + Return a SQLite DB condition for matching a value at a given path within a JSON object. + + Args: + json_column (InstrumentedAttribute[Any]): The JSON-backed model field to query. + property_path (str): The JSON path for the property to match. + value_to_match (str): The string value that must match the extracted JSON property value. + partial_match (bool): Whether to perform a substring match. + case_sensitive (bool): Whether the match should be case-sensitive. Defaults to False. + + Returns: + Any: A SQLAlchemy condition for the backend-specific JSON query. + """ + raw = func.json_extract(json_column, property_path) + if case_sensitive: + extracted_value, target = raw, value_to_match + else: + extracted_value, target = func.lower(raw), value_to_match.lower() + + if partial_match: + escaped = target.replace("%", "\\%").replace("_", "\\_") + return extracted_value.like(f"%{escaped}%", escape="\\") + return extracted_value == target + + def _get_condition_json_array_match( + self, + *, + json_column: InstrumentedAttribute[Any], + property_path: str, + array_element_path: str | None = None, + array_to_match: Sequence[str], + ) -> Any: + """ + Return a SQLite DB condition for matching an array at a given path within a JSON object. + + Args: + json_column (InstrumentedAttribute[Any]): The JSON-backed SQLAlchemy field to query. + property_path (str): The JSON path for the target array. + array_element_path (Optional[str]): An optional JSON path applied to each array item before matching. + array_to_match (Sequence[str]): The array that must match the extracted JSON array values. + For a match, ALL values in this array must be present in the JSON array. + If `array_to_match` is empty, the condition matches only if the target is also an empty array or None. + + Returns: + Any: A database-specific SQLAlchemy condition. + """ + array_expr = func.json_extract(json_column, property_path) + if len(array_to_match) == 0: + return or_( + json_column.is_(None), + array_expr.is_(None), + array_expr == "[]", + ) + + uid = self._uid() + table_name = json_column.class_.__tablename__ + column_name = json_column.key + pp_param = f"property_path_{uid}" + sp_param = f"array_element_path_{uid}" + value_expression = f"LOWER(json_extract(value, :{sp_param}))" if array_element_path else "LOWER(value)" + + conditions = [] + bindparams_dict: dict[str, str] = {pp_param: property_path} + if array_element_path: + bindparams_dict[sp_param] = array_element_path + + for index, match_value in enumerate(array_to_match): + mv_param = f"mv_{uid}_{index}" + conditions.append( + f"""EXISTS(SELECT 1 FROM json_each( + json_extract("{table_name}".{column_name}, :{pp_param})) + WHERE {value_expression} = :{mv_param})""" + ) + bindparams_dict[mv_param] = match_value.lower() + + combined = " AND ".join(conditions) + return text(combined).bindparams(**bindparams_dict) + def add_message_pieces_to_memory(self, *, message_pieces: Sequence[MessagePiece]) -> None: """ Insert a list of message pieces into the memory storage. @@ -526,59 +604,6 @@ def _get_attack_result_label_condition(self, *, labels: dict[str, str]) -> Any: ) return labels_subquery # noqa: RET504 - def _get_attack_result_attack_class_condition(self, *, attack_class: str) -> Any: - """ - SQLite implementation for filtering AttackResults by attack class. - Uses json_extract() on the atomic_attack_identifier JSON column. - - Returns: - Any: A SQLAlchemy condition for filtering by attack class. - """ - return ( - func.json_extract(AttackResultEntry.atomic_attack_identifier, "$.children.attack.class_name") - == attack_class - ) - - def _get_attack_result_converter_classes_condition(self, *, converter_classes: Sequence[str]) -> Any: - """ - SQLite implementation for filtering AttackResults by converter classes. - - Uses json_extract() on the atomic_attack_identifier JSON column. - - When converter_classes is empty, matches attacks with no converters - (children.attack.children.request_converters is absent or null in the JSON). - When non-empty, uses json_each() to check all specified classes are present - (AND logic, case-insensitive). - - Returns: - Any: A SQLAlchemy condition for filtering by converter classes. - """ - if len(converter_classes) == 0: - # Explicitly "no converters": match attacks where the converter list - # is absent, null, or empty in the stored JSON. - converter_json = func.json_extract( - AttackResultEntry.atomic_attack_identifier, - "$.children.attack.children.request_converters", - ) - return or_( - AttackResultEntry.atomic_attack_identifier.is_(None), - converter_json.is_(None), - converter_json == "[]", - ) - - conditions = [] - for i, cls in enumerate(converter_classes): - param_name = f"conv_cls_{i}" - conditions.append( - text( - f"""EXISTS(SELECT 1 FROM json_each( - json_extract("AttackResultEntries".atomic_attack_identifier, - '$.children.attack.children.request_converters')) - WHERE LOWER(json_extract(value, '$.class_name')) = :{param_name})""" - ).bindparams(**{param_name: cls.lower()}) - ) - return and_(*conditions) - def get_unique_attack_class_names(self) -> list[str]: """ SQLite implementation: extract unique class_name values from @@ -710,27 +735,3 @@ def _get_scenario_result_label_condition(self, *, labels: dict[str, str]) -> Any return and_( *[func.json_extract(ScenarioResultEntry.labels, f"$.{key}") == value for key, value in labels.items()] ) - - def _get_scenario_result_target_endpoint_condition(self, *, endpoint: str) -> Any: - """ - SQLite implementation for filtering ScenarioResults by target endpoint. - Uses json_extract() function specific to SQLite. - - Returns: - Any: A SQLAlchemy subquery for filtering by target endpoint. - """ - return func.lower(func.json_extract(ScenarioResultEntry.objective_target_identifier, "$.endpoint")).like( - f"%{endpoint.lower()}%" - ) - - def _get_scenario_result_target_model_condition(self, *, model_name: str) -> Any: - """ - SQLite implementation for filtering ScenarioResults by target model name. - Uses json_extract() function specific to SQLite. - - Returns: - Any: A SQLAlchemy subquery for filtering by target model name. - """ - return func.lower(func.json_extract(ScenarioResultEntry.objective_target_identifier, "$.model_name")).like( - f"%{model_name.lower()}%" - ) diff --git a/tests/unit/memory/memory_interface/test_interface_attack_results.py b/tests/unit/memory/memory_interface/test_interface_attack_results.py index 91367c3a1c..03600e3260 100644 --- a/tests/unit/memory/memory_interface/test_interface_attack_results.py +++ b/tests/unit/memory/memory_interface/test_interface_attack_results.py @@ -9,6 +9,7 @@ from pyrit.identifiers import ComponentIdentifier from pyrit.identifiers.atomic_attack_identifier import build_atomic_attack_identifier from pyrit.memory import MemoryInterface +from pyrit.memory.identifier_filters import IdentifierFilter, IdentifierType from pyrit.memory.memory_models import AttackResultEntry from pyrit.models import ( AttackOutcome, @@ -1352,3 +1353,64 @@ def test_get_unique_converter_class_names_skips_no_converters(sqlite_instance: M result = sqlite_instance.get_unique_converter_class_names() assert result == ["Base64Converter"] + + +def test_get_attack_results_by_attack_identifier_filter_hash(sqlite_instance: MemoryInterface): + """Test filtering attack results by AttackIdentifierFilter with hash.""" + ar1 = _make_attack_result_with_identifier("conv_1", "CrescendoAttack") + ar2 = _make_attack_result_with_identifier("conv_2", "ManualAttack") + sqlite_instance.add_attack_results_to_memory(attack_results=[ar1, ar2]) + + # Filter by hash of ar1's attack identifier + results = sqlite_instance.get_attack_results( + identifier_filters=[ + IdentifierFilter( + identifier_type=IdentifierType.ATTACK, + property_path="$.hash", + value_to_match=ar1.atomic_attack_identifier.hash, + partial_match=False, + ) + ], + ) + assert len(results) == 1 + assert results[0].conversation_id == "conv_1" + + +def test_get_attack_results_by_attack_identifier_filter_class_name(sqlite_instance: MemoryInterface): + """Test filtering attack results by AttackIdentifierFilter with class_name.""" + ar1 = _make_attack_result_with_identifier("conv_1", "CrescendoAttack") + ar2 = _make_attack_result_with_identifier("conv_2", "ManualAttack") + ar3 = _make_attack_result_with_identifier("conv_3", "CrescendoAttack") + sqlite_instance.add_attack_results_to_memory(attack_results=[ar1, ar2, ar3]) + + # Filter by partial attack class name + results = sqlite_instance.get_attack_results( + identifier_filters=[ + IdentifierFilter( + identifier_type=IdentifierType.ATTACK, + property_path="$.children.attack.class_name", + value_to_match="Crescendo", + partial_match=True, + ) + ], + ) + assert len(results) == 2 + assert {r.conversation_id for r in results} == {"conv_1", "conv_3"} + + +def test_get_attack_results_by_attack_identifier_filter_no_match(sqlite_instance: MemoryInterface): + """Test that AttackIdentifierFilter returns empty when nothing matches.""" + ar1 = _make_attack_result_with_identifier("conv_1", "CrescendoAttack") + sqlite_instance.add_attack_results_to_memory(attack_results=[ar1]) + + results = sqlite_instance.get_attack_results( + identifier_filters=[ + IdentifierFilter( + identifier_type=IdentifierType.ATTACK, + property_path="$.hash", + value_to_match="nonexistent_hash", + partial_match=False, + ) + ], + ) + assert len(results) == 0 diff --git a/tests/unit/memory/memory_interface/test_interface_prompts.py b/tests/unit/memory/memory_interface/test_interface_prompts.py index 67a4292f87..e85ce02739 100644 --- a/tests/unit/memory/memory_interface/test_interface_prompts.py +++ b/tests/unit/memory/memory_interface/test_interface_prompts.py @@ -14,6 +14,7 @@ from pyrit.executor.attack.single_turn.prompt_sending import PromptSendingAttack from pyrit.identifiers import ComponentIdentifier from pyrit.memory import MemoryInterface, PromptMemoryEntry +from pyrit.memory.identifier_filters import IdentifierFilter, IdentifierType from pyrit.models import ( Message, MessagePiece, @@ -1248,3 +1249,205 @@ def test_get_request_from_response_raises_error_for_sequence_less_than_one(sqlit with pytest.raises(ValueError, match="The provided request does not have a preceding request \\(sequence < 1\\)."): sqlite_instance.get_request_from_response(response=response_without_request) + + +def test_get_message_pieces_by_attack_identifier_filter(sqlite_instance: MemoryInterface): + attack1 = PromptSendingAttack(objective_target=get_mock_target()) + attack2 = PromptSendingAttack(objective_target=get_mock_target("Target2")) + + entries = [ + PromptMemoryEntry( + entry=MessagePiece( + role="user", + original_value="Hello 1", + attack_identifier=attack1.get_identifier(), + ) + ), + PromptMemoryEntry( + entry=MessagePiece( + role="assistant", + original_value="Hello 2", + attack_identifier=attack2.get_identifier(), + ) + ), + ] + + sqlite_instance._insert_entries(entries=entries) + + # Filter by exact attack hash + results = sqlite_instance.get_message_pieces( + identifier_filters=[ + IdentifierFilter( + identifier_type=IdentifierType.ATTACK, + property_path="$.hash", + value_to_match=attack1.get_identifier().hash, + partial_match=False, + ) + ], + ) + assert len(results) == 1 + assert results[0].original_value == "Hello 1" + + # No match + results = sqlite_instance.get_message_pieces( + identifier_filters=[ + IdentifierFilter( + identifier_type=IdentifierType.ATTACK, + property_path="$.hash", + value_to_match="nonexistent_hash", + partial_match=False, + ) + ], + ) + assert len(results) == 0 + + +def test_get_message_pieces_by_target_identifier_filter(sqlite_instance: MemoryInterface): + target_id_1 = ComponentIdentifier( + class_name="OpenAIChatTarget", + class_module="pyrit.prompt_target", + params={"endpoint": "https://api.openai.com", "model_name": "gpt-4"}, + ) + target_id_2 = ComponentIdentifier( + class_name="AzureChatTarget", + class_module="pyrit.prompt_target", + params={"endpoint": "https://azure.com", "model_name": "gpt-3.5"}, + ) + + entries = [ + PromptMemoryEntry( + entry=MessagePiece( + role="user", + original_value="Hello OpenAI", + prompt_target_identifier=target_id_1, + ) + ), + PromptMemoryEntry( + entry=MessagePiece( + role="user", + original_value="Hello Azure", + prompt_target_identifier=target_id_2, + ) + ), + ] + + sqlite_instance._insert_entries(entries=entries) + + # Filter by target hash + results = sqlite_instance.get_message_pieces( + identifier_filters=[ + IdentifierFilter( + identifier_type=IdentifierType.TARGET, + property_path="$.hash", + value_to_match=target_id_1.hash, + partial_match=False, + ) + ], + ) + assert len(results) == 1 + assert results[0].original_value == "Hello OpenAI" + + # Filter by endpoint partial match + results = sqlite_instance.get_message_pieces( + identifier_filters=[ + IdentifierFilter( + identifier_type=IdentifierType.TARGET, + property_path="$.endpoint", + value_to_match="openai", + partial_match=True, + ) + ], + ) + assert len(results) == 1 + assert results[0].original_value == "Hello OpenAI" + + # No match + results = sqlite_instance.get_message_pieces( + identifier_filters=[ + IdentifierFilter( + identifier_type=IdentifierType.TARGET, + property_path="$.hash", + value_to_match="nonexistent", + partial_match=False, + ) + ], + ) + assert len(results) == 0 + + +def test_get_message_pieces_by_converter_identifier_filter_with_array_element_path(sqlite_instance: MemoryInterface): + converter_a = ComponentIdentifier( + class_name="Base64Converter", + class_module="pyrit.prompt_converter", + ) + converter_b = ComponentIdentifier( + class_name="ROT13Converter", + class_module="pyrit.prompt_converter", + ) + + entries = [ + PromptMemoryEntry( + entry=MessagePiece( + role="user", + original_value="With Base64", + converter_identifiers=[converter_a], + ) + ), + PromptMemoryEntry( + entry=MessagePiece( + role="user", + original_value="With both converters", + converter_identifiers=[converter_a, converter_b], + ) + ), + PromptMemoryEntry( + entry=MessagePiece( + role="user", + original_value="No converters", + ) + ), + ] + + sqlite_instance._insert_entries(entries=entries) + + # Filter by converter class_name using array_element_path (array element matching) + results = sqlite_instance.get_message_pieces( + identifier_filters=[ + IdentifierFilter( + identifier_type=IdentifierType.CONVERTER, + property_path="$", + array_element_path="$.class_name", + value_to_match="Base64Converter", + ) + ], + ) + assert len(results) == 2 + original_values = {r.original_value for r in results} + assert original_values == {"With Base64", "With both converters"} + + # Filter by ROT13Converter — only the entry with both converters + results = sqlite_instance.get_message_pieces( + identifier_filters=[ + IdentifierFilter( + identifier_type=IdentifierType.CONVERTER, + property_path="$", + array_element_path="$.class_name", + value_to_match="ROT13Converter", + ) + ], + ) + assert len(results) == 1 + assert results[0].original_value == "With both converters" + + # No match + results = sqlite_instance.get_message_pieces( + identifier_filters=[ + IdentifierFilter( + identifier_type=IdentifierType.CONVERTER, + property_path="$", + array_element_path="$.class_name", + value_to_match="NonexistentConverter", + ) + ], + ) + assert len(results) == 0 diff --git a/tests/unit/memory/memory_interface/test_interface_scenario_results.py b/tests/unit/memory/memory_interface/test_interface_scenario_results.py index e513e8b873..3696705c5a 100644 --- a/tests/unit/memory/memory_interface/test_interface_scenario_results.py +++ b/tests/unit/memory/memory_interface/test_interface_scenario_results.py @@ -9,6 +9,7 @@ from pyrit.identifiers import ComponentIdentifier from pyrit.memory import MemoryInterface +from pyrit.memory.identifier_filters import IdentifierFilter, IdentifierType from pyrit.models import ( AttackOutcome, AttackResult, @@ -645,3 +646,125 @@ def test_combined_filters(sqlite_instance: MemoryInterface): assert len(results) == 1 assert results[0].scenario_identifier.pyrit_version == "0.5.0" assert "gpt-4" in results[0].objective_target_identifier.params["model_name"] + + +def test_get_scenario_results_by_target_identifier_filter_hash(sqlite_instance: MemoryInterface): + """Test filtering scenario results by identifier filter.""" + target_id_1 = ComponentIdentifier( + class_name="OpenAI", + class_module="test", + params={"endpoint": "https://api.openai.com", "model_name": "gpt-4"}, + ) + target_id_2 = ComponentIdentifier( + class_name="Azure", + class_module="test", + params={"endpoint": "https://azure.com", "model_name": "gpt-3.5"}, + ) + + attack_result1 = create_attack_result("conv_1", "Objective 1") + attack_result2 = create_attack_result("conv_2", "Objective 2") + sqlite_instance.add_attack_results_to_memory(attack_results=[attack_result1, attack_result2]) + + scenario1 = ScenarioResult( + scenario_identifier=ScenarioIdentifier(name="Scenario OpenAI", scenario_version=1), + objective_target_identifier=target_id_1, + attack_results={"Attack1": [attack_result1]}, + objective_scorer_identifier=get_mock_scorer_identifier(), + ) + scenario2 = ScenarioResult( + scenario_identifier=ScenarioIdentifier(name="Scenario Azure", scenario_version=1), + objective_target_identifier=target_id_2, + attack_results={"Attack2": [attack_result2]}, + objective_scorer_identifier=get_mock_scorer_identifier(), + ) + sqlite_instance.add_scenario_results_to_memory(scenario_results=[scenario1, scenario2]) + + # Filter by target hash + results = sqlite_instance.get_scenario_results( + identifier_filters=[ + IdentifierFilter( + identifier_type=IdentifierType.TARGET, + property_path="$.hash", + value_to_match=target_id_1.hash, + partial_match=False, + ) + ], + ) + assert len(results) == 1 + assert results[0].scenario_identifier.name == "Scenario OpenAI" + + +def test_get_scenario_results_by_target_identifier_filter_endpoint(sqlite_instance: MemoryInterface): + """Test filtering scenario results by identifier filter with endpoint.""" + target_id_1 = ComponentIdentifier( + class_name="OpenAI", + class_module="test", + params={"endpoint": "https://api.openai.com", "model_name": "gpt-4"}, + ) + target_id_2 = ComponentIdentifier( + class_name="Azure", + class_module="test", + params={"endpoint": "https://azure.com", "model_name": "gpt-3.5"}, + ) + + attack_result1 = create_attack_result("conv_1", "Objective 1") + attack_result2 = create_attack_result("conv_2", "Objective 2") + sqlite_instance.add_attack_results_to_memory(attack_results=[attack_result1, attack_result2]) + + scenario1 = ScenarioResult( + scenario_identifier=ScenarioIdentifier(name="Scenario OpenAI", scenario_version=1), + objective_target_identifier=target_id_1, + attack_results={"Attack1": [attack_result1]}, + objective_scorer_identifier=get_mock_scorer_identifier(), + ) + scenario2 = ScenarioResult( + scenario_identifier=ScenarioIdentifier(name="Scenario Azure", scenario_version=1), + objective_target_identifier=target_id_2, + attack_results={"Attack2": [attack_result2]}, + objective_scorer_identifier=get_mock_scorer_identifier(), + ) + sqlite_instance.add_scenario_results_to_memory(scenario_results=[scenario1, scenario2]) + + # Filter by endpoint partial match + results = sqlite_instance.get_scenario_results( + identifier_filters=[ + IdentifierFilter( + identifier_type=IdentifierType.TARGET, + property_path="$.endpoint", + value_to_match="openai", + partial_match=True, + ) + ], + ) + assert len(results) == 1 + assert results[0].scenario_identifier.name == "Scenario OpenAI" + + +def test_get_scenario_results_by_target_identifier_filter_no_match(sqlite_instance: MemoryInterface): + """Test that TargetIdentifierFilter returns empty when nothing matches.""" + attack_result1 = create_attack_result("conv_1", "Objective 1") + sqlite_instance.add_attack_results_to_memory(attack_results=[attack_result1]) + + scenario1 = ScenarioResult( + scenario_identifier=ScenarioIdentifier(name="Test Scenario", scenario_version=1), + objective_target_identifier=ComponentIdentifier( + class_name="OpenAI", + class_module="test", + params={"endpoint": "https://api.openai.com"}, + ), + attack_results={"Attack1": [attack_result1]}, + objective_scorer_identifier=get_mock_scorer_identifier(), + ) + sqlite_instance.add_scenario_results_to_memory(scenario_results=[scenario1]) + + results = sqlite_instance.get_scenario_results( + identifier_filters=[ + IdentifierFilter( + identifier_type=IdentifierType.TARGET, + property_path="$.hash", + value_to_match="nonexistent_hash", + partial_match=False, + ) + ], + ) + assert len(results) == 0 diff --git a/tests/unit/memory/memory_interface/test_interface_scores.py b/tests/unit/memory/memory_interface/test_interface_scores.py index 6087af1418..1b2f79f47b 100644 --- a/tests/unit/memory/memory_interface/test_interface_scores.py +++ b/tests/unit/memory/memory_interface/test_interface_scores.py @@ -13,6 +13,7 @@ from pyrit.executor.attack.single_turn.prompt_sending import PromptSendingAttack from pyrit.identifiers import ComponentIdentifier from pyrit.memory import MemoryInterface, PromptMemoryEntry +from pyrit.memory.identifier_filters import IdentifierFilter, IdentifierType from pyrit.models import ( MessagePiece, Score, @@ -227,3 +228,89 @@ async def test_get_seeds_no_filters(sqlite_instance: MemoryInterface): assert len(result) == 2 assert result[0].value == "prompt1" assert result[1].value == "prompt2" + + +def test_get_scores_by_scorer_identifier_filter( + sqlite_instance: MemoryInterface, + sample_conversation_entries: Sequence[PromptMemoryEntry], +): + prompt_id = sample_conversation_entries[0].id + sqlite_instance._insert_entries(entries=sample_conversation_entries) + + score_a = Score( + score_value="0.9", + score_value_description="High", + score_type="float_scale", + score_category=["cat_a"], + score_rationale="Rationale A", + score_metadata={}, + scorer_class_identifier=_test_scorer_id("ScorerAlpha"), + message_piece_id=prompt_id, + ) + score_b = Score( + score_value="0.1", + score_value_description="Low", + score_type="float_scale", + score_category=["cat_b"], + score_rationale="Rationale B", + score_metadata={}, + scorer_class_identifier=_test_scorer_id("ScorerBeta"), + message_piece_id=prompt_id, + ) + + sqlite_instance.add_scores_to_memory(scores=[score_a, score_b]) + + # Filter by exact class_name match + results = sqlite_instance.get_scores( + identifier_filters=[ + IdentifierFilter( + identifier_type=IdentifierType.SCORER, + property_path="$.class_name", + value_to_match="ScorerAlpha", + partial_match=False, + ) + ], + ) + assert len(results) == 1 + assert results[0].score_value == "0.9" + + # Filter by partial class_name match + results = sqlite_instance.get_scores( + identifier_filters=[ + IdentifierFilter( + identifier_type=IdentifierType.SCORER, + property_path="$.class_name", + value_to_match="Scorer", + partial_match=True, + ) + ], + ) + assert len(results) == 2 + + # Filter by hash + scorer_hash = score_a.scorer_class_identifier.hash + results = sqlite_instance.get_scores( + identifier_filters=[ + IdentifierFilter( + identifier_type=IdentifierType.SCORER, + property_path="$.hash", + value_to_match=scorer_hash, + partial_match=False, + ) + ], + ) + assert len(results) == 1 + assert results[0].score_value == "0.9" + + # No match + results = sqlite_instance.get_scores( + identifier_filters=[ + IdentifierFilter( + identifier_type=IdentifierType.SCORER, + property_path="$.class_name", + value_to_match="NonExistent", + partial_match=False, + ) + ], + ) + assert len(results) == 0 diff --git a/tests/unit/memory/test_azure_sql_memory.py b/tests/unit/memory/test_azure_sql_memory.py index 5723800396..e0d488a61f 100644 --- a/tests/unit/memory/test_azure_sql_memory.py +++ b/tests/unit/memory/test_azure_sql_memory.py @@ -326,6 +326,33 @@ def test_update_labels_by_conversation_id(memory_interface: AzureSQLMemory): assert updated_entry.labels["test1"] == "change" +@pytest.mark.parametrize( + "partial_match, expected_value", + [ + (False, "testvalue"), + (True, "%testvalue%"), + ], + ids=["exact_match", "partial_match"], +) +def test_get_condition_json_property_match_bind_params( + memory_interface: AzureSQLMemory, partial_match: bool, expected_value: str +): + condition = memory_interface._get_condition_json_property_match( + json_column=PromptMemoryEntry.labels, + property_path="$.key", + value_to_match="TestValue", + partial_match=partial_match, + ) + # Extract the compiled bind parameters (param names include a random uid suffix) + params = condition.compile().params + pp_params = {k: v for k, v in params.items() if k.startswith("pp_")} + mv_params = {k: v for k, v in params.items() if k.startswith("mv_")} + assert len(pp_params) == 1 + assert list(pp_params.values())[0] == "$.key" + assert len(mv_params) == 1 + assert list(mv_params.values())[0] == expected_value + + def test_update_prompt_metadata_by_conversation_id(memory_interface: AzureSQLMemory): # Insert a test entry entry = PromptMemoryEntry( diff --git a/tests/unit/memory/test_identifier_filters.py b/tests/unit/memory/test_identifier_filters.py new file mode 100644 index 0000000000..8316ef08ba --- /dev/null +++ b/tests/unit/memory/test_identifier_filters.py @@ -0,0 +1,59 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +import pytest + +from pyrit.memory import MemoryInterface +from pyrit.memory.identifier_filters import IdentifierFilter, IdentifierType +from pyrit.memory.memory_models import AttackResultEntry + + +@pytest.mark.parametrize( + "array_element_path, partial_match, case_sensitive", + [ + ("$.class_name", True, False), + ("$.class_name", False, True), + ("$.class_name", True, True), + ], + ids=["array_element_path+partial_match", "array_element_path+case_sensitive", "array_element_path+both"], +) +def test_identifier_filter_array_element_path_with_partial_or_case_sensitive_raises( + array_element_path: str, partial_match: bool, case_sensitive: bool +): + with pytest.raises(ValueError, match="Cannot use array_element_path with partial_match or case_sensitive"): + IdentifierFilter( + identifier_type=IdentifierType.ATTACK, + property_path="$.children", + value_to_match="test", + array_element_path=array_element_path, + partial_match=partial_match, + case_sensitive=case_sensitive, + ) + + +def test_identifier_filter_valid_with_array_element_path(): + f = IdentifierFilter( + identifier_type=IdentifierType.CONVERTER, + property_path="$", + value_to_match="Base64Converter", + array_element_path="$.class_name", + ) + assert f.array_element_path == "$.class_name" + assert not f.partial_match + assert not f.case_sensitive + + +def test_build_identifier_filter_conditions_unsupported_type_raises(sqlite_instance: MemoryInterface): + filters = [ + IdentifierFilter( + identifier_type=IdentifierType.SCORER, + property_path="$.class_name", + value_to_match="MyScorer", + ) + ] + with pytest.raises(ValueError, match="does not support identifier type"): + sqlite_instance._build_identifier_filter_conditions( + identifier_filters=filters, + identifier_column_map={IdentifierType.ATTACK: AttackResultEntry.atomic_attack_identifier}, + caller="test_caller", + )