From b63361c44359a11a66daabf5e7f202b926817cb4 Mon Sep 17 00:00:00 2001 From: Behnam Ousat Date: Wed, 15 Apr 2026 16:55:10 -0700 Subject: [PATCH 1/6] Add labels to attack results --- pyrit/backend/mappers/attack_mappers.py | 5 +- pyrit/backend/services/attack_service.py | 1 + .../attack/multi_turn/chunked_request.py | 1 + pyrit/executor/attack/multi_turn/crescendo.py | 1 + .../attack/multi_turn/multi_prompt_sending.py | 1 + .../executor/attack/multi_turn/red_teaming.py | 1 + .../attack/multi_turn/tree_of_attacks.py | 1 + .../attack/single_turn/prompt_sending.py | 1 + .../attack/single_turn/skeleton_key.py | 1 + pyrit/executor/benchmark/fairness_bias.py | 1 + pyrit/memory/azure_sql_memory.py | 28 +++-- pyrit/memory/memory_interface.py | 8 +- pyrit/memory/memory_models.py | 4 + pyrit/memory/sqlite_memory.py | 21 +--- pyrit/models/attack_result.py | 3 + tests/unit/backend/test_attack_service.py | 5 +- tests/unit/backend/test_mappers.py | 22 +++- .../test_interface_attack_results.py | 117 ++++++++---------- tests/unit/scenario/test_scenario.py | 1 + 19 files changed, 122 insertions(+), 101 deletions(-) diff --git a/pyrit/backend/mappers/attack_mappers.py b/pyrit/backend/mappers/attack_mappers.py index 0245e2af12..c37dd77fd9 100644 --- a/pyrit/backend/mappers/attack_mappers.py +++ b/pyrit/backend/mappers/attack_mappers.py @@ -197,8 +197,11 @@ def attack_result_to_summary( """ message_count = stats.message_count last_preview = stats.last_message_preview - labels = dict(stats.labels) if stats.labels else {} + # Merge attack-result labels with conversation-level labels. + # Conversation labels take precedence on key collision. + labels = dict(ar.labels) if ar.labels else {} + labels.update(stats.labels or {}) created_str = ar.metadata.get("created_at") updated_str = ar.metadata.get("updated_at") created_at = datetime.fromisoformat(created_str) if created_str else datetime.now(timezone.utc) diff --git a/pyrit/backend/services/attack_service.py b/pyrit/backend/services/attack_service.py index 8852071a66..b7c3635e88 100644 --- a/pyrit/backend/services/attack_service.py +++ b/pyrit/backend/services/attack_service.py @@ -308,6 +308,7 @@ async def create_attack_async(self, *, request: CreateAttackRequest) -> CreateAt "created_at": now.isoformat(), "updated_at": now.isoformat(), }, + labels=labels, ) # Store in memory diff --git a/pyrit/executor/attack/multi_turn/chunked_request.py b/pyrit/executor/attack/multi_turn/chunked_request.py index 1a70c89195..ed95c5d226 100644 --- a/pyrit/executor/attack/multi_turn/chunked_request.py +++ b/pyrit/executor/attack/multi_turn/chunked_request.py @@ -325,6 +325,7 @@ async def _perform_async(self, *, context: ChunkedRequestAttackContext) -> Attac outcome_reason=outcome_reason, executed_turns=context.executed_turns, metadata={"combined_chunks": combined_value, "chunk_count": len(context.chunk_responses)}, + labels=context.memory_labels, ) def _determine_attack_outcome( diff --git a/pyrit/executor/attack/multi_turn/crescendo.py b/pyrit/executor/attack/multi_turn/crescendo.py index 4a180d5df3..f137b322f3 100644 --- a/pyrit/executor/attack/multi_turn/crescendo.py +++ b/pyrit/executor/attack/multi_turn/crescendo.py @@ -402,6 +402,7 @@ async def _perform_async(self, *, context: CrescendoAttackContext) -> CrescendoA last_response=context.last_response.get_piece() if context.last_response else None, last_score=context.last_score, related_conversations=context.related_conversations, + labels=context.memory_labels, ) # setting metadata for backtrack count result.backtrack_count = context.backtrack_count diff --git a/pyrit/executor/attack/multi_turn/multi_prompt_sending.py b/pyrit/executor/attack/multi_turn/multi_prompt_sending.py index a9d4b75adc..8447737578 100644 --- a/pyrit/executor/attack/multi_turn/multi_prompt_sending.py +++ b/pyrit/executor/attack/multi_turn/multi_prompt_sending.py @@ -295,6 +295,7 @@ async def _perform_async(self, *, context: MultiTurnAttackContext[Any]) -> Attac outcome=outcome, outcome_reason=outcome_reason, executed_turns=context.executed_turns, + labels=context.memory_labels, ) def _determine_attack_outcome( diff --git a/pyrit/executor/attack/multi_turn/red_teaming.py b/pyrit/executor/attack/multi_turn/red_teaming.py index a8778f664a..1feec20586 100644 --- a/pyrit/executor/attack/multi_turn/red_teaming.py +++ b/pyrit/executor/attack/multi_turn/red_teaming.py @@ -322,6 +322,7 @@ async def _perform_async(self, *, context: MultiTurnAttackContext[Any]) -> Attac last_response=context.last_response.get_piece() if context.last_response else None, last_score=context.last_score, related_conversations=context.related_conversations, + labels=context.memory_labels, ) async def _teardown_async(self, *, context: MultiTurnAttackContext[Any]) -> None: diff --git a/pyrit/executor/attack/multi_turn/tree_of_attacks.py b/pyrit/executor/attack/multi_turn/tree_of_attacks.py index e92bd1cf67..f6ccc4ed64 100644 --- a/pyrit/executor/attack/multi_turn/tree_of_attacks.py +++ b/pyrit/executor/attack/multi_turn/tree_of_attacks.py @@ -2082,6 +2082,7 @@ def _create_attack_result( last_response=last_response, last_score=context.best_objective_score, related_conversations=context.related_conversations, + labels=context.memory_labels, ) # Set attack-specific metadata using properties diff --git a/pyrit/executor/attack/single_turn/prompt_sending.py b/pyrit/executor/attack/single_turn/prompt_sending.py index 07f1d670fa..cdb2d4b619 100644 --- a/pyrit/executor/attack/single_turn/prompt_sending.py +++ b/pyrit/executor/attack/single_turn/prompt_sending.py @@ -238,6 +238,7 @@ async def _perform_async(self, *, context: SingleTurnAttackContext[Any]) -> Atta outcome=outcome, outcome_reason=outcome_reason, executed_turns=1, + labels=context.memory_labels, ) def _determine_attack_outcome( diff --git a/pyrit/executor/attack/single_turn/skeleton_key.py b/pyrit/executor/attack/single_turn/skeleton_key.py index 683614dce5..40cc5cc302 100644 --- a/pyrit/executor/attack/single_turn/skeleton_key.py +++ b/pyrit/executor/attack/single_turn/skeleton_key.py @@ -181,4 +181,5 @@ def _create_skeleton_key_failure_result(self, *, context: SingleTurnAttackContex outcome=AttackOutcome.FAILURE, outcome_reason="Skeleton key prompt was filtered or failed", executed_turns=1, + labels=context.memory_labels, ) diff --git a/pyrit/executor/benchmark/fairness_bias.py b/pyrit/executor/benchmark/fairness_bias.py index 05bb424c17..63d33f4639 100644 --- a/pyrit/executor/benchmark/fairness_bias.py +++ b/pyrit/executor/benchmark/fairness_bias.py @@ -200,6 +200,7 @@ async def _perform_async(self, *, context: FairnessBiasBenchmarkContext) -> Atta atomic_attack_identifier=build_atomic_attack_identifier( attack_identifier=ComponentIdentifier.of(self), ), + labels=context.memory_labels, ) return last_attack_result diff --git a/pyrit/memory/azure_sql_memory.py b/pyrit/memory/azure_sql_memory.py index 207a7f7f98..15586f7152 100644 --- a/pyrit/memory/azure_sql_memory.py +++ b/pyrit/memory/azure_sql_memory.py @@ -446,7 +446,8 @@ def _get_attack_result_harm_category_condition(self, *, targeted_harm_categories def _get_attack_result_label_condition(self, *, labels: dict[str, str]) -> Any: """ - Get the SQL Azure implementation for filtering AttackResults by labels. + Get the SQL Azure implementation for filtering AttackResults by labels + stored directly on the AttackResultEntry. Uses JSON_VALUE() function specific to SQL Azure with parameterized queries. @@ -454,24 +455,27 @@ def _get_attack_result_label_condition(self, *, labels: dict[str, str]) -> Any: labels (dict[str, str]): Dictionary of label key-value pairs to filter by. Returns: - Any: SQLAlchemy exists subquery condition with bound parameters. + Any: SQLAlchemy condition with bound parameters. """ # Build JSON conditions for all labels with parameterized queries label_conditions = [] bindparams_dict = {} - for key, value in labels.items(): - param_name = f"label_{key}" - label_conditions.append(f"JSON_VALUE(labels, '$.{key}') = :{param_name}") - bindparams_dict[param_name] = str(value) + for i, (key, value) in enumerate(labels.items()): + path_param = f"label_path_{i}" + value_param = f"label_val_{i}" + label_conditions.append( + f'JSON_VALUE("AttackResultEntries".labels, :{path_param}) = :{value_param}' + ) + bindparams_dict[path_param] = f"$.{key}" + bindparams_dict[value_param] = str(value) combined_conditions = " AND ".join(label_conditions) - return exists().where( - and_( - PromptMemoryEntry.conversation_id == AttackResultEntry.conversation_id, - PromptMemoryEntry.labels.isnot(None), - text(f"ISJSON(labels) = 1 AND {combined_conditions}").bindparams(**bindparams_dict), - ) + return and_( + AttackResultEntry.labels.isnot(None), + text( + f'ISJSON("AttackResultEntries".labels) = 1 AND {combined_conditions}' + ).bindparams(**bindparams_dict), ) def get_unique_attack_class_names(self) -> list[str]: diff --git a/pyrit/memory/memory_interface.py b/pyrit/memory/memory_interface.py index 023045a5c3..19c5939cbf 100644 --- a/pyrit/memory/memory_interface.py +++ b/pyrit/memory/memory_interface.py @@ -420,7 +420,7 @@ def _get_attack_result_harm_category_condition(self, *, targeted_harm_categories def _get_attack_result_label_condition(self, *, labels: dict[str, str]) -> Any: """ Return a database-specific condition for filtering AttackResults by labels - in the associated PromptMemoryEntry records. + stored directly on the AttackResultEntry. Args: labels: Dictionary of labels that must ALL be present. @@ -1505,9 +1505,9 @@ def get_attack_results( not necessarily one(s) that were found in the response. By providing a list, this means ALL categories in the list must be present. Defaults to None. - labels (Optional[dict[str, str]], optional): A dictionary of memory labels to filter results by. - These labels are associated with the prompts themselves, used for custom tagging and tracking. - Defaults to None. + labels (Optional[dict[str, str]], optional): A dictionary of labels to filter results by. + These labels are stored directly on the AttackResult. All specified key-value pairs + must be present (AND logic). Defaults to None. identifier_filters (Optional[Sequence[IdentifierFilter]], optional): A sequence of IdentifierFilter objects that allows filtering by various attack identifier JSON properties. Defaults to None. diff --git a/pyrit/memory/memory_models.py b/pyrit/memory/memory_models.py index 9376768bd4..511005174d 100644 --- a/pyrit/memory/memory_models.py +++ b/pyrit/memory/memory_models.py @@ -720,6 +720,7 @@ class AttackResultEntry(Base): outcome (AttackOutcome): The outcome of the attack, indicating success, failure, or undetermined. outcome_reason (str): Optional reason for the outcome, providing additional context. attack_metadata (dict[str, Any]): Metadata can be included as key-value pairs to provide extra context. + labels (dict[str, str]): Optional labels associated with the attack result entry. pruned_conversation_ids (List[str]): List of conversation IDs that were pruned from the attack. adversarial_chat_conversation_ids (List[str]): List of conversation IDs used for adversarial chat. timestamp (DateTime): The timestamp of the attack result entry. @@ -751,6 +752,7 @@ class AttackResultEntry(Base): ) outcome_reason = mapped_column(String, nullable=True) attack_metadata: Mapped[dict[str, Union[str, int, float, bool]]] = mapped_column(JSON, nullable=True) + labels: Mapped[dict[str, str]] = mapped_column(JSON, nullable=True) pruned_conversation_ids: Mapped[Optional[list[str]]] = mapped_column(JSON, nullable=True) adversarial_chat_conversation_ids: Mapped[Optional[list[str]]] = mapped_column(JSON, nullable=True) timestamp = mapped_column(DateTime, nullable=False) @@ -806,6 +808,7 @@ def __init__(self, *, entry: AttackResult): self.outcome = entry.outcome.value self.outcome_reason = entry.outcome_reason self.attack_metadata = self.filter_json_serializable_metadata(entry.metadata) + self.labels = entry.labels or {} # Persist conversation references by type self.pruned_conversation_ids = [ @@ -917,6 +920,7 @@ def get_attack_result(self) -> AttackResult: outcome_reason=self.outcome_reason, related_conversations=related_conversations, metadata=self.attack_metadata or {}, + labels=self.labels or {}, ) diff --git a/pyrit/memory/sqlite_memory.py b/pyrit/memory/sqlite_memory.py index bd376d67cd..3c4bb287fd 100644 --- a/pyrit/memory/sqlite_memory.py +++ b/pyrit/memory/sqlite_memory.py @@ -613,26 +613,17 @@ def _get_attack_result_harm_category_condition(self, *, targeted_harm_categories def _get_attack_result_label_condition(self, *, labels: dict[str, str]) -> Any: """ - SQLite implementation for filtering AttackResults by labels. + SQLite implementation for filtering AttackResults by labels + stored directly on the AttackResultEntry. Uses json_extract() function specific to SQLite. Returns: - Any: A SQLAlchemy subquery for filtering by labels. + Any: A SQLAlchemy condition for filtering by labels. """ - from sqlalchemy import and_, exists, func - - from pyrit.memory.memory_models import AttackResultEntry, PromptMemoryEntry - - labels_subquery = exists().where( - and_( - PromptMemoryEntry.conversation_id == AttackResultEntry.conversation_id, - PromptMemoryEntry.labels.isnot(None), - and_( - *[func.json_extract(PromptMemoryEntry.labels, f"$.{key}") == value for key, value in labels.items()] - ), - ) + return and_( + AttackResultEntry.labels.isnot(None), + *[func.json_extract(AttackResultEntry.labels, f"$.{key}") == value for key, value in labels.items()], ) - return labels_subquery # noqa: RET504 def get_unique_attack_class_names(self) -> list[str]: """ diff --git a/pyrit/models/attack_result.py b/pyrit/models/attack_result.py index a385ac36e7..5cbdf3c93e 100644 --- a/pyrit/models/attack_result.py +++ b/pyrit/models/attack_result.py @@ -87,6 +87,9 @@ class AttackResult(StrategyResult): # Arbitrary metadata metadata: dict[str, Any] = field(default_factory=dict) + # labels associated with this attack result + labels: dict[str, str] = field(default_factory=dict) + @property def attack_identifier(self) -> Optional[ComponentIdentifier]: """ diff --git a/tests/unit/backend/test_attack_service.py b/tests/unit/backend/test_attack_service.py index 11da01effe..83dd47bbb2 100644 --- a/tests/unit/backend/test_attack_service.py +++ b/tests/unit/backend/test_attack_service.py @@ -95,6 +95,9 @@ def make_attack_result( "created_at": created.isoformat(), "updated_at": updated.isoformat(), }, + labels={ + "test_ar_label": "test_ar_value" + }, ) @@ -320,7 +323,7 @@ async def test_list_attacks_includes_labels_in_summary(self, attack_service, moc result = await attack_service.list_attacks_async() assert len(result.items) == 1 - assert result.items[0].labels == {"env": "prod", "team": "red"} + assert result.items[0].labels == {"env": "prod", "team": "red", "test_ar_label": "test_ar_value"} @pytest.mark.asyncio async def test_list_attacks_filters_by_labels_directly(self, attack_service, mock_memory) -> None: diff --git a/tests/unit/backend/test_mappers.py b/tests/unit/backend/test_mappers.py index 0f483b3f10..f7c2495c71 100644 --- a/tests/unit/backend/test_mappers.py +++ b/tests/unit/backend/test_mappers.py @@ -81,6 +81,9 @@ def _make_attack_result( "created_at": now.isoformat(), "updated_at": now.isoformat(), }, + labels={ + "test_ar_label": "test_ar_value" + }, ) @@ -175,7 +178,7 @@ def test_labels_are_mapped(self) -> None: summary = attack_result_to_summary(ar, stats=stats) - assert summary.labels == {"env": "prod", "team": "red"} + assert summary.labels == {"env": "prod", "team": "red", "test_ar_label": "test_ar_value"} def test_labels_passed_through_without_normalization(self) -> None: """Test that labels are passed through as-is (DB stores canonical keys after migration).""" @@ -187,7 +190,21 @@ def test_labels_passed_through_without_normalization(self) -> None: summary = attack_result_to_summary(ar, stats=stats) - assert summary.labels == {"operator": "alice", "operation": "op_red", "env": "prod"} + assert summary.labels == { + "operator": "alice", "operation": "op_red", "env": "prod", "test_ar_label": "test_ar_value" + } + + def test_conversation_labels_take_precedence_on_collision(self) -> None: + """Test that conversation-level labels override attack-result labels on key collision.""" + ar = _make_attack_result() + stats = ConversationStats( + message_count=1, + labels={"test_ar_label": "conversation_wins"}, + ) + + summary = attack_result_to_summary(ar, stats=stats) + + assert summary.labels["test_ar_label"] == "conversation_wins" def test_outcome_success(self) -> None: """Test that success outcome is mapped.""" @@ -249,6 +266,7 @@ def test_converters_extracted_from_identifier(self) -> None: ), outcome=AttackOutcome.UNDETERMINED, metadata={"created_at": now.isoformat(), "updated_at": now.isoformat()}, + labels={"test_label": "test_value"}, ) summary = attack_result_to_summary(ar, stats=ConversationStats(message_count=0)) diff --git a/tests/unit/memory/memory_interface/test_interface_attack_results.py b/tests/unit/memory/memory_interface/test_interface_attack_results.py index 2e30ba368a..d998e7c02a 100644 --- a/tests/unit/memory/memory_interface/test_interface_attack_results.py +++ b/tests/unit/memory/memory_interface/test_interface_attack_results.py @@ -36,12 +36,18 @@ def create_message_piece(conversation_id: str, prompt_num: int, targeted_harm_ca ) -def create_attack_result(conversation_id: str, objective_num: int, outcome: AttackOutcome = AttackOutcome.SUCCESS): +def create_attack_result( + conversation_id: str, + objective_num: int, + outcome: AttackOutcome = AttackOutcome.SUCCESS, + labels: dict[str, str] | None = None, +): """Helper function to create AttackResult.""" return AttackResult( conversation_id=conversation_id, objective=f"Objective {objective_num}", outcome=outcome, + labels=labels or {}, ) @@ -780,17 +786,14 @@ def test_get_attack_results_by_harm_category_multiple(sqlite_instance: MemoryInt def test_get_attack_results_by_labels_single(sqlite_instance: MemoryInterface): """Test filtering attack results by single label.""" - # Create message pieces with labels - message_piece1 = create_message_piece("conv_1", 1, labels={"operation": "test_op", "operator": "roakey"}) - message_piece2 = create_message_piece("conv_2", 2, labels={"operation": "test_op"}) - message_piece3 = create_message_piece("conv_3", 3, labels={"operation": "other_op", "operator": "roakey"}) - - sqlite_instance.add_message_pieces_to_memory(message_pieces=[message_piece1, message_piece2, message_piece3]) - - # Create attack results - attack_result1 = create_attack_result("conv_1", 1, AttackOutcome.SUCCESS) - attack_result2 = create_attack_result("conv_2", 2, AttackOutcome.FAILURE) - attack_result3 = create_attack_result("conv_3", 3, AttackOutcome.SUCCESS) + # Create attack results with labels + attack_result1 = create_attack_result( + "conv_1", 1, AttackOutcome.SUCCESS, labels={"operation": "test_op", "operator": "roakey"} + ) + attack_result2 = create_attack_result("conv_2", 2, AttackOutcome.FAILURE, labels={"operation": "test_op"}) + attack_result3 = create_attack_result( + "conv_3", 3, AttackOutcome.SUCCESS, labels={"operation": "other_op", "operator": "roakey"} + ) sqlite_instance.add_attack_results_to_memory(attack_results=[attack_result1, attack_result2, attack_result3]) @@ -808,22 +811,20 @@ def test_get_attack_results_by_labels_single(sqlite_instance: MemoryInterface): def test_get_attack_results_by_labels_multiple(sqlite_instance: MemoryInterface): """Test filtering attack results by multiple labels (AND logic).""" - # Create message pieces with multiple labels using helper function - message_piece1 = create_message_piece( - "conv_1", 1, labels={"operation": "test_op", "operator": "roakey", "phase": "initial"} - ) - message_piece2 = create_message_piece( - "conv_2", 2, labels={"operation": "test_op", "operator": "roakey", "phase": "final"} - ) - message_piece3 = create_message_piece("conv_3", 3, labels={"operation": "test_op", "phase": "initial"}) - - sqlite_instance.add_message_pieces_to_memory(message_pieces=[message_piece1, message_piece2, message_piece3]) - - # Create attack results + # Create attack results with multiple labels attack_results = [ - create_attack_result("conv_1", 1, AttackOutcome.SUCCESS), - create_attack_result("conv_2", 2, AttackOutcome.SUCCESS), - create_attack_result("conv_3", 3, AttackOutcome.FAILURE), + create_attack_result( + "conv_1", 1, AttackOutcome.SUCCESS, + labels={"operation": "test_op", "operator": "roakey", "phase": "initial"}, + ), + create_attack_result( + "conv_2", 2, AttackOutcome.SUCCESS, + labels={"operation": "test_op", "operator": "roakey", "phase": "final"}, + ), + create_attack_result( + "conv_3", 3, AttackOutcome.FAILURE, + labels={"operation": "test_op", "phase": "initial"}, + ), ] sqlite_instance.add_attack_results_to_memory(attack_results=attack_results) @@ -842,30 +843,24 @@ def test_get_attack_results_by_labels_multiple(sqlite_instance: MemoryInterface) def test_get_attack_results_by_harm_category_and_labels(sqlite_instance: MemoryInterface): """Test filtering attack results by both harm categories and labels.""" - # Create message pieces with both harm categories and labels using helper function - message_piece1 = create_message_piece( - "conv_1", - 1, - targeted_harm_categories=["violence", "illegal"], - labels={"operation": "test_op", "operator": "roakey"}, - ) - message_piece2 = create_message_piece( - "conv_2", 2, targeted_harm_categories=["violence"], labels={"operation": "test_op", "operator": "roakey"} - ) - message_piece3 = create_message_piece( - "conv_3", - 3, - targeted_harm_categories=["violence", "illegal"], - labels={"operation": "other_op", "operator": "bob"}, - ) + # Create message pieces with harm categories (harm categories still live on PromptMemoryEntry) + message_piece1 = create_message_piece("conv_1", 1, targeted_harm_categories=["violence", "illegal"]) + message_piece2 = create_message_piece("conv_2", 2, targeted_harm_categories=["violence"]) + message_piece3 = create_message_piece("conv_3", 3, targeted_harm_categories=["violence", "illegal"]) sqlite_instance.add_message_pieces_to_memory(message_pieces=[message_piece1, message_piece2, message_piece3]) - # Create attack results + # Create attack results with labels attack_results = [ - create_attack_result("conv_1", 1, AttackOutcome.SUCCESS), - create_attack_result("conv_2", 2, AttackOutcome.SUCCESS), - create_attack_result("conv_3", 3, AttackOutcome.FAILURE), + create_attack_result( + "conv_1", 1, AttackOutcome.SUCCESS, labels={"operation": "test_op", "operator": "roakey"} + ), + create_attack_result( + "conv_2", 2, AttackOutcome.SUCCESS, labels={"operation": "test_op", "operator": "roakey"} + ), + create_attack_result( + "conv_3", 3, AttackOutcome.FAILURE, labels={"operation": "other_op", "operator": "bob"} + ), ] sqlite_instance.add_attack_results_to_memory(attack_results=attack_results) @@ -904,11 +899,8 @@ def test_get_attack_results_harm_category_no_matches(sqlite_instance: MemoryInte def test_get_attack_results_labels_no_matches(sqlite_instance: MemoryInterface): """Test filtering by labels that don't exist.""" - # Create attack result without the labels we'll search for - message_piece = create_message_piece("conv_1", 1, labels={"operation": "test_op"}) - sqlite_instance.add_message_pieces_to_memory(message_pieces=[message_piece]) - - attack_result = create_attack_result("conv_1", 1, AttackOutcome.SUCCESS) + # Create attack result with labels that don't match the search + attack_result = create_attack_result("conv_1", 1, AttackOutcome.SUCCESS, labels={"operation": "test_op"}) sqlite_instance.add_attack_results_to_memory(attack_results=[attack_result]) # Search for non-existent labels @@ -920,11 +912,6 @@ def test_get_attack_results_labels_query_on_empty_labels(sqlite_instance: Memory """Test querying for labels when records have no labels at all""" # Create attack results with NO labels - message_piece1 = create_message_piece("conv_1", 1) - message_piece2 = create_message_piece("conv_2", 1) - - sqlite_instance.add_message_pieces_to_memory(message_pieces=[message_piece1, message_piece2]) - attack_result1 = create_attack_result("conv_1", 1, AttackOutcome.SUCCESS) attack_result2 = create_attack_result("conv_2", 2, AttackOutcome.FAILURE) @@ -944,16 +931,14 @@ def test_get_attack_results_labels_key_exists_value_mismatch(sqlite_instance: Me """Test querying for labels where the key exists but the value doesn't match.""" # Create attack results with specific label values - message_piece1 = create_message_piece("conv_1", 1, labels={"operation": "op_exists", "researcher": "roakey"}) - message_piece2 = create_message_piece("conv_2", 1, labels={"operation": "another_op", "researcher": "roakey"}) - message_piece3 = create_message_piece("conv_3", 1, labels={"operation": "test_op"}) - - sqlite_instance.add_message_pieces_to_memory(message_pieces=[message_piece1, message_piece2, message_piece3]) - attack_results = [ - create_attack_result("conv_1", 1, AttackOutcome.SUCCESS), - create_attack_result("conv_2", 2, AttackOutcome.SUCCESS), - create_attack_result("conv_3", 3, AttackOutcome.FAILURE), + create_attack_result( + "conv_1", 1, AttackOutcome.SUCCESS, labels={"operation": "op_exists", "researcher": "roakey"} + ), + create_attack_result( + "conv_2", 2, AttackOutcome.SUCCESS, labels={"operation": "another_op", "researcher": "roakey"} + ), + create_attack_result("conv_3", 3, AttackOutcome.FAILURE, labels={"operation": "test_op"}), ] sqlite_instance.add_attack_results_to_memory(attack_results=attack_results) diff --git a/tests/unit/scenario/test_scenario.py b/tests/unit/scenario/test_scenario.py index 7f02982015..480adf0543 100644 --- a/tests/unit/scenario/test_scenario.py +++ b/tests/unit/scenario/test_scenario.py @@ -85,6 +85,7 @@ def sample_attack_results(): objective=f"objective{i}", outcome=AttackOutcome.SUCCESS, executed_turns=1, + labels={"test_label": f"value{i}"}, ) for i in range(5) ] From 4ec494a67867d6bccf1da5eb4d21ccc98d5e92db Mon Sep 17 00:00:00 2001 From: Behnam Ousat Date: Thu, 16 Apr 2026 09:37:10 -0700 Subject: [PATCH 2/6] format --- pyrit/memory/azure_sql_memory.py | 8 ++----- tests/unit/backend/test_attack_service.py | 4 +--- tests/unit/backend/test_mappers.py | 9 +++---- .../test_interface_attack_results.py | 24 +++++++++---------- 4 files changed, 20 insertions(+), 25 deletions(-) diff --git a/pyrit/memory/azure_sql_memory.py b/pyrit/memory/azure_sql_memory.py index 15586f7152..64ff020416 100644 --- a/pyrit/memory/azure_sql_memory.py +++ b/pyrit/memory/azure_sql_memory.py @@ -463,9 +463,7 @@ def _get_attack_result_label_condition(self, *, labels: dict[str, str]) -> Any: for i, (key, value) in enumerate(labels.items()): path_param = f"label_path_{i}" value_param = f"label_val_{i}" - label_conditions.append( - f'JSON_VALUE("AttackResultEntries".labels, :{path_param}) = :{value_param}' - ) + label_conditions.append(f'JSON_VALUE("AttackResultEntries".labels, :{path_param}) = :{value_param}') bindparams_dict[path_param] = f"$.{key}" bindparams_dict[value_param] = str(value) @@ -473,9 +471,7 @@ def _get_attack_result_label_condition(self, *, labels: dict[str, str]) -> Any: return and_( AttackResultEntry.labels.isnot(None), - text( - f'ISJSON("AttackResultEntries".labels) = 1 AND {combined_conditions}' - ).bindparams(**bindparams_dict), + text(f'ISJSON("AttackResultEntries".labels) = 1 AND {combined_conditions}').bindparams(**bindparams_dict), ) def get_unique_attack_class_names(self) -> list[str]: diff --git a/tests/unit/backend/test_attack_service.py b/tests/unit/backend/test_attack_service.py index e62ac73d63..c494e29944 100644 --- a/tests/unit/backend/test_attack_service.py +++ b/tests/unit/backend/test_attack_service.py @@ -95,9 +95,7 @@ def make_attack_result( "created_at": created.isoformat(), "updated_at": updated.isoformat(), }, - labels={ - "test_ar_label": "test_ar_value" - }, + labels={"test_ar_label": "test_ar_value"}, ) diff --git a/tests/unit/backend/test_mappers.py b/tests/unit/backend/test_mappers.py index f7c2495c71..ad6ef1a380 100644 --- a/tests/unit/backend/test_mappers.py +++ b/tests/unit/backend/test_mappers.py @@ -81,9 +81,7 @@ def _make_attack_result( "created_at": now.isoformat(), "updated_at": now.isoformat(), }, - labels={ - "test_ar_label": "test_ar_value" - }, + labels={"test_ar_label": "test_ar_value"}, ) @@ -191,7 +189,10 @@ def test_labels_passed_through_without_normalization(self) -> None: summary = attack_result_to_summary(ar, stats=stats) assert summary.labels == { - "operator": "alice", "operation": "op_red", "env": "prod", "test_ar_label": "test_ar_value" + "operator": "alice", + "operation": "op_red", + "env": "prod", + "test_ar_label": "test_ar_value", } def test_conversation_labels_take_precedence_on_collision(self) -> None: diff --git a/tests/unit/memory/memory_interface/test_interface_attack_results.py b/tests/unit/memory/memory_interface/test_interface_attack_results.py index d998e7c02a..425b1f1360 100644 --- a/tests/unit/memory/memory_interface/test_interface_attack_results.py +++ b/tests/unit/memory/memory_interface/test_interface_attack_results.py @@ -814,15 +814,21 @@ def test_get_attack_results_by_labels_multiple(sqlite_instance: MemoryInterface) # Create attack results with multiple labels attack_results = [ create_attack_result( - "conv_1", 1, AttackOutcome.SUCCESS, + "conv_1", + 1, + AttackOutcome.SUCCESS, labels={"operation": "test_op", "operator": "roakey", "phase": "initial"}, ), create_attack_result( - "conv_2", 2, AttackOutcome.SUCCESS, + "conv_2", + 2, + AttackOutcome.SUCCESS, labels={"operation": "test_op", "operator": "roakey", "phase": "final"}, ), create_attack_result( - "conv_3", 3, AttackOutcome.FAILURE, + "conv_3", + 3, + AttackOutcome.FAILURE, labels={"operation": "test_op", "phase": "initial"}, ), ] @@ -852,15 +858,9 @@ def test_get_attack_results_by_harm_category_and_labels(sqlite_instance: MemoryI # Create attack results with labels attack_results = [ - create_attack_result( - "conv_1", 1, AttackOutcome.SUCCESS, labels={"operation": "test_op", "operator": "roakey"} - ), - create_attack_result( - "conv_2", 2, AttackOutcome.SUCCESS, labels={"operation": "test_op", "operator": "roakey"} - ), - create_attack_result( - "conv_3", 3, AttackOutcome.FAILURE, labels={"operation": "other_op", "operator": "bob"} - ), + create_attack_result("conv_1", 1, AttackOutcome.SUCCESS, labels={"operation": "test_op", "operator": "roakey"}), + create_attack_result("conv_2", 2, AttackOutcome.SUCCESS, labels={"operation": "test_op", "operator": "roakey"}), + create_attack_result("conv_3", 3, AttackOutcome.FAILURE, labels={"operation": "other_op", "operator": "bob"}), ] sqlite_instance.add_attack_results_to_memory(attack_results=attack_results) From 5e04ca7e7f515ca8dca695d93b03f8f4a481aef1 Mon Sep 17 00:00:00 2001 From: Behnam Ousat Date: Thu, 16 Apr 2026 09:44:45 -0700 Subject: [PATCH 3/6] label queries OR'ed with old way --- pyrit/memory/azure_sql_memory.py | 56 +++++++++++++++++++++++--------- pyrit/memory/memory_interface.py | 2 +- pyrit/memory/sqlite_memory.py | 21 +++++++++--- 3 files changed, 58 insertions(+), 21 deletions(-) diff --git a/pyrit/memory/azure_sql_memory.py b/pyrit/memory/azure_sql_memory.py index 64ff020416..3023bf8240 100644 --- a/pyrit/memory/azure_sql_memory.py +++ b/pyrit/memory/azure_sql_memory.py @@ -9,7 +9,7 @@ from datetime import datetime, timedelta, timezone from typing import TYPE_CHECKING, Any, Optional, TypeVar, Union -from sqlalchemy import and_, create_engine, event, exists, text +from sqlalchemy import and_, create_engine, event, exists, or_, text from sqlalchemy.engine.base import Engine from sqlalchemy.exc import SQLAlchemyError from sqlalchemy.orm import InstrumentedAttribute, joinedload, sessionmaker @@ -446,8 +446,10 @@ def _get_attack_result_harm_category_condition(self, *, targeted_harm_categories def _get_attack_result_label_condition(self, *, labels: dict[str, str]) -> Any: """ - Get the SQL Azure implementation for filtering AttackResults by labels - stored directly on the AttackResultEntry. + Get the SQL Azure implementation for filtering AttackResults by labels. + + Matches if the labels are found on the AttackResultEntry directly + OR on an associated PromptMemoryEntry (via conversation_id). Uses JSON_VALUE() function specific to SQL Azure with parameterized queries. @@ -457,23 +459,47 @@ def _get_attack_result_label_condition(self, *, labels: dict[str, str]) -> Any: Returns: Any: SQLAlchemy condition with bound parameters. """ - # Build JSON conditions for all labels with parameterized queries - label_conditions = [] - bindparams_dict = {} + # --- Direct match on AttackResultEntry.labels --- + ar_label_conditions = [] + ar_bindparams: dict[str, str] = {} for i, (key, value) in enumerate(labels.items()): - path_param = f"label_path_{i}" - value_param = f"label_val_{i}" - label_conditions.append(f'JSON_VALUE("AttackResultEntries".labels, :{path_param}) = :{value_param}') - bindparams_dict[path_param] = f"$.{key}" - bindparams_dict[value_param] = str(value) - - combined_conditions = " AND ".join(label_conditions) + path_param = f"ar_label_path_{i}" + value_param = f"ar_label_val_{i}" + ar_label_conditions.append( + f'JSON_VALUE("AttackResultEntries".labels, :{path_param}) = :{value_param}' + ) + ar_bindparams[path_param] = f"$.{key}" + ar_bindparams[value_param] = str(value) - return and_( + ar_combined = " AND ".join(ar_label_conditions) + direct_condition = and_( AttackResultEntry.labels.isnot(None), - text(f'ISJSON("AttackResultEntries".labels) = 1 AND {combined_conditions}').bindparams(**bindparams_dict), + text( + f'ISJSON("AttackResultEntries".labels) = 1 AND {ar_combined}' + ).bindparams(**ar_bindparams), ) + # --- Conversation-level match on PromptMemoryEntry.labels --- + pme_label_conditions = [] + pme_bindparams: dict[str, str] = {} + for i, (key, value) in enumerate(labels.items()): + param_name = f"pme_label_{key}" + pme_label_conditions.append(f"JSON_VALUE(labels, '$.{key}') = :{param_name}") + pme_bindparams[param_name] = str(value) + + pme_combined = " AND ".join(pme_label_conditions) + conversation_condition = exists().where( + and_( + PromptMemoryEntry.conversation_id == AttackResultEntry.conversation_id, + PromptMemoryEntry.labels.isnot(None), + text( + f"ISJSON(labels) = 1 AND {pme_combined}" + ).bindparams(**pme_bindparams), + ) + ) + + return or_(direct_condition, conversation_condition) + def get_unique_attack_class_names(self) -> list[str]: """ Azure SQL implementation: extract unique class_name values from diff --git a/pyrit/memory/memory_interface.py b/pyrit/memory/memory_interface.py index 5cd9797cdc..c90c978777 100644 --- a/pyrit/memory/memory_interface.py +++ b/pyrit/memory/memory_interface.py @@ -419,7 +419,7 @@ def _get_attack_result_harm_category_condition(self, *, targeted_harm_categories def _get_attack_result_label_condition(self, *, labels: dict[str, str]) -> Any: """ Return a database-specific condition for filtering AttackResults by labels - stored directly on the AttackResultEntry. + stored directly on the AttackResultEntry OR on an associated PromptMemoryEntry (via conversation_id). Args: labels: Dictionary of labels that must ALL be present. diff --git a/pyrit/memory/sqlite_memory.py b/pyrit/memory/sqlite_memory.py index 3c4bb287fd..ac8b2319eb 100644 --- a/pyrit/memory/sqlite_memory.py +++ b/pyrit/memory/sqlite_memory.py @@ -10,7 +10,7 @@ from pathlib import Path from typing import Any, Optional, TypeVar, Union, cast -from sqlalchemy import and_, create_engine, func, or_, text +from sqlalchemy import and_, create_engine, exists, func, or_, text from sqlalchemy.engine.base import Engine from sqlalchemy.exc import SQLAlchemyError from sqlalchemy.orm import InstrumentedAttribute, joinedload, sessionmaker @@ -613,18 +613,29 @@ def _get_attack_result_harm_category_condition(self, *, targeted_harm_categories def _get_attack_result_label_condition(self, *, labels: dict[str, str]) -> Any: """ - SQLite implementation for filtering AttackResults by labels - stored directly on the AttackResultEntry. - Uses json_extract() function specific to SQLite. + SQLite implementation for filtering AttackResults by labels. + + Matches if the labels are found on the AttackResultEntry directly + OR on an associated PromptMemoryEntry (via conversation_id). Returns: Any: A SQLAlchemy condition for filtering by labels. """ - return and_( + direct_condition = and_( AttackResultEntry.labels.isnot(None), *[func.json_extract(AttackResultEntry.labels, f"$.{key}") == value for key, value in labels.items()], ) + conversation_condition = exists().where( + and_( + PromptMemoryEntry.conversation_id == AttackResultEntry.conversation_id, + PromptMemoryEntry.labels.isnot(None), + *[func.json_extract(PromptMemoryEntry.labels, f"$.{key}") == value for key, value in labels.items()], + ) + ) + + return or_(direct_condition, conversation_condition) + def get_unique_attack_class_names(self) -> list[str]: """ SQLite implementation: extract unique class_name values from From fa5a0d76bab4c582a5e4f86ae8f727d08b3400fd Mon Sep 17 00:00:00 2001 From: Behnam Ousat Date: Thu, 16 Apr 2026 09:57:02 -0700 Subject: [PATCH 4/6] review feedback --- pyrit/memory/memory_interface.py | 4 ++-- .../test_interface_attack_results.py | 21 +++++++++++++++++++ 2 files changed, 23 insertions(+), 2 deletions(-) diff --git a/pyrit/memory/memory_interface.py b/pyrit/memory/memory_interface.py index c90c978777..6e7e1a1b15 100644 --- a/pyrit/memory/memory_interface.py +++ b/pyrit/memory/memory_interface.py @@ -1481,8 +1481,8 @@ def get_attack_results( By providing a list, this means ALL categories in the list must be present. Defaults to None. labels (Optional[dict[str, str]], optional): A dictionary of labels to filter results by. - These labels are stored directly on the AttackResult. All specified key-value pairs - must be present (AND logic). Defaults to None. + These labels are stored on the AttackResult or associated PromptMemoryEntry (via conversation_id) + . All specified key-value pairs must be present (AND logic). Defaults to None. identifier_filters (Optional[Sequence[IdentifierFilter]], optional): A sequence of IdentifierFilter objects that allows filtering by various attack identifier JSON properties. Defaults to None. diff --git a/tests/unit/memory/memory_interface/test_interface_attack_results.py b/tests/unit/memory/memory_interface/test_interface_attack_results.py index 425b1f1360..a4d87fd6d3 100644 --- a/tests/unit/memory/memory_interface/test_interface_attack_results.py +++ b/tests/unit/memory/memory_interface/test_interface_attack_results.py @@ -981,6 +981,27 @@ def test_get_attack_results_labels_key_exists_value_mismatch(sqlite_instance: Me assert results[0].conversation_id == "conv_1" +def test_get_attack_results_by_labels_falls_back_to_conversation_labels(sqlite_instance: MemoryInterface): + """Test that label filtering matches via PromptMemoryEntry when AttackResult has no labels.""" + + # Attack result with NO labels + attack_result = create_attack_result("conv_1", 1, AttackOutcome.SUCCESS, labels={}) + sqlite_instance.add_attack_results_to_memory(attack_results=[attack_result]) + + # Conversation message carries the labels instead + message_piece = create_message_piece("conv_1", 1, labels={"operation": "legacy_op"}) + sqlite_instance.add_message_pieces_to_memory(message_pieces=[message_piece]) + + # Should still find the attack result via the PME fallback path + results = sqlite_instance.get_attack_results(labels={"operation": "legacy_op"}) + assert len(results) == 1 + assert results[0].conversation_id == "conv_1" + + # Non-matching label should return nothing + results = sqlite_instance.get_attack_results(labels={"operation": "missing"}) + assert len(results) == 0 + + # --------------------------------------------------------------------------- # get_unique_attack_labels tests # --------------------------------------------------------------------------- From aafefc1ad8c222f9cb451a497e3f11fe3bb2381d Mon Sep 17 00:00:00 2001 From: Behnam Ousat Date: Thu, 16 Apr 2026 10:02:58 -0700 Subject: [PATCH 5/6] format --- pyrit/memory/azure_sql_memory.py | 14 ++++---------- 1 file changed, 4 insertions(+), 10 deletions(-) diff --git a/pyrit/memory/azure_sql_memory.py b/pyrit/memory/azure_sql_memory.py index 3023bf8240..e7d3097615 100644 --- a/pyrit/memory/azure_sql_memory.py +++ b/pyrit/memory/azure_sql_memory.py @@ -465,24 +465,20 @@ def _get_attack_result_label_condition(self, *, labels: dict[str, str]) -> Any: for i, (key, value) in enumerate(labels.items()): path_param = f"ar_label_path_{i}" value_param = f"ar_label_val_{i}" - ar_label_conditions.append( - f'JSON_VALUE("AttackResultEntries".labels, :{path_param}) = :{value_param}' - ) + ar_label_conditions.append(f'JSON_VALUE("AttackResultEntries".labels, :{path_param}) = :{value_param}') ar_bindparams[path_param] = f"$.{key}" ar_bindparams[value_param] = str(value) ar_combined = " AND ".join(ar_label_conditions) direct_condition = and_( AttackResultEntry.labels.isnot(None), - text( - f'ISJSON("AttackResultEntries".labels) = 1 AND {ar_combined}' - ).bindparams(**ar_bindparams), + text(f'ISJSON("AttackResultEntries".labels) = 1 AND {ar_combined}').bindparams(**ar_bindparams), ) # --- Conversation-level match on PromptMemoryEntry.labels --- pme_label_conditions = [] pme_bindparams: dict[str, str] = {} - for i, (key, value) in enumerate(labels.items()): + for key, value in labels.items(): param_name = f"pme_label_{key}" pme_label_conditions.append(f"JSON_VALUE(labels, '$.{key}') = :{param_name}") pme_bindparams[param_name] = str(value) @@ -492,9 +488,7 @@ def _get_attack_result_label_condition(self, *, labels: dict[str, str]) -> Any: and_( PromptMemoryEntry.conversation_id == AttackResultEntry.conversation_id, PromptMemoryEntry.labels.isnot(None), - text( - f"ISJSON(labels) = 1 AND {pme_combined}" - ).bindparams(**pme_bindparams), + text(f"ISJSON(labels) = 1 AND {pme_combined}").bindparams(**pme_bindparams), ) ) From fca41cb3c2a8ac6ca7d6b118708e184fee2b4d6b Mon Sep 17 00:00:00 2001 From: Behnam Ousat Date: Thu, 16 Apr 2026 10:24:05 -0700 Subject: [PATCH 6/6] cc --- tests/unit/memory/test_azure_sql_memory.py | 19 +++++++++++++++++++ 1 file changed, 19 insertions(+) diff --git a/tests/unit/memory/test_azure_sql_memory.py b/tests/unit/memory/test_azure_sql_memory.py index acf4420604..75cef6997b 100644 --- a/tests/unit/memory/test_azure_sql_memory.py +++ b/tests/unit/memory/test_azure_sql_memory.py @@ -226,6 +226,25 @@ def test_get_memories_with_attack_id(memory_interface: AzureSQLMemory): pytest.skip("Test requires Azure SQL-specific JSON functions; covered by integration tests") +def test_get_attack_result_label_condition_single_label(memory_interface: AzureSQLMemory): + """Test that _get_attack_result_label_condition builds a valid condition for a single label.""" + condition = memory_interface._get_attack_result_label_condition(labels={"operation": "test_op"}) + compiled = str(condition.compile(compile_kwargs={"literal_binds": False})) + assert "JSON_VALUE" in compiled + assert "ISJSON" in compiled + + +def test_get_attack_result_label_condition_multiple_labels(memory_interface: AzureSQLMemory): + """Test that _get_attack_result_label_condition builds a valid condition for multiple labels.""" + condition = memory_interface._get_attack_result_label_condition( + labels={"operation": "test_op", "operator": "roakey"} + ) + compiled = str(condition.compile(compile_kwargs={"literal_binds": False})) + # Both AR-direct and PME-conversation branches should appear + assert "AttackResultEntries" in compiled + assert "PromptMemoryEntries" in compiled + + def test_update_entries(memory_interface: AzureSQLMemory): # Insert a test entry entry = PromptMemoryEntry(