Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 0 additions & 3 deletions doc/code/memory/3_memory_data_types.md
Original file line number Diff line number Diff line change
Expand Up @@ -25,11 +25,8 @@ One of the most fundamental data structures in PyRIT is [MessagePiece](../../../
- **`converter_identifiers`**: List of converters applied to transform the prompt
- **`prompt_target_identifier`**: Information about the target that received this prompt
- **`attack_identifier`**: Information about the attack that generated this prompt
- **`scorer_identifier`**: Information about the scorer that evaluated this prompt
- **`response_error`**: Error status (e.g., `none`, `blocked`, `processing`)
- **`originator`**: Source of the prompt (`attack`, `converter`, `scorer`, `undefined`)
- **`scores`**: List of `Score` objects associated with this piece
- **`targeted_harm_categories`**: Harm categories associated with the prompt
- **`timestamp`**: When the piece was created

This rich context allows PyRIT to track the full lifecycle of each interaction, including transformations, targeting, scoring, and error handling.
Expand Down
37 changes: 0 additions & 37 deletions pyrit/memory/azure_sql_memory.py
Original file line number Diff line number Diff line change
Expand Up @@ -407,43 +407,6 @@ def _get_condition_json_array_match(
combined = " AND ".join(conditions)
return text(f"""ISJSON("{table_name}".{column_name}) = 1 AND {combined}""").bindparams(**bindparams_dict)

def _get_attack_result_harm_category_condition(self, *, targeted_harm_categories: Sequence[str]) -> Any:
"""
Get the SQL Azure implementation for filtering AttackResults by targeted harm categories.

Uses JSON_QUERY() function specific to SQL Azure to check if categories exist in the JSON array.

Args:
targeted_harm_categories (Sequence[str]): List of harm category strings to filter by.

Returns:
Any: SQLAlchemy exists subquery condition with bound parameters.
"""
# For SQL Azure, we need to use JSON_QUERY to check if a value exists in a JSON array
# OPENJSON can parse the array and we check if the category exists
# Using parameterized queries for safety
harm_conditions = []
bindparams_dict = {}
for i, category in enumerate(targeted_harm_categories):
param_name = f"harm_cat_{i}"
# Check if the JSON array contains the category value
harm_conditions.append(
f"EXISTS(SELECT 1 FROM OPENJSON(targeted_harm_categories) WHERE value = :{param_name})"
)
bindparams_dict[param_name] = category

combined_conditions = " AND ".join(harm_conditions)

return exists().where(
and_(
PromptMemoryEntry.conversation_id == AttackResultEntry.conversation_id,
PromptMemoryEntry.targeted_harm_categories.isnot(None),
PromptMemoryEntry.targeted_harm_categories != "",
PromptMemoryEntry.targeted_harm_categories != "[]",
text(f"ISJSON(targeted_harm_categories) = 1 AND {combined_conditions}").bindparams(**bindparams_dict),
)
)

def _get_attack_result_label_condition(self, *, labels: dict[str, str]) -> Any:
"""
Get the SQL Azure implementation for filtering AttackResults by labels.
Expand Down
27 changes: 0 additions & 27 deletions pyrit/memory/memory_interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -402,19 +402,6 @@ def _update_entries(self, *, entries: MutableSequence[Base], update_fields: dict
update_fields (dict): A dictionary of field names and their new values.
"""

@abc.abstractmethod
def _get_attack_result_harm_category_condition(self, *, targeted_harm_categories: Sequence[str]) -> Any:
"""
Return a database-specific condition for filtering AttackResults by targeted harm categories
in the associated PromptMemoryEntry records.

Args:
targeted_harm_categories: List of harm categories that must ALL be present.

Returns:
Database-specific SQLAlchemy condition.
"""

@abc.abstractmethod
def _get_attack_result_label_condition(self, *, labels: dict[str, str]) -> Any:
"""
Expand Down Expand Up @@ -1453,7 +1440,6 @@ def get_attack_results(
outcome: Optional[str] = None,
attack_class: Optional[str] = None,
converter_classes: Optional[Sequence[str]] = None,
targeted_harm_categories: Optional[Sequence[str]] = None,
labels: Optional[dict[str, str]] = None,
identifier_filters: Optional[Sequence[IdentifierFilter]] = None,
Comment thread
behnam-o marked this conversation as resolved.
) -> Sequence[AttackResult]:
Expand All @@ -1473,13 +1459,6 @@ def get_attack_results(
converter_classes (Optional[Sequence[str]], optional): Filter by converter class names.
Returns only attacks that used ALL specified converters (AND logic, case-insensitive).
Defaults to None.
targeted_harm_categories (Optional[Sequence[str]], optional):
A list of targeted harm categories to filter results by.
These targeted harm categories are associated with the prompts themselves,
meaning they are harm(s) we're trying to elicit with the prompt,
not necessarily one(s) that were found in the response.
By providing a list, this means ALL categories in the list must be present.
Defaults to None.
labels (Optional[dict[str, str]], optional): A dictionary of memory labels to filter results by.
These labels are associated with the prompts themselves, used for custom tagging and tracking.
Defaults to None.
Expand Down Expand Up @@ -1530,12 +1509,6 @@ def get_attack_results(
)
)

if targeted_harm_categories:
# Use database-specific JSON query method
conditions.append(
self._get_attack_result_harm_category_condition(targeted_harm_categories=targeted_harm_categories)
)

if labels:
# Use database-specific JSON query method
conditions.append(self._get_attack_result_label_condition(labels=labels))
Expand Down
6 changes: 1 addition & 5 deletions pyrit/memory/memory_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,7 +156,6 @@ class PromptMemoryEntry(Base):
Can be the same number for multi-part requests or multi-part responses.
timestamp (DateTime): The timestamp of the memory entry.
labels (Dict[str, str]): The labels associated with the memory entry. Several can be standardized.
targeted_harm_categories (List[str]): The targeted harm categories for the memory entry.
prompt_metadata (JSON): The metadata associated with the prompt. This can be specific to any scenarios.
Because memory is how components talk with each other, this can be component specific.
e.g. the URI from a file uploaded to a blob store, or a document type you want to upload.
Expand Down Expand Up @@ -188,7 +187,6 @@ class PromptMemoryEntry(Base):
timestamp = mapped_column(DateTime, nullable=False)
labels: Mapped[dict[str, str]] = mapped_column(JSON)
prompt_metadata: Mapped[dict[str, Union[str, int]]] = mapped_column(JSON)
targeted_harm_categories: Mapped[Optional[list[str]]] = mapped_column(JSON)
converter_identifiers: Mapped[Optional[list[dict[str, str]]]] = mapped_column(JSON)
prompt_target_identifier: Mapped[dict[str, str]] = mapped_column(JSON)
attack_identifier: Mapped[dict[str, str]] = mapped_column(JSON)
Expand Down Expand Up @@ -235,7 +233,6 @@ def __init__(self, *, entry: MessagePiece):
self.timestamp = entry.timestamp
self.labels = entry.labels
self.prompt_metadata = entry.prompt_metadata
self.targeted_harm_categories = entry.targeted_harm_categories
self.converter_identifiers = [
conv.to_dict(max_value_length=MAX_IDENTIFIER_VALUE_LENGTH) for conv in entry.converter_identifiers
]
Expand Down Expand Up @@ -303,7 +300,6 @@ def get_message_piece(self) -> MessagePiece:
sequence=self.sequence,
labels=self.labels,
prompt_metadata=self.prompt_metadata,
targeted_harm_categories=self.targeted_harm_categories,
converter_identifiers=converter_ids,
prompt_target_identifier=target_id,
attack_identifier=attack_id,
Expand All @@ -313,7 +309,7 @@ def get_message_piece(self) -> MessagePiece:
original_prompt_id=self.original_prompt_id,
timestamp=_ensure_utc(self.timestamp),
)
message_piece.scores = [score.get_score() for score in self.scores]
message_piece._set_scores([score.get_score() for score in self.scores])
return message_piece

def __str__(self) -> str:
Expand Down
29 changes: 0 additions & 29 deletions pyrit/memory/sqlite_memory.py
Original file line number Diff line number Diff line change
Expand Up @@ -582,35 +582,6 @@ def export_all_tables(self, *, export_type: str = "json") -> None:
# Convert to list for exporter compatibility
self.exporter.export_data(list(data), file_path=file_path, export_type=export_type) # type: ignore[arg-type]

def _get_attack_result_harm_category_condition(self, *, targeted_harm_categories: Sequence[str]) -> Any:
"""
SQLite implementation for filtering AttackResults by targeted harm categories.
Uses json_extract() function specific to SQLite.

Returns:
Any: A SQLAlchemy subquery for filtering by targeted harm categories.
"""
from sqlalchemy import and_, exists, func

from pyrit.memory.memory_models import AttackResultEntry, PromptMemoryEntry

targeted_harm_categories_subquery = exists().where(
and_(
PromptMemoryEntry.conversation_id == AttackResultEntry.conversation_id,
# Exclude empty strings, None, and empty lists
PromptMemoryEntry.targeted_harm_categories.isnot(None),
PromptMemoryEntry.targeted_harm_categories != "",
PromptMemoryEntry.targeted_harm_categories != "[]",
and_(
*[
func.json_extract(PromptMemoryEntry.targeted_harm_categories, "$").like(f'%"{category}"%')
for category in targeted_harm_categories
]
),
)
)
return targeted_harm_categories_subquery # noqa: RET504

def _get_attack_result_label_condition(self, *, labels: dict[str, str]) -> Any:
"""
SQLite implementation for filtering AttackResults by labels.
Expand Down
35 changes: 12 additions & 23 deletions pyrit/models/message_piece.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

import uuid
from datetime import datetime, timezone
from typing import TYPE_CHECKING, Any, Literal, Optional, Union, get_args
from typing import TYPE_CHECKING, Any, Optional, Union, get_args
from uuid import uuid4

from pyrit.identifiers.component_identifier import ComponentIdentifier
Expand All @@ -14,8 +14,6 @@
if TYPE_CHECKING:
from pyrit.models.score import Score

Originator = Literal["attack", "converter", "undefined", "scorer"]


class MessagePiece:
"""
Expand All @@ -42,15 +40,11 @@ def __init__(
converter_identifiers: Optional[list[Union[ComponentIdentifier, dict[str, str]]]] = None,
prompt_target_identifier: Optional[Union[ComponentIdentifier, dict[str, Any]]] = None,
attack_identifier: Optional[Union[ComponentIdentifier, dict[str, str]]] = None,
scorer_identifier: Optional[Union[ComponentIdentifier, dict[str, str]]] = None,
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why do we get rid of scorer_identifier but not the other identifiers ?

original_value_data_type: PromptDataType = "text",
converted_value_data_type: Optional[PromptDataType] = None,
response_error: PromptResponseError = "none",
originator: Originator = "undefined",
original_prompt_id: Optional[uuid.UUID] = None,
timestamp: Optional[datetime] = None,
scores: Optional[list[Score]] = None,
targeted_harm_categories: Optional[list[str]] = None,
):
"""
Initialize a MessagePiece.
Expand All @@ -74,16 +68,11 @@ def __init__(
objects or dicts (deprecated, will be removed in 0.14.0). Defaults to None.
prompt_target_identifier: The target identifier for the prompt. Defaults to None.
attack_identifier: The attack identifier for the prompt. Defaults to None.
scorer_identifier: The scorer identifier for the prompt. Accepts a ComponentIdentifier.
Defaults to None.
original_value_data_type: The data type of the original prompt (text, image). Defaults to "text".
converted_value_data_type: The data type of the converted prompt (text, image). Defaults to "text".
response_error: The response error type. Defaults to "none".
originator: The originator of the prompt. Defaults to "undefined".
original_prompt_id: The original prompt id. It is equal to id unless it is a duplicate. Defaults to None.
timestamp: The timestamp of the memory entry. Defaults to None (auto-generated).
scores: The scores associated with the prompt. Defaults to None.
targeted_harm_categories: The harm categories associated with the prompt. Defaults to None.

Raises:
ValueError: If role, data types, or response error are invalid.
Expand Down Expand Up @@ -134,11 +123,6 @@ def __init__(
ComponentIdentifier.normalize(attack_identifier) if attack_identifier else None
)

# Handle scorer_identifier: normalize to ComponentIdentifier (handles dict with deprecation warning)
self.scorer_identifier: Optional[ComponentIdentifier] = (
ComponentIdentifier.normalize(scorer_identifier) if scorer_identifier else None
)

self.original_value = original_value

if original_value_data_type not in get_args(PromptDataType):
Expand All @@ -161,13 +145,21 @@ def __init__(
raise ValueError(f"response_error {response_error} is not a valid response error.")

self.response_error = response_error
self.originator = originator

# Original prompt id defaults to id (assumes that this is the original prompt, not a duplicate)
self.original_prompt_id = original_prompt_id or self.id

self.scores = scores if scores else []
self.targeted_harm_categories = targeted_harm_categories if targeted_harm_categories else []
# Scores are not set via constructor. They are hydrated by the memory layer
# via _set_scores() after construction.
self._scores: list[Score] = []

@property
def scores(self) -> list[Score]:
"""Scores associated with this message piece, hydrated by the memory layer."""
return list(self._scores)

def _set_scores(self, scores: list[Score]) -> None:
self._scores = scores
Comment thread
behnam-o marked this conversation as resolved.

async def set_sha256_values_async(self) -> None:
"""
Expand Down Expand Up @@ -280,22 +272,19 @@ def to_dict(self) -> dict[str, object]:
"sequence": self.sequence,
"timestamp": self.timestamp.isoformat() if self.timestamp else None,
"labels": self.labels,
"targeted_harm_categories": self.targeted_harm_categories if self.targeted_harm_categories else None,
"prompt_metadata": self.prompt_metadata,
"converter_identifiers": [conv.to_dict() for conv in self.converter_identifiers],
"prompt_target_identifier": (
self.prompt_target_identifier.to_dict() if self.prompt_target_identifier else None
),
"attack_identifier": self.attack_identifier.to_dict() if self.attack_identifier else None,
"scorer_identifier": self.scorer_identifier.to_dict() if self.scorer_identifier else None,
"original_value_data_type": self.original_value_data_type,
"original_value": self.original_value,
"original_value_sha256": self.original_value_sha256,
"converted_value_data_type": self.converted_value_data_type,
"converted_value": self.converted_value,
"converted_value_sha256": self.converted_value_sha256,
"response_error": self.response_error,
"originator": self.originator,
"original_prompt_id": str(self.original_prompt_id),
"scores": [score.to_dict() for score in self.scores],
}
Expand Down
1 change: 0 additions & 1 deletion pyrit/score/conversation_scorer.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,6 @@ async def _score_async(self, message: Message, *, objective: Optional[str] = Non
original_value_data_type=original_piece.original_value_data_type,
converted_value_data_type=original_piece.converted_value_data_type,
response_error=original_piece.response_error,
originator=original_piece.originator,
original_prompt_id=(
cast("UUID", original_piece.original_prompt_id)
if isinstance(original_piece.original_prompt_id, str)
Expand Down
Loading
Loading