Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions pyrit/memory/memory_interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import atexit
import logging
import uuid
import warnings
import weakref
from collections.abc import MutableSequence, Sequence
from contextlib import closing
Expand Down Expand Up @@ -1531,6 +1532,11 @@ def get_attack_results(
)

if targeted_harm_categories:
warnings.warn(
"The 'targeted_harm_categories' parameter is deprecated and will be removed in a future release.",
DeprecationWarning,
stacklevel=2,
)
# Use database-specific JSON query method
conditions.append(
self._get_attack_result_harm_category_condition(targeted_harm_categories=targeted_harm_categories)
Expand Down
29 changes: 29 additions & 0 deletions pyrit/models/message_piece.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from __future__ import annotations

import uuid
import warnings
from datetime import datetime, timezone
from typing import TYPE_CHECKING, Any, Literal, Optional, Union, get_args
from uuid import uuid4
Expand All @@ -15,6 +16,7 @@
from pyrit.models.score import Score

Originator = Literal["attack", "converter", "undefined", "scorer"]
"""Deprecated: The Originator type alias will be removed in a future release."""


class MessagePiece:
Expand Down Expand Up @@ -135,6 +137,12 @@ def __init__(
)

# Handle scorer_identifier: normalize to ComponentIdentifier (handles dict with deprecation warning)
if scorer_identifier is not None:
warnings.warn(
"The 'scorer_identifier' parameter is deprecated and will be removed in a future release.",
DeprecationWarning,
stacklevel=2,
)
self.scorer_identifier: Optional[ComponentIdentifier] = (
ComponentIdentifier.normalize(scorer_identifier) if scorer_identifier else None
)
Expand All @@ -161,12 +169,33 @@ def __init__(
raise ValueError(f"response_error {response_error} is not a valid response error.")

self.response_error = response_error

if originator != "undefined":
warnings.warn(
"The 'originator' parameter is deprecated and will be removed in a future release.",
DeprecationWarning,
stacklevel=2,
)
self.originator = originator

# Original prompt id defaults to id (assumes that this is the original prompt, not a duplicate)
self.original_prompt_id = original_prompt_id or self.id

if scores is not None:
warnings.warn(
"The 'scores' parameter is deprecated and will be removed in a future release. "
"Scores are now hydrated by the memory layer.",
DeprecationWarning,
stacklevel=2,
)
self.scores = scores if scores else []

if targeted_harm_categories is not None:
warnings.warn(
"The 'targeted_harm_categories' parameter is deprecated and will be removed in a future release.",
DeprecationWarning,
stacklevel=2,
)
self.targeted_harm_categories = targeted_harm_categories if targeted_harm_categories else []

async def set_sha256_values_async(self) -> None:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1414,3 +1414,20 @@ def test_get_attack_results_by_attack_identifier_filter_no_match(sqlite_instance
],
)
assert len(results) == 0


def test_get_attack_results_targeted_harm_categories_emits_deprecation_warning(sqlite_instance: MemoryInterface):
"""Test that passing targeted_harm_categories emits a DeprecationWarning."""
import warnings

message_piece = create_message_piece("conv_1", 1, targeted_harm_categories=["violence"])
sqlite_instance.add_message_pieces_to_memory(message_pieces=[message_piece])

attack_result = create_attack_result("conv_1", 1, AttackOutcome.SUCCESS)
sqlite_instance.add_attack_results_to_memory(attack_results=[attack_result])

with warnings.catch_warnings(record=True) as w:
warnings.simplefilter("always")
sqlite_instance.get_attack_results(targeted_harm_categories=["violence"])
deprecation_msgs = [x for x in w if issubclass(x.category, DeprecationWarning)]
assert any("targeted_harm_categories" in str(m.message) for m in deprecation_msgs)
66 changes: 66 additions & 0 deletions tests/unit/models/test_message_piece.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import tempfile
import time
import uuid
import warnings
from collections.abc import MutableSequence
from datetime import datetime, timedelta, timezone

Expand Down Expand Up @@ -1043,3 +1044,68 @@ def test_role_setter_sets_simulated_assistant(self):
assert piece.get_role_for_storage() == "simulated_assistant"
assert piece.api_role == "assistant"
assert piece.is_simulated is True


class TestMessagePieceDeprecationWarnings:
"""Tests for deprecation warnings on parameters scheduled for removal."""

def test_scorer_identifier_emits_deprecation_warning(self):
with warnings.catch_warnings(record=True) as w:
warnings.simplefilter("always")
MessagePiece(
role="user",
original_value="Hello",
scorer_identifier=ComponentIdentifier(class_name="S", class_module="m"),
)
deprecation_msgs = [x for x in w if issubclass(x.category, DeprecationWarning)]
assert any("scorer_identifier" in str(m.message) for m in deprecation_msgs)

def test_scorer_identifier_none_no_warning(self):
with warnings.catch_warnings(record=True) as w:
warnings.simplefilter("always")
MessagePiece(role="user", original_value="Hello")
deprecation_msgs = [x for x in w if issubclass(x.category, DeprecationWarning)]
assert not any("scorer_identifier" in str(m.message) for m in deprecation_msgs)

def test_originator_non_default_emits_deprecation_warning(self):
with warnings.catch_warnings(record=True) as w:
warnings.simplefilter("always")
MessagePiece(role="user", original_value="Hello", originator="attack")
deprecation_msgs = [x for x in w if issubclass(x.category, DeprecationWarning)]
assert any("originator" in str(m.message) for m in deprecation_msgs)

def test_originator_default_no_warning(self):
with warnings.catch_warnings(record=True) as w:
warnings.simplefilter("always")
MessagePiece(role="user", original_value="Hello")
deprecation_msgs = [x for x in w if issubclass(x.category, DeprecationWarning)]
assert not any("originator" in str(m.message) for m in deprecation_msgs)

def test_scores_emits_deprecation_warning(self):
with warnings.catch_warnings(record=True) as w:
warnings.simplefilter("always")
MessagePiece(role="user", original_value="Hello", scores=[])
# scores=[] is falsy but not None, however the check is `scores is not None`
deprecation_msgs = [x for x in w if issubclass(x.category, DeprecationWarning)]
assert any("scores" in str(m.message) for m in deprecation_msgs)

def test_scores_none_no_warning(self):
with warnings.catch_warnings(record=True) as w:
warnings.simplefilter("always")
MessagePiece(role="user", original_value="Hello")
deprecation_msgs = [x for x in w if issubclass(x.category, DeprecationWarning)]
assert not any("scores" in str(m.message) for m in deprecation_msgs)

def test_targeted_harm_categories_emits_deprecation_warning(self):
with warnings.catch_warnings(record=True) as w:
warnings.simplefilter("always")
MessagePiece(role="user", original_value="Hello", targeted_harm_categories=["violence"])
deprecation_msgs = [x for x in w if issubclass(x.category, DeprecationWarning)]
assert any("targeted_harm_categories" in str(m.message) for m in deprecation_msgs)

def test_targeted_harm_categories_none_no_warning(self):
with warnings.catch_warnings(record=True) as w:
warnings.simplefilter("always")
MessagePiece(role="user", original_value="Hello")
deprecation_msgs = [x for x in w if issubclass(x.category, DeprecationWarning)]
assert not any("targeted_harm_categories" in str(m.message) for m in deprecation_msgs)
Loading