Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion pyrit/backend/mappers/attack_mappers.py
Original file line number Diff line number Diff line change
Expand Up @@ -197,8 +197,11 @@ def attack_result_to_summary(
"""
message_count = stats.message_count
last_preview = stats.last_message_preview
labels = dict(stats.labels) if stats.labels else {}

# Merge attack-result labels with conversation-level labels.
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Confused about this comment. Aren't attack-result labels conversation-level?

# Conversation labels take precedence on key collision.
labels = dict(ar.labels) if ar.labels else {}
labels.update(stats.labels or {})
created_str = ar.metadata.get("created_at")
updated_str = ar.metadata.get("updated_at")
created_at = datetime.fromisoformat(created_str) if created_str else datetime.now(timezone.utc)
Expand Down
1 change: 1 addition & 0 deletions pyrit/backend/services/attack_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -308,6 +308,7 @@ async def create_attack_async(self, *, request: CreateAttackRequest) -> CreateAt
"created_at": now.isoformat(),
"updated_at": now.isoformat(),
},
labels=labels,
)

# Store in memory
Expand Down
1 change: 1 addition & 0 deletions pyrit/executor/attack/multi_turn/chunked_request.py
Original file line number Diff line number Diff line change
Expand Up @@ -325,6 +325,7 @@ async def _perform_async(self, *, context: ChunkedRequestAttackContext) -> Attac
outcome_reason=outcome_reason,
executed_turns=context.executed_turns,
metadata={"combined_chunks": combined_value, "chunk_count": len(context.chunk_responses)},
labels=context.memory_labels,
)

def _determine_attack_outcome(
Expand Down
1 change: 1 addition & 0 deletions pyrit/executor/attack/multi_turn/crescendo.py
Original file line number Diff line number Diff line change
Expand Up @@ -402,6 +402,7 @@ async def _perform_async(self, *, context: CrescendoAttackContext) -> CrescendoA
last_response=context.last_response.get_piece() if context.last_response else None,
last_score=context.last_score,
related_conversations=context.related_conversations,
labels=context.memory_labels,
)
# setting metadata for backtrack count
result.backtrack_count = context.backtrack_count
Expand Down
1 change: 1 addition & 0 deletions pyrit/executor/attack/multi_turn/multi_prompt_sending.py
Original file line number Diff line number Diff line change
Expand Up @@ -295,6 +295,7 @@ async def _perform_async(self, *, context: MultiTurnAttackContext[Any]) -> Attac
outcome=outcome,
outcome_reason=outcome_reason,
executed_turns=context.executed_turns,
labels=context.memory_labels,
)

def _determine_attack_outcome(
Expand Down
1 change: 1 addition & 0 deletions pyrit/executor/attack/multi_turn/red_teaming.py
Original file line number Diff line number Diff line change
Expand Up @@ -322,6 +322,7 @@ async def _perform_async(self, *, context: MultiTurnAttackContext[Any]) -> Attac
last_response=context.last_response.get_piece() if context.last_response else None,
last_score=context.last_score,
related_conversations=context.related_conversations,
labels=context.memory_labels,
)

async def _teardown_async(self, *, context: MultiTurnAttackContext[Any]) -> None:
Expand Down
1 change: 1 addition & 0 deletions pyrit/executor/attack/multi_turn/tree_of_attacks.py
Original file line number Diff line number Diff line change
Expand Up @@ -2082,6 +2082,7 @@ def _create_attack_result(
last_response=last_response,
last_score=context.best_objective_score,
related_conversations=context.related_conversations,
labels=context.memory_labels,
)

# Set attack-specific metadata using properties
Expand Down
1 change: 1 addition & 0 deletions pyrit/executor/attack/single_turn/prompt_sending.py
Original file line number Diff line number Diff line change
Expand Up @@ -238,6 +238,7 @@ async def _perform_async(self, *, context: SingleTurnAttackContext[Any]) -> Atta
outcome=outcome,
outcome_reason=outcome_reason,
executed_turns=1,
labels=context.memory_labels,
)

def _determine_attack_outcome(
Expand Down
1 change: 1 addition & 0 deletions pyrit/executor/attack/single_turn/skeleton_key.py
Original file line number Diff line number Diff line change
Expand Up @@ -181,4 +181,5 @@ def _create_skeleton_key_failure_result(self, *, context: SingleTurnAttackContex
outcome=AttackOutcome.FAILURE,
outcome_reason="Skeleton key prompt was filtered or failed",
executed_turns=1,
labels=context.memory_labels,
)
1 change: 1 addition & 0 deletions pyrit/executor/benchmark/fairness_bias.py
Original file line number Diff line number Diff line change
Expand Up @@ -200,6 +200,7 @@ async def _perform_async(self, *, context: FairnessBiasBenchmarkContext) -> Atta
atomic_attack_identifier=build_atomic_attack_identifier(
attack_identifier=ComponentIdentifier.of(self),
),
labels=context.memory_labels,
)

return last_attack_result
Expand Down
46 changes: 33 additions & 13 deletions pyrit/memory/azure_sql_memory.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from datetime import datetime, timedelta, timezone
from typing import TYPE_CHECKING, Any, Optional, TypeVar, Union

from sqlalchemy import and_, create_engine, event, exists, text
from sqlalchemy import and_, create_engine, event, exists, or_, text
from sqlalchemy.engine.base import Engine
from sqlalchemy.exc import SQLAlchemyError
from sqlalchemy.orm import InstrumentedAttribute, joinedload, sessionmaker
Expand Down Expand Up @@ -448,32 +448,52 @@ def _get_attack_result_label_condition(self, *, labels: dict[str, str]) -> Any:
"""
Get the SQL Azure implementation for filtering AttackResults by labels.

Matches if the labels are found on the AttackResultEntry directly
OR on an associated PromptMemoryEntry (via conversation_id).
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we get rid of PromptMemoryEntry labels?

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think I'd rather have the route to only attack result labels to simplify things. And a data migration path for the databases

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am putting that in a separate PR (deprecate labels on message piece and its memory entries) and then, this OR will be removed too.

Maybe I'm being over-cautious? should I do the flip in one go?


Uses JSON_VALUE() function specific to SQL Azure with parameterized queries.

Args:
labels (dict[str, str]): Dictionary of label key-value pairs to filter by.

Returns:
Any: SQLAlchemy exists subquery condition with bound parameters.
"""
# Build JSON conditions for all labels with parameterized queries
label_conditions = []
bindparams_dict = {}
for key, value in labels.items():
param_name = f"label_{key}"
label_conditions.append(f"JSON_VALUE(labels, '$.{key}') = :{param_name}")
bindparams_dict[param_name] = str(value)
Any: SQLAlchemy condition with bound parameters.
"""
# --- Direct match on AttackResultEntry.labels ---
ar_label_conditions = []
ar_bindparams: dict[str, str] = {}
for i, (key, value) in enumerate(labels.items()):
path_param = f"ar_label_path_{i}"
value_param = f"ar_label_val_{i}"
ar_label_conditions.append(f'JSON_VALUE("AttackResultEntries".labels, :{path_param}) = :{value_param}')
ar_bindparams[path_param] = f"$.{key}"
ar_bindparams[value_param] = str(value)

ar_combined = " AND ".join(ar_label_conditions)
direct_condition = and_(
AttackResultEntry.labels.isnot(None),
text(f'ISJSON("AttackResultEntries".labels) = 1 AND {ar_combined}').bindparams(**ar_bindparams),
)

combined_conditions = " AND ".join(label_conditions)
# --- Conversation-level match on PromptMemoryEntry.labels ---
pme_label_conditions = []
pme_bindparams: dict[str, str] = {}
for key, value in labels.items():
param_name = f"pme_label_{key}"
pme_label_conditions.append(f"JSON_VALUE(labels, '$.{key}') = :{param_name}")
pme_bindparams[param_name] = str(value)

return exists().where(
pme_combined = " AND ".join(pme_label_conditions)
conversation_condition = exists().where(
and_(
PromptMemoryEntry.conversation_id == AttackResultEntry.conversation_id,
PromptMemoryEntry.labels.isnot(None),
text(f"ISJSON(labels) = 1 AND {combined_conditions}").bindparams(**bindparams_dict),
text(f"ISJSON(labels) = 1 AND {pme_combined}").bindparams(**pme_bindparams),
)
)

return or_(direct_condition, conversation_condition)

def get_unique_attack_class_names(self) -> list[str]:
"""
Azure SQL implementation: extract unique class_name values from
Expand Down
8 changes: 4 additions & 4 deletions pyrit/memory/memory_interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -419,7 +419,7 @@ def _get_attack_result_harm_category_condition(self, *, targeted_harm_categories
def _get_attack_result_label_condition(self, *, labels: dict[str, str]) -> Any:
"""
Return a database-specific condition for filtering AttackResults by labels
in the associated PromptMemoryEntry records.
stored directly on the AttackResultEntry OR on an associated PromptMemoryEntry (via conversation_id).

Args:
labels: Dictionary of labels that must ALL be present.
Expand Down Expand Up @@ -1480,9 +1480,9 @@ def get_attack_results(
not necessarily one(s) that were found in the response.
By providing a list, this means ALL categories in the list must be present.
Defaults to None.
labels (Optional[dict[str, str]], optional): A dictionary of memory labels to filter results by.
These labels are associated with the prompts themselves, used for custom tagging and tracking.
Defaults to None.
labels (Optional[dict[str, str]], optional): A dictionary of labels to filter results by.
These labels are stored on the AttackResult or associated PromptMemoryEntry (via conversation_id)
. All specified key-value pairs must be present (AND logic). Defaults to None.
Comment on lines +1484 to +1485
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
These labels are stored on the AttackResult or associated PromptMemoryEntry (via conversation_id)
. All specified key-value pairs must be present (AND logic). Defaults to None.
These labels are stored on the AttackResult or associated PromptMemoryEntry (via conversation_id).
All specified key-value pairs must be present (AND logic). Defaults to None.

identifier_filters (Optional[Sequence[IdentifierFilter]], optional):
A sequence of IdentifierFilter objects that allows filtering by various attack identifier
JSON properties. Defaults to None.
Expand Down
4 changes: 4 additions & 0 deletions pyrit/memory/memory_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -697,6 +697,7 @@ class AttackResultEntry(Base):
outcome (AttackOutcome): The outcome of the attack, indicating success, failure, or undetermined.
outcome_reason (str): Optional reason for the outcome, providing additional context.
attack_metadata (dict[str, Any]): Metadata can be included as key-value pairs to provide extra context.
labels (dict[str, str]): Optional labels associated with the attack result entry.
pruned_conversation_ids (List[str]): List of conversation IDs that were pruned from the attack.
adversarial_chat_conversation_ids (List[str]): List of conversation IDs used for adversarial chat.
timestamp (DateTime): The timestamp of the attack result entry.
Expand Down Expand Up @@ -728,6 +729,7 @@ class AttackResultEntry(Base):
)
outcome_reason = mapped_column(String, nullable=True)
attack_metadata: Mapped[dict[str, Union[str, int, float, bool]]] = mapped_column(JSON, nullable=True)
labels: Mapped[dict[str, str]] = mapped_column(JSON, nullable=True)
pruned_conversation_ids: Mapped[Optional[list[str]]] = mapped_column(JSON, nullable=True)
adversarial_chat_conversation_ids: Mapped[Optional[list[str]]] = mapped_column(JSON, nullable=True)
timestamp = mapped_column(DateTime, nullable=False)
Expand Down Expand Up @@ -783,6 +785,7 @@ def __init__(self, *, entry: AttackResult):
self.outcome = entry.outcome.value
self.outcome_reason = entry.outcome_reason
self.attack_metadata = self.filter_json_serializable_metadata(entry.metadata)
self.labels = entry.labels or {}

# Persist conversation references by type
self.pruned_conversation_ids = [
Expand Down Expand Up @@ -894,6 +897,7 @@ def get_attack_result(self) -> AttackResult:
outcome_reason=self.outcome_reason,
related_conversations=related_conversations,
metadata=self.attack_metadata or {},
labels=self.labels or {},
)


Expand Down
24 changes: 13 additions & 11 deletions pyrit/memory/sqlite_memory.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from pathlib import Path
from typing import Any, Optional, TypeVar, Union, cast

from sqlalchemy import and_, create_engine, func, or_, text
from sqlalchemy import and_, create_engine, exists, func, or_, text
from sqlalchemy.engine.base import Engine
from sqlalchemy.exc import SQLAlchemyError
from sqlalchemy.orm import InstrumentedAttribute, joinedload, sessionmaker
Expand Down Expand Up @@ -614,25 +614,27 @@ def _get_attack_result_harm_category_condition(self, *, targeted_harm_categories
def _get_attack_result_label_condition(self, *, labels: dict[str, str]) -> Any:
"""
SQLite implementation for filtering AttackResults by labels.
Uses json_extract() function specific to SQLite.

Matches if the labels are found on the AttackResultEntry directly
OR on an associated PromptMemoryEntry (via conversation_id).

Returns:
Any: A SQLAlchemy subquery for filtering by labels.
Any: A SQLAlchemy condition for filtering by labels.
"""
from sqlalchemy import and_, exists, func

from pyrit.memory.memory_models import AttackResultEntry, PromptMemoryEntry
direct_condition = and_(
AttackResultEntry.labels.isnot(None),
*[func.json_extract(AttackResultEntry.labels, f"$.{key}") == value for key, value in labels.items()],
)

labels_subquery = exists().where(
conversation_condition = exists().where(
and_(
PromptMemoryEntry.conversation_id == AttackResultEntry.conversation_id,
PromptMemoryEntry.labels.isnot(None),
and_(
*[func.json_extract(PromptMemoryEntry.labels, f"$.{key}") == value for key, value in labels.items()]
),
*[func.json_extract(PromptMemoryEntry.labels, f"$.{key}") == value for key, value in labels.items()],
)
)
return labels_subquery # noqa: RET504

return or_(direct_condition, conversation_condition)

def get_unique_attack_class_names(self) -> list[str]:
"""
Expand Down
3 changes: 3 additions & 0 deletions pyrit/models/attack_result.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,9 @@ class AttackResult(StrategyResult):
# Arbitrary metadata
metadata: dict[str, Any] = field(default_factory=dict)

# labels associated with this attack result
labels: dict[str, str] = field(default_factory=dict)

@property
def attack_identifier(self) -> Optional[ComponentIdentifier]:
"""
Expand Down
3 changes: 2 additions & 1 deletion tests/unit/backend/test_attack_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,7 @@ def make_attack_result(
"created_at": created.isoformat(),
"updated_at": updated.isoformat(),
},
labels={"test_ar_label": "test_ar_value"},
)


Expand Down Expand Up @@ -321,7 +322,7 @@ async def test_list_attacks_includes_labels_in_summary(self, attack_service, moc
result = await attack_service.list_attacks_async()

assert len(result.items) == 1
assert result.items[0].labels == {"env": "prod", "team": "red"}
assert result.items[0].labels == {"env": "prod", "team": "red", "test_ar_label": "test_ar_value"}

@pytest.mark.asyncio
async def test_list_attacks_filters_by_labels_directly(self, attack_service, mock_memory) -> None:
Expand Down
23 changes: 21 additions & 2 deletions tests/unit/backend/test_mappers.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,7 @@ def _make_attack_result(
"created_at": now.isoformat(),
"updated_at": now.isoformat(),
},
labels={"test_ar_label": "test_ar_value"},
)


Expand Down Expand Up @@ -175,7 +176,7 @@ def test_labels_are_mapped(self) -> None:

summary = attack_result_to_summary(ar, stats=stats)

assert summary.labels == {"env": "prod", "team": "red"}
assert summary.labels == {"env": "prod", "team": "red", "test_ar_label": "test_ar_value"}

def test_labels_passed_through_without_normalization(self) -> None:
"""Test that labels are passed through as-is (DB stores canonical keys after migration)."""
Expand All @@ -187,7 +188,24 @@ def test_labels_passed_through_without_normalization(self) -> None:

summary = attack_result_to_summary(ar, stats=stats)

assert summary.labels == {"operator": "alice", "operation": "op_red", "env": "prod"}
assert summary.labels == {
"operator": "alice",
"operation": "op_red",
"env": "prod",
"test_ar_label": "test_ar_value",
}

def test_conversation_labels_take_precedence_on_collision(self) -> None:
"""Test that conversation-level labels override attack-result labels on key collision."""
ar = _make_attack_result()
stats = ConversationStats(
message_count=1,
labels={"test_ar_label": "conversation_wins"},
)

summary = attack_result_to_summary(ar, stats=stats)

assert summary.labels["test_ar_label"] == "conversation_wins"

def test_outcome_success(self) -> None:
"""Test that success outcome is mapped."""
Expand Down Expand Up @@ -249,6 +267,7 @@ def test_converters_extracted_from_identifier(self) -> None:
),
outcome=AttackOutcome.UNDETERMINED,
metadata={"created_at": now.isoformat(), "updated_at": now.isoformat()},
labels={"test_label": "test_value"},
)

summary = attack_result_to_summary(ar, stats=ConversationStats(message_count=0))
Expand Down
Loading
Loading