From 5b88eecfc345c3f21b09f59dd6cca5c327282a27 Mon Sep 17 00:00:00 2001 From: Richard Lundeen Date: Thu, 16 Apr 2026 08:44:11 -0700 Subject: [PATCH 1/3] MAINT Breaking: Updating content-harms to rapid response --- .../instructions/scenarios.instructions.md | 28 +- doc/code/scenarios/0_scenarios.ipynb | 13 +- doc/code/scenarios/0_scenarios.py | 13 +- .../scenarios/1_scenario_parameters.ipynb | 4 + doc/code/scenarios/1_scenario_parameters.py | 4 + pyrit/scenario/__init__.py | 2 +- pyrit/scenario/core/__init__.py | 10 + .../scenario/core/attack_technique_factory.py | 71 +- pyrit/scenario/core/core_techniques.py | 56 ++ pyrit/scenario/core/scenario.py | 51 ++ pyrit/scenario/core/scenario_strategy.py | 7 + pyrit/scenario/scenarios/airt/__init__.py | 7 +- .../scenario/scenarios/airt/content_harms.py | 357 +------- .../scenario/scenarios/airt/rapid_response.py | 201 +++++ .../scenario/test_attack_technique_factory.py | 20 +- tests/unit/scenario/test_content_harms.py | 808 ------------------ tests/unit/scenario/test_rapid_response.py | 569 ++++++++++++ 17 files changed, 1037 insertions(+), 1184 deletions(-) create mode 100644 pyrit/scenario/core/core_techniques.py create mode 100644 pyrit/scenario/scenarios/airt/rapid_response.py delete mode 100644 tests/unit/scenario/test_content_harms.py create mode 100644 tests/unit/scenario/test_rapid_response.py diff --git a/.github/instructions/scenarios.instructions.md b/.github/instructions/scenarios.instructions.md index ba544465fc..d4db07b744 100644 --- a/.github/instructions/scenarios.instructions.md +++ b/.github/instructions/scenarios.instructions.md @@ -94,25 +94,41 @@ Options: ## Strategy Enum -Strategies should be selectable by an axis. E.g. it could be harm category or and attack type, but likely not both or it gets confusing. +Strategy members should represent **attack techniques** — the *how* of an attack (e.g., prompt sending, role play, TAP). Datasets control *what* is tested (e.g., harm categories, compliance topics). Avoid mixing dataset/category selection into the strategy enum; use `DatasetConfiguration` and `--dataset-names` for that axis. ```python class MyStrategy(ScenarioStrategy): - ALL = ("all", {"all"}) # Required aggregate - EASY = ("easy", {"easy"}) + ALL = ("all", {"all"}) # Required aggregate + DEFAULT = ("default", {"default"}) # Recommended default aggregate + SINGLE_TURN = ("single_turn", {"single_turn"}) # Category aggregate - Base64 = ("base64", {"easy", "converter"}) - Crescendo = ("crescendo", {"difficult", "multi_turn"}) + PromptSending = ("prompt_sending", {"single_turn", "default"}) + RolePlay = ("role_play", {"single_turn"}) + ManyShot = ("many_shot", {"multi_turn", "default"}) @classmethod def get_aggregate_tags(cls) -> set[str]: - return {"all", "easy", "difficult"} + return {"all", "default", "single_turn", "multi_turn"} ``` - `ALL` aggregate is always required - Each member: `NAME = ("string_value", {tag_set})` - Aggregates expand to all strategies matching their tag +### `_build_atomic_attack_name()` — Result Grouping + +Override `_build_atomic_attack_name()` on the `Scenario` base class to control how attack results are grouped: + +```python +def _build_atomic_attack_name(self, *, technique_name: str, seed_group_name: str) -> str: + # Default: group by technique name (most common) + return technique_name + + # Override examples: + # Group by dataset/harm category: return seed_group_name + # Cross-product: return f"{technique_name}_{seed_group_name}" +``` + ## AtomicAttack Construction ```python diff --git a/doc/code/scenarios/0_scenarios.ipynb b/doc/code/scenarios/0_scenarios.ipynb index 868cd01394..ab29a07bb1 100644 --- a/doc/code/scenarios/0_scenarios.ipynb +++ b/doc/code/scenarios/0_scenarios.ipynb @@ -53,8 +53,10 @@ "\n", "### Required Components\n", "\n", - "1. **Strategy Enum**: Create a `ScenarioStrategy` enum that defines the available strategies for your scenario.\n", - " - Each enum member is defined as `(value, tags)` where value is a string and tags is a set of strings\n", + "1. **Strategy Enum**: Create a `ScenarioStrategy` enum that defines the available attack techniques for your scenario.\n", + " - Each enum member represents an **attack technique** (the *how* of an attack)\n", + " - Datasets control *what* content is tested; strategies control *how* attacks are run\n", + " - Each member is defined as `(value, tags)` where value is a string and tags is a set of strings\n", " - Include an `ALL` aggregate strategy that expands to all available strategies\n", " - Optionally implement `supports_composition()` and `validate_composition()` for strategy composition rules\n", "\n", @@ -117,8 +119,9 @@ "\n", "class MyStrategy(ScenarioStrategy):\n", " ALL = (\"all\", {\"all\"})\n", - " StrategyA = (\"strategy_a\", {\"tag1\", \"tag2\"})\n", - " StrategyB = (\"strategy_b\", {\"tag1\"})\n", + " # Strategy members represent attack techniques\n", + " PromptSending = (\"prompt_sending\", {\"single_turn\"})\n", + " RolePlay = (\"role_play\", {\"single_turn\"})\n", "\n", "\n", "class MyScenario(Scenario):\n", @@ -178,7 +181,7 @@ " # self._dataset_config is set by the parent class\n", " seed_groups = self._dataset_config.get_all_seed_groups()\n", "\n", - " # Create attack instances based on strategy\n", + " # Create attack instances based on the selected technique\n", " attack = PromptSendingAttack(\n", " objective_target=self._objective_target,\n", " attack_scoring_config=self._scorer_config,\n", diff --git a/doc/code/scenarios/0_scenarios.py b/doc/code/scenarios/0_scenarios.py index 8335c7a248..fca62237f8 100644 --- a/doc/code/scenarios/0_scenarios.py +++ b/doc/code/scenarios/0_scenarios.py @@ -59,8 +59,10 @@ # # ### Required Components # -# 1. **Strategy Enum**: Create a `ScenarioStrategy` enum that defines the available strategies for your scenario. -# - Each enum member is defined as `(value, tags)` where value is a string and tags is a set of strings +# 1. **Strategy Enum**: Create a `ScenarioStrategy` enum that defines the available attack techniques for your scenario. +# - Each enum member represents an **attack technique** (the *how* of an attack) +# - Datasets control *what* content is tested; strategies control *how* attacks are run +# - Each member is defined as `(value, tags)` where value is a string and tags is a set of strings # - Include an `ALL` aggregate strategy that expands to all available strategies # - Optionally implement `supports_composition()` and `validate_composition()` for strategy composition rules # @@ -105,8 +107,9 @@ class MyStrategy(ScenarioStrategy): ALL = ("all", {"all"}) - StrategyA = ("strategy_a", {"tag1", "tag2"}) - StrategyB = ("strategy_b", {"tag1"}) + # Strategy members represent attack techniques + PromptSending = ("prompt_sending", {"single_turn"}) + RolePlay = ("role_play", {"single_turn"}) class MyScenario(Scenario): @@ -166,7 +169,7 @@ async def _get_atomic_attacks_async(self) -> list[AtomicAttack]: # self._dataset_config is set by the parent class seed_groups = self._dataset_config.get_all_seed_groups() - # Create attack instances based on strategy + # Create attack instances based on the selected technique attack = PromptSendingAttack( objective_target=self._objective_target, attack_scoring_config=self._scorer_config, diff --git a/doc/code/scenarios/1_scenario_parameters.ipynb b/doc/code/scenarios/1_scenario_parameters.ipynb index 04d3fd6d55..aba7e8255c 100644 --- a/doc/code/scenarios/1_scenario_parameters.ipynb +++ b/doc/code/scenarios/1_scenario_parameters.ipynb @@ -11,6 +11,10 @@ "strategies, baseline execution, and custom scorers. All examples use `RedTeamAgent` but the\n", "patterns apply to any scenario.\n", "\n", + "> **Two selection axes**: *Strategies* select attack techniques (*how* attacks run — e.g., prompt\n", + "> sending, role play, TAP). *Datasets* select objectives (*what* is tested — e.g., harm categories,\n", + "> compliance topics). Use `--dataset-names` on the CLI to filter by content category.\n", + "\n", "> **Running scenarios from the command line?** See the [Scanner documentation](../../scanner/0_scanner.md).\n", "\n", "## Setup\n", diff --git a/doc/code/scenarios/1_scenario_parameters.py b/doc/code/scenarios/1_scenario_parameters.py index f62a9b85a0..8020da46f1 100644 --- a/doc/code/scenarios/1_scenario_parameters.py +++ b/doc/code/scenarios/1_scenario_parameters.py @@ -15,6 +15,10 @@ # strategies, baseline execution, and custom scorers. All examples use `RedTeamAgent` but the # patterns apply to any scenario. # +# > **Two selection axes**: *Strategies* select attack techniques (*how* attacks run — e.g., prompt +# > sending, role play, TAP). *Datasets* select objectives (*what* is tested — e.g., harm categories, +# > compliance topics). Use `--dataset-names` on the CLI to filter by content category. +# # > **Running scenarios from the command line?** See the [Scanner documentation](../../scanner/0_scanner.md). # # ## Setup diff --git a/pyrit/scenario/__init__.py b/pyrit/scenario/__init__.py index e8ebfb2946..bf758528b7 100644 --- a/pyrit/scenario/__init__.py +++ b/pyrit/scenario/__init__.py @@ -8,7 +8,7 @@ from pyrit.scenario import Scenario, AtomicAttack, ScenarioStrategy Specific scenarios should be imported from their subpackages: - from pyrit.scenario.airt import ContentHarms, Cyber + from pyrit.scenario.airt import RapidResponse, Cyber from pyrit.scenario.garak import Encoding from pyrit.scenario.foundry import RedTeamAgent """ diff --git a/pyrit/scenario/core/__init__.py b/pyrit/scenario/core/__init__.py index 8f40282bef..7affb77c5f 100644 --- a/pyrit/scenario/core/__init__.py +++ b/pyrit/scenario/core/__init__.py @@ -6,6 +6,12 @@ from pyrit.scenario.core.atomic_attack import AtomicAttack from pyrit.scenario.core.attack_technique import AttackTechnique from pyrit.scenario.core.attack_technique_factory import AttackTechniqueFactory +from pyrit.scenario.core.core_techniques import ( + many_shot_factory, + prompt_sending_factory, + role_play_factory, + tap_factory, +) from pyrit.scenario.core.dataset_configuration import EXPLICIT_SEED_GROUPS_KEY, DatasetConfiguration from pyrit.scenario.core.scenario import Scenario from pyrit.scenario.core.scenario_strategy import ScenarioCompositeStrategy, ScenarioStrategy @@ -19,4 +25,8 @@ "Scenario", "ScenarioCompositeStrategy", "ScenarioStrategy", + "many_shot_factory", + "prompt_sending_factory", + "role_play_factory", + "tap_factory", ] diff --git a/pyrit/scenario/core/attack_technique_factory.py b/pyrit/scenario/core/attack_technique_factory.py index fac94e4932..edf3934faa 100644 --- a/pyrit/scenario/core/attack_technique_factory.py +++ b/pyrit/scenario/core/attack_technique_factory.py @@ -128,39 +128,76 @@ def create( self, *, objective_target: PromptTarget, - attack_scoring_config: AttackScoringConfig, - attack_adversarial_config: AttackAdversarialConfig | None = None, - attack_converter_config: AttackConverterConfig | None = None, + attack_scoring_config_override: AttackScoringConfig | None = None, + attack_adversarial_config_override: AttackAdversarialConfig | None = None, + attack_converter_config_override: AttackConverterConfig | None = None, ) -> AttackTechnique: """ - Create a fresh AttackTechnique bound to the given target and scorer. + Create a fresh AttackTechnique bound to the given target. Each call produces a fully independent attack instance by calling the - real constructor. Config objects are deep-copied to prevent shared - mutable state between instances. + real constructor. Config objects frozen at factory construction time are + deep-copied into every new instance. + + The ``*_override`` parameters let a caller **replace** a config that was + baked into the factory at construction time. When ``None`` (the + default), the factory's original config is kept as-is — so baked-in + converters, adversarial targets, etc. are preserved automatically. + + Override configs are only forwarded when the attack class constructor + declares a matching parameter (without the ``_override`` suffix). + This allows a single call site to safely pass all available overrides + without breaking attacks that don't support them. + + Some attacks (e.g., TAP) create their own scoring config internally + when none is provided. Pass ``None`` (the default) for + ``attack_scoring_config_override`` to let those attacks use their + built-in defaults. Args: - objective_target: The target to attack. - attack_scoring_config: Scoring configuration for the attack. - attack_adversarial_config: Optional adversarial configuration. - Overrides any adversarial config in the frozen kwargs. - attack_converter_config: Optional converter configuration. - Overrides any converter config in the frozen kwargs. + objective_target: The target to attack (always required at create time). + attack_scoring_config_override: When non-None, replaces any scoring + config baked into the factory. Only forwarded if the attack + class constructor accepts ``attack_scoring_config``. + attack_adversarial_config_override: When non-None, replaces any + adversarial config baked into the factory. Only forwarded if + the attack class constructor accepts ``attack_adversarial_config``. + attack_converter_config_override: When non-None, replaces any + converter config baked into the factory. Only forwarded if + the attack class constructor accepts ``attack_converter_config``. Returns: A fresh AttackTechnique with a newly-constructed attack strategy. """ kwargs = copy.deepcopy(self._attack_kwargs) kwargs["objective_target"] = objective_target - kwargs["attack_scoring_config"] = attack_scoring_config - if attack_adversarial_config is not None: - kwargs["attack_adversarial_config"] = attack_adversarial_config - if attack_converter_config is not None: - kwargs["attack_converter_config"] = attack_converter_config + + # Only forward overrides when the attack class accepts the underlying param + accepted_params = self._get_accepted_params() + if attack_scoring_config_override is not None and "attack_scoring_config" in accepted_params: + kwargs["attack_scoring_config"] = attack_scoring_config_override + if attack_adversarial_config_override is not None and "attack_adversarial_config" in accepted_params: + kwargs["attack_adversarial_config"] = attack_adversarial_config_override + if attack_converter_config_override is not None and "attack_converter_config" in accepted_params: + kwargs["attack_converter_config"] = attack_converter_config_override attack = self._attack_class(**kwargs) return AttackTechnique(attack=attack, seed_technique=self._seed_technique) + def _get_accepted_params(self) -> set[str]: + """Return the set of keyword parameter names accepted by the attack class constructor.""" + sig = inspect.signature(self._attack_class.__init__) + return { + name + for name, param in sig.parameters.items() + if name != "self" + and param.kind + in ( + inspect.Parameter.KEYWORD_ONLY, + inspect.Parameter.POSITIONAL_OR_KEYWORD, + ) + } + @staticmethod def _serialize_value(value: Any) -> Any: """ diff --git a/pyrit/scenario/core/core_techniques.py b/pyrit/scenario/core/core_techniques.py new file mode 100644 index 0000000000..d8200bbeda --- /dev/null +++ b/pyrit/scenario/core/core_techniques.py @@ -0,0 +1,56 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +""" +Shared AttackTechniqueFactory builders for common attack techniques. + +These functions return ``AttackTechniqueFactory`` instances that can be +used by any scenario. Each factory captures technique-specific defaults +at registration time; runtime parameters (``objective_target``) and +optional overrides (``attack_scoring_config_override``, etc.) are +provided when ``factory.create()`` is called during scenario execution. + +Scenarios expose available factories via the overridable +``Scenario.get_attack_technique_factories()`` classmethod. +""" + +from pyrit.executor.attack import ( + ManyShotJailbreakAttack, + PromptSendingAttack, + RolePlayAttack, + RolePlayPaths, + TreeOfAttacksWithPruningAttack, +) +from pyrit.scenario.core.attack_technique_factory import AttackTechniqueFactory + + +def prompt_sending_factory() -> AttackTechniqueFactory: + """Create a factory for ``PromptSendingAttack`` (single-turn, no converter).""" + return AttackTechniqueFactory(attack_class=PromptSendingAttack) + + +def role_play_factory( + *, + role_play_path: str | None = None, +) -> AttackTechniqueFactory: + """ + Create a factory for ``RolePlayAttack`` (single-turn with role-play converter). + + Args: + role_play_path: Path to the role-play YAML definition. + Defaults to ``RolePlayPaths.MOVIE_SCRIPT``. + """ + kwargs: dict[str, object] = { + "role_play_definition_path": role_play_path or RolePlayPaths.MOVIE_SCRIPT.value, + } + return AttackTechniqueFactory(attack_class=RolePlayAttack, attack_kwargs=kwargs) + + +def many_shot_factory() -> AttackTechniqueFactory: + """Create a factory for ``ManyShotJailbreakAttack`` (multi-turn).""" + return AttackTechniqueFactory(attack_class=ManyShotJailbreakAttack) + + +def tap_factory() -> AttackTechniqueFactory: + """Create a factory for ``TreeOfAttacksWithPruningAttack`` (multi-turn).""" + return AttackTechniqueFactory(attack_class=TreeOfAttacksWithPruningAttack) diff --git a/pyrit/scenario/core/scenario.py b/pyrit/scenario/core/scenario.py index 7e91c53bda..27898b61b3 100644 --- a/pyrit/scenario/core/scenario.py +++ b/pyrit/scenario/core/scenario.py @@ -39,6 +39,7 @@ from pyrit.executor.attack.core.attack_config import AttackScoringConfig from pyrit.identifiers import ComponentIdentifier from pyrit.models import SeedAttackGroup + from pyrit.scenario.core.attack_technique_factory import AttackTechniqueFactory logger = logging.getLogger(__name__) @@ -173,6 +174,56 @@ def default_dataset_config(cls) -> DatasetConfiguration: DatasetConfiguration: The default dataset configuration. """ + @classmethod + def get_attack_technique_factories(cls) -> dict[str, "AttackTechniqueFactory"]: + """ + Return the default attack technique factories for this scenario. + + Each key is a technique name (matching a strategy enum value) and each + value is an ``AttackTechniqueFactory`` that can produce an + ``AttackTechnique`` for that technique. + + The base implementation returns the common set from + ``core_techniques``. Subclasses may override to add, remove, or + replace factories. + + Returns: + dict[str, AttackTechniqueFactory]: Mapping of technique name to factory. + """ + from pyrit.scenario.core.core_techniques import ( + many_shot_factory, + prompt_sending_factory, + role_play_factory, + tap_factory, + ) + + return { + "prompt_sending": prompt_sending_factory(), + "role_play": role_play_factory(), + "many_shot": many_shot_factory(), + "tap": tap_factory(), + } + + def _build_atomic_attack_name(self, *, technique_name: str, seed_group_name: str) -> str: + """ + Build the grouping key for an atomic attack. + + Controls how attacks are grouped for result storage and resume + logic. Override to customize grouping: + + - **By technique** (default): ``return technique_name`` + - **By dataset/category**: ``return seed_group_name`` + - **Cross-product**: ``return f"{technique_name}_{seed_group_name}"`` + + Args: + technique_name: The name of the attack technique. + seed_group_name: The dataset or category name for the seed group. + + Returns: + str: The atomic attack name used as a grouping key. + """ + return technique_name + def _get_default_objective_scorer(self) -> TrueFalseScorer: # Deferred import to avoid circular dependency: from pyrit.setup.initializers.components.scorers import ScorerInitializerTags diff --git a/pyrit/scenario/core/scenario_strategy.py b/pyrit/scenario/core/scenario_strategy.py index 0252f68415..114ba9b2a7 100644 --- a/pyrit/scenario/core/scenario_strategy.py +++ b/pyrit/scenario/core/scenario_strategy.py @@ -61,6 +61,13 @@ class ScenarioStrategy(Enum, metaclass=_DeprecatedEnumMeta): (like "easy", "moderate", "difficult" or "fast", "medium") that automatically expand to include all strategies with that tag. + **Convention**: Strategy enum members should map 1:1 to selectable **attack techniques** + (e.g., ``PromptSending``, ``RolePlay``, ``TAP``) or to aggregates of techniques + (e.g., ``DEFAULT``, ``SINGLE_TURN``). Datasets control *what* content or objectives + are tested; strategies control *how* attacks are executed. Avoid encoding dataset or + category selection into the strategy enum — use ``DatasetConfiguration`` and the + ``--dataset-names`` CLI flag for that axis. + **Tags**: Flexible categorization system where strategies can have multiple tags (e.g., {"easy", "converter"}, {"difficult", "multi_turn"}) diff --git a/pyrit/scenario/scenarios/airt/__init__.py b/pyrit/scenario/scenarios/airt/__init__.py index fb0e504daa..2c4d489db5 100644 --- a/pyrit/scenario/scenarios/airt/__init__.py +++ b/pyrit/scenario/scenarios/airt/__init__.py @@ -11,19 +11,22 @@ from pyrit.scenario.scenarios.airt.jailbreak import Jailbreak, JailbreakStrategy from pyrit.scenario.scenarios.airt.leakage import Leakage, LeakageStrategy from pyrit.scenario.scenarios.airt.psychosocial import Psychosocial, PsychosocialStrategy +from pyrit.scenario.scenarios.airt.rapid_response import RapidResponse, RapidResponseStrategy from pyrit.scenario.scenarios.airt.scam import Scam, ScamStrategy __all__ = [ "ContentHarms", "ContentHarmsStrategy", - "Psychosocial", - "PsychosocialStrategy", "Cyber", "CyberStrategy", "Jailbreak", "JailbreakStrategy", "Leakage", "LeakageStrategy", + "Psychosocial", + "PsychosocialStrategy", + "RapidResponse", + "RapidResponseStrategy", "Scam", "ScamStrategy", ] diff --git a/pyrit/scenario/scenarios/airt/content_harms.py b/pyrit/scenario/scenarios/airt/content_harms.py index d22ece85ff..47c399592e 100644 --- a/pyrit/scenario/scenarios/airt/content_harms.py +++ b/pyrit/scenario/scenarios/airt/content_harms.py @@ -1,342 +1,39 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT license. -import logging -import os -from collections.abc import Sequence -from typing import Any, Optional, TypeVar +""" +Deprecated — use ``rapid_response`` instead. -from pyrit.auth import get_azure_openai_auth -from pyrit.common import apply_defaults -from pyrit.executor.attack import ( - AttackAdversarialConfig, - AttackScoringConfig, - AttackStrategy, - ManyShotJailbreakAttack, - PromptSendingAttack, - RolePlayAttack, - RolePlayPaths, - TreeOfAttacksWithPruningAttack, -) -from pyrit.models import SeedAttackGroup, SeedGroup -from pyrit.prompt_target import OpenAIChatTarget, PromptChatTarget -from pyrit.scenario.core.atomic_attack import AtomicAttack -from pyrit.scenario.core.attack_technique import AttackTechnique -from pyrit.scenario.core.dataset_configuration import DatasetConfiguration -from pyrit.scenario.core.scenario import Scenario -from pyrit.scenario.core.scenario_strategy import ( - ScenarioCompositeStrategy, - ScenarioStrategy, -) -from pyrit.score import TrueFalseScorer - -logger = logging.getLogger(__name__) - -AttackStrategyT = TypeVar("AttackStrategyT", bound="AttackStrategy[Any, Any]") - - -class ContentHarmsDatasetConfiguration(DatasetConfiguration): - """ - Dataset configuration for content harms that loads seed groups by harm category. - - This subclass overrides the default loading behavior to use harm category pattern - matching instead of exact dataset name matching. When scenario_composites are provided, - it filters datasets to only those matching the selected harm strategies. - """ - - def get_seed_groups(self) -> dict[str, list[SeedGroup]]: - """ - Get seed groups filtered by harm strategies from stored scenario_composites. - - When scenario_composites are set, this filters to only include datasets - matching the selected harm strategies and returns harm strategy names as keys. - - Returns: - Dict[str, List[SeedGroup]]: Dictionary mapping harm strategy names to their - seed groups, filtered by the selected harm strategies. - """ - result = super().get_seed_groups() - - if self._scenario_composites is None: - return result - - # Extract selected harm strategies - selected_harms = ScenarioCompositeStrategy.extract_single_strategy_values( - self._scenario_composites, strategy_type=ContentHarmsStrategy - ) - - # Filter to matching datasets and map keys to harm names - mapped_result: dict[str, list[SeedGroup]] = {} - for name, groups in result.items(): - matched_harm = next((harm for harm in selected_harms if harm in name), None) - if matched_harm: - mapped_result[matched_harm] = groups - - return mapped_result - - -class ContentHarmsStrategy(ScenarioStrategy): - """ - ContentHarmsStrategy defines a set of strategies for testing model behavior - across several different harm categories. The scenario is designed to provide quick - feedback on model performance with respect to common harm types with the idea being that - users will dive deeper into specific harm categories based on initial results. - - Each tag represents a different harm category that the model can be tested for. - Specifying the all tag will include a comprehensive test suite covering all harm categories. - Users can define objectives for each harm category via seed datasets or use the default datasets - provided with PyRIT. - - """ - - ALL = ("all", {"all"}) - - Hate = ("hate", set[str]()) - Fairness = ("fairness", set[str]()) - Violence = ("violence", set[str]()) - Sexual = ("sexual", set[str]()) - Harassment = ("harassment", set[str]()) - Misinformation = ("misinformation", set[str]()) - Leakage = ("leakage", set[str]()) - - -class ContentHarms(Scenario): - """ - - Content Harms Scenario implementation for PyRIT. - - This scenario contains various harm-based checks that you can run to get a quick idea about model behavior - with respect to certain harm categories. - """ - - VERSION: int = 1 - - @classmethod - def get_strategy_class(cls) -> type[ScenarioStrategy]: - """ - Get the strategy enum class for this scenario. - - Returns: - Type[ScenarioStrategy]: The ContentHarmsStrategy enum class. - """ - return ContentHarmsStrategy - - @classmethod - def get_default_strategy(cls) -> ScenarioStrategy: - """ - Get the default strategy used when no strategies are specified. +``ContentHarms`` and ``ContentHarmsStrategy`` are thin aliases kept for +backward compatibility. They will be removed in a future release. +""" - Returns: - ScenarioStrategy: ContentHarmsStrategy.ALL - """ - return ContentHarmsStrategy.ALL +import warnings - @classmethod - def default_dataset_config(cls) -> DatasetConfiguration: - """ - Return the default dataset configuration for this scenario. - - Returns: - DatasetConfiguration: Configuration with all content harm datasets. - """ - return ContentHarmsDatasetConfiguration( - dataset_names=[ - "airt_hate", - "airt_fairness", - "airt_violence", - "airt_sexual", - "airt_harassment", - "airt_misinformation", - "airt_leakage", - ], - max_dataset_size=4, - ) - - @apply_defaults - def __init__( - self, - *, - adversarial_chat: Optional[PromptChatTarget] = None, - objective_scorer: Optional[TrueFalseScorer] = None, - scenario_result_id: Optional[str] = None, - ): - """ - Initialize the Content Harms Scenario. - - Args: - adversarial_chat (Optional[PromptChatTarget]): Additionally used for scoring defaults. - If not provided, a default OpenAI target will be created using environment variables. - objective_scorer (Optional[TrueFalseScorer]): Scorer to evaluate attack success. - If not provided, creates a default composite scorer using Azure Content Filter - and SelfAsk Refusal scorers. - scenario_result_id (Optional[str]): Optional ID of an existing scenario result to resume. - """ - self._objective_scorer: TrueFalseScorer = ( - objective_scorer if objective_scorer else self._get_default_objective_scorer() - ) - self._adversarial_chat = adversarial_chat if adversarial_chat else self._get_default_adversarial_target() - - super().__init__( - version=self.VERSION, - objective_scorer=self._objective_scorer, - strategy_class=ContentHarmsStrategy, - scenario_result_id=scenario_result_id, - ) - - def _get_default_adversarial_target(self) -> OpenAIChatTarget: - endpoint = os.environ.get("AZURE_OPENAI_GPT4O_UNSAFE_CHAT_ENDPOINT") - return OpenAIChatTarget( - endpoint=endpoint, - api_key=get_azure_openai_auth(endpoint), - model_name=os.environ.get("AZURE_OPENAI_GPT4O_UNSAFE_CHAT_MODEL"), - temperature=1.2, - ) - - def _resolve_seed_groups_by_harm(self) -> dict[str, list[SeedAttackGroup]]: - """ - Resolve seed groups from dataset configuration. - - Returns: - Dict[str, List[SeedAttackGroup]]: Dictionary mapping content harm strategy names to their - seed attack groups. - """ - # Set scenario_composites on the config so get_seed_attack_groups can filter by strategy - self._dataset_config._scenario_composites = self._scenario_composites - return self._dataset_config.get_seed_attack_groups() - - async def _get_atomic_attacks_async(self) -> list[AtomicAttack]: - """ - Retrieve the list of AtomicAttack instances for harm strategies. - - Returns: - List[AtomicAttack]: The list of AtomicAttack instances for harm strategies. - """ - seed_groups_by_harm = self._resolve_seed_groups_by_harm() - - atomic_attacks: list[AtomicAttack] = [] - for strategy, seed_groups in seed_groups_by_harm.items(): - atomic_attacks.extend(self._get_strategy_attacks(strategy=strategy, seed_groups=seed_groups)) - return atomic_attacks - - def _get_strategy_attacks( - self, - *, - strategy: str, - seed_groups: Sequence[SeedAttackGroup], - ) -> list[AtomicAttack]: - """ - Create AtomicAttack instances for a given harm strategy. - - Args: - strategy (str): The harm strategy name to create attacks for. - seed_groups (Sequence[SeedAttackGroup]): The seed attack groups associated with the harm dataset. - - Returns: - list[AtomicAttack]: The constructed AtomicAttack instances for each attack type. - - Raises: - ValueError: If scenario is not properly initialized. - """ - # objective_target is guaranteed to be non-None by parent class validation - if self._objective_target is None: - raise ValueError( - "Scenario not properly initialized. Call await scenario.initialize_async() before running." - ) - - attacks: list[AtomicAttack] = [ - *self._get_single_turn_attacks(strategy=strategy, seed_groups=seed_groups), - *self._get_multi_turn_attacks(strategy=strategy, seed_groups=seed_groups), - ] - - return attacks - - def _get_single_turn_attacks( - self, - *, - strategy: str, - seed_groups: Sequence[SeedAttackGroup], - ) -> list[AtomicAttack]: - """ - Create single-turn AtomicAttack instances: RolePlayAttack and PromptSendingAttack. - - Args: - strategy (str): The harm strategy name. - seed_groups (Sequence[SeedAttackGroup]): Seed attack groups for this harm category. +from pyrit.scenario.scenarios.airt.rapid_response import ( + RapidResponse, + RapidResponseStrategy, +) - Returns: - list[AtomicAttack]: The single-turn atomic attacks. - """ - prompt_sending_attack = PromptSendingAttack( - objective_target=self._objective_target, - attack_scoring_config=AttackScoringConfig(objective_scorer=self._objective_scorer), - ) - role_play_attack = RolePlayAttack( - objective_target=self._objective_target, - attack_adversarial_config=AttackAdversarialConfig(target=self._adversarial_chat), - role_play_definition_path=RolePlayPaths.MOVIE_SCRIPT.value, +def __getattr__(name: str): + if name == "ContentHarms": + warnings.warn( + "ContentHarms is deprecated. Use RapidResponse instead.", + DeprecationWarning, + stacklevel=2, ) - - return [ - AtomicAttack( - atomic_attack_name=strategy, - attack_technique=AttackTechnique(attack=prompt_sending_attack), - seed_groups=list(seed_groups), - adversarial_chat=self._adversarial_chat, - objective_scorer=self._objective_scorer, - memory_labels=self._memory_labels, - ), - AtomicAttack( - atomic_attack_name=strategy, - attack_technique=AttackTechnique(attack=role_play_attack), - seed_groups=list(seed_groups), - adversarial_chat=self._adversarial_chat, - objective_scorer=self._objective_scorer, - memory_labels=self._memory_labels, - ), - ] - - def _get_multi_turn_attacks( - self, - *, - strategy: str, - seed_groups: Sequence[SeedAttackGroup], - ) -> list[AtomicAttack]: - """ - Create multi-turn AtomicAttack instances: ManyShotJailbreakAttack and TreeOfAttacksWithPruningAttack. - - Args: - strategy (str): The harm strategy name. - seed_groups (Sequence[SeedAttackGroup]): Seed attack groups for this harm category. - - Returns: - list[AtomicAttack]: The multi-turn atomic attacks. - """ - many_shot_jailbreak_attack = ManyShotJailbreakAttack( - objective_target=self._objective_target, - attack_scoring_config=AttackScoringConfig(objective_scorer=self._objective_scorer), + return RapidResponse + if name == "ContentHarmsStrategy": + warnings.warn( + "ContentHarmsStrategy is deprecated. Use RapidResponseStrategy instead.", + DeprecationWarning, + stacklevel=2, ) + return RapidResponseStrategy + raise AttributeError(f"module {__name__!r} has no attribute {name!r}") - tap_attack = TreeOfAttacksWithPruningAttack( - objective_target=self._objective_target, - attack_adversarial_config=AttackAdversarialConfig(target=self._adversarial_chat), - ) - return [ - AtomicAttack( - atomic_attack_name=strategy, - attack_technique=AttackTechnique(attack=many_shot_jailbreak_attack), - seed_groups=list(seed_groups), - adversarial_chat=self._adversarial_chat, - objective_scorer=self._objective_scorer, - memory_labels=self._memory_labels, - ), - AtomicAttack( - atomic_attack_name=strategy, - attack_technique=AttackTechnique(attack=tap_attack), - seed_groups=list(seed_groups), - adversarial_chat=self._adversarial_chat, - objective_scorer=self._objective_scorer, - memory_labels=self._memory_labels, - ), - ] +# Direct aliases for import-from statements +ContentHarms = RapidResponse +ContentHarmsStrategy = RapidResponseStrategy diff --git a/pyrit/scenario/scenarios/airt/rapid_response.py b/pyrit/scenario/scenarios/airt/rapid_response.py new file mode 100644 index 0000000000..507ae96ddc --- /dev/null +++ b/pyrit/scenario/scenarios/airt/rapid_response.py @@ -0,0 +1,201 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +""" +RapidResponse scenario — technique-based rapid content-harms testing. + +Strategies select **attack techniques** (PromptSending, RolePlay, +ManyShot, TAP). Datasets select **harm categories** (hate, fairness, +violence, …). Use ``--dataset-names`` to narrow which harm categories +to test. +""" + +import logging +import os +from typing import Optional + +from pyrit.auth import get_azure_openai_auth +from pyrit.common import apply_defaults +from pyrit.executor.attack import AttackAdversarialConfig, AttackScoringConfig +from pyrit.models import SeedAttackGroup +from pyrit.prompt_target import OpenAIChatTarget, PromptChatTarget +from pyrit.scenario.core.atomic_attack import AtomicAttack +from pyrit.scenario.core.dataset_configuration import DatasetConfiguration +from pyrit.scenario.core.scenario import Scenario +from pyrit.scenario.core.scenario_strategy import ( + ScenarioCompositeStrategy, + ScenarioStrategy, +) +from pyrit.score import TrueFalseScorer + +logger = logging.getLogger(__name__) + + +class RapidResponseStrategy(ScenarioStrategy): + """ + Attack-technique strategies for the RapidResponse scenario. + + Each non-aggregate member maps to a single attack technique. + Aggregates (ALL, DEFAULT, SINGLE_TURN, MULTI_TURN) expand to + all techniques that share the corresponding tag. + + ``ScenarioStrategy`` members should map 1:1 to selectable attack + techniques or aggregates of techniques. They are the user-facing + selection API; ``AttackTechniqueFactory`` is the execution + abstraction. + """ + + ALL = ("all", {"all"}) + DEFAULT = ("default", {"default"}) + SINGLE_TURN = ("single_turn", {"single_turn"}) + MULTI_TURN = ("multi_turn", {"multi_turn"}) + + PromptSending = ("prompt_sending", {"single_turn", "default"}) + RolePlay = ("role_play", {"single_turn"}) + ManyShot = ("many_shot", {"multi_turn", "default"}) + TAP = ("tap", {"multi_turn"}) + + @classmethod + def get_aggregate_tags(cls) -> set[str]: + return {"all", "default", "single_turn", "multi_turn"} + + +class RapidResponse(Scenario): + """ + Rapid Response scenario for content-harms testing. + + Tests model behaviour across harm categories using selectable attack + techniques. Strategies control *how* prompts are delivered (e.g. + prompt_sending, role_play, many_shot, TAP). Datasets control *what* + harm content is tested (e.g. hate, violence, sexual). Use + ``--dataset-names`` to filter harm categories. + """ + + VERSION: int = 2 + + @classmethod + def get_strategy_class(cls) -> type[ScenarioStrategy]: + return RapidResponseStrategy + + @classmethod + def get_default_strategy(cls) -> ScenarioStrategy: + return RapidResponseStrategy.DEFAULT + + @classmethod + def default_dataset_config(cls) -> DatasetConfiguration: + return DatasetConfiguration( + dataset_names=[ + "airt_hate", + "airt_fairness", + "airt_violence", + "airt_sexual", + "airt_harassment", + "airt_misinformation", + "airt_leakage", + ], + max_dataset_size=4, + ) + + @apply_defaults + def __init__( + self, + *, + adversarial_chat: PromptChatTarget | None = None, + objective_scorer: TrueFalseScorer | None = None, + scenario_result_id: str | None = None, + ) -> None: + """ + Initialize the Rapid Response scenario. + + Args: + adversarial_chat: Chat target for multi-turn / adversarial + attacks (RolePlay, TAP). Defaults to an Azure OpenAI + target from environment variables. + objective_scorer: Scorer for evaluating attack success. + Defaults to a composite Azure-Content-Filter + refusal + scorer. + scenario_result_id: Optional ID of an existing scenario + result to resume. + """ + self._objective_scorer: TrueFalseScorer = ( + objective_scorer if objective_scorer else self._get_default_objective_scorer() + ) + self._adversarial_chat = adversarial_chat if adversarial_chat else self._get_default_adversarial_target() + + super().__init__( + version=self.VERSION, + objective_scorer=self._objective_scorer, + strategy_class=RapidResponseStrategy, + scenario_result_id=scenario_result_id, + ) + + def _build_atomic_attack_name(self, *, technique_name: str, seed_group_name: str) -> str: + """Group results by harm category (dataset) rather than technique.""" + return seed_group_name + + def _get_default_adversarial_target(self) -> OpenAIChatTarget: + endpoint = os.environ.get("AZURE_OPENAI_GPT4O_UNSAFE_CHAT_ENDPOINT") + return OpenAIChatTarget( + endpoint=endpoint, + api_key=get_azure_openai_auth(endpoint), + model_name=os.environ.get("AZURE_OPENAI_GPT4O_UNSAFE_CHAT_MODEL"), + temperature=1.2, + ) + + async def _get_atomic_attacks_async(self) -> list[AtomicAttack]: + """ + Build atomic attacks from selected techniques × harm datasets. + + Iterates over every (technique, harm-dataset) pair and creates + an ``AtomicAttack`` for each. The ``_build_atomic_attack_name`` + override groups results by harm category. + """ + if self._objective_target is None: + raise ValueError( + "Scenario not properly initialized. Call await scenario.initialize_async() before running." + ) + + selected_techniques = ScenarioCompositeStrategy.extract_single_strategy_values( + self._scenario_composites, strategy_type=RapidResponseStrategy + ) + + factories = self.get_attack_technique_factories() + seed_groups_by_dataset = self._dataset_config.get_seed_attack_groups() + + scoring_config = AttackScoringConfig(objective_scorer=self._objective_scorer) + adversarial_config = AttackAdversarialConfig(target=self._adversarial_chat) + + atomic_attacks: list[AtomicAttack] = [] + for technique_name in selected_techniques: + factory = factories.get(technique_name) + if factory is None: + logger.warning(f"No factory for technique '{technique_name}', skipping.") + continue + + # TAP creates its own FloatScaleThresholdScorer internally when no + # scoring config is provided. Passing the scenario's TrueFalseScorer + # would fail TAP's type validation. + scoring_for_technique = None if technique_name == "tap" else scoring_config + + attack_technique = factory.create( + objective_target=self._objective_target, + attack_scoring_config_override=scoring_for_technique, + attack_adversarial_config_override=adversarial_config, + ) + + for dataset_name, seed_groups in seed_groups_by_dataset.items(): + atomic_attacks.append( + AtomicAttack( + atomic_attack_name=self._build_atomic_attack_name( + technique_name=technique_name, + seed_group_name=dataset_name, + ), + attack_technique=attack_technique, + seed_groups=list(seed_groups), + adversarial_chat=self._adversarial_chat, + objective_scorer=self._objective_scorer, + memory_labels=self._memory_labels, + ) + ) + + return atomic_attacks diff --git a/tests/unit/scenario/test_attack_technique_factory.py b/tests/unit/scenario/test_attack_technique_factory.py index 00734eb009..1d2cb1fbff 100644 --- a/tests/unit/scenario/test_attack_technique_factory.py +++ b/tests/unit/scenario/test_attack_technique_factory.py @@ -153,7 +153,7 @@ def test_create_produces_attack_technique(self): factory = AttackTechniqueFactory(attack_class=_StubAttack) target = MagicMock(spec=PromptTarget) - technique = factory.create(objective_target=target, attack_scoring_config=self._scoring()) + technique = factory.create(objective_target=target, attack_scoring_config_override=self._scoring()) assert isinstance(technique, AttackTechnique) assert isinstance(technique.attack, _StubAttack) @@ -166,7 +166,7 @@ def test_create_passes_frozen_kwargs(self): ) target = MagicMock(spec=PromptTarget) - technique = factory.create(objective_target=target, attack_scoring_config=self._scoring()) + technique = factory.create(objective_target=target, attack_scoring_config_override=self._scoring()) assert technique.attack.max_turns == 42 @@ -175,7 +175,7 @@ def test_create_passes_scoring_config(self): target = MagicMock(spec=PromptTarget) scoring = MagicMock(spec=AttackScoringConfig) - technique = factory.create(objective_target=target, attack_scoring_config=scoring) + technique = factory.create(objective_target=target, attack_scoring_config_override=scoring) assert technique.attack.attack_scoring_config is scoring @@ -189,7 +189,7 @@ def test_create_overrides_frozen_scoring_config(self): target = MagicMock(spec=PromptTarget) override_scoring = MagicMock(spec=AttackScoringConfig) - technique = factory.create(objective_target=target, attack_scoring_config=override_scoring) + technique = factory.create(objective_target=target, attack_scoring_config_override=override_scoring) assert technique.attack.attack_scoring_config is override_scoring assert technique.attack.attack_scoring_config is not frozen_scoring @@ -199,7 +199,7 @@ def test_create_preserves_seed_technique(self): factory = AttackTechniqueFactory(attack_class=_StubAttack, seed_technique=seeds) target = MagicMock(spec=PromptTarget) - technique = factory.create(objective_target=target, attack_scoring_config=self._scoring()) + technique = factory.create(objective_target=target, attack_scoring_config_override=self._scoring()) assert technique.seed_technique is seeds @@ -213,8 +213,8 @@ def test_create_produces_independent_instances(self): target2 = MagicMock(spec=PromptTarget) scoring = self._scoring() - technique1 = factory.create(objective_target=target1, attack_scoring_config=scoring) - technique2 = factory.create(objective_target=target2, attack_scoring_config=scoring) + technique1 = factory.create(objective_target=target1, attack_scoring_config_override=scoring) + technique2 = factory.create(objective_target=target2, attack_scoring_config_override=scoring) assert technique1.attack is not technique2.attack assert technique1.attack.objective_target is target1 @@ -238,11 +238,11 @@ def get_identifier(self): ) target = MagicMock(spec=PromptTarget) - technique1 = factory.create(objective_target=target, attack_scoring_config=self._scoring()) + technique1 = factory.create(objective_target=target, attack_scoring_config_override=self._scoring()) # Mutate the source list mutable_list.append(999) - technique2 = factory.create(objective_target=target, attack_scoring_config=self._scoring()) + technique2 = factory.create(objective_target=target, attack_scoring_config_override=self._scoring()) # First create should have the original snapshot assert technique1.attack.items == [1, 2, 3] @@ -271,7 +271,7 @@ def get_identifier(self): factory = AttackTechniqueFactory(attack_class=_SentinelAttack) target = MagicMock(spec=PromptTarget) - technique = factory.create(objective_target=target, attack_scoring_config=self._scoring()) + technique = factory.create(objective_target=target, attack_scoring_config_override=self._scoring()) assert not technique.attack.adversarial_was_passed assert not technique.attack.converter_was_passed diff --git a/tests/unit/scenario/test_content_harms.py b/tests/unit/scenario/test_content_harms.py deleted file mode 100644 index 1e177e15bd..0000000000 --- a/tests/unit/scenario/test_content_harms.py +++ /dev/null @@ -1,808 +0,0 @@ -# Copyright (c) Microsoft Corporation. -# Licensed under the MIT license. - -"""Tests for the ContentHarms class.""" - -import pathlib -from unittest.mock import MagicMock, patch - -import pytest - -from pyrit.common.path import DATASETS_PATH -from pyrit.identifiers import ComponentIdentifier -from pyrit.models import SeedAttackGroup, SeedObjective, SeedPrompt -from pyrit.prompt_target import PromptTarget -from pyrit.prompt_target.common.prompt_chat_target import PromptChatTarget -from pyrit.scenario import ScenarioCompositeStrategy -from pyrit.scenario.airt import ( - ContentHarms, - ContentHarmsStrategy, -) -from pyrit.scenario.scenarios.airt.content_harms import ( - ContentHarmsDatasetConfiguration, -) -from pyrit.score import TrueFalseScorer - - -def _mock_scorer_id(name: str = "MockObjectiveScorer") -> ComponentIdentifier: - """Helper to create ComponentIdentifier for tests.""" - return ComponentIdentifier( - class_name=name, - class_module="test", - ) - - -def _mock_target_id(name: str = "MockTarget") -> ComponentIdentifier: - """Helper to create ComponentIdentifier for tests.""" - return ComponentIdentifier( - class_name=name, - class_module="test", - ) - - -@pytest.fixture -def mock_objective_target(): - """Create a mock objective target for testing.""" - mock = MagicMock(spec=PromptTarget) - mock.get_identifier.return_value = _mock_target_id("MockObjectiveTarget") - return mock - - -@pytest.fixture -def mock_adversarial_target(): - """Create a mock adversarial target for testing.""" - mock = MagicMock(spec=PromptChatTarget) - mock.get_identifier.return_value = _mock_target_id("MockAdversarialTarget") - return mock - - -@pytest.fixture -def mock_objective_scorer(): - """Create a mock objective scorer for testing.""" - mock = MagicMock(spec=TrueFalseScorer) - mock.get_identifier.return_value = _mock_scorer_id("MockObjectiveScorer") - return mock - - -@pytest.fixture -def sample_objectives(): - """Create sample objectives for testing.""" - return ["objective1", "objective2", "objective3"] - - -@pytest.fixture(scope="class") -def mock_seed_groups(): - """Create mock seed groups for testing.""" - - def create_seed_groups_for_strategy(strategy_name: str): - """Helper to create seed groups for a given strategy.""" - return [ - SeedAttackGroup( - seeds=[ - SeedObjective(value=f"{strategy_name} objective 1"), - SeedPrompt(value=f"{strategy_name} prompt 1"), - ] - ), - SeedAttackGroup( - seeds=[ - SeedObjective(value=f"{strategy_name} objective 2"), - SeedPrompt(value=f"{strategy_name} prompt 2"), - ] - ), - ] - - return create_seed_groups_for_strategy - - -@pytest.fixture(scope="class") -def mock_all_harm_objectives(mock_seed_groups): - """Class-scoped fixture for all harm category objectives to reduce test code duplication.""" - return { - "hate": mock_seed_groups("hate"), - "fairness": mock_seed_groups("fairness"), - "violence": mock_seed_groups("violence"), - "sexual": mock_seed_groups("sexual"), - "harassment": mock_seed_groups("harassment"), - "misinformation": mock_seed_groups("misinformation"), - "leakage": mock_seed_groups("leakage"), - } - - -class TestContentHarmsStrategy: - """Tests for the ContentHarmsStrategy enum.""" - - def test_all_harm_categories_exist(self): - """Test that all expected harm categories exist as strategies.""" - expected_categories = ["hate", "fairness", "violence", "sexual", "harassment", "misinformation", "leakage"] - aggregate_values = {"all"} - strategy_values = [s.value for s in ContentHarmsStrategy if s.value not in aggregate_values] - - for category in expected_categories: - assert category in strategy_values, f"Expected harm category '{category}' not found in strategies" - - def test_strategy_tags_are_sets(self): - """Test that all strategy tags are set objects.""" - for strategy in ContentHarmsStrategy: - assert isinstance(strategy.tags, set), f"Tags for {strategy.name} are not a set" - - def test_enum_members_count(self): - """Test that we have the expected number of strategy members.""" - # ALL + 7 harm categories = 8 total - assert len(list(ContentHarmsStrategy)) == 8 - - def test_all_strategies_can_be_accessed_by_name(self): - """Test that all strategies can be accessed by their name.""" - assert ContentHarmsStrategy["ALL"] == ContentHarmsStrategy.ALL - assert ContentHarmsStrategy.Hate == ContentHarmsStrategy["Hate"] - assert ContentHarmsStrategy.Fairness == ContentHarmsStrategy["Fairness"] - assert ContentHarmsStrategy.Violence == ContentHarmsStrategy["Violence"] - assert ContentHarmsStrategy.Sexual == ContentHarmsStrategy["Sexual"] - assert ContentHarmsStrategy.Harassment == ContentHarmsStrategy["Harassment"] - assert ContentHarmsStrategy.Misinformation == ContentHarmsStrategy["Misinformation"] - assert ContentHarmsStrategy.Leakage == ContentHarmsStrategy["Leakage"] - - def test_all_strategies_can_be_accessed_by_value(self): - """Test that all strategies can be accessed by their value.""" - assert ContentHarmsStrategy("all") == ContentHarmsStrategy.ALL - assert ContentHarmsStrategy("hate") == ContentHarmsStrategy.Hate - assert ContentHarmsStrategy("fairness") == ContentHarmsStrategy.Fairness - assert ContentHarmsStrategy("violence") == ContentHarmsStrategy.Violence - assert ContentHarmsStrategy("sexual") == ContentHarmsStrategy.Sexual - assert ContentHarmsStrategy("harassment") == ContentHarmsStrategy.Harassment - assert ContentHarmsStrategy("misinformation") == ContentHarmsStrategy.Misinformation - assert ContentHarmsStrategy("leakage") == ContentHarmsStrategy.Leakage - - def test_strategies_are_unique(self): - """Test that all strategy values are unique.""" - values = [s.value for s in ContentHarmsStrategy] - assert len(values) == len(set(values)), "Strategy values are not unique" - - def test_strategy_iteration(self): - """Test that we can iterate over all strategies.""" - strategies = list(ContentHarmsStrategy) - assert len(strategies) == 8 - assert ContentHarmsStrategy.ALL in strategies - assert ContentHarmsStrategy.Hate in strategies - - def test_strategy_comparison(self): - """Test that strategy comparison works correctly.""" - assert ContentHarmsStrategy.Hate == ContentHarmsStrategy.Hate - assert ContentHarmsStrategy.Hate != ContentHarmsStrategy.Violence - assert ContentHarmsStrategy.Hate != ContentHarmsStrategy.ALL - - def test_strategy_hash(self): - """Test that strategies can be hashed and used in sets/dicts.""" - strategy_set = {ContentHarmsStrategy.Hate, ContentHarmsStrategy.Violence} - assert len(strategy_set) == 2 - assert ContentHarmsStrategy.Hate in strategy_set - - strategy_dict = {ContentHarmsStrategy.Hate: "hate_value"} - assert strategy_dict[ContentHarmsStrategy.Hate] == "hate_value" - - def test_strategy_string_representation(self): - """Test string representation of strategies.""" - assert "Hate" in str(ContentHarmsStrategy.Hate) - assert "ALL" in str(ContentHarmsStrategy.ALL) - - def test_invalid_strategy_value_raises_error(self): - """Test that accessing invalid strategy value raises ValueError.""" - with pytest.raises(ValueError): - ContentHarmsStrategy("invalid_strategy") - - def test_invalid_strategy_name_raises_error(self): - """Test that accessing invalid strategy name raises KeyError.""" - with pytest.raises(KeyError): - ContentHarmsStrategy["InvalidStrategy"] - - def test_get_aggregate_tags_includes_all_aggregates(self): - """Test that get_aggregate_tags includes 'all' tag.""" - aggregate_tags = ContentHarmsStrategy.get_aggregate_tags() - - assert "all" in aggregate_tags - assert isinstance(aggregate_tags, set) - assert len(aggregate_tags) == 1 - - def test_get_aggregate_tags_returns_set(self): - """Test that get_aggregate_tags returns a set.""" - aggregate_tags = ContentHarmsStrategy.get_aggregate_tags() - assert isinstance(aggregate_tags, set) - - def test_get_aggregate_strategies(self): - """Test that ALL aggregate expands to all individual harm strategies.""" - # The ALL strategy should include all individual harm categories - all_strategies = list(ContentHarmsStrategy) - assert len(all_strategies) == 8 # ALL + 7 harm categories - - # Non-aggregate strategies should be just the 7 harm categories - non_aggregate = ContentHarmsStrategy.get_all_strategies() - assert len(non_aggregate) == 7 - - -@pytest.mark.usefixtures("patch_central_database") -class TestContentHarmsBasic: - """Basic tests for ContentHarms initialization and properties.""" - - @pytest.mark.asyncio - @patch("pyrit.scenario.core.scenario.Scenario._get_default_objective_scorer") - @patch("pyrit.scenario.scenarios.airt.content_harms.ContentHarmsDatasetConfiguration.get_seed_attack_groups") - async def test_initialization_with_minimal_parameters( - self, - mock_get_seed_attack_groups, - mock_get_scorer, - mock_objective_target, - mock_adversarial_target, - mock_objective_scorer, - mock_all_harm_objectives, - ): - """Test initialization with only required parameters.""" - mock_get_scorer.return_value = mock_objective_scorer - mock_get_seed_attack_groups.return_value = mock_all_harm_objectives - - scenario = ContentHarms(adversarial_chat=mock_adversarial_target) - - # Constructor should set adversarial chat and basic metadata - assert scenario._adversarial_chat == mock_adversarial_target - assert scenario.name == "ContentHarms" - assert scenario.VERSION == 1 - - # Initialization populates objective target and scenario composites - await scenario.initialize_async(objective_target=mock_objective_target) - - assert scenario._objective_target == mock_objective_target - - @pytest.mark.asyncio - @patch("pyrit.scenario.core.scenario.Scenario._get_default_objective_scorer") - @patch("pyrit.scenario.scenarios.airt.content_harms.ContentHarmsDatasetConfiguration.get_seed_attack_groups") - async def test_initialization_with_custom_strategies( - self, - mock_get_seed_attack_groups, - mock_get_scorer, - mock_objective_target, - mock_adversarial_target, - mock_objective_scorer, - mock_seed_groups, - ): - """Test initialization with custom harm strategies.""" - mock_get_scorer.return_value = mock_objective_scorer - mock_get_seed_attack_groups.return_value = { - "hate": mock_seed_groups("hate"), - "fairness": mock_seed_groups("fairness"), - } - - strategies = [ContentHarmsStrategy.Hate, ContentHarmsStrategy.Fairness] - - scenario = ContentHarms(adversarial_chat=mock_adversarial_target) - - await scenario.initialize_async(objective_target=mock_objective_target, scenario_strategies=strategies) - - # Prepared composites should match provided strategies - assert len(scenario._scenario_composites) == 2 - - def test_initialization_with_custom_scorer( - self, mock_objective_target, mock_adversarial_target, mock_objective_scorer - ): - """Test initialization with custom objective scorer.""" - scenario = ContentHarms( - adversarial_chat=mock_adversarial_target, - objective_scorer=mock_objective_scorer, - ) - - # The scorer is stored in _objective_scorer - assert scenario._objective_scorer == mock_objective_scorer - - @pytest.mark.asyncio - @patch("pyrit.scenario.core.scenario.Scenario._get_default_objective_scorer") - @patch("pyrit.scenario.scenarios.airt.content_harms.ContentHarmsDatasetConfiguration.get_seed_attack_groups") - async def test_initialization_with_custom_max_concurrency( - self, - mock_get_seed_attack_groups, - mock_get_scorer, - mock_objective_target, - mock_adversarial_target, - mock_objective_scorer, - mock_all_harm_objectives, - ): - """Test initialization with custom max concurrency.""" - mock_get_scorer.return_value = mock_objective_scorer - mock_get_seed_attack_groups.return_value = mock_all_harm_objectives - - scenario = ContentHarms(adversarial_chat=mock_adversarial_target) - - await scenario.initialize_async(objective_target=mock_objective_target, max_concurrency=10) - - assert scenario._max_concurrency == 10 - - @pytest.mark.asyncio - @patch("pyrit.scenario.core.scenario.Scenario._get_default_objective_scorer") - @patch("pyrit.scenario.scenarios.airt.content_harms.ContentHarmsDatasetConfiguration.get_seed_attack_groups") - async def test_initialization_with_custom_dataset_path( - self, - mock_get_seed_attack_groups, - mock_get_scorer, - mock_objective_target, - mock_adversarial_target, - mock_objective_scorer, - mock_all_harm_objectives, - ): - """Test initialization with custom seed dataset prefix.""" - mock_get_scorer.return_value = mock_objective_scorer - mock_get_seed_attack_groups.return_value = mock_all_harm_objectives - - scenario = ContentHarms(adversarial_chat=mock_adversarial_target) - - await scenario.initialize_async(objective_target=mock_objective_target) - - # Just verify it initializes without error - assert scenario is not None - - @pytest.mark.asyncio - @patch("pyrit.scenario.core.scenario.Scenario._get_default_objective_scorer") - @patch("pyrit.scenario.scenarios.airt.content_harms.ContentHarmsDatasetConfiguration.get_seed_attack_groups") - async def test_initialization_defaults_to_all_strategy( - self, - mock_get_seed_attack_groups, - mock_get_scorer, - mock_objective_target, - mock_adversarial_target, - mock_objective_scorer, - mock_all_harm_objectives, - ): - """Test that initialization defaults to ALL strategy when none provided.""" - mock_get_scorer.return_value = mock_objective_scorer - mock_get_seed_attack_groups.return_value = mock_all_harm_objectives - - scenario = ContentHarms(adversarial_chat=mock_adversarial_target) - - await scenario.initialize_async(objective_target=mock_objective_target) - - # Should have strategies from the ALL aggregate - assert len(scenario._scenario_composites) > 0 - - def test_get_default_strategy_returns_all(self): - """Test that get_default_strategy returns ALL strategy.""" - assert ContentHarms.get_default_strategy() == ContentHarmsStrategy.ALL - - @patch("pyrit.scenario.core.scenario.Scenario._get_default_objective_scorer") - @patch.dict( - "os.environ", - { - "AZURE_OPENAI_GPT4O_UNSAFE_CHAT_ENDPOINT": "https://test.endpoint", - "AZURE_OPENAI_GPT4O_UNSAFE_CHAT_KEY": "test_key", - "AZURE_OPENAI_GPT4O_UNSAFE_CHAT_MODEL": "gpt-4", - }, - ) - def test_get_default_adversarial_target(self, mock_get_scorer, mock_objective_target, mock_objective_scorer): - """Test default adversarial target creation.""" - mock_get_scorer.return_value = mock_objective_scorer - scenario = ContentHarms() - - assert scenario._adversarial_chat is not None - - @patch("pyrit.scenario.core.scenario.Scenario._get_default_objective_scorer") - @patch.dict( - "os.environ", - { - "AZURE_OPENAI_GPT4O_UNSAFE_CHAT_ENDPOINT": "https://test.endpoint", - "AZURE_OPENAI_GPT4O_UNSAFE_CHAT_KEY": "test_key", - "AZURE_OPENAI_GPT4O_UNSAFE_CHAT_MODEL": "gpt-4", - }, - ) - def test_get_default_objective_scorer(self, mock_get_scorer, mock_objective_target, mock_objective_scorer): - """Test default objective scorer is set from base class.""" - mock_get_scorer.return_value = mock_objective_scorer - scenario = ContentHarms() - - assert scenario._objective_scorer == mock_objective_scorer - - def test_scenario_version(self): - """Test that scenario has correct version.""" - assert ContentHarms.VERSION == 1 - - @patch("pyrit.scenario.core.scenario.Scenario._get_default_objective_scorer") - @patch.dict( - "os.environ", - { - "AZURE_OPENAI_GPT4O_UNSAFE_CHAT_ENDPOINT": "https://test.endpoint", - "AZURE_OPENAI_GPT4O_UNSAFE_CHAT_KEY": "test_key", - "AZURE_OPENAI_GPT4O_UNSAFE_CHAT_MODEL": "gpt-4", - }, - ) - @pytest.mark.asyncio - async def test_initialize_raises_exception_when_no_datasets_available( - self, mock_get_scorer, mock_objective_target, mock_adversarial_target, mock_objective_scorer - ): - """Test that initialization raises ValueError when datasets are not available in memory.""" - mock_get_scorer.return_value = mock_objective_scorer - # Don't mock _get_objectives_by_harm, let it try to load from empty memory - scenario = ContentHarms(adversarial_chat=mock_adversarial_target) - - with pytest.raises(ValueError, match="DatasetConfiguration has no seed_groups"): - await scenario.initialize_async(objective_target=mock_objective_target) - - @pytest.mark.asyncio - @patch("pyrit.scenario.core.scenario.Scenario._get_default_objective_scorer") - @patch("pyrit.scenario.scenarios.airt.content_harms.ContentHarmsDatasetConfiguration.get_seed_attack_groups") - async def test_initialization_with_max_retries( - self, - mock_get_seed_attack_groups, - mock_get_scorer, - mock_objective_target, - mock_adversarial_target, - mock_objective_scorer, - mock_all_harm_objectives, - ): - """Test initialization with max_retries parameter.""" - mock_get_scorer.return_value = mock_objective_scorer - mock_get_seed_attack_groups.return_value = mock_all_harm_objectives - - scenario = ContentHarms(adversarial_chat=mock_adversarial_target) - - await scenario.initialize_async(objective_target=mock_objective_target, max_retries=3) - - assert scenario._max_retries == 3 - - @pytest.mark.asyncio - @patch("pyrit.scenario.core.scenario.Scenario._get_default_objective_scorer") - @patch("pyrit.scenario.scenarios.airt.content_harms.ContentHarmsDatasetConfiguration.get_seed_attack_groups") - async def test_memory_labels_are_stored( - self, - mock_get_seed_attack_groups, - mock_get_scorer, - mock_objective_target, - mock_adversarial_target, - mock_objective_scorer, - mock_all_harm_objectives, - ): - """Test that memory labels are properly stored.""" - mock_get_scorer.return_value = mock_objective_scorer - mock_get_seed_attack_groups.return_value = mock_all_harm_objectives - - memory_labels = {"test_run": "123", "category": "harm"} - - scenario = ContentHarms(adversarial_chat=mock_adversarial_target) - - await scenario.initialize_async(objective_target=mock_objective_target, memory_labels=memory_labels) - - assert scenario._memory_labels == memory_labels - - @pytest.mark.asyncio - @patch("pyrit.scenario.scenarios.airt.content_harms.ContentHarmsDatasetConfiguration.get_seed_attack_groups") - async def test_initialization_with_all_parameters( - self, - mock_get_seed_attack_groups, - mock_objective_target, - mock_adversarial_target, - mock_objective_scorer, - mock_seed_groups, - ): - """Test initialization with all possible parameters.""" - mock_get_seed_attack_groups.return_value = { - "hate": mock_seed_groups("hate"), - "violence": mock_seed_groups("violence"), - } - - memory_labels = {"test": "value"} - strategies = [ContentHarmsStrategy.Hate, ContentHarmsStrategy.Violence] - - scenario = ContentHarms( - adversarial_chat=mock_adversarial_target, - objective_scorer=mock_objective_scorer, - ) - - await scenario.initialize_async( - objective_target=mock_objective_target, - scenario_strategies=strategies, - memory_labels=memory_labels, - max_concurrency=5, - max_retries=2, - ) - - assert scenario._objective_target == mock_objective_target - assert scenario._adversarial_chat == mock_adversarial_target - assert scenario._objective_scorer == mock_objective_scorer - assert scenario._memory_labels == memory_labels - assert scenario._max_concurrency == 5 - assert scenario._max_retries == 2 - - @pytest.mark.parametrize( - "harm_category", ["hate", "fairness", "violence", "sexual", "harassment", "misinformation", "leakage"] - ) - def test_harm_category_prompt_file_exists(self, harm_category): - harm_dataset_path = pathlib.Path(DATASETS_PATH) / "seed_datasets" / "local" / "airt" - file_path = harm_dataset_path / f"{harm_category}.prompt" - assert file_path.exists(), f"Missing file: {file_path}" # Fails if file does not exist - - -class TestContentHarmsDatasetConfiguration: - """Tests for the ContentHarmsDatasetConfiguration class.""" - - def test_get_seed_attack_groups_returns_all_datasets_when_no_composites(self): - """Test that get_seed_attack_groups returns all datasets when scenario_composites is None.""" - # Create mock seed groups for each dataset - mock_groups = { - "airt_hate": [SeedAttackGroup(seeds=[SeedObjective(value="hate obj")])], - "airt_violence": [SeedAttackGroup(seeds=[SeedObjective(value="violence obj")])], - } - - config = ContentHarmsDatasetConfiguration( - dataset_names=["airt_hate", "airt_violence"], - ) - - with patch.object(config, "_load_seed_groups_for_dataset") as mock_load: - mock_load.side_effect = lambda dataset_name: mock_groups.get(dataset_name, []) - - result = config.get_seed_attack_groups() - - # Without scenario_composites, returns dataset names as keys - assert "airt_hate" in result - assert "airt_violence" in result - assert len(result) == 2 - - def test_get_seed_attack_groups_filters_by_selected_harm_strategy(self): - """Test that get_seed_attack_groups filters datasets by selected harm strategies.""" - mock_groups = { - "airt_hate": [SeedAttackGroup(seeds=[SeedObjective(value="hate obj")])], - "airt_violence": [SeedAttackGroup(seeds=[SeedObjective(value="violence obj")])], - "airt_sexual": [SeedAttackGroup(seeds=[SeedObjective(value="sexual obj")])], - } - - config = ContentHarmsDatasetConfiguration( - dataset_names=["airt_hate", "airt_violence", "airt_sexual"], - scenario_composites=[ScenarioCompositeStrategy(strategies=[ContentHarmsStrategy.Hate])], - ) - - with patch.object(config, "_load_seed_groups_for_dataset") as mock_load: - mock_load.side_effect = lambda dataset_name: mock_groups.get(dataset_name, []) - - result = config.get_seed_attack_groups() - - # Should only return "hate" key (mapped from "airt_hate") - assert "hate" in result - assert "violence" not in result - assert "sexual" not in result - assert len(result) == 1 - - def test_get_seed_attack_groups_maps_dataset_names_to_harm_names(self): - """Test that dataset names are mapped to harm strategy names.""" - mock_groups = { - "airt_hate": [SeedAttackGroup(seeds=[SeedObjective(value="hate obj")])], - "airt_fairness": [SeedAttackGroup(seeds=[SeedObjective(value="fairness obj")])], - } - - config = ContentHarmsDatasetConfiguration( - dataset_names=["airt_hate", "airt_fairness"], - scenario_composites=[ - ScenarioCompositeStrategy(strategies=[ContentHarmsStrategy.Hate]), - ScenarioCompositeStrategy(strategies=[ContentHarmsStrategy.Fairness]), - ], - ) - - with patch.object(config, "_load_seed_groups_for_dataset") as mock_load: - mock_load.side_effect = lambda dataset_name: mock_groups.get(dataset_name, []) - - result = config.get_seed_attack_groups() - - # Keys should be harm names, not dataset names - assert "hate" in result - assert "fairness" in result - assert "airt_hate" not in result - assert "airt_fairness" not in result - - def test_get_seed_attack_groups_with_all_strategy_returns_all_harms(self): - """Test that ALL strategy returns all harm categories.""" - all_datasets = [ - "airt_hate", - "airt_fairness", - "airt_violence", - "airt_sexual", - "airt_harassment", - "airt_misinformation", - "airt_leakage", - ] - mock_groups = {name: [SeedAttackGroup(seeds=[SeedObjective(value=f"{name} obj")])] for name in all_datasets} - - # ALL strategy expands to all individual harm strategies - all_harms = ["hate", "fairness", "violence", "sexual", "harassment", "misinformation", "leakage"] - composites = [ScenarioCompositeStrategy(strategies=[ContentHarmsStrategy(harm)]) for harm in all_harms] - - config = ContentHarmsDatasetConfiguration( - dataset_names=all_datasets, - scenario_composites=composites, - ) - - with patch.object(config, "_load_seed_groups_for_dataset") as mock_load: - mock_load.side_effect = lambda dataset_name: mock_groups.get(dataset_name, []) - - result = config.get_seed_attack_groups() - - # Should have all 7 harm categories - assert len(result) == 7 - for harm in all_harms: - assert harm in result - - def test_get_seed_attack_groups_applies_max_dataset_size(self): - """Test that max_dataset_size is applied per dataset.""" - # Create 5 seed groups for the dataset - mock_groups = { - "airt_hate": [SeedAttackGroup(seeds=[SeedObjective(value=f"hate obj {i}")]) for i in range(5)], - } - - config = ContentHarmsDatasetConfiguration( - dataset_names=["airt_hate"], - max_dataset_size=2, - scenario_composites=[ScenarioCompositeStrategy(strategies=[ContentHarmsStrategy.Hate])], - ) - - with patch.object(config, "_load_seed_groups_for_dataset") as mock_load: - mock_load.side_effect = lambda dataset_name: mock_groups.get(dataset_name, []) - - result = config.get_seed_attack_groups() - - # Should have at most 2 seed groups due to max_dataset_size - assert "hate" in result - assert len(result["hate"]) == 2 - - def test_default_dataset_config_has_all_harm_datasets(self): - """Test that default_dataset_config includes all 7 harm category datasets.""" - config = ContentHarms.default_dataset_config() - - assert isinstance(config, ContentHarmsDatasetConfiguration) - dataset_names = config.get_default_dataset_names() - - expected_datasets = [ - "airt_hate", - "airt_fairness", - "airt_violence", - "airt_sexual", - "airt_harassment", - "airt_misinformation", - "airt_leakage", - ] - - for expected in expected_datasets: - assert expected in dataset_names - - assert len(dataset_names) == 7 - - def test_default_dataset_config_has_max_dataset_size(self): - """Test that default_dataset_config has max_dataset_size set to 4.""" - config = ContentHarms.default_dataset_config() - - assert config.max_dataset_size == 4 - - -@pytest.mark.usefixtures("patch_central_database") -class TestContentHarmsAttackGroups: - """Tests for the single-turn and multi-turn attack generation.""" - - @pytest.mark.asyncio - @patch("pyrit.scenario.core.scenario.Scenario._get_default_objective_scorer") - @patch("pyrit.scenario.scenarios.airt.content_harms.ContentHarmsDatasetConfiguration.get_seed_attack_groups") - async def test_get_single_turn_attacks_returns_prompt_sending_and_role_play( - self, - mock_get_seed_attack_groups, - mock_get_scorer, - mock_objective_target, - mock_adversarial_target, - mock_objective_scorer, - mock_seed_groups, - ): - """Test that _get_single_turn_attacks returns PromptSendingAttack and RolePlayAttack.""" - from pyrit.executor.attack import PromptSendingAttack, RolePlayAttack - - mock_get_scorer.return_value = mock_objective_scorer - seed_groups = mock_seed_groups("hate") - mock_get_seed_attack_groups.return_value = {"hate": seed_groups} - - scenario = ContentHarms(adversarial_chat=mock_adversarial_target) - await scenario.initialize_async( - objective_target=mock_objective_target, - scenario_strategies=[ContentHarmsStrategy.Hate], - ) - - attacks = scenario._get_single_turn_attacks(strategy="hate", seed_groups=seed_groups) - - assert len(attacks) == 2 - attack_types = [type(a.attack_technique.attack) for a in attacks] - assert PromptSendingAttack in attack_types - assert RolePlayAttack in attack_types - - @pytest.mark.asyncio - @patch("pyrit.scenario.core.scenario.Scenario._get_default_objective_scorer") - @patch("pyrit.scenario.scenarios.airt.content_harms.ContentHarmsDatasetConfiguration.get_seed_attack_groups") - async def test_get_multi_turn_attacks_returns_many_shot_and_tap( - self, - mock_get_seed_attack_groups, - mock_get_scorer, - mock_objective_target, - mock_adversarial_target, - mock_objective_scorer, - mock_seed_groups, - ): - """Test that _get_multi_turn_attacks returns ManyShotJailbreakAttack and TreeOfAttacksWithPruningAttack.""" - from pyrit.executor.attack import ManyShotJailbreakAttack, TreeOfAttacksWithPruningAttack - - mock_get_scorer.return_value = mock_objective_scorer - seed_groups = mock_seed_groups("hate") - mock_get_seed_attack_groups.return_value = {"hate": seed_groups} - - scenario = ContentHarms(adversarial_chat=mock_adversarial_target) - await scenario.initialize_async( - objective_target=mock_objective_target, - scenario_strategies=[ContentHarmsStrategy.Hate], - ) - - attacks = scenario._get_multi_turn_attacks(strategy="hate", seed_groups=seed_groups) - - assert len(attacks) == 2 - attack_types = [type(a.attack_technique.attack) for a in attacks] - assert ManyShotJailbreakAttack in attack_types - assert TreeOfAttacksWithPruningAttack in attack_types - - @pytest.mark.asyncio - @patch("pyrit.scenario.core.scenario.Scenario._get_default_objective_scorer") - @patch("pyrit.scenario.scenarios.airt.content_harms.ContentHarmsDatasetConfiguration.get_seed_attack_groups") - async def test_get_strategy_attacks_includes_all_groups( - self, - mock_get_seed_attack_groups, - mock_get_scorer, - mock_objective_target, - mock_adversarial_target, - mock_objective_scorer, - mock_seed_groups, - ): - """Test that _get_strategy_attacks returns attacks from both single-turn and multi-turn groups.""" - from pyrit.executor.attack import ( - ManyShotJailbreakAttack, - PromptSendingAttack, - RolePlayAttack, - TreeOfAttacksWithPruningAttack, - ) - - mock_get_scorer.return_value = mock_objective_scorer - seed_groups = mock_seed_groups("hate") - mock_get_seed_attack_groups.return_value = {"hate": seed_groups} - - scenario = ContentHarms(adversarial_chat=mock_adversarial_target) - await scenario.initialize_async( - objective_target=mock_objective_target, - scenario_strategies=[ContentHarmsStrategy.Hate], - ) - - attacks = scenario._get_strategy_attacks(strategy="hate", seed_groups=seed_groups) - - # 2 single-turn + 2 multi-turn = 4 - assert len(attacks) == 4 - attack_types = [type(a.attack_technique.attack) for a in attacks] - assert PromptSendingAttack in attack_types - assert RolePlayAttack in attack_types - assert ManyShotJailbreakAttack in attack_types - assert TreeOfAttacksWithPruningAttack in attack_types - - @pytest.mark.asyncio - @patch("pyrit.scenario.core.scenario.Scenario._get_default_objective_scorer") - @patch("pyrit.scenario.scenarios.airt.content_harms.ContentHarmsDatasetConfiguration.get_seed_attack_groups") - async def test_get_strategy_attacks_raises_when_not_initialized( - self, - mock_get_seed_attack_groups, - mock_get_scorer, - mock_adversarial_target, - mock_objective_scorer, - mock_seed_groups, - ): - """Test that _get_strategy_attacks raises ValueError when scenario is not initialized.""" - mock_get_scorer.return_value = mock_objective_scorer - seed_groups = mock_seed_groups("hate") - - scenario = ContentHarms(adversarial_chat=mock_adversarial_target) - - with pytest.raises(ValueError, match="Scenario not properly initialized"): - scenario._get_strategy_attacks(strategy="hate", seed_groups=seed_groups) - - def test_aggregate_strategies_only_includes_all(self): - """Test that ALL is the only aggregate strategy.""" - aggregates = ContentHarmsStrategy.get_aggregate_strategies() - aggregate_values = [s.value for s in aggregates] - - assert "all" in aggregate_values - assert len(aggregates) == 1 diff --git a/tests/unit/scenario/test_rapid_response.py b/tests/unit/scenario/test_rapid_response.py new file mode 100644 index 0000000000..f5c942a203 --- /dev/null +++ b/tests/unit/scenario/test_rapid_response.py @@ -0,0 +1,569 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +"""Tests for the RapidResponse scenario (refactored from ContentHarms).""" + +import pathlib +from unittest.mock import MagicMock, patch + +import pytest + +from pyrit.common.path import DATASETS_PATH +from pyrit.executor.attack import ( + ManyShotJailbreakAttack, + PromptSendingAttack, + RolePlayAttack, + TreeOfAttacksWithPruningAttack, +) +from pyrit.identifiers import ComponentIdentifier +from pyrit.models import SeedAttackGroup, SeedObjective, SeedPrompt +from pyrit.prompt_target import PromptTarget +from pyrit.prompt_target.common.prompt_chat_target import PromptChatTarget +from pyrit.scenario import ScenarioCompositeStrategy +from pyrit.scenario.core.core_techniques import ( + many_shot_factory, + prompt_sending_factory, + role_play_factory, + tap_factory, +) +from pyrit.scenario.core.dataset_configuration import DatasetConfiguration +from pyrit.scenario.scenarios.airt.rapid_response import ( + RapidResponse, + RapidResponseStrategy, +) +from pyrit.score import TrueFalseScorer + + +# --------------------------------------------------------------------------- +# Synthetic many-shot examples — prevents reading the real JSON during tests +# --------------------------------------------------------------------------- +_MOCK_MANY_SHOT_EXAMPLES = [{"question": f"test question {i}", "answer": f"test answer {i}"} for i in range(100)] + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + + +def _mock_id(name: str) -> ComponentIdentifier: + return ComponentIdentifier(class_name=name, class_module="test") + + +# --------------------------------------------------------------------------- +# Fixtures +# --------------------------------------------------------------------------- + + +@pytest.fixture +def mock_objective_target(): + mock = MagicMock(spec=PromptTarget) + mock.get_identifier.return_value = _mock_id("MockObjectiveTarget") + return mock + + +@pytest.fixture +def mock_adversarial_target(): + mock = MagicMock(spec=PromptChatTarget) + mock.get_identifier.return_value = _mock_id("MockAdversarialTarget") + return mock + + +@pytest.fixture +def mock_objective_scorer(): + mock = MagicMock(spec=TrueFalseScorer) + mock.get_identifier.return_value = _mock_id("MockObjectiveScorer") + return mock + + +@pytest.fixture(autouse=True) +def patch_many_shot_load(): + """Prevent ManyShotJailbreakAttack from loading the full bundled dataset.""" + with patch( + "pyrit.executor.attack.single_turn.many_shot_jailbreak.load_many_shot_jailbreaking_dataset", + return_value=_MOCK_MANY_SHOT_EXAMPLES, + ): + yield + + +@pytest.fixture +def mock_runtime_env(): + with patch.dict( + "os.environ", + { + "AZURE_OPENAI_GPT4O_UNSAFE_CHAT_ENDPOINT": "https://test.openai.azure.com/", + "AZURE_OPENAI_GPT4O_UNSAFE_CHAT_KEY": "test-key", + "AZURE_OPENAI_GPT4O_UNSAFE_CHAT_MODEL": "gpt-4", + }, + ): + yield + + +def _make_seed_groups(name: str) -> list[SeedAttackGroup]: + """Create two seed attack groups for a given category.""" + return [ + SeedAttackGroup( + seeds=[SeedObjective(value=f"{name} objective 1"), SeedPrompt(value=f"{name} prompt 1")] + ), + SeedAttackGroup( + seeds=[SeedObjective(value=f"{name} objective 2"), SeedPrompt(value=f"{name} prompt 2")] + ), + ] + + +ALL_HARM_CATEGORIES = ["hate", "fairness", "violence", "sexual", "harassment", "misinformation", "leakage"] + +ALL_HARM_SEED_GROUPS = {cat: _make_seed_groups(cat) for cat in ALL_HARM_CATEGORIES} + + +FIXTURES = ["patch_central_database", "mock_runtime_env"] + + +# =========================================================================== +# Strategy enum tests +# =========================================================================== + + +class TestRapidResponseStrategy: + """Tests for the RapidResponseStrategy enum.""" + + def test_technique_members_exist(self): + """All four technique members are accessible.""" + assert RapidResponseStrategy.PromptSending.value == "prompt_sending" + assert RapidResponseStrategy.RolePlay.value == "role_play" + assert RapidResponseStrategy.ManyShot.value == "many_shot" + assert RapidResponseStrategy.TAP.value == "tap" + + def test_aggregate_members_exist(self): + """All four aggregate members are accessible.""" + assert RapidResponseStrategy.ALL.value == "all" + assert RapidResponseStrategy.DEFAULT.value == "default" + assert RapidResponseStrategy.SINGLE_TURN.value == "single_turn" + assert RapidResponseStrategy.MULTI_TURN.value == "multi_turn" + + def test_total_member_count(self): + """4 aggregates + 4 techniques = 8 members.""" + assert len(list(RapidResponseStrategy)) == 8 + + def test_non_aggregate_count(self): + """get_all_strategies returns only the 4 technique members.""" + non_aggregate = RapidResponseStrategy.get_all_strategies() + assert len(non_aggregate) == 4 + + def test_aggregate_tags(self): + tags = RapidResponseStrategy.get_aggregate_tags() + assert tags == {"all", "default", "single_turn", "multi_turn"} + + def test_default_expands_to_prompt_sending_and_many_shot(self): + """DEFAULT aggregate should expand to PromptSending + ManyShot.""" + expanded = RapidResponseStrategy.normalize_strategies({RapidResponseStrategy.DEFAULT}) + values = {s.value for s in expanded} + assert values == {"prompt_sending", "many_shot"} + + def test_single_turn_expands_to_prompt_sending_and_role_play(self): + expanded = RapidResponseStrategy.normalize_strategies({RapidResponseStrategy.SINGLE_TURN}) + values = {s.value for s in expanded} + assert values == {"prompt_sending", "role_play"} + + def test_multi_turn_expands_to_many_shot_and_tap(self): + expanded = RapidResponseStrategy.normalize_strategies({RapidResponseStrategy.MULTI_TURN}) + values = {s.value for s in expanded} + assert values == {"many_shot", "tap"} + + def test_all_expands_to_all_techniques(self): + expanded = RapidResponseStrategy.normalize_strategies({RapidResponseStrategy.ALL}) + values = {s.value for s in expanded} + assert values == {"prompt_sending", "role_play", "many_shot", "tap"} + + def test_strategy_values_are_unique(self): + values = [s.value for s in RapidResponseStrategy] + assert len(values) == len(set(values)) + + def test_invalid_strategy_value_raises(self): + with pytest.raises(ValueError): + RapidResponseStrategy("nonexistent") + + def test_invalid_strategy_name_raises(self): + with pytest.raises(KeyError): + RapidResponseStrategy["Nonexistent"] + + +# =========================================================================== +# Initialization / class-level tests +# =========================================================================== + + +@pytest.mark.usefixtures(*FIXTURES) +class TestRapidResponseBasic: + """Tests for RapidResponse initialization and class properties.""" + + def test_version_is_2(self): + assert RapidResponse.VERSION == 2 + + def test_get_strategy_class(self): + assert RapidResponse.get_strategy_class() is RapidResponseStrategy + + def test_get_default_strategy_returns_default(self): + assert RapidResponse.get_default_strategy() == RapidResponseStrategy.DEFAULT + + def test_default_dataset_config_has_all_harm_datasets(self): + config = RapidResponse.default_dataset_config() + assert isinstance(config, DatasetConfiguration) + names = config.get_default_dataset_names() + expected = [f"airt_{cat}" for cat in ALL_HARM_CATEGORIES] + for name in expected: + assert name in names + assert len(names) == 7 + + def test_default_dataset_config_max_dataset_size(self): + config = RapidResponse.default_dataset_config() + assert config.max_dataset_size == 4 + + @patch("pyrit.scenario.core.scenario.Scenario._get_default_objective_scorer") + def test_initialization_minimal(self, mock_get_scorer, mock_adversarial_target, mock_objective_scorer): + mock_get_scorer.return_value = mock_objective_scorer + scenario = RapidResponse(adversarial_chat=mock_adversarial_target) + assert scenario._adversarial_chat == mock_adversarial_target + assert scenario.name == "RapidResponse" + + def test_initialization_with_custom_scorer(self, mock_adversarial_target, mock_objective_scorer): + scenario = RapidResponse( + adversarial_chat=mock_adversarial_target, + objective_scorer=mock_objective_scorer, + ) + assert scenario._objective_scorer == mock_objective_scorer + + @patch("pyrit.scenario.core.scenario.Scenario._get_default_objective_scorer") + def test_default_adversarial_target_created(self, mock_get_scorer, mock_objective_scorer): + """With env vars patched, constructor creates an OpenAIChatTarget.""" + mock_get_scorer.return_value = mock_objective_scorer + scenario = RapidResponse() + assert scenario._adversarial_chat is not None + + @pytest.mark.asyncio + @patch("pyrit.scenario.core.scenario.Scenario._get_default_objective_scorer") + @patch.object(DatasetConfiguration, "get_seed_attack_groups", return_value=ALL_HARM_SEED_GROUPS) + async def test_initialization_defaults_to_default_strategy( + self, + _mock_groups, + mock_get_scorer, + mock_objective_target, + mock_adversarial_target, + mock_objective_scorer, + ): + mock_get_scorer.return_value = mock_objective_scorer + scenario = RapidResponse(adversarial_chat=mock_adversarial_target) + await scenario.initialize_async(objective_target=mock_objective_target) + # DEFAULT expands to PromptSending + ManyShot → 2 composites + assert len(scenario._scenario_composites) == 2 + + @pytest.mark.asyncio + async def test_initialize_raises_when_no_datasets( + self, mock_objective_target, mock_adversarial_target, mock_objective_scorer + ): + """Dataset resolution fails from empty memory.""" + scenario = RapidResponse( + adversarial_chat=mock_adversarial_target, + objective_scorer=mock_objective_scorer, + ) + with pytest.raises(ValueError, match="DatasetConfiguration has no seed_groups"): + await scenario.initialize_async(objective_target=mock_objective_target) + + @pytest.mark.asyncio + @patch("pyrit.scenario.core.scenario.Scenario._get_default_objective_scorer") + @patch.object(DatasetConfiguration, "get_seed_attack_groups", return_value=ALL_HARM_SEED_GROUPS) + async def test_memory_labels_stored( + self, + _mock_groups, + mock_get_scorer, + mock_objective_target, + mock_adversarial_target, + mock_objective_scorer, + ): + mock_get_scorer.return_value = mock_objective_scorer + labels = {"test_run": "123"} + scenario = RapidResponse(adversarial_chat=mock_adversarial_target) + await scenario.initialize_async(objective_target=mock_objective_target, memory_labels=labels) + assert scenario._memory_labels == labels + + @pytest.mark.parametrize("harm_category", ALL_HARM_CATEGORIES) + def test_harm_category_prompt_file_exists(self, harm_category): + harm_path = pathlib.Path(DATASETS_PATH) / "seed_datasets" / "local" / "airt" + assert (harm_path / f"{harm_category}.prompt").exists() + + +# =========================================================================== +# Attack generation tests +# =========================================================================== + + +@pytest.mark.usefixtures(*FIXTURES) +class TestRapidResponseAttackGeneration: + """Tests for _get_atomic_attacks_async with various strategies.""" + + async def _init_and_get_attacks( + self, + *, + mock_objective_target, + mock_adversarial_target, + mock_objective_scorer, + strategies: list[RapidResponseStrategy] | None = None, + seed_groups: dict[str, list[SeedAttackGroup]] | None = None, + ): + """Helper: initialize scenario and return atomic attacks.""" + groups = seed_groups or {"hate": _make_seed_groups("hate")} + with patch.object(DatasetConfiguration, "get_seed_attack_groups", return_value=groups): + scenario = RapidResponse( + adversarial_chat=mock_adversarial_target, + objective_scorer=mock_objective_scorer, + ) + init_kwargs = {"objective_target": mock_objective_target} + if strategies: + init_kwargs["scenario_strategies"] = strategies + await scenario.initialize_async(**init_kwargs) + return await scenario._get_atomic_attacks_async() + + @pytest.mark.asyncio + async def test_default_strategy_produces_prompt_sending_and_many_shot( + self, mock_objective_target, mock_adversarial_target, mock_objective_scorer + ): + attacks = await self._init_and_get_attacks( + mock_objective_target=mock_objective_target, + mock_adversarial_target=mock_adversarial_target, + mock_objective_scorer=mock_objective_scorer, + ) + technique_classes = {type(a.attack_technique.attack) for a in attacks} + assert technique_classes == {PromptSendingAttack, ManyShotJailbreakAttack} + + @pytest.mark.asyncio + async def test_single_turn_strategy_produces_prompt_sending_and_role_play( + self, mock_objective_target, mock_adversarial_target, mock_objective_scorer + ): + attacks = await self._init_and_get_attacks( + mock_objective_target=mock_objective_target, + mock_adversarial_target=mock_adversarial_target, + mock_objective_scorer=mock_objective_scorer, + strategies=[RapidResponseStrategy.SINGLE_TURN], + ) + technique_classes = {type(a.attack_technique.attack) for a in attacks} + assert technique_classes == {PromptSendingAttack, RolePlayAttack} + + @pytest.mark.asyncio + async def test_multi_turn_strategy_produces_many_shot_and_tap( + self, mock_objective_target, mock_adversarial_target, mock_objective_scorer + ): + attacks = await self._init_and_get_attacks( + mock_objective_target=mock_objective_target, + mock_adversarial_target=mock_adversarial_target, + mock_objective_scorer=mock_objective_scorer, + strategies=[RapidResponseStrategy.MULTI_TURN], + ) + technique_classes = {type(a.attack_technique.attack) for a in attacks} + assert technique_classes == {ManyShotJailbreakAttack, TreeOfAttacksWithPruningAttack} + + @pytest.mark.asyncio + async def test_all_strategy_produces_all_four_techniques( + self, mock_objective_target, mock_adversarial_target, mock_objective_scorer + ): + attacks = await self._init_and_get_attacks( + mock_objective_target=mock_objective_target, + mock_adversarial_target=mock_adversarial_target, + mock_objective_scorer=mock_objective_scorer, + strategies=[RapidResponseStrategy.ALL], + ) + technique_classes = {type(a.attack_technique.attack) for a in attacks} + assert technique_classes == { + PromptSendingAttack, + RolePlayAttack, + ManyShotJailbreakAttack, + TreeOfAttacksWithPruningAttack, + } + + @pytest.mark.asyncio + async def test_single_technique_selection( + self, mock_objective_target, mock_adversarial_target, mock_objective_scorer + ): + attacks = await self._init_and_get_attacks( + mock_objective_target=mock_objective_target, + mock_adversarial_target=mock_adversarial_target, + mock_objective_scorer=mock_objective_scorer, + strategies=[RapidResponseStrategy.PromptSending], + ) + assert len(attacks) > 0 + for a in attacks: + assert isinstance(a.attack_technique.attack, PromptSendingAttack) + + @pytest.mark.asyncio + async def test_attack_count_is_techniques_times_datasets( + self, mock_objective_target, mock_adversarial_target, mock_objective_scorer + ): + """With 2 datasets and DEFAULT (2 techniques), expect 4 atomic attacks.""" + two_datasets = { + "hate": _make_seed_groups("hate"), + "violence": _make_seed_groups("violence"), + } + attacks = await self._init_and_get_attacks( + mock_objective_target=mock_objective_target, + mock_adversarial_target=mock_adversarial_target, + mock_objective_scorer=mock_objective_scorer, + seed_groups=two_datasets, + ) + # DEFAULT = PromptSending + ManyShot = 2 techniques, 2 datasets → 4 + assert len(attacks) == 4 + + @pytest.mark.asyncio + async def test_atomic_attack_names_group_by_harm_category( + self, mock_objective_target, mock_adversarial_target, mock_objective_scorer + ): + """_build_atomic_attack_name groups by dataset (harm category), not technique.""" + two_datasets = { + "hate": _make_seed_groups("hate"), + "violence": _make_seed_groups("violence"), + } + attacks = await self._init_and_get_attacks( + mock_objective_target=mock_objective_target, + mock_adversarial_target=mock_adversarial_target, + mock_objective_scorer=mock_objective_scorer, + seed_groups=two_datasets, + ) + names = {a.atomic_attack_name for a in attacks} + assert names == {"hate", "violence"} + + @pytest.mark.asyncio + async def test_raises_when_not_initialized(self, mock_adversarial_target, mock_objective_scorer): + scenario = RapidResponse( + adversarial_chat=mock_adversarial_target, + objective_scorer=mock_objective_scorer, + ) + with pytest.raises(ValueError, match="Scenario not properly initialized"): + await scenario._get_atomic_attacks_async() + + @pytest.mark.asyncio + async def test_unknown_technique_skipped_with_warning( + self, mock_objective_target, mock_adversarial_target, mock_objective_scorer + ): + """If a technique name has no factory, it's skipped (not an error).""" + groups = {"hate": _make_seed_groups("hate")} + with ( + patch.object(DatasetConfiguration, "get_seed_attack_groups", return_value=groups), + patch.object( + RapidResponse, + "get_attack_technique_factories", + return_value={"prompt_sending": prompt_sending_factory()}, + ), + ): + scenario = RapidResponse( + adversarial_chat=mock_adversarial_target, + objective_scorer=mock_objective_scorer, + ) + # Select ALL which includes role_play, many_shot, tap — none have factories + await scenario.initialize_async( + objective_target=mock_objective_target, + scenario_strategies=[RapidResponseStrategy.ALL], + ) + attacks = await scenario._get_atomic_attacks_async() + # Only prompt_sending should have produced attacks + assert len(attacks) == 1 + assert isinstance(attacks[0].attack_technique.attack, PromptSendingAttack) + + @pytest.mark.asyncio + async def test_attacks_include_seed_groups( + self, mock_objective_target, mock_adversarial_target, mock_objective_scorer + ): + """Each atomic attack carries the correct seed groups.""" + attacks = await self._init_and_get_attacks( + mock_objective_target=mock_objective_target, + mock_adversarial_target=mock_adversarial_target, + mock_objective_scorer=mock_objective_scorer, + strategies=[RapidResponseStrategy.PromptSending], + ) + for a in attacks: + assert len(a.objectives) > 0 + + +# =========================================================================== +# _build_atomic_attack_name tests +# =========================================================================== + + +@pytest.mark.usefixtures(*FIXTURES) +class TestBuildAtomicAttackName: + def test_rapid_response_groups_by_seed_group_name(self, mock_adversarial_target, mock_objective_scorer): + scenario = RapidResponse( + adversarial_chat=mock_adversarial_target, + objective_scorer=mock_objective_scorer, + ) + result = scenario._build_atomic_attack_name(technique_name="prompt_sending", seed_group_name="hate") + assert result == "hate" + + def test_rapid_response_ignores_technique_name(self, mock_adversarial_target, mock_objective_scorer): + scenario = RapidResponse( + adversarial_chat=mock_adversarial_target, + objective_scorer=mock_objective_scorer, + ) + r1 = scenario._build_atomic_attack_name(technique_name="prompt_sending", seed_group_name="hate") + r2 = scenario._build_atomic_attack_name(technique_name="tap", seed_group_name="hate") + assert r1 == r2 == "hate" + + +# =========================================================================== +# Core techniques factory tests +# =========================================================================== + + +class TestCoreTechniques: + """Tests for shared AttackTechniqueFactory builders in core_techniques.py.""" + + def test_prompt_sending_factory_attack_class(self): + f = prompt_sending_factory() + assert f.attack_class is PromptSendingAttack + + def test_role_play_factory_attack_class(self): + f = role_play_factory() + assert f.attack_class is RolePlayAttack + + def test_many_shot_factory_attack_class(self): + f = many_shot_factory() + assert f.attack_class is ManyShotJailbreakAttack + + def test_tap_factory_attack_class(self): + f = tap_factory() + assert f.attack_class is TreeOfAttacksWithPruningAttack + + def test_base_class_returns_all_four_factories(self): + factories = RapidResponse.get_attack_technique_factories() + assert set(factories.keys()) == {"prompt_sending", "role_play", "many_shot", "tap"} + assert factories["prompt_sending"].attack_class is PromptSendingAttack + assert factories["role_play"].attack_class is RolePlayAttack + assert factories["many_shot"].attack_class is ManyShotJailbreakAttack + assert factories["tap"].attack_class is TreeOfAttacksWithPruningAttack + + +# =========================================================================== +# Deprecated alias tests +# =========================================================================== + + +@pytest.mark.usefixtures(*FIXTURES) +class TestDeprecatedAliases: + """Tests for backward-compatible ContentHarms aliases.""" + + def test_content_harms_is_rapid_response(self): + from pyrit.scenario.scenarios.airt.content_harms import ContentHarms + + assert ContentHarms is RapidResponse + + def test_content_harms_strategy_is_rapid_response_strategy(self): + from pyrit.scenario.scenarios.airt.content_harms import ContentHarmsStrategy + + assert ContentHarmsStrategy is RapidResponseStrategy + + def test_content_harms_instance_name_is_rapid_response(self, mock_adversarial_target, mock_objective_scorer): + """ContentHarms() creates a RapidResponse with name 'RapidResponse'.""" + from pyrit.scenario.scenarios.airt.content_harms import ContentHarms + + scenario = ContentHarms( + adversarial_chat=mock_adversarial_target, + objective_scorer=mock_objective_scorer, + ) + assert scenario.name == "RapidResponse" + assert isinstance(scenario, RapidResponse) From 8fac0f11d89d773743bc99d9d5fcf8ff8ac6e397 Mon Sep 17 00:00:00 2001 From: Richard Lundeen Date: Thu, 16 Apr 2026 12:24:42 -0700 Subject: [PATCH 2/3] refactor --- .env_example | 5 + doc/code/scenarios/0_scenarios.ipynb | 9 +- doc/code/scenarios/0_scenarios.py | 9 +- pyrit/registry/__init__.py | 2 + .../class_registries/scenario_registry.py | 16 + pyrit/registry/object_registries/__init__.py | 2 + .../attack_technique_registry.py | 61 +++- pyrit/scenario/core/__init__.py | 26 +- .../scenario/core/attack_technique_factory.py | 5 +- pyrit/scenario/core/core_techniques.py | 61 +--- pyrit/scenario/core/scenario.py | 32 +- pyrit/scenario/core/scenario_techniques.py | 181 ++++++++++ .../scenario/scenarios/airt/content_harms.py | 29 +- .../scenario/scenarios/airt/rapid_response.py | 46 +-- .../setup/initializers/components/targets.py | 13 + .../test_attack_technique_registry.py | 25 +- .../scenario/test_attack_technique_factory.py | 15 +- tests/unit/scenario/test_rapid_response.py | 319 ++++++++++++++++-- 18 files changed, 667 insertions(+), 189 deletions(-) create mode 100644 pyrit/scenario/core/scenario_techniques.py diff --git a/.env_example b/.env_example index e8bd5da94b..ef1aae1b4b 100644 --- a/.env_example +++ b/.env_example @@ -73,6 +73,11 @@ AZURE_OPENAI_GPT4O_UNSAFE_CHAT_KEY2="xxxxx" AZURE_OPENAI_GPT4O_UNSAFE_CHAT_MODEL2="deployment-name" AZURE_OPENAI_GPT4O_UNSAFE_CHAT_UNDERLYING_MODEL2="" +# Adversarial chat target (used by scenario attack techniques, e.g. role-play, TAP) +ADVERSARIAL_CHAT_ENDPOINT="https://xxxxx.openai.azure.com/openai/v1" +ADVERSARIAL_CHAT_KEY="xxxxx" +ADVERSARIAL_CHAT_MODEL="deployment-name" + AZURE_FOUNDRY_DEEPSEEK_ENDPOINT="https://xxxxx.eastus2.models.ai.azure.com" AZURE_FOUNDRY_DEEPSEEK_KEY="xxxxx" AZURE_FOUNDRY_DEEPSEEK_MODEL="" diff --git a/doc/code/scenarios/0_scenarios.ipynb b/doc/code/scenarios/0_scenarios.ipynb index ab29a07bb1..d27b3e4615 100644 --- a/doc/code/scenarios/0_scenarios.ipynb +++ b/doc/code/scenarios/0_scenarios.ipynb @@ -55,7 +55,6 @@ "\n", "1. **Strategy Enum**: Create a `ScenarioStrategy` enum that defines the available attack techniques for your scenario.\n", " - Each enum member represents an **attack technique** (the *how* of an attack)\n", - " - Datasets control *what* content is tested; strategies control *how* attacks are run\n", " - Each member is defined as `(value, tags)` where value is a string and tags is a set of strings\n", " - Include an `ALL` aggregate strategy that expands to all available strategies\n", " - Optionally implement `supports_composition()` and `validate_composition()` for strategy composition rules\n", @@ -65,7 +64,11 @@ " - `get_default_strategy()`: Return the default strategy (typically `YourStrategy.ALL`)\n", " - `_get_atomic_attacks_async()`: Build and return a list of `AtomicAttack` instances\n", "\n", - "3. **Constructor**: Use `@apply_defaults` decorator and call `super().__init__()` with scenario metadata:\n", + "3. **Default Dataset**: Implement `default_dataset_config()` to specify the datasets your scenario uses out of the box.\n", + " - Returns a `DatasetConfiguration` with one or more named datasets (e.g., `DatasetConfiguration(dataset_names=[\"my_dataset\"])`)\n", + " - Users can override this at runtime via `--dataset-names` in the CLI or by passing a custom `dataset_config` programmatically\n", + "\n", + "4. **Constructor**: Use `@apply_defaults` decorator and call `super().__init__()` with scenario metadata:\n", " - `name`: Descriptive name for your scenario\n", " - `version`: Integer version number\n", " - `strategy_class`: The strategy enum class for this scenario\n", @@ -73,7 +76,7 @@ " - `include_default_baseline`: Whether to include a baseline attack (default: True)\n", " - `scenario_result_id`: Optional ID to resume an existing scenario (optional)\n", "\n", - "4. **Initialization**: Call `await scenario.initialize_async()` to populate atomic attacks:\n", + "5. **Initialization**: Call `await scenario.initialize_async()` to populate atomic attacks:\n", " - `objective_target`: The target system being tested (required)\n", " - `scenario_strategies`: List of strategies to execute (optional, defaults to ALL)\n", " - `max_concurrency`: Number of concurrent operations (default: 1)\n", diff --git a/doc/code/scenarios/0_scenarios.py b/doc/code/scenarios/0_scenarios.py index fca62237f8..3d24fbed48 100644 --- a/doc/code/scenarios/0_scenarios.py +++ b/doc/code/scenarios/0_scenarios.py @@ -61,7 +61,6 @@ # # 1. **Strategy Enum**: Create a `ScenarioStrategy` enum that defines the available attack techniques for your scenario. # - Each enum member represents an **attack technique** (the *how* of an attack) -# - Datasets control *what* content is tested; strategies control *how* attacks are run # - Each member is defined as `(value, tags)` where value is a string and tags is a set of strings # - Include an `ALL` aggregate strategy that expands to all available strategies # - Optionally implement `supports_composition()` and `validate_composition()` for strategy composition rules @@ -71,7 +70,11 @@ # - `get_default_strategy()`: Return the default strategy (typically `YourStrategy.ALL`) # - `_get_atomic_attacks_async()`: Build and return a list of `AtomicAttack` instances # -# 3. **Constructor**: Use `@apply_defaults` decorator and call `super().__init__()` with scenario metadata: +# 3. **Default Dataset**: Implement `default_dataset_config()` to specify the datasets your scenario uses out of the box. +# - Returns a `DatasetConfiguration` with one or more named datasets (e.g., `DatasetConfiguration(dataset_names=["my_dataset"])`) +# - Users can override this at runtime via `--dataset-names` in the CLI or by passing a custom `dataset_config` programmatically +# +# 4. **Constructor**: Use `@apply_defaults` decorator and call `super().__init__()` with scenario metadata: # - `name`: Descriptive name for your scenario # - `version`: Integer version number # - `strategy_class`: The strategy enum class for this scenario @@ -79,7 +82,7 @@ # - `include_default_baseline`: Whether to include a baseline attack (default: True) # - `scenario_result_id`: Optional ID to resume an existing scenario (optional) # -# 4. **Initialization**: Call `await scenario.initialize_async()` to populate atomic attacks: +# 5. **Initialization**: Call `await scenario.initialize_async()` to populate atomic attacks: # - `objective_target`: The target system being tested (required) # - `scenario_strategies`: List of strategies to execute (optional, defaults to ALL) # - `max_concurrency`: Number of concurrent operations (default: 1) diff --git a/pyrit/registry/__init__.py b/pyrit/registry/__init__.py index 4f8290e993..e5dccc5082 100644 --- a/pyrit/registry/__init__.py +++ b/pyrit/registry/__init__.py @@ -25,6 +25,7 @@ RetrievableInstanceRegistry, ScorerRegistry, TargetRegistry, + TechniqueSpec, ) __all__ = [ @@ -45,4 +46,5 @@ "ScenarioRegistry", "ScorerRegistry", "TargetRegistry", + "TechniqueSpec", ] diff --git a/pyrit/registry/class_registries/scenario_registry.py b/pyrit/registry/class_registries/scenario_registry.py index f8b0e3e87f..d34c077ee0 100644 --- a/pyrit/registry/class_registries/scenario_registry.py +++ b/pyrit/registry/class_registries/scenario_registry.py @@ -118,6 +118,22 @@ def _discover_builtin_scenarios(self) -> None: logger.debug(f"Skipping deprecated alias: {scenario_class.__name__}") continue + # Skip re-exported aliases: if the class was defined in a different + # module than the one being discovered, it's an alias (e.g., + # ContentHarms in content_harms.py is really RapidResponse from + # rapid_response.py). + class_module = getattr(scenario_class, "__module__", "") + expected_module_suffix = registry_name.replace(".", "/") + if not class_module.endswith(registry_name.replace("/", ".")): + # Build the full expected module name for comparison + expected_module = f"pyrit.scenario.scenarios.{registry_name.replace('/', '.')}" + if class_module != expected_module: + logger.debug( + f"Skipping alias '{scenario_class.__name__}' in '{registry_name}' " + f"(defined in {class_module})" + ) + continue + # Check for registry key collision if registry_name in self._class_entries: logger.warning( diff --git a/pyrit/registry/object_registries/__init__.py b/pyrit/registry/object_registries/__init__.py index 0a43a5af2f..f0c85d23c3 100644 --- a/pyrit/registry/object_registries/__init__.py +++ b/pyrit/registry/object_registries/__init__.py @@ -13,6 +13,7 @@ from pyrit.registry.object_registries.attack_technique_registry import ( AttackTechniqueRegistry, + TechniqueSpec, ) from pyrit.registry.object_registries.base_instance_registry import ( BaseInstanceRegistry, @@ -41,4 +42,5 @@ "ConverterRegistry", "ScorerRegistry", "TargetRegistry", + "TechniqueSpec", ] diff --git a/pyrit/registry/object_registries/attack_technique_registry.py b/pyrit/registry/object_registries/attack_technique_registry.py index 2b68ffd651..110b6fb209 100644 --- a/pyrit/registry/object_registries/attack_technique_registry.py +++ b/pyrit/registry/object_registries/attack_technique_registry.py @@ -12,7 +12,8 @@ from __future__ import annotations import logging -from typing import TYPE_CHECKING +from dataclasses import dataclass, field +from typing import TYPE_CHECKING, Any, Callable from pyrit.registry.object_registries.base_instance_registry import ( BaseInstanceRegistry, @@ -24,13 +25,39 @@ AttackConverterConfig, AttackScoringConfig, ) - from pyrit.prompt_target import PromptTarget + from pyrit.prompt_target import PromptChatTarget, PromptTarget from pyrit.scenario.core.attack_technique import AttackTechnique from pyrit.scenario.core.attack_technique_factory import AttackTechniqueFactory logger = logging.getLogger(__name__) +@dataclass(frozen=True) +class TechniqueSpec: + """ + Declarative definition of an attack technique. + + Each spec describes one registrable technique. The registrar converts + specs into ``AttackTechniqueFactory`` instances and registers them. + + Whether a technique receives an ``AttackAdversarialConfig`` is determined + automatically: the registrar inspects the attack class constructor and + injects one when ``attack_adversarial_config`` is an accepted parameter. + + Args: + name: Registry name (must match the strategy enum value). + attack_class: The ``AttackStrategy`` subclass. + tags: Classification tags (e.g. ``["single_turn"]``). + extra_kwargs_builder: Optional callback that returns additional kwargs + for the factory. Receives the resolved adversarial target. + """ + + name: str + attack_class: type + tags: list[str] = field(default_factory=list) + extra_kwargs_builder: Callable[["PromptChatTarget"], dict[str, Any]] | None = None + + class AttackTechniqueRegistry(BaseInstanceRegistry["AttackTechniqueFactory"]): """ Singleton registry of reusable attack technique factories. @@ -59,14 +86,23 @@ def register_technique( self.register(factory, name=name, tags=tags) logger.debug(f"Registered attack technique factory: {name} ({factory.attack_class.__name__})") + def get_factories(self) -> dict[str, "AttackTechniqueFactory"]: + """ + Return all registered factories as a name→factory dict. + + Returns: + dict[str, AttackTechniqueFactory]: Mapping of technique name to factory. + """ + return {name: entry.instance for name, entry in self._registry_items.items()} + def create_technique( self, name: str, *, objective_target: PromptTarget, - attack_scoring_config: AttackScoringConfig, - attack_adversarial_config: AttackAdversarialConfig | None = None, - attack_converter_config: AttackConverterConfig | None = None, + attack_scoring_config_override: AttackScoringConfig | None = None, + attack_adversarial_config_override: AttackAdversarialConfig | None = None, + attack_converter_config_override: AttackConverterConfig | None = None, ) -> AttackTechnique: """ Retrieve a factory by name and produce a fresh attack technique. @@ -74,9 +110,12 @@ def create_technique( Args: name: The registry name of the technique. objective_target: The target to attack. - attack_scoring_config: Scoring configuration for the attack. - attack_adversarial_config: Optional adversarial configuration override. - attack_converter_config: Optional converter configuration override. + attack_scoring_config_override: When non-None, replaces any scoring + config baked into the factory. + attack_adversarial_config_override: When non-None, replaces any + adversarial config baked into the factory. + attack_converter_config_override: When non-None, replaces any + converter config baked into the factory. Returns: A fresh AttackTechnique with a newly-constructed attack strategy. @@ -89,7 +128,7 @@ def create_technique( raise KeyError(f"No technique registered with name '{name}'") return entry.instance.create( objective_target=objective_target, - attack_scoring_config=attack_scoring_config, - attack_adversarial_config=attack_adversarial_config, - attack_converter_config=attack_converter_config, + attack_scoring_config_override=attack_scoring_config_override, + attack_adversarial_config_override=attack_adversarial_config_override, + attack_converter_config_override=attack_converter_config_override, ) diff --git a/pyrit/scenario/core/__init__.py b/pyrit/scenario/core/__init__.py index 7affb77c5f..f464f32ddf 100644 --- a/pyrit/scenario/core/__init__.py +++ b/pyrit/scenario/core/__init__.py @@ -6,27 +6,35 @@ from pyrit.scenario.core.atomic_attack import AtomicAttack from pyrit.scenario.core.attack_technique import AttackTechnique from pyrit.scenario.core.attack_technique_factory import AttackTechniqueFactory -from pyrit.scenario.core.core_techniques import ( - many_shot_factory, - prompt_sending_factory, - role_play_factory, - tap_factory, +from pyrit.scenario.core.scenario_techniques import ( + SCENARIO_TECHNIQUES, + ScenarioTechniqueRegistrar, + get_default_adversarial_target, ) from pyrit.scenario.core.dataset_configuration import EXPLICIT_SEED_GROUPS_KEY, DatasetConfiguration from pyrit.scenario.core.scenario import Scenario from pyrit.scenario.core.scenario_strategy import ScenarioCompositeStrategy, ScenarioStrategy +# TechniqueSpec lives in the registry module but is re-exported here for convenience +from pyrit.registry.object_registries.attack_technique_registry import TechniqueSpec + +# Backward-compatible aliases (old names) +CORE_TECHNIQUES = SCENARIO_TECHNIQUES +CoreTechniqueRegistrar = ScenarioTechniqueRegistrar + __all__ = [ "AtomicAttack", "AttackTechnique", "AttackTechniqueFactory", + "CORE_TECHNIQUES", + "CoreTechniqueRegistrar", "DatasetConfiguration", "EXPLICIT_SEED_GROUPS_KEY", + "SCENARIO_TECHNIQUES", "Scenario", "ScenarioCompositeStrategy", "ScenarioStrategy", - "many_shot_factory", - "prompt_sending_factory", - "role_play_factory", - "tap_factory", + "ScenarioTechniqueRegistrar", + "TechniqueSpec", + "get_default_adversarial_target", ] diff --git a/pyrit/scenario/core/attack_technique_factory.py b/pyrit/scenario/core/attack_technique_factory.py index edf3934faa..8c7aa1142e 100644 --- a/pyrit/scenario/core/attack_technique_factory.py +++ b/pyrit/scenario/core/attack_technique_factory.py @@ -11,7 +11,6 @@ from __future__ import annotations -import copy import inspect from typing import TYPE_CHECKING, Any @@ -64,7 +63,7 @@ def __init__( ValueError: If ``objective_target`` is included in attack_kwargs. """ self._attack_class = attack_class - self._attack_kwargs = copy.deepcopy(attack_kwargs) if attack_kwargs else {} + self._attack_kwargs = dict(attack_kwargs) if attack_kwargs else {} self._seed_technique = seed_technique self._validate_kwargs() @@ -169,7 +168,7 @@ class constructor accepts ``attack_scoring_config``. Returns: A fresh AttackTechnique with a newly-constructed attack strategy. """ - kwargs = copy.deepcopy(self._attack_kwargs) + kwargs = dict(self._attack_kwargs) kwargs["objective_target"] = objective_target # Only forward overrides when the attack class accepts the underlying param diff --git a/pyrit/scenario/core/core_techniques.py b/pyrit/scenario/core/core_techniques.py index d8200bbeda..dd4b5ecf9e 100644 --- a/pyrit/scenario/core/core_techniques.py +++ b/pyrit/scenario/core/core_techniques.py @@ -2,55 +2,24 @@ # Licensed under the MIT license. """ -Shared AttackTechniqueFactory builders for common attack techniques. +Deprecated — use ``scenario_techniques`` instead. -These functions return ``AttackTechniqueFactory`` instances that can be -used by any scenario. Each factory captures technique-specific defaults -at registration time; runtime parameters (``objective_target``) and -optional overrides (``attack_scoring_config_override``, etc.) are -provided when ``factory.create()`` is called during scenario execution. - -Scenarios expose available factories via the overridable -``Scenario.get_attack_technique_factories()`` classmethod. +This module re-exports everything from ``scenario_techniques`` for backward +compatibility. It will be removed in a future release. """ -from pyrit.executor.attack import ( - ManyShotJailbreakAttack, - PromptSendingAttack, - RolePlayAttack, - RolePlayPaths, - TreeOfAttacksWithPruningAttack, +from pyrit.scenario.core.scenario_techniques import ( + SCENARIO_TECHNIQUES as CORE_TECHNIQUES, + ScenarioTechniqueRegistrar as CoreTechniqueRegistrar, + get_default_adversarial_target, ) -from pyrit.scenario.core.attack_technique_factory import AttackTechniqueFactory - - -def prompt_sending_factory() -> AttackTechniqueFactory: - """Create a factory for ``PromptSendingAttack`` (single-turn, no converter).""" - return AttackTechniqueFactory(attack_class=PromptSendingAttack) - - -def role_play_factory( - *, - role_play_path: str | None = None, -) -> AttackTechniqueFactory: - """ - Create a factory for ``RolePlayAttack`` (single-turn with role-play converter). - - Args: - role_play_path: Path to the role-play YAML definition. - Defaults to ``RolePlayPaths.MOVIE_SCRIPT``. - """ - kwargs: dict[str, object] = { - "role_play_definition_path": role_play_path or RolePlayPaths.MOVIE_SCRIPT.value, - } - return AttackTechniqueFactory(attack_class=RolePlayAttack, attack_kwargs=kwargs) - - -def many_shot_factory() -> AttackTechniqueFactory: - """Create a factory for ``ManyShotJailbreakAttack`` (multi-turn).""" - return AttackTechniqueFactory(attack_class=ManyShotJailbreakAttack) +# Re-export TechniqueSpec from its canonical location +from pyrit.registry.object_registries.attack_technique_registry import TechniqueSpec -def tap_factory() -> AttackTechniqueFactory: - """Create a factory for ``TreeOfAttacksWithPruningAttack`` (multi-turn).""" - return AttackTechniqueFactory(attack_class=TreeOfAttacksWithPruningAttack) +__all__ = [ + "CORE_TECHNIQUES", + "CoreTechniqueRegistrar", + "TechniqueSpec", + "get_default_adversarial_target", +] diff --git a/pyrit/scenario/core/scenario.py b/pyrit/scenario/core/scenario.py index 27898b61b3..23add6c3fa 100644 --- a/pyrit/scenario/core/scenario.py +++ b/pyrit/scenario/core/scenario.py @@ -174,41 +174,35 @@ def default_dataset_config(cls) -> DatasetConfiguration: DatasetConfiguration: The default dataset configuration. """ - @classmethod - def get_attack_technique_factories(cls) -> dict[str, "AttackTechniqueFactory"]: + def get_attack_technique_factories(self) -> dict[str, "AttackTechniqueFactory"]: """ - Return the default attack technique factories for this scenario. + Return the attack technique factories for this scenario. Each key is a technique name (matching a strategy enum value) and each value is an ``AttackTechniqueFactory`` that can produce an ``AttackTechnique`` for that technique. - The base implementation returns the common set from - ``core_techniques``. Subclasses may override to add, remove, or - replace factories. + The base implementation lazily populates the + ``AttackTechniqueRegistry`` singleton with core techniques (via + ``ScenarioTechniqueRegistrar``) and returns all registered factories. + Subclasses may override to add, remove, or replace factories. Returns: dict[str, AttackTechniqueFactory]: Mapping of technique name to factory. """ - from pyrit.scenario.core.core_techniques import ( - many_shot_factory, - prompt_sending_factory, - role_play_factory, - tap_factory, - ) + from pyrit.scenario.core.scenario_techniques import ScenarioTechniqueRegistrar - return { - "prompt_sending": prompt_sending_factory(), - "role_play": role_play_factory(), - "many_shot": many_shot_factory(), - "tap": tap_factory(), - } + ScenarioTechniqueRegistrar().register() + + from pyrit.registry.object_registries.attack_technique_registry import AttackTechniqueRegistry + + return AttackTechniqueRegistry.get_registry_singleton().get_factories() def _build_atomic_attack_name(self, *, technique_name: str, seed_group_name: str) -> str: """ Build the grouping key for an atomic attack. - Controls how attacks are grouped for result storage and resume + Controls how attacks are grouped for result storage, display, and resume logic. Override to customize grouping: - **By technique** (default): ``return technique_name`` diff --git a/pyrit/scenario/core/scenario_techniques.py b/pyrit/scenario/core/scenario_techniques.py new file mode 100644 index 0000000000..cfdb8b7d70 --- /dev/null +++ b/pyrit/scenario/core/scenario_techniques.py @@ -0,0 +1,181 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +""" +Scenario attack technique definitions and registration. + +Provides ``SCENARIO_TECHNIQUES`` (the standard catalog) and +``ScenarioTechniqueRegistrar`` (registers specs into the +``AttackTechniqueRegistry`` singleton). + +To add a new technique, append a ``TechniqueSpec`` to ``SCENARIO_TECHNIQUES``. +""" + +from __future__ import annotations + +import inspect +import logging +from typing import Any + +from pyrit.executor.attack import ( + AttackAdversarialConfig, + ManyShotJailbreakAttack, + PromptSendingAttack, + RolePlayAttack, + RolePlayPaths, + TreeOfAttacksWithPruningAttack, +) +from pyrit.prompt_target import OpenAIChatTarget, PromptChatTarget +from pyrit.registry.object_registries.attack_technique_registry import TechniqueSpec +from pyrit.scenario.core.attack_technique_factory import AttackTechniqueFactory + +logger = logging.getLogger(__name__) + + +# --------------------------------------------------------------------------- +# Scenario technique catalog +# --------------------------------------------------------------------------- + +SCENARIO_TECHNIQUES: list[TechniqueSpec] = [ + TechniqueSpec( + name="prompt_sending", + attack_class=PromptSendingAttack, + tags=["single_turn"], + ), + TechniqueSpec( + name="role_play", + attack_class=RolePlayAttack, + tags=["single_turn"], + extra_kwargs_builder=lambda _adv: { + "role_play_definition_path": RolePlayPaths.MOVIE_SCRIPT.value, + }, + ), + TechniqueSpec( + name="many_shot", + attack_class=ManyShotJailbreakAttack, + tags=["multi_turn"], + ), + TechniqueSpec( + name="tap", + attack_class=TreeOfAttacksWithPruningAttack, + tags=["multi_turn"], + ), +] + + +# --------------------------------------------------------------------------- +# Default adversarial target +# --------------------------------------------------------------------------- + + +def get_default_adversarial_target() -> PromptChatTarget: + """ + Resolve the default adversarial chat target. + + First checks the ``TargetRegistry`` for an ``"adversarial_chat"`` entry + (populated by ``TargetInitializer`` from ``ADVERSARIAL_CHAT_*`` env vars). + Falls back to a plain ``OpenAIChatTarget(temperature=1.2)`` using + ``@apply_defaults`` resolution. + """ + from pyrit.registry import TargetRegistry + + registry = TargetRegistry.get_registry_singleton() + if "adversarial_chat" in registry: + return registry.get("adversarial_chat") + + return OpenAIChatTarget(temperature=1.2) + + +# --------------------------------------------------------------------------- +# Registrar +# --------------------------------------------------------------------------- + + +class ScenarioTechniqueRegistrar: + """ + Registers ``TechniqueSpec`` entries into the ``AttackTechniqueRegistry``. + + Holds shared defaults (e.g. ``adversarial_chat``) so they're set once + and applied to every technique that needs them. + + Typical usage from a scenario:: + + ScenarioTechniqueRegistrar(adversarial_chat=self._adversarial_chat).register() + """ + + def __init__(self, *, adversarial_chat: PromptChatTarget | None = None) -> None: + """ + Args: + adversarial_chat: Shared adversarial chat target for techniques + that require one. Defaults to ``get_default_adversarial_target()``. + """ + self._adversarial_chat = adversarial_chat + + @property + def adversarial_chat(self) -> PromptChatTarget: + """Resolve the adversarial chat target (custom or default).""" + if self._adversarial_chat is None: + self._adversarial_chat = get_default_adversarial_target() + return self._adversarial_chat + + def build_factory(self, spec: TechniqueSpec) -> AttackTechniqueFactory: + """ + Build an ``AttackTechniqueFactory`` from a ``TechniqueSpec``. + + Automatically injects ``AttackAdversarialConfig`` when the attack + class accepts ``attack_adversarial_config`` as a constructor parameter. + + Args: + spec: The technique specification. + + Returns: + AttackTechniqueFactory: A factory ready for registration. + """ + kwargs: dict[str, Any] = {} + + if self._accepts_adversarial(spec.attack_class): + kwargs["attack_adversarial_config"] = AttackAdversarialConfig(target=self.adversarial_chat) + + if spec.extra_kwargs_builder: + kwargs.update(spec.extra_kwargs_builder(self.adversarial_chat)) + + return AttackTechniqueFactory( + attack_class=spec.attack_class, + attack_kwargs=kwargs or None, + ) + + @staticmethod + def _accepts_adversarial(attack_class: type) -> bool: + """Check if an attack class accepts ``attack_adversarial_config``.""" + sig = inspect.signature(attack_class.__init__) + return "attack_adversarial_config" in sig.parameters + + def register( + self, + *, + techniques: list[TechniqueSpec] | None = None, + registry: "AttackTechniqueRegistry | None" = None, + ) -> None: + """ + Register technique specs into the registry. + + Per-name idempotent: existing entries are not overwritten. + + Args: + techniques: Specs to register. Defaults to ``SCENARIO_TECHNIQUES``. + registry: Registry instance. Defaults to the singleton. + """ + from pyrit.registry.object_registries.attack_technique_registry import AttackTechniqueRegistry + + if registry is None: + registry = AttackTechniqueRegistry.get_registry_singleton() + if techniques is None: + techniques = SCENARIO_TECHNIQUES + + for spec in techniques: + if spec.name not in registry: + factory = self.build_factory(spec) + registry.register_technique(name=spec.name, factory=factory, tags=spec.tags) + + logger.debug("Technique registration complete (%d total in registry)", len(registry)) + diff --git a/pyrit/scenario/scenarios/airt/content_harms.py b/pyrit/scenario/scenarios/airt/content_harms.py index 47c399592e..7d11a46285 100644 --- a/pyrit/scenario/scenarios/airt/content_harms.py +++ b/pyrit/scenario/scenarios/airt/content_harms.py @@ -8,32 +8,9 @@ backward compatibility. They will be removed in a future release. """ -import warnings - from pyrit.scenario.scenarios.airt.rapid_response import ( - RapidResponse, - RapidResponseStrategy, + RapidResponse as ContentHarms, + RapidResponseStrategy as ContentHarmsStrategy, ) - -def __getattr__(name: str): - if name == "ContentHarms": - warnings.warn( - "ContentHarms is deprecated. Use RapidResponse instead.", - DeprecationWarning, - stacklevel=2, - ) - return RapidResponse - if name == "ContentHarmsStrategy": - warnings.warn( - "ContentHarmsStrategy is deprecated. Use RapidResponseStrategy instead.", - DeprecationWarning, - stacklevel=2, - ) - return RapidResponseStrategy - raise AttributeError(f"module {__name__!r} has no attribute {name!r}") - - -# Direct aliases for import-from statements -ContentHarms = RapidResponse -ContentHarmsStrategy = RapidResponseStrategy +__all__ = ["ContentHarms", "ContentHarmsStrategy"] diff --git a/pyrit/scenario/scenarios/airt/rapid_response.py b/pyrit/scenario/scenarios/airt/rapid_response.py index 507ae96ddc..445e1d532b 100644 --- a/pyrit/scenario/scenarios/airt/rapid_response.py +++ b/pyrit/scenario/scenarios/airt/rapid_response.py @@ -10,15 +10,14 @@ to test. """ +from __future__ import annotations + import logging -import os -from typing import Optional +from typing import TYPE_CHECKING -from pyrit.auth import get_azure_openai_auth from pyrit.common import apply_defaults -from pyrit.executor.attack import AttackAdversarialConfig, AttackScoringConfig -from pyrit.models import SeedAttackGroup -from pyrit.prompt_target import OpenAIChatTarget, PromptChatTarget +from pyrit.executor.attack import AttackScoringConfig +from pyrit.prompt_target import PromptChatTarget from pyrit.scenario.core.atomic_attack import AtomicAttack from pyrit.scenario.core.dataset_configuration import DatasetConfiguration from pyrit.scenario.core.scenario import Scenario @@ -28,6 +27,9 @@ ) from pyrit.score import TrueFalseScorer +if TYPE_CHECKING: + from pyrit.scenario.core.attack_technique_factory import AttackTechniqueFactory + logger = logging.getLogger(__name__) @@ -109,8 +111,8 @@ def __init__( Args: adversarial_chat: Chat target for multi-turn / adversarial - attacks (RolePlay, TAP). Defaults to an Azure OpenAI - target from environment variables. + attacks (RolePlay, TAP). When provided, overrides the + default adversarial target baked into technique factories. objective_scorer: Scorer for evaluating attack success. Defaults to a composite Azure-Content-Filter + refusal scorer. @@ -120,7 +122,7 @@ def __init__( self._objective_scorer: TrueFalseScorer = ( objective_scorer if objective_scorer else self._get_default_objective_scorer() ) - self._adversarial_chat = adversarial_chat if adversarial_chat else self._get_default_adversarial_target() + self._adversarial_chat = adversarial_chat super().__init__( version=self.VERSION, @@ -133,14 +135,15 @@ def _build_atomic_attack_name(self, *, technique_name: str, seed_group_name: str """Group results by harm category (dataset) rather than technique.""" return seed_group_name - def _get_default_adversarial_target(self) -> OpenAIChatTarget: - endpoint = os.environ.get("AZURE_OPENAI_GPT4O_UNSAFE_CHAT_ENDPOINT") - return OpenAIChatTarget( - endpoint=endpoint, - api_key=get_azure_openai_auth(endpoint), - model_name=os.environ.get("AZURE_OPENAI_GPT4O_UNSAFE_CHAT_MODEL"), - temperature=1.2, - ) + def get_attack_technique_factories(self) -> dict[str, "AttackTechniqueFactory"]: + """ + Register core techniques with this scenario's adversarial chat target. + """ + from pyrit.registry.object_registries.attack_technique_registry import AttackTechniqueRegistry + from pyrit.scenario.core.scenario_techniques import ScenarioTechniqueRegistrar + + ScenarioTechniqueRegistrar(adversarial_chat=self._adversarial_chat).register() + return AttackTechniqueRegistry.get_registry_singleton().get_factories() async def _get_atomic_attacks_async(self) -> list[AtomicAttack]: """ @@ -163,7 +166,11 @@ async def _get_atomic_attacks_async(self) -> list[AtomicAttack]: seed_groups_by_dataset = self._dataset_config.get_seed_attack_groups() scoring_config = AttackScoringConfig(objective_scorer=self._objective_scorer) - adversarial_config = AttackAdversarialConfig(target=self._adversarial_chat) + + # Resolve adversarial_chat for AtomicAttack parameter building. + from pyrit.scenario.core.scenario_techniques import get_default_adversarial_target + + adversarial_chat = self._adversarial_chat or get_default_adversarial_target() atomic_attacks: list[AtomicAttack] = [] for technique_name in selected_techniques: @@ -180,7 +187,6 @@ async def _get_atomic_attacks_async(self) -> list[AtomicAttack]: attack_technique = factory.create( objective_target=self._objective_target, attack_scoring_config_override=scoring_for_technique, - attack_adversarial_config_override=adversarial_config, ) for dataset_name, seed_groups in seed_groups_by_dataset.items(): @@ -192,7 +198,7 @@ async def _get_atomic_attacks_async(self) -> list[AtomicAttack]: ), attack_technique=attack_technique, seed_groups=list(seed_groups), - adversarial_chat=self._adversarial_chat, + adversarial_chat=adversarial_chat, objective_scorer=self._objective_scorer, memory_labels=self._memory_labels, ) diff --git a/pyrit/setup/initializers/components/targets.py b/pyrit/setup/initializers/components/targets.py index 4c652aae18..94f30c4f05 100644 --- a/pyrit/setup/initializers/components/targets.py +++ b/pyrit/setup/initializers/components/targets.py @@ -44,6 +44,7 @@ class TargetInitializerTags(str, Enum): SCORER = "scorer" ALL = "all" DEFAULT_OBJECTIVE_TARGET = "default_objective_target" + ADVERSARIAL = "adversarial" @dataclass @@ -165,6 +166,18 @@ class TargetConfig: model_var="AZURE_OPENAI_GPT4O_UNSAFE_CHAT_MODEL2", underlying_model_var="AZURE_OPENAI_GPT4O_UNSAFE_CHAT_UNDERLYING_MODEL2", ), + # ============================================ + # Adversarial Chat Target (for scenario attack techniques) + # ============================================ + TargetConfig( + registry_name="adversarial_chat", + target_class=OpenAIChatTarget, + endpoint_var="ADVERSARIAL_CHAT_ENDPOINT", + key_var="ADVERSARIAL_CHAT_KEY", + model_var="ADVERSARIAL_CHAT_MODEL", + temperature=1.2, + tags=[TargetInitializerTags.ALL, TargetInitializerTags.ADVERSARIAL], + ), TargetConfig( registry_name="azure_foundry_deepseek", target_class=OpenAIChatTarget, diff --git a/tests/unit/registry/test_attack_technique_registry.py b/tests/unit/registry/test_attack_technique_registry.py index e0d7463b51..18c64be892 100644 --- a/tests/unit/registry/test_attack_technique_registry.py +++ b/tests/unit/registry/test_attack_technique_registry.py @@ -120,7 +120,7 @@ def test_create_technique_returns_attack_technique(self): target = MagicMock(spec=PromptTarget) scoring = MagicMock(spec=AttackScoringConfig) - technique = self.registry.create_technique("stub", objective_target=target, attack_scoring_config=scoring) + technique = self.registry.create_technique("stub", objective_target=target, attack_scoring_config_override=scoring) assert isinstance(technique, AttackTechnique) assert isinstance(technique.attack, _StubAttack) @@ -141,7 +141,7 @@ def get_identifier(self): scoring = MagicMock(spec=AttackScoringConfig) technique = self.registry.create_technique( - "scoring_stub", objective_target=target, attack_scoring_config=scoring + "scoring_stub", objective_target=target, attack_scoring_config_override=scoring ) assert technique.attack.attack_scoring_config is scoring @@ -151,7 +151,7 @@ def test_create_technique_raises_on_missing_name(self): self.registry.create_technique( "nonexistent", objective_target=MagicMock(spec=PromptTarget), - attack_scoring_config=MagicMock(spec=AttackScoringConfig), + attack_scoring_config_override=MagicMock(spec=AttackScoringConfig), ) def test_create_technique_preserves_frozen_kwargs(self): @@ -163,7 +163,7 @@ def test_create_technique_preserves_frozen_kwargs(self): target = MagicMock(spec=PromptTarget) technique = self.registry.create_technique( - "custom", objective_target=target, attack_scoring_config=MagicMock(spec=AttackScoringConfig) + "custom", objective_target=target, attack_scoring_config_override=MagicMock(spec=AttackScoringConfig) ) assert technique.attack.max_turns == 42 @@ -252,3 +252,20 @@ def test_iter_yields_sorted_names(self): self.registry.register_technique(name="a", factory=factory) assert list(self.registry) == ["a", "b"] + + def test_get_factories_returns_dict_mapping(self): + factory_a = AttackTechniqueFactory(attack_class=_StubAttack) + factory_b = AttackTechniqueFactory(attack_class=_StubAttack, attack_kwargs={"max_turns": 5}) + self.registry.register_technique(name="alpha", factory=factory_a) + self.registry.register_technique(name="beta", factory=factory_b) + + result = self.registry.get_factories() + + assert isinstance(result, dict) + assert set(result.keys()) == {"alpha", "beta"} + assert result["alpha"] is factory_a + assert result["beta"] is factory_b + + def test_get_factories_empty_registry(self): + result = self.registry.get_factories() + assert result == {} diff --git a/tests/unit/scenario/test_attack_technique_factory.py b/tests/unit/scenario/test_attack_technique_factory.py index 1d2cb1fbff..756c77b4f0 100644 --- a/tests/unit/scenario/test_attack_technique_factory.py +++ b/tests/unit/scenario/test_attack_technique_factory.py @@ -220,8 +220,8 @@ def test_create_produces_independent_instances(self): assert technique1.attack.objective_target is target1 assert technique2.attack.objective_target is target2 - def test_create_deepcopies_kwargs(self): - """Mutating the original kwargs dict should not affect future creates.""" + def test_create_shares_kwargs_values(self): + """Factory uses shallow copy — mutable values inside kwargs are shared (by design).""" mutable_list = [1, 2, 3] class _ListAttack: @@ -239,15 +239,12 @@ def get_identifier(self): target = MagicMock(spec=PromptTarget) technique1 = factory.create(objective_target=target, attack_scoring_config_override=self._scoring()) - # Mutate the source list - mutable_list.append(999) + assert technique1.attack.items == [1, 2, 3] + # Mutating the original list is visible to future creates (shallow copy) + mutable_list.append(999) technique2 = factory.create(objective_target=target, attack_scoring_config_override=self._scoring()) - - # First create should have the original snapshot - assert technique1.attack.items == [1, 2, 3] - # Second create should also have the original (from deepcopy of stored kwargs) - assert technique2.attack.items == [1, 2, 3] + assert technique2.attack.items == [1, 2, 3, 999] def test_create_without_optional_configs_omits_them(self): """When optional configs are None, adversarial and converter should not be passed.""" diff --git a/tests/unit/scenario/test_rapid_response.py b/tests/unit/scenario/test_rapid_response.py index f5c942a203..e705208387 100644 --- a/tests/unit/scenario/test_rapid_response.py +++ b/tests/unit/scenario/test_rapid_response.py @@ -10,6 +10,7 @@ from pyrit.common.path import DATASETS_PATH from pyrit.executor.attack import ( + AttackAdversarialConfig, ManyShotJailbreakAttack, PromptSendingAttack, RolePlayAttack, @@ -17,15 +18,16 @@ ) from pyrit.identifiers import ComponentIdentifier from pyrit.models import SeedAttackGroup, SeedObjective, SeedPrompt -from pyrit.prompt_target import PromptTarget +from pyrit.prompt_target import OpenAIChatTarget, PromptTarget from pyrit.prompt_target.common.prompt_chat_target import PromptChatTarget +from pyrit.registry.object_registries.attack_technique_registry import AttackTechniqueRegistry, TechniqueSpec from pyrit.scenario import ScenarioCompositeStrategy -from pyrit.scenario.core.core_techniques import ( - many_shot_factory, - prompt_sending_factory, - role_play_factory, - tap_factory, +from pyrit.scenario.core.scenario_techniques import ( + SCENARIO_TECHNIQUES, + ScenarioTechniqueRegistrar, + get_default_adversarial_target, ) +from pyrit.scenario.core.attack_technique_factory import AttackTechniqueFactory from pyrit.scenario.core.dataset_configuration import DatasetConfiguration from pyrit.scenario.scenarios.airt.rapid_response import ( RapidResponse, @@ -75,6 +77,18 @@ def mock_objective_scorer(): return mock +@pytest.fixture(autouse=True) +def reset_technique_registry(): + """Reset the AttackTechniqueRegistry and TargetRegistry singletons between tests.""" + from pyrit.registry import TargetRegistry + + AttackTechniqueRegistry.reset_instance() + TargetRegistry.reset_instance() + yield + AttackTechniqueRegistry.reset_instance() + TargetRegistry.reset_instance() + + @pytest.fixture(autouse=True) def patch_many_shot_load(): """Prevent ManyShotJailbreakAttack from loading the full bundled dataset.""" @@ -87,12 +101,13 @@ def patch_many_shot_load(): @pytest.fixture def mock_runtime_env(): + """Set minimal env vars needed for OpenAIChatTarget fallback via @apply_defaults.""" with patch.dict( "os.environ", { - "AZURE_OPENAI_GPT4O_UNSAFE_CHAT_ENDPOINT": "https://test.openai.azure.com/", - "AZURE_OPENAI_GPT4O_UNSAFE_CHAT_KEY": "test-key", - "AZURE_OPENAI_GPT4O_UNSAFE_CHAT_MODEL": "gpt-4", + "OPENAI_CHAT_ENDPOINT": "https://test.openai.azure.com/", + "OPENAI_CHAT_KEY": "test-key", + "OPENAI_CHAT_MODEL": "gpt-4", }, ): yield @@ -233,11 +248,11 @@ def test_initialization_with_custom_scorer(self, mock_adversarial_target, mock_o assert scenario._objective_scorer == mock_objective_scorer @patch("pyrit.scenario.core.scenario.Scenario._get_default_objective_scorer") - def test_default_adversarial_target_created(self, mock_get_scorer, mock_objective_scorer): - """With env vars patched, constructor creates an OpenAIChatTarget.""" + def test_no_adversarial_chat_stored_when_not_provided(self, mock_get_scorer, mock_objective_scorer): + """When adversarial_chat is not provided, it stays None (factories own the default).""" mock_get_scorer.return_value = mock_objective_scorer scenario = RapidResponse() - assert scenario._adversarial_chat is not None + assert scenario._adversarial_chat is None @pytest.mark.asyncio @patch("pyrit.scenario.core.scenario.Scenario._get_default_objective_scorer") @@ -443,13 +458,19 @@ async def test_unknown_technique_skipped_with_warning( ): """If a technique name has no factory, it's skipped (not an error).""" groups = {"hate": _make_seed_groups("hate")} + + # Register only prompt_sending in the registry — the other techniques + # (role_play, many_shot, tap) won't have factories. + registry = AttackTechniqueRegistry.get_registry_singleton() + registry.register_technique( + name="prompt_sending", + factory=AttackTechniqueFactory(attack_class=PromptSendingAttack), + tags=["single_turn"], + ) + with ( patch.object(DatasetConfiguration, "get_seed_attack_groups", return_value=groups), - patch.object( - RapidResponse, - "get_attack_technique_factories", - return_value={"prompt_sending": prompt_sending_factory()}, - ), + patch.object(ScenarioTechniqueRegistrar, "register"), ): scenario = RapidResponse( adversarial_chat=mock_adversarial_target, @@ -510,33 +531,39 @@ def test_rapid_response_ignores_technique_name(self, mock_adversarial_target, mo # =========================================================================== +@pytest.mark.usefixtures(*FIXTURES) class TestCoreTechniques: - """Tests for shared AttackTechniqueFactory builders in core_techniques.py.""" - - def test_prompt_sending_factory_attack_class(self): - f = prompt_sending_factory() - assert f.attack_class is PromptSendingAttack - - def test_role_play_factory_attack_class(self): - f = role_play_factory() - assert f.attack_class is RolePlayAttack - - def test_many_shot_factory_attack_class(self): - f = many_shot_factory() - assert f.attack_class is ManyShotJailbreakAttack - - def test_tap_factory_attack_class(self): - f = tap_factory() - assert f.attack_class is TreeOfAttacksWithPruningAttack + """Tests for shared AttackTechniqueFactory builders in scenario_techniques.py.""" - def test_base_class_returns_all_four_factories(self): - factories = RapidResponse.get_attack_technique_factories() + def test_instance_returns_all_four_factories(self, mock_adversarial_target, mock_objective_scorer): + scenario = RapidResponse(adversarial_chat=mock_adversarial_target, objective_scorer=mock_objective_scorer) + factories = scenario.get_attack_technique_factories() assert set(factories.keys()) == {"prompt_sending", "role_play", "many_shot", "tap"} assert factories["prompt_sending"].attack_class is PromptSendingAttack assert factories["role_play"].attack_class is RolePlayAttack assert factories["many_shot"].attack_class is ManyShotJailbreakAttack assert factories["tap"].attack_class is TreeOfAttacksWithPruningAttack + def test_factories_use_default_adversarial_when_none(self, mock_objective_scorer): + """When no adversarial_chat is passed, factories use get_default_adversarial_target.""" + scenario = RapidResponse(objective_scorer=mock_objective_scorer) + factories = scenario.get_attack_technique_factories() + # role_play and tap should have attack_adversarial_config baked in + assert "attack_adversarial_config" in factories["role_play"]._attack_kwargs + assert "attack_adversarial_config" in factories["tap"]._attack_kwargs + + def test_factories_use_custom_adversarial_when_provided(self, mock_adversarial_target, mock_objective_scorer): + """When adversarial_chat is provided, the registrar bakes it into factories.""" + scenario = RapidResponse(adversarial_chat=mock_adversarial_target, objective_scorer=mock_objective_scorer) + factories = scenario.get_attack_technique_factories() + + # The registrar bakes the custom adversarial target directly into factories + rp_kwargs = factories["role_play"]._attack_kwargs + assert rp_kwargs["attack_adversarial_config"].target is mock_adversarial_target + + tap_kwargs = factories["tap"]._attack_kwargs + assert tap_kwargs["attack_adversarial_config"].target is mock_adversarial_target + # =========================================================================== # Deprecated alias tests @@ -567,3 +594,223 @@ def test_content_harms_instance_name_is_rapid_response(self, mock_adversarial_ta ) assert scenario.name == "RapidResponse" assert isinstance(scenario, RapidResponse) + + +# =========================================================================== +# Registry integration tests +# =========================================================================== + + +@pytest.mark.usefixtures(*FIXTURES) +class TestRegistryIntegration: + """Tests for AttackTechniqueRegistry wiring via ScenarioTechniqueRegistrar.""" + + def test_registrar_populates_registry(self, mock_adversarial_target): + """After calling register(), all 4 techniques are in registry.""" + ScenarioTechniqueRegistrar(adversarial_chat=mock_adversarial_target).register() + registry = AttackTechniqueRegistry.get_registry_singleton() + names = set(registry.get_names()) + assert names == {"prompt_sending", "role_play", "many_shot", "tap"} + + def test_registrar_idempotent(self, mock_adversarial_target): + """Calling register() twice doesn't duplicate entries.""" + ScenarioTechniqueRegistrar(adversarial_chat=mock_adversarial_target).register() + ScenarioTechniqueRegistrar(adversarial_chat=mock_adversarial_target).register() + registry = AttackTechniqueRegistry.get_registry_singleton() + assert len(registry) == 4 + + def test_registrar_preserves_custom(self, mock_adversarial_target): + """Pre-registered custom techniques aren't overwritten.""" + registry = AttackTechniqueRegistry.get_registry_singleton() + custom_factory = AttackTechniqueFactory(attack_class=PromptSendingAttack) + registry.register_technique(name="role_play", factory=custom_factory, tags=["custom"]) + + ScenarioTechniqueRegistrar(adversarial_chat=mock_adversarial_target).register() + + # role_play should still be the custom factory + factories = registry.get_factories() + assert factories["role_play"] is custom_factory + # Other 3 should have been registered normally + assert len(factories) == 4 + + def test_get_factories_returns_dict(self, mock_adversarial_target): + """get_factories() returns a dict of name → factory.""" + ScenarioTechniqueRegistrar(adversarial_chat=mock_adversarial_target).register() + registry = AttackTechniqueRegistry.get_registry_singleton() + factories = registry.get_factories() + assert isinstance(factories, dict) + assert set(factories.keys()) == {"prompt_sending", "role_play", "many_shot", "tap"} + assert factories["prompt_sending"].attack_class is PromptSendingAttack + + def test_scenario_base_class_reads_from_registry(self, mock_adversarial_target, mock_objective_scorer): + """Scenario.get_attack_technique_factories() triggers registration and reads from registry.""" + scenario = RapidResponse(adversarial_chat=mock_adversarial_target, objective_scorer=mock_objective_scorer) + factories = scenario.get_attack_technique_factories() + + # Should have all 4 core techniques from the registry + assert set(factories.keys()) == {"prompt_sending", "role_play", "many_shot", "tap"} + + # Registry should also have them + registry = AttackTechniqueRegistry.get_registry_singleton() + assert set(registry.get_names()) == {"prompt_sending", "role_play", "many_shot", "tap"} + + def test_tags_assigned_correctly(self, mock_adversarial_target): + """Core techniques have correct tags (single_turn / multi_turn).""" + ScenarioTechniqueRegistrar(adversarial_chat=mock_adversarial_target).register() + registry = AttackTechniqueRegistry.get_registry_singleton() + + single_turn = {e.name for e in registry.get_by_tag(tag="single_turn")} + multi_turn = {e.name for e in registry.get_by_tag(tag="multi_turn")} + + assert single_turn == {"prompt_sending", "role_play"} + assert multi_turn == {"many_shot", "tap"} + + +# =========================================================================== +# ScenarioTechniqueRegistrar tests +# =========================================================================== + + +@pytest.mark.usefixtures(*FIXTURES) +class TestScenarioTechniqueRegistrar: + """Tests for the declarative ScenarioTechniqueRegistrar class.""" + + def test_registrar_populates_all_four_techniques(self): + """Registrar with default adversarial registers all 4 techniques.""" + ScenarioTechniqueRegistrar().register() + registry = AttackTechniqueRegistry.get_registry_singleton() + assert set(registry.get_names()) == {"prompt_sending", "role_play", "many_shot", "tap"} + + def test_registrar_with_custom_adversarial(self, mock_adversarial_target): + """Custom adversarial_chat is baked into adversarial-needing factories.""" + ScenarioTechniqueRegistrar(adversarial_chat=mock_adversarial_target).register() + registry = AttackTechniqueRegistry.get_registry_singleton() + factories = registry.get_factories() + + # role_play and tap should have the mock adversarial target baked in + rp_kwargs = factories["role_play"]._attack_kwargs + assert rp_kwargs["attack_adversarial_config"].target is mock_adversarial_target + + tap_kwargs = factories["tap"]._attack_kwargs + assert tap_kwargs["attack_adversarial_config"].target is mock_adversarial_target + + def test_registrar_idempotent(self, mock_adversarial_target): + """Calling register() twice does not duplicate or overwrite entries.""" + registrar = ScenarioTechniqueRegistrar(adversarial_chat=mock_adversarial_target) + registrar.register() + registrar.register() + registry = AttackTechniqueRegistry.get_registry_singleton() + assert len(registry) == 4 + + def test_registrar_preserves_custom_preregistered(self, mock_adversarial_target): + """Pre-registered custom techniques are not overwritten by registrar.""" + registry = AttackTechniqueRegistry.get_registry_singleton() + custom_factory = AttackTechniqueFactory(attack_class=PromptSendingAttack) + registry.register_technique(name="role_play", factory=custom_factory, tags=["custom"]) + + ScenarioTechniqueRegistrar(adversarial_chat=mock_adversarial_target).register() + # role_play should still be the custom factory + assert registry.get_factories()["role_play"] is custom_factory + assert len(registry) == 4 + + def test_registrar_assigns_correct_tags(self, mock_adversarial_target): + """Tags from TechniqueSpec are applied correctly.""" + ScenarioTechniqueRegistrar(adversarial_chat=mock_adversarial_target).register() + registry = AttackTechniqueRegistry.get_registry_singleton() + + single_turn = {e.name for e in registry.get_by_tag(tag="single_turn")} + multi_turn = {e.name for e in registry.get_by_tag(tag="multi_turn")} + assert single_turn == {"prompt_sending", "role_play"} + assert multi_turn == {"many_shot", "tap"} + + def test_registrar_custom_techniques_list(self, mock_adversarial_target): + """Registrar accepts a custom list of TechniqueSpecs.""" + custom_specs = [ + TechniqueSpec(name="custom_attack", attack_class=PromptSendingAttack, tags=["custom"]), + ] + ScenarioTechniqueRegistrar(adversarial_chat=mock_adversarial_target).register(techniques=custom_specs) + registry = AttackTechniqueRegistry.get_registry_singleton() + assert set(registry.get_names()) == {"custom_attack"} + + def test_registrar_adversarial_lazy_resolution(self): + """Adversarial target is not resolved until register() accesses it.""" + registrar = ScenarioTechniqueRegistrar() + # No env var resolution yet — just creating the registrar + assert registrar._adversarial_chat is None + + def test_get_default_adversarial_target_from_registry(self, mock_adversarial_target): + """get_default_adversarial_target returns registry entry when available.""" + from pyrit.registry import TargetRegistry + + target_registry = TargetRegistry.get_registry_singleton() + target_registry.register(name="adversarial_chat", instance=mock_adversarial_target) + result = get_default_adversarial_target() + assert result is mock_adversarial_target + + def test_get_default_adversarial_target_fallback(self): + """get_default_adversarial_target falls back to OpenAIChatTarget when not in registry.""" + result = get_default_adversarial_target() + assert isinstance(result, OpenAIChatTarget) + assert result._temperature == 1.2 + + +# =========================================================================== +# TechniqueSpec tests +# =========================================================================== + + +@pytest.mark.usefixtures(*FIXTURES) +class TestTechniqueSpec: + """Tests for the TechniqueSpec dataclass.""" + + def test_simple_spec(self): + spec = TechniqueSpec(name="test", attack_class=PromptSendingAttack, tags=["single_turn"]) + assert spec.name == "test" + assert spec.attack_class is PromptSendingAttack + assert spec.tags == ["single_turn"] + assert spec.extra_kwargs_builder is None + + def test_extra_kwargs_builder(self, mock_adversarial_target): + builder = lambda _adv: {"role_play_definition_path": "/custom/path.yaml"} + spec = TechniqueSpec( + name="complex", + attack_class=RolePlayAttack, + tags=["single_turn"], + extra_kwargs_builder=builder, + ) + registrar = ScenarioTechniqueRegistrar(adversarial_chat=mock_adversarial_target) + factory = registrar.build_factory(spec) + assert factory._attack_kwargs["role_play_definition_path"] == "/custom/path.yaml" + assert "attack_adversarial_config" in factory._attack_kwargs + + def test_build_factory_no_adversarial(self, mock_adversarial_target): + """Non-adversarial spec should not have attack_adversarial_config.""" + spec = TechniqueSpec(name="simple", attack_class=PromptSendingAttack, tags=[]) + registrar = ScenarioTechniqueRegistrar(adversarial_chat=mock_adversarial_target) + factory = registrar.build_factory(spec) + assert "attack_adversarial_config" not in (factory._attack_kwargs or {}) + + def test_SCENARIO_TECHNIQUES_list_has_four_entries(self): + assert len(SCENARIO_TECHNIQUES) == 4 + names = {s.name for s in SCENARIO_TECHNIQUES} + assert names == {"prompt_sending", "role_play", "many_shot", "tap"} + + def test_frozen_spec(self): + """TechniqueSpec is frozen (immutable).""" + spec = TechniqueSpec(name="test", attack_class=PromptSendingAttack) + with pytest.raises(AttributeError): + spec.name = "modified" + + def test_adversarial_auto_detected_from_signature(self, mock_adversarial_target): + """Adversarial config is injected based on attack class signature, not a manual flag.""" + registrar = ScenarioTechniqueRegistrar(adversarial_chat=mock_adversarial_target) + + # RolePlayAttack accepts attack_adversarial_config → should be injected + rp_spec = TechniqueSpec(name="rp", attack_class=RolePlayAttack, tags=[]) + rp_factory = registrar.build_factory(rp_spec) + assert "attack_adversarial_config" in rp_factory._attack_kwargs + + # PromptSendingAttack does NOT accept it → should not be injected + ps_spec = TechniqueSpec(name="ps", attack_class=PromptSendingAttack, tags=[]) + ps_factory = registrar.build_factory(ps_spec) + assert "attack_adversarial_config" not in (ps_factory._attack_kwargs or {}) From 7ffd7d9f96a6165ced367af26bd40be6bf90b810 Mon Sep 17 00:00:00 2001 From: Richard Lundeen Date: Thu, 16 Apr 2026 13:28:06 -0700 Subject: [PATCH 3/3] refactoring more --- .../attack_technique_registry.py | 70 ++++++++++- pyrit/scenario/core/__init__.py | 10 +- pyrit/scenario/core/core_techniques.py | 25 ---- pyrit/scenario/core/scenario.py | 4 +- pyrit/scenario/core/scenario_techniques.py | 111 ++++-------------- .../scenario/scenarios/airt/rapid_response.py | 12 +- .../setup/initializers/components/targets.py | 2 +- tests/unit/scenario/test_rapid_response.py | 105 +++++++++-------- 8 files changed, 160 insertions(+), 179 deletions(-) delete mode 100644 pyrit/scenario/core/core_techniques.py diff --git a/pyrit/registry/object_registries/attack_technique_registry.py b/pyrit/registry/object_registries/attack_technique_registry.py index 110b6fb209..0e3d338b0d 100644 --- a/pyrit/registry/object_registries/attack_technique_registry.py +++ b/pyrit/registry/object_registries/attack_technique_registry.py @@ -11,6 +11,7 @@ from __future__ import annotations +import inspect import logging from dataclasses import dataclass, field from typing import TYPE_CHECKING, Any, Callable @@ -37,11 +38,11 @@ class TechniqueSpec: """ Declarative definition of an attack technique. - Each spec describes one registrable technique. The registrar converts + Each spec describes one registrable technique. The registry converts specs into ``AttackTechniqueFactory`` instances and registers them. Whether a technique receives an ``AttackAdversarialConfig`` is determined - automatically: the registrar inspects the attack class constructor and + automatically: the registry inspects the attack class constructor and injects one when ``attack_adversarial_config`` is an accepted parameter. Args: @@ -132,3 +133,68 @@ def create_technique( attack_adversarial_config_override=attack_adversarial_config_override, attack_converter_config_override=attack_converter_config_override, ) + + @staticmethod + def build_factory_from_spec( + spec: TechniqueSpec, + *, + adversarial_chat: "PromptChatTarget | None" = None, + ) -> "AttackTechniqueFactory": + """ + Build an ``AttackTechniqueFactory`` from a ``TechniqueSpec``. + + Automatically injects ``AttackAdversarialConfig`` when the attack + class accepts ``attack_adversarial_config`` as a constructor parameter. + + Args: + spec: The technique specification. + adversarial_chat: Shared adversarial chat target for techniques + that require one. If None, no adversarial config is injected. + + Returns: + AttackTechniqueFactory: A factory ready for registration. + """ + from pyrit.executor.attack import AttackAdversarialConfig + from pyrit.scenario.core.attack_technique_factory import AttackTechniqueFactory + + kwargs: dict[str, Any] = {} + + if adversarial_chat is not None and AttackTechniqueRegistry._accepts_adversarial(spec.attack_class): + kwargs["attack_adversarial_config"] = AttackAdversarialConfig(target=adversarial_chat) + + if spec.extra_kwargs_builder: + kwargs.update(spec.extra_kwargs_builder(adversarial_chat)) + + return AttackTechniqueFactory( + attack_class=spec.attack_class, + attack_kwargs=kwargs or None, + ) + + @staticmethod + def _accepts_adversarial(attack_class: type) -> bool: + """Check if an attack class accepts ``attack_adversarial_config``.""" + sig = inspect.signature(attack_class.__init__) + return "attack_adversarial_config" in sig.parameters + + def register_from_specs( + self, + specs: list[TechniqueSpec], + *, + adversarial_chat: "PromptChatTarget | None" = None, + ) -> None: + """ + Build factories from specs and register them. + + Per-name idempotent: existing entries are not overwritten. + + Args: + specs: Technique specifications to register. + adversarial_chat: Shared adversarial chat target for techniques + that require one. + """ + for spec in specs: + if spec.name not in self: + factory = self.build_factory_from_spec(spec, adversarial_chat=adversarial_chat) + self.register_technique(name=spec.name, factory=factory, tags=spec.tags) + + logger.debug("Technique registration complete (%d total in registry)", len(self)) diff --git a/pyrit/scenario/core/__init__.py b/pyrit/scenario/core/__init__.py index f464f32ddf..a9c3341cad 100644 --- a/pyrit/scenario/core/__init__.py +++ b/pyrit/scenario/core/__init__.py @@ -8,8 +8,8 @@ from pyrit.scenario.core.attack_technique_factory import AttackTechniqueFactory from pyrit.scenario.core.scenario_techniques import ( SCENARIO_TECHNIQUES, - ScenarioTechniqueRegistrar, get_default_adversarial_target, + register_scenario_techniques, ) from pyrit.scenario.core.dataset_configuration import EXPLICIT_SEED_GROUPS_KEY, DatasetConfiguration from pyrit.scenario.core.scenario import Scenario @@ -18,23 +18,17 @@ # TechniqueSpec lives in the registry module but is re-exported here for convenience from pyrit.registry.object_registries.attack_technique_registry import TechniqueSpec -# Backward-compatible aliases (old names) -CORE_TECHNIQUES = SCENARIO_TECHNIQUES -CoreTechniqueRegistrar = ScenarioTechniqueRegistrar - __all__ = [ "AtomicAttack", "AttackTechnique", "AttackTechniqueFactory", - "CORE_TECHNIQUES", - "CoreTechniqueRegistrar", "DatasetConfiguration", "EXPLICIT_SEED_GROUPS_KEY", "SCENARIO_TECHNIQUES", "Scenario", "ScenarioCompositeStrategy", "ScenarioStrategy", - "ScenarioTechniqueRegistrar", "TechniqueSpec", "get_default_adversarial_target", + "register_scenario_techniques", ] diff --git a/pyrit/scenario/core/core_techniques.py b/pyrit/scenario/core/core_techniques.py deleted file mode 100644 index dd4b5ecf9e..0000000000 --- a/pyrit/scenario/core/core_techniques.py +++ /dev/null @@ -1,25 +0,0 @@ -# Copyright (c) Microsoft Corporation. -# Licensed under the MIT license. - -""" -Deprecated — use ``scenario_techniques`` instead. - -This module re-exports everything from ``scenario_techniques`` for backward -compatibility. It will be removed in a future release. -""" - -from pyrit.scenario.core.scenario_techniques import ( - SCENARIO_TECHNIQUES as CORE_TECHNIQUES, - ScenarioTechniqueRegistrar as CoreTechniqueRegistrar, - get_default_adversarial_target, -) - -# Re-export TechniqueSpec from its canonical location -from pyrit.registry.object_registries.attack_technique_registry import TechniqueSpec - -__all__ = [ - "CORE_TECHNIQUES", - "CoreTechniqueRegistrar", - "TechniqueSpec", - "get_default_adversarial_target", -] diff --git a/pyrit/scenario/core/scenario.py b/pyrit/scenario/core/scenario.py index 23add6c3fa..977f5d5059 100644 --- a/pyrit/scenario/core/scenario.py +++ b/pyrit/scenario/core/scenario.py @@ -190,9 +190,9 @@ def get_attack_technique_factories(self) -> dict[str, "AttackTechniqueFactory"]: Returns: dict[str, AttackTechniqueFactory]: Mapping of technique name to factory. """ - from pyrit.scenario.core.scenario_techniques import ScenarioTechniqueRegistrar + from pyrit.scenario.core.scenario_techniques import register_scenario_techniques - ScenarioTechniqueRegistrar().register() + register_scenario_techniques() from pyrit.registry.object_registries.attack_technique_registry import AttackTechniqueRegistry diff --git a/pyrit/scenario/core/scenario_techniques.py b/pyrit/scenario/core/scenario_techniques.py index cfdb8b7d70..58d466d74e 100644 --- a/pyrit/scenario/core/scenario_techniques.py +++ b/pyrit/scenario/core/scenario_techniques.py @@ -5,7 +5,7 @@ Scenario attack technique definitions and registration. Provides ``SCENARIO_TECHNIQUES`` (the standard catalog) and -``ScenarioTechniqueRegistrar`` (registers specs into the +``register_scenario_techniques`` (registers specs into the ``AttackTechniqueRegistry`` singleton). To add a new technique, append a ``TechniqueSpec`` to ``SCENARIO_TECHNIQUES``. @@ -13,12 +13,9 @@ from __future__ import annotations -import inspect import logging -from typing import Any from pyrit.executor.attack import ( - AttackAdversarialConfig, ManyShotJailbreakAttack, PromptSendingAttack, RolePlayAttack, @@ -26,8 +23,8 @@ TreeOfAttacksWithPruningAttack, ) from pyrit.prompt_target import OpenAIChatTarget, PromptChatTarget +from pyrit.prompt_target.common.target_capabilities import CapabilityName from pyrit.registry.object_registries.attack_technique_registry import TechniqueSpec -from pyrit.scenario.core.attack_technique_factory import AttackTechniqueFactory logger = logging.getLogger(__name__) @@ -81,101 +78,37 @@ def get_default_adversarial_target() -> PromptChatTarget: registry = TargetRegistry.get_registry_singleton() if "adversarial_chat" in registry: - return registry.get("adversarial_chat") + target = registry.get("adversarial_chat") + if not target.capabilities.includes(capability=CapabilityName.MULTI_TURN): + raise ValueError( + f"Registry entry 'adversarial_chat' must support multi-turn conversations, " + f"but {type(target).__name__} does not." + ) + return target # type: ignore[return-value] return OpenAIChatTarget(temperature=1.2) # --------------------------------------------------------------------------- -# Registrar +# Registration helper # --------------------------------------------------------------------------- -class ScenarioTechniqueRegistrar: +def register_scenario_techniques(*, adversarial_chat: PromptChatTarget | None = None) -> None: """ - Registers ``TechniqueSpec`` entries into the ``AttackTechniqueRegistry``. + Register all ``SCENARIO_TECHNIQUES`` into the ``AttackTechniqueRegistry`` singleton. - Holds shared defaults (e.g. ``adversarial_chat``) so they're set once - and applied to every technique that needs them. + Per-name idempotent: existing entries are not overwritten. - Typical usage from a scenario:: - - ScenarioTechniqueRegistrar(adversarial_chat=self._adversarial_chat).register() + Args: + adversarial_chat: Shared adversarial chat target for techniques + that require one. If None, resolved via ``get_default_adversarial_target()``. """ + from pyrit.registry.object_registries.attack_technique_registry import AttackTechniqueRegistry + + if adversarial_chat is None: + adversarial_chat = get_default_adversarial_target() - def __init__(self, *, adversarial_chat: PromptChatTarget | None = None) -> None: - """ - Args: - adversarial_chat: Shared adversarial chat target for techniques - that require one. Defaults to ``get_default_adversarial_target()``. - """ - self._adversarial_chat = adversarial_chat - - @property - def adversarial_chat(self) -> PromptChatTarget: - """Resolve the adversarial chat target (custom or default).""" - if self._adversarial_chat is None: - self._adversarial_chat = get_default_adversarial_target() - return self._adversarial_chat - - def build_factory(self, spec: TechniqueSpec) -> AttackTechniqueFactory: - """ - Build an ``AttackTechniqueFactory`` from a ``TechniqueSpec``. - - Automatically injects ``AttackAdversarialConfig`` when the attack - class accepts ``attack_adversarial_config`` as a constructor parameter. - - Args: - spec: The technique specification. - - Returns: - AttackTechniqueFactory: A factory ready for registration. - """ - kwargs: dict[str, Any] = {} - - if self._accepts_adversarial(spec.attack_class): - kwargs["attack_adversarial_config"] = AttackAdversarialConfig(target=self.adversarial_chat) - - if spec.extra_kwargs_builder: - kwargs.update(spec.extra_kwargs_builder(self.adversarial_chat)) - - return AttackTechniqueFactory( - attack_class=spec.attack_class, - attack_kwargs=kwargs or None, - ) - - @staticmethod - def _accepts_adversarial(attack_class: type) -> bool: - """Check if an attack class accepts ``attack_adversarial_config``.""" - sig = inspect.signature(attack_class.__init__) - return "attack_adversarial_config" in sig.parameters - - def register( - self, - *, - techniques: list[TechniqueSpec] | None = None, - registry: "AttackTechniqueRegistry | None" = None, - ) -> None: - """ - Register technique specs into the registry. - - Per-name idempotent: existing entries are not overwritten. - - Args: - techniques: Specs to register. Defaults to ``SCENARIO_TECHNIQUES``. - registry: Registry instance. Defaults to the singleton. - """ - from pyrit.registry.object_registries.attack_technique_registry import AttackTechniqueRegistry - - if registry is None: - registry = AttackTechniqueRegistry.get_registry_singleton() - if techniques is None: - techniques = SCENARIO_TECHNIQUES - - for spec in techniques: - if spec.name not in registry: - factory = self.build_factory(spec) - registry.register_technique(name=spec.name, factory=factory, tags=spec.tags) - - logger.debug("Technique registration complete (%d total in registry)", len(registry)) + registry = AttackTechniqueRegistry.get_registry_singleton() + registry.register_from_specs(SCENARIO_TECHNIQUES, adversarial_chat=adversarial_chat) diff --git a/pyrit/scenario/scenarios/airt/rapid_response.py b/pyrit/scenario/scenarios/airt/rapid_response.py index 445e1d532b..463d6d9d07 100644 --- a/pyrit/scenario/scenarios/airt/rapid_response.py +++ b/pyrit/scenario/scenarios/airt/rapid_response.py @@ -16,7 +16,7 @@ from typing import TYPE_CHECKING from pyrit.common import apply_defaults -from pyrit.executor.attack import AttackScoringConfig +from pyrit.executor.attack import AttackAdversarialConfig, AttackScoringConfig from pyrit.prompt_target import PromptChatTarget from pyrit.scenario.core.atomic_attack import AtomicAttack from pyrit.scenario.core.dataset_configuration import DatasetConfiguration @@ -140,9 +140,9 @@ def get_attack_technique_factories(self) -> dict[str, "AttackTechniqueFactory"]: Register core techniques with this scenario's adversarial chat target. """ from pyrit.registry.object_registries.attack_technique_registry import AttackTechniqueRegistry - from pyrit.scenario.core.scenario_techniques import ScenarioTechniqueRegistrar + from pyrit.scenario.core.scenario_techniques import register_scenario_techniques - ScenarioTechniqueRegistrar(adversarial_chat=self._adversarial_chat).register() + register_scenario_techniques(adversarial_chat=self._adversarial_chat) return AttackTechniqueRegistry.get_registry_singleton().get_factories() async def _get_atomic_attacks_async(self) -> list[AtomicAttack]: @@ -184,9 +184,15 @@ async def _get_atomic_attacks_async(self) -> list[AtomicAttack]: # would fail TAP's type validation. scoring_for_technique = None if technique_name == "tap" else scoring_config + # Build adversarial config override if scenario has a custom adversarial target + adversarial_override = None + if self._adversarial_chat is not None: + adversarial_override = AttackAdversarialConfig(target=self._adversarial_chat) + attack_technique = factory.create( objective_target=self._objective_target, attack_scoring_config_override=scoring_for_technique, + attack_adversarial_config_override=adversarial_override, ) for dataset_name, seed_groups in seed_groups_by_dataset.items(): diff --git a/pyrit/setup/initializers/components/targets.py b/pyrit/setup/initializers/components/targets.py index 94f30c4f05..e97209ac70 100644 --- a/pyrit/setup/initializers/components/targets.py +++ b/pyrit/setup/initializers/components/targets.py @@ -176,7 +176,7 @@ class TargetConfig: key_var="ADVERSARIAL_CHAT_KEY", model_var="ADVERSARIAL_CHAT_MODEL", temperature=1.2, - tags=[TargetInitializerTags.ALL, TargetInitializerTags.ADVERSARIAL], + tags=[TargetInitializerTags.DEFAULT, TargetInitializerTags.ADVERSARIAL], ), TargetConfig( registry_name="azure_foundry_deepseek", diff --git a/tests/unit/scenario/test_rapid_response.py b/tests/unit/scenario/test_rapid_response.py index e705208387..c993e36208 100644 --- a/tests/unit/scenario/test_rapid_response.py +++ b/tests/unit/scenario/test_rapid_response.py @@ -24,8 +24,8 @@ from pyrit.scenario import ScenarioCompositeStrategy from pyrit.scenario.core.scenario_techniques import ( SCENARIO_TECHNIQUES, - ScenarioTechniqueRegistrar, get_default_adversarial_target, + register_scenario_techniques, ) from pyrit.scenario.core.attack_technique_factory import AttackTechniqueFactory from pyrit.scenario.core.dataset_configuration import DatasetConfiguration @@ -470,7 +470,9 @@ async def test_unknown_technique_skipped_with_warning( with ( patch.object(DatasetConfiguration, "get_seed_attack_groups", return_value=groups), - patch.object(ScenarioTechniqueRegistrar, "register"), + patch( + "pyrit.scenario.core.scenario_techniques.register_scenario_techniques", + ), ): scenario = RapidResponse( adversarial_chat=mock_adversarial_target, @@ -603,29 +605,29 @@ def test_content_harms_instance_name_is_rapid_response(self, mock_adversarial_ta @pytest.mark.usefixtures(*FIXTURES) class TestRegistryIntegration: - """Tests for AttackTechniqueRegistry wiring via ScenarioTechniqueRegistrar.""" + """Tests for AttackTechniqueRegistry wiring via register_scenario_techniques.""" - def test_registrar_populates_registry(self, mock_adversarial_target): - """After calling register(), all 4 techniques are in registry.""" - ScenarioTechniqueRegistrar(adversarial_chat=mock_adversarial_target).register() + def test_register_populates_registry(self, mock_adversarial_target): + """After calling register_scenario_techniques(), all 4 techniques are in registry.""" + register_scenario_techniques(adversarial_chat=mock_adversarial_target) registry = AttackTechniqueRegistry.get_registry_singleton() names = set(registry.get_names()) assert names == {"prompt_sending", "role_play", "many_shot", "tap"} - def test_registrar_idempotent(self, mock_adversarial_target): - """Calling register() twice doesn't duplicate entries.""" - ScenarioTechniqueRegistrar(adversarial_chat=mock_adversarial_target).register() - ScenarioTechniqueRegistrar(adversarial_chat=mock_adversarial_target).register() + def test_register_idempotent(self, mock_adversarial_target): + """Calling register_scenario_techniques() twice doesn't duplicate entries.""" + register_scenario_techniques(adversarial_chat=mock_adversarial_target) + register_scenario_techniques(adversarial_chat=mock_adversarial_target) registry = AttackTechniqueRegistry.get_registry_singleton() assert len(registry) == 4 - def test_registrar_preserves_custom(self, mock_adversarial_target): + def test_register_preserves_custom(self, mock_adversarial_target): """Pre-registered custom techniques aren't overwritten.""" registry = AttackTechniqueRegistry.get_registry_singleton() custom_factory = AttackTechniqueFactory(attack_class=PromptSendingAttack) registry.register_technique(name="role_play", factory=custom_factory, tags=["custom"]) - ScenarioTechniqueRegistrar(adversarial_chat=mock_adversarial_target).register() + register_scenario_techniques(adversarial_chat=mock_adversarial_target) # role_play should still be the custom factory factories = registry.get_factories() @@ -635,7 +637,7 @@ def test_registrar_preserves_custom(self, mock_adversarial_target): def test_get_factories_returns_dict(self, mock_adversarial_target): """get_factories() returns a dict of name → factory.""" - ScenarioTechniqueRegistrar(adversarial_chat=mock_adversarial_target).register() + register_scenario_techniques(adversarial_chat=mock_adversarial_target) registry = AttackTechniqueRegistry.get_registry_singleton() factories = registry.get_factories() assert isinstance(factories, dict) @@ -656,7 +658,7 @@ def test_scenario_base_class_reads_from_registry(self, mock_adversarial_target, def test_tags_assigned_correctly(self, mock_adversarial_target): """Core techniques have correct tags (single_turn / multi_turn).""" - ScenarioTechniqueRegistrar(adversarial_chat=mock_adversarial_target).register() + register_scenario_techniques(adversarial_chat=mock_adversarial_target) registry = AttackTechniqueRegistry.get_registry_singleton() single_turn = {e.name for e in registry.get_by_tag(tag="single_turn")} @@ -667,23 +669,23 @@ def test_tags_assigned_correctly(self, mock_adversarial_target): # =========================================================================== -# ScenarioTechniqueRegistrar tests +# Registration and factory-from-spec tests # =========================================================================== @pytest.mark.usefixtures(*FIXTURES) -class TestScenarioTechniqueRegistrar: - """Tests for the declarative ScenarioTechniqueRegistrar class.""" +class TestRegistrationAndFactoryFromSpec: + """Tests for register_scenario_techniques and AttackTechniqueRegistry.build_factory_from_spec.""" - def test_registrar_populates_all_four_techniques(self): - """Registrar with default adversarial registers all 4 techniques.""" - ScenarioTechniqueRegistrar().register() + def test_register_populates_all_four_techniques(self): + """register_scenario_techniques with default adversarial registers all 4 techniques.""" + register_scenario_techniques() registry = AttackTechniqueRegistry.get_registry_singleton() assert set(registry.get_names()) == {"prompt_sending", "role_play", "many_shot", "tap"} - def test_registrar_with_custom_adversarial(self, mock_adversarial_target): + def test_register_with_custom_adversarial(self, mock_adversarial_target): """Custom adversarial_chat is baked into adversarial-needing factories.""" - ScenarioTechniqueRegistrar(adversarial_chat=mock_adversarial_target).register() + register_scenario_techniques(adversarial_chat=mock_adversarial_target) registry = AttackTechniqueRegistry.get_registry_singleton() factories = registry.get_factories() @@ -694,28 +696,27 @@ def test_registrar_with_custom_adversarial(self, mock_adversarial_target): tap_kwargs = factories["tap"]._attack_kwargs assert tap_kwargs["attack_adversarial_config"].target is mock_adversarial_target - def test_registrar_idempotent(self, mock_adversarial_target): - """Calling register() twice does not duplicate or overwrite entries.""" - registrar = ScenarioTechniqueRegistrar(adversarial_chat=mock_adversarial_target) - registrar.register() - registrar.register() + def test_register_idempotent(self, mock_adversarial_target): + """Calling register_scenario_techniques() twice does not duplicate or overwrite entries.""" + register_scenario_techniques(adversarial_chat=mock_adversarial_target) + register_scenario_techniques(adversarial_chat=mock_adversarial_target) registry = AttackTechniqueRegistry.get_registry_singleton() assert len(registry) == 4 - def test_registrar_preserves_custom_preregistered(self, mock_adversarial_target): - """Pre-registered custom techniques are not overwritten by registrar.""" + def test_register_preserves_custom_preregistered(self, mock_adversarial_target): + """Pre-registered custom techniques are not overwritten.""" registry = AttackTechniqueRegistry.get_registry_singleton() custom_factory = AttackTechniqueFactory(attack_class=PromptSendingAttack) registry.register_technique(name="role_play", factory=custom_factory, tags=["custom"]) - ScenarioTechniqueRegistrar(adversarial_chat=mock_adversarial_target).register() + register_scenario_techniques(adversarial_chat=mock_adversarial_target) # role_play should still be the custom factory assert registry.get_factories()["role_play"] is custom_factory assert len(registry) == 4 - def test_registrar_assigns_correct_tags(self, mock_adversarial_target): + def test_register_assigns_correct_tags(self, mock_adversarial_target): """Tags from TechniqueSpec are applied correctly.""" - ScenarioTechniqueRegistrar(adversarial_chat=mock_adversarial_target).register() + register_scenario_techniques(adversarial_chat=mock_adversarial_target) registry = AttackTechniqueRegistry.get_registry_singleton() single_turn = {e.name for e in registry.get_by_tag(tag="single_turn")} @@ -723,21 +724,15 @@ def test_registrar_assigns_correct_tags(self, mock_adversarial_target): assert single_turn == {"prompt_sending", "role_play"} assert multi_turn == {"many_shot", "tap"} - def test_registrar_custom_techniques_list(self, mock_adversarial_target): - """Registrar accepts a custom list of TechniqueSpecs.""" + def test_register_from_specs_custom_list(self, mock_adversarial_target): + """register_from_specs accepts a custom list of TechniqueSpecs.""" custom_specs = [ TechniqueSpec(name="custom_attack", attack_class=PromptSendingAttack, tags=["custom"]), ] - ScenarioTechniqueRegistrar(adversarial_chat=mock_adversarial_target).register(techniques=custom_specs) registry = AttackTechniqueRegistry.get_registry_singleton() + registry.register_from_specs(custom_specs, adversarial_chat=mock_adversarial_target) assert set(registry.get_names()) == {"custom_attack"} - def test_registrar_adversarial_lazy_resolution(self): - """Adversarial target is not resolved until register() accesses it.""" - registrar = ScenarioTechniqueRegistrar() - # No env var resolution yet — just creating the registrar - assert registrar._adversarial_chat is None - def test_get_default_adversarial_target_from_registry(self, mock_adversarial_target): """get_default_adversarial_target returns registry entry when available.""" from pyrit.registry import TargetRegistry @@ -753,6 +748,18 @@ def test_get_default_adversarial_target_fallback(self): assert isinstance(result, OpenAIChatTarget) assert result._temperature == 1.2 + def test_get_default_adversarial_target_capability_check(self): + """get_default_adversarial_target rejects targets without multi-turn support.""" + from pyrit.registry import TargetRegistry + + target_registry = TargetRegistry.get_registry_singleton() + # Register a plain PromptTarget (lacks multi-turn capability) + mock_target = MagicMock(spec=PromptTarget) + mock_target.capabilities.includes.return_value = False + target_registry.register(name="adversarial_chat", instance=mock_target) + with pytest.raises(ValueError, match="must support multi-turn"): + get_default_adversarial_target() + # =========================================================================== # TechniqueSpec tests @@ -778,16 +785,14 @@ def test_extra_kwargs_builder(self, mock_adversarial_target): tags=["single_turn"], extra_kwargs_builder=builder, ) - registrar = ScenarioTechniqueRegistrar(adversarial_chat=mock_adversarial_target) - factory = registrar.build_factory(spec) + factory = AttackTechniqueRegistry.build_factory_from_spec(spec, adversarial_chat=mock_adversarial_target) assert factory._attack_kwargs["role_play_definition_path"] == "/custom/path.yaml" assert "attack_adversarial_config" in factory._attack_kwargs def test_build_factory_no_adversarial(self, mock_adversarial_target): """Non-adversarial spec should not have attack_adversarial_config.""" spec = TechniqueSpec(name="simple", attack_class=PromptSendingAttack, tags=[]) - registrar = ScenarioTechniqueRegistrar(adversarial_chat=mock_adversarial_target) - factory = registrar.build_factory(spec) + factory = AttackTechniqueRegistry.build_factory_from_spec(spec, adversarial_chat=mock_adversarial_target) assert "attack_adversarial_config" not in (factory._attack_kwargs or {}) def test_SCENARIO_TECHNIQUES_list_has_four_entries(self): @@ -803,14 +808,16 @@ def test_frozen_spec(self): def test_adversarial_auto_detected_from_signature(self, mock_adversarial_target): """Adversarial config is injected based on attack class signature, not a manual flag.""" - registrar = ScenarioTechniqueRegistrar(adversarial_chat=mock_adversarial_target) - # RolePlayAttack accepts attack_adversarial_config → should be injected rp_spec = TechniqueSpec(name="rp", attack_class=RolePlayAttack, tags=[]) - rp_factory = registrar.build_factory(rp_spec) + rp_factory = AttackTechniqueRegistry.build_factory_from_spec( + rp_spec, adversarial_chat=mock_adversarial_target + ) assert "attack_adversarial_config" in rp_factory._attack_kwargs # PromptSendingAttack does NOT accept it → should not be injected ps_spec = TechniqueSpec(name="ps", attack_class=PromptSendingAttack, tags=[]) - ps_factory = registrar.build_factory(ps_spec) + ps_factory = AttackTechniqueRegistry.build_factory_from_spec( + ps_spec, adversarial_chat=mock_adversarial_target + ) assert "attack_adversarial_config" not in (ps_factory._attack_kwargs or {})