diff --git a/doc/code/memory/3_memory_data_types.md b/doc/code/memory/3_memory_data_types.md index b35daaa005..3e02a80531 100644 --- a/doc/code/memory/3_memory_data_types.md +++ b/doc/code/memory/3_memory_data_types.md @@ -25,11 +25,8 @@ One of the most fundamental data structures in PyRIT is [MessagePiece](../../../ - **`converter_identifiers`**: List of converters applied to transform the prompt - **`prompt_target_identifier`**: Information about the target that received this prompt - **`attack_identifier`**: Information about the attack that generated this prompt -- **`scorer_identifier`**: Information about the scorer that evaluated this prompt - **`response_error`**: Error status (e.g., `none`, `blocked`, `processing`) -- **`originator`**: Source of the prompt (`attack`, `converter`, `scorer`, `undefined`) - **`scores`**: List of `Score` objects associated with this piece -- **`targeted_harm_categories`**: Harm categories associated with the prompt - **`timestamp`**: When the piece was created This rich context allows PyRIT to track the full lifecycle of each interaction, including transformations, targeting, scoring, and error handling. diff --git a/pyrit/memory/azure_sql_memory.py b/pyrit/memory/azure_sql_memory.py index 207a7f7f98..6952a06271 100644 --- a/pyrit/memory/azure_sql_memory.py +++ b/pyrit/memory/azure_sql_memory.py @@ -407,43 +407,6 @@ def _get_condition_json_array_match( combined = " AND ".join(conditions) return text(f"""ISJSON("{table_name}".{column_name}) = 1 AND {combined}""").bindparams(**bindparams_dict) - def _get_attack_result_harm_category_condition(self, *, targeted_harm_categories: Sequence[str]) -> Any: - """ - Get the SQL Azure implementation for filtering AttackResults by targeted harm categories. - - Uses JSON_QUERY() function specific to SQL Azure to check if categories exist in the JSON array. - - Args: - targeted_harm_categories (Sequence[str]): List of harm category strings to filter by. - - Returns: - Any: SQLAlchemy exists subquery condition with bound parameters. - """ - # For SQL Azure, we need to use JSON_QUERY to check if a value exists in a JSON array - # OPENJSON can parse the array and we check if the category exists - # Using parameterized queries for safety - harm_conditions = [] - bindparams_dict = {} - for i, category in enumerate(targeted_harm_categories): - param_name = f"harm_cat_{i}" - # Check if the JSON array contains the category value - harm_conditions.append( - f"EXISTS(SELECT 1 FROM OPENJSON(targeted_harm_categories) WHERE value = :{param_name})" - ) - bindparams_dict[param_name] = category - - combined_conditions = " AND ".join(harm_conditions) - - return exists().where( - and_( - PromptMemoryEntry.conversation_id == AttackResultEntry.conversation_id, - PromptMemoryEntry.targeted_harm_categories.isnot(None), - PromptMemoryEntry.targeted_harm_categories != "", - PromptMemoryEntry.targeted_harm_categories != "[]", - text(f"ISJSON(targeted_harm_categories) = 1 AND {combined_conditions}").bindparams(**bindparams_dict), - ) - ) - def _get_attack_result_label_condition(self, *, labels: dict[str, str]) -> Any: """ Get the SQL Azure implementation for filtering AttackResults by labels. diff --git a/pyrit/memory/memory_interface.py b/pyrit/memory/memory_interface.py index a0abed1476..0d74c038e1 100644 --- a/pyrit/memory/memory_interface.py +++ b/pyrit/memory/memory_interface.py @@ -402,19 +402,6 @@ def _update_entries(self, *, entries: MutableSequence[Base], update_fields: dict update_fields (dict): A dictionary of field names and their new values. """ - @abc.abstractmethod - def _get_attack_result_harm_category_condition(self, *, targeted_harm_categories: Sequence[str]) -> Any: - """ - Return a database-specific condition for filtering AttackResults by targeted harm categories - in the associated PromptMemoryEntry records. - - Args: - targeted_harm_categories: List of harm categories that must ALL be present. - - Returns: - Database-specific SQLAlchemy condition. - """ - @abc.abstractmethod def _get_attack_result_label_condition(self, *, labels: dict[str, str]) -> Any: """ @@ -1453,7 +1440,6 @@ def get_attack_results( outcome: Optional[str] = None, attack_class: Optional[str] = None, converter_classes: Optional[Sequence[str]] = None, - targeted_harm_categories: Optional[Sequence[str]] = None, labels: Optional[dict[str, str]] = None, identifier_filters: Optional[Sequence[IdentifierFilter]] = None, ) -> Sequence[AttackResult]: @@ -1473,13 +1459,6 @@ def get_attack_results( converter_classes (Optional[Sequence[str]], optional): Filter by converter class names. Returns only attacks that used ALL specified converters (AND logic, case-insensitive). Defaults to None. - targeted_harm_categories (Optional[Sequence[str]], optional): - A list of targeted harm categories to filter results by. - These targeted harm categories are associated with the prompts themselves, - meaning they are harm(s) we're trying to elicit with the prompt, - not necessarily one(s) that were found in the response. - By providing a list, this means ALL categories in the list must be present. - Defaults to None. labels (Optional[dict[str, str]], optional): A dictionary of memory labels to filter results by. These labels are associated with the prompts themselves, used for custom tagging and tracking. Defaults to None. @@ -1530,12 +1509,6 @@ def get_attack_results( ) ) - if targeted_harm_categories: - # Use database-specific JSON query method - conditions.append( - self._get_attack_result_harm_category_condition(targeted_harm_categories=targeted_harm_categories) - ) - if labels: # Use database-specific JSON query method conditions.append(self._get_attack_result_label_condition(labels=labels)) diff --git a/pyrit/memory/memory_models.py b/pyrit/memory/memory_models.py index b34c906af6..fd4b63ad18 100644 --- a/pyrit/memory/memory_models.py +++ b/pyrit/memory/memory_models.py @@ -156,7 +156,6 @@ class PromptMemoryEntry(Base): Can be the same number for multi-part requests or multi-part responses. timestamp (DateTime): The timestamp of the memory entry. labels (Dict[str, str]): The labels associated with the memory entry. Several can be standardized. - targeted_harm_categories (List[str]): The targeted harm categories for the memory entry. prompt_metadata (JSON): The metadata associated with the prompt. This can be specific to any scenarios. Because memory is how components talk with each other, this can be component specific. e.g. the URI from a file uploaded to a blob store, or a document type you want to upload. @@ -188,7 +187,6 @@ class PromptMemoryEntry(Base): timestamp = mapped_column(DateTime, nullable=False) labels: Mapped[dict[str, str]] = mapped_column(JSON) prompt_metadata: Mapped[dict[str, Union[str, int]]] = mapped_column(JSON) - targeted_harm_categories: Mapped[Optional[list[str]]] = mapped_column(JSON) converter_identifiers: Mapped[Optional[list[dict[str, str]]]] = mapped_column(JSON) prompt_target_identifier: Mapped[dict[str, str]] = mapped_column(JSON) attack_identifier: Mapped[dict[str, str]] = mapped_column(JSON) @@ -235,7 +233,6 @@ def __init__(self, *, entry: MessagePiece): self.timestamp = entry.timestamp self.labels = entry.labels self.prompt_metadata = entry.prompt_metadata - self.targeted_harm_categories = entry.targeted_harm_categories self.converter_identifiers = [ conv.to_dict(max_value_length=MAX_IDENTIFIER_VALUE_LENGTH) for conv in entry.converter_identifiers ] @@ -303,7 +300,6 @@ def get_message_piece(self) -> MessagePiece: sequence=self.sequence, labels=self.labels, prompt_metadata=self.prompt_metadata, - targeted_harm_categories=self.targeted_harm_categories, converter_identifiers=converter_ids, prompt_target_identifier=target_id, attack_identifier=attack_id, @@ -313,7 +309,7 @@ def get_message_piece(self) -> MessagePiece: original_prompt_id=self.original_prompt_id, timestamp=_ensure_utc(self.timestamp), ) - message_piece.scores = [score.get_score() for score in self.scores] + message_piece._set_scores([score.get_score() for score in self.scores]) return message_piece def __str__(self) -> str: diff --git a/pyrit/memory/sqlite_memory.py b/pyrit/memory/sqlite_memory.py index bd376d67cd..979ab8f3f0 100644 --- a/pyrit/memory/sqlite_memory.py +++ b/pyrit/memory/sqlite_memory.py @@ -582,35 +582,6 @@ def export_all_tables(self, *, export_type: str = "json") -> None: # Convert to list for exporter compatibility self.exporter.export_data(list(data), file_path=file_path, export_type=export_type) # type: ignore[arg-type] - def _get_attack_result_harm_category_condition(self, *, targeted_harm_categories: Sequence[str]) -> Any: - """ - SQLite implementation for filtering AttackResults by targeted harm categories. - Uses json_extract() function specific to SQLite. - - Returns: - Any: A SQLAlchemy subquery for filtering by targeted harm categories. - """ - from sqlalchemy import and_, exists, func - - from pyrit.memory.memory_models import AttackResultEntry, PromptMemoryEntry - - targeted_harm_categories_subquery = exists().where( - and_( - PromptMemoryEntry.conversation_id == AttackResultEntry.conversation_id, - # Exclude empty strings, None, and empty lists - PromptMemoryEntry.targeted_harm_categories.isnot(None), - PromptMemoryEntry.targeted_harm_categories != "", - PromptMemoryEntry.targeted_harm_categories != "[]", - and_( - *[ - func.json_extract(PromptMemoryEntry.targeted_harm_categories, "$").like(f'%"{category}"%') - for category in targeted_harm_categories - ] - ), - ) - ) - return targeted_harm_categories_subquery # noqa: RET504 - def _get_attack_result_label_condition(self, *, labels: dict[str, str]) -> Any: """ SQLite implementation for filtering AttackResults by labels. diff --git a/pyrit/models/message_piece.py b/pyrit/models/message_piece.py index 91d01032bf..2cfcf6f880 100644 --- a/pyrit/models/message_piece.py +++ b/pyrit/models/message_piece.py @@ -5,7 +5,7 @@ import uuid from datetime import datetime, timezone -from typing import TYPE_CHECKING, Any, Literal, Optional, Union, get_args +from typing import TYPE_CHECKING, Any, Optional, Union, get_args from uuid import uuid4 from pyrit.identifiers.component_identifier import ComponentIdentifier @@ -14,8 +14,6 @@ if TYPE_CHECKING: from pyrit.models.score import Score -Originator = Literal["attack", "converter", "undefined", "scorer"] - class MessagePiece: """ @@ -42,15 +40,11 @@ def __init__( converter_identifiers: Optional[list[Union[ComponentIdentifier, dict[str, str]]]] = None, prompt_target_identifier: Optional[Union[ComponentIdentifier, dict[str, Any]]] = None, attack_identifier: Optional[Union[ComponentIdentifier, dict[str, str]]] = None, - scorer_identifier: Optional[Union[ComponentIdentifier, dict[str, str]]] = None, original_value_data_type: PromptDataType = "text", converted_value_data_type: Optional[PromptDataType] = None, response_error: PromptResponseError = "none", - originator: Originator = "undefined", original_prompt_id: Optional[uuid.UUID] = None, timestamp: Optional[datetime] = None, - scores: Optional[list[Score]] = None, - targeted_harm_categories: Optional[list[str]] = None, ): """ Initialize a MessagePiece. @@ -74,16 +68,11 @@ def __init__( objects or dicts (deprecated, will be removed in 0.14.0). Defaults to None. prompt_target_identifier: The target identifier for the prompt. Defaults to None. attack_identifier: The attack identifier for the prompt. Defaults to None. - scorer_identifier: The scorer identifier for the prompt. Accepts a ComponentIdentifier. - Defaults to None. original_value_data_type: The data type of the original prompt (text, image). Defaults to "text". converted_value_data_type: The data type of the converted prompt (text, image). Defaults to "text". response_error: The response error type. Defaults to "none". - originator: The originator of the prompt. Defaults to "undefined". original_prompt_id: The original prompt id. It is equal to id unless it is a duplicate. Defaults to None. timestamp: The timestamp of the memory entry. Defaults to None (auto-generated). - scores: The scores associated with the prompt. Defaults to None. - targeted_harm_categories: The harm categories associated with the prompt. Defaults to None. Raises: ValueError: If role, data types, or response error are invalid. @@ -134,11 +123,6 @@ def __init__( ComponentIdentifier.normalize(attack_identifier) if attack_identifier else None ) - # Handle scorer_identifier: normalize to ComponentIdentifier (handles dict with deprecation warning) - self.scorer_identifier: Optional[ComponentIdentifier] = ( - ComponentIdentifier.normalize(scorer_identifier) if scorer_identifier else None - ) - self.original_value = original_value if original_value_data_type not in get_args(PromptDataType): @@ -161,13 +145,21 @@ def __init__( raise ValueError(f"response_error {response_error} is not a valid response error.") self.response_error = response_error - self.originator = originator # Original prompt id defaults to id (assumes that this is the original prompt, not a duplicate) self.original_prompt_id = original_prompt_id or self.id - self.scores = scores if scores else [] - self.targeted_harm_categories = targeted_harm_categories if targeted_harm_categories else [] + # Scores are not set via constructor. They are hydrated by the memory layer + # via _set_scores() after construction. + self._scores: list[Score] = [] + + @property + def scores(self) -> list[Score]: + """Scores associated with this message piece, hydrated by the memory layer.""" + return list(self._scores) + + def _set_scores(self, scores: list[Score]) -> None: + self._scores = scores async def set_sha256_values_async(self) -> None: """ @@ -280,14 +272,12 @@ def to_dict(self) -> dict[str, object]: "sequence": self.sequence, "timestamp": self.timestamp.isoformat() if self.timestamp else None, "labels": self.labels, - "targeted_harm_categories": self.targeted_harm_categories if self.targeted_harm_categories else None, "prompt_metadata": self.prompt_metadata, "converter_identifiers": [conv.to_dict() for conv in self.converter_identifiers], "prompt_target_identifier": ( self.prompt_target_identifier.to_dict() if self.prompt_target_identifier else None ), "attack_identifier": self.attack_identifier.to_dict() if self.attack_identifier else None, - "scorer_identifier": self.scorer_identifier.to_dict() if self.scorer_identifier else None, "original_value_data_type": self.original_value_data_type, "original_value": self.original_value, "original_value_sha256": self.original_value_sha256, @@ -295,7 +285,6 @@ def to_dict(self) -> dict[str, object]: "converted_value": self.converted_value, "converted_value_sha256": self.converted_value_sha256, "response_error": self.response_error, - "originator": self.originator, "original_prompt_id": str(self.original_prompt_id), "scores": [score.to_dict() for score in self.scores], } diff --git a/pyrit/score/conversation_scorer.py b/pyrit/score/conversation_scorer.py index e2b7a5ce13..012a3113fa 100644 --- a/pyrit/score/conversation_scorer.py +++ b/pyrit/score/conversation_scorer.py @@ -90,7 +90,6 @@ async def _score_async(self, message: Message, *, objective: Optional[str] = Non original_value_data_type=original_piece.original_value_data_type, converted_value_data_type=original_piece.converted_value_data_type, response_error=original_piece.response_error, - originator=original_piece.originator, original_prompt_id=( cast("UUID", original_piece.original_prompt_id) if isinstance(original_piece.original_prompt_id, str) diff --git a/tests/integration/memory/test_azure_sql_memory_integration.py b/tests/integration/memory/test_azure_sql_memory_integration.py index 5d167e8e7b..55623ef722 100644 --- a/tests/integration/memory/test_azure_sql_memory_integration.py +++ b/tests/integration/memory/test_azure_sql_memory_integration.py @@ -234,93 +234,6 @@ async def test_get_seeds_with_metadata_filter(azuresql_instance: AzureSQLMemory) assert azuresql_instance.get_seeds(metadata={"key2": value1}, added_by=test_id) == [] -@pytest.mark.asyncio -async def test_get_attack_results_by_harm_categories(azuresql_instance: AzureSQLMemory): - """ - Integration test for SQL Azure JSON filtering on targeted harm categories. - - Tests that harm category filtering requires ALL specified categories to be present - (AND logic, not OR). Verifies both single and multiple category filters work correctly. - """ - # Use unique conversation IDs to avoid test pollution - test_id = generate_test_id() - - conversation_ids = [ - f"conv_harm_1_{test_id}", - f"conv_harm_2_{test_id}", - f"conv_harm_3_{test_id}", - ] - - with cleanup_conversation_data(azuresql_instance, conversation_ids): - # Create message pieces with harm categories - piece1 = MessagePiece( - conversation_id=conversation_ids[0], - role="user", - original_value="Test 1", - converted_value="Test 1", - targeted_harm_categories=["hate", "violence"], - ) - piece2 = MessagePiece( - conversation_id=conversation_ids[1], - role="user", - original_value="Test 2", - converted_value="Test 2", - targeted_harm_categories=["hate"], - ) - piece3 = MessagePiece( - conversation_id=conversation_ids[2], - role="user", - original_value="Test 3", - converted_value="Test 3", - targeted_harm_categories=["violence"], - ) - - azuresql_instance.add_message_pieces_to_memory(message_pieces=[piece1, piece2, piece3]) - - # Create attack results - atomic_id = get_test_atomic_attack_identifier() - result1 = AttackResult( - conversation_id=conversation_ids[0], - objective="Test objective 1", - atomic_attack_identifier=atomic_id, - outcome=AttackOutcome.SUCCESS, - ) - result2 = AttackResult( - conversation_id=conversation_ids[1], - objective="Test objective 2", - atomic_attack_identifier=atomic_id, - outcome=AttackOutcome.SUCCESS, - ) - result3 = AttackResult( - conversation_id=conversation_ids[2], - objective="Test objective 3", - atomic_attack_identifier=atomic_id, - outcome=AttackOutcome.FAILURE, - ) - - azuresql_instance.add_attack_results_to_memory(attack_results=[result1, result2, result3]) - - # Test filtering by single harm category - results = azuresql_instance.get_attack_results(targeted_harm_categories=["hate"]) - # Filter to only results from this test - results = [r for r in results if test_id in r.conversation_id] - assert len(results) == 2 - conv_ids = {r.conversation_id for r in results} - assert conversation_ids[0] in conv_ids - assert conversation_ids[1] in conv_ids - - # Test filtering by multiple harm categories (ALL must be present) - results = azuresql_instance.get_attack_results(targeted_harm_categories=["hate", "violence"]) - results = [r for r in results if test_id in r.conversation_id] - assert len(results) == 1 - assert results[0].conversation_id == conversation_ids[0] - - # Test filtering with no matches - results = azuresql_instance.get_attack_results(targeted_harm_categories=["hate", "self-harm"]) - results = [r for r in results if test_id in r.conversation_id] - assert len(results) == 0 - - @pytest.mark.asyncio async def test_get_attack_results_by_labels(azuresql_instance: AzureSQLMemory): """ diff --git a/tests/unit/executor/attack/component/test_conversation_manager.py b/tests/unit/executor/attack/component/test_conversation_manager.py index c86e741e9c..e52a5d84e8 100644 --- a/tests/unit/executor/attack/component/test_conversation_manager.py +++ b/tests/unit/executor/attack/component/test_conversation_manager.py @@ -872,8 +872,8 @@ async def test_multipart_message_extracts_scores_from_all_pieces( original_value="Here is the analysis:", original_value_data_type="text", conversation_id=piece_conversation_id, - scores=[score1], # Attach score directly to piece ) + piece1._set_scores([score1]) # Create score for second piece # Also false since prepended conversations only extract false scores @@ -892,8 +892,8 @@ async def test_multipart_message_extracts_scores_from_all_pieces( original_value="chart_image.png", original_value_data_type="image_path", conversation_id=piece_conversation_id, - scores=[score2], # Attach score directly to piece ) + piece2._set_scores([score2]) multipart_response = Message(message_pieces=[piece1, piece2]) context.prepended_conversation = [ @@ -958,16 +958,16 @@ async def test_prepended_conversation_ignores_true_scores( original_value="Simulated success response", original_value_data_type="text", conversation_id=str(uuid.uuid4()), - scores=[true_score], ) + piece_with_true._set_scores([true_score]) piece_with_false = MessagePiece( role="assistant", original_value="Simulated refusal response", original_value_data_type="text", conversation_id=str(uuid.uuid4()), - scores=[false_score], ) + piece_with_false._set_scores([false_score]) # Test with true score only - should get no scores context.prepended_conversation = [ diff --git a/tests/unit/memory/memory_interface/test_interface_attack_results.py b/tests/unit/memory/memory_interface/test_interface_attack_results.py index 2e30ba368a..b40e640670 100644 --- a/tests/unit/memory/memory_interface/test_interface_attack_results.py +++ b/tests/unit/memory/memory_interface/test_interface_attack_results.py @@ -24,14 +24,13 @@ from collections.abc import Sequence -def create_message_piece(conversation_id: str, prompt_num: int, targeted_harm_categories=None, labels=None): - """Helper function to create MessagePiece with optional targeted harm categories and labels.""" +def create_message_piece(conversation_id: str, prompt_num: int, labels=None): + """Helper function to create MessagePiece with optional labels.""" return MessagePiece( role="user", original_value=f"Test prompt {prompt_num}", converted_value=f"Test prompt {prompt_num}", conversation_id=conversation_id, - targeted_harm_categories=targeted_harm_categories, labels=labels, ) @@ -721,62 +720,6 @@ def test_update_attack_result_stale_entry_does_not_overwrite(sqlite_instance: Me assert results[0].related_conversations.pop().conversation_id == "branch-1" -def test_get_attack_results_by_harm_category_single(sqlite_instance: MemoryInterface): - """Test filtering attack results by a single harm category.""" - - # Create message pieces with harm categories using helper function - message_piece1 = create_message_piece("conv_1", 1, targeted_harm_categories=["violence", "illegal"]) - message_piece2 = create_message_piece("conv_2", 2, targeted_harm_categories=["illegal"]) - message_piece3 = create_message_piece("conv_3", 3, targeted_harm_categories=["violence"]) - - # Add message pieces to memory - sqlite_instance.add_message_pieces_to_memory(message_pieces=[message_piece1, message_piece2, message_piece3]) - - # Create attack results using helper function - attack_result1 = create_attack_result("conv_1", 1, AttackOutcome.SUCCESS) - attack_result2 = create_attack_result("conv_2", 2, AttackOutcome.FAILURE) - attack_result3 = create_attack_result("conv_3", 3, AttackOutcome.SUCCESS) - - sqlite_instance.add_attack_results_to_memory(attack_results=[attack_result1, attack_result2, attack_result3]) - - violence_results = sqlite_instance.get_attack_results(targeted_harm_categories=["violence"]) - assert len(violence_results) == 2 - conversation_ids = {result.conversation_id for result in violence_results} - assert conversation_ids == {"conv_1", "conv_3"} - - illegal_results = sqlite_instance.get_attack_results(targeted_harm_categories=["illegal"]) - assert len(illegal_results) == 2 - conversation_ids = {result.conversation_id for result in illegal_results} - assert conversation_ids == {"conv_1", "conv_2"} - - -def test_get_attack_results_by_harm_category_multiple(sqlite_instance: MemoryInterface): - """Test filtering attack results by multiple harm categories (AND logic).""" - - # Create message pieces with different harm category combinations - message_piece1 = create_message_piece("conv_1", 1, targeted_harm_categories=["violence", "illegal", "hate"]) - message_piece2 = create_message_piece("conv_2", 2, targeted_harm_categories=["violence", "illegal"]) - message_piece3 = create_message_piece("conv_3", 3, targeted_harm_categories=["violence"]) - - sqlite_instance.add_message_pieces_to_memory(message_pieces=[message_piece1, message_piece2, message_piece3]) - - # Create attack results - attack_result1 = create_attack_result("conv_1", 1, AttackOutcome.SUCCESS) - attack_result2 = create_attack_result("conv_2", 2, AttackOutcome.SUCCESS) - attack_result3 = create_attack_result("conv_3", 3, AttackOutcome.FAILURE) - - sqlite_instance.add_attack_results_to_memory(attack_results=[attack_result1, attack_result2, attack_result3]) - - # Test filtering by multiple harm categories - violence_and_illegal_results = sqlite_instance.get_attack_results(targeted_harm_categories=["violence", "illegal"]) - assert len(violence_and_illegal_results) == 2 - conversation_ids = {result.conversation_id for result in violence_and_illegal_results} - assert conversation_ids == {"conv_1", "conv_2"} - all_three_results = sqlite_instance.get_attack_results(targeted_harm_categories=["violence", "illegal", "hate"]) - assert len(all_three_results) == 1 - assert all_three_results[0].conversation_id == "conv_1" - - def test_get_attack_results_by_labels_single(sqlite_instance: MemoryInterface): """Test filtering attack results by single label.""" @@ -839,68 +782,6 @@ def test_get_attack_results_by_labels_multiple(sqlite_instance: MemoryInterface) assert conversation_ids == {"conv_1", "conv_2"} -def test_get_attack_results_by_harm_category_and_labels(sqlite_instance: MemoryInterface): - """Test filtering attack results by both harm categories and labels.""" - - # Create message pieces with both harm categories and labels using helper function - message_piece1 = create_message_piece( - "conv_1", - 1, - targeted_harm_categories=["violence", "illegal"], - labels={"operation": "test_op", "operator": "roakey"}, - ) - message_piece2 = create_message_piece( - "conv_2", 2, targeted_harm_categories=["violence"], labels={"operation": "test_op", "operator": "roakey"} - ) - message_piece3 = create_message_piece( - "conv_3", - 3, - targeted_harm_categories=["violence", "illegal"], - labels={"operation": "other_op", "operator": "bob"}, - ) - - sqlite_instance.add_message_pieces_to_memory(message_pieces=[message_piece1, message_piece2, message_piece3]) - - # Create attack results - attack_results = [ - create_attack_result("conv_1", 1, AttackOutcome.SUCCESS), - create_attack_result("conv_2", 2, AttackOutcome.SUCCESS), - create_attack_result("conv_3", 3, AttackOutcome.FAILURE), - ] - - sqlite_instance.add_attack_results_to_memory(attack_results=attack_results) - - # Test filtering by both harm categories and labels - violence_illegal_roakey_results = sqlite_instance.get_attack_results( - targeted_harm_categories=["violence", "illegal"], labels={"operator": "roakey"} - ) - assert len(violence_illegal_roakey_results) == 1 - assert violence_illegal_roakey_results[0].conversation_id == "conv_1" - - # Test filtering by harm category and operation - violence_test_op_results = sqlite_instance.get_attack_results( - targeted_harm_categories=["violence"], labels={"operation": "test_op"} - ) - assert len(violence_test_op_results) == 2 - conversation_ids = {result.conversation_id for result in violence_test_op_results} - assert conversation_ids == {"conv_1", "conv_2"} - - -def test_get_attack_results_harm_category_no_matches(sqlite_instance: MemoryInterface): - """Test filtering by harm category that doesn't exist.""" - - # Create attack result without the harm category we'll search for - message_piece = create_message_piece("conv_1", 1, targeted_harm_categories=["violence"]) - sqlite_instance.add_message_pieces_to_memory(message_pieces=[message_piece]) - - attack_result = create_attack_result("conv_1", 1, AttackOutcome.SUCCESS) - sqlite_instance.add_attack_results_to_memory(attack_results=[attack_result]) - - # Search for non-existent harm category - results = sqlite_instance.get_attack_results(targeted_harm_categories=["nonexistent"]) - assert len(results) == 0 - - def test_get_attack_results_labels_no_matches(sqlite_instance: MemoryInterface): """Test filtering by labels that don't exist.""" diff --git a/tests/unit/models/test_message_piece.py b/tests/unit/models/test_message_piece.py index d13064a321..37c4b86bb8 100644 --- a/tests/unit/models/test_message_piece.py +++ b/tests/unit/models/test_message_piece.py @@ -669,7 +669,6 @@ def test_message_piece_to_dict(): conversation_id="test_conversation", sequence=1, labels={"label1": "value1"}, - targeted_harm_categories=["violence", "illegal"], prompt_metadata={"key": "metadata"}, converter_identifiers=[ ComponentIdentifier( @@ -686,36 +685,31 @@ def test_message_piece_to_dict(): class_name="PromptSendingAttack", class_module="pyrit.executor.attack.single_turn.prompt_sending_attack", ), - scorer_identifier=ComponentIdentifier( - class_name="TestScorer", - class_module="pyrit.score.test_scorer", - ), original_value_data_type="text", converted_value_data_type="text", response_error="none", - originator="undefined", original_prompt_id=uuid.uuid4(), timestamp=datetime.now(tz=timezone.utc), - scores=[ - Score( - id=str(uuid.uuid4()), - score_value="false", - score_value_description="true false score", - score_type="true_false", - score_category=["Category1"], - score_rationale="Rationale text", - score_metadata={"key": "value"}, - scorer_class_identifier=ComponentIdentifier( - class_name="Scorer1", - class_module="pyrit.score", - ), - message_piece_id=str(uuid.uuid4()), - timestamp=datetime.now(tz=timezone.utc), - objective="Task1", - ) - ], ) + test_score = Score( + id=str(uuid.uuid4()), + score_value="false", + score_value_description="true false score", + score_type="true_false", + score_category=["Category1"], + score_rationale="Rationale text", + score_metadata={"key": "value"}, + scorer_class_identifier=ComponentIdentifier( + class_name="Scorer1", + class_module="pyrit.score", + ), + message_piece_id=str(uuid.uuid4()), + timestamp=datetime.now(tz=timezone.utc), + objective="Task1", + ) + entry._set_scores([test_score]) + result = entry.to_dict() expected_keys = [ @@ -725,12 +719,10 @@ def test_message_piece_to_dict(): "sequence", "timestamp", "labels", - "targeted_harm_categories", "prompt_metadata", "converter_identifiers", "prompt_target_identifier", "attack_identifier", - "scorer_identifier", "original_value_data_type", "original_value", "original_value_sha256", @@ -738,7 +730,6 @@ def test_message_piece_to_dict(): "converted_value", "converted_value_sha256", "response_error", - "originator", "original_prompt_id", "scores", ] @@ -752,12 +743,10 @@ def test_message_piece_to_dict(): assert result["sequence"] == entry.sequence assert result["timestamp"] == entry.timestamp.isoformat() assert result["labels"] == entry.labels - assert result["targeted_harm_categories"] == entry.targeted_harm_categories assert result["prompt_metadata"] == entry.prompt_metadata assert result["converter_identifiers"] == [conv.to_dict() for conv in entry.converter_identifiers] assert result["prompt_target_identifier"] == entry.prompt_target_identifier.to_dict() assert result["attack_identifier"] == entry.attack_identifier.to_dict() - assert result["scorer_identifier"] == entry.scorer_identifier.to_dict() assert result["original_value_data_type"] == entry.original_value_data_type assert result["original_value"] == entry.original_value assert result["original_value_sha256"] == entry.original_value_sha256 @@ -765,52 +754,10 @@ def test_message_piece_to_dict(): assert result["converted_value"] == entry.converted_value assert result["converted_value_sha256"] == entry.converted_value_sha256 assert result["response_error"] == entry.response_error - assert result["originator"] == entry.originator assert result["original_prompt_id"] == str(entry.original_prompt_id) assert result["scores"] == [score.to_dict() for score in entry.scores] -def test_message_piece_scorer_identifier_dict_backward_compatibility(): - """Test that passing a dict for scorer_identifier normalizes to ComponentIdentifier.""" - - scorer_dict = { - "class_name": "TestScorer", - "class_module": "pyrit.score.test_scorer", - } - - entry = MessagePiece( - role="user", - original_value="Hello", - scorer_identifier=scorer_dict, - ) - - # Check that scorer_identifier is now a ComponentIdentifier - assert isinstance(entry.scorer_identifier, ComponentIdentifier) - assert entry.scorer_identifier.class_name == "TestScorer" - assert entry.scorer_identifier.class_module == "pyrit.score.test_scorer" - - -def test_message_piece_scorer_identifier_none_default(): - """Test that scorer_identifier defaults to None when not provided.""" - entry = MessagePiece( - role="user", - original_value="Hello", - ) - - assert entry.scorer_identifier is None - - -def test_message_piece_to_dict_scorer_identifier_none(): - """Test that to_dict() returns None for scorer_identifier when not set.""" - entry = MessagePiece( - role="user", - original_value="Hello", - ) - - result = entry.to_dict() - assert result["scorer_identifier"] is None - - def test_construct_response_from_request_combines_metadata(): # Create a message piece with metadata request = MessagePiece( @@ -923,66 +870,6 @@ def test_message_piece_has_error_and_is_blocked_consistency(): assert no_error_entry.has_error() is False -def test_message_piece_harm_categories_none(): - """Test that harm_categories defaults to None.""" - entry = MessagePiece( - role="user", - original_value="Hello", - converted_value="Hello", - ) - assert entry.targeted_harm_categories == [] - - -def test_message_piece_harm_categories_single(): - """Test that harm_categories can be set to a single category.""" - entry = MessagePiece( - role="user", original_value="Hello", converted_value="Hello", targeted_harm_categories=["violence"] - ) - assert entry.targeted_harm_categories == ["violence"] - - -def test_message_piece_harm_categories_multiple(): - """Test that harm_categories can be set to multiple categories.""" - harm_categories = ["violence", "illegal", "hate_speech"] - entry = MessagePiece( - role="user", original_value="Hello", converted_value="Hello", targeted_harm_categories=harm_categories - ) - assert entry.targeted_harm_categories == harm_categories - - -def test_message_piece_harm_categories_serialization(): - """Test that harm_categories is properly serialized in to_dict().""" - harm_categories = ["violence", "illegal"] - entry = MessagePiece( - role="user", original_value="Hello", converted_value="Hello", targeted_harm_categories=harm_categories - ) - - result = entry.to_dict() - assert "targeted_harm_categories" in result - assert result["targeted_harm_categories"] == harm_categories - - -def test_message_piece_harm_categories_with_labels(): - """Test that harm_categories and labels can coexist.""" - harm_categories = ["violence", "illegal"] - labels = {"operation": "test_op", "researcher": "alice"} - - entry = MessagePiece( - role="user", - original_value="Hello", - converted_value="Hello", - targeted_harm_categories=harm_categories, - labels=labels, - ) - - assert entry.targeted_harm_categories == harm_categories - assert entry.labels == labels - - result = entry.to_dict() - assert result["targeted_harm_categories"] == harm_categories - assert result["labels"] == labels - - class TestSimulatedAssistantRole: """Tests for simulated_assistant role properties."""